diff --git a/.asf.yaml b/.asf.yaml index 44a819080e00e..755bb42e6661e 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -60,12 +60,18 @@ github: v2-3-stable: required_pull_request_reviews: required_approving_review_count: 1 + v2-4-stable: + required_pull_request_reviews: + required_approving_review_count: 1 collaborators: - auvipy - paolaperaza - petedejoy - gmcdonald + - o-nikolas + - ferruzzi + - Taragolis notifications: jobs: jobs@airflow.apache.org diff --git a/.coveragerc b/.coveragerc index 5e6918f9262b9..cdcc55d446971 100644 --- a/.coveragerc +++ b/.coveragerc @@ -23,7 +23,6 @@ omit = dev/* airflow/migrations/* airflow/www/node_modules/** - airflow/ui/node_modules/** airflow/_vendor/** [run] diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index d3b2c86215aef..faf30cdd7e686 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -2,9 +2,11 @@ "name": "Apache Airflow - sqlite", "dockerComposeFile": [ "../scripts/ci/docker-compose/devcontainer.yml", - "../scripts/ci/docker-compose/local.yml", - "../scripts/ci/docker-compose/backend-sqlite.yml", + "../scripts/ci/docker-compose/backend-sqlite.yml" ], + "settings": { + "terminal.integrated.defaultProfile.linux": "bash" + }, "extensions": [ "ms-python.python", "ms-python.vscode-pylance", @@ -19,5 +21,10 @@ "rogalmic.bash-debug" ], "service": "airflow", - "forwardPorts": [8080,5555,5432,6379] + "forwardPorts": [8080, 5555, 5432, 6379], + "workspaceFolder": "/opt/airflow", + // for users who use non-standard git config patterns + // https://github.com/microsoft/vscode-remote-release/issues/2084#issuecomment-989756268 + "initializeCommand": "cd \"${localWorkspaceFolder}\" && git config --local user.email \"$(git config user.email)\" && git config --local user.name \"$(git config user.name)\"", + "overrideCommand": true } diff --git a/.devcontainer/mysql/devcontainer.json b/.devcontainer/mysql/devcontainer.json index aa9696b12a33a..5a25b6ad50625 100644 --- a/.devcontainer/mysql/devcontainer.json +++ b/.devcontainer/mysql/devcontainer.json @@ -1,11 +1,13 @@ { "name": "Apache Airflow - mysql", "dockerComposeFile": [ - "../scripts/ci/docker-compose/devcontainer.yml", - "../scripts/ci/docker-compose/local.yml", - "../scripts/ci/docker-compose/backend-mysql.yml", - "../scripts/ci/docker-compose/devcontainer-mysql.yml" + "../../scripts/ci/docker-compose/devcontainer.yml", + "../../scripts/ci/docker-compose/backend-mysql.yml", + "../../scripts/ci/docker-compose/devcontainer-mysql.yml" ], + "settings": { + "terminal.integrated.defaultProfile.linux": "bash" + }, "extensions": [ "ms-python.python", "ms-python.vscode-pylance", diff --git a/.devcontainer/postgres/devcontainer.json b/.devcontainer/postgres/devcontainer.json index fec09cb0500d7..46ba305b58554 100644 --- a/.devcontainer/postgres/devcontainer.json +++ b/.devcontainer/postgres/devcontainer.json @@ -1,11 +1,13 @@ { "name": "Apache Airflow - postgres", "dockerComposeFile": [ - "../scripts/ci/docker-compose/devcontainer.yml", - "../scripts/ci/docker-compose/local.yml", - "../scripts/ci/docker-compose/backend-postgres.yml", - "../scripts/ci/docker-compose/devcontainer-postgres.yml" + "../../scripts/ci/docker-compose/devcontainer.yml", + "../../scripts/ci/docker-compose/backend-postgres.yml", + "../../scripts/ci/docker-compose/devcontainer-postgres.yml" ], + "settings": { + "terminal.integrated.defaultProfile.linux": "bash" + }, "extensions": [ "ms-python.python", "ms-python.vscode-pylance", diff --git a/.dockerignore b/.dockerignore index 69a3bbfca68c0..045b730630678 100644 --- a/.dockerignore +++ b/.dockerignore @@ -34,7 +34,6 @@ !chart !docs !licenses -!metastore_browser # Add those folders to the context so that they are available in the CI container !scripts/in_container @@ -74,19 +73,14 @@ !setup.cfg !setup.py !manifests +!generated # Now - ignore unnecessary files inside allowed directories # This goes after the allowed directories # Git version is dynamically generated airflow/git_version - -# Exclude static www files generated by NPM -airflow/www/static/coverage -airflow/www/static/dist +# Exclude mode_modules pulled by "yarn" for compilation of www files generated by NPM airflow/www/node_modules -# Exclude static ui files generated by NPM -airflow/ui/build -airflow/ui/node_modules # Exclude link to docs airflow/www/static/docs @@ -99,7 +93,7 @@ airflow/www/static/docs **/env/ **/build/ **/develop-eggs/ -**/dist/ +/dist/ **/downloads/ **/eggs/ **/.eggs/ @@ -128,3 +122,7 @@ airflow/www/static/docs docs/_build/ docs/_api/ docs/_doctrees/ + +# files generated by memray +*.py.*.html +*.py.*.bin diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000000000..639153b3c0326 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,11 @@ +# Black enabled. +4e8f9cc8d02b29c325b8a5a76b4837671bdf5f68 +fdd9b6f65b608c516b8a062b058972d9a45ec9e3 + +# PEP-563 (Postponed Evaluation of Annotations). +d67ac5932dabbf06ae733fc57b48491a8029b8c2 + +# Mass converting string literals to use double quotes. +2a34dc9e8470285b0ed2db71109ef4265e29688b +bfcae349b88fd959e32bfacd027a5be976fe2132 +01a819a42daa7990c30ab9776208b3dcb9f3a28b diff --git a/.gitattributes b/.gitattributes index 497db03fbcfc5..4ee74a9d8ab9d 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,6 +1,3 @@ -breeze export-ignore -breeze-legacy export-ignore -breeze-complete export-ignore clients export-ignore clients export-ignore dev export-ignore @@ -16,8 +13,6 @@ tests export-ignore Dockerfile.ci export-ignore ISSUE_TRIAGE_PROCESS.rst export-ignore -PULL_REQUEST_WORKFLOW.rst export-ignore -SELECTIVE_CHECKS.md export-ignore STATIC_CODE_CHECKS.rst export-ignore TESTING.rst export-ignore LOCAL_VIRTUALENV.rst export-ignore diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 594f939e46bb4..15300027e31b2 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -31,9 +31,6 @@ /airflow/api/ @mik-laj @ephraimbuddy /airflow/api_connexion/ @mik-laj @ephraimbuddy -# UI -/airflow/ui/ @ryanahamilton @ashb @bbovenzi - # WWW /airflow/www/ @ryanahamilton @ashb @bbovenzi @@ -46,6 +43,15 @@ /airflow/timetables/ @uranusjr /docs/apache-airflow/concepts/timetable.rst @uranusjr +# Task expansion, scheduling, and rendering +/airflow/models/abstractoperator.py @uranusjr +/airflow/models/baseoperator.py @uranusjr +/airflow/models/expandinput.py @uranusjr +/airflow/models/mappedoperator.py @uranusjr +/airflow/models/operator.py @uranusjr +/airflow/models/xcom_arg.py @uranusjr +/docs/apache-airflow/concepts/dynamic-task-mapping.rst @uranusjr + # Async Operators & Triggerer /airflow/jobs/triggerer_job.py @dstandish /airflow/cli/commands/triggerer_command.py @dstandish @@ -59,12 +65,21 @@ /airflow/providers/google/ @turbaszek /airflow/providers/snowflake/ @turbaszek @potiuk @mik-laj /airflow/providers/cncf/kubernetes @jedcunningham +/airflow/providers/dbt/cloud/ @josh-fell +/airflow/providers/tabular/ @Fokko +/airflow/providers/amazon/ @eladkal +/airflow/providers/common/sql/ @eladkal +/airflow/providers/slack/ @eladkal +/docs/apache-airflow-providers-amazon/ @eladkal +/docs/apache-airflow-providers-common-sql/ @eladkal +/docs/apache-airflow-providers-slack/ @eladkal /docs/apache-airflow-providers-cncf-kubernetes @jedcunningham +/tests/providers/amazon/ @eladkal +/tests/providers/common/sql/ @eladkal +/tests/providers/slack/ @eladkal # Dev tools /.github/workflows/ @potiuk @ashb @kaxil -breeze @potiuk -breeze-complete @potiuk Dockerfile @potiuk @ashb @mik-laj Dockerfile.ci @potiuk @ashb /dev/ @potiuk @ashb @jedcunningham diff --git a/.github/ISSUE_TEMPLATE/airflow_bug_report.yml b/.github/ISSUE_TEMPLATE/airflow_bug_report.yml index fc8cf8f324bc3..913d869e0486b 100644 --- a/.github/ISSUE_TEMPLATE/airflow_bug_report.yml +++ b/.github/ISSUE_TEMPLATE/airflow_bug_report.yml @@ -21,27 +21,13 @@ body: attributes: label: Apache Airflow version description: > - What Apache Airflow version are you using? Only Airflow 2 is supported for bugs. If you wish to - discuss Airflow 1.10, open a [discussion](https://github.com/apache/airflow/discussions) instead! + What Apache Airflow version are you using? If you do not see your version, please (ideally) test on + the latest release or main to see if the issue is fixed before reporting it. multiple: false options: - - "2.3.1 (latest released)" - - "2.3.0" - - "2.2.5" - - "2.2.4" - - "2.2.3" - - "2.2.2" - - "2.2.1" - - "2.2.0" - - "2.1.4" - - "2.1.3" - - "2.1.2" - - "2.1.1" - - "2.1.0" - - "2.0.2" - - "2.0.1" - - "2.0.0" + - "2.4.3" - "main (development)" + - "Other Airflow 2 version (please specify below)" validations: required: true - type: textarea diff --git a/.github/ISSUE_TEMPLATE/airflow_helmchart_bug_report.yml b/.github/ISSUE_TEMPLATE/airflow_helmchart_bug_report.yml index dc3d788b848ea..3111db5957de8 100644 --- a/.github/ISSUE_TEMPLATE/airflow_helmchart_bug_report.yml +++ b/.github/ISSUE_TEMPLATE/airflow_helmchart_bug_report.yml @@ -28,7 +28,8 @@ body: What Apache Airflow Helm Chart version are you using? multiple: false options: - - "1.6.0 (latest released)" + - "1.7.0 (latest released)" + - "1.6.0" - "1.5.0" - "1.4.0" - "1.3.0" @@ -38,31 +39,12 @@ body: - "main (development)" validations: required: true - - type: dropdown + - type: input attributes: label: Apache Airflow version description: > - What Apache Airflow version are you using? Only Airflow 2 is supported for bugs. If you wish to - discuss Airflow 1.10, open a [discussion](https://github.com/apache/airflow/discussions) instead! - multiple: false - options: - - "2.3.1 (latest released)" - - "2.3.0" - - "2.2.5" - - "2.2.4" - - "2.2.3" - - "2.2.2" - - "2.2.1" - - "2.2.0" - - "2.1.4" - - "2.1.3" - - "2.1.2" - - "2.1.1" - - "2.1.0" - - "2.0.2" - - "2.0.1" - - "2.0.0" - - "main (development)" + What Apache Airflow version are you using? + [Only Airflow 2 is supported](https://github.com/apache/airflow#version-life-cycle) for bugs. validations: required: true - type: input diff --git a/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml b/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml index 10bf533ac98a9..198c11e284fd5 100644 --- a/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml +++ b/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml @@ -40,9 +40,11 @@ body: - apache-sqoop - arangodb - asana + - atlassian-jira - celery - cloudant - cncf-kubernetes + - common-sql - databricks - datadog - dbt-cloud @@ -62,7 +64,6 @@ body: - influxdb - jdbc - jenkins - - jira - microsoft-azure - microsoft-mssql - microsoft-psrp @@ -92,6 +93,7 @@ body: - sqlite - ssh - tableau + - tabular - telegram - trino - vertica @@ -104,31 +106,12 @@ body: label: Versions of Apache Airflow Providers description: What Apache Airflow Providers versions are you using? placeholder: You can use `pip freeze | grep apache-airflow-providers` (you can leave only relevant ones) - - type: dropdown + - type: input attributes: label: Apache Airflow version description: > - What Apache Airflow version are you using? Only Airflow 2 is supported for bugs. If you wish to - discuss Airflow 1.10, open a [discussion](https://github.com/apache/airflow/discussions) instead! - multiple: false - options: - - "2.3.1 (latest released)" - - "2.3.0" - - "2.2.5" - - "2.2.4" - - "2.2.3" - - "2.2.2" - - "2.2.1" - - "2.2.0" - - "2.1.4" - - "2.1.3" - - "2.1.2" - - "2.1.1" - - "2.1.0" - - "2.0.2" - - "2.0.1" - - "2.0.0" - - "main (development)" + What Apache Airflow version are you using? + [Only Airflow 2 is supported](https://github.com/apache/airflow#version-life-cycle) for bugs. validations: required: true - type: input diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index 388b7d19a5277..837a88f698d7a 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -12,9 +12,21 @@ body: We really appreciate the community efforts to improve Airflow. - Note, that you do not need to create an issue if you have a change ready to submit! - + Features should be small improvements that do not dramatically change Airflow assumptions. + Note, that in this case you do not even need to create an issue if you have a code change ready to submit! You can open [Pull Request](https://github.com/apache/airflow/pulls) immediately instead. + + For bigger features, those that are impacting Airflow's architecture, security assumptions, + backwards compatibility etc. should be discussed in the [airflow devlist](https://lists.apache.org/list.html?dev@airflow.apache.org). + Such features will need initial discussion - possibly in [discussion](https://github.com/apache/airflow/discussions), followed by + [Airflow Improvement Proposal](https://cwiki.apache.org/confluence/display/AIRFLOW/Airflow+Improvement+Proposals) and formal voting. + If you want to introduce such feature, you need to be prepared to lead a discussion, get consensus + among the community and eventually conduct a successful + [vote](https://www.apache.org/foundation/voting.html) in the community. + + If unsure - open a [discussion](https://github.com/apache/airflow/discussions) first to gather + an initial feedback on your idea. +
" # yamllint enable rule:line-length @@ -37,10 +49,12 @@ body: attributes: label: Are you willing to submit a PR? description: > - This is absolutely not required, but we are happy to guide you in the contribution process - especially if you already have a good understanding of how to implement the feature. + If want to submit a PR you do not need to open feature request, just create the PR!. + Especially if you already have a good understanding of how to implement the feature. Airflow is a community-managed project and we love to bring new contributors in. - Find us in #airflow-how-to-pr on Slack! + Find us in #airflow-how-to-pr on Slack! It's optional though - if you have good idea for small + feature, others might implement it if they pick an interest in it, so feel free to leave that + checkbox unchecked. options: - label: Yes I am willing to submit a PR! - type: checkboxes diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index be2a102e92d15..021313b68eb72 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -5,7 +5,7 @@ remember to adjust the documentation. Feel free to ping committers for the review! -In case of existing issue, reference it using one of the following: +In case of an existing issue, reference it using one of the following: closes: #ISSUE related: #ISSUE @@ -18,6 +18,6 @@ http://chris.beams.io/posts/git-commit/ **^ Add meaningful description above** Read the **[Pull Request Guidelines](https://github.com/apache/airflow/blob/main/CONTRIBUTING.rst#pull-request-guidelines)** for more information. -In case of fundamental code change, Airflow Improvement Proposal ([AIP](https://cwiki.apache.org/confluence/display/AIRFLOW/Airflow+Improvements+Proposals)) is needed. +In case of fundamental code changes, an Airflow Improvement Proposal ([AIP](https://cwiki.apache.org/confluence/display/AIRFLOW/Airflow+Improvement+Proposals)) is needed. In case of a new dependency, check compliance with the [ASF 3rd Party License Policy](https://www.apache.org/legal/resolved.html#category-x). -In case of backwards incompatible changes please leave a note in a newsfragement file, named `{pr_number}.significant.rst`, in [newsfragments](https://github.com/apache/airflow/tree/main/newsfragments). +In case of backwards incompatible changes please leave a note in a newsfragment file, named `{pr_number}.significant.rst` or `{issue_number}.significant.rst`, in [newsfragments](https://github.com/apache/airflow/tree/main/newsfragments). diff --git a/.github/actions/breeze/action.yml b/.github/actions/breeze/action.yml new file mode 100644 index 0000000000000..4db1164e2fe9b --- /dev/null +++ b/.github/actions/breeze/action.yml @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +--- +name: 'Setup Breeze' +description: 'Sets up Python and Breeze' +outputs: + host-python-version: + description: Python version used in host + value: ${{ steps.host-python-version.outputs.host-python-version }} +runs: + using: "composite" + steps: + - name: "Setup python" + uses: actions/setup-python@v4 + with: + python-version: 3.7 + cache: 'pip' + cache-dependency-path: ./dev/breeze/setup* + - name: Cache breeze + uses: actions/cache@v3 + with: + path: ~/.local/pipx + key: "breeze-${{ hashFiles('dev/breeze/README.md') }}" # README has the latest breeze's hash + restore-keys: breeze- + - name: "Install Breeze" + shell: bash + run: ./scripts/ci/install_breeze.sh + - name: "Free space" + shell: bash + run: breeze ci free-space + - name: "Get Python version" + shell: bash + run: > + echo "host-python-version=$(python -c 'import platform; print(platform.python_version())')" + >> ${GITHUB_OUTPUT} + id: host-python-version + - name: "Disable cheatsheet" + shell: bash + run: breeze setup config --no-cheatsheet --no-asciiart diff --git a/.github/actions/build-ci-images/action.yml b/.github/actions/build-ci-images/action.yml new file mode 100644 index 0000000000000..08259aadabaaa --- /dev/null +++ b/.github/actions/build-ci-images/action.yml @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +--- +name: 'Build CI images' +description: 'Build CI images' +runs: + using: "composite" + steps: + - name: "Install Breeze" + uses: ./.github/actions/breeze + - name: "Regenerate dependencies" + # This is done in case some someone updated provider dependencies and did not generate + # dependencies - in which case build image might fail because of lack of new dependencies + shell: bash + run: | + pip install rich>=12.4.4 pyyaml + python scripts/ci/pre_commit/pre_commit_update_providers_dependencies.py + if: env.UPGRADE_TO_NEWER_DEPENDENCIES != 'false' + - name: "Build & Push AMD64 CI images ${{ env.IMAGE_TAG }} ${{ env.PYTHON_VERSIONS }}" + shell: bash + run: breeze ci-image build --push --tag-as-latest --run-in-parallel --upgrade-on-failure + - name: "Show dependencies to be upgraded" + shell: bash + run: > + breeze release-management generate-constraints --run-in-parallel + --airflow-constraints-mode constraints-source-providers + if: env.UPGRADE_TO_NEWER_DEPENDENCIES != 'false' + - name: Push empty CI image ${{ env.PYTHON_MAJOR_MINOR_VERSION }}:${{ env.IMAGE }} + if: failure() || cancelled() + shell: bash + run: breeze ci-image build --push --empty-image --run-in-parallel + env: + IMAGE_TAG: ${{ env.IMAGE_TAG }} + - name: "Candidates for pip resolver backtrack triggers" + shell: bash + run: breeze ci find-newer-dependencies --max-age 1 --python 3.7 + if: failure() || cancelled() + - name: "Fix ownership" + shell: bash + run: breeze ci fix-ownership + if: always() diff --git a/.github/actions/build-prod-images/action.yml b/.github/actions/build-prod-images/action.yml new file mode 100644 index 0000000000000..b4cec13ddc1a1 --- /dev/null +++ b/.github/actions/build-prod-images/action.yml @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +--- +name: 'Build PROD images' +description: 'Build PROD images' +inputs: + build-provider-packages: + description: 'Whether to build provider packages from sources' + required: true +runs: + using: "composite" + steps: + - name: "Install Breeze" + uses: ./.github/actions/breeze + - name: "Regenerate dependencies in case they was modified manually so that we can build an image" + shell: bash + run: | + pip install rich>=12.4.4 pyyaml + python scripts/ci/pre_commit/pre_commit_update_providers_dependencies.py + if: env.UPGRADE_TO_NEWER_DEPENDENCIES != 'false' + - name: "Pull CI image for PROD build: ${{ env.PYTHON_VERSIONS }}:${{ env.IMAGE_TAG }}" + shell: bash + run: breeze ci-image pull --tag-as-latest + env: + PYTHON_MAJOR_MINOR_VERSION: "3.7" + - name: "Cleanup dist and context file" + shell: bash + run: rm -fv ./dist/* ./docker-context-files/* + - name: "Prepare providers packages" + shell: bash + run: > + breeze release-management prepare-provider-packages + --package-list-file ./scripts/ci/installed_providers.txt + --package-format wheel --version-suffix-for-pypi dev0 + if: ${{ inputs.build-provider-packages == 'true' }} + - name: "Prepare airflow package" + shell: bash + run: > + breeze release-management prepare-airflow-package + --package-format wheel --version-suffix-for-pypi dev0 + - name: "Move dist packages to docker-context files" + shell: bash + run: mv -v ./dist/*.whl ./docker-context-files + - name: "Build & Push PROD images ${{ env.IMAGE_TAG }}:${{ env.PYTHON_VERSIONS }}" + shell: bash + run: > + breeze prod-image build --tag-as-latest --run-in-parallel --push + --install-packages-from-context --upgrade-on-failure + - name: Push empty PROD images ${{ env.IMAGE_TAG }} + shell: bash + run: breeze prod-image build --cleanup-context --push --empty-image --run-in-parallel + if: failure() || cancelled() + - name: "Fix ownership" + shell: bash + run: breeze ci fix-ownership + if: always() diff --git a/.github/actions/checks-action b/.github/actions/checks-action deleted file mode 160000 index 9f02872da71b6..0000000000000 --- a/.github/actions/checks-action +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 9f02872da71b6f558c6a6f190f925dde5e4d8798 diff --git a/.github/actions/get-workflow-origin b/.github/actions/get-workflow-origin deleted file mode 160000 index 588cc14f9f1cd..0000000000000 --- a/.github/actions/get-workflow-origin +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 588cc14f9f1cdf1b8be3db816855e96422204fec diff --git a/.github/actions/label-when-approved-action b/.github/actions/label-when-approved-action deleted file mode 160000 index 0058d0094da27..0000000000000 --- a/.github/actions/label-when-approved-action +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 0058d0094da27e116fad6e0da516ebe1107f26de diff --git a/.github/actions/migration_tests/action.yml b/.github/actions/migration_tests/action.yml new file mode 100644 index 0000000000000..81ea7e838335a --- /dev/null +++ b/.github/actions/migration_tests/action.yml @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +--- +name: 'Run migration tests' +description: 'Runs migration tests' +runs: + using: "composite" + steps: + - name: "Test downgrade migration file ${{env.BACKEND}}" + shell: bash + run: > + breeze shell "airflow db reset --skip-init -y && + airflow db upgrade --to-revision heads && + airflow db downgrade -r e959f08ac86c -y && + airflow db upgrade" + - name: "Test downgrade ORM ${{env.BACKEND}}" + shell: bash + run: > + breeze shell "airflow db reset -y && + airflow db upgrade && + airflow db downgrade -r e959f08ac86c -y && + airflow db upgrade" diff --git a/.github/actions/post_tests/action.yml b/.github/actions/post_tests/action.yml new file mode 100644 index 0000000000000..96bed9211bb5b --- /dev/null +++ b/.github/actions/post_tests/action.yml @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +--- +name: 'Post tests' +description: 'Run post tests actions' +runs: + using: "composite" + steps: + - name: "Upload airflow logs" + uses: actions/upload-artifact@v3 + if: failure() + with: + name: airflow-logs-${{env.JOB_ID}} + path: './files/airflow_logs*' + retention-days: 7 + - name: "Upload container logs" + uses: actions/upload-artifact@v3 + if: failure() + with: + name: container-logs-${{env.JOB_ID}} + path: "./files/container_logs*" + retention-days: 7 + - name: "Upload artifact for coverage" + uses: actions/upload-artifact@v3 + if: env.COVERAGE == 'true' + with: + name: coverage-${{env.JOB_ID}} + path: ./files/coverage*.xml + retention-days: 7 + - name: "Upload artifact for warnings" + uses: actions/upload-artifact@v3 + with: + name: test-warnings-${{env.JOB_ID}} + path: ./files/warnings-*.txt + retention-days: 7 + - name: "Fix ownership" + shell: bash + run: breeze ci fix-ownership + if: always() diff --git a/.github/actions/prepare_breeze_and_image/action.yml b/.github/actions/prepare_breeze_and_image/action.yml new file mode 100644 index 0000000000000..61967e1e52bb2 --- /dev/null +++ b/.github/actions/prepare_breeze_and_image/action.yml @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +--- +name: 'Prepare breeze && current python image' +description: 'Installs breeze and pulls current python image' +inputs: + pull-image-type: + description: 'Which image to pull' + default: CI +outputs: + host-python-version: + description: Python version used in host + value: ${{ steps.breeze.outputs.host-python-version }} +runs: + using: "composite" + steps: + - name: "Install Breeze" + uses: ./.github/actions/breeze + id: breeze + - name: Pull CI image ${{ env.PYTHON_MAJOR_MINOR_VERSION }}:${{ env.IMAGE_TAG }} + shell: bash + run: breeze ci-image pull --tag-as-latest + if: inputs.pull-image-type == 'CI' + - name: Pull PROD image ${{ env.PYTHON_MAJOR_MINOR_VERSION }}:${{ env.IMAGE_TAG }} + shell: bash + run: breeze prod-image pull --tag-as-latest + if: inputs.pull-image-type == 'PROD' diff --git a/.github/boring-cyborg.yml b/.github/boring-cyborg.yml index 607d4fb6cdc55..7a444cf253378 100644 --- a/.github/boring-cyborg.yml +++ b/.github/boring-cyborg.yml @@ -39,6 +39,16 @@ labelPRBasedOnFilePath: - docs/apache-airflow-providers-apache-*/**/* - tests/providers/apache/**/* + provider:Common-sql: + - airflow/providers/common/sql/**/* + - docs/apache-airflow-providers-common-sql/**/* + - tests/providers/common/sql/**/* + + provider:Databricks: + - airflow/providers/databricks/**/* + - docs/apache-airflow-providers-databricks/**/* + - tests/providers/databricks/**/* + provider:Snowflake: - airflow/providers/snowflake/**/* - docs/apache-airflow-providers-snowflake/**/* @@ -57,8 +67,8 @@ labelPRBasedOnFilePath: - airflow/kubernetes_executor_templates/**/* - airflow/executors/kubernetes_executor.py - airflow/executors/celery_kubernetes_executor.py - - docs/apache-airflow/executor/kubernetes.rst - - docs/apache-airflow/executor/celery_kubernetes.rst + - docs/apache-airflow/core-concepts/executor/kubernetes.rst + - docs/apache-airflow/core-concepts/executor/celery_kubernetes.rst - docs/apache-airflow-providers-cncf-kubernetes/**/* - kubernetes_tests/**/* @@ -73,14 +83,11 @@ labelPRBasedOnFilePath: - tests/www/api/**/* area:dev-tools: - - airflow/mypy/**/* - scripts/**/* - dev/**/* - .github/**/* - - breeze - Dockerfile.ci - BREEZE.rst - - breeze-complete - CONTRIBUTING.* - LOCAL_VIRTUALENV.rst - STATIC_CODE_CHECKS.rst @@ -119,28 +126,26 @@ labelPRBasedOnFilePath: - airflow/www/.eslintrc - airflow/www/.stylelintignore - airflow/www/.stylelintrc - - airflow/www/compile_assets.sh - airflow/www/package.json - airflow/www/webpack.config.js - airflow/www/yarn.lock - docs/apache-airflow/ui.rst - - airflow/ui/**/* area:CLI: - airflow/cli/**/*.py - tests/cli/**/*.py - docs/apache-airflow/cli-and-env-variables-ref.rst - - docs/apache-airflow/usage-cli.rst + - docs/apache-airflow/howto/usage-cli.rst area:Lineage: - airflow/lineage/**/* - tests/lineage/**/* - - docs/apache-airflow/lineage.rst + - docs/apache-airflow/administration-and-deployment/lineage.rst area:Logging: - airflow/providers/**/log/* - airflow/utils/log/**/* - - docs/apache-airflow/logging-monitoring/logging-*.rst + - docs/apache-airflow/administration-and-deployment/logging-monitoring/logging-*.rst - tests/providers/**/log/* - tests/utils/log/**/* @@ -149,15 +154,15 @@ labelPRBasedOnFilePath: - airflow/plugins_manager.py - tests/cli/commands/test_plugins_command.py - tests/plugins/**/* - - docs/apache-airflow/plugins.rst + - docs/apache-airflow/authoring-and-scheduling/plugins.rst area:Scheduler/Executor: - airflow/executors/**/* - airflow/jobs/**/* - airflow/task/task_runner/**/* - airflow/dag_processing/**/* - - docs/apache-airflow/executor/**/* - - docs/apache-airflow/concepts/scheduler.rst + - docs/apache-airflow/core-concepts/executor/**/* + - docs/apache-airflow/administration-and-deployment/scheduler.rst - tests/executors/**/* - tests/jobs/**/* @@ -166,14 +171,14 @@ labelPRBasedOnFilePath: - airflow/providers/**/secrets/* - tests/secrets/**/* - tests/providers/**/secrets/* - - docs/apache-airflow/security/secrets/**/* + - docs/apache-airflow/administration-and-deployment/security/secrets/**/* area:Serialization: - airflow/serialization/**/* - airflow/models/serialized_dag.py - tests/serialization/**/* - tests/models/test_serialized_dag.py - - docs/apache-airflow/dag-serialization.rst + - docs/apache-airflow/administration-and-deployment/dag-serialization.rst area:core-operators: - airflow/operators/**/* @@ -190,6 +195,9 @@ labelPRBasedOnFilePath: - docs/docker-stack/**/* - docker_tests/**/* + area:system-tests: + - tests/system/**/* + # Various Flags to control behaviour of the "Labeler" labelerFlags: # If this flag is changed to 'false', labels would only be added when the PR is first created @@ -216,7 +224,7 @@ firstPRWelcomeComment: > Consider adding an example DAG that shows how users should use it. - Consider using [Breeze environment](https://github.com/apache/airflow/blob/main/BREEZE.rst) for testing - locally, it’s a heavy docker but it ships with a working Airflow and a lot of integrations. + locally, it's a heavy docker but it ships with a working Airflow and a lot of integrations. - Be patient and persistent. It might take some time to get a review or get the final approval from Committers. @@ -245,7 +253,4 @@ firstIssueWelcomeComment: > checkUpToDate: targetBranch: main - files: - - airflow/migrations/* - - airflow/migrations/**/* - - airflow/alembic.ini + files: [] diff --git a/.github/workflows/build-images.yml b/.github/workflows/build-images.yml index 9970a82e6c213..b6934b655538c 100644 --- a/.github/workflows/build-images.yml +++ b/.github/workflows/build-images.yml @@ -17,17 +17,16 @@ # --- name: "Build Images" +run-name: > + Build images for ${{ github.event.pull_request.title }} ${{ github.event.pull_request._links.html.href }} on: # yamllint disable-line rule:truthy pull_request_target: permissions: # all other permissions are set to none contents: read + pull-requests: read env: - MOUNT_SELECTED_LOCAL_SOURCES: "false" ANSWER: "yes" - CHECK_IMAGE_FOR_REBUILD: "true" - SKIP_CHECK_REMOTE_IMAGE: "true" - DEBIAN_VERSION: "bullseye" DB_RESET: "true" VERBOSE: "true" GITHUB_REPOSITORY: ${{ github.repository }} @@ -39,7 +38,9 @@ env: secrets.CONSTRAINTS_GITHUB_REPOSITORY || 'apache/airflow' }} # This token is WRITE one - pull_request_target type of events always have the WRITE token GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - IMAGE_TAG_FOR_THE_BUILD: "${{ github.event.pull_request.head.sha || github.sha }}" + IMAGE_TAG: "${{ github.event.pull_request.head.sha || github.sha }}" + USE_SUDO: "true" + INCLUDE_SUCCESS_OUTPUTS: "true" concurrency: group: build-${{ github.event.pull_request.number || github.ref }} @@ -51,20 +52,21 @@ jobs: name: "Build Info" runs-on: ${{ github.repository == 'apache/airflow' && 'self-hosted' || 'ubuntu-20.04' }} env: - targetBranch: ${{ github.event.pull_request.base.ref }} + TARGET_BRANCH: ${{ github.event.pull_request.base.ref }} outputs: - runsOn: ${{ github.repository == 'apache/airflow' && '["self-hosted"]' || '["ubuntu-20.04"]' }} - pythonVersions: "${{ steps.selective-checks.python-versions }}" - upgradeToNewerDependencies: ${{ steps.selective-checks.outputs.upgrade-to-newer-dependencies }} - allPythonVersions: ${{ steps.selective-checks.outputs.all-python-versions }} - defaultPythonVersion: ${{ steps.selective-checks.outputs.default-python-version }} + runs-on: ${{ github.repository == 'apache/airflow' && 'self-hosted' || 'ubuntu-20.04' }} + python-versions: "${{ steps.selective-checks.python-versions }}" + upgrade-to-newer-dependencies: ${{ steps.selective-checks.outputs.upgrade-to-newer-dependencies }} + all-python-versions-list-as-string: >- + ${{ steps.selective-checks.outputs.all-python-versions-list-as-string }} + default-python-version: ${{ steps.selective-checks.outputs.default-python-version }} run-tests: ${{ steps.selective-checks.outputs.run-tests }} run-kubernetes-tests: ${{ steps.selective-checks.outputs.run-kubernetes-tests }} - image-build: ${{ steps.dynamic-outputs.outputs.image-build }} - cacheDirective: ${{ steps.dynamic-outputs.outputs.cacheDirective }} - targetBranch: ${{ steps.dynamic-outputs.outputs.targetBranch }} - defaultBranch: ${{ steps.selective-checks.outputs.default-branch }} - targetCommitSha: "${{steps.discover-pr-merge-commit.outputs.targetCommitSha || + image-build: ${{ steps.selective-checks.outputs.image-build }} + cache-directive: ${{ steps.selective-checks.outputs.cache-directive }} + default-branch: ${{ steps.selective-checks.outputs.default-branch }} + default-constraints-branch: ${{ steps.selective-checks.outputs.default-constraints-branch }} + target-commit-sha: "${{steps.discover-pr-merge-commit.outputs.target-commit-sha || github.event.pull_request.head.sha || github.sha }}" @@ -73,15 +75,15 @@ jobs: id: discover-pr-merge-commit run: | TARGET_COMMIT_SHA="$(gh api '${{ github.event.pull_request.url }}' --jq .merge_commit_sha)" - echo "TARGET_COMMIT_SHA=$TARGET_COMMIT_SHA" >> $GITHUB_ENV - echo "::set-output name=targetCommitSha::${TARGET_COMMIT_SHA}" + echo "TARGET_COMMIT_SHA=$TARGET_COMMIT_SHA" >> ${GITHUB_ENV} + echo "target-commit-sha=${TARGET_COMMIT_SHA}" >> ${GITHUB_OUTPUT} if: github.event_name == 'pull_request_target' # The labels in the event aren't updated when re-triggering the job, So lets hit the API to get # up-to-date values - name: Get latest PR labels id: get-latest-pr-labels run: | - echo -n "::set-output name=pullRequestLabels::" + echo -n "pull-request-labels=" >> ${GITHUB_OUTPUT} gh api graphql --paginate -F node_id=${{github.event.pull_request.node_id}} -f query=' query($node_id: ID!, $endCursor: String) { node(id:$node_id) { @@ -92,269 +94,246 @@ jobs: } } } - }' --jq '.data.node.labels.nodes[]' | jq --slurp -c '[.[].name]' + }' --jq '.data.node.labels.nodes[]' | jq --slurp -c '[.[].name]' >> ${GITHUB_OUTPUT} if: github.event_name == 'pull_request_target' # Retrieve it to be able to determine which files has changed in the incoming commit of the PR # we checkout the target commit and it's parent to be able to compare them - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: ref: ${{ env.TARGET_COMMIT_SHA }} persist-credentials: false fetch-depth: 2 - - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + - name: "Setup python" + uses: actions/setup-python@v4 with: + python-version: 3.7 + - name: "Retrieve defaults from branch_defaults.py" + # We cannot "execute" the branch_defaults.py python code here because that would be + # a security problem (we cannot run any code that comes from the sources coming from the PR. + # Therefore, we extract the branches via embedded Python code + # we need to do it before next step replaces checked-out breeze and scripts code coming from + # the PR, because the PR defaults have to be retrieved here. + id: defaults + run: | + python - <> ${GITHUB_ENV} + from pathlib import Path + import re + import sys + + DEFAULTS_CONTENT = Path('dev/breeze/src/airflow_breeze/branch_defaults.py').read_text() + BRANCH_PATTERN = r'^AIRFLOW_BRANCH = "(.*)"$' + CONSTRAINTS_BRANCH_PATTERN = r'^DEFAULT_AIRFLOW_CONSTRAINTS_BRANCH = "(.*)"$' + + branch = re.search(BRANCH_PATTERN, DEFAULTS_CONTENT, re.MULTILINE).group(1) + constraints_branch = re.search(CONSTRAINTS_BRANCH_PATTERN, DEFAULTS_CONTENT, re.MULTILINE).group(1) + + output = f""" + DEFAULT_BRANCH={branch} + DEFAULT_CONSTRAINTS_BRANCH={constraints_branch} + """.strip() + + print(output) + # Stdout is redirected to GITHUB_ENV but we also print it to stderr to see it in ci log + print(output, file=sys.stderr) + EOF + - name: Checkout main branch to use breeze from there. + uses: actions/checkout@v3 + with: + ref: "main" persist-credentials: false submodules: recursive + - name: "Install Breeze" + uses: ./.github/actions/breeze - name: Selective checks id: selective-checks env: - PR_LABELS: ${{ steps.get-latest-pr-labels.outputs.pullRequestLabels }} - run: | - if [[ ${GITHUB_EVENT_NAME} == "pull_request_target" ]]; then - # Run selective checks - ./scripts/ci/selective_ci_checks.sh "${TARGET_COMMIT_SHA}" - else - # Run all checks - ./scripts/ci/selective_ci_checks.sh - fi - - name: Compute dynamic outputs - id: dynamic-outputs - run: | - set -x - if [[ "${{ github.event_name }}" == 'pull_request_target' ]]; then - echo "::set-output name=targetBranch::${targetBranch}" - else - # Direct push to branch, or scheduled build - echo "::set-output name=targetBranch::${GITHUB_REF#refs/heads/}" - fi - - if [[ "${{ github.event_name }}" == 'schedule' ]]; then - echo "::set-output name=cacheDirective::disabled" - else - echo "::set-output name=cacheDirective:registry" - fi - - if [[ "$SELECTIVE_CHECKS_IMAGE_BUILD" == "true" ]]; then - echo "::set-output name=image-build::true" - else - echo "::set-output name=image-build::false" - fi - env: - SELECTIVE_CHECKS_IMAGE_BUILD: ${{ steps.selective-checks.outputs.image-build }} + PR_LABELS: "${{ steps.get-latest-pr-labels.outputs.pull-request-labels }}" + COMMIT_REF: "${{ env.TARGET_COMMIT_SHA }}" + VERBOSE: "false" + run: breeze ci selective-check >> ${GITHUB_OUTPUT} - name: env run: printenv env: - dynamicOutputs: ${{ toJSON(steps.dynamic-outputs.outputs) }} - PR_LABELS: ${{ steps.get-latest-pr-labels.outputs.pullRequestLabels }} + PR_LABELS: ${{ steps.get-latest-pr-labels.outputs.pull-request-labels }} GITHUB_CONTEXT: ${{ toJson(github) }} build-ci-images: permissions: packages: write timeout-minutes: 80 - name: "Build CI image ${{matrix.python-version}}" - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + name: > + Build CI images ${{needs.build-info.outputs.all-python-versions-list-as-string}} + runs-on: ${{ needs.build-info.outputs.runs-on }} needs: [build-info] - strategy: - matrix: - python-version: ${{ fromJson(needs.build-info.outputs.allPythonVersions) }} - fail-fast: true if: | needs.build-info.outputs.image-build == 'true' && github.event.pull_request.head.repo.full_name != 'apache/airflow' env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn)[0] }} - PYTHON_MAJOR_MINOR_VERSION: ${{ matrix.python-version }} + DEFAULT_BRANCH: ${{ needs.build-info.outputs.default-branch }} + DEFAULT_CONSTRAINTS_BRANCH: ${{ needs.build-info.outputs.default-constraints-branch }} + RUNS_ON: ${{ needs.build-info.outputs.runs-on }} BACKEND: sqlite - outputs: ${{toJSON(needs.build-info.outputs) }} steps: - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: - ref: ${{ needs.build-info.outputs.targetCommitSha }} + ref: ${{ needs.build-info.outputs.target-commit-sha }} persist-credentials: false submodules: recursive - - name: "Retrieve DEFAULTS from the _initialization.sh" - # We cannot "source" the script here because that would be a security problem (we cannot run - # any code that comes from the sources coming from the PR. Therefore, we extract the - # DEFAULT_BRANCH and DEFAULT_CONSTRAINTS_BRANCH and DEBIAN_VERSION via custom grep/awk/sed commands - id: defaults - run: | - DEFAULT_BRANCH=$(grep "export DEFAULT_BRANCH" scripts/ci/libraries/_initialization.sh | \ - awk 'BEGIN{FS="="} {print $3}' | sed s'/["}]//g') - echo "DEFAULT_BRANCH=${DEFAULT_BRANCH}" >> $GITHUB_ENV - DEFAULT_CONSTRAINTS_BRANCH=$(grep "export DEFAULT_CONSTRAINTS_BRANCH" \ - scripts/ci/libraries/_initialization.sh | \ - awk 'BEGIN{FS="="} {print $3}' | sed s'/["}]//g') - echo "DEFAULT_CONSTRAINTS_BRANCH=${DEFAULT_CONSTRAINTS_BRANCH}" >> $GITHUB_ENV - DEBIAN_VERSION=$(grep "export DEBIAN_VERSION" scripts/ci/libraries/_initialization.sh | \ - awk 'BEGIN{FS="="} {print $3}' | sed s'/["}]//g') - echo "DEBIAN_VERSION=${DEBIAN_VERSION}" >> $GITHUB_ENV - name: > - Checkout "${{ needs.build-info.outputs.targetBranch }}" branch to 'main-airflow' folder + Checkout "main branch to 'main-airflow' folder to use ci/scripts from there. - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: path: "main-airflow" - ref: "${{ needs.build-info.outputs.targetBranch }}" + ref: "main" persist-credentials: false submodules: recursive - - name: "Setup python" - uses: actions/setup-python@v2 - with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - name: > - Override "scripts/ci" with the "${{ needs.build-info.outputs.targetBranch }}" branch + Override "scripts/ci", "dev" and "./github/actions" with the "main" branch so that the PR does not override it # We should not override those scripts which become part of the image as they will not be # changed in the image built - we should only override those that are executed to build # the image. run: | rm -rfv "scripts/ci" - rm -rfv "dev" mv -v "main-airflow/scripts/ci" "scripts" + rm -rfv "dev" mv -v "main-airflow/dev" "." - - uses: actions/setup-python@v2 - with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - cache: 'pip' - cache-dependency-path: ./dev/breeze/setup* - - run: ./scripts/ci/install_breeze.sh - - name: "Free space" - run: breeze free-space - - name: Build & Push CI image ${{ env.PYTHON_MAJOR_MINOR_VERSION }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} - run: breeze build-image --push-image --tag-as-latest - env: - UPGRADE_TO_NEWER_DEPENDENCIES: ${{ needs.build-info.outputs.upgradeToNewerDependencies }} - DOCKER_CACHE: ${{ needs.build-info.outputs.cacheDirective }} - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} - - name: Push empty CI image ${{ env.PYTHON_MAJOR_MINOR_VERSION }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} - if: failure() || cancelled() - run: breeze build-image --push-image --empty-image + rm -rfv "./github/actions" + mv -v "main-airflow/.github/actions" "actions" + - name: > + Build CI Images ${{needs.build-info.outputs.all-python-versions-list-as-string}}:${{env.IMAGE_TAG}} + uses: ./.github/actions/build-ci-images env: - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} - - name: "Candidates for pip resolver backtrack triggers: ${{ matrix.python-version }}" - if: failure() || cancelled() - run: breeze find-newer-dependencies --max-age 1 --python "${{ matrix.python-version }}" - - name: "Fix ownership" - run: breeze fix-ownership - if: always() + UPGRADE_TO_NEWER_DEPENDENCIES: ${{ needs.build-info.outputs.upgrade-to-newer-dependencies }} + DOCKER_CACHE: ${{ needs.build-info.outputs.cache-directive }} + PYTHON_VERSIONS: ${{needs.build-info.outputs.all-python-versions-list-as-string}} + DEBUG_RESOURCES: ${{ needs.build-info.outputs.debug-resources }} build-prod-images: permissions: packages: write timeout-minutes: 80 - name: "Build PROD image ${{matrix.python-version}}" - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + name: > + Build PROD images + ${{needs.build-info.outputs.all-python-versions-list-as-string}} + runs-on: ${{ needs.build-info.outputs.runs-on }} needs: [build-info, build-ci-images] - strategy: - matrix: - python-version: ${{ fromJson(needs.build-info.outputs.allPythonVersions) }} - fail-fast: true if: | needs.build-info.outputs.image-build == 'true' && github.event.pull_request.head.repo.full_name != 'apache/airflow' env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn)[0] }} - PYTHON_MAJOR_MINOR_VERSION: ${{ matrix.python-version }} + DEFAULT_BRANCH: ${{ needs.build-info.outputs.default-branch }} + DEFAULT_CONSTRAINTS_BRANCH: ${{ needs.build-info.outputs.default-constraints-branch }} + RUNS_ON: ${{ needs.build-info.outputs.runs-on }} BACKEND: sqlite steps: - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: - ref: ${{ needs.build-info.outputs.targetCommitSha }} + ref: ${{ needs.build-info.outputs.target-commit-sha }} persist-credentials: false submodules: recursive - - name: "Retrieve DEFAULTS from the _initialization.sh" - # We cannot "source" the script here because that would be a security problem (we cannot run - # any code that comes from the sources coming from the PR. Therefore we extract the - # DEFAULT_BRANCH and DEFAULT_CONSTRAINTS_BRANCH and DEBIAN_VERSION via custom grep/awk/sed commands - id: defaults - run: | - DEFAULT_BRANCH=$(grep "export DEFAULT_BRANCH" scripts/ci/libraries/_initialization.sh | \ - awk 'BEGIN{FS="="} {print $3}' | sed s'/["}]//g') - echo "DEFAULT_BRANCH=${DEFAULT_BRANCH}" >> $GITHUB_ENV - DEFAULT_CONSTRAINTS_BRANCH=$(grep "export DEFAULT_CONSTRAINTS_BRANCH" \ - scripts/ci/libraries/_initialization.sh | \ - awk 'BEGIN{FS="="} {print $3}' | sed s'/["}]//g') - echo "DEFAULT_CONSTRAINTS_BRANCH=${DEFAULT_CONSTRAINTS_BRANCH}" >> $GITHUB_ENV - DEBIAN_VERSION=$(grep "export DEBIAN_VERSION" scripts/ci/libraries/_initialization.sh | \ - cut -d "=" -f 3 | sed s'/["}]//g') - echo "DEBIAN_VERSION=${DEBIAN_VERSION}" >> $GITHUB_ENV - name: > - Checkout "${{ needs.build-info.outputs.targetBranch }}" branch to 'main-airflow' folder + Checkout "main" branch to 'main-airflow' folder to use ci/scripts from there. - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: path: "main-airflow" - ref: "${{ needs.build-info.outputs.targetBranch }}" + ref: "main" persist-credentials: false submodules: recursive - - name: "Setup python" - uses: actions/setup-python@v2 - with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - name: > - Override "scripts/ci" with the "${{ needs.build-info.outputs.targetBranch }}" branch + Override "scripts/ci", "dev" and "./github/actions" with the "main" branch so that the PR does not override it # We should not override those scripts which become part of the image as they will not be # changed in the image built - we should only override those that are executed to build # the image. run: | rm -rfv "scripts/ci" - rm -rfv "dev" mv -v "main-airflow/scripts/ci" "scripts" + rm -rfv "dev" mv -v "main-airflow/dev" "." - - uses: actions/setup-python@v2 - with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - cache: 'pip' - cache-dependency-path: ./dev/breeze/setup* - - run: ./scripts/ci/install_breeze.sh - - name: "Free space" - run: breeze free-space + rm -rfv "./github/actions" + mv -v "main-airflow/.github/actions" "actions" - name: > - Pull CI image for PROD build: - ${{ env.PYTHON_MAJOR_MINOR_VERSION }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} - run: breeze pull-image --tag-as-latest + Build PROD Images + ${{needs.build-info.outputs.all-python-versions-list-as-string}}:${{env.IMAGE_TAG}} + uses: ./.github/actions/build-prod-images + with: + build-provider-packages: ${{ needs.build-info.outputs.default-branch == 'main' }} env: - # Always use default Python version of CI image for preparing packages - PYTHON_MAJOR_MINOR_VERSION: ${{ needs.build-info.outputs.defaultPythonVersion }} - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} - - name: "Cleanup dist and context file" - run: rm -fv ./dist/* ./docker-context-files/* - - name: "Prepare providers packages" - run: > - breeze prepare-provider-packages - --package-list-file ./scripts/ci/installed_providers.txt - --package-format wheel - --version-suffix-for-pypi dev0 - - name: "Prepare airflow package" - run: breeze prepare-airflow-package --package-format wheel --version-suffix-for-pypi dev0 - - name: "Move dist packages to docker-context files" - run: mv -v ./dist/*.whl ./docker-context-files - - name: Build & Push PROD image ${{ env.PYTHON_MAJOR_MINOR_VERSION }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} + UPGRADE_TO_NEWER_DEPENDENCIES: ${{ needs.build-info.outputs.upgrade-to-newer-dependencies }} + DOCKER_CACHE: ${{ needs.build-info.outputs.cache-directive }} + PYTHON_VERSIONS: ${{needs.build-info.outputs.all-python-versions-list-as-string}} + DEBUG_RESOURCES: ${{ needs.build-info.outputs.debug-resources }} + + build-ci-images-arm: + timeout-minutes: 50 + name: "Build ARM CI images ${{needs.build-info.outputs.all-python-versions-list-as-string}}" + runs-on: ${{ needs.build-info.outputs.runs-on }} + needs: [build-info, build-prod-images] + if: | + needs.build-info.outputs.image-build == 'true' && + needs.build-info.outputs.upgrade-to-newer-dependencies != 'false' && + github.event.pull_request.head.repo.full_name != 'apache/airflow' + env: + DEFAULT_BRANCH: ${{ needs.build-info.outputs.default-branch }} + DEFAULT_CONSTRAINTS_BRANCH: ${{ needs.build-info.outputs.default-constraints-branch }} + RUNS_ON: ${{ needs.build-info.outputs.runs-on }} + BACKEND: sqlite + outputs: ${{toJSON(needs.build-info.outputs) }} + steps: + - name: Cleanup repo + run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" + - uses: actions/checkout@v3 + with: + ref: ${{ needs.build-info.outputs.target-commit-sha }} + persist-credentials: false + submodules: recursive + - name: > + Checkout "main" branch to 'main-airflow' folder + to use ci/scripts from there. + uses: actions/checkout@v3 + with: + path: "main-airflow" + ref: "main" + persist-credentials: false + submodules: recursive + - name: > + Override "scripts/ci", "dev" and "./github/actions" with the "main" branch + so that the PR does not override it + # We should not override those scripts which become part of the image as they will not be + # changed in the image built - we should only override those that are executed to build + # the image. + run: | + rm -rfv "scripts/ci" + mv -v "main-airflow/scripts/ci" "scripts" + rm -rfv "dev" + mv -v "main-airflow/dev" "." + rm -rfv "./github/actions" + mv -v "main-airflow/.github/actions" "actions" + - name: "Start ARM instance" + run: ./scripts/ci/images/ci_start_arm_instance_and_connect_to_docker.sh + - name: "Install Breeze" + uses: ./.github/actions/breeze + - name: > + Build ARM CI images ${{ env.IMAGE_TAG }} + ${{needs.build-info.outputs.all-python-versions-list-as-string}} run: > - breeze build-prod-image - --tag-as-latest - --push-image - --install-packages-from-context - --disable-airflow-repo-cache - --airflow-is-in-context + breeze ci-image build --run-in-parallel --builder airflow_cache --platform "linux/arm64" env: - UPGRADE_TO_NEWER_DEPENDENCIES: ${{ needs.build-info.outputs.upgradeToNewerDependencies }} - DOCKER_CACHE: ${{ needs.build-info.outputs.cacheDirective }} - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} - - name: Push empty PROD image ${{ env.PYTHON_MAJOR_MINOR_VERSION }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} - if: failure() || cancelled() - run: breeze build-prod-image --cleanup-context --push-image --empty-image - env: - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} + UPGRADE_TO_NEWER_DEPENDENCIES: ${{ needs.build-info.outputs.upgrade-to-newer-dependencies }} + DOCKER_CACHE: ${{ needs.build-info.outputs.cache-directive }} + PYTHON_VERSIONS: ${{needs.build-info.outputs.all-python-versions-list-as-string}} + - name: "Stop ARM instance" + run: ./scripts/ci/images/ci_stop_arm_instance.sh + if: always() - name: "Fix ownership" - run: breeze fix-ownership + run: breeze ci fix-ownership if: always() diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 05781d33d7cc2..d4163671bde51 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,11 +29,7 @@ permissions: contents: read packages: read env: - MOUNT_SELECTED_LOCAL_SOURCES: "false" ANSWER: "yes" - CHECK_IMAGE_FOR_REBUILD: "true" - SKIP_CHECK_REMOTE_IMAGE: "true" - DEBIAN_VERSION: "bullseye" DB_RESET: "true" VERBOSE: "true" GITHUB_REPOSITORY: ${{ github.repository }} @@ -46,7 +42,9 @@ env: # In builds from forks, this token is read-only. For scheduler/direct push it is WRITE one GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} ENABLE_TEST_COVERAGE: "${{ github.event_name == 'push' }}" - IMAGE_TAG_FOR_THE_BUILD: "${{ github.event.pull_request.head.sha || github.sha }}" + IMAGE_TAG: "${{ github.event.pull_request.head.sha || github.sha }}" + USE_SUDO: "true" + INCLUDE_SUCCESS_OUTPUTS: "true" concurrency: group: ci-${{ github.event.pull_request.number || github.ref }} @@ -61,13 +59,11 @@ jobs: # is checked again by the runner using it's own list, so a PR author cannot # change this and get access to our self-hosted runners # - # When changing this list, ensure that it is kept in sync with the - # /runners/apache/airflow/configOverlay - # parameter in AWS SSM ParameterStore (which is what the runner uses) - # and restart the self-hosted runners. - # - # This list of committers can be generated with: - # https://github.com/apache/airflow-ci-infra/blob/main/scripts/list_committers + # This list is kept up-to-date from the list of authors found in the + # 'airflow-ci-infra' by the 'sync_authors' Github workflow. It uses a regexp + # to find the list of authors and replace them, so any changes to the + # formatting of the contains(fromJSON()) structure below will need to be + # reflected in that workflow too. runs-on: >- ${{ ( ( @@ -77,10 +73,10 @@ jobs: "BasPH", "Fokko", "KevinYang21", + "Taragolis", "XD-DENG", "aijamalnk", "alexvanboxel", - "aneesh-joseph", "aoen", "artwr", "ashb", @@ -107,6 +103,9 @@ jobs: "milton0825", "mistercrunch", "msumit", + "o-nikolas", + "pierrejeambrun", + "pingzh", "potiuk", "r39132", "ryanahamilton", @@ -125,35 +124,32 @@ jobs: env: GITHUB_CONTEXT: ${{ toJson(github) }} outputs: - defaultBranch: ${{ steps.selective-checks.outputs.default-branch }} - cacheDirective: ${{ steps.dynamic-outputs.outputs.cacheDirective }} - waitForImage: ${{ steps.wait-for-image.outputs.wait-for-image }} - allPythonVersions: ${{ steps.selective-checks.outputs.all-python-versions }} - upgradeToNewerDependencies: ${{ steps.selective-checks.outputs.upgrade-to-newer-dependencies }} - pythonVersions: ${{ steps.selective-checks.outputs.python-versions }} - pythonVersionsListAsString: ${{ steps.selective-checks.outputs.python-versions-list-as-string }} - defaultPythonVersion: ${{ steps.selective-checks.outputs.default-python-version }} - kubernetesVersions: ${{ steps.selective-checks.outputs.kubernetes-versions }} - kubernetesVersionsListAsString: ${{ steps.selective-checks.outputs.kubernetes-versions-list-as-string }} - defaultKubernetesVersion: ${{ steps.selective-checks.outputs.default-kubernetes-version }} - kubernetesModes: ${{ steps.selective-checks.outputs.kubernetes-modes }} - defaultKubernetesMode: ${{ steps.selective-checks.outputs.default-kubernetes-mode }} - postgresVersions: ${{ steps.selective-checks.outputs.postgres-versions }} - defaultPostgresVersion: ${{ steps.selective-checks.outputs.default-postgres-version }} - mysqlVersions: ${{ steps.selective-checks.outputs.mysql-versions }} - mssqlVersions: ${{ steps.selective-checks.outputs.mssql-versions }} - defaultMySQLVersion: ${{ steps.selective-checks.outputs.default-mysql-version }} - helmVersions: ${{ steps.selective-checks.outputs.helm-versions }} - defaultHelmVersion: ${{ steps.selective-checks.outputs.default-helm-version }} - kindVersions: ${{ steps.selective-checks.outputs.kind-versions }} - defaultKindVersion: ${{ steps.selective-checks.outputs.default-kind-version }} - testTypes: ${{ steps.selective-checks.outputs.test-types }} - postgresExclude: ${{ steps.selective-checks.outputs.postgres-exclude }} - mysqlExclude: ${{ steps.selective-checks.outputs.mysql-exclude }} - mssqlExclude: ${{ steps.selective-checks.outputs.mssql-exclude }} - sqliteExclude: ${{ steps.selective-checks.outputs.sqlite-exclude }} + cache-directive: ${{ steps.selective-checks.outputs.cache-directive }} + upgrade-to-newer-dependencies: ${{ steps.selective-checks.outputs.upgrade-to-newer-dependencies }} + python-versions: ${{ steps.selective-checks.outputs.python-versions }} + python-versions-list-as-string: ${{ steps.selective-checks.outputs.python-versions-list-as-string }} + all-python-versions-list-as-string: >- + ${{ steps.selective-checks.outputs.all-python-versions-list-as-string }} + default-python-version: ${{ steps.selective-checks.outputs.default-python-version }} + kubernetes-versions-list-as-string: >- + ${{ steps.selective-checks.outputs.kubernetes-versions-list-as-string }} + kubernetes-combos: ${{ steps.selective-checks.outputs.kubernetes-combos }} + default-kubernetes-version: ${{ steps.selective-checks.outputs.default-kubernetes-version }} + postgres-versions: ${{ steps.selective-checks.outputs.postgres-versions }} + default-postgres-version: ${{ steps.selective-checks.outputs.default-postgres-version }} + mysql-versions: ${{ steps.selective-checks.outputs.mysql-versions }} + mssql-versions: ${{ steps.selective-checks.outputs.mssql-versions }} + default-mysql-version: ${{ steps.selective-checks.outputs.default-mysql-version }} + default-helm-version: ${{ steps.selective-checks.outputs.default-helm-version }} + default-kind-version: ${{ steps.selective-checks.outputs.default-kind-version }} + full-tests-needed: ${{ steps.selective-checks.outputs.full-tests-needed }} + test-types: ${{ steps.selective-checks.outputs.test-types }} + postgres-exclude: ${{ steps.selective-checks.outputs.postgres-exclude }} + mysql-exclude: ${{ steps.selective-checks.outputs.mysql-exclude }} + mssql-exclude: ${{ steps.selective-checks.outputs.mssql-exclude }} + sqlite-exclude: ${{ steps.selective-checks.outputs.sqlite-exclude }} + providers-package-format-exclude: ${{ steps.selective-checks.outputs.providers-package-format-exclude }} run-tests: ${{ steps.selective-checks.outputs.run-tests }} - run-ui-tests: ${{ steps.selective-checks.outputs.run-ui-tests }} run-www-tests: ${{ steps.selective-checks.outputs.run-www-tests }} run-kubernetes-tests: ${{ steps.selective-checks.outputs.run-kubernetes-tests }} basic-checks-only: ${{ steps.selective-checks.outputs.basic-checks-only }} @@ -163,353 +159,298 @@ jobs: needs-api-tests: ${{ steps.selective-checks.outputs.needs-api-tests }} needs-api-codegen: ${{ steps.selective-checks.outputs.needs-api-codegen }} default-branch: ${{ steps.selective-checks.outputs.default-branch }} - sourceHeadRepo: ${{ steps.source-run-info.outputs.sourceHeadRepo }} - pullRequestNumber: ${{ steps.source-run-info.outputs.pullRequestNumber }} - pullRequestLabels: ${{ steps.source-run-info.outputs.pullRequestLabels }} - runsOn: ${{ steps.set-runs-on.outputs.runsOn }} - runCoverage: ${{ steps.set-run-coverage.outputs.runCoverage }} - inWorkflowBuild: ${{ steps.set-in-workflow-build.outputs.inWorkflowBuild }} - buildJobDescription: ${{ steps.set-in-workflow-build.outputs.buildJobDescription }} - mergeRun: ${{ steps.set-merge-run.outputs.merge-run }} + default-constraints-branch: ${{ steps.selective-checks.outputs.default-constraints-branch }} + docs-filter: ${{ steps.selective-checks.outputs.docs-filter }} + skip-pre-commits: ${{ steps.selective-checks.outputs.skip-pre-commits }} + debug-resources: ${{ steps.selective-checks.outputs.debug-resources }} + source-head-repo: ${{ steps.source-run-info.outputs.source-head-repo }} + pull-request-labels: ${{ steps.source-run-info.outputs.pr-labels }} + in-workflow-build: ${{ steps.source-run-info.outputs.in-workflow-build }} + build-job-description: ${{ steps.source-run-info.outputs.build-job-description }} + runs-on: ${{ steps.source-run-info.outputs.runs-on }} + canary-run: ${{ steps.source-run-info.outputs.canary-run }} + run-coverage: ${{ steps.source-run-info.outputs.run-coverage }} steps: - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: persist-credentials: false submodules: recursive - - name: "Get information about the PR" - uses: ./.github/actions/get-workflow-origin - id: source-run-info - with: - token: ${{ secrets.GITHUB_TOKEN }} - name: Fetch incoming commit ${{ github.sha }} with its parent - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: ref: ${{ github.sha }} fetch-depth: 2 persist-credentials: false - if: github.event_name == 'pull_request' - - name: Selective checks - id: selective-checks - env: - PR_LABELS: "${{ steps.source-run-info.outputs.pullRequestLabels }}" - run: | - if [[ ${GITHUB_EVENT_NAME} == "pull_request" ]]; then - # Run selective checks - ./scripts/ci/selective_ci_checks.sh "${GITHUB_SHA}" - else - # Run all checks - ./scripts/ci/selective_ci_checks.sh - fi - # Avoid having to specify the runs-on logic every time. We use the custom - # env var AIRFLOW_SELF_HOSTED_RUNNER set only on our runners, but never - # on the public runners - - name: Set runs-on - id: set-runs-on - env: - PR_LABELS: "${{ steps.source-run-info.outputs.pullRequestLabels }}" - run: | - if [[ ${PR_LABELS=} == *"use public runners"* ]]; then - echo "Forcing running on Public Runners via `use public runners` label" - echo "::set-output name=runsOn::\"ubuntu-20.04\"" - elif [[ ${AIRFLOW_SELF_HOSTED_RUNNER} == "" ]]; then - echo "Regular PR running with Public Runner" - echo "::set-output name=runsOn::\"ubuntu-20.04\"" - else - echo "Maintainer or main run running with self-hosted runner" - echo "::set-output name=runsOn::\"self-hosted\"" - fi - # Avoid having to specify the coverage logic every time. - - name: Set run coverage - id: set-run-coverage - run: echo "::set-output name=runCoverage::true" - if: > - github.ref == 'refs/heads/main' && github.repository == 'apache/airflow' && - github.event_name == 'push' && - steps.selective-checks.outputs.default-branch == 'main' - - name: Determine where to run image builds - id: set-in-workflow-build - # Run in-workflow build image when: - # * direct push is run - # * schedule build is run - # * pull request is run not from fork - run: | - set -x - if [[ ${GITHUB_EVENT_NAME} == "push" || ${GITHUB_EVENT_NAME} == "push" || \ - ${{steps.source-run-info.outputs.sourceHeadRepo}} == "apache/airflow" ]]; then - echo "Images will be built in current workflow" - echo "::set-output name=inWorkflowBuild::true" - echo "::set-output name=buildJobDescription::Build" - else - echo "Images will be built in pull_request_target workflow" - echo "::set-output name=inWorkflowBuild::false" - echo "::set-output name=buildJobDescription::Skip Build (pull_request_target)" - fi - - name: Determine if this is merge run - id: set-merge-run - run: echo "::set-output name=merge-run::true" - # Only in Apache Airflow repo, when there is a merge run to main or any of v2*test branches - if: | - github.repository == 'apache/airflow' && github.event_name == 'push' && - ( - github.ref_name == 'main' || - startsWith(github.ref_name, 'v2') && endsWith(github.ref_name, 'test') - ) - - name: Compute dynamic outputs - id: dynamic-outputs + - name: "Install Breeze" + uses: ./.github/actions/breeze + - name: "Retrieve defaults from branch_defaults.py" + id: defaults + # We could retrieve it differently here - by just importing the variables and + # printing them from python code, however we want to have the same code as used in + # the build-images.yml (there we cannot import python code coming from the PR - we need to + # treat the python code as text and extract the variables from there. run: | - set -x - if [[ "${{ github.event_name }}" == 'schedule' ]]; then - echo "::set-output name=cacheDirective::disabled" - else - echo "::set-output name=cacheDirective::registry" - fi + python - <> ${GITHUB_ENV} + from pathlib import Path + import re + import sys - if [[ "$SELECTIVE_CHECKS_IMAGE_BUILD" == "true" ]]; then - echo "::set-output name=image-build::true" - else - echo "::set-output name=image-build::false" - fi + DEFAULTS_CONTENT = Path('dev/breeze/src/airflow_breeze/branch_defaults.py').read_text() + BRANCH_PATTERN = r'^AIRFLOW_BRANCH = "(.*)"$' + CONSTRAINTS_BRANCH_PATTERN = r'^DEFAULT_AIRFLOW_CONSTRAINTS_BRANCH = "(.*)"$' + + branch = re.search(BRANCH_PATTERN, DEFAULTS_CONTENT, re.MULTILINE).group(1) + constraints_branch = re.search(CONSTRAINTS_BRANCH_PATTERN, DEFAULTS_CONTENT, re.MULTILINE).group(1) + + output = f""" + DEFAULT_BRANCH={branch} + DEFAULT_CONSTRAINTS_BRANCH={constraints_branch} + """.strip() + + print(output) + # Stdout is redirected to GITHUB_ENV but we also print it to stderr to see it in ci log + print(output, file=sys.stderr) + EOF + - name: "Get information about the Workflow" + id: source-run-info + run: breeze ci get-workflow-info 2>> ${GITHUB_OUTPUT} + - name: Selective checks + id: selective-checks env: - SELECTIVE_CHECKS_IMAGE_BUILD: ${{ steps.selective-checks.outputs.image-build }} + PR_LABELS: "${{ steps.source-run-info.outputs.pr-labels }}" + COMMIT_REF: "${{ github.sha }}" + VERBOSE: "false" + run: breeze ci selective-check >> ${GITHUB_OUTPUT} - name: env run: printenv env: - dynamicOutputs: ${{ toJSON(steps.dynamic-outputs.outputs) }} - PR_LABELS: ${{ steps.get-latest-pr-labels.outputs.pullRequestLabels }} + PR_LABELS: ${{ steps.source-run-info.outputs.pr-labels }} GITHUB_CONTEXT: ${{ toJson(github) }} + # Push early BuildX cache to GitHub Registry in Apache repository, This cache does not wait for all the + # tests to complete - it is run very early in the build process for "main" merges in order to refresh + # cache using the current constraints. This will speed up cache refresh in cases when setup.py + # changes or in case of Dockerfile changes. Failure in this step is not a problem (at most it will + # delay cache refresh. It does not attempt to upgrade to newer dependencies. + # We only push CI cache as PROD cache usually does not gain as much from fresh cache because + # it uses prepared airflow and provider packages that invalidate the cache anyway most of the time + push-early-buildx-cache-to-github-registry: + permissions: + packages: write + timeout-minutes: 50 + name: "Push Early Image Cache" + runs-on: "${{needs.build-info.outputs.runs-on}}" + needs: + - build-info + strategy: + fail-fast: false + matrix: + platform: ["linux/amd64", "linux/arm64"] + env: + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" + UPGRADE_TO_NEWER_DEPENDENCIES: false + continue-on-error: true + steps: + - name: Cleanup repo + run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" + if: needs.build-info.outputs.canary-run == 'true' && needs.build-info.outputs.default-branch == 'main' + - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" + uses: actions/checkout@v3 + with: + persist-credentials: false + if: needs.build-info.outputs.canary-run == 'true' && needs.build-info.outputs.default-branch == 'main' + - name: "Install Breeze" + uses: ./.github/actions/breeze + if: needs.build-info.outputs.canary-run == 'true' && needs.build-info.outputs.default-branch == 'main' + - name: "Start ARM instance" + run: ./scripts/ci/images/ci_start_arm_instance_and_connect_to_docker.sh + if: > + matrix.platform == 'linux/arm64' && needs.build-info.outputs.canary-run == 'true' + && needs.build-info.outputs.default-branch == 'main' + - name: "Push CI cache ${{ matrix.platform }}" + run: > + breeze ci-image build + --builder airflow_cache + --prepare-buildx-cache + --run-in-parallel + --force-build + --platform ${{ matrix.platform }} + env: + DEBUG_RESOURCES: ${{needs.build-info.outputs.debug-resources}} + if: needs.build-info.outputs.canary-run == 'true' && needs.build-info.outputs.default-branch == 'main' + - name: "Push CI latest image ${{ matrix.platform }}" + run: > + breeze ci-image build + --tag-as-latest --push --run-in-parallel --platform ${{ matrix.platform }} + env: + DEBUG_RESOURCES: ${{needs.build-info.outputs.debug-resources}} + # We only push "amd" image as it is really only needed for any kind of automated builds in CI + # and currently there is not an easy way to make multi-platform image from two separate builds + if: > + matrix.platform == 'linux/amd64' && needs.build-info.outputs.canary-run == 'true' + && needs.build-info.outputs.default-branch == 'main' + - name: "Stop ARM instance" + run: ./scripts/ci/images/ci_stop_arm_instance.sh + if: > + always() && matrix.platform == 'linux/arm64' && needs.build-info.outputs.canary-run == 'true' + && needs.build-info.outputs.default-branch == 'main' + - name: "Clean docker cache for ${{ matrix.platform }}" + run: docker system prune --all --force + if: > + matrix.platform == 'linux/amd64' && needs.build-info.outputs.canary-run == 'true' + && needs.build-info.outputs.default-branch == 'main' + - name: "Fix ownership" + run: breeze ci fix-ownership + if: > + always() && needs.build-info.outputs.canary-run == 'true' + && needs.build-info.outputs.default-branch == 'main' + # Check that after earlier cache push, breeze command will build quickly + check-that-image-builds-quickly: + timeout-minutes: 5 + name: "Check that image builds quickly" + runs-on: "${{needs.build-info.outputs.runs-on}}" + needs: + - build-info + - push-early-buildx-cache-to-github-registry + env: + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" + UPGRADE_TO_NEWER_DEPENDENCIES: false + PLATFORM: "linux/amd64" + if: needs.build-info.outputs.canary-run == 'true' + steps: + - name: Cleanup repo + run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" + - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" + uses: actions/checkout@v3 + with: + persist-credentials: false + - name: "Install Breeze" + uses: ./.github/actions/breeze + - name: "Check that image builds quickly" + run: breeze shell --max-time 120 + - name: "Fix ownership" + run: breeze ci fix-ownership + if: always() + build-ci-images: permissions: packages: write timeout-minutes: 80 - name: "${{needs.build-info.outputs.buildJobDescription}} CI image ${{matrix.python-version}}" - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + name: >- + ${{needs.build-info.outputs.build-job-description}} CI images + ${{needs.build-info.outputs.all-python-versions-list-as-string}} + runs-on: "${{needs.build-info.outputs.runs-on}}" needs: [build-info] - strategy: - matrix: - python-version: ${{ fromJson(needs.build-info.outputs.allPythonVersions) }} - fail-fast: true env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn)[0] }} + DEFAULT_BRANCH: ${{ needs.build-info.outputs.default-branch }} + DEFAULT_CONSTRAINTS_BRANCH: ${{ needs.build-info.outputs.default-constraints-branch }} + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" steps: - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - if: needs.build-info.outputs.inWorkflowBuild == 'true' - - uses: actions/checkout@v2 + if: needs.build-info.outputs.in-workflow-build == 'true' + - uses: actions/checkout@v3 with: ref: ${{ needs.build-info.outputs.targetCommitSha }} persist-credentials: false submodules: recursive - if: needs.build-info.outputs.inWorkflowBuild == 'true' - - name: "Setup python" - uses: actions/setup-python@v2 - with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - if: needs.build-info.outputs.inWorkflowBuild == 'true' - - name: "Retrieve DEFAULTS from the _initialization.sh" - # We cannot "source" the script here because that would be a security problem (we cannot run - # any code that comes from the sources coming from the PR. Therefore we extract the - # DEFAULT_BRANCH and DEFAULT_CONSTRAINTS_BRANCH and DEBIAN_VERSION via custom grep/awk/sed commands - id: defaults - run: | - DEFAULT_BRANCH=$(grep "export DEFAULT_BRANCH" scripts/ci/libraries/_initialization.sh | \ - awk 'BEGIN{FS="="} {print $3}' | sed s'/["}]//g') - echo "DEFAULT_BRANCH=${DEFAULT_BRANCH}" >> $GITHUB_ENV - DEFAULT_CONSTRAINTS_BRANCH=$(grep "export DEFAULT_CONSTRAINTS_BRANCH" \ - scripts/ci/libraries/_initialization.sh | \ - awk 'BEGIN{FS="="} {print $3}' | sed s'/["}]//g') - echo "DEFAULT_CONSTRAINTS_BRANCH=${DEFAULT_CONSTRAINTS_BRANCH}" >> $GITHUB_ENV - DEBIAN_VERSION=$(grep "export DEBIAN_VERSION" scripts/ci/libraries/_initialization.sh | \ - awk 'BEGIN{FS="="} {print $3}' | sed s'/["}]//g') - echo "DEBIAN_VERSION=${DEBIAN_VERSION}" >> $GITHUB_ENV - if: needs.build-info.outputs.inWorkflowBuild == 'true' - - run: ./scripts/ci/install_breeze.sh - if: needs.build-info.outputs.inWorkflowBuild == 'true' - - name: "Free space" - run: breeze free-space - if: needs.build-info.outputs.inWorkflowBuild == 'true' - - name: Build & Push CI image ${{ matrix.python-version }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} - run: breeze build-image --push-image --tag-as-latest + if: needs.build-info.outputs.in-workflow-build == 'true' + - name: > + Build CI Images + ${{needs.build-info.outputs.all-python-versions-list-as-string}}:${{env.IMAGE_TAG}} + uses: ./.github/actions/build-ci-images + if: needs.build-info.outputs.in-workflow-build == 'true' env: - PYTHON_MAJOR_MINOR_VERSION: ${{ matrix.python-version }} - UPGRADE_TO_NEWER_DEPENDENCIES: ${{ needs.build-info.outputs.upgradeToNewerDependencies }} - DOCKER_CACHE: ${{ needs.build-info.outputs.cacheDirective }} - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} - if: needs.build-info.outputs.inWorkflowBuild == 'true' - - name: "Candidates for pip resolver backtrack triggers: ${{ matrix.python-version }}" - if: failure() || cancelled() - run: breeze find-newer-dependencies --max-age 1 --python "${{ matrix.python-version }}" - - name: "Fix ownership" - run: breeze fix-ownership - if: always() && needs.build-info.outputs.inWorkflowBuild == 'true' + UPGRADE_TO_NEWER_DEPENDENCIES: ${{ needs.build-info.outputs.upgrade-to-newer-dependencies }} + DOCKER_CACHE: ${{ needs.build-info.outputs.cache-directive }} + PYTHON_VERSIONS: ${{needs.build-info.outputs.all-python-versions-list-as-string}} + DEBUG_RESOURCES: ${{needs.build-info.outputs.debug-resources}} build-prod-images: permissions: packages: write timeout-minutes: 80 - name: "${{needs.build-info.outputs.buildJobDescription}} PROD image ${{matrix.python-version}}" - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + name: > + ${{needs.build-info.outputs.build-job-description}} PROD images + ${{needs.build-info.outputs.all-python-versions-list-as-string}} + runs-on: "${{needs.build-info.outputs.runs-on}}" needs: [build-info, build-ci-images] - strategy: - matrix: - python-version: ${{ fromJson(needs.build-info.outputs.allPythonVersions) }} - fail-fast: true env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn)[0] }} + DEFAULT_BRANCH: ${{ needs.build-info.outputs.default-branch }} + DEFAULT_CONSTRAINTS_BRANCH: ${{ needs.build-info.outputs.default-constraints-branch }} + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" BACKEND: sqlite - PYTHON_MAJOR_MINOR_VERSION: ${{ matrix.python-version }} - DOCKER_CACHE: ${{ needs.build-info.outputs.cacheDirective }} + DOCKER_CACHE: ${{ needs.build-info.outputs.cache-directive }} VERSION_SUFFIX_FOR_PYPI: "dev0" + DEBUG_RESOURCES: ${{needs.build-info.outputs.debug-resources}} steps: - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - if: needs.build-info.outputs.inWorkflowBuild == 'true' - - uses: actions/checkout@v2 + if: needs.build-info.outputs.in-workflow-build == 'true' + - uses: actions/checkout@v3 with: ref: ${{ needs.build-info.outputs.targetCommitSha }} persist-credentials: false submodules: recursive - if: needs.build-info.outputs.inWorkflowBuild == 'true' - - name: "Setup python" - uses: actions/setup-python@v2 - with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - if: needs.build-info.outputs.inWorkflowBuild == 'true' - - name: "Retrieve DEFAULTS from the _initialization.sh" - # We cannot "source" the script here because that would be a security problem (we cannot run - # any code that comes from the sources coming from the PR. Therefore we extract the - # DEFAULT_BRANCH and DEFAULT_CONSTRAINTS_BRANCH and DEBIAN_VERSION via custom grep/awk/sed commands - id: defaults - run: | - DEFAULT_BRANCH=$(grep "export DEFAULT_BRANCH" scripts/ci/libraries/_initialization.sh | \ - awk 'BEGIN{FS="="} {print $3}' | sed s'/["}]//g') - echo "DEFAULT_BRANCH=${DEFAULT_BRANCH}" >> $GITHUB_ENV - DEFAULT_CONSTRAINTS_BRANCH=$(grep "export DEFAULT_CONSTRAINTS_BRANCH" \ - scripts/ci/libraries/_initialization.sh | \ - awk 'BEGIN{FS="="} {print $3}' | sed s'/["}]//g') - echo "DEFAULT_CONSTRAINTS_BRANCH=${DEFAULT_CONSTRAINTS_BRANCH}" >> $GITHUB_ENV - DEBIAN_VERSION=$(grep "export DEBIAN_VERSION" scripts/ci/libraries/_initialization.sh | \ - awk 'BEGIN{FS="="} {print $3}' | sed s'/["}]//g') - echo "DEBIAN_VERSION=${DEBIAN_VERSION}" >> $GITHUB_ENV - if: needs.build-info.outputs.inWorkflowBuild == 'true' - - run: ./scripts/ci/install_breeze.sh - if: needs.build-info.outputs.inWorkflowBuild == 'true' - - name: "Free space" - run: breeze free-space - if: needs.build-info.outputs.inWorkflowBuild == 'true' + if: needs.build-info.outputs.in-workflow-build == 'true' - name: > - Pull CI image for PROD build: - ${{ needs.build-info.outputs.defaultPythonVersion }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }}" - run: breeze pull-image --tag-as-latest - env: - # Always use default Python version of CI image for preparing packages - PYTHON_MAJOR_MINOR_VERSION: ${{ needs.build-info.outputs.defaultPythonVersion }} - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} - if: needs.build-info.outputs.inWorkflowBuild == 'true' - - name: "Cleanup dist and context file" - run: rm -fv ./dist/* ./docker-context-files/* - if: needs.build-info.outputs.inWorkflowBuild == 'true' - - name: "Prepare providers packages" - run: > - breeze prepare-provider-packages - --package-list-file ./scripts/ci/installed_providers.txt - --package-format wheel --version-suffix-for-pypi dev0 - if: needs.build-info.outputs.inWorkflowBuild == 'true' - - name: "Prepare airflow package" - run: breeze prepare-airflow-package --package-format wheel --version-suffix-for-pypi dev0 - if: needs.build-info.outputs.inWorkflowBuild == 'true' - - name: "Move dist packages to docker-context files" - run: mv -v ./dist/*.whl ./docker-context-files - if: needs.build-info.outputs.inWorkflowBuild == 'true' - - name: Build & Push PROD image ${{ env.PYTHON_MAJOR_MINOR_VERSION }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} - run: > - breeze build-prod-image - --tag-as-latest - --push-image - --install-packages-from-context - --disable-airflow-repo-cache - --airflow-is-in-context + Build PROD Images + ${{needs.build-info.outputs.all-python-versions-list-as-string}}:${{env.IMAGE_TAG}} + uses: ./.github/actions/build-prod-images + if: needs.build-info.outputs.in-workflow-build == 'true' + with: + build-provider-packages: ${{ needs.build-info.outputs.default-branch == 'main' }} env: - UPGRADE_TO_NEWER_DEPENDENCIES: ${{ needs.build-info.outputs.upgradeToNewerDependencies }} - DOCKER_CACHE: ${{ needs.build-info.outputs.cacheDirective }} - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} - if: needs.build-info.outputs.inWorkflowBuild == 'true' - - name: "Fix ownership" - run: breeze fix-ownership - if: always() && needs.build-info.outputs.inWorkflowBuild == 'true' + UPGRADE_TO_NEWER_DEPENDENCIES: ${{ needs.build-info.outputs.upgrade-to-newer-dependencies }} + DOCKER_CACHE: ${{ needs.build-info.outputs.cache-directive }} + PYTHON_VERSIONS: ${{needs.build-info.outputs.all-python-versions-list-as-string}} + DEBUG_RESOURCES: ${{ needs.build-info.outputs.debug-resources }} run-new-breeze-tests: timeout-minutes: 10 name: Breeze unit tests - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + runs-on: "${{needs.build-info.outputs.runs-on}}" needs: [build-info] steps: - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: persist-credentials: false - - uses: actions/setup-python@v2 + - uses: actions/setup-python@v4 with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} + python-version: "${{needs.build-info.outputs.default-python-version}}" cache: 'pip' cache-dependency-path: ./dev/breeze/setup* - run: python -m pip install --editable ./dev/breeze/ - run: python -m pytest ./dev/breeze/ -n auto --color=yes - - run: breeze version - - tests-ui: - timeout-minutes: 10 - name: React UI tests - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} - needs: [build-info] - if: needs.build-info.outputs.run-ui-tests == 'true' - steps: - - name: Cleanup repo - run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 - with: - persist-credentials: false - - name: "Setup node" - uses: actions/setup-node@v2 - with: - node-version: 14 - - name: "Cache eslint" - uses: actions/cache@v2 - with: - path: 'airflow/ui/node_modules' - key: ${{ runner.os }}-ui-node-modules-${{ hashFiles('airflow/ui/**/yarn.lock') }} - - run: yarn --cwd airflow/ui/ install --frozen-lockfile --non-interactive - - run: yarn --cwd airflow/ui/ run test - env: - FORCE_COLOR: 2 + - run: breeze setup version tests-www: timeout-minutes: 10 name: React WWW tests - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + runs-on: "${{needs.build-info.outputs.runs-on}}" needs: [build-info] if: needs.build-info.outputs.run-www-tests == 'true' steps: - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: persist-credentials: false - name: "Setup node" - uses: actions/setup-node@v2 + uses: actions/setup-node@v3 with: node-version: 14 - name: "Cache eslint" - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: 'airflow/www/node_modules' - key: ${{ runner.os }}-ui-node-modules-${{ hashFiles('airflow/ui/**/yarn.lock') }} + key: ${{ runner.os }}-www-node-modules-${{ hashFiles('airflow/www/**/yarn.lock') }} - run: yarn --cwd airflow/www/ install --frozen-lockfile --non-interactive - run: yarn --cwd airflow/www/ run test env: @@ -519,14 +460,14 @@ jobs: test-openapi-client-generation: timeout-minutes: 10 name: "Test OpenAPI client generation" - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + runs-on: "${{needs.build-info.outputs.runs-on}}" needs: [build-info] if: needs.build-info.outputs.needs-api-codegen == 'true' steps: - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: fetch-depth: 2 persist-credentials: false @@ -536,21 +477,21 @@ jobs: test-examples-of-prod-image-building: timeout-minutes: 60 name: "Test examples of production image building" - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + runs-on: "${{needs.build-info.outputs.runs-on}}" needs: [build-info] if: needs.build-info.outputs.image-build == 'true' steps: - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: fetch-depth: 2 persist-credentials: false - name: "Setup python" - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} + python-version: "${{needs.build-info.outputs.default-python-version}}" cache: 'pip' cache-dependency-path: ./dev/requirements.txt - name: "Test examples of PROD image building" @@ -558,90 +499,88 @@ jobs: python -m pip install -r ./docker_tests/requirements.txt && python -m pytest docker_tests/test_examples_of_prod_image_building.py -n auto --color=yes + test-git-clone-on-windows: + timeout-minutes: 5 + name: "Test git clone on Windows" + runs-on: windows-latest + needs: [build-info] + steps: + - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" + uses: actions/checkout@v3 + with: + fetch-depth: 2 + persist-credentials: false + if: needs.build-info.outputs.runs-on != 'self-hosted' + wait-for-ci-images: timeout-minutes: 120 name: "Wait for CI images" - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + runs-on: "${{needs.build-info.outputs.runs-on}}" needs: [build-info, build-ci-images] if: needs.build-info.outputs.image-build == 'true' env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn) }} + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" BACKEND: sqlite + # Force more parallelism for pull even on public images + PARALLELISM: 6 steps: - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: persist-credentials: false - - name: "Setup python" - uses: actions/setup-python@v2 - with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - cache: 'pip' - cache-dependency-path: ./dev/breeze/setup* - - run: ./scripts/ci/install_breeze.sh - - name: "Free space" - run: breeze free-space - - name: Wait for CI images ${{ env.PYTHON_VERSIONS }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} + - name: "Install Breeze" + uses: ./.github/actions/breeze + - name: Wait for CI images ${{ env.PYTHON_VERSIONS }}:${{ env.IMAGE_TAG }} id: wait-for-images - run: breeze pull-image --run-in-parallel --verify-image --wait-for-image --tag-as-latest + run: breeze ci-image pull --run-in-parallel --verify --wait-for-image --tag-as-latest env: - PYTHON_VERSIONS: ${{ needs.build-info.outputs.pythonVersionsListAsString }} - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} + PYTHON_VERSIONS: ${{ needs.build-info.outputs.python-versions-list-as-string }} + DEBUG_RESOURCES: ${{needs.build-info.outputs.debug-resources}} - name: "Fix ownership" - run: breeze fix-ownership + run: breeze ci fix-ownership if: always() static-checks: timeout-minutes: 30 name: "Static checks" - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + runs-on: "${{needs.build-info.outputs.runs-on}}" needs: [build-info, wait-for-ci-images] env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn) }} - PYTHON_MAJOR_MINOR_VERSION: ${{ needs.build-info.outputs.defaultPythonVersion }} + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" + PYTHON_MAJOR_MINOR_VERSION: "${{needs.build-info.outputs.default-python-version}}" if: needs.build-info.outputs.basic-checks-only == 'false' steps: - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: persist-credentials: false - - name: "Setup python" - uses: actions/setup-python@v2 - with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - cache: 'pip' - cache-dependency-path: ./dev/breeze/setup* + - name: > + Prepare breeze & CI image: ${{needs.build-info.outputs.default-python-version}}:${{env.IMAGE_TAG}} + uses: ./.github/actions/prepare_breeze_and_image + id: breeze - name: Cache pre-commit envs - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/pre-commit - key: "pre-commit-${{steps.host-python-version.outputs.host-python-version}}-\ -${{ hashFiles('.pre-commit-config.yaml') }}" - restore-keys: pre-commit-${{steps.host-python-version.outputs.host-python-version}} - - run: ./scripts/ci/install_breeze.sh - - name: "Free space" - run: breeze free-space - - name: > - Pull CI image ${{ env.PYTHON_MAJOR_MINOR_VERSION }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} - run: breeze pull-image --tag-as-latest - env: - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} - - name: "Get Python version" - run: "echo \"::set-output name=host-python-version::$(python -c - 'import platform; print(platform.python_version())')\"" - id: host-python-version + # yamllint disable-line rule:line-length + key: "pre-commit-full-${{steps.breeze.outputs.host-python-version}}-${{ hashFiles('.pre-commit-config.yaml') }}" + restore-keys: | + pre-commit-full-${{steps.breeze.outputs.host-python-version}} + pre-commit-full - name: "Static checks" run: breeze static-checks --all-files --show-diff-on-failure --color always env: VERBOSE: "false" - SKIP: "identity" + SKIP: ${{ needs.build-info.outputs.skip-pre-commits }} COLUMNS: "250" + SKIP_GROUP_OUTPUT: "true" + DEFAULT_BRANCH: ${{ needs.build-info.outputs.default-branch }} - name: "Fix ownership" - run: breeze fix-ownership + run: breeze ci fix-ownership if: always() # Those checks are run if no image needs to be built for checks. This is for simple changes that @@ -650,44 +589,46 @@ ${{ hashFiles('.pre-commit-config.yaml') }}" static-checks-basic-checks-only: timeout-minutes: 30 name: "Static checks: basic checks only" - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + runs-on: "${{needs.build-info.outputs.runs-on}}" needs: [build-info] env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn) }} + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" if: needs.build-info.outputs.basic-checks-only == 'true' steps: - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: persist-credentials: false - name: "Setup python" - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} + python-version: "${{needs.build-info.outputs.default-python-version}}" cache: 'pip' cache-dependency-path: ./dev/breeze/setup* + - name: "Install Breeze" + uses: ./.github/actions/breeze + id: breeze - name: Cache pre-commit envs - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/pre-commit - key: "pre-commit-basic-${{steps.host-python-version.outputs.host-python-version}}-\ -${{ hashFiles('.pre-commit-config.yaml') }}" - restore-keys: pre-commit-basic-${{steps.host-python-version.outputs.host-python-version}} + # yamllint disable-line rule:line-length + key: "pre-commit-basic-${{steps.breeze.outputs.host-python-version}}-${{ hashFiles('.pre-commit-config.yaml') }}" + restore-keys: "\ + pre-commit-full-${{steps.breeze.outputs.host-python-version}}-\ + ${{ hashFiles('.pre-commit-config.yaml') }}\n + pre-commit-basic-${{steps.breeze.outputs.host-python-version}}\n + pre-commit-full-${{steps.breeze.outputs.host-python-version}}\n + pre-commit-basic-\n + pre-commit-full-" - name: Fetch incoming commit ${{ github.sha }} with its parent - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: ref: ${{ github.sha }} fetch-depth: 2 persist-credentials: false - - run: ./scripts/ci/install_breeze.sh - - name: "Free space" - run: breeze free-space - - name: "Get Python version" - run: "echo \"::set-output name=host-python-version::$(python -c - 'import platform; print(platform.python_version())')\"" - id: host-python-version - name: "Static checks: basic checks only" run: > breeze static-checks --all-files --show-diff-on-failure --color always @@ -695,42 +636,33 @@ ${{ hashFiles('.pre-commit-config.yaml') }}" env: VERBOSE: "false" SKIP_IMAGE_PRE_COMMITS: "true" - SKIP: "identity" + SKIP: ${{ needs.build-info.outputs.skip-pre-commits }} COLUMNS: "250" - name: "Fix ownership" - run: breeze fix-ownership + run: breeze ci fix-ownership if: always() docs: timeout-minutes: 45 name: "Build docs" - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + runs-on: "${{needs.build-info.outputs.runs-on}}" needs: [build-info, wait-for-ci-images] if: needs.build-info.outputs.docs-build == 'true' env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn) }} - PYTHON_MAJOR_MINOR_VERSION: ${{ needs.build-info.outputs.defaultPythonVersion }} + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" + PYTHON_MAJOR_MINOR_VERSION: "${{needs.build-info.outputs.default-python-version}}" steps: - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: persist-credentials: false submodules: recursive - - uses: actions/setup-python@v2 - with: - python-version: ${{needs.build-info.outputs.defaultPythonVersion}} - cache: 'pip' - cache-dependency-path: ./dev/breeze/setup* - - run: ./scripts/ci/install_breeze.sh - - name: "Free space" - run: breeze free-space - - name: Pull CI image ${{ env.PYTHON_MAJOR_MINOR_VERSION }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} - run: breeze pull-image --tag-as-latest - env: - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} - - uses: actions/cache@v2 + - name: > + Prepare breeze & CI image: ${{needs.build-info.outputs.default-python-version}}:${{env.IMAGE_TAG}} + uses: ./.github/actions/prepare_breeze_and_image + - uses: actions/cache@v3 id: cache-doc-inventories with: path: ./docs/_inventory_cache/ @@ -739,7 +671,7 @@ ${{ hashFiles('.pre-commit-config.yaml') }}" docs-inventory-${{ hashFiles('setup.py','setup.cfg','pyproject.toml;') }} docs-inventory- - name: "Build docs" - run: breeze build-docs + run: breeze build-docs ${{ needs.build-info.outputs.docs-filter }} - name: Configure AWS credentials uses: ./.github/actions/configure-aws-credentials if: > @@ -755,145 +687,88 @@ ${{ hashFiles('.pre-commit-config.yaml') }}" github.event_name == 'push' run: aws s3 sync --delete ./files/documentation s3://apache-airflow-docs - name: "Fix ownership" - run: breeze fix-ownership + run: breeze ci fix-ownership if: always() - prepare-test-provider-packages-wheel: - timeout-minutes: 40 - name: "Build and test provider packages wheel" - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + prepare-test-provider-packages: + timeout-minutes: 80 + name: "Provider packages ${{matrix.package-format}}" + runs-on: "${{needs.build-info.outputs.runs-on}}" needs: [build-info, wait-for-ci-images] + strategy: + matrix: + package-format: ["sdist", "wheel"] + exclude: "${{fromJson(needs.build-info.outputs.providers-package-format-exclude)}}" + fail-fast: false env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn) }} - PYTHON_MAJOR_MINOR_VERSION: ${{ needs.build-info.outputs.defaultPythonVersion }} + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" + PYTHON_MAJOR_MINOR_VERSION: "${{needs.build-info.outputs.default-python-version}}" + PACKAGE_FORMAT: "${{matrix.package-format}}" + USE_AIRFLOW_VERSION: "${{matrix.package-format}}" if: needs.build-info.outputs.image-build == 'true' && needs.build-info.outputs.default-branch == 'main' steps: - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: persist-credentials: false - if: needs.build-info.outputs.default-branch == 'main' - - name: "Setup python" - uses: actions/setup-python@v2 - with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - cache: 'pip' - cache-dependency-path: ./dev/breeze/setup* - - run: ./scripts/ci/install_breeze.sh - - name: "Free space" - run: breeze free-space - name: > - Pull CI image ${{ env.PYTHON_MAJOR_MINOR_VERSION }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} - run: breeze pull-image --tag-as-latest - env: - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} + Prepare breeze & CI image: ${{needs.build-info.outputs.default-python-version}}:${{env.IMAGE_TAG}} + uses: ./.github/actions/prepare_breeze_and_image - name: "Cleanup dist files" run: rm -fv ./dist/* - name: "Prepare provider documentation" - run: breeze prepare-provider-documentation --answer yes - - name: "Prepare provider packages: wheel" - run: breeze prepare-provider-packages --package-format wheel --version-suffix-for-pypi dev0 - - name: "Prepare airflow package: wheel" - run: breeze prepare-airflow-package --package-format wheel --version-suffix-for-pypi dev0 - - name: "Install and test provider packages and airflow via wheel files" - run: > - breeze verify-provider-packages --use-airflow-version wheel --use-packages-from-dist - --package-format wheel - - name: "Remove airflow package and replace providers with 2.1-compliant versions" + run: breeze release-management prepare-provider-documentation + if: matrix.package-format == 'wheel' + - name: "Prepare provider packages: ${{matrix.package-format}}" + run: breeze release-management prepare-provider-packages --version-suffix-for-pypi dev0 + - name: "Prepare airflow package: ${{matrix.package-format}}" + run: breeze release-management prepare-airflow-package --version-suffix-for-pypi dev0 + - name: "Verify wheel packages with twine" + run: pipx install twine && twine check dist/*.whl + if: matrix.package-format == 'wheel' + - name: "Verify sdist packages with twine" + run: pipx install twine && twine check dist/*.tar.gz + if: matrix.package-format == 'sdist' + - name: "Install and test provider packages and airflow via ${{matrix.package-format}} files" + run: breeze release-management verify-provider-packages --use-packages-from-dist + env: + SKIP_CONSTRAINTS: "${{ needs.build-info.outputs.upgrade-to-newer-dependencies }}" + - name: "Remove airflow package and replace providers with 2.3-compliant versions" run: | - rm -vf dist/apache_airflow-*.whl \ - dist/apache_airflow_providers_cncf_kubernetes*.whl \ - dist/apache_airflow_providers_celery*.whl - pip download --no-deps --dest dist \ - apache-airflow-providers-cncf-kubernetes==3.0.0 \ - apache-airflow-providers-celery==2.1.3 - - name: "Install and test provider packages and airflow on Airflow 2.1 files" + rm -vf dist/apache_airflow-*.whl dist/apache_airflow_providers_docker*.whl + pip download --no-deps --dest dist apache-airflow-providers-docker==3.1.0 + if: matrix.package-format == 'wheel' + - name: "Get all provider extras as AIRFLOW_EXTRAS env variable" run: > - breeze verify-provider-packages --use-airflow-version 2.1.0 - --use-packages-from-dist --package-format wheel --airflow-constraints-reference constraints-2.1.0 - env: - # The extras below are all extras that should be installed with Airflow 2.1.0 - AIRFLOW_EXTRAS: "airbyte,alibaba,amazon,apache.atlas.apache.beam,apache.cassandra,apache.drill,\ - apache.druid,apache.hdfs,apache.hive,apache.kylin,apache.livy,apache.pig,apache.pinot,\ - apache.spark,apache.sqoop,apache.webhdfs,arangodb,asana,async,\ - celery,cgroups,cloudant,cncf.kubernetes,dask,databricks,datadog,dbt.cloud,\ - deprecated_api,dingding,discord,docker,\ - elasticsearch,exasol,facebook,ftp,github,github_enterprise,google,google_auth,\ - grpc,hashicorp,http,imap,influxdb,jdbc,jenkins,jira,kerberos,ldap,\ - leveldb,microsoft.azure,microsoft.mssql,microsoft.psrp,microsoft.winrm,mongo,mysql,\ - neo4j,odbc,openfaas,opsgenie,oracle,pagerduty,pandas,papermill,password,plexus,\ - postgres,presto,qubole,rabbitmq,redis,salesforce,samba,segment,sendgrid,sentry,\ - sftp,singularity,slack,snowflake,sqlite,ssh,statsd,tableau,telegram,trino,vertica,\ - virtualenv,yandex,zendesk" - - name: "Fix ownership" - run: breeze fix-ownership - if: always() - - prepare-test-provider-packages-sdist: - timeout-minutes: 40 - name: "Build and test provider packages sdist" - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} - needs: [build-info, wait-for-ci-images] - env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn) }} - PYTHON_MAJOR_MINOR_VERSION: ${{ needs.build-info.outputs.defaultPythonVersion }} - if: needs.build-info.outputs.image-build == 'true' && needs.build-info.outputs.default-branch == 'main' - steps: - - name: Cleanup repo - run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 - with: - persist-credentials: false - if: needs.build-info.outputs.default-branch == 'main' - - name: "Setup python" - uses: actions/setup-python@v2 - with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - cache: 'pip' - cache-dependency-path: ./dev/breeze/setup* - - run: ./scripts/ci/install_breeze.sh - - name: "Free space" - run: breeze free-space - - name: > - Pull CI image ${{ env.PYTHON_MAJOR_MINOR_VERSION }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} - run: breeze pull-image --tag-as-latest - env: - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} - - name: "Cleanup dist files" - run: rm -fv ./dist/* - - name: "Prepare provider packages: sdist" - run: breeze prepare-provider-packages --package-format sdist --version-suffix-for-pypi dev0 - - name: "Prepare airflow package: sdist" - run: breeze prepare-airflow-package --package-format sdist --version-suffix-for-pypi dev0 - - name: "Upload provider distribution artifacts" - uses: actions/upload-artifact@v2 - with: - name: airflow-provider-packages - path: "./dist/apache-airflow-providers-*.tar.gz" - retention-days: 1 - - name: "Install and test provider packages and airflow via sdist files" + python -c 'from pathlib import Path; import json; + providers = json.loads(Path("generated/provider_dependencies.json").read_text()); + provider_keys = ",".join(providers.keys()); + print("AIRFLOW_EXTRAS={}".format(provider_keys))' >> $GITHUB_ENV + if: matrix.package-format == 'wheel' + - name: "Install and test provider packages and airflow on Airflow 2.3 files" run: > - breeze verify-provider-packages --use-airflow-version sdist --use-packages-from-dist - --package-format sdist + breeze release-management verify-provider-packages --use-airflow-version 2.3.0 + --use-packages-from-dist --airflow-constraints-reference constraints-2.3.0 + if: matrix.package-format == 'wheel' - name: "Fix ownership" - run: breeze fix-ownership + run: breeze ci fix-ownership if: always() tests-helm: timeout-minutes: 80 - name: "Python unit tests for helm chart" - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + name: "Python unit tests for Helm chart" + runs-on: "${{needs.build-info.outputs.runs-on}}" needs: [build-info, wait-for-ci-images] env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn) }} - MOUNT_SELECTED_LOCAL_SOURCES: "true" + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" TEST_TYPES: "Helm" BACKEND: "" DB_RESET: "false" - PYTHON_MAJOR_MINOR_VERSION: ${{needs.build-info.outputs.defaultPythonVersion}} + PYTHON_MAJOR_MINOR_VERSION: "${{needs.build-info.outputs.default-python-version}}" + JOB_ID: "helm-tests" + COVERAGE: "${{needs.build-info.outputs.run-coverage}}" if: > needs.build-info.outputs.needs-helm-tests == 'true' && (github.repository == 'apache/airflow' || github.event_name != 'schedule') && @@ -902,420 +777,320 @@ ${{ hashFiles('.pre-commit-config.yaml') }}" - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: persist-credentials: false - - name: "Setup python" - uses: actions/setup-python@v2 - with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - cache: 'pip' - cache-dependency-path: ./dev/breeze/setup* - - run: ./scripts/ci/install_breeze.sh - - name: "Free space" - run: breeze free-space - name: > - Pull CI image ${{ env.PYTHON_MAJOR_MINOR_VERSION }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} - run: breeze pull-image --tag-as-latest - env: - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} - - name: "Tests: Helm" - run: ./scripts/ci/testing/ci_run_airflow_testing.sh - env: - PR_LABELS: "${{ needs.build-info.outputs.pullRequestLabels }}" - - name: "Upload airflow logs" - uses: actions/upload-artifact@v2 - if: failure() - with: - name: airflow-logs-helm - path: "./files/airflow_logs*" - retention-days: 7 - - name: "Upload container logs" - uses: actions/upload-artifact@v2 - if: failure() - with: - name: container-logs-helm - path: "./files/container_logs*" - retention-days: 7 - - name: "Upload artifact for coverage" - uses: actions/upload-artifact@v2 - if: needs.build-info.outputs.runCoverage == 'true' - with: - name: coverage-helm - path: "./files/coverage*.xml" - retention-days: 7 - - name: "Fix ownership" - run: breeze fix-ownership - if: always() + Prepare breeze & CI image: ${{needs.build-info.outputs.default-python-version}}:${{env.IMAGE_TAG}} + uses: ./.github/actions/prepare_breeze_and_image + - name: "Helm Unit Tests" + run: breeze testing helm-tests + - name: "Post Helm Tests" + uses: ./.github/actions/post_tests tests-postgres: timeout-minutes: 130 name: > Postgres${{matrix.postgres-version}},Py${{matrix.python-version}}: - ${{needs.build-info.outputs.testTypes}} - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + ${{needs.build-info.outputs.test-types}} + runs-on: "${{needs.build-info.outputs.runs-on}}" needs: [build-info, wait-for-ci-images] strategy: matrix: - python-version: ${{ fromJson(needs.build-info.outputs.pythonVersions) }} - postgres-version: ${{ fromJson(needs.build-info.outputs.postgresVersions) }} - exclude: ${{ fromJson(needs.build-info.outputs.postgresExclude) }} + python-version: "${{fromJson(needs.build-info.outputs.python-versions)}}" + postgres-version: "${{fromJson(needs.build-info.outputs.postgres-versions)}}" + exclude: "${{fromJson(needs.build-info.outputs.postgres-exclude)}}" fail-fast: false env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn) }} - BACKEND: postgres - POSTGRES_VERSION: ${{ matrix.postgres-version }} - TEST_TYPES: "${{needs.build-info.outputs.testTypes}}" - PYTHON_MAJOR_MINOR_VERSION: ${{ matrix.python-version }} + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" + TEST_TYPES: "${{needs.build-info.outputs.test-types}}" + PR_LABELS: "${{needs.build-info.outputs.pull-request-labels}}" + FULL_TESTS_NEEDED: "${{needs.build-info.outputs.full-tests-needed}}" + DEBUG_RESOURCES: "${{needs.build-info.outputs.debug-resources}}" + BACKEND: "postgres" + PYTHON_MAJOR_MINOR_VERSION: "${{matrix.python-version}}" + POSTGRES_VERSION: "${{matrix.postgres-version}}" + BACKEND_VERSION: "${{matrix.postgres-version}}" + JOB_ID: "postgres-${{matrix.postgres-version}}-${{matrix.python-version}}" + COVERAGE: "${{needs.build-info.outputs.run-coverage}}" if: needs.build-info.outputs.run-tests == 'true' steps: - name: Cleanup repo + shell: bash run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: persist-credentials: false - - name: "Setup python" - uses: actions/setup-python@v2 - with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - cache: 'pip' - cache-dependency-path: ./dev/breeze/setup* - - run: ./scripts/ci/install_breeze.sh - - name: "Free space" - run: breeze free-space - - name: Pull CI image ${{ env.PYTHON_MAJOR_MINOR_VERSION }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} - run: breeze pull-image --tag-as-latest - env: - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} - - name: "Test downgrade" - run: ./scripts/ci/testing/run_downgrade_test.sh - - name: "Test Offline SQL generation" - run: ./scripts/ci/testing/run_offline_sql_test.sh - - name: "Tests: ${{needs.build-info.outputs.testTypes}}" - run: ./scripts/ci/testing/ci_run_airflow_testing.sh - env: - PR_LABELS: "${{ needs.build-info.outputs.pullRequestLabels }}" - - name: "Upload airflow logs" - uses: actions/upload-artifact@v2 - if: failure() - with: - name: airflow-logs-${{matrix.python-version}}-${{matrix.postgres-version}} - path: "./files/airflow_logs*" - retention-days: 7 - - name: "Upload container logs" - uses: actions/upload-artifact@v2 - if: failure() - with: - name: container-logs-postgres-${{matrix.python-version}}-${{matrix.postgres-version}} - path: "./files/container_logs*" - retention-days: 7 - - name: "Upload artifact for coverage" - uses: actions/upload-artifact@v2 - if: needs.build-info.outputs.runCoverage == 'true' - with: - name: coverage-postgres-${{matrix.python-version}}-${{matrix.postgres-version}} - path: "./files/coverage*.xml" - retention-days: 7 - - name: "Fix ownership" - run: breeze fix-ownership - if: always() + - name: "Prepare breeze & CI image: ${{matrix.python-version}}:${{env.IMAGE_TAG}}" + uses: ./.github/actions/prepare_breeze_and_image + - name: "Migration Tests: ${{matrix.python-version}}:${{needs.build-info.outputs.test-types}}" + uses: ./.github/actions/migration_tests + - name: "Tests: ${{matrix.python-version}}:${{needs.build-info.outputs.test-types}} (w/Kerberos)" + run: breeze testing tests --run-in-parallel + - name: "Post Tests: ${{matrix.python-version}}:${{needs.build-info.outputs.test-types}}" + uses: ./.github/actions/post_tests tests-mysql: timeout-minutes: 130 name: > - MySQL${{matrix.mysql-version}}, Py${{matrix.python-version}}: ${{needs.build-info.outputs.testTypes}} - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + MySQL${{matrix.mysql-version}}, Py${{matrix.python-version}}: ${{needs.build-info.outputs.test-types}} + runs-on: "${{needs.build-info.outputs.runs-on}}" needs: [build-info, wait-for-ci-images] strategy: matrix: - python-version: ${{ fromJson(needs.build-info.outputs.pythonVersions) }} - mysql-version: ${{ fromJson(needs.build-info.outputs.mysqlVersions) }} - exclude: ${{ fromJson(needs.build-info.outputs.mysqlExclude) }} + python-version: "${{fromJson(needs.build-info.outputs.python-versions)}}" + mysql-version: "${{fromJson(needs.build-info.outputs.mysql-versions)}}" + exclude: "${{fromJson(needs.build-info.outputs.mysql-exclude)}}" fail-fast: false env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn) }} - BACKEND: mysql - MYSQL_VERSION: ${{ matrix.mysql-version }} - TEST_TYPES: "${{needs.build-info.outputs.testTypes}}" - PYTHON_MAJOR_MINOR_VERSION: ${{ matrix.python-version }} + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" + PR_LABELS: "${{needs.build-info.outputs.pull-request-labels}}" + FULL_TESTS_NEEDED: "${{needs.build-info.outputs.full-tests-needed}}" + TEST_TYPES: "${{needs.build-info.outputs.test-types}}" + DEBUG_RESOURCES: "${{needs.build-info.outputs.debug-resources}}" + BACKEND: "mysql" + PYTHON_MAJOR_MINOR_VERSION: "${{matrix.python-version}}" + MYSQL_VERSION: "${{matrix.mysql-version}}" + BACKEND_VERSION: "${{matrix.mysql-version}}" + JOB_ID: "mysql-${{matrix.mysql-version}}-${{matrix.python-version}}" if: needs.build-info.outputs.run-tests == 'true' steps: - name: Cleanup repo + shell: bash run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: persist-credentials: false - - name: "Setup python" - uses: actions/setup-python@v2 - with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - cache: 'pip' - cache-dependency-path: ./dev/breeze/setup* - - run: ./scripts/ci/install_breeze.sh - - name: "Free space" - run: breeze free-space - - name: Pull CI image ${{ env.PYTHON_MAJOR_MINOR_VERSION }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} - run: breeze pull-image --tag-as-latest - env: - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} - - name: "Test downgrade" - run: ./scripts/ci/testing/run_downgrade_test.sh - - name: "Test Offline SQL generation" - run: ./scripts/ci/testing/run_offline_sql_test.sh - - name: "Tests: ${{needs.build-info.outputs.testTypes}}" - run: ./scripts/ci/testing/ci_run_airflow_testing.sh - env: - PR_LABELS: "${{ needs.build-info.outputs.pullRequestLabels }}" - - name: "Upload airflow logs" - uses: actions/upload-artifact@v2 - if: failure() - with: - name: airflow-logs-${{matrix.python-version}}-${{matrix.mysql-version}} - path: "./files/airflow_logs*" - retention-days: 7 - - name: "Upload container logs" - uses: actions/upload-artifact@v2 - if: failure() - with: - name: container-logs-mysql-${{matrix.python-version}}-${{matrix.mysql-version}} - path: "./files/container_logs*" - retention-days: 7 - - name: "Upload artifact for coverage" - uses: actions/upload-artifact@v2 - if: needs.build-info.outputs.runCoverage == 'true' - with: - name: coverage-mysql-${{matrix.python-version}}-${{matrix.mysql-version}} - path: "./files/coverage*.xml" - retention-days: 7 - - name: "Fix ownership" - run: breeze fix-ownership - if: always() + - name: "Prepare breeze & CI image: ${{matrix.python-version}}:${{env.IMAGE_TAG}}" + uses: ./.github/actions/prepare_breeze_and_image + - name: "Migration Tests: ${{matrix.python-version}}:${{needs.build-info.outputs.test-types}}" + uses: ./.github/actions/migration_tests + - name: "Tests: ${{matrix.python-version}}:${{needs.build-info.outputs.test-types}}" + run: breeze testing tests --run-in-parallel + - name: "Post Tests: ${{matrix.python-version}}:${{needs.build-info.outputs.test-types}}" + uses: ./.github/actions/post_tests tests-mssql: timeout-minutes: 130 name: > - MSSQL${{matrix.mssql-version}}, Py${{matrix.python-version}}: ${{needs.build-info.outputs.testTypes}} - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + MSSQL${{matrix.mssql-version}}, Py${{matrix.python-version}}: ${{needs.build-info.outputs.test-types}} + runs-on: "${{needs.build-info.outputs.runs-on}}" needs: [build-info, wait-for-ci-images] strategy: matrix: - python-version: ${{ fromJson(needs.build-info.outputs.pythonVersions) }} - mssql-version: ${{ fromJson(needs.build-info.outputs.mssqlVersions) }} - exclude: ${{ fromJson(needs.build-info.outputs.mssqlExclude) }} + python-version: "${{fromJson(needs.build-info.outputs.python-versions)}}" + mssql-version: "${{fromJson(needs.build-info.outputs.mssql-versions)}}" + exclude: "${{fromJson(needs.build-info.outputs.mssql-exclude)}}" fail-fast: false env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn) }} - BACKEND: mssql - MSSQL_VERSION: ${{ matrix.mssql-version }} - TEST_TYPES: "${{needs.build-info.outputs.testTypes}}" - PYTHON_MAJOR_MINOR_VERSION: ${{ matrix.python-version }} + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" + TEST_TYPES: "${{needs.build-info.outputs.test-types}}" + PR_LABELS: "${{needs.build-info.outputs.pull-request-labels}}" + FULL_TESTS_NEEDED: "${{needs.build-info.outputs.full-tests-needed}}" + DEBUG_RESOURCES: "${{needs.build-info.outputs.debug-resources}}" + BACKEND: "mssql" + PYTHON_MAJOR_MINOR_VERSION: "${{matrix.python-version}}" + MSSQL_VERSION: "${{matrix.mssql-version}}" + BACKEND_VERSION: "${{matrix.mssql-version}}" + JOB_ID: "mssql-${{matrix.mssql-version}}-${{matrix.python-version}}" + COVERAGE: "${{needs.build-info.outputs.run-coverage}}" if: needs.build-info.outputs.run-tests == 'true' steps: - name: Cleanup repo + shell: bash run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: persist-credentials: false - - name: "Setup python" - uses: actions/setup-python@v2 - with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - cache: 'pip' - cache-dependency-path: ./dev/breeze/setup* - - run: ./scripts/ci/install_breeze.sh - - name: "Free space" - run: breeze free-space - - name: Pull CI image ${{ env.PYTHON_MAJOR_MINOR_VERSION }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} - run: breeze pull-image --tag-as-latest - env: - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} - - name: "Test downgrade" - run: ./scripts/ci/testing/run_downgrade_test.sh - - name: "Tests: ${{needs.build-info.outputs.testTypes}}" - run: ./scripts/ci/testing/ci_run_airflow_testing.sh - env: - PR_LABELS: "${{ needs.build-info.outputs.pullRequestLabels }}" - - name: "Upload airflow logs" - uses: actions/upload-artifact@v2 - if: failure() - with: - name: airflow-logs-${{matrix.python-version}}-${{matrix.mssql-version}} - path: "./files/airflow_logs*" - retention-days: 7 - - name: "Upload container logs" - uses: actions/upload-artifact@v2 - if: failure() - with: - name: container-logs-mssql-${{matrix.python-version}}-${{matrix.mssql-version}} - path: "./files/container_logs*" - retention-days: 7 - - name: "Upload artifact for coverage" - uses: actions/upload-artifact@v2 - if: needs.build-info.outputs.runCoverage == 'true' - with: - name: coverage-mssql-${{matrix.python-version}}-${{matrix.mssql-version}} - path: "./files/coverage*.xml" - retention-days: 7 - - name: "Fix ownership" - run: breeze fix-ownership - if: always() + - name: "Prepare breeze & CI image: ${{matrix.python-version}}:${{env.IMAGE_TAG}}" + uses: ./.github/actions/prepare_breeze_and_image + - name: "Migration Tests: ${{matrix.python-version}}:${{needs.build-info.outputs.test-types}}" + uses: ./.github/actions/migration_tests + - name: "Tests: ${{matrix.python-version}}:${{needs.build-info.outputs.test-types}}" + run: breeze testing tests --run-in-parallel + - name: "Post Tests: ${{matrix.python-version}}:${{needs.build-info.outputs.test-types}}" + uses: ./.github/actions/post_tests tests-sqlite: timeout-minutes: 130 name: > - Sqlite Py${{matrix.python-version}}: ${{needs.build-info.outputs.testTypes}} - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + Sqlite Py${{matrix.python-version}}: ${{needs.build-info.outputs.test-types}} + runs-on: "${{needs.build-info.outputs.runs-on}}" needs: [build-info, wait-for-ci-images] strategy: matrix: - python-version: ${{ fromJson(needs.build-info.outputs.pythonVersions) }} - exclude: ${{ fromJson(needs.build-info.outputs.sqliteExclude) }} + python-version: ${{ fromJson(needs.build-info.outputs.python-versions) }} + exclude: ${{ fromJson(needs.build-info.outputs.sqlite-exclude) }} fail-fast: false - env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn) }} - BACKEND: sqlite - TEST_TYPES: "${{needs.build-info.outputs.testTypes}}" - PYTHON_MAJOR_MINOR_VERSION: ${{ matrix.python-version }} if: needs.build-info.outputs.run-tests == 'true' + env: + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" + TEST_TYPES: "${{needs.build-info.outputs.test-types}}" + PR_LABELS: "${{needs.build-info.outputs.pull-request-labels}}" + PYTHON_MAJOR_MINOR_VERSION: "${{matrix.python-version}}" + FULL_TESTS_NEEDED: "${{needs.build-info.outputs.full-tests-needed}}" + DEBUG_RESOURCES: "${{needs.build-info.outputs.debug-resources}}" + BACKEND: "sqlite" + BACKEND_VERSION: "" + JOB_ID: "sqlite-${{matrix.python-version}}" + COVERAGE: "${{needs.build-info.outputs.run-coverage}}" steps: - name: Cleanup repo + shell: bash run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: persist-credentials: false - - name: "Setup python" - uses: actions/setup-python@v2 - with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - cache: 'pip' - cache-dependency-path: ./dev/breeze/setup* - - run: ./scripts/ci/install_breeze.sh - - name: "Free space" - run: breeze free-space - - name: Pull CI image ${{ env.PYTHON_MAJOR_MINOR_VERSION }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} - run: breeze pull-image --tag-as-latest - env: - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} - - name: "Test downgrade" - run: ./scripts/ci/testing/run_downgrade_test.sh - - name: "Tests: ${{needs.build-info.outputs.testTypes}}" - run: ./scripts/ci/testing/ci_run_airflow_testing.sh - env: - PR_LABELS: "${{ needs.build-info.outputs.pullRequestLabels }}" - - name: "Upload airflow logs" - uses: actions/upload-artifact@v2 - if: failure() - with: - name: airflow-logs-${{matrix.python-version}} - path: './files/airflow_logs*' - retention-days: 7 - - name: "Upload container logs" - uses: actions/upload-artifact@v2 - if: failure() + - name: "Prepare breeze & CI image: ${{matrix.python-version}}:${{env.IMAGE_TAG}}" + uses: ./.github/actions/prepare_breeze_and_image + - name: "Migration Tests: ${{matrix.python-version}}:${{needs.build-info.outputs.test-types}}" + uses: ./.github/actions/migration_tests + - name: "Tests: ${{matrix.python-version}}:${{needs.build-info.outputs.test-types}}" + run: breeze testing tests --run-in-parallel + - name: "Post Tests: ${{matrix.python-version}}:${{needs.build-info.outputs.test-types}}" + uses: ./.github/actions/post_tests + + tests-integration-postgres: + timeout-minutes: 130 + name: Integration Tests Postgres + runs-on: "${{needs.build-info.outputs.runs-on}}" + needs: [build-info, wait-for-ci-images] + env: + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" + TEST_TYPES: "${{needs.build-info.outputs.test-types}}" + PR_LABELS: "${{needs.build-info.outputs.pull-request-labels}}" + FULL_TESTS_NEEDED: "${{needs.build-info.outputs.full-tests-needed}}" + DEBUG_RESOURCES: "${{needs.build-info.outputs.debug-resources}}" + BACKEND: "postgres" + PYTHON_MAJOR_MINOR_VERSION: "${{needs.build-info.outputs.default-python-version}}" + POSTGRES_VERSION: "${{needs.build-info.outputs.default-postgres-version}}" + BACKEND_VERSION: "${{needs.build-info.outputs.default-python-version}}" + JOB_ID: "integration" + COVERAGE: "${{needs.build-info.outputs.run-coverage}}" + if: needs.build-info.outputs.run-tests == 'true' && needs.build-info.outputs.default-branch == 'main' + steps: + - name: Cleanup repo + shell: bash + run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" + - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" + uses: actions/checkout@v3 with: - name: container-logs-sqlite-${{matrix.python-version}} - path: "./files/container_logs*" - retention-days: 7 - - name: "Upload artifact for coverage" - uses: actions/upload-artifact@v2 - if: needs.build-info.outputs.runCoverage == 'true' + persist-credentials: false + - name: "Prepare breeze & CI image: ${{env.PYTHON_MAJOR_MINOR_VERSION}}:${{env.IMAGE_TAG}}" + uses: ./.github/actions/prepare_breeze_and_image + - name: "Integration Tests Postgres: cassandra" + run: | + breeze testing integration-tests --integration cassandra + breeze stop + if: needs.build-info.outputs.runs-on != 'self-hosted' + - name: "Integration Tests Postgres: mongo" + run: | + breeze testing integration-tests --integration mongo + breeze stop + if: needs.build-info.outputs.runs-on != 'self-hosted' + - name: "Integration Tests Postgres: pinot" + run: | + breeze testing integration-tests --integration pinot + breeze stop + if: needs.build-info.outputs.runs-on != 'self-hosted' + - name: "Integration Tests Postgres: celery" + run: | + breeze testing integration-tests --integration celery + breeze stop + if: needs.build-info.outputs.runs-on != 'self-hosted' + - name: "Integration Tests Postgres: trino, kerberos" + run: | + breeze testing integration-tests --integration trino --integration kerberos + breeze stop + if: needs.build-info.outputs.runs-on != 'self-hosted' + - name: "Integration Tests Postgres: all" + run: breeze testing integration-tests --integration all + if: needs.build-info.outputs.runs-on == 'self-hosted' + - name: "Post Tests: ${{matrix.python-version}}:${{needs.build-info.outputs.test-types}}" + uses: ./.github/actions/post_tests + + tests-integration-mysql: + timeout-minutes: 130 + name: Integration Tests MySQL + runs-on: "${{needs.build-info.outputs.runs-on}}" + needs: [build-info, wait-for-ci-images] + env: + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" + TEST_TYPES: "${{needs.build-info.outputs.test-types}}" + PR_LABELS: "${{needs.build-info.outputs.pull-request-labels}}" + FULL_TESTS_NEEDED: "${{needs.build-info.outputs.full-tests-needed}}" + DEBUG_RESOURCES: "${{needs.build-info.outputs.debug-resources}}" + BACKEND: "postgres" + PYTHON_MAJOR_MINOR_VERSION: "${{needs.build-info.outputs.default-python-version}}" + POSTGRES_VERSION: "${{needs.build-info.outputs.default-postgres-version}}" + BACKEND_VERSION: "${{needs.build-info.outputs.default-python-version}}" + JOB_ID: "integration" + COVERAGE: "${{needs.build-info.outputs.run-coverage}}" + if: needs.build-info.outputs.run-tests == 'true' && needs.build-info.outputs.default-branch == 'main' + steps: + - name: Cleanup repo + shell: bash + run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" + if: needs.build-info.outputs.runs-on == 'self-hosted' + - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" + uses: actions/checkout@v3 with: - name: coverage-sqlite-${{matrix.python-version}} - path: ./files/coverage*.xml - retention-days: 7 - - name: "Fix ownership" - run: breeze fix-ownership - if: always() + persist-credentials: false + if: needs.build-info.outputs.runs-on == 'self-hosted' + - name: "Prepare breeze & CI image: ${{env.PYTHON_MAJOR_MINOR_VERSION}}:${{env.IMAGE_TAG}}" + uses: ./.github/actions/prepare_breeze_and_image + if: needs.build-info.outputs.runs-on == 'self-hosted' + - name: "Integration Tests MySQL: all" + run: breeze testing integration-tests --integration all + if: needs.build-info.outputs.runs-on == 'self-hosted' + - name: "Post Tests: ${{matrix.python-version}}:${{needs.build-info.outputs.test-types}}" + uses: ./.github/actions/post_tests + if: needs.build-info.outputs.runs-on == 'self-hosted' + tests-quarantined: timeout-minutes: 60 name: "Quarantined tests" - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + runs-on: "${{needs.build-info.outputs.runs-on}}" continue-on-error: true needs: [build-info, wait-for-ci-images] env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn) }} - MYSQL_VERSION: ${{needs.build-info.outputs.defaultMySQLVersion}} - POSTGRES_VERSION: ${{needs.build-info.outputs.defaultPostgresVersion}} + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" TEST_TYPES: "Quarantined" - PYTHON_MAJOR_MINOR_VERSION: ${{ needs.build-info.outputs.defaultPythonVersion }} + PR_LABELS: "${{needs.build-info.outputs.pull-request-labels}}" + PYTHON_MAJOR_MINOR_VERSION: "${{needs.build-info.outputs.default-python-version}}" + DEBUG_RESOURCES: "${{needs.build-info.outputs.debug-resources}}" + BACKEND: "sqlite" + BACKEND_VERSION: "" + JOB_ID: "quarantined-${{needs.build-info.outputs.default-python-version}}" + COVERAGE: "${{needs.build-info.outputs.run-coverage}}" if: needs.build-info.outputs.run-tests == 'true' steps: - name: Cleanup repo + shell: bash run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: persist-credentials: false - - name: "Setup python" - uses: actions/setup-python@v2 - with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - cache: 'pip' - cache-dependency-path: ./dev/breeze/setup* - - name: "Set issue id for main" - if: github.ref == 'refs/heads/main' - run: | - echo "ISSUE_ID=10118" >> $GITHUB_ENV - - name: "Set issue id for v1-10-stable" - if: github.ref == 'refs/heads/v1-10-stable' - run: | - echo "ISSUE_ID=10127" >> $GITHUB_ENV - - name: "Set issue id for v1-10-test" - if: github.ref == 'refs/heads/v1-10-test' - run: | - echo "ISSUE_ID=10128" >> $GITHUB_ENV - - run: ./scripts/ci/install_breeze.sh - - name: "Free space" - run: breeze free-space - - name: Pull CI image ${{ env.PYTHON_MAJOR_MINOR_VERSION }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} - run: breeze pull-image --tag-as-latest - env: - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} - - name: "Tests: Quarantined" - run: ./scripts/ci/testing/ci_run_quarantined_tests.sh - env: - PR_LABELS: "${{ needs.build-info.outputs.pullRequestLabels }}" - - name: "Upload Quarantine test results" - uses: actions/upload-artifact@v2 - if: always() - with: - name: quarantined-tests - path: "files/test_result-*.xml" - retention-days: 7 - - name: "Upload airflow logs" - uses: actions/upload-artifact@v2 - if: failure() - with: - name: airflow-logs-quarantined-${{ matrix.backend }} - path: "./files/airflow_logs*" - retention-days: 7 - - name: "Upload container logs" - uses: actions/upload-artifact@v2 - if: failure() - with: - name: container-logs-quarantined-${{ matrix.backend }} - path: "./files/container_logs*" - retention-days: 7 - - name: "Upload artifact for coverage" - uses: actions/upload-artifact@v2 - if: needs.build-info.outputs.runCoverage == 'true' - with: - name: coverage-quarantined-${{ matrix.backend }} - path: "./files/coverage*.xml" - retention-days: 7 - - name: "Fix ownership" - run: breeze fix-ownership - if: always() + - name: > + Prepare breeze & CI image: ${{needs.build-info.outputs.default-python-version}}:${{env.IMAGE_TAG}} + uses: ./.github/actions/prepare_breeze_and_image + - name: > + Tests: ${{needs.build-info.outputs.default-python-version}}:Quarantined + run: breeze testing tests --run-in-parallel || true + - name: > + Post Tests: ${{needs.build-info.outputs.default-python-version}}:Quarantined" + uses: ./.github/actions/post_tests upload-coverage: timeout-minutes: 15 name: "Upload coverage" - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + runs-on: "${{needs.build-info.outputs.runs-on}}" continue-on-error: true needs: - build-info @@ -1324,20 +1099,22 @@ ${{ hashFiles('.pre-commit-config.yaml') }}" - tests-mysql - tests-mssql - tests-quarantined + - tests-integration-postgres + - tests-integration-mysql env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn) }} + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" # Only upload coverage on merges to main - if: needs.build-info.outputs.runCoverage == 'true' + if: needs.build-info.outputs.run-coverage == 'true' steps: - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: persist-credentials: false submodules: recursive - name: "Download all artifacts from the current build" - uses: actions/download-artifact@v2 + uses: actions/download-artifact@v3 with: path: ./coverage-files - name: "Removes unnecessary artifacts" @@ -1347,107 +1124,114 @@ ${{ hashFiles('.pre-commit-config.yaml') }}" with: directory: "./coverage-files" + summarize-warnings: + timeout-minutes: 15 + name: "Summarize warnings" + runs-on: "${{needs.build-info.outputs.runs-on}}" + needs: + - build-info + - tests-postgres + - tests-sqlite + - tests-mysql + - tests-mssql + - tests-quarantined + - tests-integration-postgres + - tests-integration-mysql + env: + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" + steps: + - name: Cleanup repo + run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" + - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" + uses: actions/checkout@v3 + with: + persist-credentials: false + submodules: recursive + - name: "Download all artifacts from the current build" + uses: actions/download-artifact@v3 + with: + path: ./artifacts + - name: "Summarize all warnings" + run: | + ls -R ./artifacts/ + cat ./artifacts/test-warnings*/* | sort | uniq + echo + echo Total number of unique warnings $(cat ./artifacts/test-warnings*/* | sort | uniq | wc -l) + + wait-for-prod-images: timeout-minutes: 120 name: "Wait for PROD images" - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + runs-on: "${{needs.build-info.outputs.runs-on}}" needs: [build-info, wait-for-ci-images, build-prod-images] if: needs.build-info.outputs.image-build == 'true' env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn) }} + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" BACKEND: sqlite - PYTHON_MAJOR_MINOR_VERSION: ${{ needs.build-info.outputs.defaultPythonVersion }} + PYTHON_MAJOR_MINOR_VERSION: "${{needs.build-info.outputs.default-python-version}}" + # Force more parallelism for pull even on public images + PARALLELISM: 6 steps: - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: persist-credentials: false - - name: "Setup python" - uses: actions/setup-python@v2 - with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - cache: 'pip' - cache-dependency-path: ./dev/breeze/setup* - - run: ./scripts/ci/install_breeze.sh - - name: "Free space" - run: breeze free-space - - name: "Cache virtualenv environment" - uses: actions/cache@v2 - with: - path: '.build/.docker_venv' - key: ${{ runner.os }}-docker-venv-${{ hashFiles('scripts/ci/images/ci_run_docker_tests.py') }} - - name: Wait for PROD images ${{ env.PYTHON_VERSIONS }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} + - name: "Install Breeze" + uses: ./.github/actions/breeze + - name: Wait for PROD images ${{ env.PYTHON_VERSIONS }}:${{ env.IMAGE_TAG }} # We wait for the images to be available either from "build-images.yml' run as pull_request_target - # or from build-prod-image above. + # or from build-prod-images above. # We are utilising single job to wait for all images because this job merely waits # For the images to be available and test them. - run: breeze pull-prod-image --verify-image --wait-for-image --run-in-parallel + run: breeze prod-image pull --verify --wait-for-image --run-in-parallel env: - PYTHON_VERSIONS: ${{ needs.build-info.outputs.pythonVersionsListAsString }} - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} + PYTHON_VERSIONS: ${{ needs.build-info.outputs.python-versions-list-as-string }} + DEBUG_RESOURCES: ${{needs.build-info.outputs.debug-resources}} - name: "Fix ownership" - run: breeze fix-ownership + run: breeze ci fix-ownership if: always() test-docker-compose-quick-start: timeout-minutes: 60 name: "Test docker-compose quick start" - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + runs-on: "${{needs.build-info.outputs.runs-on}}" needs: [build-info, wait-for-prod-images] if: needs.build-info.outputs.image-build == 'true' env: - PYTHON_MAJOR_MINOR_VERSION: ${{ needs.build-info.outputs.defaultPythonVersion }} + PYTHON_MAJOR_MINOR_VERSION: "${{needs.build-info.outputs.default-python-version}}" steps: - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: fetch-depth: 2 persist-credentials: false - - name: "Setup python" - uses: actions/setup-python@v2 + - name: > + Prepare breeze & PROD image: ${{needs.build-info.outputs.default-python-version}}:${{env.IMAGE_TAG}} + uses: ./.github/actions/prepare_breeze_and_image with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - cache: 'pip' - cache-dependency-path: ./dev/breeze/setup* - - run: ./scripts/ci/install_breeze.sh - - name: "Free space" - run: breeze free-space - - name: Pull PROD image ${{ env.PYTHON_MAJOR_MINOR_VERSION }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} - run: breeze pull-prod-image --tag-as-latest - env: - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} + pull-image-type: 'PROD' - name: "Test docker-compose quick start" - run: breeze docker-compose-tests + run: breeze testing docker-compose-tests - name: "Fix ownership" - run: breeze fix-ownership + run: breeze ci fix-ownership if: always() tests-kubernetes: - timeout-minutes: 70 - name: Helm Chart; ${{matrix.executor}} - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + timeout-minutes: 240 + name: "Helm: ${{matrix.executor}} - ${{needs.build-info.outputs.kubernetes-versions-list-as-string}}" + runs-on: "${{needs.build-info.outputs.runs-on}}" needs: [build-info, wait-for-prod-images] strategy: matrix: executor: [KubernetesExecutor, CeleryExecutor, LocalExecutor] fail-fast: false env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn) }} - BACKEND: postgres - RUN_TESTS: "true" - RUNTIME: "kubernetes" - KUBERNETES_MODE: "image" - EXECUTOR: ${{matrix.executor}} - KIND_VERSION: "${{ needs.build-info.outputs.defaultKindVersion }}" - HELM_VERSION: "${{ needs.build-info.outputs.defaultHelmVersion }}" - CURRENT_PYTHON_MAJOR_MINOR_VERSIONS_AS_STRING: > - ${{needs.build-info.outputs.pythonVersionsListAsString}} - CURRENT_KUBERNETES_VERSIONS_AS_STRING: > - ${{needs.build-info.outputs.kubernetesVersionsListAsString}} + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" + DEBUG_RESOURCES: ${{needs.build-info.outputs.debug-resources}} if: > ( needs.build-info.outputs.run-kubernetes-tests == 'true' || needs.build-info.outputs.needs-helm-tests == 'true' ) && @@ -1456,130 +1240,59 @@ ${{ hashFiles('.pre-commit-config.yaml') }}" - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: persist-credentials: false - - name: "Setup python" - uses: actions/setup-python@v2 - with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - cache: 'pip' - cache-dependency-path: ./dev/breeze/setup* - - run: ./scripts/ci/install_breeze.sh - - name: "Free space" - run: breeze free-space - - name: Pull PROD images ${{ env.PYTHON_VERSIONS }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} - run: breeze pull-prod-image --run-in-parallel --tag-as-latest + - name: "Install Breeze" + uses: ./.github/actions/breeze + - name: Pull PROD images ${{ env.PYTHON_VERSIONS }}:${{ env.IMAGE_TAG }} + run: breeze prod-image pull --run-in-parallel --tag-as-latest env: - PYTHON_VERSIONS: ${{ needs.build-info.outputs.pythonVersionsListAsString }} - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} + PYTHON_VERSIONS: ${{ needs.build-info.outputs.python-versions-list-as-string }} - name: "Cache bin folder with tools for kubernetes testing" - uses: actions/cache@v2 - with: - path: ".build/kubernetes-bin" - key: "kubernetes-binaries --${{ needs.build-info.outputs.defaultKindVersion }}\ --${{ needs.build-info.outputs.defaultHelmVersion }}" - restore-keys: "kubernetes-binaries" - - name: "Kubernetes Tests" - run: ./scripts/ci/kubernetes/ci_setup_clusters_and_run_kubernetes_tests_in_parallel.sh + uses: actions/cache@v3 + with: + path: ".build/.k8s-env" + key: "\ + k8s-env-${{ hashFiles('scripts/ci/kubernetes/k8s_requirements.txt','setup.cfg',\ + 'setup.py','pyproject.toml','generated/provider_dependencies.json') }}" + - name: Run complete K8S tests ${{needs.build-info.outputs.kubernetes-combos}} + run: breeze k8s run-complete-tests --run-in-parallel --upgrade env: - PR_LABELS: "${{ needs.build-info.outputs.pullRequestLabels }}" - - name: "Upload KinD logs" - uses: actions/upload-artifact@v2 + PYTHON_VERSIONS: ${{ needs.build-info.outputs.python-versions-list-as-string }} + KUBERNETES_VERSIONS: ${{needs.build-info.outputs.kubernetes-versions-list-as-string}} + EXECUTOR: ${{matrix.executor}} + VERBOSE: false + - name: Upload KinD logs on failure ${{needs.build-info.outputs.kubernetes-combos}} + uses: actions/upload-artifact@v3 if: failure() || cancelled() with: name: kind-logs-${{matrix.executor}} path: /tmp/kind_logs_* retention-days: 7 - - name: "Fix ownership" - run: breeze fix-ownership - if: always() - - tests-helm-executor-upgrade: - timeout-minutes: 150 - name: Helm Chart Executor Upgrade - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} - needs: [build-info, wait-for-prod-images] - env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn) }} - BACKEND: postgres - RUN_TESTS: "true" - RUNTIME: "kubernetes" - KUBERNETES_MODE: "image" - EXECUTOR: "KubernetesExecutor" - KIND_VERSION: "${{ needs.build-info.outputs.defaultKindVersion }}" - HELM_VERSION: "${{ needs.build-info.outputs.defaultHelmVersion }}" - CURRENT_PYTHON_MAJOR_MINOR_VERSIONS_AS_STRING: > - ${{needs.build-info.outputs.pythonVersionsListAsString}} - CURRENT_KUBERNETES_VERSIONS_AS_STRING: > - ${{needs.build-info.outputs.kubernetesVersionsListAsString}} - if: > - needs.build-info.outputs.run-kubernetes-tests == 'true' && - needs.build-info.outputs.default-branch == 'main' - steps: - - name: Cleanup repo - run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 - with: - persist-credentials: false - - name: "Setup python" - uses: actions/setup-python@v2 - with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - cache: 'pip' - cache-dependency-path: ./dev/breeze/setup* - - run: ./scripts/ci/install_breeze.sh - - name: "Free space" - run: breeze free-space - - name: Pull PROD images ${{ env.PYTHON_VERSIONS }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} - run: breeze pull-prod-image --run-in-parallel --tag-as-latest - env: - PYTHON_VERSIONS: ${{ needs.build-info.outputs.pythonVersionsListAsString }} - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} - - name: "Cache virtualenv for kubernetes testing" - uses: actions/cache@v2 - with: - path: ".build/.kubernetes_venv" - key: "kubernetes-${{ needs.build-info.outputs.defaultPythonVersion }}\ - -${{needs.build-info.outputs.kubernetesVersionsListAsString}} - -${{needs.build-info.outputs.pythonVersionsListAsString}} - -${{ hashFiles('setup.py','setup.cfg') }}" - restore-keys: "kubernetes-${{ needs.build-info.outputs.defaultPythonVersion }}-\ - -${{needs.build-info.outputs.kubernetesVersionsListAsString}} - -${{needs.build-info.outputs.pythonVersionsListAsString}}" - - name: "Cache bin folder with tools for kubernetes testing" - uses: actions/cache@v2 - with: - path: ".build/kubernetes-bin" - key: "kubernetes-binaries - -${{ needs.build-info.outputs.defaultKindVersion }}\ - -${{ needs.build-info.outputs.defaultHelmVersion }}" - restore-keys: "kubernetes-binaries" - - name: "Kubernetes Helm Chart Executor Upgrade Tests" - run: ./scripts/ci/kubernetes/ci_upgrade_cluster_with_different_executors_in_parallel.sh - env: - PR_LABELS: "${{ needs.build-info.outputs.pullRequestLabels }}" - - name: "Upload KinD logs" - uses: actions/upload-artifact@v2 + - name: Upload test resource logs on failure ${{needs.build-info.outputs.kubernetes-combos}} + uses: actions/upload-artifact@v3 if: failure() || cancelled() with: - name: kind-logs-KubernetesExecutor - path: /tmp/kind_logs_* + name: k8s-test-resources-${{matrix.executor}} + path: /tmp/k8s_test_resources_* retention-days: 7 + - name: "Delete clusters just in case they are left" + run: breeze k8s delete-cluster --all + if: always() - name: "Fix ownership" - run: breeze fix-ownership + run: breeze ci fix-ownership if: always() constraints: permissions: contents: write - timeout-minutes: 40 + timeout-minutes: 80 name: "Constraints" - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + runs-on: "${{needs.build-info.outputs.runs-on}}" needs: - build-info + - docs - wait-for-ci-images - wait-for-prod-images - static-checks @@ -1587,62 +1300,60 @@ ${{ hashFiles('.pre-commit-config.yaml') }}" - tests-mysql - tests-mssql - tests-postgres + - tests-integration-postgres + - tests-integration-mysql + - push-early-buildx-cache-to-github-registry env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn) }} - if: needs.build-info.outputs.upgradeToNewerDependencies != 'false' + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" + DEBUG_RESOURCES: ${{needs.build-info.outputs.debug-resources}} + if: needs.build-info.outputs.upgrade-to-newer-dependencies != 'false' steps: - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: persist-credentials: false submodules: recursive - - name: "Setup python" - uses: actions/setup-python@v2 - with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - cache: 'pip' - cache-dependency-path: ./dev/breeze/setup* - - run: ./scripts/ci/install_breeze.sh - - name: "Free space" - run: breeze free-space - - name: Pull CI images ${{ env.PYTHON_VERSIONS }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }} - run: breeze pull-image --run-in-parallel --tag-as-latest + - name: "Install Breeze" + uses: ./.github/actions/breeze + - name: Pull CI images ${{ env.PYTHON_VERSIONS }}:${{ env.IMAGE_TAG }} + run: breeze ci-image pull --run-in-parallel --tag-as-latest env: - PYTHON_VERSIONS: ${{ needs.build-info.outputs.pythonVersionsListAsString }} - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} + PYTHON_VERSIONS: ${{ needs.build-info.outputs.python-versions-list-as-string }} - name: "Generate constraints" run: | - breeze generate-constraints --run-in-parallel \ + breeze release-management generate-constraints --run-in-parallel \ --airflow-constraints-mode constraints-source-providers - breeze generate-constraints --run-in-parallel --airflow-constraints-mode constraints-no-providers - breeze generate-constraints --run-in-parallel --airflow-constraints-mode constraints + breeze release-management generate-constraints \ + --run-in-parallel --airflow-constraints-mode constraints-no-providers + breeze release-management generate-constraints \ + --run-in-parallel --airflow-constraints-mode constraints env: - PYTHON_VERSIONS: ${{ needs.build-info.outputs.pythonVersionsListAsString }} + PYTHON_VERSIONS: ${{ needs.build-info.outputs.python-versions-list-as-string }} - name: "Set constraints branch name" id: constraints-branch - run: ./scripts/ci/constraints/ci_branch_constraints.sh - if: needs.build-info.outputs.mergeRun == 'true' + run: ./scripts/ci/constraints/ci_branch_constraints.sh >> ${GITHUB_OUTPUT} + if: needs.build-info.outputs.canary-run == 'true' - name: Checkout ${{ steps.constraints-branch.outputs.branch }} - uses: actions/checkout@v2 - if: needs.build-info.outputs.mergeRun == 'true' + uses: actions/checkout@v3 + if: needs.build-info.outputs.canary-run == 'true' with: path: "repo" ref: ${{ steps.constraints-branch.outputs.branch }} persist-credentials: false - - name: "Commit changed constraint files for ${{needs.build-info.outputs.pythonVersions}}" + - name: "Commit changed constraint files for ${{needs.build-info.outputs.python-versions}}" run: ./scripts/ci/constraints/ci_commit_constraints.sh - if: needs.build-info.outputs.mergeRun == 'true' + if: needs.build-info.outputs.canary-run == 'true' - name: "Push changes" uses: ./.github/actions/github-push-action - if: needs.build-info.outputs.mergeRun == 'true' + if: needs.build-info.outputs.canary-run == 'true' with: github_token: ${{ secrets.GITHUB_TOKEN }} branch: ${{ steps.constraints-branch.outputs.branch }} directory: "repo" - name: "Fix ownership" - run: breeze fix-ownership + run: breeze ci fix-ownership if: always() # Push BuildX cache to GitHub Registry in Apache repository, if all tests are successful and build @@ -1652,85 +1363,140 @@ ${{ hashFiles('.pre-commit-config.yaml') }}" push-buildx-cache-to-github-registry: permissions: packages: write - timeout-minutes: 120 + timeout-minutes: 50 name: "Push Image Cache" - runs-on: ${{ fromJson(needs.build-info.outputs.runsOn) }} + runs-on: "${{needs.build-info.outputs.runs-on}}" needs: - build-info - constraints - docs - if: needs.build-info.outputs.mergeRun == 'true' + if: needs.build-info.outputs.canary-run == 'true' strategy: fail-fast: false matrix: - python-version: ${{ fromJson(needs.build-info.outputs.pythonVersions) }} platform: ["linux/amd64", "linux/arm64"] env: - RUNS_ON: ${{ fromJson(needs.build-info.outputs.runsOn) }} - PYTHON_MAJOR_MINOR_VERSION: ${{ matrix.python-version }} + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" steps: - name: Cleanup repo run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: persist-credentials: false - - name: "Setup python" - uses: actions/setup-python@v2 - with: - python-version: ${{ needs.build-info.outputs.defaultPythonVersion }} - cache: 'pip' - cache-dependency-path: ./dev/breeze/setup* - - run: ./scripts/ci/install_breeze.sh - - name: "Free space" - run: breeze free-space - name: > - Pull CI image for PROD build - ${{ env.PYTHON_MAJOR_MINOR_VERSION }}:${{ env.IMAGE_TAG_FOR_THE_BUILD }}" - run: breeze pull-image --tag-as-latest + Prepare breeze & CI image: ${{needs.build-info.outputs.default-python-version}}:${{env.IMAGE_TAG}} + uses: ./.github/actions/prepare_breeze_and_image env: # Always use default Python version of CI image for preparing packages - PYTHON_MAJOR_MINOR_VERSION: ${{ needs.build-info.outputs.defaultPythonVersion }} - IMAGE_TAG: ${{ env.IMAGE_TAG_FOR_THE_BUILD }} + PYTHON_MAJOR_MINOR_VERSION: "${{needs.build-info.outputs.default-python-version}}" - name: "Cleanup dist and context file" run: rm -fv ./dist/* ./docker-context-files/* + - name: "Prepare airflow package for PROD build" + run: breeze release-management prepare-airflow-package --package-format wheel + env: + VERSION_SUFFIX_FOR_PYPI: "dev0" - name: "Prepare providers packages for PROD build" run: > - breeze prepare-provider-packages + breeze release-management prepare-provider-packages --package-list-file ./scripts/ci/installed_providers.txt --package-format wheel env: VERSION_SUFFIX_FOR_PYPI: "dev0" - - name: "Prepare airflow package for PROD build" - run: breeze prepare-airflow-package --package-format wheel - env: - VERSION_SUFFIX_FOR_PYPI: "dev0" + if: needs.build-info.outputs.default-branch == 'main' - name: "Start ARM instance" run: ./scripts/ci/images/ci_start_arm_instance_and_connect_to_docker.sh if: matrix.platform == 'linux/arm64' - - name: "Push CI cache ${{ matrix.python-version }} ${{ matrix.platform }}" + - name: "Push CI cache ${{ matrix.platform }}" run: > - breeze build-image + breeze ci-image build + --builder airflow_cache --prepare-buildx-cache + --run-in-parallel --force-build --platform ${{ matrix.platform }} env: - PYTHON_MAJOR_MINOR_VERSION: ${{ matrix.python-version }} + DEBUG_RESOURCES: ${{needs.build-info.outputs.debug-resources}} + - name: "Push CI latest image ${{ matrix.platform }}" + run: > + breeze ci-image build + --tag-as-latest --push --run-in-parallel --platform ${{ matrix.platform }} + env: + DEBUG_RESOURCES: ${{needs.build-info.outputs.debug-resources}} + if: matrix.platform == 'linux/amd64' - name: "Move dist packages to docker-context files" run: mv -v ./dist/*.whl ./docker-context-files - name: "Push PROD cache ${{ matrix.python-version }} ${{ matrix.platform }}" run: > - breeze build-prod-image - --airflow-is-in-context + breeze prod-image build + --builder airflow_cache --install-packages-from-context --prepare-buildx-cache - --disable-airflow-repo-cache --platform ${{ matrix.platform }} + - name: "Push PROD latest image ${{ matrix.platform }}" + run: > + breeze prod-image build --tag-as-latest --install-packages-from-context + --push --run-in-parallel --platform ${{ matrix.platform }} env: - PYTHON_MAJOR_MINOR_VERSION: ${{ matrix.python-version }} + DEBUG_RESOURCES: ${{needs.build-info.outputs.debug-resources}} + if: matrix.platform == 'linux/amd64' - name: "Stop ARM instance" run: ./scripts/ci/images/ci_stop_arm_instance.sh if: always() && matrix.platform == 'linux/arm64' - name: "Fix ownership" - run: breeze fix-ownership + run: breeze ci fix-ownership + if: always() + + build-ci-arm-images: + timeout-minutes: 50 + name: > + Build CI ARM images + ${{needs.build-info.outputs.all-python-versions-list-as-string}} + runs-on: "${{needs.build-info.outputs.runs-on}}" + needs: + - build-info + - wait-for-ci-images + - wait-for-prod-images + - static-checks + - tests-sqlite + - tests-mysql + - tests-mssql + - tests-postgres + - tests-integration-postgres + - tests-integration-mysql + env: + DEFAULT_BRANCH: ${{ needs.build-info.outputs.default-branch }} + DEFAULT_CONSTRAINTS_BRANCH: ${{ needs.build-info.outputs.default-constraints-branch }} + RUNS_ON: "${{needs.build-info.outputs.runs-on}}" + if: > + needs.build-info.outputs.upgrade-to-newer-dependencies != 'false' && + needs.build-info.outputs.in-workflow-build == 'true' && + needs.build-info.outputs.canary-run != 'true' + steps: + - name: Cleanup repo + run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" + - uses: actions/checkout@v3 + with: + ref: ${{ needs.build-info.outputs.targetCommitSha }} + persist-credentials: false + submodules: recursive + - name: "Install Breeze" + uses: ./.github/actions/breeze + - name: "Start ARM instance" + run: ./scripts/ci/images/ci_start_arm_instance_and_connect_to_docker.sh + - name: > + Build CI ARM images ${{ env.IMAGE_TAG }} + ${{needs.build-info.outputs.all-python-versions-list-as-string}}:${{env.IMAGE_TAG}} + run: > + breeze ci-image build --run-in-parallel --builder airflow_cache --platform "linux/arm64" + env: + UPGRADE_TO_NEWER_DEPENDENCIES: ${{ needs.build-info.outputs.upgrade-to-newer-dependencies }} + DOCKER_CACHE: ${{ needs.build-info.outputs.cache-directive }} + PYTHON_VERSIONS: ${{needs.build-info.outputs.all-python-versions-list-as-string}} + DEBUG_RESOURCES: ${{needs.build-info.outputs.debug-resources}} + - name: "Stop ARM instance" + run: ./scripts/ci/images/ci_stop_arm_instance.sh + if: always() + - name: "Fix ownership" + run: breeze ci fix-ownership if: always() diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 6d6f4d02562d5..858b8a5d6b5e1 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -39,23 +39,18 @@ jobs: needs-javascript-scans: ${{ steps.selective-checks.outputs.needs-javascript-scans }} steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: fetch-depth: 2 persist-credentials: false + - name: "Install Breeze" + uses: ./.github/actions/breeze - name: Selective checks id: selective-checks env: - EVENT_NAME: ${{ github.event_name }} - TARGET_COMMIT_SHA: ${{ github.sha }} - run: | - if [[ ${EVENT_NAME} == "pull_request" ]]; then - # Run selective checks - ./scripts/ci/selective_ci_checks.sh "${TARGET_COMMIT_SHA}" - else - # Run all checks - ./scripts/ci/selective_ci_checks.sh - fi + COMMIT_REF: "${{ github.sha }}" + VERBOSE: "false" + run: breeze ci selective-check >> ${GITHUB_OUTPUT} analyze: name: Analyze @@ -74,7 +69,7 @@ jobs: security-events: write steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: persist-credentials: false if: | @@ -83,7 +78,7 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v1 + uses: github/codeql-action/init@v2 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -97,13 +92,13 @@ jobs: # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild - uses: github/codeql-action/autobuild@v1 + uses: github/codeql-action/autobuild@v2 if: | matrix.language == 'python' && needs.selective-checks.outputs.needs-python-scans == 'true' || matrix.language == 'javascript' && needs.selective-checks.outputs.needs-javascript-scans == 'true' - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v1 + uses: github/codeql-action/analyze@v2 if: | matrix.language == 'python' && needs.selective-checks.outputs.needs-python-scans == 'true' || matrix.language == 'javascript' && needs.selective-checks.outputs.needs-javascript-scans == 'true' diff --git a/.github/workflows/label_when_reviewed.yml b/.github/workflows/label_when_reviewed.yml deleted file mode 100644 index 189a2d7343b9b..0000000000000 --- a/.github/workflows/label_when_reviewed.yml +++ /dev/null @@ -1,28 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ---- -name: Label when reviewed -on: pull_request_review # yamllint disable-line rule:truthy -jobs: - - label-when-reviewed: - name: "Label PRs when reviewed" - runs-on: ubuntu-20.04 - steps: - - name: "Do nothing. Only trigger corresponding workflow_run event" - run: echo diff --git a/.github/workflows/label_when_reviewed_workflow_run.yml b/.github/workflows/label_when_reviewed_workflow_run.yml deleted file mode 100644 index 9b11d71ad2498..0000000000000 --- a/.github/workflows/label_when_reviewed_workflow_run.yml +++ /dev/null @@ -1,177 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ---- -name: Label when reviewed workflow run -on: # yamllint disable-line rule:truthy - workflow_run: - workflows: ["Label when reviewed"] - types: ['requested'] -permissions: - # All other permissions are set to none - checks: write - contents: read - pull-requests: write -jobs: - - label-when-reviewed: - name: "Label PRs when reviewed workflow run" - runs-on: ubuntu-20.04 - outputs: - labelSet: ${{ steps.label-when-reviewed.outputs.labelSet }} - steps: - - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 - with: - persist-credentials: false - submodules: recursive - - name: "Get information about the original trigger of the run" - uses: ./.github/actions/get-workflow-origin - id: source-run-info - with: - token: ${{ secrets.GITHUB_TOKEN }} - sourceRunId: ${{ github.event.workflow_run.id }} - - name: Initiate Selective Build check - uses: ./.github/actions/checks-action - id: selective-build-check - with: - token: ${{ secrets.GITHUB_TOKEN }} - name: "Selective build check" - status: "in_progress" - sha: ${{ steps.source-run-info.outputs.sourceHeadSha }} - details_url: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }} - output: > - {"summary": - "Checking selective status of the build in - [the run](https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}) - "} - - name: > - Event: ${{ steps.source-run-info.outputs.sourceEvent }} - Repo: ${{ steps.source-run-info.outputs.sourceHeadRepo }} - Branch: ${{ steps.source-run-info.outputs.sourceHeadBranch }} - Run id: ${{ github.run_id }} - Source Run id: ${{ github.event.workflow_run.id }} - Sha: ${{ github.sha }} - Source Sha: ${{ steps.source-run-info.outputs.sourceHeadSha }} - Merge commit Sha: ${{ steps.source-run-info.outputs.mergeCommitSha }} - Target commit Sha: ${{ steps.source-run-info.outputs.targetCommitSha }} - run: printenv - - name: > - Fetch incoming commit ${{ steps.source-run-info.outputs.targetCommitSha }} with its parent - uses: actions/checkout@v2 - with: - ref: ${{ steps.source-run-info.outputs.targetCommitSha }} - fetch-depth: 2 - persist-credentials: false - # checkout the main branch again, to use the right script in main workflow - - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" - uses: actions/checkout@v2 - with: - persist-credentials: false - submodules: recursive - - name: Selective checks - id: selective-checks - env: - EVENT_NAME: ${{ steps.source-run-info.outputs.sourceEvent }} - TARGET_COMMIT_SHA: ${{ steps.source-run-info.outputs.targetCommitSha }} - PR_LABELS: ${{ steps.source-run-info.outputs.pullRequestLabels }} - run: | - if [[ ${EVENT_NAME} == "pull_request_review" ]]; then - # Run selective checks - ./scripts/ci/selective_ci_checks.sh "${TARGET_COMMIT_SHA}" - else - # Run all checks - ./scripts/ci/selective_ci_checks.sh - fi - - name: "Label when approved by committers for PRs that require full tests" - uses: ./.github/actions/label-when-approved-action - id: label-full-test-prs-when-approved-by-commiters - if: > - steps.selective-checks.outputs.run-tests == 'true' && - contains(steps.selective-checks.outputs.test-types, 'Core') - with: - token: ${{ secrets.GITHUB_TOKEN }} - label: 'full tests needed' - require_committers_approval: 'true' - remove_label_when_approval_missing: 'false' - pullRequestNumber: ${{ steps.source-run-info.outputs.pullRequestNumber }} - comment: > - The PR most likely needs to run full matrix of tests because it modifies parts of the core - of Airflow. However, committers might decide to merge it quickly and take the risk. - If they don't merge it quickly - please rebase it to the latest main at your convenience, - or amend the last commit of the PR, and push it with --force-with-lease. - - name: "Initiate GitHub Check forcing rerun of SH ${{ github.event.pull_request.head.sha }}" - uses: ./.github/actions/checks-action - id: full-test-check - if: steps.label-full-test-prs-when-approved-by-commiters.outputs.labelSet == 'true' - with: - token: ${{ secrets.GITHUB_TOKEN }} - name: "Please rebase or amend, and force push the PR to run full tests" - status: "in_progress" - sha: ${{ steps.source-run-info.outputs.sourceHeadSha }} - details_url: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }} - output: > - {"summary": - "The PR likely needs to run all tests! This was determined via selective check in - [the run](https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}) - "} - - name: "Label when approved by committers for PRs that do not require full tests" - uses: ./.github/actions/label-when-approved-action - id: label-simple-test-prs-when-approved-by-commiters - if: > - steps.selective-checks.outputs.run-tests == 'true' && - ! contains(steps.selective-checks.outputs.test-types, 'Core') - with: - token: ${{ secrets.GITHUB_TOKEN }} - label: 'okay to merge' - require_committers_approval: 'true' - pullRequestNumber: ${{ steps.source-run-info.outputs.pullRequestNumber }} - comment: > - The PR is likely OK to be merged with just subset of tests for default Python and Database - versions without running the full matrix of tests, because it does not modify the core of - Airflow. If the committers decide that the full tests matrix is needed, they will add the label - 'full tests needed'. Then you should rebase to the latest main or amend the last commit - of the PR, and push it with --force-with-lease. - - name: "Label when approved by committers for PRs that do not require tests at all" - uses: ./.github/actions/label-when-approved-action - id: label-no-test-prs-when-approved-by-commiters - if: steps.selective-checks.outputs.run-tests != 'true' - with: - token: ${{ secrets.GITHUB_TOKEN }} - label: 'okay to merge' - pullRequestNumber: ${{ steps.source-run-info.outputs.pullRequestNumber }} - require_committers_approval: 'true' - comment: > - The PR is likely ready to be merged. No tests are needed as no important environment files, - nor python files were modified by it. However, committers might decide that full test matrix is - needed and add the 'full tests needed' label. Then you should rebase it to the latest main - or amend the last commit of the PR, and push it with --force-with-lease. - - name: Update Selective Build check - uses: ./.github/actions/checks-action - if: always() - with: - token: ${{ secrets.GITHUB_TOKEN }} - check_id: ${{ steps.selective-build-check.outputs.check_id }} - status: "completed" - sha: ${{ steps.source-run-info.outputs.sourceHeadSha }} - conclusion: ${{ job.status }} - details_url: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }} - output: > - {"summary": - "Checking selective status of the build completed in - [the run](https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}) - "} diff --git a/.github/workflows/release_dockerhub_image.yml b/.github/workflows/release_dockerhub_image.yml new file mode 100644 index 0000000000000..65527e2c9b6a0 --- /dev/null +++ b/.github/workflows/release_dockerhub_image.yml @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +--- +name: "Release PROD images" +on: # yamllint disable-line rule:truthy + workflow_dispatch: + inputs: + airflowVersion: + description: 'Airflow version' + required: true + skipLatest: + description: 'Skip Latest: Set to true if not latest.' + default: '' + required: false +permissions: + contents: read +concurrency: + group: ${{ github.event.inputs.airflowVersion }} + cancel-in-progress: true +jobs: + build-info: + timeout-minutes: 10 + name: "Build Info" + runs-on: ${{ github.repository == 'apache/airflow' && 'self-hosted' || 'ubuntu-20.04' }} + outputs: + pythonVersions: ${{ steps.selective-checks.outputs.python-versions }} + allPythonVersions: ${{ steps.selective-checks.outputs.all-python-versions }} + defaultPythonVersion: ${{ steps.selective-checks.outputs.default-python-version }} + skipLatest: ${{ github.event.inputs.skipLatest == '' && ' ' || '--skip-latest' }} + limitPlatform: ${{ github.repository == 'apache/airflow' && ' ' || '--limit-platform linux/amd64' }} + env: + GITHUB_CONTEXT: ${{ toJson(github) }} + steps: + - name: Cleanup repo + run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" + - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" + uses: actions/checkout@v3 + with: + persist-credentials: false + submodules: recursive + - name: "Install Breeze" + uses: ./.github/actions/breeze + - name: Selective checks + id: selective-checks + env: + VERBOSE: "false" + run: breeze ci selective-check >> ${GITHUB_OUTPUT} + release-images: + timeout-minutes: 120 + name: "Release images: ${{ github.event.inputs.airflowVersion }}, ${{ matrix.python-version }}" + runs-on: ${{ github.repository == 'apache/airflow' && 'self-hosted' || 'ubuntu-20.04' }} + needs: [build-info] + strategy: + fail-fast: false + matrix: + python-version: ${{ fromJson(needs.build-info.outputs.pythonVersions) }} + env: + RUNS_ON: ${{ github.repository == 'apache/airflow' && 'self-hosted' || 'ubuntu-20.04' }} + if: contains(fromJSON('[ + "ashb", + "ephraimbuddy", + "jedcunningham", + "kaxil", + "potiuk", + ]'), github.event.sender.login) + steps: + - name: Cleanup repo + run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" + - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" + uses: actions/checkout@v3 + with: + persist-credentials: false + - name: "Install Breeze" + uses: ./.github/actions/breeze + - name: Build CI image for PROD build ${{ needs.build-info.outputs.defaultPythonVersion }} + run: breeze ci-image build + env: + PYTHON_MAJOR_MINOR_VERSION: ${{ needs.build-info.outputs.defaultPythonVersion }} + - name: "Cleanup dist and context file" + run: rm -fv ./dist/* ./docker-context-files/* + - name: "Start ARM instance" + run: ./scripts/ci/images/ci_start_arm_instance_and_connect_to_docker.sh + if: github.repository == 'apache/airflow' + - name: "Login to docker" + run: > + echo ${{ secrets.DOCKERHUB_TOKEN }} | + docker login --password-stdin --username ${{ secrets.DOCKERHUB_USER }} + - name: > + Release regular images: ${{ github.event.inputs.airflowVersion }}, ${{ matrix.python-version }} + run: > + breeze release-management release-prod-images + --dockerhub-repo ${{ github.repository }} + --airflow-version ${{ github.event.inputs.airflowVersion }} + ${{ needs.build-info.outputs.skipLatest }} + ${{ needs.build-info.outputs.limitPlatform }} + --limit-python ${{ matrix.python-version }} + - name: > + Release slim images: ${{ github.event.inputs.airflowVersion }}, ${{ matrix.python-version }} + run: > + breeze release-management release-prod-images + --dockerhub-repo ${{ github.repository }} + --airflow-version ${{ github.event.inputs.airflowVersion }} + ${{ needs.build-info.outputs.skipLatest }} + ${{ needs.build-info.outputs.limitPlatform }} + --limit-python ${{ matrix.python-version }} --slim-images + - name: "Stop ARM instance" + run: ./scripts/ci/images/ci_stop_arm_instance.sh + if: always() && github.repository == 'apache/airflow' + - name: > + Verify regular AMD64 image: ${{ github.event.inputs.airflowVersion }}, ${{ matrix.python-version }} + run: > + breeze prod-image verify + --pull + --image-name + ${{github.repository}}:${{github.event.inputs.airflowVersion}}-python${{matrix.python-version}} + - name: > + Verify slim AMD64 image: ${{ github.event.inputs.airflowVersion }}, ${{ matrix.python-version }} + run: > + breeze prod-image verify + --pull + --slim-image + --image-name + ${{github.repository}}:slim-${{github.event.inputs.airflowVersion}}-python${{matrix.python-version}} + - name: "Docker logout" + run: docker logout + if: always() + - name: "Fix ownership" + run: breeze ci fix-ownership + if: always() diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index a1bdd2b3d4d02..0dc1e093885e9 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -29,7 +29,7 @@ jobs: stale: runs-on: ubuntu-20.04 steps: - - uses: actions/stale@v4 + - uses: actions/stale@v5 with: stale-pr-message: > This pull request has been automatically marked as stale because it has not had diff --git a/.github/workflows/sync_authors.yml b/.github/workflows/sync_authors.yml new file mode 100644 index 0000000000000..5ef30ba9c4093 --- /dev/null +++ b/.github/workflows/sync_authors.yml @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +--- +name: Sync authors list + +on: # yamllint disable-line rule:truthy + schedule: + # min hr dom mon dow + - cron: '11 01 * * *' # daily at 1.11am + workflow_dispatch: + # only users with write access to apache/airflow can run manually + # https://docs.github.com/en/actions/managing-workflow-runs/manually-running-a-workflow + +permissions: + contents: write + pull-requests: write + +jobs: + sync: + name: Sync + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Requests + run: | + pip install requests toml + + - name: Sync the authors list + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + python scripts/ci/runners/sync_authors.py + git config user.name "GitHub Actions" + git config user.email "actions@users.noreply.github.com" + if [ -n "$(git status --porcelain)" ]; then + branch=update-$(date +%s) + git add -A + git checkout -b $branch + git commit --message "Authors list automatic update" + git push origin $branch + gh pr create --title "Authors list automatic update" --body '' + fi diff --git a/.gitignore b/.gitignore index 9a00d53fa3bda..edd1362f96c3e 100644 --- a/.gitignore +++ b/.gitignore @@ -13,11 +13,11 @@ unittests.db # Airflow temporary artifacts airflow/git_version airflow/www/static/coverage/ -airflow/www/static/dist - +airflow/www/*.log /logs/ airflow-webserver.pid standalone_admin_password.txt +warnings.txt # Byte-compiled / optimized / DLL files __pycache__/ @@ -206,9 +206,6 @@ log.txt* # Terraform variables *.tfvars -# Chart dependencies -**/charts/*.tgz - # Might be generated when you build wheels pip-wheel-metadata @@ -225,3 +222,11 @@ licenses/LICENSES-ui.txt # Packaged breeze on Windows /breeze.exe + +# Generated out dir + +/out + +# files generated by memray +*.py.*.html +*.py.*.bin diff --git a/.gitmodules b/.gitmodules index e03978e263653..aa1358f88496d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,9 +1,6 @@ [submodule ".github/actions/get-workflow-origin"] path = .github/actions/get-workflow-origin url = https://github.com/potiuk/get-workflow-origin -[submodule ".github/actions/checks-action"] - path = .github/actions/checks-action - url = https://github.com/LouisBrunner/checks-action [submodule ".github/actions/configure-aws-credentials"] path = .github/actions/configure-aws-credentials url = https://github.com/aws-actions/configure-aws-credentials @@ -13,6 +10,3 @@ [submodule ".github/actions/github-push-action"] path = .github/actions/github-push-action url = https://github.com/ad-m/github-push-action -[submodule ".github/actions/label-when-approved-action"] - path = .github/actions/label-when-approved-action - url = https://github.com/TobKed/label-when-approved-action diff --git a/.gitpod.yml b/.gitpod.yml index b4115c3801285..f255f38dd8f40 100644 --- a/.gitpod.yml +++ b/.gitpod.yml @@ -20,7 +20,7 @@ # Docs: https://www.gitpod.io/docs/config-gitpod-file/ tasks: - - init: ./breeze-legacy -y + - init: ./scripts/ci/install_breeze.sh - name: Install pre-commit openMode: split-right command: | diff --git a/.mailmap b/.mailmap index e8cfb301a051f..c3274e8700623 100644 --- a/.mailmap +++ b/.mailmap @@ -40,7 +40,7 @@ Gerard Toonstra Greg Neiheisel Hossein Torabi James Timmins -Jarek Potiuk +Jarek Potiuk Jeremiah Lowin Jeremiah Lowin Jeremiah Lowin diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7de8e753565ea..248029bfe582c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,8 +17,8 @@ --- default_stages: [commit, push] default_language_version: - # force all unspecified python hooks to run python3 python: python3 + node: 18.6.0 minimum_pre_commit_version: "2.0.0" repos: - repo: meta @@ -41,13 +41,13 @@ repos: - repo: https://github.com/Lucas-C/pre-commit-hooks rev: v1.2.0 hooks: - - id: forbid-tabs - name: Fail if tabs are used in the project - exclude: ^airflow/_vendor/|^clients/gen/go\.sh$|^\.gitmodules$ - id: insert-license name: Add license for all SQL files files: \.sql$ - exclude: ^\.github/.*$|^airflow/_vendor/ + exclude: | + (?x) + ^\.github/| + ^airflow/_vendor/ args: - --comment-style - "/*||*/" @@ -87,7 +87,7 @@ repos: - id: insert-license name: Add license for all shell files exclude: ^\.github/.*$|^airflow/_vendor/|^dev/breeze/autocomplete/.*$ - files: ^breeze-legacy$|^breeze-complete$|\.bash$|\.sh$ + files: \.bash$|\.sh$ args: - --comment-style - "|#|" @@ -146,19 +146,30 @@ repos: - --fuzzy-match-generates-todo files: > \.cfg$|\.conf$|\.ini$|\.ldif$|\.properties$|\.readthedocs$|\.service$|\.tf$|Dockerfile.*$ - # Keep version of black in sync wit blackend-docs and pre-commit-hook-names + - repo: https://github.com/PyCQA/isort + rev: 5.11.2 + hooks: + - id: isort + name: Run isort to sort imports in Python files + # Keep version of black in sync wit blacken-docs and pre-commit-hook-names - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 22.12.0 hooks: - id: black - name: Run Black (the uncompromising Python code formatter) + name: Run black (python formatter) args: [--config=./pyproject.toml] - exclude: ^airflow/_vendor/ + exclude: ^airflow/_vendor/|^airflow/contrib/ - repo: https://github.com/asottile/blacken-docs rev: v1.12.1 hooks: - id: blacken-docs name: Run black on python code blocks in documentation files + args: + - --line-length=110 + - --target-version=py37 + - --target-version=py38 + - --target-version=py39 + - --target-version=py310 alias: black additional_dependencies: [black==22.3.0] - repo: https://github.com/pre-commit/pre-commit-hooks @@ -175,7 +186,7 @@ repos: name: Detect if private key is added to the repository - id: end-of-file-fixer name: Make sure that there is an empty line at the end - exclude: ^airflow/_vendor/ + exclude: ^airflow/_vendor/|^docs/apache-airflow/img/.*\.dot|^docs/apache-airflow/img/.*\.sha256 - id: mixed-line-ending name: Detect if mixed line ending is used (\r vs. \r\n) exclude: ^airflow/_vendor/ @@ -187,7 +198,7 @@ repos: exclude: ^airflow/_vendor/ - id: trailing-whitespace name: Remove trailing whitespace at end of line - exclude: ^airflow/_vendor/|^images/breeze/output.*$ + exclude: ^airflow/_vendor/|^images/breeze/output.*$|^docs/apache-airflow/img/.*\.dot|^docs/apache-airflow/img/.*\.dot - id: fix-encoding-pragma name: Remove encoding header from python files exclude: ^airflow/_vendor/ @@ -204,7 +215,7 @@ repos: pass_filenames: true # TODO: Bump to Python 3.8 when support for Python 3.7 is dropped in Airflow. - repo: https://github.com/asottile/pyupgrade - rev: v2.32.1 + rev: v3.3.1 hooks: - id: pyupgrade name: Upgrade Python code automatically @@ -227,14 +238,6 @@ repos: entry: yamllint -c yamllint-config.yml --strict types: [yaml] exclude: ^.*init_git_sync\.template\.yaml$|^.*airflow\.template\.yaml$|^chart/(?:templates|files)/.*\.yaml$|openapi/.*\.yaml$|^\.pre-commit-config\.yaml$|^airflow/_vendor/ - - repo: https://github.com/PyCQA/isort - rev: 5.10.1 - hooks: - - id: isort - name: Run isort to sort imports in Python files - files: \.py$|\.pyi$ - # To keep consistent with the global isort skip config defined in setup.cfg - exclude: ^airflow/_vendor/|^build/.*$|^venv/.*$|^\.tox/.*$ - repo: https://github.com/pycqa/pydocstyle rev: 6.1.1 hooks: @@ -242,7 +245,7 @@ repos: name: Run pydocstyle args: - --convention=pep257 - - --add-ignore=D100,D102,D103,D104,D105,D107,D202,D205,D400,D401 + - --add-ignore=D100,D102,D103,D104,D105,D107,D205,D400,D401 exclude: | (?x) ^tests/.*\.py$| @@ -256,7 +259,7 @@ repos: ^airflow/_vendor/ additional_dependencies: ['toml'] - repo: https://github.com/asottile/yesqa - rev: v1.3.0 + rev: v1.4.0 hooks: - id: yesqa name: Remove unnecessary noqa statements @@ -265,7 +268,7 @@ repos: ^airflow/_vendor/ additional_dependencies: ['flake8>=4.0.1'] - repo: https://github.com/ikamensh/flynt - rev: '0.76' + rev: '0.77' hooks: - id: flynt name: Run flynt string format converter for Python @@ -288,13 +291,20 @@ repos: The word(s) should be in lowercase." && exec codespell "$@"' -- language: python types: [text] - exclude: ^airflow/_vendor/|^RELEASE_NOTES\.txt$|^airflow/www/static/css/material-icons\.css$|^images/.*$ + exclude: ^airflow/_vendor/|^RELEASE_NOTES\.txt$|^airflow/www/static/css/material-icons\.css$|^images/.*$|^.*package-lock.json$ args: - --ignore-words=docs/spelling_wordlist.txt - - --skip=docs/*/commits.rst,airflow/providers/*/*.rst,*.lock,INTHEWILD.md,*.min.js,docs/apache-airflow/pipeline_example.csv + - --skip=docs/*/commits.rst,airflow/providers/*/*.rst,*.lock,INTHEWILD.md,*.min.js,docs/apache-airflow/tutorial/pipeline_example.csv,airflow/www/*.log - --exclude-file=.codespellignorelines - repo: local hooks: + - id: replace-bad-characters + name: Replace bad characters + entry: ./scripts/ci/pre_commit/pre_commit_replace_bad_characters.py + language: python + types: [file, text] + exclude: ^airflow/_vendor/|^clients/gen/go\.sh$|^\.gitmodules$ + additional_dependencies: ['rich>=12.4.4'] - id: static-check-autoflake name: Remove all unused code entry: autoflake --remove-all-unused-imports --ignore-init-module-imports --in-place @@ -313,8 +323,8 @@ repos: files: ^airflow/api_connexion/openapi/ - id: lint-dockerfile name: Lint dockerfile - language: system - entry: ./scripts/ci/pre_commit/pre_commit_lint_dockerfile.sh + language: python + entry: ./scripts/ci/pre_commit/pre_commit_lint_dockerfile.py files: Dockerfile.*$ pass_filenames: true require_serial: true @@ -324,7 +334,7 @@ repos: files: ^setup\.cfg$|^setup\.py$ pass_filenames: false entry: ./scripts/ci/pre_commit/pre_commit_check_order_setup.py - additional_dependencies: ['rich>=12.4.1'] + additional_dependencies: ['rich>=12.4.4'] - id: check-extra-packages-references name: Checks setup extra packages description: Checks if all the libraries in setup.py are listed in extra-packages-ref.rst file @@ -332,42 +342,28 @@ repos: files: ^setup\.py$|^docs/apache-airflow/extra-packages-ref\.rst$ pass_filenames: false entry: ./scripts/ci/pre_commit/pre_commit_check_setup_extra_packages_ref.py - additional_dependencies: ['rich>=12.4.1'] - # This check might be removed when min-airflow-version in providers is 2.2 - - id: check-airflow-2-1-compatibility - name: Check that providers are 2.1 compatible. - entry: ./scripts/ci/pre_commit/pre_commit_check_2_1_compatibility.py + additional_dependencies: ['rich>=12.4.4'] + - id: check-airflow-provider-compatibility + name: Check compatibility of Providers with Airflow + entry: ./scripts/ci/pre_commit/pre_commit_check_provider_airflow_compatibility.py language: python pass_filenames: true files: ^airflow/providers/.*\.py$ - additional_dependencies: ['rich>=12.4.1'] - - id: update-breeze-file - name: Update output of breeze commands in BREEZE.rst - entry: ./scripts/ci/pre_commit/pre_commit_breeze_cmd_line.py - language: python - files: ^BREEZE\.rst$|^dev/breeze/.*$ - pass_filenames: false - additional_dependencies: ['rich>=12.4.1', 'rich-click'] + additional_dependencies: ['rich>=12.4.4'] - id: update-local-yml-file name: Update mounts in the local yml file entry: ./scripts/ci/pre_commit/pre_commit_local_yml_mounts.py language: python files: ^dev/breeze/src/airflow_breeze/utils/docker_command_utils\.py$|^scripts/ci/docker_compose/local\.yml$ pass_filenames: false - additional_dependencies: ['rich>=12.4.1'] - - id: update-setup-cfg-file - name: Update setup.cfg file with all licenses - entry: ./scripts/ci/pre_commit/pre_commit_setup_cfg_file.sh - language: system - files: ^setup\.cfg$ - pass_filenames: false + additional_dependencies: ['rich>=12.4.4'] - id: update-providers-dependencies name: Update cross-dependencies for providers packages - entry: ./scripts/ci/pre_commit/pre_commit_build_providers_dependencies.sh + entry: ./scripts/ci/pre_commit/pre_commit_update_providers_dependencies.py language: python - files: ^airflow/providers/.*\.py$|^tests/providers/.*\.py$ + files: ^airflow/providers/.*\.py$|^tests/providers/.*\.py$|^tests/system/providers/.*\.py$|^airflow/providers/.*/provider.yaml$ pass_filenames: false - additional_dependencies: ['setuptools'] + additional_dependencies: ['setuptools', 'rich>=12.4.4', 'pyyaml'] - id: update-extras name: Update extras in documentation entry: ./scripts/ci/pre_commit/pre_commit_insert_extras.py @@ -380,7 +376,7 @@ repos: language: python files: ^Dockerfile$ pass_filenames: false - additional_dependencies: ['rich>=12.4.1'] + additional_dependencies: ['rich>=12.4.4'] - id: update-supported-versions name: Updates supported versions in documentation entry: ./scripts/ci/pre_commit/pre_commit_supported_versions.py @@ -393,6 +389,7 @@ repos: language: python entry: ./scripts/ci/pre_commit/pre_commit_version_heads_map.py pass_filenames: false + additional_dependencies: ['packaging'] - id: update-version name: Update version to the latest version in the documentation entry: ./scripts/ci/pre_commit/pre_commit_update_versions.py @@ -405,6 +402,19 @@ repos: entry: "pydevd.*settrace\\(" pass_filenames: true files: \.py$ + - id: check-links-to-example-dags-do-not-use-hardcoded-versions + name: Check that example dags do not use hard-coded version numbers + description: The links to example dags should use |version| as version specification + language: pygrep + entry: > + (?i) + .*https://github.*[0-9]/tests/system/providers| + .*https://github.*/main/tests/system/providers| + .*https://github.*/master/tests/system/providers| + .*https://github.*/main/airflow/providers/.*/example_dags/| + .*https://github.*/master/airflow/providers/.*/example_dags/ + pass_filenames: true + files: ^docs/apache-airflow-providers-.*\.rst - id: check-safe-filter-usage-in-html language: pygrep name: Don't use safe in templates @@ -422,11 +432,11 @@ repos: - id: check-no-relative-imports language: pygrep name: No relative imports - description: Airflow style is to use absolute imports only + description: Airflow style is to use absolute imports only (except docs building) entry: "^\\s*from\\s+\\." pass_filenames: true files: \.py$ - exclude: ^tests/|^airflow/_vendor/ + exclude: ^tests/|^airflow/_vendor/|^docs/ - id: check-for-inclusive-language language: pygrep name: Check for language that we do not accept as community @@ -449,12 +459,13 @@ repos: ^airflow/www/static/| ^airflow/providers/| ^tests/providers/apache/cassandra/hooks/test_cassandra.py$| + ^tests/integration/providers/apache/cassandra/hooks/test_cassandra.py$| + ^tests/system/providers/apache/spark/example_spark_dag.py$| ^docs/apache-airflow-providers-apache-cassandra/connections/cassandra.rst$| ^docs/apache-airflow-providers-apache-hive/commits.rst$| ^airflow/api_connexion/openapi/v1.yaml$| ^tests/cli/commands/test_webserver_command.py$| ^airflow/cli/commands/webserver_command.py$| - ^airflow/ui/yarn.lock$| ^airflow/config_templates/default_airflow.cfg$| ^airflow/config_templates/config.yml$| ^docs/*.*$| @@ -468,7 +479,13 @@ repos: language: python entry: ./scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py pass_filenames: false - files: ^airflow/models/(?:base|mapped)operator.py$ + files: ^airflow/models/(?:base|mapped)operator\.py$ + - id: check-init-decorator-arguments + name: Check model __init__ and decorator arguments are in sync + language: python + entry: ./scripts/ci/pre_commit/pre_commit_sync_init_decorator.py + pass_filenames: false + files: ^airflow/models/dag\.py$|^airflow/(?:decorators|utils)/task_group.py$ - id: check-base-operator-usage language: pygrep name: Check BaseOperator[Link] core imports @@ -494,6 +511,19 @@ repos: (?x) ^airflow/providers/.*\.py$ exclude: ^airflow/_vendor/ + - id: check-decorated-operator-implements-custom-name + name: Check @task decorator implements custom_operator_name + language: python + entry: ./scripts/ci/pre_commit/pre_commit_decorator_operator_implements_custom_name.py + pass_filenames: true + files: ^airflow/.*\.py$ + - id: check-core-deprecation-classes + language: pygrep + name: Verify using of dedicated Airflow deprecation classes in core + entry: category=DeprecationWarning|category=PendingDeprecationWarning + files: \.py$ + exclude: ^airflow/configuration.py|^airflow/providers|^scripts/in_container/verify_providers.py + pass_filenames: true - id: check-provide-create-sessions-imports language: pygrep name: Check provide_session and create_session imports @@ -524,16 +554,10 @@ repos: files: \.*example_dags.*\.py$ exclude: ^airflow/_vendor/ pass_filenames: true - - id: check-integrations-are-consistent - name: Check if integration list is consistent in various places - entry: ./scripts/ci/pre_commit/pre_commit_check_integrations.sh - language: system - pass_filenames: false - files: ^common/_common_values\.sh$|^breeze-complete$ - id: check-apache-license-rat name: Check if licenses are OK for Apache - entry: ./scripts/ci/pre_commit/pre_commit_check_license.sh - language: system + entry: ./scripts/ci/pre_commit/pre_commit_check_license.py + language: python files: ^.*LICENSE.*$|^.*LICENCE.*$ pass_filenames: false - id: check-airflow-config-yaml-consistent @@ -553,28 +577,31 @@ repos: additional_dependencies: ['pyyaml', 'termcolor==1.1.0', 'wcmatch==8.2'] - id: update-in-the-wild-to-be-sorted name: Sort INTHEWILD.md alphabetically - entry: ./scripts/ci/pre_commit/pre_commit_sort_in_the_wild.sh - language: system + entry: ./scripts/ci/pre_commit/pre_commit_sort_in_the_wild.py + language: python files: ^\.pre-commit-config\.yaml$|^INTHEWILD\.md$ + pass_filenames: false require_serial: true - id: update-spelling-wordlist-to-be-sorted name: Sort alphabetically and uniquify spelling_wordlist.txt - entry: ./scripts/ci/pre_commit/pre_commit_sort_spelling_wordlist.sh - language: system + entry: ./scripts/ci/pre_commit/pre_commit_sort_spelling_wordlist.py + language: python files: ^\.pre-commit-config\.yaml$|^docs/spelling_wordlist\.txt$ require_serial: true + pass_filenames: false - id: lint-helm-chart name: Lint Helm Chart - entry: ./scripts/ci/pre_commit/pre_commit_helm_lint.sh - language: system + entry: ./scripts/ci/pre_commit/pre_commit_helm_lint.py + language: python pass_filenames: false files: ^chart require_serial: true + additional_dependencies: ['rich>=12.4.4','requests'] - id: run-shellcheck name: Check Shell scripts syntax correctness language: docker_image entry: koalaman/shellcheck:v0.8.0 -x -a - files: ^breeze-legacy$|^breeze-complete$|\.sh$|^hooks/build$|^hooks/push$|\.bash$ + files: ^.*\.sh$|^hooks/build$|^hooks/push$|\.bash$ exclude: ^dev/breeze/autocomplete/.*$ - id: lint-css name: stylelint @@ -583,12 +610,30 @@ repos: files: ^airflow/www/.*\.(css|scss|sass)$ # Keep dependency versions in sync w/ airflow/www/package.json additional_dependencies: ['stylelint@13.3.1', 'stylelint-config-standard@20.0.0'] + - id: compile-www-assets + name: Compile www assets + language: node + stages: ['manual'] + 'types_or': [javascript, tsx, ts] + files: ^airflow/www/ + entry: ./scripts/ci/pre_commit/pre_commit_compile_www_assets.py + pass_filenames: false + additional_dependencies: ['yarn@1.22.19'] + - id: compile-www-assets-dev + name: Compile www assets in dev mode + language: node + stages: ['manual'] + 'types_or': [javascript, tsx, ts] + files: ^airflow/www/ + entry: ./scripts/ci/pre_commit/pre_commit_compile_www_assets_dev.py + pass_filenames: false + additional_dependencies: ['yarn@1.22.19'] - id: check-providers-init-file-missing name: Provider init file is missing pass_filenames: false always_run: true - entry: ./scripts/ci/pre_commit/pre_commit_check_providers_init.sh - language: system + entry: ./scripts/ci/pre_commit/pre_commit_check_providers_init.py + language: python - id: check-providers-subpackages-init-file-exist name: Provider subpackage init files are there pass_filenames: false @@ -596,19 +641,6 @@ repos: entry: ./scripts/ci/pre_commit/pre_commit_check_providers_subpackages_all_have_init.py language: python require_serial: true - - id: check-provider-yaml-valid - name: Validate providers.yaml files - pass_filenames: false - entry: ./scripts/ci/pre_commit/pre_commit_check_provider_yaml_files.py - language: python - require_serial: true - files: ^docs/|provider\.yaml$|^scripts/ci/pre_commit/pre_commit_check_provider_yaml_files\.py$ - additional_dependencies: - - 'PyYAML==5.3.1' - - 'jsonschema>=3.2.0,<5.0.0' - - 'tabulate==0.8.8' - - 'jsonpath-ng==1.5.3' - - 'rich>=12.4.1' - id: check-pre-commit-information-consistent name: Update information re pre-commit hooks and verify ids and names entry: ./scripts/ci/pre_commit/pre_commit_check_pre_commit_hooks.py @@ -616,22 +648,14 @@ repos: - --max-length=64 language: python files: ^\.pre-commit-config\.yaml$|^scripts/ci/pre_commit/pre_commit_check_pre_commit_hook_names\.py$ - additional_dependencies: ['pyyaml', 'jinja2', 'black==22.3.0', 'tabulate', 'rich>=12.4.1'] + additional_dependencies: ['pyyaml', 'jinja2', 'black==22.3.0', 'tabulate', 'rich>=12.4.4'] require_serial: true pass_filenames: false - - id: check-airflow-providers-have-extras - name: Checks providers available when declared by extras in setup.py - language: python - entry: ./scripts/ci/pre_commit/pre_commit_check_extras_have_providers.py - files: ^setup\.py$|^airflow/providers/.*\.py$ - pass_filenames: false - require_serial: true - additional_dependencies: ['rich>=12.4.1'] - id: update-breeze-readme-config-hash name: Update Breeze README.md with config files hash language: python entry: ./scripts/ci/pre_commit/pre_commit_update_breeze_config_hash.py - files: ^dev/breeze/setup.*$|^dev/breeze/pyproject.toml$|^dev/breeze/README.md$ + files: dev/breeze/setup.py|dev/breeze/setup.cfg|dev/breeze/pyproject.toml|dev/breeze/README.md pass_filenames: false require_serial: true - id: check-breeze-top-dependencies-limited @@ -641,15 +665,15 @@ repos: files: ^dev/breeze/.*$ pass_filenames: false require_serial: true - additional_dependencies: ['click', 'rich>=12.4.1'] + additional_dependencies: ['click', 'rich>=12.4.4'] - id: check-system-tests-present name: Check if system tests have required segments of code entry: ./scripts/ci/pre_commit/pre_commit_check_system_tests.py language: python files: ^tests/system/.*/example_[^/]*.py$ - exclude: ^tests/system/providers/google/bigquery/example_bigquery_queries\.py$ + exclude: ^tests/system/providers/google/cloud/bigquery/example_bigquery_queries\.py$ pass_filenames: true - additional_dependencies: ['rich>=12.4.1'] + additional_dependencies: ['rich>=12.4.4'] - id: lint-markdown name: Run markdownlint description: Checks the style of Markdown files. @@ -740,7 +764,8 @@ repos: language: python pass_filenames: true files: ^\.github/workflows/.*\.yml$ - additional_dependencies: ['PyYAML', 'rich>=12.4.1'] + exclude: ^\.github/workflows/sync_authors\.yml$ + additional_dependencies: ['PyYAML', 'rich>=12.4.4'] - id: check-docstring-param-types name: Check that docstrings do not specify param types entry: ./scripts/ci/pre_commit/pre_commit_docstring_param_type.py @@ -748,7 +773,7 @@ repos: pass_filenames: true files: \.py$ exclude: ^airflow/_vendor/ - additional_dependencies: ['rich>=12.4.1'] + additional_dependencies: ['rich>=12.4.4'] - id: lint-chart-schema name: Lint chart/values.schema.json file entry: ./scripts/ci/pre_commit/pre_commit_chart_schema.py @@ -778,6 +803,52 @@ repos: # We sometimes won't have newsfragments in the repo, so always run it so `check-hooks-apply` passes # This is fast, so not too much downside always_run: true + - id: update-breeze-cmd-output + name: Update output of breeze commands in BREEZE.rst + entry: ./scripts/ci/pre_commit/pre_commit_breeze_cmd_line.py + language: python + files: ^BREEZE\.rst$|^dev/breeze/.*$|^\.pre-commit-config\.yaml$ + require_serial: true + pass_filenames: false + additional_dependencies: ['rich>=12.4.4', 'rich-click>=1.5', 'inputimeout'] + - id: check-example-dags-urls + name: Check that example dags url include provider versions + entry: ./scripts/ci/pre_commit/pre_commit_update_example_dags_paths.py + language: python + pass_filenames: true + files: ^docs/.*index\.rst$|^docs/.*example-dags\.rst$ + additional_dependencies: ['rich>=12.4.4', 'pyyaml'] + always_run: true + - id: check-system-tests-tocs + name: Check that system tests is properly added + entry: ./scripts/ci/pre_commit/pre_commit_check_system_tests_hidden_in_index.py + language: python + pass_filenames: true + files: ^docs/apache-airflow-providers-[^/]*/index\.rst$ + additional_dependencies: ['rich>=12.4.4', 'pyyaml'] + - id: check-lazy-logging + name: Check that all logging methods are lazy + entry: ./scripts/ci/pre_commit/pre_commit_check_lazy_logging.py + language: python + pass_filenames: true + files: \.py$ + exclude: ^airflow/_vendor/ + additional_dependencies: ['rich>=12.4.4', 'astor'] + - id: create-missing-init-py-files-tests + name: Create missing init.py files in tests + entry: ./scripts/ci/pre_commit/pre_commit_check_init_in_tests.py + language: python + additional_dependencies: ['rich>=12.4.4'] + pass_filenames: false + files: ^tests/.*\.py$ + - id: ts-compile-and-lint-javascript + name: TS types generation and ESLint against current UI files + language: node + 'types_or': [javascript, tsx, ts, yaml] + files: ^airflow/www/static/js/|^airflow/api_connexion/openapi/v1.yaml + entry: ./scripts/ci/pre_commit/pre_commit_www_lint.py + additional_dependencies: ['yarn@1.22.19'] + pass_filenames: false ## ADD MOST PRE-COMMITS ABOVE THAT LINE # The below pre-commits are those requiring CI image to be built - id: run-mypy @@ -786,22 +857,22 @@ repos: entry: ./scripts/ci/pre_commit/pre_commit_mypy.py files: ^dev/.*\.py$ require_serial: true - additional_dependencies: ['rich>=12.4.1'] + additional_dependencies: ['rich>=12.4.4', 'inputimeout'] - id: run-mypy name: Run mypy for core language: python entry: ./scripts/ci/pre_commit/pre_commit_mypy.py --namespace-packages files: \.py$ - exclude: ^provider_packages|^docs|^airflow/_vendor/|^airflow/providers|^airflow/migrations|^dev + exclude: ^provider_packages|^docs|^airflow/_vendor/|^airflow/providers|^airflow/migrations|^dev|^tests/system/providers|^tests/providers require_serial: true - additional_dependencies: ['rich>=12.4.1'] + additional_dependencies: ['rich>=12.4.4', 'inputimeout'] - id: run-mypy name: Run mypy for providers language: python entry: ./scripts/ci/pre_commit/pre_commit_mypy.py --namespace-packages - files: ^airflow/providers/.*\.py$ + files: ^airflow/providers/.*\.py$|^tests/system/providers/\*.py|^tests/providers/\*.py require_serial: true - additional_dependencies: ['rich>=12.4.1'] + additional_dependencies: ['rich>=12.4.4', 'inputimeout'] - id: run-mypy name: Run mypy for /docs/ folder language: python @@ -809,7 +880,7 @@ repos: files: ^docs/.*\.py$ exclude: ^docs/rtd-deprecation require_serial: true - additional_dependencies: ['rich>=12.4.1'] + additional_dependencies: ['rich>=12.4.4', 'inputimeout'] - id: run-flake8 name: Run flake8 language: python @@ -817,28 +888,27 @@ repos: files: \.py$ pass_filenames: true exclude: ^airflow/_vendor/ - additional_dependencies: ['rich>=12.4.1'] - - id: lint-javascript - name: ESLint against airflow/ui - language: python - 'types_or': [javascript, tsx, ts] - files: ^airflow/ui/ - entry: ./scripts/ci/pre_commit/pre_commit_ui_lint.py + additional_dependencies: ['rich>=12.4.4', 'inputimeout'] + - id: check-provider-yaml-valid + name: Validate provider.yaml files pass_filenames: false - additional_dependencies: ['rich>=12.4.1'] - - id: lint-javascript - name: ESLint against current UI JavaScript files + entry: ./scripts/ci/pre_commit/pre_commit_check_provider_yaml_files.py language: python - 'types_or': [javascript] - files: ^airflow/www/static/js/ - entry: ./scripts/ci/pre_commit/pre_commit_www_lint.py - pass_filenames: false - additional_dependencies: ['rich>=12.4.1'] + require_serial: true + files: ^docs/|provider\.yaml$|^scripts/ci/pre_commit/pre_commit_check_provider_yaml_files\.py$ + additional_dependencies: ['rich>=12.4.4', 'inputimeout', 'markdown-it-py'] - id: update-migration-references name: Update migration ref doc language: python entry: ./scripts/ci/pre_commit/pre_commit_migration_reference.py pass_filenames: false files: ^airflow/migrations/versions/.*\.py$|^docs/apache-airflow/migrations-ref\.rst$ - additional_dependencies: ['rich>=12.4.1'] + additional_dependencies: ['rich>=12.4.4', 'inputimeout', 'markdown-it-py'] + - id: update-er-diagram + name: Update ER diagram + language: python + entry: ./scripts/ci/pre_commit/pre_commit_update_er_diagram.py + pass_filenames: false + files: ^airflow/migrations/versions/.*\.py$|^docs/apache-airflow/migrations-ref\.rst$ + additional_dependencies: ['rich>=12.4.4'] ## ONLY ADD PRE-COMMITS HERE THAT REQUIRE CI IMAGE diff --git a/.rat-excludes b/.rat-excludes index fa4663ce65188..1e16d61a67f5d 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -1,6 +1,7 @@ # Note: these patterns are applied to single files or directories, not full paths # coverage/* will ignore any coverage dir, but airflow/www/static/coverage/* will match nothing +.git-blame-ignore-revs .github/* .gitignore .gitattributes @@ -77,6 +78,7 @@ flake8_diff.sh coverage*.xml _sources/* +robots.txt rat-results.txt apache-airflow-.*\+source.tar.gz.* apache-airflow-.*\+bin.tar.gz.* @@ -113,11 +115,21 @@ chart/values.schema.json chart/Chart.lock chart/values_schema.schema.json -# A simplistic Robots.txt -airflow/www/static/robots.txt - # Generated autocomplete files -dev/breeze/autocomplete/* +./dev/breeze/autocomplete/* # Newsfragments are snippets that will be, eventually, consumed into RELEASE_NOTES newsfragments/* + +# Warning file generated +warnings.txt + +# Dev stuff +tests/* +scripts/* +images/* +dev/* +chart/*.iml + +# Sha files +.*sha256 diff --git a/BREEZE.rst b/BREEZE.rst index b80db9c45356d..341f8b8dbeede 100644 --- a/BREEZE.rst +++ b/BREEZE.rst @@ -36,29 +36,6 @@ We call it *Airflow Breeze* as **It's a Breeze to contribute to Airflow**. The advantages and disadvantages of using the Breeze environment vs. other ways of testing Airflow are described in `CONTRIBUTING.rst `_. -.. note:: - We are currently migrating old Bash-based ./breeze-legacy to the Python-based breeze. Some of the - commands are already converted to breeze, but some old commands should use breeze-legacy. The - documentation mentions when ``./breeze-legacy`` is involved. - - The new ``breeze`` after installing is available on your PATH and you should launch it simply as - ``breeze ``. Previously you had to prepend breeze with ``./`` but this is not needed - any more. For convenience, we will keep ``./breeze`` script for a while to run the new breeze and you - can still use the legacy Breeze with ``./breeze-legacy``. - -Watch the video below about Airflow Breeze. It explains the motivation for Breeze -and screencast all its uses. The video describes old ``./breeze-legacy`` (in video it still -called ``./breeze`` ). - -.. raw:: html - -
- - Airflow Breeze - Development and Test Environment for Apache Airflow - -
- Prerequisites ============= @@ -89,6 +66,15 @@ Here is an example configuration with more than 200GB disk space for Docker: alt="Disk space MacOS"> + +- **Docker is not running** - even if it is running with Docker Desktop. This is an issue + specific to Docker Desktop 4.13.0 (released in late October 2022). Please upgrade Docker + Desktop to 4.13.1 or later to resolve the issue. For technical details, see also + `docker/for-mac#6529 `_. + +Note: If you use Colima, please follow instructions at: `Contributors Quick Start Guide `__ + Docker Compose -------------- @@ -174,6 +160,12 @@ environments. This can be done automatically by the following command (follow in pipx ensurepath +In Mac + +.. code-block:: bash + + python -m pipx ensurepath + Resources required ================== @@ -201,7 +193,7 @@ periodically. For details see On WSL2 you might want to increase your Virtual Hard Disk by following: `Expanding the size of your WSL 2 Virtual Hard Disk `_ -There is a command ``breeze resource-check`` that you can run to check available resources. See below +There is a command ``breeze ci resource-check`` that you can run to check available resources. See below for details. Cleaning the environment @@ -243,24 +235,19 @@ command. Those are all available commands for Breeze and details about the commands are described below: .. image:: ./images/breeze/output-commands.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output-commands.svg :width: 100% :alt: Breeze commands Breeze installed this way is linked to your checked out sources of Airflow so Breeze will automatically use latest version of sources from ``./dev/breeze``. Sometimes, when dependencies are -updated ``breeze`` commands with offer you to ``self-upgrade`` (you just need to answer ``y`` when asked). +updated ``breeze`` commands with offer you to run self-upgrade. You can always run such self-upgrade at any time: .. code-block:: bash - breeze self-upgrade - -Those are all available flags of ``self-upgrade`` command: - -.. image:: ./images/breeze/output-self-upgrade.svg - :width: 100% - :alt: Breeze self-upgrade + breeze setup self-upgrade If you have several checked out Airflow sources, Breeze will warn you if you are using it from a different source tree and will offer you to re-install from those sources - to make sure that you are using the right @@ -272,17 +259,11 @@ By default Breeze works on the version of Airflow that you run it in - in case y sources of Airflow and you installed Breeze from a directory - Breeze will be run on Airflow sources from where it was installed. -You can run ``breeze version`` command to see where breeze installed from and what are the current sources +You can run ``breeze setup version`` command to see where breeze installed from and what are the current sources that Breeze works on -Those are all available flags of ``version`` command: - -.. image:: ./images/breeze/output-version.svg - :width: 100% - :alt: Breeze version - Running Breeze for the first time -================================= +--------------------------------- The First time you run Breeze, it pulls and builds a local version of Docker images. It pulls the latest Airflow CI images from the @@ -302,25 +283,17 @@ You should set up the autocomplete option automatically by running: .. code-block:: bash - breeze setup-autocomplete - -You get the auto-completion working when you re-enter the shell (follow the instructions printed). -The command will warn you and not reinstall autocomplete if you already did, but you can -also force reinstalling the autocomplete via: - -.. code-block:: bash - - breeze setup-autocomplete --force - -Those are all available flags of ``setup-autocomplete`` command: + breeze setup autocomplete -.. image:: ./images/breeze/output-setup-autocomplete.svg - :width: 100% - :alt: Breeze setup autocomplete +Automating breeze installation +------------------------------ +Breeze on POSIX-compliant systems (Linux, MacOS) can be automatically installed by running the +``scripts/tools/setup_breeze`` bash script. This includes checking and installing ``pipx``, setting up +``breeze`` with it and setting up autocomplete. -Customize your environment --------------------------- +Customizing your environment +---------------------------- When you enter the Breeze environment, automatically an environment file is sourced from ``files/airflow-breeze-config/variables.env``. @@ -329,6 +302,12 @@ You can also add ``files/airflow-breeze-config/init.sh`` and the script will be when you enter Breeze. For example you can add ``pip install`` commands if you want to install custom dependencies - but there are no limits to add your own customizations. +You can override the name of the init script by setting ``INIT_SCRIPT_FILE`` environment variable before +running the breeze environment. + +You can also customize your environment by setting ``BREEZE_INIT_COMMAND`` environment variable. This variable +will be evaluated at entering the environment. + The ``files`` folder from your local sources is automatically mounted to the container under ``/files`` path and you can put there any files you want to make available for the Breeze container. @@ -357,27 +336,18 @@ inside container, to enable modified tmux configurations. -Running tests in the CI interactive environment -=============================================== - -Breeze helps with running tests in the same environment/way as CI tests are run. You can run various -types of tests while you enter Breeze CI interactive environment - this is described in detail -in ``_ +Regular development tasks +========================= -Here is the part of Breeze video which is relevant (note that it refers to the old ``./breeze-legacy`` -command and it is not yet available in the new ``breeze`` command): - -.. raw:: html +The regular Breeze development tasks are available as top-level commands. Those tasks are most often +used during the development, that's why they are available without any sub-command. More advanced +commands are separated to sub-commands. -
- - Airflow Breeze - Running tests - -
+Entering Breeze shell +--------------------- -Choosing different Breeze environment configuration -=================================================== +This is the most often used feature of breeze. It simply allows to enter the shell inside the Breeze +development environment (inside the Breeze container). You can use additional ``breeze`` flags to choose your environment. You can specify a Python version to use, and backend (the meta-data database). Thanks to that, with Breeze, you can recreate the same @@ -398,20 +368,6 @@ default settings. You can see which value of the parameters that can be stored persistently in cache marked with >VALUE< in the help of the commands. -Another part of configuration is enabling/disabling cheatsheet, asciiart. The cheatsheet and asciiart can -be disabled - they are "nice looking" and cheatsheet -contains useful information for first time users but eventually you might want to disable both if you -find it repetitive and annoying. - -With the config setting colour-blind-friendly communication for Breeze messages. By default we communicate -with the users about information/errors/warnings/successes via colour-coded messages, but we can switch -it off by passing ``--no-colour`` to config in which case the messages to the user printed by Breeze -will be printed using different schemes (italic/bold/underline) to indicate different kind of messages -rather than colours. - -Here is the part of Breeze video which is relevant (note that it refers to the old ``./breeze-legacy`` -command but it is very similar to current ``breeze`` command): - .. raw:: html
@@ -421,199 +377,149 @@ command but it is very similar to current ``breeze`` command):
-Those are all available flags of ``config`` command: - -.. image:: ./images/breeze/output-config.svg - :width: 100% - :alt: Breeze config +Building the documentation +-------------------------- +To build documentation in Breeze, use the ``build-docs`` command: -You can also dump hash of the configuration options used - this is mostly use to generate the dump -of help of the commands only when they change. +.. code-block:: bash -.. image:: ./images/breeze/output-command-hash-export.svg - :width: 100% - :alt: Breeze command-hash-export + breeze build-docs +Results of the build can be found in the ``docs/_build`` folder. -Starting complete Airflow installation -====================================== +The documentation build consists of three steps: -For testing Airflow oyou often want to start multiple components (in multiple terminals). Breeze has -built-in ``start-airflow`` command that start breeze container, launches multiple terminals using tmux -and launches all Airflow necessary components in those terminals. +* verifying consistency of indexes +* building documentation +* spell checking -You can also use it to start any released version of Airflow from ``PyPI`` with the -``--use-airflow-version`` flag. +You can choose only one stage of the two by providing ``--spellcheck-only`` or ``--docs-only`` after +extra ``--`` flag. .. code-block:: bash - breeze --python 3.7 --backend mysql --use-airflow-version 2.2.5 start-airflow + breeze build-docs --spellcheck-only -Those are all available flags of ``start-airflow`` command: +This process can take some time, so in order to make it shorter you can filter by package, using the flag +``--package-filter ``. The package name has to be one of the providers or ``apache-airflow``. For +instance, for using it with Amazon, the command would be: -.. image:: ./images/breeze/output-start-airflow.svg - :width: 100% - :alt: Breeze start-airflow +.. code-block:: bash + breeze build-docs --package-filter apache-airflow-providers-amazon -Troubleshooting -=============== +Often errors during documentation generation come from the docstrings of auto-api generated classes. +During the docs building auto-api generated files are stored in the ``docs/_api`` folder. This helps you +easily identify the location the problems with documentation originated from. -If you are having problems with the Breeze environment, try the steps below. After each step you -can check whether your problem is fixed. +Those are all available flags of ``build-docs`` command: -1. If you are on macOS, check if you have enough disk space for Docker (Breeze will warn you if not). -2. Stop Breeze with ``breeze stop``. -3. Delete the ``.build`` directory and run ``breeze build-image``. -4. Clean up Docker images via ``breeze cleanup`` command. -5. Restart your Docker Engine and try again. -6. Restart your machine and try again. -7. Re-install Docker Desktop and try again. +.. image:: ./images/breeze/output_build-docs.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_build-docs.svg + :width: 100% + :alt: Breeze build documentation -In case the problems are not solved, you can set the VERBOSE_COMMANDS variable to "true": -.. code-block:: +.. raw:: html - export VERBOSE_COMMANDS="true" +
+ + Airflow Breeze - Build docs + +
+Running static checks +--------------------- -Then run the failed command, copy-and-paste the output from your terminal to the -`Airflow Slack `_ #airflow-breeze channel and -describe your problem. +You can run static checks via Breeze. You can also run them via pre-commit command but with auto-completion +Breeze makes it easier to run selective static checks. If you press after the static-check and if +you have auto-complete setup you should see auto-completable list of all checks available. -Uses of the Airflow Breeze environment -====================================== +.. code-block:: bash -Airflow Breeze is a bash script serving as a "swiss-army-knife" of Airflow testing. Under the -hood it uses other scripts that you can also run manually if you have problem with running the Breeze -environment. Breeze script allows performing the following tasks: + breeze static-checks -t run-mypy + +The above will run mypy check for currently staged files. + +You can also pass specific pre-commit flags for example ``--all-files`` : -Development tasks ------------------ +.. code-block:: bash -Those are commands mostly used by contributors: + breeze static-checks -t run-mypy --all-files -* Execute arbitrary command in the test environment with ``breeze shell`` command -* Enter interactive shell in CI container when ``shell`` (or no command) is specified -* Start containerised, development-friendly airflow installation with ``breeze start-airflow`` command -* Build documentation with ``breeze build-docs`` command -* Initialize local virtualenv with ``./scripts/tools/initialize_virtualenv.py`` command -* Run static checks with autocomplete support ``breeze static-checks`` command -* Run test specified with ``breeze tests`` command -* Build CI docker image with ``breeze build-image`` command -* Cleanup breeze with ``breeze cleanup`` command +The above will run mypy check for all files. -Additional management tasks: +There is a convenience ``--last-commit`` flag that you can use to run static check on last commit only: -* Join running interactive shell with ``breeze exec`` command -* Stop running interactive environment with ``breeze stop`` command -* Execute arbitrary docker-compose command with ``./breeze-legacy docker-compose`` command +.. code-block:: bash -Tests ------ + breeze static-checks -t run-mypy --last-commit -* Run docker-compose tests with ``breeze docker-compose-tests`` command. -* Run test specified with ``breeze tests`` command. +The above will run mypy check for all files in the last commit. -.. image:: ./images/breeze/output-tests.svg - :width: 100% - :alt: Breeze tests +There is another convenience ``--commit-ref`` flag that you can use to run static check on specific commit: -Kubernetes tests ----------------- +.. code-block:: bash -* Manage KinD Kubernetes cluster and deploy Airflow to KinD cluster ``./breeze-legacy kind-cluster`` commands -* Run Kubernetes tests specified with ``./breeze-legacy kind-cluster tests`` command -* Enter the interactive kubernetes test environment with ``./breeze-legacy kind-cluster shell`` command + breeze static-checks -t run-mypy --commit-ref 639483d998ecac64d0fef7c5aa4634414065f690 -CI Image tasks --------------- +The above will run mypy check for all files in the 639483d998ecac64d0fef7c5aa4634414065f690 commit. +Any ``commit-ish`` reference from Git will work here (branch, tag, short/long hash etc.) -The image building is usually run for users automatically when needed, -but sometimes Breeze users might want to manually build, pull or verify the CI images. +If you ever need to get a list of the files that will be checked (for troubleshooting) use these commands: -* Build CI docker image with ``breeze build-image`` command -* Pull CI images in parallel ``breeze pull-image`` command -* Verify CI image ``breeze verify-image`` command +.. code-block:: bash -PROD Image tasks ----------------- + breeze static-checks -t identity --verbose # currently staged files + breeze static-checks -t identity --verbose --from-ref $(git merge-base main HEAD) --to-ref HEAD # branch updates -Users can also build Production images when they are developing them. However when you want to -use the PROD image, the regular docker build commands are recommended. See -`building the image `_ +Those are all available flags of ``static-checks`` command: -* Build PROD image with ``breeze build-prod-image`` command -* Pull PROD image in parallel ``breeze pull-prod-image`` command -* Verify CI image ``breeze verify-prod-image`` command +.. image:: ./images/breeze/output_static-checks.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_static-checks.svg + :width: 100% + :alt: Breeze static checks -Configuration and maintenance ------------------------------ -* Cleanup breeze with ``breeze cleanup`` command -* Self-upgrade breeze with ``breeze self-upgrade`` command -* Setup autocomplete for Breeze with ``breeze setup-autocomplete`` command -* Checking available resources for docker with ``breeze resource-check`` command -* Freeing space needed to run CI tests with ``breeze free-space`` command -* Fixing ownership of files in your repository with ``breeze fix-ownership`` command -* Print Breeze version with ``breeze version`` command -* Outputs hash of commands defined by ``breeze`` with ``command-hash-export`` (useful to avoid needless - regeneration of Breeze images) - -Release tasks -------------- +.. note:: -Maintainers also can use Breeze for other purposes (those are commands that regular contributors likely -do not need or have no access to run). Those are usually connected with releasing Airflow: + When you run static checks, some of the artifacts (mypy_cache) is stored in docker-compose volume + so that it can speed up static checks execution significantly. However, sometimes, the cache might + get broken, in which case you should run ``breeze stop`` to clean up the cache. -* Prepare cache for CI: ``breeze build-image --prepare-build-cache`` and - ``breeze build-prod image --prepare-build-cache``(needs buildx plugin and write access to registry ghcr.io) -* Generate constraints with ``breeze generate-constraints`` (needed when conflicting changes are merged) -* Prepare airflow packages: ``breeze prepare-airflow-package`` (when releasing Airflow) -* Verify providers: ``breeze verify-provider-packages`` (when releasing provider packages) - including importing - the providers in an earlier airflow version. -* Prepare provider documentation ``breeze prepare-provider-documentation`` and prepare provider packages - ``breeze prepare-provider-packages`` (when releasing provider packages) -* Finding the updated dependencies since the last successful build when we have conflict with - ``breeze find-newer-dependencies`` command -* Release production images to DockerHub with ``breeze release-prod-images`` command +Starting Airflow +---------------- -Details of Breeze usage -======================= +For testing Airflow you often want to start multiple components (in multiple terminals). Breeze has +built-in ``start-airflow`` command that start breeze container, launches multiple terminals using tmux +and launches all Airflow necessary components in those terminals. -Database volumes in Breeze --------------------------- +When you are starting airflow from local sources, www asset compilation is automatically executed before. -Breeze keeps data for all it's integration in named docker volumes. Each backend and integration -keeps data in their own volume. Those volumes are persisted until ``breeze stop`` command. -You can also preserve the volumes by adding flag ``--preserve-volumes`` when you run the command. -Then, next time when you start Breeze, it will have the data pre-populated. +.. code-block:: bash -Those are all available flags of ``stop`` command: + breeze --python 3.7 --backend mysql start-airflow -.. image:: ./images/breeze/output-stop.svg - :width: 100% - :alt: Breeze stop -Image cleanup --------------- +You can also use it to start any released version of Airflow from ``PyPI`` with the +``--use-airflow-version`` flag. -Breeze uses docker images heavily and those images are rebuild periodically. This might cause extra -disk usage by the images. If you need to clean-up the images periodically you can run -``breeze cleanup`` command (by default it will skip removing your images before cleaning up but you -can also remove the images to clean-up everything by adding ``--all``). +.. code-block:: bash -Those are all available flags of ``cleanup`` command: + breeze start-airflow --python 3.7 --backend mysql --use-airflow-version 2.2.5 +Those are all available flags of ``start-airflow`` command: -.. image:: ./images/breeze/output-cleanup.svg +.. image:: ./images/breeze/output_start-airflow.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_start-airflow.svg :width: 100% - :alt: Breeze cleanup + :alt: Breeze start-airflow -Launching multiple terminals ----------------------------- +Launching multiple terminals in the same environment +---------------------------------------------------- Often if you want to run full airflow in the Breeze environment you need to launch multiple terminals and run ``airflow webserver``, ``airflow scheduler``, ``airflow worker`` in separate terminals. @@ -624,299 +530,696 @@ capability of creating multiple virtual terminals and multiplex between them. Mo found at `tmux GitHub wiki page `_ . Tmux has several useful shortcuts that allow you to split the terminals, open new tabs etc - it's pretty useful to learn it. -Here is the part of Breeze video which is relevant: - -.. raw:: html - -
- - Airflow Breeze - Using tmux - -
- - Another way is to exec into Breeze terminal from the host's terminal. Often you can have multiple terminals in the host (Linux/MacOS/WSL2 on Windows) and you can simply use those terminals to enter the running container. It's as easy as launching ``breeze exec`` while you already started the Breeze environment. You will be dropped into bash and environment variables will be read in the same way as when you enter the environment. You can do it multiple times and open as many terminals as you need. -Here is the part of Breeze video which is relevant: - -.. raw:: html - -
- - Airflow Breeze - Using tmux - -
- - Those are all available flags of ``exec`` command: -.. image:: ./images/breeze/output-exec.svg +.. image:: ./images/breeze/output_exec.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_exec.svg :width: 100% :alt: Breeze exec -Additional tools ----------------- -To shrink the Docker image, not all tools are pre-installed in the Docker image. But we have made sure that there -is an easy process to install additional tools. +Compiling www assets +-------------------- -Additional tools are installed in ``/files/bin``. This path is added to ``$PATH``, so your shell will -automatically autocomplete files that are in that directory. You can also keep the binaries for your tools -in this directory if you need to. +Airflow webserver needs to prepare www assets - compiled with node and yarn. The ``compile-www-assets`` +command takes care about it. This is needed when you want to run webserver inside of the breeze. -**Installation scripts** +.. image:: ./images/breeze/output_compile-www-assets.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_compile-www-assets.svg + :width: 100% + :alt: Breeze compile-www-assets -For the development convenience, we have also provided installation scripts for commonly used tools. They are -installed to ``/files/opt/``, so they are preserved after restarting the Breeze environment. Each script -is also available in ``$PATH``, so just type ``install_`` to get a list of tools. +Breeze cleanup +-------------- -Currently available scripts: +Breeze uses docker images heavily and those images are rebuild periodically. This might cause extra +disk usage by the images. If you need to clean-up the images periodically you can run +``breeze setup cleanup`` command (by default it will skip removing your images before cleaning up but you +can also remove the images to clean-up everything by adding ``--all``). -* ``install_aws.sh`` - installs `the AWS CLI `__ including -* ``install_az.sh`` - installs `the Azure CLI `__ including -* ``install_gcloud.sh`` - installs `the Google Cloud SDK `__ including - ``gcloud``, ``gsutil``. -* ``install_imgcat.sh`` - installs `imgcat - Inline Images Protocol `__ - for iTerm2 (Mac OS only) -* ``install_java.sh`` - installs `the OpenJDK 8u41 `__ -* ``install_kubectl.sh`` - installs `the Kubernetes command-line tool, kubectl `__ -* ``install_snowsql.sh`` - installs `SnowSQL `__ -* ``install_terraform.sh`` - installs `Terraform `__ +Those are all available flags of ``cleanup`` command: -Launching Breeze integrations ------------------------------ -When Breeze starts, it can start additional integrations. Those are additional docker containers -that are started in the same docker-compose command. Those are required by some of the tests -as described in ``_. +.. image:: ./images/breeze/output_cleanup.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_cleanup.svg + :width: 100% + :alt: Breeze setup cleanup -By default Breeze starts only airflow container without any integration enabled. If you selected -``postgres`` or ``mysql`` backend, the container for the selected backend is also started (but only the one -that is selected). You can start the additional integrations by passing ``--integration`` flag -with appropriate integration name when starting Breeze. You can specify several ``--integration`` flags -to start more than one integration at a time. -Finally you can specify ``--integration all`` to start all integrations. +Running arbitrary commands in container +--------------------------------------- -Once integration is started, it will continue to run until the environment is stopped with -``breeze stop`` command. or restarted via ``breeze restart`` command +More sophisticated usages of the breeze shell is using the ``breeze shell`` command - it has more parameters +and you can also use it to execute arbitrary commands inside the container. -Note that running integrations uses significant resources - CPU and memory. +.. code-block:: bash -Here is the part of Breeze video which is relevant (note that it refers to the old ``./breeze-legacy`` -command but it is very similar to current ``breeze`` command): + breeze shell "ls -la" -.. raw:: html +Those are all available flags of ``shell`` command: -
- - Airflow Breeze - Integrations - -
+.. image:: ./images/breeze/output_shell.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_shell.svg + :width: 100% + :alt: Breeze shell -Managing CI images ------------------- -With Breeze you can build images that are used by Airflow CI and production ones. +Stopping the environment +------------------------ -For all development tasks, unit tests, integration tests, and static code checks, we use the -**CI image** maintained in GitHub Container Registry. +After starting up, the environment runs in the background and takes quite some memory which you might +want to free for other things you are running on your host. -The CI image is built automatically as needed, however it can be rebuilt manually with -``build-image`` command. The production -image should be built manually - but also a variant of this image is built automatically when -kubernetes tests are executed see `Running Kubernetes tests <#running-kubernetes-tests>`_ +You can always stop it via: -Here is the part of Breeze video which is relevant (note that it refers to the old ``./breeze-legacy`` -command but it is very similar to current ``breeze`` command): +.. code-block:: bash -.. raw:: html + breeze stop -
- - Airflow Breeze - Building images - -
+Those are all available flags of ``stop`` command: -Building the image first time pulls a pre-built version of images from the Docker Hub, which may take some -time. But for subsequent source code changes, no wait time is expected. -However, changes to sensitive files like ``setup.py`` or ``Dockerfile.ci`` will trigger a rebuild -that may take more time though it is highly optimized to only rebuild what is needed. +.. image:: ./images/breeze/output_stop.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_stop.svg + :width: 100% + :alt: Breeze stop -Breeze has built in mechanism to check if your local image has not diverged too much from the -latest image build on CI. This might happen when for example latest patches have been released as new -Python images or when significant changes are made in the Dockerfile. In such cases, Breeze will -download the latest images before rebuilding because this is usually faster than rebuilding the image. +Troubleshooting +=============== -Those are all available flags of ``build-image`` command: +If you are having problems with the Breeze environment, try the steps below. After each step you +can check whether your problem is fixed. -.. image:: ./images/breeze/output-build-image.svg - :width: 100% - :alt: Breeze build-image +1. If you are on macOS, check if you have enough disk space for Docker (Breeze will warn you if not). +2. Stop Breeze with ``breeze stop``. +3. Delete the ``.build`` directory and run ``breeze ci-image build``. +4. Clean up Docker images via ``breeze cleanup`` command. +5. Restart your Docker Engine and try again. +6. Restart your machine and try again. +7. Re-install Docker Desktop and try again. -You can also pull the CI images locally in parallel with optional verification. +In case the problems are not solved, you can set the VERBOSE_COMMANDS variable to "true": -Those are all available flags of ``pull-image`` command: +.. code-block:: -.. image:: ./images/breeze/output-pull-image.svg - :width: 100% - :alt: Breeze pull-image + export VERBOSE_COMMANDS="true" -Finally, you can verify CI image by running tests - either with the pulled/built images or -with an arbitrary image. -Those are all available flags of ``verify-image`` command: +Then run the failed command, copy-and-paste the output from your terminal to the +`Airflow Slack `_ #airflow-breeze channel and +describe your problem. + +Advanced commands +================= + +Airflow Breeze is a bash script serving as a "swiss-army-knife" of Airflow testing. Under the +hood it uses other scripts that you can also run manually if you have problem with running the Breeze +environment. Breeze script allows performing the following tasks: + +Running tests +------------- + +You can run tests with ``breeze``. There are various tests type and breeze allows to run different test +types easily. You can run unit tests in different ways, either interactively run tests with the default +``shell`` command or via the ``testing`` commands. The latter allows to run more kinds of tests easily. -.. image:: ./images/breeze/output-verify-image.svg +Here is the detailed set of options for the ``breeze testing`` command. + +.. image:: ./images/breeze/output_testing.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_testing.svg :width: 100% - :alt: Breeze verify-image + :alt: Breeze testing -Verifying providers -------------------- -Breeze can also be used to verify if provider classes are importable and if they are following the -right naming conventions. This happens automatically on CI but you can also run it manually. +Iterate on tests interactively via ``shell`` command +.................................................... -.. code-block:: bash +You can simply enter the ``breeze`` container and run ``pytest`` command there. You can enter the +container via just ``breeze`` command or ``breeze shell`` command (the latter has more options +useful when you run integration or system tests). This is the best way if you want to interactively +run selected tests and iterate with the tests. Once you enter ``breeze`` environment it is ready +out-of-the-box to run your tests by running the right ``pytest`` command (autocomplete should help +you with autocompleting test name if you start typing ``pytest tests``). - breeze verify-provider-packages +Here are few examples: -You can also run the verification with an earlier airflow version to check for compatibility. +Running single test: .. code-block:: bash - breeze verify-provider-packages --use-airflow-version 2.1.0 + pytest tests/core/test_core.py::TestCore::test_check_operators -All the command parameters are here: +To run the whole test class: -.. image:: ./images/breeze/output-verify-provider-packages.svg - :width: 100% - :alt: Breeze verify-provider-packages +.. code-block:: bash -Preparing packages ------------------- + pytest tests/core/test_core.py::TestCore -Breeze can also be used to prepare airflow packages - both "apache-airflow" main package and -provider packages. +You can re-run the tests interactively, add extra parameters to pytest and modify the files before +re-running the test to iterate over the tests. You can also add more flags when starting the +``breeze shell`` command when you run integration tests or system tests. Read more details about it +in the ``TESTING.rst `` where all the test types of our are explained and more information +on how to run them. -You can read more about testing provider packages in -`TESTING.rst `_ +This applies to all kind of tests - all our tests can be run using pytest. -There are several commands that you can run in Breeze to manage and build packages: +Running unit tests +.................. -* preparing Provider documentation files -* preparing Airflow packages -* preparing Provider packages +Another option you have is that you can also run tests via built-in ``breeze testing tests`` command. +The iterative ``pytest`` command allows to run test individually, or by class or in any other way +pytest allows to test them and run them interactively, but ``breeze testing tests`` command allows to +run the tests in the same test "types" that are used to run the tests in CI: for example Core, Always +API, Providers. This how our CI runs them - running each group in parallel to other groups and you can +replicate this behaviour. -Preparing provider documentation files is part of the release procedure by the release managers -and it is described in detail in `dev `_ . +Another interesting use of the ``breeze testing tests`` command is that you can easily specify sub-set of the +tests for Providers. -The below example perform documentation preparation for provider packages. +For example this will only run provider tests for airbyte and http providers: .. code-block:: bash - breeze prepare-provider-documentation + breeze testing tests --test-type "Providers[airbyte,http]" -By default, the documentation preparation runs package verification to check if all packages are -importable, but you can add ``--skip-package-verification`` to skip it. +You can also run parallel tests with ``--run-in-parallel`` flag - by default it will run all tests types +in parallel, but you can specify the test type that you want to run with space separated list of test +types passed to ``--test-types`` flag. + +For example this will run API and WWW tests in parallel: .. code-block:: bash - breeze prepare-provider-documentation --skip-package-verification + breeze testing tests --test-types "API WWW" --run-in-parallel -You can also add ``--answer yes`` to perform non-interactive build. -.. image:: ./images/breeze/output-prepare-provider-documentation.svg +Here is the detailed set of options for the ``breeze testing tests`` command. + +.. image:: ./images/breeze/output_testing_tests.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_testing_tests.svg :width: 100% - :alt: Breeze prepare-provider-documentation + :alt: Breeze testing tests -The packages are prepared in ``dist`` folder. Note, that this command cleans up the ``dist`` folder -before running, so you should run it before generating airflow package below as it will be removed. +Running integration tests +......................... -The below example builds provider packages in the wheel format. +You can also run integration tests via built-in ``breeze testing integration-tests`` command. Some of our +tests require additional integrations to be started in docker-compose. The integration tests command will +run the expected integration and tests that need that integration. + +For example this will only run kerberos tests: .. code-block:: bash - breeze prepare-provider-packages + breeze testing integration-tests --integration Kerberos -If you run this command without packages, you will prepare all packages, you can however specify -providers that you would like to build. By default ``both`` types of packages are prepared ( -``wheel`` and ``sdist``, but you can change it providing optional --package-format flag. -.. code-block:: bash +Here is the detailed set of options for the ``breeze testing integration-tests`` command. - breeze prepare-provider-packages google amazon +.. image:: ./images/breeze/output_testing_integration-tests.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_testing_integration_tests.svg + :width: 100% + :alt: Breeze testing integration-tests -You can see all providers available by running this command: -.. code-block:: bash +Running Helm tests +.................. - breeze prepare-provider-packages --help +You can use Breeze to run all Helm tests. Those tests are run inside the breeze image as there are all +necessary tools installed there. -.. image:: ./images/breeze/output-prepare-provider-packages.svg +.. image:: ./images/breeze/output_testing_helm-tests.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_testing_helm-tests.svg :width: 100% - :alt: Breeze prepare-provider-packages + :alt: Breeze testing helm-tests -You can prepare airflow packages using breeze: +You can also iterate over those tests with pytest commands, similarly as in case of regular unit tests. +The helm tests can be found in ``tests/chart`` folder in the main repo. -.. code-block:: bash +Running docker-compose tests +............................ - breeze prepare-airflow-package +You can use Breeze to run all docker-compose tests. Those tests are run using Production image +and they are running test with the Quick-start docker compose we have. -This prepares airflow .whl package in the dist folder. +.. image:: ./images/breeze/output_testing_docker-compose-tests.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_testing_docker-compose-tests.svg + :width: 100% + :alt: Breeze testing docker-compose-tests -Again, you can specify optional ``--package-format`` flag to build selected formats of airflow packages, -default is to build ``both`` type of packages ``sdist`` and ``wheel``. +You can also iterate over those tests with pytest command, but - unlike regular unit tests and +Helm tests, they need to be run in local virtual environment. They also require to have +``DOCKER_IMAGE`` environment variable set, pointing to the image to test if you do not run them +through ``breeze testing docker-compose-tests`` command. -.. code-block:: bash +The docker-compose tests are in ``docker-tests/`` folder in the main repo. - breeze prepare-airflow-package --package-format=wheel +Running Kubernetes tests +------------------------ -.. image:: ./images/breeze/output-prepare-airflow-package.svg +Breeze helps with running Kubernetes tests in the same environment/way as CI tests are run. +Breeze helps to setup KinD cluster for testing, setting up virtualenv and downloads the right tools +automatically to run the tests. + +You can: + +* Setup environment for k8s tests with ``breeze k8s setup-env`` +* Build airflow k8S images with ``breeze k8s build-k8s-image`` +* Manage KinD Kubernetes cluster and upload image and deploy Airflow to KinD cluster via + ``breeze k8s create-cluster``, ``breeze k8s configure-cluster``, ``breeze k8s deploy-airflow``, ``breeze k8s status``, + ``breeze k8s upload-k8s-image``, ``breeze k8s delete-cluster`` commands +* Run Kubernetes tests specified with ``breeze k8s tests`` command +* Run complete test run with ``breeze k8s run-complete-tests`` - performing the full cycle of creating + cluster, uploading the image, deploying airflow, running tests and deleting the cluster +* Enter the interactive kubernetes test environment with ``breeze k8s shell`` and ``breeze k8s k9s`` command +* Run multi-cluster-operations ``breeze k8s list-all-clusters`` and + ``breeze k8s delete-all-clusters`` commands as well as running complete tests in parallel + via ``breeze k8s dump-logs`` command + +This is described in detail in `Testing Kubernetes `_. + +You can read more about KinD that we use in `The documentation `_ + +Here is the detailed set of options for the ``breeze k8s`` command. + +.. image:: ./images/breeze/output_k8s.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s.svg :width: 100% - :alt: Breeze prepare-airflow-package + :alt: Breeze k8s -Managing Production images --------------------------- + +Setting up K8S environment +.......................... + +Kubernetes environment can be set with the ``breeze k8s setup-env`` command. +It will create appropriate virtualenv to run tests and download the right set of tools to run +the tests: ``kind``, ``kubectl`` and ``helm`` in the right versions. You can re-run the command +when you want to make sure the expected versions of the tools are installed properly in the +virtualenv. The Virtualenv is available in ``.build/.k8s-env/bin`` subdirectory of your Airflow +installation. + +.. image:: ./images/breeze/output_k8s_setup-env.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_setup-env.svg + :width: 100% + :alt: Breeze k8s setup-env + +Creating K8S cluster +.................... + +You can create kubernetes cluster (separate cluster for each python/kubernetes version) via +``breeze k8s create-cluster`` command. With ``--force`` flag the cluster will be +deleted if exists. You can also use it to create multiple clusters in parallel with +``--run-in-parallel`` flag - this is what happens in our CI. + +All parameters of the command are here: + +.. image:: ./images/breeze/output_k8s_create-cluster.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_create-cluster.svg + :width: 100% + :alt: Breeze k8s create-cluster + +Deleting K8S cluster +.................... + +You can delete current kubernetes cluster via ``breeze k8s delete-cluster`` command. You can also add +``--run-in-parallel`` flag to delete all clusters. + +All parameters of the command are here: + +.. image:: ./images/breeze/output_k8s_delete-cluster.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_delete-cluster.svg + :width: 100% + :alt: Breeze k8s delete-cluster + +Building Airflow K8s images +........................... + +Before deploying Airflow Helm Chart, you need to make sure the appropriate Airflow image is build (it has +embedded test dags, pod templates and webserver is configured to refresh immediately. This can +be done via ``breeze k8s build-k8s-image`` command. It can also be done in parallel for all images via +``--run-in-parallel`` flag. + +All parameters of the command are here: + +.. image:: ./images/breeze/output_k8s_build-k8s-image.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_build-k8s-image.svg + :width: 100% + :alt: Breeze k8s build-k8s-image + +Uploading Airflow K8s images +............................ + +The K8S airflow images need to be uploaded to the KinD cluster. This can be done via +``breeze k8s upload-k8s-image`` command. It can also be done in parallel for all images via +``--run-in-parallel`` flag. + +All parameters of the command are here: + +.. image:: ./images/breeze/output_k8s_upload-k8s-image.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_upload-k8s-image.svg + :width: 100% + :alt: Breeze k8s upload-k8s-image + +Configuring K8S cluster +....................... + +In order to deploy Airflow, the cluster needs to be configured. Airflow namespace needs to be created +and test resources should be deployed. By passing ``--run-in-parallel`` the configuration can be run +for all clusters in parallel. + +All parameters of the command are here: + +.. image:: ./images/breeze/output_k8s_configure-cluster.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_configure-cluster.svg + :width: 100% + :alt: Breeze k8s configure-cluster + +Deploying Airflow to the Cluster +................................ + +Airflow can be deployed to the Cluster with ``breeze k8s deploy-airflow``. This step will automatically +(unless disabled by switches) will rebuild the image to be deployed. It also uses the latest version +of the Airflow Helm Chart to deploy it. You can also choose to upgrade existing airflow deployment +and pass extra arguments to ``helm install`` or ``helm upgrade`` commands that are used to +deploy airflow. By passing ``--run-in-parallel`` the deployment can be run +for all clusters in parallel. + +All parameters of the command are here: + +.. image:: ./images/breeze/output_k8s_deploy-airflow.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_deploy-airflow.svg + :width: 100% + :alt: Breeze k8s deploy-airflow + +Checking status of the K8S cluster +.................................. + +You can delete kubernetes cluster and airflow deployed in the current cluster +via ``breeze k8s status`` command. It can be also checked fora all clusters created so far by passing +``--all`` flag. + +All parameters of the command are here: + +.. image:: ./images/breeze/output_k8s_status.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_status.svg + :width: 100% + :alt: Breeze k8s status + +Running k8s tests +................. + +You can run ``breeze k8s tests`` command to run ``pytest`` tests with your cluster. Those testa are placed +in ``kubernetes_tests/`` and you can either specify the tests to run as parameter of the tests command or +you can leave them empty to run all tests. By passing ``--run-in-parallel`` the tests can be run +for all clusters in parallel. + +Run all tests: + +.. code-block::bash + + breeze k8s tests + +Run selected tests: + +.. code-block::bash + + breeze k8s tests kubernetes_tests/test_kubernetes_executor.py + +All parameters of the command are here: + +.. image:: ./images/breeze/output_k8s_tests.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_tests.svg + :width: 100% + :alt: Breeze k8s tests + +You can also specify any pytest flags as extra parameters - they will be passed to the +shell command directly. In case the shell parameters are the same as the parameters of the command, you +can pass them after ``--``. For example this is the way how you can see all available parameters of the shell +you have: + +.. code-block::bash + + breeze k8s tests -- --help + +The options that are not overlapping with the ``tests`` command options can be passed directly and mixed +with the specifications of tests you want to run. For example the command below will only run +``test_kubernetes_executor.py`` and will suppress capturing output from Pytest so that you can see the +output during test execution. + +.. code-block::bash + + breeze k8s tests -- kubernetes_tests/test_kubernetes_executor.py -s + +Running k8s complete tests +.......................... + +You can run ``breeze k8s run-complete-tests`` command to combine all previous steps in one command. That +command will create cluster, deploy airflow and run tests and finally delete cluster. It is used in CI +to run the whole chains in parallel. + +Run all tests: + +.. code-block::bash + + breeze k8s run-complete-tests + +Run selected tests: + +.. code-block::bash + + breeze k8s run-complete-tests kubernetes_tests/test_kubernetes_executor.py + +All parameters of the command are here: + +.. image:: ./images/breeze/output_k8s_run-complete-tests.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_run-complete-tests.svg + :width: 100% + :alt: Breeze k8s tests + +You can also specify any pytest flags as extra parameters - they will be passed to the +shell command directly. In case the shell parameters are the same as the parameters of the command, you +can pass them after ``--``. For example this is the way how you can see all available parameters of the shell +you have: + +.. code-block::bash + + breeze k8s run-complete-tests -- --help + +The options that are not overlapping with the ``tests`` command options can be passed directly and mixed +with the specifications of tests you want to run. For example the command below will only run +``test_kubernetes_executor.py`` and will suppress capturing output from Pytest so that you can see the +output during test execution. + +.. code-block::bash + + breeze k8s run-complete-tests -- kubernetes_tests/test_kubernetes_executor.py -s + + +Entering k8s shell +.................. + +You can have multiple clusters created - with different versions of Kubernetes and Python at the same time. +Breeze enables you to interact with the chosen cluster by entering dedicated shell session that has the +cluster pre-configured. This is done via ``breeze k8s shell`` command. + +Once you are in the shell, the prompt will indicate which cluster you are interacting with as well +as executor you use, similar to: + +.. code-block::bash + + (kind-airflow-python-3.9-v1.24.0:KubernetesExecutor)> + + +The shell automatically activates the virtual environment that has all appropriate dependencies +installed and you can interactively run all k8s tests with pytest command (of course the cluster need to +be created and airflow deployed to it before running the tests): + +.. code-block::bash + + (kind-airflow-python-3.9-v1.24.0:KubernetesExecutor)> pytest kubernetes_tests/test_kubernetes_executor.py + ================================================= test session starts ================================================= + platform linux -- Python 3.10.6, pytest-6.2.5, py-1.11.0, pluggy-1.0.0 -- /home/jarek/code/airflow/.build/.k8s-env/bin/python + cachedir: .pytest_cache + rootdir: /home/jarek/code/airflow, configfile: pytest.ini + plugins: anyio-3.6.1 + collected 2 items + + kubernetes_tests/test_kubernetes_executor.py::TestKubernetesExecutor::test_integration_run_dag PASSED [ 50%] + kubernetes_tests/test_kubernetes_executor.py::TestKubernetesExecutor::test_integration_run_dag_with_scheduler_failure PASSED [100%] + + ================================================== warnings summary =================================================== + .build/.k8s-env/lib/python3.10/site-packages/_pytest/config/__init__.py:1233 + /home/jarek/code/airflow/.build/.k8s-env/lib/python3.10/site-packages/_pytest/config/__init__.py:1233: PytestConfigWarning: Unknown config option: asyncio_mode + + self._warn_or_fail_if_strict(f"Unknown config option: {key}\n") + + -- Docs: https://docs.pytest.org/en/stable/warnings.html + ============================================ 2 passed, 1 warning in 38.62s ============================================ + (kind-airflow-python-3.9-v1.24.0:KubernetesExecutor)> + + +All parameters of the command are here: + +.. image:: ./images/breeze/output_k8s_shell.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_shell.svg + :width: 100% + :alt: Breeze k8s shell + +You can also specify any shell flags and commands as extra parameters - they will be passed to the +shell command directly. In case the shell parameters are the same as the parameters of the command, you +can pass them after ``--``. For example this is the way how you can see all available parameters of the shell +you have: + +.. code-block::bash + + breeze k8s shell -- --help + +Running k9s tool +................ + +The ``k9s`` is a fantastic tool that allows you to interact with running k8s cluster. Since we can have +multiple clusters capability, ``breeze k8s k9s`` allows you to start k9s without setting it up or +downloading - it uses k9s docker image to run it and connect it to the right cluster. + +All parameters of the command are here: + +.. image:: ./images/breeze/output_k8s_k9s.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_k9s.svg + :width: 100% + :alt: Breeze k8s k9s + +You can also specify any ``k9s`` flags and commands as extra parameters - they will be passed to the +``k9s`` command directly. In case the ``k9s`` parameters are the same as the parameters of the command, you +can pass them after ``--``. For example this is the way how you can see all available parameters of the +``k9s`` you have: + +.. code-block::bash + + breeze k8s k9s -- --help + +Dumping logs from all k8s clusters +.................................. + +KinD allows to export logs from the running cluster so that you can troubleshoot your deployment. +This can be done with ``breeze k8s logs`` command. Logs can be also dumped fora all clusters created +so far by passing ``--all`` flag. + +All parameters of the command are here: + +.. image:: ./images/breeze/output_k8s_logs.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_logs.svg + :width: 100% + :alt: Breeze k8s logs + + +CI Image tasks +-------------- + +The image building is usually run for users automatically when needed, +but sometimes Breeze users might want to manually build, pull or verify the CI images. + +.. image:: ./images/breeze/output_ci-image.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci-image.svg + :width: 100% + :alt: Breeze ci-image + +For all development tasks, unit tests, integration tests, and static code checks, we use the +**CI image** maintained in GitHub Container Registry. + +The CI image is built automatically as needed, however it can be rebuilt manually with +``ci image build`` command. + +Building the image first time pulls a pre-built version of images from the Docker Hub, which may take some +time. But for subsequent source code changes, no wait time is expected. +However, changes to sensitive files like ``setup.py`` or ``Dockerfile.ci`` will trigger a rebuild +that may take more time though it is highly optimized to only rebuild what is needed. + +Breeze has built in mechanism to check if your local image has not diverged too much from the +latest image build on CI. This might happen when for example latest patches have been released as new +Python images or when significant changes are made in the Dockerfile. In such cases, Breeze will +download the latest images before rebuilding because this is usually faster than rebuilding the image. + +Building CI image +................. + +Those are all available flags of ``ci-image build`` command: + +.. image:: ./images/breeze/output_ci-image_build.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci-image_build.svg + :width: 100% + :alt: Breeze ci-image build + +Pulling CI image +................ + +You can also pull the CI images locally in parallel with optional verification. + +Those are all available flags of ``pull`` command: + +.. image:: ./images/breeze/output_ci-image_pull.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci-image_pull.svg + :width: 100% + :alt: Breeze ci-image pull + +Verifying CI image +.................. + +Finally, you can verify CI image by running tests - either with the pulled/built images or +with an arbitrary image. + +Those are all available flags of ``verify`` command: + +.. image:: ./images/breeze/output_ci-image_verify.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci-image_verify.svg + :width: 100% + :alt: Breeze ci-image verify + +PROD Image tasks +---------------- + +Users can also build Production images when they are developing them. However when you want to +use the PROD image, the regular docker build commands are recommended. See +`building the image `_ + +.. image:: ./images/breeze/output_prod-image.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_prod-image.svg + :width: 100% + :alt: Breeze prod-image The **Production image** is also maintained in GitHub Container Registry for Caching and in ``apache/airflow`` manually pushed for released versions. This Docker image (built using official Dockerfile) contains size-optimised Airflow installation with selected extras and dependencies. However in many cases you want to add your own custom version of the image - with added apt dependencies, -python dependencies, additional Airflow extras. Breeze's ``build-image`` command helps to build your own, +python dependencies, additional Airflow extras. Breeze's ``prod-image build`` command helps to build your own, customized variant of the image that contains everything you need. -You can switch to building the production image by using ``build-prod-image`` command. +You can building the production image manually by using ``prod-image build`` command. Note, that the images can also be built using ``docker build`` command by passing appropriate build-args as described in `IMAGES.rst `_ , but Breeze provides several flags that -makes it easier to do it. You can see all the flags by running ``breeze build-prod-image --help``, +makes it easier to do it. You can see all the flags by running ``breeze prod-image build --help``, but here typical examples are presented: .. code-block:: bash - breeze build-prod-image --additional-extras "jira" + breeze prod-image build --additional-extras "jira" This installs additional ``jira`` extra while installing airflow in the image. .. code-block:: bash - breeze build-prod-image --additional-python-deps "torchio==0.17.10" + breeze prod-image build --additional-python-deps "torchio==0.17.10" This install additional pypi dependency - torchio in specified version. - .. code-block:: bash - breeze build-prod-image --additional-dev-apt-deps "libasound2-dev" \ + breeze prod-image build --additional-dev-apt-deps "libasound2-dev" \ --additional-runtime-apt-deps "libasound2" This installs additional apt dependencies - ``libasound2-dev`` in the build image and ``libasound`` in the @@ -929,202 +1232,389 @@ suffix and they need to also be paired with corresponding runtime dependency add .. code-block:: bash - breeze build-prod-image --python 3.7 --additional-dev-deps "libasound2-dev" \ + breeze prod-image build --python 3.7 --additional-dev-deps "libasound2-dev" \ --additional-runtime-apt-deps "libasound2" Same as above but uses python 3.7. +Building PROD image +................... + Those are all available flags of ``build-prod-image`` command: -.. image:: ./images/breeze/output-build-prod-image.svg +.. image:: ./images/breeze/output_prod-image_build.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_prod-image_build.svg :width: 100% - :alt: Breeze commands - -Here is the part of Breeze video which is relevant (note that it refers to the old ``./breeze-legacy`` -command but it is very similar to current ``breeze`` command): + :alt: Breeze prod-image build -.. raw:: html - -
- - Airflow Breeze - Building Production images - -
+Pulling PROD image +.................. You can also pull PROD images in parallel with optional verification. Those are all available flags of ``pull-prod-image`` command: -.. image:: ./images/breeze/output-pull-prod-image.svg +.. image:: ./images/breeze/output_prod-image_pull.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_prod-image_pull.svg :width: 100% - :alt: Breeze pull-prod-image + :alt: Breeze prod-image pull + +Verifying PROD image +.................... Finally, you can verify PROD image by running tests - either with the pulled/built images or with an arbitrary image. Those are all available flags of ``verify-prod-image`` command: -.. image:: ./images/breeze/output-verify-prod-image.svg +.. image:: ./images/breeze/output_prod-image_verify.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_prod-image_verify.svg :width: 100% - :alt: Breeze verify-prod-image + :alt: Breeze prod-image verify -Releasing Production images to DockerHub ----------------------------------------- -The **Production image** can be released by release managers who have permissions to push the image. This -happens only when there is an RC candidate or final version of Airflow released. +Breeze setup +------------ -You release "regular" and "slim" images as separate steps. +Breeze has tools that you can use to configure defaults and breeze behaviours and perform some maintenance +operations that might be necessary when you add new commands in Breeze. It also allows to configure your +host operating system for Breeze autocompletion. -Releasing "regular" images: +Those are all available flags of ``setup`` command: -.. code-block:: bash +.. image:: ./images/breeze/output_setup.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_setup.svg + :width: 100% + :alt: Breeze setup - breeze release-prod-images --airflow-version 2.4.0 +Breeze configuration +.................... -Or "slim" images: +You can configure and inspect settings of Breeze command via this command: Python version, Backend used as +well as backend versions. + +Another part of configuration is enabling/disabling cheatsheet, asciiart. The cheatsheet and asciiart can +be disabled - they are "nice looking" and cheatsheet +contains useful information for first time users but eventually you might want to disable both if you +find it repetitive and annoying. + +With the config setting colour-blind-friendly communication for Breeze messages. By default we communicate +with the users about information/errors/warnings/successes via colour-coded messages, but we can switch +it off by passing ``--no-colour`` to config in which case the messages to the user printed by Breeze +will be printed using different schemes (italic/bold/underline) to indicate different kind of messages +rather than colours. + +Those are all available flags of ``setup config`` command: + +.. image:: ./images/breeze/output_setup_config.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_setup_config.svg + :width: 100% + :alt: Breeze setup config + +Setting up autocompletion +......................... + +You get the auto-completion working when you re-enter the shell (follow the instructions printed). +The command will warn you and not reinstall autocomplete if you already did, but you can +also force reinstalling the autocomplete via: .. code-block:: bash - breeze release-prod-images --airflow-version 2.4.0 --slim-images + breeze setup autocomplete --force -By default when you are releasing the "final" image, we also tag image with "latest" tags but this -step can be skipped if you pass the ``--skip-latest`` flag. +Those are all available flags of ``setup-autocomplete`` command: -These are all of the available flags for the ``release-prod-images`` command: +.. image:: ./images/breeze/output_setup_autocomplete.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_setup_autocomplete.svg + :width: 100% + :alt: Breeze setup autocomplete + +Breeze version +.............. -.. image:: ./images/breeze/output-release-prod-images.svg +You can display Breeze version and with ``--verbose`` flag it can provide more information: where +Breeze is installed from and details about setup hashes. + +Those are all available flags of ``version`` command: + +.. image:: ./images/breeze/output_setup_version.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_setup_version.svg :width: 100% - :alt: Release prod images + :alt: Breeze version -Running static checks ---------------------- +Breeze self-upgrade +................... -You can run static checks via Breeze. You can also run them via pre-commit command but with auto-completion -Breeze makes it easier to run selective static checks. If you press after the static-check and if -you have auto-complete setup you should see auto-completable list of all checks available. +You can self-upgrade breeze automatically. Those are all available flags of ``self-upgrade`` command: -.. code-block:: bash +.. image:: ./images/breeze/output_setup_self-upgrade.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_setup_self-upgrade.svg + :width: 100% + :alt: Breeze setup self-upgrade - breeze static-checks -t mypy -The above will run mypy check for currently staged files. +Regenerating images for documentation +..................................... -You can also pass specific pre-commit flags for example ``--all-files`` : +This documentation contains exported images with "help" of their commands and parameters. You can +regenerate those images that need to be regenerated because their commands changed (usually after +the breeze code has been changed) via ``regenerate-command-images`` command. Usually this is done +automatically via pre-commit, but sometimes (for example when ``rich`` or ``rich-click`` library changes) +you need to regenerate those images. -.. code-block:: bash +You can add ``--force`` flag (or ``FORCE="true"`` environment variable to regenerate all images (not +only those that need regeneration). You can also run the command with ``--check-only`` flag to simply +check if there are any images that need regeneration. - breeze static-checks -t mypy --all-files +.. image:: ./images/breeze/output_setup_regenerate-command-images.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_setup_regenerate-command-images.svg + :width: 100% + :alt: Breeze setup regenerate-command-images -The above will run mypy check for all files. -There is a convenience ``--last-commit`` flag that you can use to run static check on last commit only: +CI tasks +-------- -.. code-block:: bash +Breeze hase a number of commands that are mostly used in CI environment to perform cleanup. - breeze static-checks -t mypy --last-commit +.. image:: ./images/breeze/output_ci.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci.svg + :width: 100% + :alt: Breeze ci commands -The above will run mypy check for all files in the last commit. +Running resource check +...................... -There is another convenience ``--commit-ref`` flag that you can use to run static check on specific commit: +Breeze requires certain resources to be available - disk, memory, CPU. When you enter Breeze's shell, +the resources are checked and information if there is enough resources is displayed. However you can +manually run resource check any time by ``breeze ci resource-check`` command. + +Those are all available flags of ``resource-check`` command: + +.. image:: ./images/breeze/output_ci_resource-check.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci_resource-check.svg + :width: 100% + :alt: Breeze ci resource-check + +Freeing the space +................. + +When our CI runs a job, it needs all memory and disk it can have. We have a Breeze command that frees +the memory and disk space used. You can also use it clear space locally but it performs a few operations +that might be a bit invasive - such are removing swap file and complete pruning of docker disk space used. + +Those are all available flags of ``free-space`` command: + +.. image:: ./images/breeze/output_ci_free-space.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci_free-space.svg + :width: 100% + :alt: Breeze ci free-space + +Fixing File/Directory Ownership +............................... + +On Linux, there is a problem with propagating ownership of created files (a known Docker problem). The +files and directories created in the container are not owned by the host user (but by the root user in our +case). This may prevent you from switching branches, for example, if files owned by the root user are +created within your sources. In case you are on a Linux host and have some files in your sources created +by the root user, you can fix the ownership of those files by running : + +.. code-block:: + + breeze ci fix-ownership + +Those are all available flags of ``fix-ownership`` command: + +.. image:: ./images/breeze/output_ci_fix-ownership.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci_fix-ownership.svg + :width: 100% + :alt: Breeze ci fix-ownership + +Selective check +............... + +When our CI runs a job, it needs to decide which tests to run, whether to build images and how much the test +should be run on multiple combinations of Python, Kubernetes, Backend versions. In order to optimize time +needed to run the CI Builds. You can also use the tool to test what tests will be run when you provide +a specific commit that Breeze should run the tests on. + +The selective-check command will produce the set of ``name=value`` pairs of outputs derived +from the context of the commit/PR to be merged via stderr output. + +More details about the algorithm used to pick the right tests and the available outputs can be +found in `Selective Checks `_. + +Those are all available flags of ``selective-check`` command: + +.. image:: ./images/breeze/output_ci_selective-check.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci_selective-check.svg + :width: 100% + :alt: Breeze ci selective-check + +Getting workflow information +............................ + +When our CI runs a job, it might be within one of several workflows. Information about those workflows +is stored in GITHUB_CONTEXT. Rather than using some jq/bash commands, we retrieve the necessary information +(like PR labels, event_type, where the job runs on, job description and convert them into GA outputs. + +Those are all available flags of ``get-workflow-info`` command: + +.. image:: ./images/breeze/output_ci_get-workflow-info.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci_get-workflow-info.svg + :width: 100% + :alt: Breeze ci get-workflow-info + +Tracking backtracking issues for CI builds +.......................................... + +When our CI runs a job, we automatically upgrade our dependencies in the ``main`` build. However, this might +lead to conflicts and ``pip`` backtracking for a long time (possibly forever) for dependency resolution. +Unfortunately those issues are difficult to diagnose so we had to invent our own tool to help us with +diagnosing them. This tool is ``find-newer-dependencies`` and it works in the way that it helps to guess +which new dependency might have caused the backtracking. The whole process is described in +`tracking backtracking issues `_. + +Those are all available flags of ``find-newer-dependencies`` command: + +.. image:: ./images/breeze/output_ci_find-newer-dependencies.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci_find-newer-dependencies.svg + :width: 100% + :alt: Breeze ci find-newer-dependencies + +Release management tasks +------------------------ + +Maintainers also can use Breeze for other purposes (those are commands that regular contributors likely +do not need or have no access to run). Those are usually connected with releasing Airflow: + +.. image:: ./images/breeze/output_release-management.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_release-management.svg + :width: 100% + :alt: Breeze release management + +Breeze can be used to prepare airflow packages - both "apache-airflow" main package and +provider packages. + +Preparing provider documentation +................................ + +You can read more about testing provider packages in +`TESTING.rst `_ + +There are several commands that you can run in Breeze to manage and build packages: + +* preparing Provider documentation files +* preparing Airflow packages +* preparing Provider packages + +Preparing provider documentation files is part of the release procedure by the release managers +and it is described in detail in `dev `_ . + +The below example perform documentation preparation for provider packages. .. code-block:: bash - breeze static-checks -t mypy --commit-ref 639483d998ecac64d0fef7c5aa4634414065f690 + breeze release-management prepare-provider-documentation -The above will run mypy check for all files in the 639483d998ecac64d0fef7c5aa4634414065f690 commit. -Any ``commit-ish`` reference from Git will work here (branch, tag, short/long hash etc.) +By default, the documentation preparation runs package verification to check if all packages are +importable, but you can add ``--skip-package-verification`` to skip it. + +.. code-block:: bash + + breeze release-management prepare-provider-documentation --skip-package-verification + +You can also add ``--answer yes`` to perform non-interactive build. + +.. image:: ./images/breeze/output_release-management_prepare-provider-documentation.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_release-management_prepare-provider-documentation.svg + :width: 100% + :alt: Breeze prepare-provider-documentation + +Preparing provider packages +........................... + +You can use Breeze to prepare provider packages. + +The packages are prepared in ``dist`` folder. Note, that this command cleans up the ``dist`` folder +before running, so you should run it before generating airflow package below as it will be removed. -If you ever need to get a list of the files that will be checked (for troubleshooting) use these commands: +The below example builds provider packages in the wheel format. .. code-block:: bash - breeze static-checks -t identity --verbose # currently staged files - breeze static-checks -t identity --verbose --from-ref $(git merge-base main HEAD) --to-ref HEAD # branch updates - -Those are all available flags of ``static-checks`` command: + breeze release-management prepare-provider-packages -.. image:: ./images/breeze/output-static-checks.svg - :width: 100% - :alt: Breeze static checks +If you run this command without packages, you will prepare all packages, you can however specify +providers that you would like to build. By default ``both`` types of packages are prepared ( +``wheel`` and ``sdist``, but you can change it providing optional --package-format flag. -Here is the part of Breeze video which is relevant (note that it refers to the old ``./breeze-legacy`` -command but it is very similar to current ``breeze`` command): +.. code-block:: bash -.. raw:: html + breeze release-management prepare-provider-packages google amazon -
- - Airflow Breeze - Static checks - -
+You can see all providers available by running this command: -.. note:: +.. code-block:: bash - When you run static checks, some of the artifacts (mypy_cache) is stored in docker-compose volume - so that it can speed up static checks execution significantly. However, sometimes, the cache might - get broken, in which case you should run ``breeze stop`` to clean up the cache. + breeze release-management prepare-provider-packages --help +.. image:: ./images/breeze/output_release-management_prepare-provider-packages.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_release-management_prepare-provider-packages.svg + :width: 100% + :alt: Breeze prepare-provider-packages -Building the Documentation --------------------------- +Verifying provider packages +........................... -To build documentation in Breeze, use the ``build-docs`` command: +Breeze can also be used to verify if provider classes are importable and if they are following the +right naming conventions. This happens automatically on CI but you can also run it manually if you +just prepared provider packages and they are present in ``dist`` folder. .. code-block:: bash - breeze build-docs + breeze release-management verify-provider-packages -Results of the build can be found in the ``docs/_build`` folder. +You can also run the verification with an earlier airflow version to check for compatibility. -The documentation build consists of three steps: +.. code-block:: bash -* verifying consistency of indexes -* building documentation -* spell checking + breeze release-management verify-provider-packages --use-airflow-version 2.1.0 -You can choose only one stage of the two by providing ``--spellcheck-only`` or ``--docs-only`` after -extra ``--`` flag. +All the command parameters are here: -.. code-block:: bash +.. image:: ./images/breeze/output_release-management_verify-provider-packages.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_release-management_verify-provider-packages.svg + :width: 100% + :alt: Breeze verify-provider-packages - breeze build-docs --spellcheck-only -This process can take some time, so in order to make it shorter you can filter by package, using the flag -``--package-filter ``. The package name has to be one of the providers or ``apache-airflow``. For -instance, for using it with Amazon, the command would be: +Preparing airflow packages +.......................... -.. code-block:: bash +You can prepare airflow packages using Breeze: - breeze build-docs --package-filter apache-airflow-providers-amazon +.. code-block:: bash -Often errors during documentation generation come from the docstrings of auto-api generated classes. -During the docs building auto-api generated files are stored in the ``docs/_api`` folder. This helps you -easily identify the location the problems with documentation originated from. + breeze release-management prepare-airflow-package -Those are all available flags of ``build-docs`` command: +This prepares airflow .whl package in the dist folder. -.. image:: ./images/breeze/output-build-docs.svg - :width: 100% - :alt: Breeze build documentation +Again, you can specify optional ``--package-format`` flag to build selected formats of airflow packages, +default is to build ``both`` type of packages ``sdist`` and ``wheel``. -Here is the part of Breeze video which is relevant (note that it refers to the old ``./breeze-legacy`` -command but it is very similar to current ``breeze`` command): +.. code-block:: bash -.. raw:: html + breeze release-management prepare-airflow-package --package-format=wheel -
- - Airflow Breeze - Build docs - -
+.. image:: ./images/breeze/output_release-management_prepare-airflow-package.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_release-management_prepare-airflow-package.svg + :width: 100% + :alt: Breeze release-management prepare-airflow-package Generating constraints ----------------------- +...................... Whenever setup.py gets modified, the CI main job will re-generate constraint files. Those constraint files are stored in separated orphan branches: ``constraints-main``, ``constraints-2-0``. @@ -1133,7 +1623,7 @@ Those are constraint files as described in detail in the ``_ contributing documentation. -You can use ``breeze generate-constraints`` command to manually generate constraints for +You can use ``breeze release-management generate-constraints`` command to manually generate constraints for all or selected python version and single constraint mode like this: .. warning:: @@ -1144,7 +1634,7 @@ all or selected python version and single constraint mode like this: .. code-block:: bash - breeze generate-constraints --airflow-constraints-mode constraints + breeze release-management generate-constraints --airflow-constraints-mode constraints Constraints are generated separately for each python version and there are separate constraints modes: @@ -1163,7 +1653,8 @@ Constraints are generated separately for each python version and there are separ Those are all available flags of ``generate-constraints`` command: -.. image:: ./images/breeze/output-generate-constraints.svg +.. image:: ./images/breeze/output_release-management_generate-constraints.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_release-management_generate-constraints.svg :width: 100% :alt: Breeze generate-constraints @@ -1176,157 +1667,140 @@ This bumps the constraint files to latest versions and stores hash of setup.py. and setup.py hash files are stored in the ``files`` folder and while generating the constraints diff of changes vs the previous constraint files is printed. +Releasing Production images +........................... -Using local virtualenv environment in Your Host IDE ---------------------------------------------------- - -You can set up your host IDE (for example, IntelliJ's PyCharm/Idea) to work with Breeze -and benefit from all the features provided by your IDE, such as local and remote debugging, -language auto-completion, documentation support, etc. +The **Production image** can be released by release managers who have permissions to push the image. This +happens only when there is an RC candidate or final version of Airflow released. -To use your host IDE with Breeze: +You release "regular" and "slim" images as separate steps. -1. Create a local virtual environment: +Releasing "regular" images: - You can use any of the following wrappers to create and manage your virtual environments: - `pyenv `_, `pyenv-virtualenv `_, - or `virtualenvwrapper `_. +.. code-block:: bash -2. Use the right command to activate the virtualenv (``workon`` if you use virtualenvwrapper or - ``pyenv activate`` if you use pyenv. + breeze release-management release-prod-images --airflow-version 2.4.0 -3. Initialize the created local virtualenv: +Or "slim" images: .. code-block:: bash - ./scripts/tools/initialize_virtualenv.py - -.. warning:: - Make sure that you use the right Python version in this command - matching the Python version you have - in your local virtualenv. If you don't, you will get strange conflicts. + breeze release-management release-prod-images --airflow-version 2.4.0 --slim-images -4. Select the virtualenv you created as the project's default virtualenv in your IDE. +By default when you are releasing the "final" image, we also tag image with "latest" tags but this +step can be skipped if you pass the ``--skip-latest`` flag. -Note that you can also use the local virtualenv for Airflow development without Breeze. -This is a lightweight solution that has its own limitations. +These are all of the available flags for the ``release-prod-images`` command: -More details on using the local virtualenv are available in the `LOCAL_VIRTUALENV.rst `_. +.. image:: ./images/breeze/output_release-management_release-prod-images.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_release-management_release-prod-images.svg + :width: 100% + :alt: Breeze release management release prod images -Here is the part of Breeze video which is relevant (note that it refers to the old ``./breeze-legacy`` -but it is not available in the ``breeze`` command): -.. raw:: html +Details of Breeze usage +======================= -
- - Airflow Breeze - Initialize virtualenv - -
+Database volumes in Breeze +-------------------------- -Running docker-compose tests ----------------------------- +Breeze keeps data for all it's integration in named docker volumes. Each backend and integration +keeps data in their own volume. Those volumes are persisted until ``breeze stop`` command. +You can also preserve the volumes by adding flag ``--preserve-volumes`` when you run the command. +Then, next time when you start Breeze, it will have the data pre-populated. -You can use Breeze to run docker-compose tests. Those tests are run using Production image -and they are running test with the Quick-start docker compose we have. +Those are all available flags of ``stop`` command: -.. image:: ./images/breeze/output-docker-compose-tests.svg +.. image:: ./images/breeze/output-stop.svg + :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output-stop.svg :width: 100% - :alt: Breeze generate-constraints - - -Running Kubernetes tests ------------------------- + :alt: Breeze stop -Breeze helps with running Kubernetes tests in the same environment/way as CI tests are run. -Breeze helps to setup KinD cluster for testing, setting up virtualenv and downloads the right tools -automatically to run the tests. -This is described in detail in `Testing Kubernetes `_. +Additional tools +---------------- -Here is the part of Breeze video which is relevant (note that it refers to the old ``./breeze-legacy`` -command and it is not yet available in the current ``breeze`` command): +To shrink the Docker image, not all tools are pre-installed in the Docker image. But we have made sure that there +is an easy process to install additional tools. -.. raw:: html +Additional tools are installed in ``/files/bin``. This path is added to ``$PATH``, so your shell will +automatically autocomplete files that are in that directory. You can also keep the binaries for your tools +in this directory if you need to. -
- - Airflow Breeze - Kubernetes tests - -
+**Installation scripts** -Stopping the interactive environment ------------------------------------- +For the development convenience, we have also provided installation scripts for commonly used tools. They are +installed to ``/files/opt/``, so they are preserved after restarting the Breeze environment. Each script +is also available in ``$PATH``, so just type ``install_`` to get a list of tools. -After starting up, the environment runs in the background and takes precious memory. -You can always stop it via: +Currently available scripts: -.. code-block:: bash +* ``install_aws.sh`` - installs `the AWS CLI `__ including +* ``install_az.sh`` - installs `the Azure CLI `__ including +* ``install_gcloud.sh`` - installs `the Google Cloud SDK `__ including + ``gcloud``, ``gsutil``. +* ``install_imgcat.sh`` - installs `imgcat - Inline Images Protocol `__ + for iTerm2 (Mac OS only) +* ``install_java.sh`` - installs `the OpenJDK 8u41 `__ +* ``install_kubectl.sh`` - installs `the Kubernetes command-line tool, kubectl `__ +* ``install_snowsql.sh`` - installs `SnowSQL `__ +* ``install_terraform.sh`` - installs `Terraform `__ - breeze stop +Launching Breeze integrations +----------------------------- -Those are all available flags of ``stop`` command: +When Breeze starts, it can start additional integrations. Those are additional docker containers +that are started in the same docker-compose command. Those are required by some of the tests +as described in ``_. -.. image:: ./images/breeze/output-stop.svg - :width: 100% - :alt: Breeze stop +By default Breeze starts only airflow container without any integration enabled. If you selected +``postgres`` or ``mysql`` backend, the container for the selected backend is also started (but only the one +that is selected). You can start the additional integrations by passing ``--integration`` flag +with appropriate integration name when starting Breeze. You can specify several ``--integration`` flags +to start more than one integration at a time. +Finally you can specify ``--integration all`` to start all integrations. -Here is the part of Breeze video which is relevant (note that it refers to the old ``./breeze-legacy`` -command but it is very similar to current ``breeze`` command): +Once integration is started, it will continue to run until the environment is stopped with +``breeze stop`` command. or restarted via ``breeze restart`` command -.. raw:: html +Note that running integrations uses significant resources - CPU and memory. -
- - Airflow Breeze - Stop environment - -
-Resource check -============== +Using local virtualenv environment in Your Host IDE +--------------------------------------------------- -Breeze requires certain resources to be available - disk, memory, CPU. When you enter Breeze's shell, -the resources are checked and information if there is enough resources is displayed. However you can -manually run resource check any time by ``breeze resource-check`` command. +You can set up your host IDE (for example, IntelliJ's PyCharm/Idea) to work with Breeze +and benefit from all the features provided by your IDE, such as local and remote debugging, +language auto-completion, documentation support, etc. -Those are all available flags of ``resource-check`` command: +To use your host IDE with Breeze: -.. image:: ./images/breeze/output-resource-check.svg - :width: 100% - :alt: Breeze resource-check +1. Create a local virtual environment: + You can use any of the following wrappers to create and manage your virtual environments: + `pyenv `_, `pyenv-virtualenv `_, + or `virtualenvwrapper `_. -Freeing the space -================= +2. Use the right command to activate the virtualenv (``workon`` if you use virtualenvwrapper or + ``pyenv activate`` if you use pyenv. -When our CI runs a job, it needs all memory and disk it can have. We have a Breeze command that frees -the memory and disk space used. You can also use it clear space locally but it performs a few operations -that might be a bit invasive - such are removing swap file and complete pruning of docker disk space used. +3. Initialize the created local virtualenv: -Those are all available flags of ``free-space`` command: +.. code-block:: bash -.. image:: ./images/breeze/output-free-space.svg - :width: 100% - :alt: Breeze free-space + ./scripts/tools/initialize_virtualenv.py +.. warning:: + Make sure that you use the right Python version in this command - matching the Python version you have + in your local virtualenv. If you don't, you will get strange conflicts. -Tracking backtracking issues for CI builds -========================================== +4. Select the virtualenv you created as the project's default virtualenv in your IDE. -When our CI runs a job, we automatically upgrade our dependencies in the ``main`` build. However, this might -lead to conflicts and ``pip`` backtracking for a long time (possibly forever) for dependency resolution. -Unfortunately those issues are difficult to diagnose so we had to invent our own tool to help us with -diagnosing them. This tool is ``find-newer-dependencies`` and it works in the way that it helps to guess -which new dependency might have caused the backtracking. The whole process is described in -`tracking backtracking issues `_. +Note that you can also use the local virtualenv for Airflow development without Breeze. +This is a lightweight solution that has its own limitations. -Those are all available flags of ``find-newer-dependencies`` command: +More details on using the local virtualenv are available in the `LOCAL_VIRTUALENV.rst `_. -.. image:: ./images/breeze/output-find-newer-dependencies.svg - :width: 100% - :alt: Breeze find-newer-dependencies Internal details of Breeze ========================== @@ -1341,10 +1815,12 @@ When you are in the CI container, the following directories are used: /opt/airflow - Contains sources of Airflow mounted from the host (AIRFLOW_SOURCES). /root/airflow - Contains all the "dynamic" Airflow files (AIRFLOW_HOME), such as: airflow.db - sqlite database in case sqlite is used; - dags - folder with non-test dags (test dags are in /opt/airflow/tests/dags); logs - logs from Airflow executions; unittest.cfg - unit test configuration generated when entering the environment; webserver_config.py - webserver configuration generated when running Airflow in the container. + /files - files mounted from "files" folder in your sources. You can edit them in the host as well + dags - this is the folder where Airflow DAGs are read from + airflow-breeze-config - this is where you can keep your own customization configuration of breeze Note that when running in your local environment, the ``/root/airflow/logs`` folder is actually mounted from your ``logs`` directory in the Airflow sources, so all logs created in the container are automatically @@ -1358,43 +1834,17 @@ When you are in the production container, the following directories are used: /opt/airflow - Contains sources of Airflow mounted from the host (AIRFLOW_SOURCES). /root/airflow - Contains all the "dynamic" Airflow files (AIRFLOW_HOME), such as: airflow.db - sqlite database in case sqlite is used; - dags - folder with non-test dags (test dags are in /opt/airflow/tests/dags); logs - logs from Airflow executions; unittest.cfg - unit test configuration generated when entering the environment; webserver_config.py - webserver configuration generated when running Airflow in the container. + /files - files mounted from "files" folder in your sources. You can edit them in the host as well + dags - this is the folder where Airflow DAGs are read from Note that when running in your local environment, the ``/root/airflow/logs`` folder is actually mounted from your ``logs`` directory in the Airflow sources, so all logs created in the container are automatically visible in the host as well. Every time you enter the container, the ``logs`` directory is cleaned so that logs do not accumulate. -Running Arbitrary commands in the Breeze environment ----------------------------------------------------- - -To run other commands/executables inside the Breeze Docker-based environment, use the -``breeze shell`` command. - -.. code-block:: bash - - breeze shell "ls -la" - -Those are all available flags of ``shell`` command: - -.. image:: ./images/breeze/output-shell.svg - :width: 100% - :alt: Breeze shell - -Running "Docker Compose" commands ---------------------------------- - -To run Docker Compose commands (such as ``help``, ``pull``, etc), use the -``docker-compose`` command. To add extra arguments, specify them -after ``--`` as extra arguments. - -.. code-block:: bash - - ./breeze-legacy docker-compose pull -- --ignore-pull-failures - Setting default answers for user interaction -------------------------------------------- @@ -1410,25 +1860,6 @@ For automation scripts, you can export the ``ANSWER`` variable (and set it to export ANSWER="yes" -Fixing File/Directory Ownership -------------------------------- - -On Linux, there is a problem with propagating ownership of created files (a known Docker problem). The -files and directories created in the container are not owned by the host user (but by the root user in our -case). This may prevent you from switching branches, for example, if files owned by the root user are -created within your sources. In case you are on a Linux host and have some files in your sources created -by the root user, you can fix the ownership of those files by running : - -.. code-block:: - - breeze fix-ownership - -Those are all available flags of ``fix-ownership`` command: - -.. image:: ./images/breeze/output-fix-ownership.svg - :width: 100% - :alt: Breeze fix-ownership - Mounting Local Sources to Breeze -------------------------------- @@ -1448,6 +1879,11 @@ By default ``/files/dags`` folder is mounted from your local `` the directory used by airflow scheduler and webserver to scan dags for. You can use it to test your dags from local sources in Airflow. If you wish to add local DAGs that can be run by Breeze. +The ``/files/airflow-breeze-config`` folder contains configuration files that might be used to +customize your breeze instance. Those files will be kept across checking out a code from different +branches and stopping/starting breeze so you can keep your configuration there and use it continuously while +you switch to different source code versions. + Port Forwarding --------------- @@ -1508,28 +1944,18 @@ If you set these variables, next time when you enter the environment the new por Managing Dependencies --------------------- -If you need to change apt dependencies in the ``Dockerfile.ci``, add Python packages in ``setup.py`` or -add JavaScript dependencies in ``package.json``, you can either add dependencies temporarily for a single -Breeze session or permanently in ``setup.py``, ``Dockerfile.ci``, or ``package.json`` files. - -Installing Dependencies for a Single Breeze Session -................................................... - -You can install dependencies inside the container using ``sudo apt install``, ``pip install`` or -``yarn install`` (in ``airflow/www`` folder) respectively. This is useful if you want to test something -quickly while you are in the container. However, these changes are not retained: they disappear once you -exit the container (except for the node.js dependencies if your sources are mounted to the container). -Therefore, if you want to retain a new dependency, follow the second option described below. +If you need to change apt dependencies in the ``Dockerfile.ci``, add Python packages in ``setup.py`` +for airflow and in provider.yaml for packages. If you add any "node" dependencies in ``airflow/www`` +, you need to compile them in the host with ``breeze compile-www-assets`` command. Adding Dependencies Permanently ............................... -You can add dependencies to the ``Dockerfile.ci``, ``setup.py`` or ``package.json`` and rebuild the image. -This should happen automatically if you modify any of these files. +You can add dependencies to the ``Dockerfile.ci``, ``setup.py``. After you exit the container and re-run ``breeze``, Breeze detects changes in dependencies, asks you to confirm rebuilding the image and proceeds with rebuilding if you confirm (or skip it if you do not confirm). After rebuilding is done, Breeze drops you to shell. You may also use the -``build-image`` command to only rebuild CI image and not to go into shell. +``build`` command to only rebuild CI image and not to go into shell. Incremental apt Dependencies in the Dockerfile.ci during development .................................................................... diff --git a/CI.rst b/CI.rst index f24639271e977..5c937e579431f 100644 --- a/CI.rst +++ b/CI.rst @@ -21,7 +21,7 @@ CI Environment ============== Continuous Integration is important component of making Apache Airflow robust and stable. We are running -a lot of tests for every pull request, for main and v2-*-test branches and regularly as CRON jobs. +a lot of tests for every pull request, for main and v2-*-test branches and regularly as scheduled jobs. Our execution environment for CI is `GitHub Actions `_. GitHub Actions (GA) are very well integrated with GitHub code and Workflow and it has evolved fast in 2019/202 to become @@ -33,7 +33,7 @@ environments we use. Most of our CI jobs are written as bash scripts which are e the CI jobs. And we have a number of variables determine build behaviour. You can also take a look at the `CI Sequence Diagrams `_ for more graphical overview -of how Airlfow's CI works. +of how Airflow CI works. GitHub Actions runs ------------------- @@ -84,288 +84,88 @@ We use `GitHub Container Registry `_ but in essence it is a script that allows -you to re-create CI environment in your local development instance and interact with it. In its basic -form, when you do development you can run all the same tests that will be run in CI - but locally, -before you submit them as PR. Another use case where Breeze is useful is when tests fail on CI. You can -take the full ``COMMIT_SHA`` of the failed build pass it as ``--github-image-id`` parameter of Breeze and it will -download the very same version of image that was used in CI and run it locally. This way, you can very -easily reproduce any failed test that happens in CI - even if you do not check out the sources -connected with the run. -You can read more about it in `BREEZE.rst `_ and `TESTING.rst `_ +Naming conventions for stored images +==================================== -Difference between local runs and GitHub Action workflows ---------------------------------------------------------- +The images produced during the ``Build Images`` workflow of CI jobs are stored in the +`GitHub Container Registry `_ -Depending whether the scripts are run locally (most often via `Breeze `_) or whether they -are run in ``Build Images`` or ``Tests`` workflows they can take different values. +The images are stored with both "latest" tag (for last main push image that passes all the tests as well +with the COMMIT_SHA id for images that were used in particular build. -You can use those variables when you try to reproduce the build locally. +The image names follow the patterns (except the Python image, all the images are stored in +https://ghcr.io/ in ``apache`` organization. -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| Variable | Local | Build Images | Tests | Comment | -| | development | CI workflow | Workflow | | -+=========================================+=============+==============+============+=================================================+ -| Basic variables | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| ``PYTHON_MAJOR_MINOR_VERSION`` | | | | Major/Minor version of Python used. | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| ``DB_RESET`` | false | true | true | Determines whether database should be reset | -| | | | | at the container entry. By default locally | -| | | | | the database is not reset, which allows to | -| | | | | keep the database content between runs in | -| | | | | case of Postgres or MySQL. However, | -| | | | | it requires to perform manual init/reset | -| | | | | if you stop the environment. | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| Mount variables | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| ``MOUNT_SELECTED_LOCAL_SOURCES`` | true | false | false | Determines whether local sources are | -| | | | | mounted to inside the container. Useful for | -| | | | | local development, as changes you make | -| | | | | locally can be immediately tested in | -| | | | | the container. We mount only selected, | -| | | | | important folders. We do not mount the whole | -| | | | | project folder in order to avoid accidental | -| | | | | use of artifacts (such as ``egg-info`` | -| | | | | directories) generated locally on the | -| | | | | host during development. | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| ``MOUNT_ALL_LOCAL_SOURCES`` | false | false | false | Determines whether all local sources are | -| | | | | mounted to inside the container. Useful for | -| | | | | local development when you need to access .git | -| | | | | folders and other folders excluded when | -| | | | | ``MOUNT_SELECTED_LOCAL_SOURCES`` is true. | -| | | | | You might need to manually delete egg-info | -| | | | | folder when you enter breeze and the folder was | -| | | | | generated using different Python versions. | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| Force variables | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| ``FORCE_BUILD_IMAGES`` | false | false | false | Forces building images. This is generally not | -| | | | | very useful in CI as in CI environment image | -| | | | | is built or pulled only once, so there is no | -| | | | | need to set the variable to true. For local | -| | | | | builds it forces rebuild, regardless if it | -| | | | | is determined to be needed. | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| ``ANSWER`` | | yes | yes | This variable determines if answer to questions | -| | | | | during the build process should be | -| | | | | automatically given. For local development, | -| | | | | the user is occasionally asked to provide | -| | | | | answers to questions such as - whether | -| | | | | the image should be rebuilt. By default | -| | | | | the user has to answer but in the CI | -| | | | | environment, we force "yes" answer. | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| ``SKIP_CHECK_REMOTE_IMAGE`` | false | true | true | Determines whether we check if remote image | -| | | | | is "fresher" than the current image. | -| | | | | When doing local breeze runs we try to | -| | | | | determine if it will be faster to rebuild | -| | | | | the image or whether the image should be | -| | | | | pulled first from the cache because it has | -| | | | | been rebuilt. This is slightly experimental | -| | | | | feature and will be improved in the future | -| | | | | as the current mechanism does not always | -| | | | | work properly. | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| Host variables | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| ``HOST_USER_ID`` | | | | User id of the host user. | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| ``HOST_GROUP_ID`` | | | | Group id of the host user. | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| ``HOST_OS`` | | Linux | Linux | OS of the Host (Darwin/Linux). | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| Git variables | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| ``COMMIT_SHA`` | | GITHUB_SHA | GITHUB_SHA | SHA of the commit of the build is run | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| Initialization | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| ``SKIP_ENVIRONMENT_INITIALIZATION`` | false\* | false\* | false\* | Skip initialization of test environment | -| | | | | | -| | | | | \* set to true in pre-commits | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| ``SKIP_SSH_SETUP`` | false\* | false\* | false\* | Skip setting up SSH server for tests. | -| | | | | | -| | | | | \* set to true in GitHub CodeSpaces | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| Verbosity variables | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| ``PRINT_INFO_FROM_SCRIPTS`` | true\* | true\* | true\* | Allows to print output to terminal from running | -| | | | | scripts. It prints some extra outputs if true | -| | | | | including what the commands do, results of some | -| | | | | operations, summary of variable values, exit | -| | | | | status from the scripts, outputs of failing | -| | | | | commands. If verbose is on it also prints the | -| | | | | commands executed by docker, kind, helm, | -| | | | | kubectl. Disabled in pre-commit checks. | -| | | | | | -| | | | | \* set to false in pre-commits | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| ``VERBOSE`` | false | true | true | Determines whether docker, helm, kind, | -| | | | | kubectl commands should be printed before | -| | | | | execution. This is useful to determine | -| | | | | what exact commands were executed for | -| | | | | debugging purpose as well as allows | -| | | | | to replicate those commands easily by | -| | | | | copy&pasting them from the output. | -| | | | | requires ``PRINT_INFO_FROM_SCRIPTS`` set to | -| | | | | true. | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| ``VERBOSE_COMMANDS`` | false | false | false | Determines whether every command | -| | | | | executed in bash should also be printed | -| | | | | before execution. This is a low-level | -| | | | | debugging feature of bash (set -x) and | -| | | | | it should only be used if you are lost | -| | | | | at where the script failed. | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| Image build variables | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| ``UPGRADE_TO_NEWER_DEPENDENCIES`` | false | false | false\* | Determines whether the build should | -| | | | | attempt to upgrade Python base image and all | -| | | | | PIP dependencies to latest ones matching | -| | | | | ``setup.py`` limits. This tries to replicate | -| | | | | the situation of "fresh" user who just installs | -| | | | | airflow and uses latest version of matching | -| | | | | dependencies. By default we are using a | -| | | | | tested set of dependency constraints | -| | | | | stored in separated "orphan" branches | -| | | | | of the airflow repository | -| | | | | ("constraints-main, "constraints-2-0") | -| | | | | but when this flag is set to anything but false | -| | | | | (for example random value), they are not used | -| | | | | used and "eager" upgrade strategy is used | -| | | | | when installing dependencies. We set it | -| | | | | to true in case of direct pushes (merges) | -| | | | | to main and scheduled builds so that | -| | | | | the constraints are tested. In those builds, | -| | | | | in case we determine that the tests pass | -| | | | | we automatically push latest set of | -| | | | | "tested" constraints to the repository. | -| | | | | | -| | | | | Setting the value to random value is best way | -| | | | | to assure that constraints are upgraded even if | -| | | | | there is no change to setup.py | -| | | | | | -| | | | | This way our constraints are automatically | -| | | | | tested and updated whenever new versions | -| | | | | of libraries are released. | -| | | | | | -| | | | | \* true in case of direct pushes and | -| | | | | scheduled builds | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| ``CHECK_IMAGE_FOR_REBUILD`` | true | true | true\* | Determines whether attempt should be | -| | | | | made to rebuild the CI image with latest | -| | | | | sources. It is true by default for | -| | | | | local builds, however it is set to | -| | | | | true in case we know that the image | -| | | | | we pulled or built already contains | -| | | | | the right sources. In such case we | -| | | | | should set it to false, especially | -| | | | | in case our local sources are not the | -| | | | | ones we intend to use (for example | -| | | | | when ``--github-image-id`` is used | -| | | | | in Breeze. | -| | | | | | -| | | | | In CI jobs it is set to true | -| | | | | in case of the ``Build Images`` | -| | | | | workflow or when | -| | | | | waiting for images is disabled | -| | | | | in the "Tests" workflow. | -| | | | | | -| | | | | \* if waiting for images the variable is set | -| | | | | to false automatically. | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ -| ``SKIP_BUILDING_PROD_IMAGE`` | false | false | false\* | Determines whether we should skip building | -| | | | | the PROD image with latest sources. | -| | | | | It is set to false, but in deploy app for | -| | | | | kubernetes step it is set to "true", because at | -| | | | | this stage we know we have good image build or | -| | | | | pulled. | -| | | | | | -| | | | | \* set to true in "Deploy App to Kubernetes" | -| | | | | to false automatically. | -+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ +The packages are available under (CONTAINER_NAME is url-encoded name of the image). Note that "/" are +supported now in the ``ghcr.io`` as apart of the image name within ``apache`` organization, but they +have to be percent-encoded when you access them via UI (/ = %2F) -Running CI Jobs locally -======================= - -The scripts and configuration files for CI jobs are all in ``scripts/ci`` - so that in the -``pull_request_target`` target workflow, we can copy those scripts from the ``main`` branch and use them -regardless of the changes done in the PR. This way we are kept safe from PRs injecting code into the builds. - -* ``build_airflow`` - builds airflow packages -* ``constraints`` - scripts to build and publish latest set of valid constraints -* ``docs`` - scripts to build documentation -* ``images`` - scripts to build and push CI and PROD images -* ``kubernetes`` - scripts to setup kubernetes cluster, deploy airflow and run kubernetes tests with it -* ``openapi`` - scripts to run openapi generation -* ``pre_commit`` - scripts to run pre-commit checks -* ``provider_packages`` - scripts to build and test provider packages -* ``static_checks`` - scripts to run static checks manually -* ``testing`` - scripts that run unit and integration tests -* ``tools`` - scripts that can be used for various clean-up and preparation tasks - -Common libraries of functions for all the scripts can be found in ``libraries`` folder. The ``dockerfiles``, -``mysql.d``, ``openldap``, ``spectral_rules`` folders contains DockerFiles and configuration of integrations -needed to run tests. - -For detailed use of those scripts you can refer to ``.github/workflows/`` - those scripts are used -by the CI workflows of ours. There are some variables that you can set to change the behaviour of the -scripts. - -The default values are "sane" you can change them to interact with your own repositories or registries. -Note that you need to set "CI" variable to true in order to get the same results as in CI. - -+------------------------------+----------------------+-----------------------------------------------------+ -| Variable | Default | Comment | -+==============================+======================+=====================================================+ -| CI | ``false`` | If set to "true", we simulate behaviour of | -| | | all scripts as if they are in CI environment | -+------------------------------+----------------------+-----------------------------------------------------+ -| CI_TARGET_REPO | ``apache/airflow`` | Target repository for the CI job. Used to | -| | | compare incoming changes from PR with the target. | -+------------------------------+----------------------+-----------------------------------------------------+ -| CI_TARGET_BRANCH | ``main`` | Target branch where the PR should land. Used to | -| | | compare incoming changes from PR with the target. | -+------------------------------+----------------------+-----------------------------------------------------+ -| CI_BUILD_ID | ``0`` | Unique id of the build that is kept across re runs | -| | | (for GitHub actions it is ``GITHUB_RUN_ID``) | -+------------------------------+----------------------+-----------------------------------------------------+ -| CI_JOB_ID | ``0`` | Unique id of the job - used to produce unique | -| | | artifact names. | -+------------------------------+----------------------+-----------------------------------------------------+ -| CI_EVENT_TYPE | ``pull_request`` | Type of the event. It can be one of | -| | | [``pull_request``, ``pull_request_target``, | -| | | ``schedule``, ``push``] | -+------------------------------+----------------------+-----------------------------------------------------+ -| CI_REF | ``refs/head/main`` | Branch in the source repository that is used to | -| | | make the pull request. | -+------------------------------+----------------------+-----------------------------------------------------+ +``https://github.com/apache/airflow/pkgs/container/`` ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| Image | Name:tag (both cases latest version and per-build) | Description | ++==============+==========================================================+==========================================================+ +| Python image | python:-slim-bullseye | Base Python image used by both production and CI image. | +| (DockerHub) | | Python maintainer release new versions of those image | +| | | with security fixes every few weeks in DockerHub. | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| Airflow | airflow//python:-slim-bullseye | Version of python base image used in Airflow Builds | +| python base | | We keep the "latest" version only to mark last "good" | +| image | | python base that went through testing and was pushed. | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| PROD Build | airflow//prod-build/python:latest | Production Build image - this is the "build" stage of | +| image | | production image. It contains build-essentials and all | +| | | necessary apt packages to build/install PIP packages. | +| | | We keep the "latest" version only to speed up builds. | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| Manifest | airflow//ci-manifest/python:latest | CI manifest image - this is the image used to optimize | +| CI image | | pulls and builds for Breeze development environment | +| | | They store hash indicating whether the image will be | +| | | faster to build or pull. | +| | | We keep the "latest" version only to help breeze to | +| | | check if new image should be pulled. | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| CI image | airflow//ci/python:latest | CI image - this is the image used for most of the tests. | +| | or | Contains all provider dependencies and tools useful | +| | airflow//ci/python: | For testing. This image is used in Breeze. | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| | | faster to build or pull. | +| PROD image | airflow//prod/python:latest | Production image. This is the actual production image | +| | or | optimized for size. | +| | airflow//prod/python: | It contains only compiled libraries and minimal set of | +| | | dependencies to run Airflow. | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ + +* might be either "main" or "v2-*-test" +* - Python version (Major + Minor).Should be one of ["3.7", "3.8", "3.9"]. +* - full-length SHA of commit either from the tip of the branch (for pushes/schedule) or + commit from the tip of the branch used for the PR. GitHub Registry Variables ========================= -Our CI uses GitHub Registry to pull and push images to/from by default. You can use your own repo by changing -``GITHUB_REPOSITORY`` and providing your own GitHub Username and Token. +Our CI uses GitHub Registry to pull and push images to/from by default. Those variables are set automatically +by GitHub Actions when you run Airflow workflows in your fork, so they should automatically use your +own repository as GitHub Registry to build and keep the images as build image cache. + +The variables are automatically set in GitHub actions +--------------------------------+---------------------------+----------------------------------------------+ | Variable | Default | Comment | +================================+===========================+==============================================+ | GITHUB_REPOSITORY | ``apache/airflow`` | Prefix of the image. It indicates which. | -| | | registry from GitHub to use | +| | | registry from GitHub to use for image cache | +| | | and to determine the name of the image. | ++--------------------------------+---------------------------+----------------------------------------------+ +| CONSTRAINTS_GITHUB_REPOSITORY | ``apache/airflow`` | Repository where constraints are stored | +--------------------------------+---------------------------+----------------------------------------------+ | GITHUB_USERNAME | | Username to use to login to GitHub | | | | | @@ -373,37 +173,37 @@ Our CI uses GitHub Registry to pull and push images to/from by default. You can | GITHUB_TOKEN | | Token to use to login to GitHub. | | | | Only used when pushing images on CI. | +--------------------------------+---------------------------+----------------------------------------------+ -| GITHUB_REGISTRY_PULL_IMAGE_TAG | ``latest`` | Pull this image tag. This is "latest" by | -| | | default, can also be full-length commit SHA. | -+--------------------------------+---------------------------+----------------------------------------------+ -| GITHUB_REGISTRY_PUSH_IMAGE_TAG | ``latest`` | Push this image tag. This is "latest" by | -| | | default, can also be full-length commit SHA. | -+--------------------------------+---------------------------+----------------------------------------------+ + +The Variables beginning with ``GITHUB_`` cannot be overridden in GitHub Actions by the workflow. +Those variables are set by GitHub Actions automatically and they are immutable. Therefore if +you want to override them in your own CI workflow and use ``breeze``, you need to pass the +values by corresponding ``breeze`` flags ``--github-repository``, ``--github-username``, +``--github-token`` rather than by setting them as environment variables in your workflow. +Unless you want to keep your own copy of constraints in orphaned ``constraints-*`` +branches, the ``CONSTRAINTS_GITHUB_REPOSITORY`` should remain ``apache/airflow``, regardless in which +repository the CI job is run. + +One of the variables you might want to override in your own GitHub Actions workflow when using ``breeze`` is +``--github-repository`` - you might want to force it to ``apache/airflow``, because then the cache from +``apache/airflow`` repository will be used and your builds will be much faster. + +Example command to build your CI image efficiently in your own CI workflow: + +.. code-block:: bash + + # GITHUB_REPOSITORY is set automatically in Github Actions so we need to override it with flag + # + breeze ci-image build --github-repository apache/airflow --python 3.10 + docker tag ghcr.io/apache/airflow/main/ci/python3.10 your-image-name:tag + Authentication in GitHub Registry ================================= We are using GitHub Container Registry as cache for our images. Authentication uses GITHUB_TOKEN mechanism. Authentication is needed for pushing the images (WRITE) only in "push", "pull_request_target" workflows. +When you are running the CI jobs in GitHub Actions, GITHUB_TOKEN is set automatically by the actions. -CI Architecture -=============== - -The following components are part of the CI infrastructure - -* **Apache Airflow Code Repository** - our code repository at https://github.com/apache/airflow -* **Apache Airflow Forks** - forks of the Apache Airflow Code Repository from which contributors make - Pull Requests -* **GitHub Actions** - (GA) UI + execution engine for our jobs -* **GA CRON trigger** - GitHub Actions CRON triggering our jobs -* **GA Workers** - virtual machines running our jobs at GitHub Actions (max 20 in parallel) -* **GitHub Image Registry** - image registry used as build cache for CI jobs. - It is at https://ghcr.io/apache/airflow -* **DockerHub Image Registry** - image registry used to pull base Python images and (manually) publish - the released Production Airflow images. It is at https://dockerhub.com/apache/airflow -* **Official Images** (future) - these are official images that are prominently visible in DockerHub. - We aim our images to become official images so that you will be able to pull them - with ``docker pull apache-airflow`` CI run types ============ @@ -411,6 +211,12 @@ CI run types The following CI Job run types are currently run for Apache Airflow (run by ci.yaml workflow) and each of the run types has different purpose and context. +Besides the regular "PR" runs we also have "Canary" runs that are able to detect most of the +problems that might impact regular PRs early, without necessarily failing all PRs when those +problems happen. This allows to provide much more stable environment for contributors, who +contribute their PR, while giving a chance to maintainers to react early on problems that +need reaction, when the "canary" builds fail. + Pull request run ---------------- @@ -426,33 +232,37 @@ CI, Production Images as well as base Python images that are also cached in the Also for those builds we only execute Python tests if important files changed (so for example if it is "no-code" change, no tests will be executed. -The workflow involved in Pull Requests review and approval is a bit more complex than simple workflows -in most of other projects because we've implemented some optimizations related to efficient use -of queue slots we share with other Apache Software Foundation projects. More details about it -can be found in `PULL_REQUEST_WORKFLOW.rst `_. +Regular PR builds run in a "stable" environment: +* fixed set of constraints (constraints that passed the tests) - except the PRs that change dependencies +* limited matrix and set of tests (determined by selective checks based on what changed in the PR) +* no ARM image builds are build in the regular PRs +* lower probability of flaky tests for non-committer PRs (public runners and less parallelism) -Direct Push/Merge Run ---------------------- +Canary run +---------- -Those runs are results of direct pushes done by the committers or as result of merge of a Pull Request +Those runs are results of direct pushes done by the committers - basically merging of a Pull Request by the committers. Those runs execute in the context of the Apache Airflow Code Repository and have also write permission for GitHub resources (container registry, code repository). + The main purpose for the run is to check if the code after merge still holds all the assertions - like -whether it still builds, all tests are green. +whether it still builds, all tests are green. This is a "Canary" build that helps us to detect early +problems with dependencies, image building, full matrix of tests in case they passed through selective checks. This is needed because some of the conflicting changes from multiple PRs might cause build and test failures after merge even if they do not fail in isolation. Also those runs are already reviewed and confirmed by the committers so they can be used to do some housekeeping: -- pushing most recent image build in the PR to the GitHub Container Registry (for caching) + +- pushing most recent image build in the PR to the GitHub Container Registry (for caching) including recent + Dockerfile changes and setup.py/setup.cfg changes (Early Cache) +- test that image in ``breeze`` command builds quickly +- run full matrix of tests to detect any tests that will be mistakenly missed in ``selective checks`` - upgrading to latest constraints and pushing those constraints if all tests succeed - refresh latest Python base images in case new patch-level is released The housekeeping is important - Python base images are refreshed with varying frequency (once every few months usually but sometimes several times per week) with the latest security and bug fixes. -Those patch level images releases can occasionally break Airflow builds (specifically Docker image builds -based on those images) therefore in PRs we only use latest "good" Python image that we store in the -GitHub Container Registry and those push requests will refresh the latest images if they changed. Scheduled runs -------------- @@ -528,51 +338,50 @@ Tests Workflow This workflow is a regular workflow that performs all checks of Airflow code. -+---------------------------+----------------------------------------------+-------+-------+------+ -| Job | Description | PR | Push | CRON | -| | | | Merge | (1) | -+===========================+==============================================+=======+=======+======+ -| Build info | Prints detailed information about the build | Yes | Yes | Yes | -+---------------------------+----------------------------------------------+-------+-------+------+ -| Test OpenAPI client gen | Tests if OpenAPIClient continues to generate | Yes | Yes | Yes | -+---------------------------+----------------------------------------------+-------+-------+------+ -| UI tests | React UI tests for new Airflow UI | Yes | Yes | Yes | -+---------------------------+----------------------------------------------+-------+-------+------+ -| WWW tests | React tests for current Airflow UI | Yes | Yes | Yes | -+---------------------------+----------------------------------------------+-------+-------+------+ -| Test image building | Tests if PROD image build examples work | Yes | Yes | Yes | -+---------------------------+----------------------------------------------+-------+-------+------+ -| CI Images | Waits for and verify CI Images (3) | Yes | Yes | Yes | -+---------------------------+----------------------------------------------+-------+-------+------+ -| (Basic) Static checks | Performs static checks (full or basic) | Yes | Yes | Yes | -+---------------------------+----------------------------------------------+-------+-------+------+ -| Build docs | Builds documentation | Yes | Yes | Yes | -+---------------------------+----------------------------------------------+-------+-------+------+ -| Tests | Run all the Pytest tests for Python code | Yes(2)| Yes | Yes | -+---------------------------+----------------------------------------------+-------+-------+------+ -| Tests provider packages | Tests if provider packages work | Yes | Yes | Yes | -+---------------------------+----------------------------------------------+-------+-------+------+ -| Upload coverage | Uploads test coverage from all the tests | - | Yes | - | -+---------------------------+----------------------------------------------+-------+-------+------+ -| PROD Images | Waits for and verify PROD Images (3) | Yes | Yes | Yes | -+---------------------------+----------------------------------------------+-------+-------+------+ -| Tests Kubernetes | Run Kubernetes test | Yes(2)| Yes | Yes | -+---------------------------+----------------------------------------------+-------+-------+------+ -| Constraints | Upgrade constraints to latest ones (4) | - | Yes | Yes | -+---------------------------+----------------------------------------------+-------+-------+------+ -| Push images | Pushes latest images to GitHub Registry (4) | - | Yes | Yes | -+---------------------------+----------------------------------------------+-------+-------+------+ - - -Comments: - - (1) CRON jobs builds images from scratch - to test if everything works properly for clean builds - (2) The tests are run when the Trigger Tests job determine that important files change (this allows - for example "no-code" changes to build much faster) - (3) The jobs wait for CI images to be available. - (4) PROD and CI images are pushed as "latest" to GitHub Container registry and constraints are upgraded - only if all tests are successful. The images are rebuilt in this step using constraints pushed - in the previous step. ++-----------------------------+----------------------------------------------------------+---------+----------+-----------+ +| Job | Description | PR | Canary | Scheduled | ++=============================+==========================================================+=========+==========+===========+ +| Build info | Prints detailed information about the build | Yes | Yes | Yes | ++-----------------------------+----------------------------------------------------------+---------+----------+-----------+ +| Build CI/PROD images | Builds images in-workflow (not in the build images one) | - | Yes | Yes (1) | ++-----------------------------+----------------------------------------------------------+---------+----------+-----------+ +| Push early cache & images | Pushes early cache/images to GitHub Registry and test | - | Yes | - | +| | speed of building breeze images from scratch | | | | ++-----------------------------+----------------------------------------------------------+---------+----------+-----------+ +| Test OpenAPI client gen | Tests if OpenAPIClient continues to generate | Yes | Yes | Yes | ++-----------------------------+----------------------------------------------------------+---------+----------+-----------+ +| UI tests | React UI tests for new Airflow UI | Yes | Yes | Yes | ++-----------------------------+----------------------------------------------------------+---------+----------+-----------+ +| Test image building | Tests if PROD image build examples work | Yes | Yes | Yes | ++-----------------------------+----------------------------------------------------------+---------+----------+-----------+ +| CI Images | Waits for and verify CI Images (2) | Yes | Yes | Yes | ++-----------------------------+----------------------------------------------------------+---------+----------+-----------+ +| (Basic) Static checks | Performs static checks (full or basic) | Yes | Yes | Yes | ++-----------------------------+----------------------------------------------------------+---------+----------+-----------+ +| Build docs | Builds documentation | Yes | Yes | Yes | ++-----------------------------+----------------------------------------------------------+---------+----------+-----------+ +| Tests | Run all the Pytest tests for Python code | Yes | Yes | Yes | ++-----------------------------+----------------------------------------------------------+---------+----------+-----------+ +| Tests provider packages | Tests if provider packages work | Yes | Yes | Yes | ++-----------------------------+----------------------------------------------------------+---------+----------+-----------+ +| Upload coverage | Uploads test coverage from all the tests | - | Yes | - | ++-----------------------------+----------------------------------------------------------+---------+----------+-----------+ +| PROD Images | Waits for and verify PROD Images (2) | Yes | Yes | Yes | ++-----------------------------+----------------------------------------------------------+---------+----------+-----------+ +| Tests Kubernetes | Run Kubernetes test | Yes | Yes | Yes | ++-----------------------------+----------------------------------------------------------+---------+----------+-----------+ +| Constraints | Upgrade constraints to latest ones (3) | - | Yes | Yes | ++-----------------------------+----------------------------------------------------------+---------+----------+-----------+ +| Push cache & images | Pushes cache/images to GitHub Registry (3) | - | Yes | Yes | ++-----------------------------+----------------------------------------------------------+---------+----------+-----------+ + +``(1)`` Scheduled jobs builds images from scratch - to test if everything works properly for clean builds + +``(2)`` The jobs wait for CI images to be available. + +``(3)`` PROD and CI cache & images are pushed as "latest" to GitHub Container registry and constraints are +upgraded only if all tests are successful. The images are rebuilt in this step using constraints pushed +in the previous step. CodeQL scan ----------- @@ -595,65 +404,60 @@ For more information, see: Website endpoint: http://apache-airflow-docs.s3-website.eu-central-1.amazonaws.com/ -Naming conventions for stored images -==================================== -The images produced during the ``Build Images`` workflow of CI jobs are stored in the -`GitHub Container Registry `_ +Debugging CI Jobs in Github Actions +=================================== -The images are stored with both "latest" tag (for last main push image that passes all the tests as well -with the COMMIT_SHA id for images that were used in particular build. +The CI jobs are notoriously difficult to test, because you can only really see results of it when you run them +in CI environment, and the environment in which they run depend on who runs them (they might be either run +in our Self-Hosted runners (with 64 GB RAM 8 CPUs) or in the GitHub Public runners (6 GB of RAM, 2 CPUs) and +the results will vastly differ depending on which environment is used. We are utilizing parallelism to make +use of all the available CPU/Memory but sometimes you need to enable debugging and force certain environments. +Additional difficulty is that ``Build Images`` workflow is ``pull-request-target`` type, which means that it +will always run using the ``main`` version - no matter what is in your Pull Request. -The image names follow the patterns (except the Python image, all the images are stored in -https://ghcr.io/ in ``apache`` organization. +There are several ways how you can debug the CI jobs when you are maintainer. -The packages are available under (CONTAINER_NAME is url-encoded name of the image). Note that "/" are -supported now in the ``ghcr.io`` as apart of the image name within ``apache`` organization, but they -have to be percent-encoded when you access them via UI (/ = %2F) +* When you want to tests the build with all combinations of all python, backends etc on regular PR, + add ``full tests needed`` label to the PR. +* When you want to test maintainer PR using public runners, add ``public runners`` label to the PR +* When you want to see resources used by the run, add ``debug ci resources`` label to the PR +* When you want to test changes to breeze that include changes to how images are build you should push + your PR to ``apache`` repository not to your fork. This will run the images as part of the ``CI`` workflow + rather than using ``Build images`` workflow and use the same breeze version for building image and testing +* When you want to test changes to ``build-images.yml`` workflow you should push your branch as ``main`` + branch in your local fork. This will run changed ``build-images.yml`` workflow as it will be in ``main`` + branch of your fork -``https://github.com/apache/airflow/pkgs/container/`` +Replicating the CI Jobs locally +=============================== -+--------------+----------------------------------------------------------+----------------------------------------------------------+ -| Image | Name:tag (both cases latest version and per-build) | Description | -+==============+==========================================================+==========================================================+ -| Python image | python:-slim-bullseye | Base Python image used by both production and CI image. | -| (DockerHub) | | Python maintainer release new versions of those image | -| | | with security fixes every few weeks in DockerHub. | -+--------------+----------------------------------------------------------+----------------------------------------------------------+ -| Airflow | airflow//python:-slim-bullseye | Version of python base image used in Airflow Builds | -| python base | | We keep the "latest" version only to mark last "good" | -| image | | python base that went through testing and was pushed. | -+--------------+----------------------------------------------------------+----------------------------------------------------------+ -| PROD Build | airflow//prod-build/python:latest | Production Build image - this is the "build" stage of | -| image | | production image. It contains build-essentials and all | -| | | necessary apt packages to build/install PIP packages. | -| | | We keep the "latest" version only to speed up builds. | -+--------------+----------------------------------------------------------+----------------------------------------------------------+ -| Manifest | airflow//ci-manifest/python:latest | CI manifest image - this is the image used to optimize | -| CI image | | pulls and builds for Breeze development environment | -| | | They store hash indicating whether the image will be | -| | | faster to build or pull. | -| | | We keep the "latest" version only to help breeze to | -| | | check if new image should be pulled. | -+--------------+----------------------------------------------------------+----------------------------------------------------------+ -| CI image | airflow//ci/python:latest | CI image - this is the image used for most of the tests. | -| | or | Contains all provider dependencies and tools useful | -| | airflow//ci/python: | For testing. This image is used in Breeze. | -+--------------+----------------------------------------------------------+----------------------------------------------------------+ -| | | faster to build or pull. | -| PROD image | airflow//prod/python:latest | Production image. This is the actual production image | -| | or | optimized for size. | -| | airflow//prod/python: | It contains only compiled libraries and minimal set of | -| | | dependencies to run Airflow. | -+--------------+----------------------------------------------------------+----------------------------------------------------------+ +The main goal of the CI philosophy we have that no matter how complex the test and integration +infrastructure, as a developer you should be able to reproduce and re-run any of the failed checks +locally. One part of it are pre-commit checks, that allow you to run the same static checks in CI +and locally, but another part is the CI environment which is replicated locally with Breeze. -* might be either "main" or "v2-*-test" -* - Python version (Major + Minor).Should be one of ["3.7", "3.8", "3.9"]. -* - full-length SHA of commit either from the tip of the branch (for pushes/schedule) or - commit from the tip of the branch used for the PR. +You can read more about Breeze in `BREEZE.rst `_ but in essence it is a script that allows +you to re-create CI environment in your local development instance and interact with it. In its basic +form, when you do development you can run all the same tests that will be run in CI - but locally, +before you submit them as PR. Another use case where Breeze is useful is when tests fail on CI. You can +take the full ``COMMIT_SHA`` of the failed build pass it as ``--image-tag`` parameter of Breeze and it will +download the very same version of image that was used in CI and run it locally. This way, you can very +easily reproduce any failed test that happens in CI - even if you do not check out the sources +connected with the run. + +All our CI jobs are executed via ``breeze`` commands. You can replicate exactly what our CI is doing +by running the sequence of corresponding ``breeze`` command. Make sure however that you look at both: -Reproducing CI Runs locally -=========================== +* flags passed to ``breeze`` commands +* environment variables used when ``breeze`` command is run - this is useful when we want + to set a common flag for all ``breeze`` commands in the same job or even the whole workflow. For + example ``VERBOSE`` variable is set to ``true`` for all our workflows so that more detailed information + about internal commands executed in CI is printed. + +In the output of the CI jobs, you will find both - the flags passed and environment variables set. + +You can read more about it in `BREEZE.rst `_ and `TESTING.rst `_ Since we store images from every CI run, you should be able easily reproduce any of the CI tests problems locally. You can do it by pulling and using the right image and running it with the right docker command, @@ -668,12 +472,11 @@ For example knowing that the CI job was for commit ``cd27124534b46c9688a1d89e75f But you usually need to pass more variables and complex setup if you want to connect to a database or enable some integrations. Therefore it is easiest to use `Breeze `_ for that. For example if -you need to reproduce a MySQL environment with kerberos integration enabled for commit -cd27124534b46c9688a1d89e75fcd137ab5137e3, in python 3.8 environment you can run: +you need to reproduce a MySQL environment in python 3.8 environment you can run: .. code-block:: bash - ./breeze-legacy --github-image-id cd27124534b46c9688a1d89e75fcd137ab5137e3 --python 3.8 + breeze --image-tag cd27124534b46c9688a1d89e75fcd137ab5137e3 --python 3.8 --backend mysql You will be dropped into a shell with the exact version that was used during the CI run and you will be able to run pytest tests manually, easily reproducing the environment that was used in CI. Note that in @@ -681,26 +484,117 @@ this case, you do not need to checkout the sources that were used for that run - the image - but remember that any changes you make in those sources are lost when you leave the image as the sources are not mapped from your host machine. +Depending whether the scripts are run locally via `Breeze `_ or whether they +are run in ``Build Images`` or ``Tests`` workflows they can take different values. -Adding new Python versions to CI --------------------------------- +You can use those variables when you try to reproduce the build locally (alternatively you can pass +those via command line flags passed to ``breeze`` command. -In the ``main`` branch of development line we currently support Python 3.7, 3.8, 3.9. ++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ +| Variable | Local | Build Images | CI | Comment | +| | development | workflow | Workflow | | ++=========================================+=============+==============+============+=================================================+ +| Basic variables | ++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ +| ``PYTHON_MAJOR_MINOR_VERSION`` | | | | Major/Minor version of Python used. | ++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ +| ``DB_RESET`` | false | true | true | Determines whether database should be reset | +| | | | | at the container entry. By default locally | +| | | | | the database is not reset, which allows to | +| | | | | keep the database content between runs in | +| | | | | case of Postgres or MySQL. However, | +| | | | | it requires to perform manual init/reset | +| | | | | if you stop the environment. | ++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ +| Forcing answer | ++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ +| ``ANSWER`` | | yes | yes | This variable determines if answer to questions | +| | | | | during the build process should be | +| | | | | automatically given. For local development, | +| | | | | the user is occasionally asked to provide | +| | | | | answers to questions such as - whether | +| | | | | the image should be rebuilt. By default | +| | | | | the user has to answer but in the CI | +| | | | | environment, we force "yes" answer. | ++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ +| Host variables | ++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ +| ``HOST_USER_ID`` | | | | User id of the host user. | ++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ +| ``HOST_GROUP_ID`` | | | | Group id of the host user. | ++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ +| ``HOST_OS`` | | linux | linux | OS of the Host (darwin/linux/windows). | ++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ +| Git variables | ++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ +| ``COMMIT_SHA`` | | GITHUB_SHA | GITHUB_SHA | SHA of the commit of the build is run | ++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ +| In container environment initialization | ++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ +| ``SKIP_ENVIRONMENT_INITIALIZATION`` | false\* | false\* | false\* | Skip initialization of test environment | +| | | | | | +| | | | | \* set to true in pre-commits | ++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ +| ``SKIP_SSH_SETUP`` | false\* | false\* | false\* | Skip setting up SSH server for tests. | +| | | | | | +| | | | | \* set to true in GitHub CodeSpaces | ++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ +| ``VERBOSE_COMMANDS`` | false | false | false | Determines whether every command | +| | | | | executed in docker should also be printed | +| | | | | before execution. This is a low-level | +| | | | | debugging feature of bash (set -x) enabled in | +| | | | | entrypoint and it should only be used if you | +| | | | | need to debug the bash scripts in container. | ++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ +| Image build variables | ++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ +| ``UPGRADE_TO_NEWER_DEPENDENCIES`` | false | false | false\* | Determines whether the build should | +| | | | | attempt to upgrade Python base image and all | +| | | | | PIP dependencies to latest ones matching | +| | | | | ``setup.py`` limits. This tries to replicate | +| | | | | the situation of "fresh" user who just installs | +| | | | | airflow and uses latest version of matching | +| | | | | dependencies. By default we are using a | +| | | | | tested set of dependency constraints | +| | | | | stored in separated "orphan" branches | +| | | | | of the airflow repository | +| | | | | ("constraints-main, "constraints-2-0") | +| | | | | but when this flag is set to anything but false | +| | | | | (for example random value), they are not used | +| | | | | used and "eager" upgrade strategy is used | +| | | | | when installing dependencies. We set it | +| | | | | to true in case of direct pushes (merges) | +| | | | | to main and scheduled builds so that | +| | | | | the constraints are tested. In those builds, | +| | | | | in case we determine that the tests pass | +| | | | | we automatically push latest set of | +| | | | | "tested" constraints to the repository. | +| | | | | | +| | | | | Setting the value to random value is best way | +| | | | | to assure that constraints are upgraded even if | +| | | | | there is no change to setup.py | +| | | | | | +| | | | | This way our constraints are automatically | +| | | | | tested and updated whenever new versions | +| | | | | of libraries are released. | +| | | | | | +| | | | | \* true in case of direct pushes and | +| | | | | scheduled builds | ++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+ + +Adding new Python versions to CI +================================ In order to add a new version the following operations should be done (example uses Python 3.10) * copy the latest constraints in ``constraints-main`` branch from previous versions and name it using the new Python version (``constraints-3.10.txt``). Commit and push -* add the new Python version to `breeze-complete `_ and - `_initialization.sh `_ - tests will fail if they are not - in sync. - * build image locally for both prod and CI locally using Breeze: .. code-block:: bash - breeze build-image --python 3.10 + breeze ci-image build --python 3.10 * Find the 2 new images (prod, ci) created in `GitHub Container registry `_ diff --git a/CI_DIAGRAMS.md b/CI_DIAGRAMS.md index 3b228b13c71b7..e19228a02444d 100644 --- a/CI_DIAGRAMS.md +++ b/CI_DIAGRAMS.md @@ -38,7 +38,6 @@ sequenceDiagram Note over Tests: Skip Build
(Runs in 'Build Images')
CI Images Note over Tests: Skip Build
(Runs in 'Build Images')
PROD Images par - GitHub Registry ->> Build Images: Pull CI Images
[latest] Note over Build Images: Build CI Images
[COMMIT_SHA]
Use latest constraints
or upgrade if setup changed and Note over Tests: OpenAPI client gen @@ -121,7 +120,6 @@ sequenceDiagram deactivate Build Images Note over Tests: Build info
Decide on tests
Decide on Matrix (selective) par - GitHub Registry ->> Tests: Pull CI Images
[latest] Note over Tests: Build CI Images
[COMMIT_SHA]
Use latest constraints
or upgrade if setup changed and Note over Tests: OpenAPI client gen @@ -134,7 +132,6 @@ sequenceDiagram GitHub Registry ->> Tests: Pull CI Images
[COMMIT_SHA] Note over Tests: Verify CI Image
[COMMIT_SHA] par - GitHub Registry ->> Tests: Pull PROD Images
[latest] Note over Tests: Build PROD Images
[COMMIT_SHA] and opt @@ -181,18 +178,17 @@ sequenceDiagram deactivate Tests ``` -## Direct Push/Merge flow +## Merge "Canary" run ```mermaid sequenceDiagram - Note over Airflow Repo: pull request + Note over Airflow Repo: push/merge Note over Tests: push
[Write Token] activate Airflow Repo Airflow Repo -->> Tests: Trigger 'push' activate Tests Note over Tests: Build info
All tests
Full matrix par - GitHub Registry ->> Tests: Pull CI Images
[latest] Note over Tests: Build CI Images
[COMMIT_SHA]
Always upgrade deps and Note over Tests: OpenAPI client gen @@ -200,12 +196,15 @@ sequenceDiagram Note over Tests: Test UI and Note over Tests: Test examples
PROD image building + and + Note over Tests: Build CI Images
Use original constraints + Tests ->> GitHub Registry: Push CI Image Early cache + latest + Note over Tests: Test 'breeze' image build quickly end Tests ->> GitHub Registry: Push CI Images
[COMMIT_SHA] GitHub Registry ->> Tests: Pull CI Images
[COMMIT_SHA] Note over Tests: Verify CI Image
[COMMIT_SHA] par - GitHub Registry ->> Tests: Pull PROD Images
[latest] Note over Tests: Build PROD Images
[COMMIT_SHA] and opt @@ -249,11 +248,9 @@ sequenceDiagram Tests ->> Airflow Repo: Push constraints if changed end opt In merge run? - GitHub Registry ->> Tests: Pull CI Image
[latest] Note over Tests: Build CI Images
[latest]
Use latest constraints Tests ->> GitHub Registry: Push CI Image
[latest] - GitHub Registry ->> Tests: Pull PROD Image
[latest] - Note over Tests: Build PROD Images
[latest] + Note over Tests: Build PROD Images
[latest]
Use latest constraints Tests ->> GitHub Registry: Push PROD Image
[latest] end Tests -->> Airflow Repo: Status update @@ -261,7 +258,7 @@ sequenceDiagram deactivate Tests ``` -## Scheduled build flow +## Scheduled run ```mermaid sequenceDiagram @@ -280,6 +277,10 @@ sequenceDiagram Note over Tests: Test UI and Note over Tests: Test examples
PROD image building + and + Note over Tests: Build CI Images
Use original constraints + Tests ->> GitHub Registry: Push CI Image Early cache + latest + Note over Tests: Test 'breeze' image build quickly end Tests ->> GitHub Registry: Push CI Images
[COMMIT_SHA] GitHub Registry ->> Tests: Pull CI Images
[COMMIT_SHA] @@ -326,12 +327,10 @@ sequenceDiagram end Note over Tests: Generate constraints Tests ->> Airflow Repo: Push constraints if changed - GitHub Registry ->> Tests: Pull CI Image
[latest] Note over Tests: Build CI Images
[latest]
Use latest constraints - Tests ->> GitHub Registry: Push CI Image
[latest] - GitHub Registry ->> Tests: Pull PROD Image
[latest] - Note over Tests: Build PROD Images
[latest] - Tests ->> GitHub Registry: Push PROD Image
[latest] + Tests ->> GitHub Registry: Push CI Image cache + latest + Note over Tests: Build PROD Images
[latest]
Use latest constraints + Tests ->> GitHub Registry: Push PROD Image cache + latest Tests -->> Airflow Repo: Status update deactivate Airflow Repo deactivate Tests diff --git a/COMMITTERS.rst b/COMMITTERS.rst index 054988407cb60..3d05f2ed197e0 100644 --- a/COMMITTERS.rst +++ b/COMMITTERS.rst @@ -27,7 +27,7 @@ Before reading this document, you should be familiar with `Contributor's guide < Guidelines to become an Airflow Committer ------------------------------------------ -Committers are community members who have write access to the project’s +Committers are community members who have write access to the project's repositories, i.e., they can modify the code, documentation, and website by themselves and also accept other contributions. There is no strict protocol for becoming a committer. Candidates for new committers are typically people that are active contributors and community members. @@ -75,9 +75,10 @@ Code contribution Community contributions ^^^^^^^^^^^^^^^^^^^^^^^^ -1. Was instrumental in triaging issues -2. Improved documentation of Airflow in significant way -3. Lead change and improvements introduction in the “community” processes and tools +1. Actively participates in `triaging issues `_ showing their understanding + of various areas of Airflow and willingness to help other community members. +2. Improves documentation of Airflow in significant way +3. Leads/implements changes and improvements introduction in the "community" processes and tools 4. Actively spreads the word about Airflow, for example organising Airflow summit, workshops for community members, giving and recording talks, writing blogs 5. Reporting bugs with detailed reproduction steps @@ -186,3 +187,4 @@ To be able to merge PRs, committers have to integrate their GitHub ID with Apach 3. Merge your Apache and GitHub accounts using `GitBox (Apache Account Linking utility) `__. You should see 3 green checks in GitBox. 4. Wait at least 30 minutes for an email inviting you to Apache GitHub Organization and accept invitation. 5. After accepting the GitHub Invitation verify that you are a member of the `Airflow committers team on GitHub `__. +6. Ask in ``#internal-airflow-ci-cd`` channel to be `configured in self-hosted runners `_ by the CI maintainers diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 8bab15352577a..a5e00305c1297 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -134,7 +134,7 @@ and guidelines. Committers/Maintainers ---------------------- -Committers are community members that have write access to the project’s repositories, i.e., they can modify the code, +Committers are community members that have write access to the project's repositories, i.e., they can modify the code, documentation, and website by themselves and also accept other contributions. The official list of committers can be found `here `__. @@ -277,7 +277,7 @@ For effective collaboration, make sure to join the following Airflow groups: - Mailing lists: - - Developer’s mailing list ``_ + - Developer's mailing list ``_ (quite substantial traffic on this list) - All commits mailing list: ``_ @@ -341,13 +341,11 @@ Step 4: Prepare PR * `doc` * `misc` - To add a newsfragment, simply create an rst file named ``{pr_number}.{type}.rst`` (e.g. ``1234.bugfix.rst``) + To add a newsfragment, create an ``rst`` file named ``{pr_number}.{type}.rst`` (e.g. ``1234.bugfix.rst``) and place in either `newsfragments `__ for core newsfragments, or `chart/newsfragments `__ for helm chart newsfragments. - For significant newsfragments, similar to git commits, the first line is the summary and optionally a - body can be added with an empty line separating it. - For other newsfragment types, only use a single summary line. + In general newsfragments must be one line. For newsfragment type ``significant``, you may include summary and body separated by a blank line, similar to ``git`` commit messages. 2. Rebase your fork, squash commits, and resolve all conflicts. See `How to rebase PR <#how-to-rebase-pr>`_ if you need help with rebasing your change. Remember to rebase often if your PR takes a lot of time to @@ -361,33 +359,6 @@ Step 4: Prepare PR PR guidelines described in `pull request guidelines <#pull-request-guidelines>`_. Create Pull Request! Make yourself ready for the discussion! -5. Depending on "scope" of your changes, your Pull Request might go through one of few paths after approval. - We run some non-standard workflow with high degree of automation that allows us to optimize the usage - of queue slots in GitHub Actions. Our automated workflows determine the "scope" of changes in your PR - and send it through the right path: - - * In case of a "no-code" change, approval will generate a comment that the PR can be merged and no - tests are needed. This is usually when the change modifies some non-documentation related RST - files (such as this file). No python tests are run and no CI images are built for such PR. Usually - it can be approved and merged few minutes after it is submitted (unless there is a big queue of jobs). - - * In case of change involving python code changes or documentation changes, a subset of full test matrix - will be executed. This subset of tests perform relevant tests for single combination of python, backend - version and only builds one CI image and one PROD image. Here the scope of tests depends on the - scope of your changes: - - * when your change does not change "core" of Airflow (Providers, CLI, WWW, Helm Chart) you will get the - comment that PR is likely ok to be merged without running "full matrix" of tests. However decision - for that is left to committer who approves your change. The committer might set a "full tests needed" - label for your PR and ask you to rebase your request or re-run all jobs. PRs with "full tests needed" - run full matrix of tests. - - * when your change changes the "core" of Airflow you will get the comment that PR needs full tests and - the "full tests needed" label is set for your PR. Additional check is set that prevents from - accidental merging of the request until full matrix of tests succeeds for the PR. - - More details about the PR workflow be found in `PULL_REQUEST_WORKFLOW.rst `_. - Step 5: Pass PR Review ---------------------- @@ -574,13 +545,6 @@ All details about using and running Airflow Breeze can be found in The Airflow Breeze solution is intended to ease your local development as "*It's a Breeze to develop Airflow*". -.. note:: - - We are in a process of switching to the new Python-based Breeze from a legacy Bash - Breeze. Not all functionality has been ported yet and the old Breeze is still available - until then as ``./breeze-legacy`` script. The documentation mentions when the old ./breeze-legacy - should be still used. - Benefits: - Breeze is a complete environment that includes external components, such as @@ -648,16 +612,16 @@ This is the full list of those extras: .. START EXTRAS HERE airbyte, alibaba, all, all_dbs, amazon, apache.atlas, apache.beam, apache.cassandra, apache.drill, apache.druid, apache.hdfs, apache.hive, apache.kylin, apache.livy, apache.pig, apache.pinot, -apache.spark, apache.sqoop, apache.webhdfs, arangodb, asana, async, atlas, aws, azure, cassandra, -celery, cgroups, cloudant, cncf.kubernetes, crypto, dask, databricks, datadog, dbt.cloud, -deprecated_api, devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc, docker, druid, -elasticsearch, exasol, facebook, ftp, gcp, gcp_api, github, github_enterprise, google, google_auth, -grpc, hashicorp, hdfs, hive, http, imap, influxdb, jdbc, jenkins, jira, kerberos, kubernetes, ldap, -leveldb, microsoft.azure, microsoft.mssql, microsoft.psrp, microsoft.winrm, mongo, mssql, mysql, -neo4j, odbc, openfaas, opsgenie, oracle, pagerduty, pandas, papermill, password, pinot, plexus, -postgres, presto, qds, qubole, rabbitmq, redis, s3, salesforce, samba, segment, sendgrid, sentry, -sftp, singularity, slack, snowflake, spark, sqlite, ssh, statsd, tableau, telegram, trino, vertica, -virtualenv, webhdfs, winrm, yandex, zendesk +apache.spark, apache.sqoop, apache.webhdfs, arangodb, asana, async, atlas, atlassian.jira, aws, +azure, cassandra, celery, cgroups, cloudant, cncf.kubernetes, common.sql, crypto, dask, databricks, +datadog, dbt.cloud, deprecated_api, devel, devel_all, devel_ci, devel_hadoop, dingding, discord, +doc, doc_gen, docker, druid, elasticsearch, exasol, facebook, ftp, gcp, gcp_api, github, +github_enterprise, google, google_auth, grpc, hashicorp, hdfs, hive, http, imap, influxdb, jdbc, +jenkins, kerberos, kubernetes, ldap, leveldb, microsoft.azure, microsoft.mssql, microsoft.psrp, +microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas, opsgenie, oracle, pagerduty, pandas, +papermill, password, pinot, plexus, postgres, presto, qds, qubole, rabbitmq, redis, s3, salesforce, +samba, segment, sendgrid, sentry, sftp, singularity, slack, snowflake, spark, sqlite, ssh, statsd, +tableau, tabular, telegram, trino, vertica, virtualenv, webhdfs, winrm, yandex, zendesk .. END EXTRAS HERE Provider packages @@ -666,7 +630,23 @@ Provider packages Airflow 2.0 is split into core and providers. They are delivered as separate packages: * ``apache-airflow`` - core of Apache Airflow -* ``apache-airflow-providers-*`` - More than 50 provider packages to communicate with external services +* ``apache-airflow-providers-*`` - More than 70 provider packages to communicate with external services + +The information/meta-data about the providers is kept in ``provider.yaml`` file in the right sub-directory +of ``airflow\providers``. This file contains: + +* package name (``apache-airflow-provider-*``) +* user-facing name of the provider package +* description of the package that is available in the documentation +* list of versions of package that have been released so far +* list of dependencies of the provider package +* list of additional-extras that the provider package provides (together with dependencies of those extras) +* list of integrations, operators, hooks, sensors, transfers provided by the provider (useful for documentation generation) +* list of connection types, extra-links, secret backends, auth backends, and logging handlers (useful to both + register them as they are needed by Airflow and to include them in documentation automatically). + +If you want to add dependencies to the provider, you should add them to the corresponding ``provider.yaml`` +and Airflow pre-commits and package generation commands will use them when preparing package information. In Airflow 1.10 all those providers were installed together within one single package and when you installed airflow locally, from sources, they were also installed. In Airflow 2.0, providers are separated out, @@ -685,7 +665,7 @@ in this airflow folder - the providers package is importable. Some of the packages have cross-dependencies with other providers packages. This typically happens for transfer operators where operators use hooks from the other providers in case they are transferring data between the providers. The list of dependencies is maintained (automatically with pre-commits) -in the ``airflow/providers/dependencies.json``. Pre-commits are also used to generate dependencies. +in the ``generated/provider_dependencies.json``. Pre-commits are also used to generate dependencies. The dependency list is automatically used during PyPI packages generation. Cross-dependencies between provider packages are converted into extras - if you need functionality from @@ -695,49 +675,8 @@ the other provider package you can install it adding [extra] after the transfer operators from Amazon ECS. If you add a new dependency between different providers packages, it will be detected automatically during -pre-commit phase and pre-commit will fail - and add entry in dependencies.json so that the package extra -dependencies are properly added when package is installed. - -You can regenerate the whole list of provider dependencies by running this command (you need to have -``pre-commits`` installed). - -.. code-block:: bash - - pre-commit run build-providers-dependencies - - -Here is the list of packages and their extras: - - - .. START PACKAGE DEPENDENCIES HERE - -========================== =========================== -Package Extras -========================== =========================== -airbyte http -amazon apache.hive,cncf.kubernetes,exasol,ftp,google,imap,mongo,salesforce,ssh -apache.beam google -apache.druid apache.hive -apache.hive amazon,microsoft.mssql,mysql,presto,samba,vertica -apache.livy http -dbt.cloud http -dingding http -discord http -google amazon,apache.beam,apache.cassandra,cncf.kubernetes,facebook,microsoft.azure,microsoft.mssql,mysql,oracle,postgres,presto,salesforce,sftp,ssh,trino -hashicorp google -microsoft.azure google,oracle,sftp -mysql amazon,presto,trino,vertica -postgres amazon -presto google -salesforce tableau -sftp ssh -slack http -snowflake slack -trino google -========================== =========================== - - .. END PACKAGE DEPENDENCIES HERE - +and pre-commit will generate new entry in ``generated/provider_dependencies.json`` so that +the package extra dependencies are properly handled when package is installed. Developing community managed provider packages ---------------------------------------------- @@ -1190,12 +1129,12 @@ itself comes bundled with jQuery and bootstrap. While they may be phased out over time, these packages are currently not managed with yarn. Make sure you are using recent versions of node and yarn. No problems have been -found with node\>=8.11.3 and yarn\>=1.19.1. - -Installing yarn and its packages --------------------------------- +found with node\>=8.11.3 and yarn\>=1.19.1. The pre-commit framework of ours install +node and yarn automatically when installed - if you use ``breeze`` you do not need to install +neither node nor yarn. -Make sure yarn is available in your environment. +Installing yarn and its packages manually +----------------------------------------- To install yarn on macOS: @@ -1219,27 +1158,6 @@ To install yarn on macOS: export PATH="$HOME/.yarn/bin:$PATH" 4. Install third-party libraries defined in ``package.json`` by running the - following commands within the ``airflow/www/`` directory: - - -.. code-block:: bash - - # from the root of the repository, move to where our JS package.json lives - cd airflow/www/ - # run yarn install to fetch all the dependencies - yarn install - - -These commands install the libraries in a new ``node_modules/`` folder within -``www/``. - -Should you add or upgrade a node package, run -``yarn add --dev `` for packages needed in development or -``yarn add `` for packages used by the code. -Then push the newly generated ``package.json`` and ``yarn.lock`` file so that we -could get a reproducible build. See the `Yarn docs -`_ for more details. - Generate Bundled Files with yarn -------------------------------- @@ -1253,7 +1171,7 @@ commands: yarn run prod # Starts a web server that manages and updates your assets as you modify them - # You'll need to run the webserver in debug mode too: `airflow webserver -d` + # You'll need to run the webserver in debug mode too: ``airflow webserver -d`` yarn run dev @@ -1285,10 +1203,10 @@ commands: React, JSX and Chakra ----------------------------- -In order to create a more modern UI, we have started to include [React](https://reactjs.org/) in the ``airflow/www/`` project. +In order to create a more modern UI, we have started to include `React `__ in the ``airflow/www/`` project. If you are unfamiliar with React then it is recommended to check out their documentation to understand components and jsx syntax. -We are using [Chakra UI](https://chakra-ui.com/) as a component and styling library. Notably, all styling is done in a theme file or +We are using `Chakra UI `__ as a component and styling library. Notably, all styling is done in a theme file or inline when defining a component. There are a few shorthand style props like ``px`` instead of ``padding-right, padding-left``. To make this work, all Chakra styling and css styling are completely separate. It is best to think of the React components as a separate app that lives inside of the main app. @@ -1526,14 +1444,14 @@ Here are a few rules that are important to keep in mind when you enter our commu * There is a #newbie-questions channel in slack as a safe place to ask questions * You can ask one of the committers to be a mentor for you, committers can guide within the community * You can apply to more structured `Apache Mentoring Programme `_ -* It’s your responsibility as an author to take your PR from start-to-end including leading communication +* It's your responsibility as an author to take your PR from start-to-end including leading communication in the PR -* It’s your responsibility as an author to ping committers to review your PR - be mildly annoying sometimes, - it’s OK to be slightly annoying with your change - it is also a sign for committers that you care +* It's your responsibility as an author to ping committers to review your PR - be mildly annoying sometimes, + it's OK to be slightly annoying with your change - it is also a sign for committers that you care * Be considerate to the high code quality/test coverage requirements for Apache Airflow * If in doubt - ask the community for their opinion or propose to vote at the devlist * Discussions should concern subject matters - judge or criticise the merit but never criticise people -* It’s OK to express your own emotions while communicating - it helps other people to understand you +* It's OK to express your own emotions while communicating - it helps other people to understand you * Be considerate for feelings of others. Tell about how you feel not what you think of others Commit Policy @@ -1549,6 +1467,6 @@ and slightly modified and consensus reached in October 2020: Resources & Links ================= -- `Airflow’s official documentation `__ +- `Airflow's official documentation `__ - `More resources and links to Airflow related content on the Wiki `__ diff --git a/CONTRIBUTORS_QUICK_START.rst b/CONTRIBUTORS_QUICK_START.rst index 10ff2cc9516d9..bffd4967c8902 100644 --- a/CONTRIBUTORS_QUICK_START.rst +++ b/CONTRIBUTORS_QUICK_START.rst @@ -50,11 +50,11 @@ Local machine development If you do not work with remote development environment, you need those prerequisites. -1. Docker Community Edition +1. Docker Community Edition (you can also use Colima, see instructions below) 2. Docker Compose 3. pyenv (you can also use pyenv-virtualenv or virtualenvwrapper) -The below setup describe Ubuntu installation. It might be slightly different on different machines. +The below setup describe `Ubuntu installation `_. It might be slightly different on different machines. Docker Community Edition ------------------------ @@ -66,25 +66,24 @@ Docker Community Edition $ sudo apt-get update $ sudo apt-get install \ - apt-transport-https \ ca-certificates \ curl \ - gnupg-agent \ - software-properties-common + gnupg \ + lsb-release - $ curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add - + $ sudo mkdir -p /etc/apt/keyrings + $ curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg - $ sudo add-apt-repository \ - "deb [arch=amd64] https://download.docker.com/linux/ubuntu \ - $(lsb_release -cs) \ - stable" + $ echo \ + "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \ + $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null -2. Install Docker +2. Install Docker Engine, containerd, and Docker Compose Plugin. .. code-block:: bash $ sudo apt-get update - $ sudo apt-get install docker-ce docker-ce-cli containerd.io + $ sudo apt-get install docker-ce docker-ce-cli containerd.io docker-compose-plugin 3. Creating group for docker and adding current user to it. @@ -101,6 +100,24 @@ Note : After adding user to docker group Logout and Login again for group member $ docker run hello-world +Colima +------ +If you use Colima as your container runtimes engine, please follow the next steps: + +1. `Install buildx manually `_ and follow it's instructions + +2. Link the Colima socket to the default socket path. Note that this may break other Docker servers. + +.. code-block:: bash + + $ sudo ln -sf $HOME/.colima/default/docker.sock /var/run/docker.sock + +3. Change docker context to use default: + +.. code-block:: bash + + $ docker context use default + Docker Compose -------------- @@ -218,13 +235,13 @@ Setting up Breeze .. code-block:: bash - $ breeze setup-autocomplete + $ breeze setup autocomplete 4. Initialize breeze environment with required python version and backend. This may take a while for first time. .. code-block:: bash - $ breeze --python 3.7 --backend mysql + $ breeze --python 3.8 --backend postgres .. note:: If you encounter an error like "docker.credentials.errors.InitializationError: @@ -272,7 +289,7 @@ Using Breeze ------------ 1. Starting breeze environment using ``breeze start-airflow`` starts Breeze environment with last configuration run( - In this case python and backend will be picked up from last execution ``breeze --python 3.8 --backend mysql``) + In this case python and backend will be picked up from last execution ``breeze --python 3.8 --backend postgres``) It also automatically starts webserver, backend and scheduler. It drops you in tmux with scheduler in bottom left and webserver in bottom right. Use ``[Ctrl + B] and Arrow keys`` to navigate. @@ -301,7 +318,7 @@ Using Breeze * 26379 -> forwarded to Redis broker -> redis:6379 Here are links to those services that you can use on host: - * ssh connection for remote debugging: ssh -p 12322 airflow@127.0.0.1 pw: airflow + * ssh connection for remote debugging: ssh -p 12322 airflow@127.0.0.1 (password: airflow) * Webserver: http://127.0.0.1:28080 * Flower: http://127.0.0.1:25555 * Postgres: jdbc:postgresql://127.0.0.1:25433/airflow?user=postgres&password=airflow @@ -324,7 +341,7 @@ Using Breeze .. code-block:: bash - $ breeze --python 3.8 --backend mysql + $ breeze --python 3.8 --backend postgres 2. Open tmux @@ -505,7 +522,7 @@ To avoid burden on CI infrastructure and to save time, Pre-commit hooks can be r .. code-block:: bash - $ pre-commit run --files airflow/decorators.py tests/utils/test_task_group.py + $ pre-commit run --files airflow/utils/decorators.py tests/utils/test_task_group.py @@ -515,7 +532,7 @@ To avoid burden on CI infrastructure and to save time, Pre-commit hooks can be r $ pre-commit run black --files airflow/decorators.py tests/utils/test_task_group.py black...............................................................Passed - $ pre-commit run flake8 --files airflow/decorators.py tests/utils/test_task_group.py + $ pre-commit run run-flake8 --files airflow/decorators.py tests/utils/test_task_group.py Run flake8..........................................................Passed @@ -612,8 +629,7 @@ All Tests are inside ./tests directory. .. code-block:: bash - $ breeze --backend mysql --mysql-version 5.7 --python 3.8 --db-reset --test-type All tests - + $ breeze --backend postgres --postgres-version 10 --python 3.8 --db-reset testing tests --test-type All - Running specific type of test @@ -623,7 +639,7 @@ All Tests are inside ./tests directory. .. code-block:: bash - $ breeze --backend mysql --mysql-version 5.7 --python 3.8 --db-reset --test-type Core + $ breeze --backend postgres --postgres-version 10 --python 3.8 --db-reset testing tests --test-type Core - Running Integration test for specific test type @@ -632,7 +648,7 @@ All Tests are inside ./tests directory. .. code-block:: bash - $ breeze --backend mysql --mysql-version 5.7 --python 3.8 --db-reset --test-type All --integration mongo + $ breeze --backend postgres --postgres-version 10 --python 3.8 --db-reset testing tests --test-type All --integration mongo - For more information on Testing visit : |TESTING.rst| diff --git a/Dockerfile b/Dockerfile index cab46244ae954..3da82f2d2b15b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -44,11 +44,11 @@ ARG AIRFLOW_UID="50000" ARG AIRFLOW_USER_HOME_DIR=/home/airflow # latest released version here -ARG AIRFLOW_VERSION="2.3.1" +ARG AIRFLOW_VERSION="2.4.3" ARG PYTHON_BASE_IMAGE="python:3.7-slim-bullseye" -ARG AIRFLOW_PIP_VERSION=22.1.2 +ARG AIRFLOW_PIP_VERSION=22.3.1 ARG AIRFLOW_IMAGE_REPOSITORY="https://github.com/apache/airflow" ARG AIRFLOW_IMAGE_README_URL="https://raw.githubusercontent.com/apache/airflow/main/docs/docker-stack/README.md" @@ -71,39 +71,106 @@ FROM scratch as scripts # make the PROD Dockerfile standalone ############################################################################################## -# The content below is automatically copied from scripts/docker/determine_debian_version_specific_variables.sh -COPY <<"EOF" /determine_debian_version_specific_variables.sh -function determine_debian_version_specific_variables() { - local color_red - color_red=$'\e[31m' - local color_reset - color_reset=$'\e[0m' - - local debian_version - debian_version=$(lsb_release -cs) - if [[ ${debian_version} == "buster" ]]; then - export DISTRO_LIBENCHANT="libenchant-dev" - export DISTRO_LIBGCC="libgcc-8-dev" - export DISTRO_SELINUX="python-selinux" - export DISTRO_LIBFFI="libffi6" - # Note missing man directories on debian-buster - # https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=863199 - mkdir -pv /usr/share/man/man1 - mkdir -pv /usr/share/man/man7 - elif [[ ${debian_version} == "bullseye" ]]; then - export DISTRO_LIBENCHANT="libenchant-2-2" - export DISTRO_LIBGCC="libgcc-10-dev" - export DISTRO_SELINUX="python3-selinux" - export DISTRO_LIBFFI="libffi7" +# The content below is automatically copied from scripts/docker/install_os_dependencies.sh +COPY <<"EOF" /install_os_dependencies.sh +set -euo pipefail + +DOCKER_CLI_VERSION=20.10.9 + +if [[ "$#" != 1 ]]; then + echo "ERROR! There should be 'runtime' or 'dev' parameter passed as argument.". + exit 1 +fi + +if [[ "${1}" == "runtime" ]]; then + INSTALLATION_TYPE="RUNTIME" +elif [[ "${1}" == "dev" ]]; then + INSTALLATION_TYPE="dev" +else + echo "ERROR! Wrong argument. Passed ${1} and it should be one of 'runtime' or 'dev'.". + exit 1 +fi + +function get_dev_apt_deps() { + if [[ "${DEV_APT_DEPS=}" == "" ]]; then + DEV_APT_DEPS="apt-transport-https apt-utils build-essential ca-certificates dirmngr \ +freetds-bin freetds-dev git gosu graphviz graphviz-dev krb5-user ldap-utils libffi-dev \ +libkrb5-dev libldap2-dev libleveldb1d libleveldb-dev libsasl2-2 libsasl2-dev libsasl2-modules \ +libssl-dev locales lsb-release openssh-client sasl2-bin \ +software-properties-common sqlite3 sudo unixodbc unixodbc-dev" + export DEV_APT_DEPS + fi +} + +function get_runtime_apt_deps() { + if [[ "${RUNTIME_APT_DEPS=}" == "" ]]; then + RUNTIME_APT_DEPS="apt-transport-https apt-utils ca-certificates \ +curl dumb-init freetds-bin gosu krb5-user \ +ldap-utils libffi7 libldap-2.4-2 libsasl2-2 libsasl2-modules libssl1.1 locales \ +lsb-release netcat openssh-client python3-selinux rsync sasl2-bin sqlite3 sudo unixodbc" + export RUNTIME_APT_DEPS + fi +} + +function install_docker_cli() { + local platform + if [[ $(uname -m) == "arm64" || $(uname -m) == "aarch64" ]]; then + platform="aarch64" else - echo - echo "${color_red}Unknown distro version ${debian_version}${color_reset}" - echo - exit 1 + platform="x86_64" fi + curl --silent \ + "https://download.docker.com/linux/static/stable/${platform}/docker-${DOCKER_CLI_VERSION}.tgz" \ + | tar -C /usr/bin --strip-components=1 -xvzf - docker/docker } -determine_debian_version_specific_variables +function install_debian_dev_dependencies() { + apt-get update + apt-get install --no-install-recommends -yqq apt-utils >/dev/null 2>&1 + apt-get install -y --no-install-recommends curl gnupg2 lsb-release + # shellcheck disable=SC2086 + export ${ADDITIONAL_DEV_APT_ENV?} + if [[ ${DEV_APT_COMMAND} != "" ]]; then + bash -o pipefail -o errexit -o nounset -o nolog -c "${DEV_APT_COMMAND}" + fi + if [[ ${ADDITIONAL_DEV_APT_COMMAND} != "" ]]; then + bash -o pipefail -o errexit -o nounset -o nolog -c "${ADDITIONAL_DEV_APT_COMMAND}" + fi + apt-get update + # shellcheck disable=SC2086 + apt-get install -y --no-install-recommends ${DEV_APT_DEPS} ${ADDITIONAL_DEV_APT_DEPS} +} + +function install_debian_runtime_dependencies() { + apt-get update + apt-get install --no-install-recommends -yqq apt-utils >/dev/null 2>&1 + apt-get install -y --no-install-recommends curl gnupg2 lsb-release + # shellcheck disable=SC2086 + export ${ADDITIONAL_RUNTIME_APT_ENV?} + if [[ "${RUNTIME_APT_COMMAND}" != "" ]]; then + bash -o pipefail -o errexit -o nounset -o nolog -c "${RUNTIME_APT_COMMAND}" + fi + if [[ "${ADDITIONAL_RUNTIME_APT_COMMAND}" != "" ]]; then + bash -o pipefail -o errexit -o nounset -o nolog -c "${ADDITIONAL_RUNTIME_APT_COMMAND}" + fi + apt-get update + # shellcheck disable=SC2086 + apt-get install -y --no-install-recommends ${RUNTIME_APT_DEPS} ${ADDITIONAL_RUNTIME_APT_DEPS} + apt-get autoremove -yqq --purge + apt-get clean + rm -rf /var/lib/apt/lists/* /var/log/* +} + +if [[ "${INSTALLATION_TYPE}" == "RUNTIME" ]]; then + get_runtime_apt_deps + install_debian_runtime_dependencies + install_docker_cli + +else + get_dev_apt_deps + install_debian_dev_dependencies + install_docker_cli +fi EOF # The content below is automatically copied from scripts/docker/install_mysql.sh @@ -178,8 +245,6 @@ set -euo pipefail COLOR_BLUE=$'\e[34m' readonly COLOR_BLUE -COLOR_YELLOW=$'\e[33m' -readonly COLOR_YELLOW COLOR_RESET=$'\e[0m' readonly COLOR_RESET @@ -197,19 +262,8 @@ function install_mssql_client() { local distro local version distro=$(lsb_release -is | tr '[:upper:]' '[:lower:]') - version_name=$(lsb_release -cs | tr '[:upper:]' '[:lower:]') version=$(lsb_release -rs) - local driver - if [[ ${version_name} == "buster" ]]; then - driver=msodbcsql17 - elif [[ ${version_name} == "bullseye" ]]; then - driver=msodbcsql18 - else - echo - echo "${COLOR_YELLOW}Only Buster or Bullseye are supported. Skipping MSSQL installation${COLOR_RESET}" - echo - return - fi + local driver=msodbcsql18 curl --silent https://packages.microsoft.com/keys/microsoft.asc | apt-key add - >/dev/null 2>&1 curl --silent "https://packages.microsoft.com/config/${distro}/${version}/prod.list" > \ /etc/apt/sources.list.d/mssql-release.list @@ -221,11 +275,6 @@ function install_mssql_client() { apt-get clean && rm -rf /var/lib/apt/lists/* } -if [[ $(uname -m) == "arm64" || $(uname -m) == "aarch64" ]]; then - # disable MSSQL for ARM64 - INSTALL_MSSQL_CLIENT="false" -fi - install_mssql_client "${@}" EOF @@ -276,20 +325,12 @@ COPY <<"EOF" /install_pip_version.sh : "${AIRFLOW_PIP_VERSION:?Should be set}" -function install_pip_version() { - echo - echo "${COLOR_BLUE}Installing pip version ${AIRFLOW_PIP_VERSION}${COLOR_RESET}" - echo - pip install --disable-pip-version-check --no-cache-dir --upgrade "pip==${AIRFLOW_PIP_VERSION}" && - mkdir -p ${HOME}/.local/bin -} - common::get_colors common::get_airflow_version_specification common::override_pip_version_if_needed common::show_pip_version_and_location -install_pip_version +common::install_pip_version EOF # The content below is automatically copied from scripts/docker/install_airflow_dependencies_from_branch_tip.sh @@ -317,10 +358,10 @@ function install_airflow_dependencies_from_branch_tip() { # are conflicts, this might fail, but it should be fixed in the following installation steps set -x pip install --root-user-action ignore \ + ${ADDITIONAL_PIP_INSTALL_FLAGS} \ "https://github.com/${AIRFLOW_REPO}/archive/${AIRFLOW_BRANCH}.tar.gz#egg=apache-airflow[${AIRFLOW_EXTRAS}]" \ --constraint "${AIRFLOW_CONSTRAINTS_LOCATION}" || true - # make sure correct PIP version is used - pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null + common::install_pip_version pip freeze | grep apache-airflow-providers | xargs pip uninstall --yes 2>/dev/null || true set +x echo @@ -367,7 +408,7 @@ function common::get_airflow_version_specification() { function common::override_pip_version_if_needed() { if [[ -n ${AIRFLOW_VERSION} ]]; then if [[ ${AIRFLOW_VERSION} =~ ^2\.0.* || ${AIRFLOW_VERSION} =~ ^1\.* ]]; then - export AIRFLOW_PIP_VERSION="22.1.2" + export AIRFLOW_PIP_VERSION="22.3.1" fi fi } @@ -395,101 +436,18 @@ function common::show_pip_version_and_location() { echo "pip on path: $(which pip)" echo "Using pip: $(pip --version)" } -EOF - -# The content below is automatically copied from scripts/docker/prepare_node_modules.sh -COPY <<"EOF" /prepare_node_modules.sh -set -euo pipefail - -COLOR_BLUE=$'\e[34m' -readonly COLOR_BLUE -COLOR_RESET=$'\e[0m' -readonly COLOR_RESET - -function prepare_node_modules() { - echo - echo "${COLOR_BLUE}Preparing node modules${COLOR_RESET}" - echo - local www_dir - if [[ ${AIRFLOW_INSTALLATION_METHOD=} == "." ]]; then - # In case we are building from sources in production image, we should build the assets - www_dir="${AIRFLOW_SOURCES_TO=${AIRFLOW_SOURCES}}/airflow/www" - else - www_dir="$(python -m site --user-site)/airflow/www" - fi - pushd ${www_dir} || exit 1 - set +e - yarn install --frozen-lockfile --no-cache 2>/tmp/out-yarn-install.txt - local res=$? - if [[ ${res} != 0 ]]; then - >&2 echo - >&2 echo "Error when running yarn install:" - >&2 echo - >&2 cat /tmp/out-yarn-install.txt && rm -f /tmp/out-yarn-install.txt - exit 1 - fi - rm -f /tmp/out-yarn-install.txt - popd || exit 1 -} -prepare_node_modules -EOF - -# The content below is automatically copied from scripts/docker/compile_www_assets.sh -COPY <<"EOF" /compile_www_assets.sh -set -euo pipefail - -BUILD_TYPE=${BUILD_TYPE="prod"} -REMOVE_ARTIFACTS=${REMOVE_ARTIFACTS="true"} - -COLOR_BLUE=$'\e[34m' -readonly COLOR_BLUE -COLOR_RESET=$'\e[0m' -readonly COLOR_RESET - -function compile_www_assets() { +function common::install_pip_version() { echo - echo "${COLOR_BLUE}Compiling www assets: running yarn ${BUILD_TYPE}${COLOR_RESET}" + echo "${COLOR_BLUE}Installing pip version ${AIRFLOW_PIP_VERSION}${COLOR_RESET}" echo - local www_dir - if [[ ${AIRFLOW_INSTALLATION_METHOD=} == "." ]]; then - # In case we are building from sources in production image, we should build the assets - www_dir="${AIRFLOW_SOURCES_TO=${AIRFLOW_SOURCES}}/airflow/www" + if [[ ${AIRFLOW_PIP_VERSION} =~ .*https.* ]]; then + pip install --disable-pip-version-check --no-cache-dir "pip @ ${AIRFLOW_PIP_VERSION}" else - www_dir="$(python -m site --user-site)/airflow/www" + pip install --disable-pip-version-check --no-cache-dir "pip==${AIRFLOW_PIP_VERSION}" fi - pushd ${www_dir} || exit 1 - set +e - yarn run "${BUILD_TYPE}" 2>/tmp/out-yarn-run.txt - res=$? - if [[ ${res} != 0 ]]; then - >&2 echo - >&2 echo "Error when running yarn run:" - >&2 echo - >&2 cat /tmp/out-yarn-run.txt && rm -rf /tmp/out-yarn-run.txt - exit 1 - fi - rm -f /tmp/out-yarn-run.txt - set -e - local md5sum_file - md5sum_file="static/dist/sum.md5" - readonly md5sum_file - find package.json yarn.lock static/css static/js -type f | sort | xargs md5sum > "${md5sum_file}" - if [[ ${REMOVE_ARTIFACTS} == "true" ]]; then - echo - echo "${COLOR_BLUE}Removing generated node modules${COLOR_RESET}" - echo - rm -rf "${www_dir}/node_modules" - rm -vf "${www_dir}"/{package.json,yarn.lock,.eslintignore,.eslintrc,.stylelintignore,.stylelintrc,compile_assets.sh,webpack.config.js} - else - echo - echo "${COLOR_BLUE}Leaving generated node modules${COLOR_RESET}" - echo - fi - popd || exit 1 + mkdir -p "${HOME}/.local/bin" } - -compile_www_assets EOF # The content below is automatically copied from scripts/docker/pip @@ -571,12 +529,12 @@ function install_airflow_and_providers_from_docker_context_files(){ # force reinstall all airflow + provider package local files with eager upgrade set -x pip install "${pip_flags[@]}" --root-user-action ignore --upgrade --upgrade-strategy eager \ + ${ADDITIONAL_PIP_INSTALL_FLAGS} \ ${reinstalling_apache_airflow_package} ${reinstalling_apache_airflow_providers_packages} \ ${EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS} set +x - # make sure correct PIP version is left installed - pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null + common::install_pip_version pip check } @@ -591,10 +549,10 @@ function install_all_other_packages_from_docker_context_files() { grep -v apache_airflow | grep -v apache-airflow || true) if [[ -n "${reinstalling_other_packages}" ]]; then set -x - pip install --root-user-action ignore --force-reinstall --no-deps --no-index ${reinstalling_other_packages} - # make sure correct PIP version is used - pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null - set -x + pip install ${ADDITIONAL_PIP_INSTALL_FLAGS} \ + --root-user-action ignore --force-reinstall --no-deps --no-index ${reinstalling_other_packages} + common::install_pip_version + set +x fi } @@ -642,6 +600,7 @@ function install_airflow() { echo # eager upgrade pip install --root-user-action ignore --upgrade --upgrade-strategy eager \ + ${ADDITIONAL_PIP_INSTALL_FLAGS} \ "${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_VERSION_SPECIFICATION}" \ ${EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS} if [[ -n "${AIRFLOW_INSTALL_EDITABLE_FLAG}" ]]; then @@ -650,12 +609,12 @@ function install_airflow() { set -x pip uninstall apache-airflow --yes pip install --root-user-action ignore ${AIRFLOW_INSTALL_EDITABLE_FLAG} \ + ${ADDITIONAL_PIP_INSTALL_FLAGS} \ "${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_VERSION_SPECIFICATION}" set +x fi - # make sure correct PIP version is used - pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null + common::install_pip_version echo echo "${COLOR_BLUE}Running 'pip check'${COLOR_RESET}" echo @@ -666,16 +625,16 @@ function install_airflow() { echo set -x pip install --root-user-action ignore ${AIRFLOW_INSTALL_EDITABLE_FLAG} \ + ${ADDITIONAL_PIP_INSTALL_FLAGS} \ "${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_VERSION_SPECIFICATION}" \ --constraint "${AIRFLOW_CONSTRAINTS_LOCATION}" - # make sure correct PIP version is used - pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null + common::install_pip_version # then upgrade if needed without using constraints to account for new limits in setup.py pip install --root-user-action ignore --upgrade --upgrade-strategy only-if-needed \ + ${ADDITIONAL_PIP_INSTALL_FLAGS} \ ${AIRFLOW_INSTALL_EDITABLE_FLAG} \ "${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_VERSION_SPECIFICATION}" - # make sure correct PIP version is used - pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null + common::install_pip_version set +x echo echo "${COLOR_BLUE}Running 'pip check'${COLOR_RESET}" @@ -712,9 +671,9 @@ function install_additional_dependencies() { echo set -x pip install --root-user-action ignore --upgrade --upgrade-strategy eager \ + ${ADDITIONAL_PIP_INSTALL_FLAGS} \ ${ADDITIONAL_PYTHON_DEPS} ${EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS} - # make sure correct PIP version is used - pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null + common::install_pip_version set +x echo echo "${COLOR_BLUE}Running 'pip check'${COLOR_RESET}" @@ -726,9 +685,9 @@ function install_additional_dependencies() { echo set -x pip install --root-user-action ignore --upgrade --upgrade-strategy only-if-needed \ + ${ADDITIONAL_PIP_INSTALL_FLAGS} \ ${ADDITIONAL_PYTHON_DEPS} - # make sure correct PIP version is used - pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null + common::install_pip_version set +x echo echo "${COLOR_BLUE}Running 'pip check'${COLOR_RESET}" @@ -1100,44 +1059,10 @@ ENV PYTHON_BASE_IMAGE=${PYTHON_BASE_IMAGE} \ DEBIAN_FRONTEND=noninteractive LANGUAGE=C.UTF-8 LANG=C.UTF-8 LC_ALL=C.UTF-8 \ LC_CTYPE=C.UTF-8 LC_MESSAGES=C.UTF-8 -ARG DEV_APT_DEPS="\ - apt-transport-https \ - apt-utils \ - build-essential \ - ca-certificates \ - dirmngr \ - freetds-bin \ - freetds-dev \ - gosu \ - krb5-user \ - ldap-utils \ - libffi-dev \ - libkrb5-dev \ - libldap2-dev \ - libsasl2-2 \ - libsasl2-dev \ - libsasl2-modules \ - libssl-dev \ - locales \ - lsb-release \ - nodejs \ - openssh-client \ - sasl2-bin \ - software-properties-common \ - sqlite3 \ - sudo \ - unixodbc \ - unixodbc-dev \ - yarn" - +ARG DEV_APT_DEPS="" ARG ADDITIONAL_DEV_APT_DEPS="" -ARG DEV_APT_COMMAND="\ - curl --silent --fail --location https://deb.nodesource.com/setup_14.x | \ - bash -o pipefail -o errexit -o nolog - \ - && curl --silent https://dl.yarnpkg.com/debian/pubkey.gpg | \ - apt-key add - >/dev/null 2>&1\ - && echo 'deb https://dl.yarnpkg.com/debian/ stable main' > /etc/apt/sources.list.d/yarn.list" -ARG ADDITIONAL_DEV_APT_COMMAND="echo" +ARG DEV_APT_COMMAND="" +ARG ADDITIONAL_DEV_APT_COMMAND="" ARG ADDITIONAL_DEV_APT_ENV="" ENV DEV_APT_DEPS=${DEV_APT_DEPS} \ @@ -1146,23 +1071,8 @@ ENV DEV_APT_DEPS=${DEV_APT_DEPS} \ ADDITIONAL_DEV_APT_COMMAND=${ADDITIONAL_DEV_APT_COMMAND} \ ADDITIONAL_DEV_APT_ENV=${ADDITIONAL_DEV_APT_ENV} -COPY --from=scripts determine_debian_version_specific_variables.sh /scripts/docker/ -# Install basic and additional apt dependencies -RUN apt-get update \ - && apt-get install --no-install-recommends -yqq apt-utils >/dev/null 2>&1 \ - && apt-get install -y --no-install-recommends curl gnupg2 lsb-release \ - && export ${ADDITIONAL_DEV_APT_ENV?} \ - && source /scripts/docker/determine_debian_version_specific_variables.sh \ - && bash -o pipefail -o errexit -o nounset -o nolog -c "${DEV_APT_COMMAND}" \ - && bash -o pipefail -o errexit -o nounset -o nolog -c "${ADDITIONAL_DEV_APT_COMMAND}" \ - && apt-get update \ - && apt-get install -y --no-install-recommends \ - ${DEV_APT_DEPS} \ - "${DISTRO_SELINUX}" \ - ${ADDITIONAL_DEV_APT_DEPS} \ - && apt-get autoremove -yqq --purge \ - && apt-get clean \ - && rm -rf /var/lib/apt/lists/* +COPY --from=scripts install_os_dependencies.sh /scripts/docker/ +RUN bash /scripts/docker/install_os_dependencies.sh dev ARG INSTALL_MYSQL_CLIENT="true" ARG INSTALL_MSSQL_CLIENT="true" @@ -1196,18 +1106,10 @@ ARG INSTALL_PROVIDERS_FROM_SOURCES="false" # But it also can be `.` from local installation or GitHub URL pointing to specific branch or tag # Of Airflow. Note That for local source installation you need to have local sources of # Airflow checked out together with the Dockerfile and AIRFLOW_SOURCES_FROM and AIRFLOW_SOURCES_TO -# set to "." and "/opt/airflow" respectively. Similarly AIRFLOW_SOURCES_WWW_FROM/TO are set to right source -# and destination +# set to "." and "/opt/airflow" respectively. ARG AIRFLOW_INSTALLATION_METHOD="apache-airflow" # By default we do not upgrade to latest dependencies ARG UPGRADE_TO_NEWER_DEPENDENCIES="false" -# By default we install latest airflow from PyPI so we do not need to copy sources of Airflow -# www to compile the assets but in case of breeze/CI builds we use latest sources and we override those -# those SOURCES_FROM/TO with "airflow/www" and "/opt/airflow/airflow/www" respectively. -# This is to rebuild the assets only when any of the www sources change -ARG AIRFLOW_SOURCES_WWW_FROM="Dockerfile" -ARG AIRFLOW_SOURCES_WWW_TO="/Dockerfile" - # By default we install latest airflow from PyPI so we do not need to copy sources of Airflow # but in case of breeze/CI builds we use latest sources and we override those # those SOURCES_FROM/TO with "." and "/opt/airflow" respectively @@ -1251,6 +1153,9 @@ RUN if [[ -f /docker-context-files/pip.conf ]]; then \ cp /docker-context-files/.piprc "${AIRFLOW_USER_HOME_DIR}/.piprc"; \ fi +# Additional PIP flags passed to all pip install commands except reinstalling pip itself +ARG ADDITIONAL_PIP_INSTALL_FLAGS="" + ENV AIRFLOW_PIP_VERSION=${AIRFLOW_PIP_VERSION} \ AIRFLOW_PRE_CACHED_PIP_PACKAGES=${AIRFLOW_PRE_CACHED_PIP_PACKAGES} \ INSTALL_PROVIDERS_FROM_SOURCES=${INSTALL_PROVIDERS_FROM_SOURCES} \ @@ -1270,6 +1175,7 @@ ENV AIRFLOW_PIP_VERSION=${AIRFLOW_PIP_VERSION} \ PATH=${PATH}:${AIRFLOW_USER_HOME_DIR}/.local/bin \ AIRFLOW_PIP_VERSION=${AIRFLOW_PIP_VERSION} \ PIP_PROGRESS_BAR=${PIP_PROGRESS_BAR} \ + ADDITIONAL_PIP_INSTALL_FLAGS=${ADDITIONAL_PIP_INSTALL_FLAGS} \ AIRFLOW_USER_HOME_DIR=${AIRFLOW_USER_HOME_DIR} \ AIRFLOW_HOME=${AIRFLOW_HOME} \ AIRFLOW_UID=${AIRFLOW_UID} \ @@ -1283,61 +1189,50 @@ ENV AIRFLOW_PIP_VERSION=${AIRFLOW_PIP_VERSION} \ COPY --from=scripts common.sh install_pip_version.sh \ install_airflow_dependencies_from_branch_tip.sh /scripts/docker/ +# We can set this value to true in case we want to install .whl/.tar.gz packages placed in the +# docker-context-files folder. This can be done for both additional packages you want to install +# as well as Airflow and Provider packages (it will be automatically detected if airflow +# is installed from docker-context files rather than from PyPI) +ARG INSTALL_PACKAGES_FROM_CONTEXT="false" + # In case of Production build image segment we want to pre-install main version of airflow # dependencies from GitHub so that we do not have to always reinstall it from the scratch. # The Airflow (and providers in case INSTALL_PROVIDERS_FROM_SOURCES is "false") # are uninstalled, only dependencies remain # the cache is only used when "upgrade to newer dependencies" is not set to automatically -# account for removed dependencies (we do not install them in the first place) -# Upgrade to specific PIP version +# account for removed dependencies (we do not install them in the first place) and in case +# INSTALL_PACKAGES_FROM_CONTEXT is not set (because then caching it from main makes no sense). RUN bash /scripts/docker/install_pip_version.sh; \ if [[ ${AIRFLOW_PRE_CACHED_PIP_PACKAGES} == "true" && \ - ${UPGRADE_TO_NEWER_DEPENDENCIES} == "false" ]]; then \ + ${INSTALL_PACKAGES_FROM_CONTEXT} == "false" && \ + ${UPGRADE_TO_NEWER_DEPENDENCIES} == "false" ]]; then \ bash /scripts/docker/install_airflow_dependencies_from_branch_tip.sh; \ fi -COPY --from=scripts compile_www_assets.sh prepare_node_modules.sh /scripts/docker/ -COPY --chown=airflow:0 ${AIRFLOW_SOURCES_WWW_FROM} ${AIRFLOW_SOURCES_WWW_TO} - -# hadolint ignore=SC2086, SC2010 -RUN if [[ ${AIRFLOW_INSTALLATION_METHOD} == "." ]]; then \ - # only prepare node modules and compile assets if the prod image is build from sources - # otherwise they are already compiled-in. We should do it in one step with removing artifacts \ - # as we want to keep the final image small - bash /scripts/docker/prepare_node_modules.sh; \ - REMOVE_ARTIFACTS="true" BUILD_TYPE="prod" bash /scripts/docker/compile_www_assets.sh; \ - # Copy generated dist folder (otherwise it will be overridden by the COPY step below) - mv -f /opt/airflow/airflow/www/static/dist /tmp/dist; \ - fi; - COPY --chown=airflow:0 ${AIRFLOW_SOURCES_FROM} ${AIRFLOW_SOURCES_TO} -# Copy back the generated dist folder -RUN if [[ ${AIRFLOW_INSTALLATION_METHOD} == "." ]]; then \ - mv -f /tmp/dist /opt/airflow/airflow/www/static/dist; \ - fi; - # Add extra python dependencies ARG ADDITIONAL_PYTHON_DEPS="" -# We can set this value to true in case we want to install .whl .tar.gz packages placed in the -# docker-context-files folder. This can be done for both - additional packages you want to install -# and for airflow as well (you have to set AIRFLOW_IS_IN_CONTEXT to true in this case) -ARG INSTALL_PACKAGES_FROM_CONTEXT="false" -# By default we install latest airflow from PyPI or sources. You can set this parameter to false -# if Airflow is in the .whl or .tar.gz packages placed in `docker-context-files` folder and you want -# to skip installing Airflow/Providers from PyPI or sources. -ARG AIRFLOW_IS_IN_CONTEXT="false" + # Those are additional constraints that are needed for some extras but we do not want to # Force them on the main Airflow package. # * dill<0.3.3 required by apache-beam -ARG EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS="dill<0.3.3" +# * pyarrow>=6.0.0 is because pip resolver decides for Python 3.10 to downgrade pyarrow to 5 even if it is OK +# for python 3.10 and other dependencies adding the limit helps resolver to make better decisions +# We need to limit the protobuf library to < 4.21.0 because not all google libraries we use +# are compatible with the new protobuf version. All the google python client libraries need +# to be upgraded to >=2.0.0 in order to able to lift that limitation +# https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates +# * authlib, gcloud_aio_auth, adal are needed to generate constraints for PyPI packages and can be removed after we release +# new google, azure providers +# !!! MAKE SURE YOU SYNCHRONIZE THE LIST BETWEEN: Dockerfile, Dockerfile.ci, find_newer_dependencies.py +ARG EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS="dill<0.3.3 pyarrow>=6.0.0 protobuf<4.21.0 authlib>=1.0.0 gcloud_aio_auth>=4.0.0 adal>=1.2.7" ENV ADDITIONAL_PYTHON_DEPS=${ADDITIONAL_PYTHON_DEPS} \ INSTALL_PACKAGES_FROM_CONTEXT=${INSTALL_PACKAGES_FROM_CONTEXT} \ - AIRFLOW_IS_IN_CONTEXT=${AIRFLOW_IS_IN_CONTEXT} \ EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS=${EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS} -WORKDIR /opt/airflow +WORKDIR ${AIRFLOW_HOME} COPY --from=scripts install_from_docker_context_files.sh install_airflow.sh \ install_additional_dependencies.sh /scripts/docker/ @@ -1345,7 +1240,8 @@ COPY --from=scripts install_from_docker_context_files.sh install_airflow.sh \ # hadolint ignore=SC2086, SC2010 RUN if [[ ${INSTALL_PACKAGES_FROM_CONTEXT} == "true" ]]; then \ bash /scripts/docker/install_from_docker_context_files.sh; \ - elif [[ ${AIRFLOW_IS_IN_CONTEXT} == "false" ]]; then \ + fi; \ + if ! airflow version 2>/dev/null >/dev/null; then \ bash /scripts/docker/install_airflow.sh; \ fi; \ if [[ -n "${ADDITIONAL_PYTHON_DEPS}" ]]; then \ @@ -1388,32 +1284,10 @@ ARG AIRFLOW_PIP_VERSION ENV PYTHON_BASE_IMAGE=${PYTHON_BASE_IMAGE} \ # Make sure noninteractive debian install is used and language variables set DEBIAN_FRONTEND=noninteractive LANGUAGE=C.UTF-8 LANG=C.UTF-8 LC_ALL=C.UTF-8 \ - LC_CTYPE=C.UTF-8 LC_MESSAGES=C.UTF-8 \ + LC_CTYPE=C.UTF-8 LC_MESSAGES=C.UTF-8 LD_LIBRARY_PATH=/usr/local/lib \ AIRFLOW_PIP_VERSION=${AIRFLOW_PIP_VERSION} -ARG RUNTIME_APT_DEPS="\ - apt-transport-https \ - apt-utils \ - ca-certificates \ - curl \ - dumb-init \ - freetds-bin \ - gosu \ - krb5-user \ - ldap-utils \ - libldap-2.4-2 \ - libsasl2-2 \ - libsasl2-modules \ - libssl1.1 \ - locales \ - lsb-release \ - netcat \ - openssh-client \ - rsync \ - sasl2-bin \ - sqlite3 \ - sudo \ - unixodbc" +ARG RUNTIME_APT_DEPS="" ARG ADDITIONAL_RUNTIME_APT_DEPS="" ARG RUNTIME_APT_COMMAND="echo" ARG ADDITIONAL_RUNTIME_APT_COMMAND="" @@ -1432,25 +1306,8 @@ ENV RUNTIME_APT_DEPS=${RUNTIME_APT_DEPS} \ GUNICORN_CMD_ARGS="--worker-tmp-dir /dev/shm" \ AIRFLOW_INSTALLATION_METHOD=${AIRFLOW_INSTALLATION_METHOD} -COPY --from=scripts determine_debian_version_specific_variables.sh /scripts/docker/ - -# Install basic and additional apt dependencies -RUN apt-get update \ - && apt-get install --no-install-recommends -yqq apt-utils >/dev/null 2>&1 \ - && apt-get install -y --no-install-recommends curl gnupg2 lsb-release \ - && export ${ADDITIONAL_RUNTIME_APT_ENV?} \ - && source /scripts/docker/determine_debian_version_specific_variables.sh \ - && bash -o pipefail -o errexit -o nounset -o nolog -c "${RUNTIME_APT_COMMAND}" \ - && bash -o pipefail -o errexit -o nounset -o nolog -c "${ADDITIONAL_RUNTIME_APT_COMMAND}" \ - && apt-get update \ - && apt-get install -y --no-install-recommends \ - ${RUNTIME_APT_DEPS} \ - "${DISTRO_LIBFFI}" \ - ${ADDITIONAL_RUNTIME_APT_DEPS} \ - && apt-get autoremove -yqq --purge \ - && apt-get clean \ - && rm -rf /var/lib/apt/lists/* \ - && rm -rf /var/log/* +COPY --from=scripts install_os_dependencies.sh /scripts/docker/ +RUN bash /scripts/docker/install_os_dependencies.sh runtime # Having the variable in final image allows to disable providers manager warnings when # production image is prepared from sources rather than from package @@ -1472,7 +1329,7 @@ COPY --from=scripts install_mysql.sh install_mssql.sh install_postgres.sh /scrip # We run scripts with bash here to make sure we can execute the scripts. Changing to +x might have an # unexpected result - the cache for Dockerfiles might get invalidated in case the host system # had different umask set and group x bit was not set. In Azure the bit might be not set at all. -# That also protects against AUFS Docker backen dproblem where changing the executable bit required sync +# That also protects against AUFS Docker backend problem where changing the executable bit required sync RUN bash /scripts/docker/install_mysql.sh prod \ && bash /scripts/docker/install_mssql.sh \ && bash /scripts/docker/install_postgres.sh prod \ diff --git a/Dockerfile.ci b/Dockerfile.ci index c089544f244b3..cc28c355d916d 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -31,39 +31,106 @@ FROM ${PYTHON_BASE_IMAGE} as scripts # make the PROD Dockerfile standalone ############################################################################################## -# The content below is automatically copied from scripts/docker/determine_debian_version_specific_variables.sh -COPY <<"EOF" /determine_debian_version_specific_variables.sh -function determine_debian_version_specific_variables() { - local color_red - color_red=$'\e[31m' - local color_reset - color_reset=$'\e[0m' - - local debian_version - debian_version=$(lsb_release -cs) - if [[ ${debian_version} == "buster" ]]; then - export DISTRO_LIBENCHANT="libenchant-dev" - export DISTRO_LIBGCC="libgcc-8-dev" - export DISTRO_SELINUX="python-selinux" - export DISTRO_LIBFFI="libffi6" - # Note missing man directories on debian-buster - # https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=863199 - mkdir -pv /usr/share/man/man1 - mkdir -pv /usr/share/man/man7 - elif [[ ${debian_version} == "bullseye" ]]; then - export DISTRO_LIBENCHANT="libenchant-2-2" - export DISTRO_LIBGCC="libgcc-10-dev" - export DISTRO_SELINUX="python3-selinux" - export DISTRO_LIBFFI="libffi7" +# The content below is automatically copied from scripts/docker/install_os_dependencies.sh +COPY <<"EOF" /install_os_dependencies.sh +set -euo pipefail + +DOCKER_CLI_VERSION=20.10.9 + +if [[ "$#" != 1 ]]; then + echo "ERROR! There should be 'runtime' or 'dev' parameter passed as argument.". + exit 1 +fi + +if [[ "${1}" == "runtime" ]]; then + INSTALLATION_TYPE="RUNTIME" +elif [[ "${1}" == "dev" ]]; then + INSTALLATION_TYPE="dev" +else + echo "ERROR! Wrong argument. Passed ${1} and it should be one of 'runtime' or 'dev'.". + exit 1 +fi + +function get_dev_apt_deps() { + if [[ "${DEV_APT_DEPS=}" == "" ]]; then + DEV_APT_DEPS="apt-transport-https apt-utils build-essential ca-certificates dirmngr \ +freetds-bin freetds-dev git gosu graphviz graphviz-dev krb5-user ldap-utils libffi-dev \ +libkrb5-dev libldap2-dev libleveldb1d libleveldb-dev libsasl2-2 libsasl2-dev libsasl2-modules \ +libssl-dev locales lsb-release openssh-client sasl2-bin \ +software-properties-common sqlite3 sudo unixodbc unixodbc-dev" + export DEV_APT_DEPS + fi +} + +function get_runtime_apt_deps() { + if [[ "${RUNTIME_APT_DEPS=}" == "" ]]; then + RUNTIME_APT_DEPS="apt-transport-https apt-utils ca-certificates \ +curl dumb-init freetds-bin gosu krb5-user \ +ldap-utils libffi7 libldap-2.4-2 libsasl2-2 libsasl2-modules libssl1.1 locales \ +lsb-release netcat openssh-client python3-selinux rsync sasl2-bin sqlite3 sudo unixodbc" + export RUNTIME_APT_DEPS + fi +} + +function install_docker_cli() { + local platform + if [[ $(uname -m) == "arm64" || $(uname -m) == "aarch64" ]]; then + platform="aarch64" else - echo - echo "${color_red}Unknown distro version ${debian_version}${color_reset}" - echo - exit 1 + platform="x86_64" fi + curl --silent \ + "https://download.docker.com/linux/static/stable/${platform}/docker-${DOCKER_CLI_VERSION}.tgz" \ + | tar -C /usr/bin --strip-components=1 -xvzf - docker/docker } -determine_debian_version_specific_variables +function install_debian_dev_dependencies() { + apt-get update + apt-get install --no-install-recommends -yqq apt-utils >/dev/null 2>&1 + apt-get install -y --no-install-recommends curl gnupg2 lsb-release + # shellcheck disable=SC2086 + export ${ADDITIONAL_DEV_APT_ENV?} + if [[ ${DEV_APT_COMMAND} != "" ]]; then + bash -o pipefail -o errexit -o nounset -o nolog -c "${DEV_APT_COMMAND}" + fi + if [[ ${ADDITIONAL_DEV_APT_COMMAND} != "" ]]; then + bash -o pipefail -o errexit -o nounset -o nolog -c "${ADDITIONAL_DEV_APT_COMMAND}" + fi + apt-get update + # shellcheck disable=SC2086 + apt-get install -y --no-install-recommends ${DEV_APT_DEPS} ${ADDITIONAL_DEV_APT_DEPS} +} + +function install_debian_runtime_dependencies() { + apt-get update + apt-get install --no-install-recommends -yqq apt-utils >/dev/null 2>&1 + apt-get install -y --no-install-recommends curl gnupg2 lsb-release + # shellcheck disable=SC2086 + export ${ADDITIONAL_RUNTIME_APT_ENV?} + if [[ "${RUNTIME_APT_COMMAND}" != "" ]]; then + bash -o pipefail -o errexit -o nounset -o nolog -c "${RUNTIME_APT_COMMAND}" + fi + if [[ "${ADDITIONAL_RUNTIME_APT_COMMAND}" != "" ]]; then + bash -o pipefail -o errexit -o nounset -o nolog -c "${ADDITIONAL_RUNTIME_APT_COMMAND}" + fi + apt-get update + # shellcheck disable=SC2086 + apt-get install -y --no-install-recommends ${RUNTIME_APT_DEPS} ${ADDITIONAL_RUNTIME_APT_DEPS} + apt-get autoremove -yqq --purge + apt-get clean + rm -rf /var/lib/apt/lists/* /var/log/* +} + +if [[ "${INSTALLATION_TYPE}" == "RUNTIME" ]]; then + get_runtime_apt_deps + install_debian_runtime_dependencies + install_docker_cli + +else + get_dev_apt_deps + install_debian_dev_dependencies + install_docker_cli +fi EOF # The content below is automatically copied from scripts/docker/install_mysql.sh @@ -138,8 +205,6 @@ set -euo pipefail COLOR_BLUE=$'\e[34m' readonly COLOR_BLUE -COLOR_YELLOW=$'\e[33m' -readonly COLOR_YELLOW COLOR_RESET=$'\e[0m' readonly COLOR_RESET @@ -157,19 +222,8 @@ function install_mssql_client() { local distro local version distro=$(lsb_release -is | tr '[:upper:]' '[:lower:]') - version_name=$(lsb_release -cs | tr '[:upper:]' '[:lower:]') version=$(lsb_release -rs) - local driver - if [[ ${version_name} == "buster" ]]; then - driver=msodbcsql17 - elif [[ ${version_name} == "bullseye" ]]; then - driver=msodbcsql18 - else - echo - echo "${COLOR_YELLOW}Only Buster or Bullseye are supported. Skipping MSSQL installation${COLOR_RESET}" - echo - return - fi + local driver=msodbcsql18 curl --silent https://packages.microsoft.com/keys/microsoft.asc | apt-key add - >/dev/null 2>&1 curl --silent "https://packages.microsoft.com/config/${distro}/${version}/prod.list" > \ /etc/apt/sources.list.d/mssql-release.list @@ -181,11 +235,6 @@ function install_mssql_client() { apt-get clean && rm -rf /var/lib/apt/lists/* } -if [[ $(uname -m) == "arm64" || $(uname -m) == "aarch64" ]]; then - # disable MSSQL for ARM64 - INSTALL_MSSQL_CLIENT="false" -fi - install_mssql_client "${@}" EOF @@ -236,20 +285,12 @@ COPY <<"EOF" /install_pip_version.sh : "${AIRFLOW_PIP_VERSION:?Should be set}" -function install_pip_version() { - echo - echo "${COLOR_BLUE}Installing pip version ${AIRFLOW_PIP_VERSION}${COLOR_RESET}" - echo - pip install --disable-pip-version-check --no-cache-dir --upgrade "pip==${AIRFLOW_PIP_VERSION}" && - mkdir -p ${HOME}/.local/bin -} - common::get_colors common::get_airflow_version_specification common::override_pip_version_if_needed common::show_pip_version_and_location -install_pip_version +common::install_pip_version EOF # The content below is automatically copied from scripts/docker/install_airflow_dependencies_from_branch_tip.sh @@ -277,10 +318,10 @@ function install_airflow_dependencies_from_branch_tip() { # are conflicts, this might fail, but it should be fixed in the following installation steps set -x pip install --root-user-action ignore \ + ${ADDITIONAL_PIP_INSTALL_FLAGS} \ "https://github.com/${AIRFLOW_REPO}/archive/${AIRFLOW_BRANCH}.tar.gz#egg=apache-airflow[${AIRFLOW_EXTRAS}]" \ --constraint "${AIRFLOW_CONSTRAINTS_LOCATION}" || true - # make sure correct PIP version is used - pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null + common::install_pip_version pip freeze | grep apache-airflow-providers | xargs pip uninstall --yes 2>/dev/null || true set +x echo @@ -327,7 +368,7 @@ function common::get_airflow_version_specification() { function common::override_pip_version_if_needed() { if [[ -n ${AIRFLOW_VERSION} ]]; then if [[ ${AIRFLOW_VERSION} =~ ^2\.0.* || ${AIRFLOW_VERSION} =~ ^1\.* ]]; then - export AIRFLOW_PIP_VERSION="22.1.2" + export AIRFLOW_PIP_VERSION="22.3.1" fi fi } @@ -355,6 +396,18 @@ function common::show_pip_version_and_location() { echo "pip on path: $(which pip)" echo "Using pip: $(pip --version)" } + +function common::install_pip_version() { + echo + echo "${COLOR_BLUE}Installing pip version ${AIRFLOW_PIP_VERSION}${COLOR_RESET}" + echo + if [[ ${AIRFLOW_PIP_VERSION} =~ .*https.* ]]; then + pip install --disable-pip-version-check --no-cache-dir "pip @ ${AIRFLOW_PIP_VERSION}" + else + pip install --disable-pip-version-check --no-cache-dir "pip==${AIRFLOW_PIP_VERSION}" + fi + mkdir -p "${HOME}/.local/bin" +} EOF # The content below is automatically copied from scripts/docker/install_pipx_tools.sh @@ -384,101 +437,6 @@ common::get_colors install_pipx_tools EOF -# The content below is automatically copied from scripts/docker/prepare_node_modules.sh -COPY <<"EOF" /prepare_node_modules.sh -set -euo pipefail - -COLOR_BLUE=$'\e[34m' -readonly COLOR_BLUE -COLOR_RESET=$'\e[0m' -readonly COLOR_RESET - -function prepare_node_modules() { - echo - echo "${COLOR_BLUE}Preparing node modules${COLOR_RESET}" - echo - local www_dir - if [[ ${AIRFLOW_INSTALLATION_METHOD=} == "." ]]; then - # In case we are building from sources in production image, we should build the assets - www_dir="${AIRFLOW_SOURCES_TO=${AIRFLOW_SOURCES}}/airflow/www" - else - www_dir="$(python -m site --user-site)/airflow/www" - fi - pushd ${www_dir} || exit 1 - set +e - yarn install --frozen-lockfile --no-cache 2>/tmp/out-yarn-install.txt - local res=$? - if [[ ${res} != 0 ]]; then - >&2 echo - >&2 echo "Error when running yarn install:" - >&2 echo - >&2 cat /tmp/out-yarn-install.txt && rm -f /tmp/out-yarn-install.txt - exit 1 - fi - rm -f /tmp/out-yarn-install.txt - popd || exit 1 -} - -prepare_node_modules -EOF - -# The content below is automatically copied from scripts/docker/compile_www_assets.sh -COPY <<"EOF" /compile_www_assets.sh -set -euo pipefail - -BUILD_TYPE=${BUILD_TYPE="prod"} -REMOVE_ARTIFACTS=${REMOVE_ARTIFACTS="true"} - -COLOR_BLUE=$'\e[34m' -readonly COLOR_BLUE -COLOR_RESET=$'\e[0m' -readonly COLOR_RESET - -function compile_www_assets() { - echo - echo "${COLOR_BLUE}Compiling www assets: running yarn ${BUILD_TYPE}${COLOR_RESET}" - echo - local www_dir - if [[ ${AIRFLOW_INSTALLATION_METHOD=} == "." ]]; then - # In case we are building from sources in production image, we should build the assets - www_dir="${AIRFLOW_SOURCES_TO=${AIRFLOW_SOURCES}}/airflow/www" - else - www_dir="$(python -m site --user-site)/airflow/www" - fi - pushd ${www_dir} || exit 1 - set +e - yarn run "${BUILD_TYPE}" 2>/tmp/out-yarn-run.txt - res=$? - if [[ ${res} != 0 ]]; then - >&2 echo - >&2 echo "Error when running yarn run:" - >&2 echo - >&2 cat /tmp/out-yarn-run.txt && rm -rf /tmp/out-yarn-run.txt - exit 1 - fi - rm -f /tmp/out-yarn-run.txt - set -e - local md5sum_file - md5sum_file="static/dist/sum.md5" - readonly md5sum_file - find package.json yarn.lock static/css static/js -type f | sort | xargs md5sum > "${md5sum_file}" - if [[ ${REMOVE_ARTIFACTS} == "true" ]]; then - echo - echo "${COLOR_BLUE}Removing generated node modules${COLOR_RESET}" - echo - rm -rf "${www_dir}/node_modules" - rm -vf "${www_dir}"/{package.json,yarn.lock,.eslintignore,.eslintrc,.stylelintignore,.stylelintrc,compile_assets.sh,webpack.config.js} - else - echo - echo "${COLOR_BLUE}Leaving generated node modules${COLOR_RESET}" - echo - fi - popd || exit 1 -} - -compile_www_assets -EOF - # The content below is automatically copied from scripts/docker/install_airflow.sh COPY <<"EOF" /install_airflow.sh @@ -511,6 +469,7 @@ function install_airflow() { echo # eager upgrade pip install --root-user-action ignore --upgrade --upgrade-strategy eager \ + ${ADDITIONAL_PIP_INSTALL_FLAGS} \ "${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_VERSION_SPECIFICATION}" \ ${EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS} if [[ -n "${AIRFLOW_INSTALL_EDITABLE_FLAG}" ]]; then @@ -519,12 +478,12 @@ function install_airflow() { set -x pip uninstall apache-airflow --yes pip install --root-user-action ignore ${AIRFLOW_INSTALL_EDITABLE_FLAG} \ + ${ADDITIONAL_PIP_INSTALL_FLAGS} \ "${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_VERSION_SPECIFICATION}" set +x fi - # make sure correct PIP version is used - pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null + common::install_pip_version echo echo "${COLOR_BLUE}Running 'pip check'${COLOR_RESET}" echo @@ -535,16 +494,16 @@ function install_airflow() { echo set -x pip install --root-user-action ignore ${AIRFLOW_INSTALL_EDITABLE_FLAG} \ + ${ADDITIONAL_PIP_INSTALL_FLAGS} \ "${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_VERSION_SPECIFICATION}" \ --constraint "${AIRFLOW_CONSTRAINTS_LOCATION}" - # make sure correct PIP version is used - pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null + common::install_pip_version # then upgrade if needed without using constraints to account for new limits in setup.py pip install --root-user-action ignore --upgrade --upgrade-strategy only-if-needed \ + ${ADDITIONAL_PIP_INSTALL_FLAGS} \ ${AIRFLOW_INSTALL_EDITABLE_FLAG} \ "${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_VERSION_SPECIFICATION}" - # make sure correct PIP version is used - pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null + common::install_pip_version set +x echo echo "${COLOR_BLUE}Running 'pip check'${COLOR_RESET}" @@ -581,9 +540,9 @@ function install_additional_dependencies() { echo set -x pip install --root-user-action ignore --upgrade --upgrade-strategy eager \ + ${ADDITIONAL_PIP_INSTALL_FLAGS} \ ${ADDITIONAL_PYTHON_DEPS} ${EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS} - # make sure correct PIP version is used - pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null + common::install_pip_version set +x echo echo "${COLOR_BLUE}Running 'pip check'${COLOR_RESET}" @@ -595,9 +554,9 @@ function install_additional_dependencies() { echo set -x pip install --root-user-action ignore --upgrade --upgrade-strategy only-if-needed \ + ${ADDITIONAL_PIP_INSTALL_FLAGS} \ ${ADDITIONAL_PYTHON_DEPS} - # make sure correct PIP version is used - pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null + common::install_pip_version set +x echo echo "${COLOR_BLUE}Running 'pip check'${COLOR_RESET}" @@ -622,7 +581,7 @@ if [[ ${VERBOSE_COMMANDS:="false"} == "true" ]]; then set -x fi -. /opt/airflow/scripts/in_container/_in_container_script_init.sh +. "${AIRFLOW_SOURCES:-/opt/airflow}"/scripts/in_container/_in_container_script_init.sh LD_PRELOAD="/usr/lib/$(uname -m)-linux-gnu/libstdc++.so.6" export LD_PRELOAD @@ -637,6 +596,35 @@ export AIRFLOW_HOME=${AIRFLOW_HOME:=${HOME}} : "${AIRFLOW_SOURCES:?"ERROR: AIRFLOW_SOURCES not set !!!!"}" +function wait_for_asset_compilation() { + if [[ -f "${AIRFLOW_SOURCES}/.build/www/.asset_compile.lock" ]]; then + echo + echo "${COLOR_YELLOW}Waiting for asset compilation to complete in the background.${COLOR_RESET}" + echo + local counter=0 + while [[ -f "${AIRFLOW_SOURCES}/.build/www/.asset_compile.lock" ]]; do + echo "${COLOR_BLUE}Still waiting .....${COLOR_RESET}" + sleep 1 + ((counter=counter+1)) + if [[ ${counter} == "30" ]]; then + echo + echo "${COLOR_YELLOW}The asset compilation is taking too long.${COLOR_YELLOW}" + echo """ +If it does not complete soon, you might want to stop it and remove file lock: + * press Ctrl-C + * run 'rm ${AIRFLOW_SOURCES}/.build/www/.asset_compile.lock' +""" + fi + if [[ ${counter} == "60" ]]; then + echo + echo "${COLOR_RED}The asset compilation is taking too long. Exiting.${COLOR_RED}" + echo + exit 1 + fi + done + fi +} + if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then if [[ $(uname -m) == "arm64" || $(uname -m) == "aarch64" ]]; then @@ -657,17 +645,13 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then RUN_TESTS=${RUN_TESTS:="false"} CI=${CI:="false"} USE_AIRFLOW_VERSION="${USE_AIRFLOW_VERSION:=""}" + TEST_TIMEOUT=${TEST_TIMEOUT:="60"} if [[ ${USE_AIRFLOW_VERSION} == "" ]]; then export PYTHONPATH=${AIRFLOW_SOURCES} echo echo "${COLOR_BLUE}Using airflow version from current sources${COLOR_RESET}" echo - if [[ -d "${AIRFLOW_SOURCES}/airflow/www/" ]]; then - pushd "${AIRFLOW_SOURCES}/airflow/www/" >/dev/null - ./ask_for_recompile_assets_if_needed.sh - popd >/dev/null - fi # Cleanup the logs, tmp when entering the environment sudo rm -rf "${AIRFLOW_SOURCES}"/logs/* sudo rm -rf "${AIRFLOW_SOURCES}"/tmp/* @@ -686,9 +670,15 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then echo "${COLOR_BLUE}Uninstalling airflow and providers" echo uninstall_airflow_and_providers - echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" - echo - install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then + echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}" + echo + install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "none" + else + echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" + echo + install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + fi uninstall_providers elif [[ ${USE_AIRFLOW_VERSION} == "sdist" ]]; then echo @@ -696,9 +686,15 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then echo uninstall_airflow_and_providers echo - echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" - echo - install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then + echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}" + echo + install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "none" + else + echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" + echo + install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + fi uninstall_providers else echo @@ -706,9 +702,19 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then echo uninstall_airflow_and_providers echo - echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" - echo - install_released_airflow_version "${USE_AIRFLOW_VERSION}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then + echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}" + echo + install_released_airflow_version "${USE_AIRFLOW_VERSION}" "none" + else + echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}" + echo + install_released_airflow_version "${USE_AIRFLOW_VERSION}" "${AIRFLOW_CONSTRAINTS_REFERENCE}" + fi + if [[ "${USE_AIRFLOW_VERSION}" =~ ^2\.2\..*|^2\.1\..*|^2\.0\..* && "${AIRFLOW__DATABASE__SQL_ALCHEMY_CONN=}" != "" ]]; then + # make sure old variable is used for older airflow versions + export AIRFLOW__CORE__SQL_ALCHEMY_CONN="${AIRFLOW__DATABASE__SQL_ALCHEMY_CONN}" + fi fi if [[ ${USE_PACKAGES_FROM_DIST=} == "true" ]]; then echo @@ -768,11 +774,12 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then echo exit ${ENVIRONMENT_EXIT_CODE} fi - # Create symbolic link to fix possible issues with kubectl config cmd-path mkdir -p /usr/lib/google-cloud-sdk/bin touch /usr/lib/google-cloud-sdk/bin/gcloud ln -s -f /usr/bin/gcloud /usr/lib/google-cloud-sdk/bin/gcloud + in_container_fix_ownership + if [[ ${SKIP_SSH_SETUP="false"} == "false" ]]; then # Set up ssh keys echo 'yes' | ssh-keygen -t rsa -C your_email@youremail.com -m PEM -P '' -f ~/.ssh/id_rsa \ @@ -803,8 +810,9 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then cd "${AIRFLOW_SOURCES}" if [[ ${START_AIRFLOW:="false"} == "true" || ${START_AIRFLOW} == "True" ]]; then - export AIRFLOW__CORE__LOAD_DEFAULT_CONNECTIONS=${LOAD_DEFAULT_CONNECTIONS} + export AIRFLOW__DATABASE__LOAD_DEFAULT_CONNECTIONS=${LOAD_DEFAULT_CONNECTIONS} export AIRFLOW__CORE__LOAD_EXAMPLES=${LOAD_EXAMPLES} + wait_for_asset_compilation # shellcheck source=scripts/in_container/bin/run_tmux exec run_tmux fi @@ -816,7 +824,8 @@ if [[ "${RUN_TESTS}" != "true" ]]; then fi set -u -export RESULT_LOG_FILE="/files/test_result-${TEST_TYPE}-${BACKEND}.xml" +export RESULT_LOG_FILE="/files/test_result-${TEST_TYPE/\[*\]/}-${BACKEND}.xml" +export WARNINGS_FILE="/files/warnings-${TEST_TYPE/\[*\]/}-${BACKEND}.txt" EXTRA_PYTEST_ARGS=( "--verbosity=0" @@ -828,9 +837,11 @@ EXTRA_PYTEST_ARGS=( # timeouts in seconds for individual tests "--timeouts-order" "moi" - "--setup-timeout=60" - "--execution-timeout=60" - "--teardown-timeout=60" + "--setup-timeout=${TEST_TIMEOUT}" + "--execution-timeout=${TEST_TIMEOUT}" + "--teardown-timeout=${TEST_TIMEOUT}" + "--output=${WARNINGS_FILE}" + "--disable-warnings" # Only display summary for non-expected case # f - failed # E - error @@ -844,9 +855,11 @@ EXTRA_PYTEST_ARGS=( ) if [[ "${TEST_TYPE}" == "Helm" ]]; then + _cpus="$(grep -c 'cpu[0-9]' /proc/stat)" + echo "Running tests with ${_cpus} CPUs in parallel" # Enable parallelism EXTRA_PYTEST_ARGS+=( - "-n" "auto" + "-n" "${_cpus}" ) else EXTRA_PYTEST_ARGS+=( @@ -858,7 +871,7 @@ if [[ ${ENABLE_TEST_COVERAGE:="false"} == "true" ]]; then EXTRA_PYTEST_ARGS+=( "--cov=airflow/" "--cov-config=.coveragerc" - "--cov-report=xml:/files/coverage-${TEST_TYPE}-${BACKEND}.xml" + "--cov-report=xml:/files/coverage-${TEST_TYPE/\[*\]/}-${BACKEND}.xml" ) fi @@ -900,11 +913,13 @@ else ) WWW_TESTS=("tests/www") HELM_CHART_TESTS=("tests/charts") + INTEGRATION_TESTS=("tests/integration") ALL_TESTS=("tests") ALL_PRESELECTED_TESTS=( "${CLI_TESTS[@]}" "${API_TESTS[@]}" "${HELM_CHART_TESTS[@]}" + "${INTEGRATION_TESTS[@]}" "${PROVIDERS_TESTS[@]}" "${CORE_TESTS[@]}" "${ALWAYS_TESTS[@]}" @@ -925,33 +940,38 @@ else SELECTED_TESTS=("${WWW_TESTS[@]}") elif [[ ${TEST_TYPE:=""} == "Helm" ]]; then SELECTED_TESTS=("${HELM_CHART_TESTS[@]}") + elif [[ ${TEST_TYPE:=""} == "Integration" ]]; then + SELECTED_TESTS=("${INTEGRATION_TESTS[@]}") elif [[ ${TEST_TYPE:=""} == "Other" ]]; then find_all_other_tests SELECTED_TESTS=("${ALL_OTHER_TESTS[@]}") elif [[ ${TEST_TYPE:=""} == "All" || ${TEST_TYPE} == "Quarantined" || \ ${TEST_TYPE} == "Always" || \ ${TEST_TYPE} == "Postgres" || ${TEST_TYPE} == "MySQL" || \ - ${TEST_TYPE} == "Long" || \ - ${TEST_TYPE} == "Integration" ]]; then + ${TEST_TYPE} == "Long" ]]; then SELECTED_TESTS=("${ALL_TESTS[@]}") + elif [[ ${TEST_TYPE} =~ Providers\[(.*)\] ]]; then + SELECTED_TESTS=() + for provider in ${BASH_REMATCH[1]//,/ } + do + providers_dir="tests/providers/${provider//./\/}" + if [[ -d ${providers_dir} ]]; then + SELECTED_TESTS+=("${providers_dir}") + else + echo "${COLOR_YELLOW}Skip ${providers_dir} as the directory does not exist.${COLOR_RESET}" + fi + done else echo echo "${COLOR_RED}ERROR: Wrong test type ${TEST_TYPE} ${COLOR_RESET}" echo exit 1 fi - fi readonly SELECTED_TESTS CLI_TESTS API_TESTS PROVIDERS_TESTS CORE_TESTS WWW_TESTS \ ALL_TESTS ALL_PRESELECTED_TESTS -if [[ -n ${LIST_OF_INTEGRATION_TESTS_TO_RUN=} ]]; then - # Integration tests - for INT in ${LIST_OF_INTEGRATION_TESTS_TO_RUN} - do - EXTRA_PYTEST_ARGS+=("--integration" "${INT}") - done -elif [[ ${TEST_TYPE:=""} == "Long" ]]; then +if [[ ${TEST_TYPE:=""} == "Long" ]]; then EXTRA_PYTEST_ARGS+=( "-m" "long_running" "--include-long-running" @@ -998,16 +1018,6 @@ COPY <<"EOF" /entrypoint_exec.sh exec /bin/bash "${@}" EOF -############################################################################################## -# This is the www image where we keep all inlined files needed to build ui -# It is copied separately to volume to speed up building and avoid cache miss on changed -# file permissions. -# We use PYTHON_BASE_IMAGE to make sure that the scripts are different for different platforms. -############################################################################################## -FROM ${PYTHON_BASE_IMAGE} as www -COPY airflow/www/package.json airflow/www/yarn.lock airflow/www/webpack.config.js / -COPY airflow/www/static/ /static - FROM ${PYTHON_BASE_IMAGE} as main # Nolog bash flag is currently ignored - but you can replace it with other flags (for example @@ -1018,7 +1028,7 @@ ARG PYTHON_BASE_IMAGE ARG AIRFLOW_IMAGE_REPOSITORY="https://github.com/apache/airflow" # By increasing this number we can do force build of all dependencies -ARG DEPENDENCIES_EPOCH_NUMBER="6" +ARG DEPENDENCIES_EPOCH_NUMBER="7" # Make sure noninteractive debian install is used and language variables set ENV PYTHON_BASE_IMAGE=${PYTHON_BASE_IMAGE} \ @@ -1031,68 +1041,35 @@ ENV PYTHON_BASE_IMAGE=${PYTHON_BASE_IMAGE} \ RUN echo "Base image version: ${PYTHON_BASE_IMAGE}" -ARG ADDITIONAL_DEV_APT_DEPS="" -ARG DEV_APT_COMMAND="\ - curl --silent --fail --location https://deb.nodesource.com/setup_14.x | bash - \ - && curl --silent --fail https://dl.yarnpkg.com/debian/pubkey.gpg | apt-key add - >/dev/null 2>&1 \ - && echo 'deb https://dl.yarnpkg.com/debian/ stable main' > /etc/apt/sources.list.d/yarn.list" +ARG ADDITIONAL_DEV_APT_DEPS="git graphviz gosu libpq-dev netcat rsync" +ARG DEV_APT_COMMAND="" ARG ADDITIONAL_DEV_APT_COMMAND="" ARG ADDITIONAL_DEV_ENV_VARS="" +ARG ADDITIONAL_DEV_APT_DEPS="bash-completion dumb-init git graphviz gosu krb5-user \ +less libenchant-2-2 libgcc-10-dev libpq-dev net-tools netcat \ +openssh-server postgresql-client software-properties-common rsync tmux unzip vim xxd" + +ARG ADDITIONAL_DEV_APT_ENV="" ENV DEV_APT_COMMAND=${DEV_APT_COMMAND} \ ADDITIONAL_DEV_APT_DEPS=${ADDITIONAL_DEV_APT_DEPS} \ ADDITIONAL_DEV_APT_COMMAND=${ADDITIONAL_DEV_APT_COMMAND} -COPY --from=scripts determine_debian_version_specific_variables.sh /scripts/docker/ - -# Install basic and additional apt dependencies -RUN apt-get update \ - && apt-get install --no-install-recommends -yqq apt-utils >/dev/null 2>&1 \ - && apt-get install -y --no-install-recommends curl gnupg2 lsb-release \ - && mkdir -pv /usr/share/man/man1 \ - && mkdir -pv /usr/share/man/man7 \ - && export ${ADDITIONAL_DEV_ENV_VARS?} \ - && source /scripts/docker/determine_debian_version_specific_variables.sh \ - && bash -o pipefail -o errexit -o nounset -o nolog -c "${DEV_APT_COMMAND}" \ - && bash -o pipefail -o errexit -o nounset -o nolog -c "${ADDITIONAL_DEV_APT_COMMAND}" \ - && apt-get update \ - && apt-get install -y --no-install-recommends \ - apt-utils \ - build-essential \ - dirmngr \ - dumb-init \ - freetds-bin \ - freetds-dev \ - git \ - graphviz \ - gosu \ - libffi-dev \ - libldap2-dev \ - libkrb5-dev \ - libpq-dev \ - libsasl2-2 \ - libsasl2-dev \ - libsasl2-modules \ - libssl-dev \ - "${DISTRO_LIBENCHANT}" \ - locales \ - netcat \ - nodejs \ - rsync \ - sasl2-bin \ - sudo \ - unixodbc \ - unixodbc-dev \ - yarn \ - ${ADDITIONAL_DEV_APT_DEPS} \ - && apt-get autoremove -yqq --purge \ - && apt-get clean \ - && rm -rf /var/lib/apt/lists/* +COPY --from=scripts install_os_dependencies.sh /scripts/docker/ +RUN bash /scripts/docker/install_os_dependencies.sh dev # Only copy mysql/mssql installation scripts for now - so that changing the other # scripts which are needed much later will not invalidate the docker layer here. COPY --from=scripts install_mysql.sh install_mssql.sh install_postgres.sh /scripts/docker/ +ARG HOME=/root +ARG AIRFLOW_HOME=/root/airflow +ARG AIRFLOW_SOURCES=/opt/airflow + +ENV HOME=${HOME} \ + AIRFLOW_HOME=${AIRFLOW_HOME} \ + AIRFLOW_SOURCES=${AIRFLOW_SOURCES} + # We run scripts with bash here to make sure we can execute the scripts. Changing to +x might have an # unexpected result - the cache for Dockerfiles might get invalidated in case the host system # had different umask set and group x bit was not set. In Azure the bit might be not set at all. @@ -1107,28 +1084,8 @@ RUN bash /scripts/docker/install_mysql.sh prod \ && echo "airflow ALL=(ALL) NOPASSWD: ALL" > /etc/sudoers.d/airflow \ && chmod 0440 /etc/sudoers.d/airflow -ARG RUNTIME_APT_DEPS="\ - apt-transport-https \ - bash-completion \ - ca-certificates \ - software-properties-common \ - krb5-user \ - krb5-user \ - ldap-utils \ - less \ - lsb-release \ - net-tools \ - openssh-client \ - openssh-server \ - postgresql-client \ - sqlite3 \ - tmux \ - unzip \ - vim \ - xxd" - # Install Helm -ARG HELM_VERSION="v3.6.3" +ARG HELM_VERSION="v3.9.4" RUN SYSTEM=$(uname -s | tr '[:upper:]' '[:lower:]') \ && PLATFORM=$([ "$(uname -m)" = "aarch64" ] && echo "arm64" || echo "amd64" ) \ @@ -1136,42 +1093,6 @@ RUN SYSTEM=$(uname -s | tr '[:upper:]' '[:lower:]') \ && curl --silent --location "${HELM_URL}" | tar -xz -O "${SYSTEM}-${PLATFORM}/helm" > /usr/local/bin/helm \ && chmod +x /usr/local/bin/helm -ARG ADDITIONAL_RUNTIME_APT_DEPS="" -ARG RUNTIME_APT_COMMAND="" -ARG ADDITIONAL_RUNTIME_APT_COMMAND="" -ARG ADDITIONAL_DEV_APT_ENV="" -ARG ADDITIONAL_RUNTIME_APT_ENV="" - -ARG DOCKER_CLI_VERSION=19.03.9 -ARG HOME=/root -ARG AIRFLOW_HOME=/root/airflow -ARG AIRFLOW_SOURCES=/opt/airflow - -ENV RUNTIME_APT_DEP=${RUNTIME_APT_DEPS} \ - ADDITIONAL_RUNTIME_APT_DEPS=${ADDITIONAL_RUNTIME_APT_DEPS} \ - RUNTIME_APT_COMMAND=${RUNTIME_APT_COMMAND} \ - ADDITIONAL_RUNTIME_APT_COMMAND=${ADDITIONAL_RUNTIME_APT_COMMAND}\ - DOCKER_CLI_VERSION=${DOCKER_CLI_VERSION} \ - HOME=${HOME} \ - AIRFLOW_HOME=${AIRFLOW_HOME} \ - AIRFLOW_SOURCES=${AIRFLOW_SOURCES} - -RUN export ${ADDITIONAL_DEV_APT_ENV?} \ - && export ${ADDITIONAL_RUNTIME_APT_ENV?} \ - && source /scripts/docker/determine_debian_version_specific_variables.sh \ - && bash -o pipefail -o errexit -o nounset -o nolog -c "${RUNTIME_APT_COMMAND}" \ - && bash -o pipefail -o errexit -o nounset -o nolog -c "${ADDITIONAL_RUNTIME_APT_COMMAND}" \ - && apt-get update \ - && apt-get install --no-install-recommends -y \ - "${DISTRO_LIBGCC}" \ - ${RUNTIME_APT_DEPS} \ - ${ADDITIONAL_RUNTIME_APT_DEPS} \ - && apt-get autoremove -yqq --purge \ - && apt-get clean \ - && rm -rf /var/lib/apt/lists/* \ - && curl --silent "https://download.docker.com/linux/static/stable/x86_64/docker-${DOCKER_CLI_VERSION}.tgz" \ - | tar -C /usr/bin --strip-components=1 -xvzf - docker/docker - WORKDIR ${AIRFLOW_SOURCES} RUN mkdir -pv ${AIRFLOW_HOME} && \ @@ -1195,7 +1116,7 @@ ARG AIRFLOW_CI_BUILD_EPOCH="3" ARG AIRFLOW_PRE_CACHED_PIP_PACKAGES="true" # By default in the image, we are installing all providers when installing from sources ARG INSTALL_PROVIDERS_FROM_SOURCES="true" -ARG AIRFLOW_PIP_VERSION=22.1.2 +ARG AIRFLOW_PIP_VERSION=22.3.1 # Setup PIP # By default PIP install run without cache to make image smaller ARG PIP_NO_CACHE_DIR="true" @@ -1208,7 +1129,10 @@ ARG CASS_DRIVER_NO_CYTHON="1" # Build cassandra driver on multiple CPUs ARG CASS_DRIVER_BUILD_CONCURRENCY="8" -ARG AIRFLOW_VERSION="2.3.0.dev" +ARG AIRFLOW_VERSION="2.5.0.dev0" + +# Additional PIP flags passed to all pip install commands except reinstalling pip itself +ARG ADDITIONAL_PIP_INSTALL_FLAGS="" ENV AIRFLOW_REPO=${AIRFLOW_REPO}\ AIRFLOW_BRANCH=${AIRFLOW_BRANCH} \ @@ -1237,6 +1161,7 @@ ENV AIRFLOW_REPO=${AIRFLOW_REPO}\ AIRFLOW_VERSION_SPECIFICATION="" \ PIP_NO_CACHE_DIR=${PIP_NO_CACHE_DIR} \ PIP_PROGRESS_BAR=${PIP_PROGRESS_BAR} \ + ADDITIONAL_PIP_INSTALL_FLAGS=${ADDITIONAL_PIP_INSTALL_FLAGS} \ CASS_DRIVER_BUILD_CONCURRENCY=${CASS_DRIVER_BUILD_CONCURRENCY} \ CASS_DRIVER_NO_CYTHON=${CASS_DRIVER_NO_CYTHON} @@ -1245,7 +1170,16 @@ RUN echo "Airflow version: ${AIRFLOW_VERSION}" # Those are additional constraints that are needed for some extras but we do not want to # force them on the main Airflow package. Those limitations are: # * dill<0.3.3 required by apache-beam -ARG EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS="dill<0.3.3" +# * pyarrow>=6.0.0 is because pip resolver decides for Python 3.10 to downgrade pyarrow to 5 even if it is OK +# for python 3.10 and other dependencies adding the limit helps resolver to make better decisions +# We need to limit the protobuf library to < 4.21.0 because not all google libraries we use +# are compatible with the new protobuf version. All the google python client libraries need +# to be upgraded to >= 2.0.0 in order to able to lift that limitation +# https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates +# * authlib, gcloud_aio_auth, adal are needed to generate constraints for PyPI packages and can be removed after we release +# new google, azure providers +# !!! MAKE SURE YOU SYNCHRONIZE THE LIST BETWEEN: Dockerfile, Dockerfile.ci, find_newer_dependencies.py +ARG EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS="dill<0.3.3 pyarrow>=6.0.0 protobuf<4.21.0 authlib>=1.0.0 gcloud_aio_auth>=4.0.0 adal>=1.2.7" ARG UPGRADE_TO_NEWER_DEPENDENCIES="false" ENV EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS=${EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS} \ UPGRADE_TO_NEWER_DEPENDENCIES=${UPGRADE_TO_NEWER_DEPENDENCIES} @@ -1279,32 +1213,13 @@ COPY --from=scripts install_pipx_tools.sh /scripts/docker/ # dependencies installed in Airflow RUN bash /scripts/docker/install_pipx_tools.sh -# Copy package.json and yarn.lock to install node modules -# this way even if other static check files change, node modules will not need to be installed -# we want to keep node_modules so we can do this step separately from compiling assets -COPY --from=www package.json yarn.lock ${AIRFLOW_SOURCES}/airflow/www/ -COPY --from=scripts prepare_node_modules.sh /scripts/docker/ - -# Package JS/css for production -RUN bash /scripts/docker/prepare_node_modules.sh - -# Copy all the needed www/ for assets compilation. Done as two separate COPY -# commands so as otherwise it copies the _contents_ of static/ in to www/ -COPY --from=www webpack.config.js ${AIRFLOW_SOURCES}/airflow/www/ -COPY --from=www static ${AIRFLOW_SOURCES}/airflow/www/static/ -COPY --from=scripts compile_www_assets.sh /scripts/docker/ - -# Build artifacts without removing temporary artifacts (we will need them for incremental changes) -# in build mode -RUN REMOVE_ARTIFACTS="false" BUILD_TYPE="build" bash /scripts/docker/compile_www_assets.sh - # Airflow sources change frequently but dependency configuration won't change that often # We copy setup.py and other files needed to perform setup of dependencies # So in case setup.py changes we can install latest dependencies required. COPY setup.py ${AIRFLOW_SOURCES}/setup.py COPY setup.cfg ${AIRFLOW_SOURCES}/setup.cfg - COPY airflow/__init__.py ${AIRFLOW_SOURCES}/airflow/ +COPY generated/provider_dependencies.json ${AIRFLOW_SOURCES}/generated/ COPY --from=scripts install_airflow.sh /scripts/docker/ diff --git a/IMAGES.rst b/IMAGES.rst index 58ef0cca54852..f6b78aee0f310 100644 --- a/IMAGES.rst +++ b/IMAGES.rst @@ -85,13 +85,13 @@ You can build the CI image using current sources this command: .. code-block:: bash - breeze build-image + breeze ci-image build You can build the PROD image using current sources with this command: .. code-block:: bash - breeze build-prod-image + breeze prod-image build By adding ``--python `` parameter you can build the image version for the chosen Python version. @@ -104,13 +104,13 @@ For example if you want to build Python 3.7 version of production image with .. code-block:: bash - breeze build-prod-image --python 3.7 --extras "all" + breeze prod-image build --python 3.7 --extras "all" If you just want to add new extras you can add them like that: .. code-block:: bash - breeze build-prod-image --python 3.7 --additional-extras "all" + breeze prod-image build --python 3.7 --additional-extras "all" The command that builds the CI image is optimized to minimize the time needed to rebuild the image when the source code of Airflow evolves. This means that if you already have the image locally downloaded and @@ -128,7 +128,7 @@ parameter to Breeze: .. code-block:: bash - breeze build-prod-image --python 3.7 --additional-extras=trino --install-airflow-version=2.0.0 + breeze prod-image build --python 3.7 --additional-extras=trino --install-airflow-version=2.0.0 This will build the image using command similar to: @@ -165,8 +165,7 @@ You can also skip installing airflow and install it from locally provided files .. code-block:: bash - breeze build-prod-image --python 3.7 --additional-extras=trino \ - --airflow-is-in-context-pypi --install-packages-from-context + breeze prod-image build --python 3.7 --additional-extras=trino --install-packages-from-context In this case you airflow and all packages (.whl files) should be placed in ``docker-context-files`` folder. @@ -193,21 +192,21 @@ or ``disabled`` flags when you run Breeze commands. For example: .. code-block:: bash - breeze build-image --python 3.7 --docker-cache local + breeze ci-image build --python 3.7 --docker-cache local Will build the CI image using local build cache (note that it will take quite a long time the first time you run it). .. code-block:: bash - breeze build-prod-image --python 3.7 --docker-cache registry + breeze prod-image build --python 3.7 --docker-cache registry Will build the production image with cache used from registry. .. code-block:: bash - breeze build-prod-image --python 3.7 --docker-cache disabled + breeze prod-image build --python 3.7 --docker-cache disabled Will build the production image from the scratch. @@ -281,7 +280,7 @@ to refresh them. Every developer can also pull and run images being result of a specific CI run in GitHub Actions. This is a powerful tool that allows to reproduce CI failures locally, enter the images and fix them much -faster. It is enough to pass ``--github-image-id`` and the registry and Breeze will download and execute +faster. It is enough to pass ``--image-tag`` and the registry and Breeze will download and execute commands using the same image that was used during the CI tests. For example this command will run the same Python 3.8 image as was used in build identified with @@ -289,8 +288,7 @@ For example this command will run the same Python 3.8 image as was used in build .. code-block:: bash - ./breeze-legacy --github-image-id 9a621eaa394c0a0a336f8e1b31b35eff4e4ee86e \ - --python 3.8 --integration rabbitmq + breeze --image-tag 9a621eaa394c0a0a336f8e1b31b35eff4e4ee86e --python 3.8 --integration rabbitmq You can see more details and examples in `Breeze `_ @@ -318,10 +316,9 @@ you have ``buildx`` plugin installed. DOCKER_BUILDKIT=1 docker build . -f Dockerfile.ci \ --pull \ --build-arg PYTHON_BASE_IMAGE="python:3.7-slim-bullseye" \ - --build-arg ADDITIONAL_AIRFLOW_EXTRAS="jdbc" - --build-arg ADDITIONAL_PYTHON_DEPS="pandas" - --build-arg ADDITIONAL_DEV_APT_DEPS="gcc g++" - --build-arg ADDITIONAL_RUNTIME_APT_DEPS="default-jre-headless" + --build-arg ADDITIONAL_AIRFLOW_EXTRAS="jdbc" \ + --build-arg ADDITIONAL_PYTHON_DEPS="pandas" \ + --build-arg ADDITIONAL_DEV_APT_DEPS="gcc g++" \ --tag my-image:0.0.1 @@ -329,8 +326,8 @@ the same image can be built using ``breeze`` (it supports auto-completion of the .. code-block:: bash - breeze build-prod-image --python 3.7 --additional-extras=jdbc --additional-python-deps="pandas" \ - --additional-dev-apt-deps="gcc g++" --additional-runtime-apt-deps="default-jre-headless" + breeze ci-image build --python 3.7 --additional-extras=jdbc --additional-python-deps="pandas" \ + --additional-dev-apt-deps="gcc g++" You can customize more aspects of the image - such as additional commands executed before apt dependencies are installed, or adding extra sources to install your dependencies from. You can see all the arguments @@ -355,10 +352,7 @@ based on example in `this comment `_ where they describe +the issues they think are Airflow issues and should be solved. There are two kinds of issues: + +* Bugs - when the user thinks the reported issue is a bug in Airflow +* Features - when there are small features that the user would like to see in Airflow + +We have `templates `_ for both types +of issues defined in Airflow. + +However, important part of our issue reporting process are +`GitHub Discussions `_ . Issues should represent +clear, small feature requests or reproducible bugs which can/should be either implemented or fixed. +Users are encouraged to open discussions rather than issues if there are no clear, reproducible +steps, or when they have troubleshooting problems, and one of the important points of issue triaging is +to determine if the issue reported should be rather a discussion. Converting an issue to a discussion +while explaining the user why is an important part of triaging process. + +Responding to issues/discussions (relatively) quickly +''''''''''''''''''''''''''''''''''''''''''''''''''''' + +It is vital to provide rather quick feedback to issues and discussions opened by our users, so that they +feel listened to rather than ignored. Even if the response is "we are not going to work on it because ...", +or "converting this issue to discussion because ..." or "closing because it is a duplicate of #xxx", it is +far more welcoming than leaving issues and discussions unanswered. Sometimes issues and discussions are +answered by other users (and this is cool) but if an issue/discussion is not responded to for a few days or +weeks, this gives an impression that the user was ignored and that the Airflow project is unwelcoming. + +We strive to provide relatively quick responses to all such issues and discussions. Users should exercise +patience while waiting for those (knowing that people might be busy, on vacations etc.) however they should +not wait weeks until someone looks at their issues. + + +Issue Triage Team +'''''''''''''''''' + +While many of the issues can be responded to by other users and committers, the committer team is not +big enough to handle all such requests and sometimes they are busy with implementing important +Therefore, some people who are regularly contributing and helping other users and shown their deep interest +in the project can be invited to join the triage team. +`the .asf.yaml <.asf.yaml>`_ file in the ``collaborators`` section. + +Committers can invite people to become members of the triage team if they see that the users are already +helping and responding to issues and when they see the users are involved regularly. But you can also ask +to become a member of the team (on devlist) if you can show that you have done that and when you want to have +more ways to help others. + +The triage team members do not have committer privileges but they can +assign, edit, and close issues and pull requests without having capabilities to merge the code. They can +also convert issues into discussions and back. The expectation for the issue triage team is that they +spend a bit of their time on those efforts. Triaging means not only assigning the labels but often responding +to the issues and answering user concerns or if additional input is needed - tagging the committers or other community members who might be able to help provide more complete answers. + +Being an active and helpful member of the "Issue Triage Team" is actually one of the paths towards +becoming a committer. By actively helping the users, triaging the issues, responding to them and +involving others (when needed) shows that you are not only willing to help our users and the community, +but are also ready to learn about parts of the projects you are not actively contributing to - all of that +are super valuable components of being eligible to `become a committer `_. + +If you are a member of the triage team and not able to make any commitment, it's best to ask to have yourself +removed from the triage team. + +BTW. Every committer is pretty much automatically part of the "Issue Triage Team" - so if you are committer, +feel free to follow the process for every issue you stumble upon. + +Actions that can be taken by the issue triager +'''''''''''''''''''''''''''''''''''''''''''''' + +There are several actions an issue triager might take: + +* Closing and issue with "invalid" label explaining why it is closed in case the issue is invalid. This + should be accompanied by information that we can always re-open an issue if our understanding was wrong + or if the user provides more information. + +* Converting an issue to a discussion, if it is not very likely it is an Airflow issue or when it is not + responsible, or when it is a bigger feature proposal requiring discussion or when it's really users + troubleshooting or when the issue description is not at all clear. This also involves inviting the user + to a discussion if more information might change it. + +* Assigning the issue to a milestone, if the issue seems important enough that it should likely be looked + at before the next release but there is not enough information or doubts on why and what can be fixed. + Usually we assign to the the next bugfix release - then, no matter what the issue will be looked at + by the release manager and it might trigger additional actions during the release preparation. + This is usually followed by one of the actions below. + +* Fixing the issue in a PR if you see it is easy to fix. This is a great way also to learn and + contribute to parts that you usually are not contributing to, and sometimes it is surprisingly easy. + +* Assigning "good first issue" label if an issue is clear but not important to be fixed immediately, This + often lead to contributors picking up the issues when they are interested. This can be followed by assigning + the user who comments "I want to work on this" in the issue (which is most welcome). + +* Asking the user for additional information if it is needed to perform further investigations. This should + be accompanied by assigning ``pending response`` label so that we can clearly see the issues that need + extra information. + +* Calling other people who might be knowledgeable in the area by @-mentioning them in a comment. + +* Assigning other labels to the issue as described below. + Labels '''''' -Since Apache Airflow uses GitHub Issues as the issue tracking system, the -use of labels is extensive. Though issue labels tend to change over time -based on components within the project, the majority of the ones listed -below should stand the test of time. +Since Apache Airflow uses "GitHub Issues" and "Github Discussions" as the +issue tracking systems, the use of labels is extensive. Though issue +labels tend to change over time based on components within the project, +the majority of the ones listed below should stand the test of time. The intention with the use of labels with the Apache Airflow project is that they should ideally be non-temporal in nature and primarily used @@ -44,16 +140,16 @@ to indicate the following elements: **Kind** -The “kind” labels indicate “what kind of issue it is”. The most -commonly used “kind” labels are: bug, feature, documentation, or task. +The "kind" labels indicate "what kind of issue it is". The most +commonly used "kind" labels are: bug, feature, documentation, or task. Therefore, when reporting an issue, the label of ``kind:bug`` is to indicate a problem with the functionality, whereas the label of ``kind:feature`` is a desire to extend the functionality. There has been discussion within the project about whether to separate -the desire for “new features” from “enhancements to existing features”, -but in practice most “feature requests” are actually enhancement requests, +the desire for "new features" from "enhancements to existing features", +but in practice most "feature requests" are actually enhancement requests, so we decided to combine them both into ``kind:feature``. The ``kind:task`` is used to categorize issues which are @@ -67,7 +163,7 @@ made to the documentation within the project. **Area** -The “area” set of labels should indicate the component of the code +The "area" set of labels should indicate the component of the code referenced by the issue. At a high level, the biggest areas of the project are: Airflow Core and Airflow Providers, which are referenced by ``area:core`` and ``area:providers``. This is especially important since these are now @@ -75,15 +171,18 @@ being released and versioned independently. There are more detailed areas of the Core Airflow project such as Scheduler, Webserver, API, UI, Logging, and Kubernetes, which are all conceptually under the -“Airflow Core” area of the project. +"Airflow Core" area of the project. Similarly within Airflow Providers, the larger providers such as Apache, AWS, Azure, and Google who have many hooks and operators within them, have labels directly -associated with them such as ``provider/Apache``, ``provider/AWS``, -``provider/Azure``, and ``provider/Google``. +associated with them such as ``provider:Apache``, ``provider:AWS``, +``provider:Azure``, and ``provider:Google``. + These make it easier for developers working on a single provider to track issues for that provider. +Some provider labels may couple several providers for example: ``provider:Protocols`` + Most issues need a combination of "kind" and "area" labels to be actionable. For example: @@ -91,7 +190,6 @@ For example: * Bug report on the User Interface would have ``kind:bug`` and ``area:UI`` * Documentation request on the Kubernetes Executor, would have ``kind:documentation`` and ``area:kubernetes`` - Response to issues '''''''''''''''''' @@ -116,7 +214,7 @@ Therefore, the priority labels used are: It's important to use priority labels effectively so we can triage incoming issues appropriately and make sure that when we release a new version of Airflow, -we can ship a release confident that there are no “production blocker” issues in it. +we can ship a release confident that there are no "production blocker" issues in it. This applies to both Core Airflow as well as the Airflow Providers. With the separation of the Providers release from Core Airflow, a ``priority:critical`` bug in a single @@ -151,11 +249,7 @@ to be able to reproduce the issue. Typically, this may require a response to the issue creator asking for more information, with the issue then being tagged with the label ``pending-response``. Also, during this stage, additional labels may be added to the issue to help -classification and triage, such as ``reported_version`` and ``area``. - -Occasionally an issue may require a larger discussion among the Airflow PMC or -the developer mailing list. This issue may then be tagged with the -``needs:discussion`` label. +classification and triage, such as ``affected_version`` and ``area``. Some issues may need a detailed review by one of the core committers of the project and this could be tagged with a ``needs:triage`` label. @@ -164,7 +258,7 @@ and this could be tagged with a ``needs:triage`` label. **Good First Issue** Issues which are relatively straight forward to solve, will be tagged with -the ``GoodFirstIssue`` label. +the ``good first issue`` label. The intention here is to galvanize contributions from new and inexperienced contributors who are looking to contribute to the project. This has been successful @@ -175,13 +269,13 @@ Ideally, these issues only require one or two files to be changed. The intention here is that incremental changes to existing files are a lot easier for a new contributor as compared to adding something completely new. -Another possibility here is to add “how to fix” in the comments of such issues, so +Another possibility here is to add "how to fix" in the comments of such issues, so that new contributors have a running start when then pick up these issues. **Timeliness** -For the sake of quick responses, the general “soft" rule within the Airflow project +For the sake of quick responses, the general "soft" rule within the Airflow project is that if there is no assignee, anyone can take an issue to solve. However, this depends on timely resolution of the issue by the assignee. The @@ -196,13 +290,20 @@ issue creator. After the pending-response label has been assigned, if there is n further information for a period of 1 month, the issue will be automatically closed. - **Invalidity** At times issues are marked as invalid and later closed because of one of the following situations: * The issue is a duplicate of an already reported issue. In such cases, the latter issue is marked as ``duplicate``. -* Despite attempts to reproduce the issue to resolve it, the issue cannot be reproduced by the Airflow team based on the given information. In such cases, the issue is marked as ``Can’t Reproduce``. +* Despite attempts to reproduce the issue to resolve it, the issue cannot be reproduced by the Airflow team based on the given information. In such cases, the issue is marked as ``Can't Reproduce``. * In some cases, the original creator realizes that the issue was incorrectly reported and then marks it as ``invalid``. Also, a committer could mark it as ``invalid`` if the issue being reported is for an unsupported operation or environment. * In some cases, the issue may be legitimate, but may not be addressed in the short to medium term based on current project priorities or because this will be irrelevant because of an upcoming change. The committer could mark this as ``wontfix`` to set expectations that it won't be directly addressed in the near term. + +**GitHub Discussions** + +Issues should represent clear feature requests which can/should be implemented. If the idea is vague or can be solved with easier steps +we normally convert such issues to discussions in the Ideas category. +Issues that seems more like support requests are also converted to discussions in the Q&A category. +We use judgment about which Issues to convert to discussions, it's best to always clarify with a comment why the issue is being converted. +Note that we can always convert discussions back to issues. diff --git a/LOCAL_VIRTUALENV.rst b/LOCAL_VIRTUALENV.rst index 1cbf5b6b12e81..3208634a4ee62 100644 --- a/LOCAL_VIRTUALENV.rst +++ b/LOCAL_VIRTUALENV.rst @@ -205,21 +205,15 @@ Activate your virtualenv, e.g. by using ``workon``, and once you are in it, run: .. code-block:: bash - ./breeze-legacy initialize-local-virtualenv + ./scripts/tools/initialize_virtualenv.py -By default Breeze installs the ``devel`` extra only. You can optionally control which extras are installed by exporting ``VIRTUALENV_EXTRAS`` before calling Breeze: +By default Breeze installs the ``devel`` extra only. You can optionally control which extras are +Adding extra dependencies as parameter. .. code-block:: bash - export VIRTUALENV_EXTRAS="devel,google,postgres" - ./breeze-legacy initialize-local-virtualenv + ./scripts/tools/initialize_virtualenv.py devel,google,postgres -5. (optionally) run yarn build if you plan to run the webserver - -.. code-block:: bash - - cd airflow/www - yarn build Developing Providers -------------------- diff --git a/MANIFEST.in b/MANIFEST.in index 8f4b22b7ca01c..bfcaf6057c8a1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -38,3 +38,4 @@ include airflow/customized_form_field_behaviours.schema.json include airflow/serialization/schema.json include airflow/utils/python_virtualenv_script.jinja2 include airflow/utils/context.pyi +include generated diff --git a/NOTICE b/NOTICE index 4c7b795d88ce6..84c77cd4fc12c 100644 --- a/NOTICE +++ b/NOTICE @@ -20,3 +20,10 @@ This product contains a modified portion of 'Flask App Builder' developed by Dan (https://github.com/dpgaspar/Flask-AppBuilder). * Copyright 2013, Daniel Vaz Gaspar + +Chakra UI: +----- +This product contains a modified portion of 'Chakra UI' developed by Segun Adebayo. +(https://github.com/chakra-ui/chakra-ui). + +* Copyright 2019, Segun Adebayo diff --git a/PULL_REQUEST_WORKFLOW.rst b/PULL_REQUEST_WORKFLOW.rst deleted file mode 100644 index d7ca2f9b93eaa..0000000000000 --- a/PULL_REQUEST_WORKFLOW.rst +++ /dev/null @@ -1,158 +0,0 @@ - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -.. contents:: :local: - -Why non-standard pull request workflow? ---------------------------------------- - -This document describes the Pull Request Workflow we've implemented in Airflow. The workflow is slightly -more complex than regular workflow you might encounter in most of the projects because after experiencing -some huge delays in processing queues in October 2020 with GitHub Actions, we've decided to optimize the -workflow to minimize the use of GitHub Actions build time by utilising selective approach on which tests -and checks in the CI system are run depending on analysis of which files changed in the incoming PR and -allowing the Committers to control the scope of the tests during the approval/review process. - -Just to give a bit of context, we started off with the approach that we always run all tests for all the -incoming PRs, however due to our matrix of tests growing, this approach did not scale with the increasing -number of PRs and when we had to compete with other Apache Software Foundation projects for the 180 -slots that are available for the whole organization. More Apache Software Foundation projects started -to use GitHub Actions and we've started to experience long queues when our jobs waited for free slots. - -We approached the problem by: - -1) Improving mechanism of cancelling duplicate workflow runs more efficiently in case of queue conditions - (duplicate workflow runs are generated when someone pushes a fixup quickly - leading to running both - out-dated and current run to completion, taking precious slots. This has been implemented by improving - `cancel-workflow-run `_ action we are using. In version - 4.1 it got a new feature of cancelling all duplicates even if there is a long queue of builds. - -2) Heavily decreasing strain on the GitHub Actions jobs by introducing selective checks - mechanism - to control which parts of the tests are run during the tests. This is implemented by the - ``scripts/ci/selective_ci_checks.sh`` script in our repository. This script analyses which part of the - code has changed and based on that it sets the right outputs that control which tests are executed in - the ``Tests`` workflow, and whether we need to build CI images necessary to run those steps. This allowed to - heavily decrease the strain especially for the Pull Requests that were not touching code (in which case - the builds can complete in < 2 minutes) but also by limiting the number of tests executed in PRs that do - not touch the "core" of Airflow, or only touching some - standalone - parts of Airflow such as - "Providers", "WWW" or "CLI". This solution is not yet perfect as there are likely some edge cases but - it is easy to maintain and we have an escape-hatch - all the tests are always executed in main pushes, - so contributors can easily spot if there is a "missed" case and fix it - both by fixing the problem and - adding those exceptions to the code. More about it can be found in `Selective checks `_ - -3) Even more optimisation came from limiting the scope of tests to only "default" matrix parameters. So far - in Airflow we always run all tests for all matrix combinations. The primary matrix components are: - - * Python versions (currently 3.7, 3.8, 3.9, 3.10) - * Backend types (currently MySQL/Postgres) - * Backed version (currently MySQL 5.7, MySQL 8, Postgres 13 - - We've decided that instead of running all the combinations of parameters for all matrix component we will - only run default values (Python 3.7, Mysql 5.7, Postgres 13) for all PRs which are not approved yet by - the committers. This has a nice effect, that full set of tests (though with limited combinations of - the matrix) are still run in the CI for every Pull Request that needs tests at all - allowing the - contributors to make sure that their PR is "good enough" to be reviewed. - - Even after approval, the automated workflows we've implemented, check if the PR seems to need - "full test matrix" and provide helpful information to both contributors and committers in the form of - explanatory comments and labels set automatically showing the status of the PR. Committers have still - control whether they want to merge such requests automatically or ask for rebase or re-run the tests - and run "full tests" by applying the "full tests needed" label and re-running such request. - The "full tests needed" label is also applied automatically after approval when the change touches - the "core" of Airflow - also a separate check is added to the PR so that the "merge" button status - will indicate to the committer that full tests are still needed. The committer might still decide, - whether to merge such PR without the "full matrix". The "escape hatch" we have - i.e. running the full - matrix of tests in the "merge push" will enable committers to catch and fix such problems quickly. - More about it can be found in `Approval workflow and Matrix tests <#approval-workflow-and-matrix-tests>`_ - chapter. - -4) We've also applied (and received) funds to run self-hosted runners. They are used for ``main`` runs - and whenever the PRs are done by one of the maintainers. Maintainers can force using Public GitHub runners - by applying "use public runners" label to the PR before submitting it. - - -Approval Workflow and Matrix tests ----------------------------------- - -As explained above the approval and matrix tests workflow works according to the algorithm below: - -1) In case of "no-code" changes - so changes that do not change any of the code or environment of - the application, no test are run (this is done via selective checks). Also no CI/PROD images are - build saving extra minutes. Such build takes less than 2 minutes currently and only few jobs are run - which is a very small fraction of the "full build" time. - -2) When new PR is created, only a "default set" of matrix test are running. Only default - values for each of the parameters are used effectively limiting it to running matrix builds for only - one python version and one version of each of the backends. In this case only one CI and one PROD - image is built, saving precious job slots. This build takes around 50% less time than the "full matrix" - build. - -3) When such PR gets approved, the system further analyses the files changed in this PR and further - decision is made that should be communicated to both Committer and Reviewer. - -3a) In case of "no-code" builds, a message is communicated that the PR is ready to be merged and - no tests are needed. - -.. image:: images/pr/pr-no-tests-needed-comment.png - :align: center - :alt: No tests needed for "no-code" builds - -3b) In case of "non-core" builds a message is communicated that such PR is likely OK to be merged as is with - limited set of tests, but that the committer might decide to re-run the PR after applying - "full tests needed" label, which will trigger full matrix build for tests for this PR. The committer - might make further decision on what to do with this PR. - -.. image:: images/pr/pr-likely-ok-to-merge.png - :align: center - :alt: Likely ok to merge the PR with only small set of tests - -3c) In case of "core" builds (i. e. when the PR touches some "core" part of Airflow) a message is - communicated that this PR needs "full test matrix", the "full tests needed" label is applied - automatically and either the contributor might rebase the request to trigger full test build or the - committer might re-run the build manually to trigger such full test rebuild. Also a check "in-progress" - is added, so that the committer realises that the PR is not yet "green to merge". Pull requests with - "full tests needed" label always trigger the full matrix build when rebased or re-run so if the - PR gets rebased, it will continue triggering full matrix build. - -.. image:: images/pr/pr-full-tests-needed.png - :align: center - :alt: Full tests are needed for the PR - -4) If this or another committer "request changes" in a previously approved PR with "full tests needed" - label, the bot automatically removes the label, moving it back to "run only default set of parameters" - mode. For PRs touching core of airflow once the PR gets approved back, the label will be restored. - If it was manually set by the committer, it has to be restored manually. - -.. note:: Note that setting the labels and adding comments might be delayed, due to limitation of GitHub Actions, - in case of queues, processing of Pull Request reviews might take some time, so it is advised not to merge - PR immediately after approval. Luckily, the comments describing the status of the PR trigger notifications - for the PRs and they provide good "notification" for the committer to act on a PR that was recently - approved. - -The PR approval workflow is possible thanks to two custom GitHub Actions we've developed: - -* `Get workflow origin `_ -* `Label when approved `_ - - -Next steps ----------- - -We are planning to also propose the approach to other projects from Apache Software Foundation to -make it a common approach, so that our effort is not limited only to one project. - -Discussion about it in `this discussion `_ diff --git a/README.md b/README.md index 3acb883dc4ba8..3b5e704e5f086 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![Twitter Follow](https://img.shields.io/twitter/follow/ApacheAirflow.svg?style=social&label=Follow)](https://twitter.com/ApacheAirflow) [![Slack Status](https://img.shields.io/badge/slack-join_chat-white.svg?logo=slack&style=social)](https://s.apache.org/airflow-slack) +[![Contributors](https://img.shields.io/github/contributors/apache/airflow)](https://github.com/apache/airflow/graphs/contributors) [Apache Airflow](https://airflow.apache.org/docs/apache-airflow/stable/) (or simply Airflow) is a platform to programmatically author, schedule, and monitor workflows. @@ -55,7 +56,7 @@ Use Airflow to author workflows as directed acyclic graphs (DAGs) of tasks. The - [Support for Python and Kubernetes versions](#support-for-python-and-kubernetes-versions) - [Base OS support for reference Airflow images](#base-os-support-for-reference-airflow-images) - [Approach to dependencies of Airflow](#approach-to-dependencies-of-airflow) -- [Support for providers](#support-for-providers) +- [Release process for Providers](#release-process-for-providers) - [Contributing](#contributing) - [Who uses Apache Airflow?](#who-uses-apache-airflow) - [Who Maintains Apache Airflow?](#who-maintains-apache-airflow) @@ -70,7 +71,7 @@ Use Airflow to author workflows as directed acyclic graphs (DAGs) of tasks. The Airflow works best with workflows that are mostly static and slowly changing. When the DAG structure is similar from one run to the next, it clarifies the unit of work and continuity. Other similar projects include [Luigi](https://github.com/spotify/luigi), [Oozie](https://oozie.apache.org/) and [Azkaban](https://azkaban.github.io/). -Airflow is commonly used to process data, but has the opinion that tasks should ideally be idempotent (i.e., results of the task will be the same, and will not create duplicated data in a destination system), and should not pass large quantities of data from one task to the next (though tasks can pass metadata using Airflow's [Xcom feature](https://airflow.apache.org/docs/apache-airflow/stable/concepts.html#xcoms)). For high-volume, data-intensive tasks, a best practice is to delegate to external services specializing in that type of work. +Airflow is commonly used to process data, but has the opinion that tasks should ideally be idempotent (i.e., results of the task will be the same, and will not create duplicated data in a destination system), and should not pass large quantities of data from one task to the next (though tasks can pass metadata using Airflow's [XCom feature](https://airflow.apache.org/docs/apache-airflow/stable/concepts/xcoms.html)). For high-volume, data-intensive tasks, a best practice is to delegate to external services specializing in that type of work. Airflow is not a streaming solution, but it is often used to process real-time data, pulling data off streams in batches. @@ -85,12 +86,12 @@ Airflow is not a streaming solution, but it is often used to process real-time d Apache Airflow is tested with: -| | Main version (dev) | Stable version (2.3.1) | +| | Main version (dev) | Stable version (2.5.1) | |---------------------|------------------------------|------------------------------| | Python | 3.7, 3.8, 3.9, 3.10 | 3.7, 3.8, 3.9, 3.10 | | Platform | AMD64/ARM64(\*) | AMD64/ARM64(\*) | -| Kubernetes | 1.20, 1.21, 1.22, 1.23, 1.24 | 1.20, 1.21, 1.22, 1.23, 1.24 | -| PostgreSQL | 10, 11, 12, 13, 14 | 10, 11, 12, 13, 14 | +| Kubernetes | 1.21, 1.22, 1.23, 1.24, 1.25 | 1.21, 1.22, 1.23, 1.24, 1.25 | +| PostgreSQL | 11, 12, 13, 14, 15 | 11, 12, 13, 14, 15 | | MySQL | 5.7, 8 | 5.7, 8 | | SQLite | 3.15.0+ | 3.15.0+ | | MSSQL | 2017(\*), 2019 (\*) | 2017(\*), 2019 (\*) | @@ -104,9 +105,6 @@ MariaDB is not tested/recommended. **Note**: SQLite is used in Airflow tests. Do not use it in production. We recommend using the latest stable version of SQLite for local development. -**Note**: Support for Python v3.10 will be available from Airflow 2.3.0. The `main` (development) branch -already supports Python 3.10. - **Note**: Airflow currently can be run on POSIX-compliant Operating Systems. For development it is regularly tested on fairly modern Linux Distros and recent versions of MacOS. On Windows you can run it via WSL2 (Windows Subsystem for Linux 2) or via Linux Containers. @@ -120,13 +118,13 @@ is used in the [Community managed DockerHub image](https://hub.docker.com/p/apac Visit the official Airflow website documentation (latest **stable** release) for help with [installing Airflow](https://airflow.apache.org/docs/apache-airflow/stable/installation.html), -[getting started](https://airflow.apache.org/docs/apache-airflow/stable/start/index.html), or walking +[getting started](https://airflow.apache.org/docs/apache-airflow/stable/start.html), or walking through a more complete [tutorial](https://airflow.apache.org/docs/apache-airflow/stable/tutorial.html). > Note: If you're looking for documentation for the main branch (latest development branch): you can find it on [s.apache.org/airflow-docs](https://s.apache.org/airflow-docs/). For more information on Airflow Improvement Proposals (AIPs), visit -the [Airflow Wiki](https://cwiki.apache.org/confluence/display/AIRFLOW/Airflow+Improvements+Proposals). +the [Airflow Wiki](https://cwiki.apache.org/confluence/display/AIRFLOW/Airflow+Improvement+Proposals). Documentation for dependent projects like provider packages, Docker image, Helm Chart, you'll find it in [the documentation index](https://airflow.apache.org/docs/). @@ -160,15 +158,15 @@ them to the appropriate format and workflow that your tool requires. ```bash -pip install 'apache-airflow==2.3.1' \ - --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.3.1/constraints-3.7.txt" +pip install 'apache-airflow==2.5.1' \ + --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.5.1/constraints-3.7.txt" ``` 2. Installing with extras (i.e., postgres, google) ```bash -pip install 'apache-airflow[postgres,google]==2.3.1' \ - --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.3.1/constraints-3.7.txt" +pip install 'apache-airflow[postgres,google]==2.5.1' \ + --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.5.1/constraints-3.7.txt" ``` For information on installing provider packages, check @@ -273,7 +271,7 @@ Apache Airflow version life cycle: | Version | Current Patch/Minor | State | First Release | Limited Support | EOL/Terminated | |-----------|-----------------------|-----------|-----------------|-------------------|------------------| -| 2 | 2.3.1 | Supported | Dec 17, 2020 | TBD | TBD | +| 2 | 2.5.1 | Supported | Dec 17, 2020 | TBD | TBD | | 1.10 | 1.10.15 | EOL | Aug 27, 2018 | Dec 17, 2020 | June 17, 2021 | | 1.9 | 1.9.0 | EOL | Jan 03, 2018 | Aug 27, 2018 | Aug 27, 2018 | | 1.8 | 1.8.2 | EOL | Mar 19, 2017 | Jan 03, 2018 | Jan 03, 2018 | @@ -293,8 +291,8 @@ They are based on the official release schedule of Python and Kubernetes, nicely [Python Developer's Guide](https://devguide.python.org/#status-of-python-branches) and [Kubernetes version skew policy](https://kubernetes.io/docs/setup/release/version-skew-policy/). -1. We drop support for Python and Kubernetes versions when they reach EOL. Except for kubernetes, a - version stay supported by Airflow if two major cloud provider still provide support for it. We drop +1. We drop support for Python and Kubernetes versions when they reach EOL. Except for Kubernetes, a + version stays supported by Airflow if two major cloud providers still provide support for it. We drop support for those EOL versions in main right after EOL date, and it is effectively removed when we release the first new MINOR (Or MAJOR if there is no new MINOR version) of Airflow. For example, for Python 3.7 it means that we will drop support in main right after 27.06.2023, and the first MAJOR or MINOR version of @@ -303,7 +301,7 @@ They are based on the official release schedule of Python and Kubernetes, nicely 2. The "oldest" supported version of Python/Kubernetes is the default one until we decide to switch to later version. "Default" is only meaningful in terms of "smoke tests" in CI PRs, which are run using this default version and the default reference image available. Currently `apache/airflow:latest` - and `apache/airflow:2.3.1` images are Python 3.7 images. This means that default reference image will + and `apache/airflow:2.5.1` images are Python 3.7 images. This means that default reference image will become the default at the time when we start preparing for dropping 3.7 support which is few months before the end of life for Python 3.7. @@ -321,7 +319,7 @@ we publish an Apache Airflow release. Those images contain: Airflow released (so there could be different versions for 2.3 and 2.2 line for example) * Libraries required to connect to suppoerted Databases (again the set of databases supported depends on the MINOR version of Airflow. -* Predefined set of popular providers (for details see the [Dockerfile](Dockerfile)). +* Predefined set of popular providers (for details see the [Dockerfile](https://raw.githubusercontent.com/apache/airflow/main/Dockerfile)). * Possibility of building your own, custom image where the user can choose their own set of providers and libraries (see [Building the image](https://airflow.apache.org/docs/docker-stack/build.html)) * In the future Airflow might also support a "slim" version without providers nor database clients installed @@ -330,14 +328,17 @@ The version of the base OS image is the stable version of Debian. Airflow suppor stable versions - as soon as all Airflow dependencies support building, and we set up the CI pipeline for building and testing the OS version. Approximately 6 months before the end-of-life of a previous stable version of the OS, Airflow switches the images released to use the latest supported version of the OS. -For example since Debian Buster end-of-life is August 2022, Airflow switches the images in `main` branch -to use Debian Bullseye in February/March 2022. The version will be used in the next MINOR release after -the switch happens. In case of the Bullseye switch - 2.3.0 version will use Bullseye. The images released -in the previous MINOR version continue to use the version that all other releases for the MINOR version -used. +For example since ``Debian Buster`` end-of-life was August 2022, Airflow switched the images in `main` branch +to use ``Debian Bullseye`` in February/March 2022. The version was used in the next MINOR release after +the switch happened. In case of the Bullseye switch - 2.3.0 version used ``Debian Bullseye``. +The images released in the previous MINOR version continue to use the version that all other releases +for the MINOR version used. + +Support for ``Debian Buster`` image was dropped in August 2022 completely and everyone is expected to +stop building their images using ``Debian Buster``. Users will continue to be able to build their images using stable Debian releases until the end of life and -building and verifying of the images happens in our CI but no unit tests are executed using this image in +building and verifying of the images happens in our CI but no unit tests were executed using this image in the `main` branch. ## Approach to dependencies of Airflow @@ -391,22 +392,81 @@ The important dependencies are: ### Approach for dependencies in Airflow Providers and extras -Those `extras` and `providers` dependencies are maintained in `setup.py`. +Those `extras` and `providers` dependencies are maintained in `provider.yaml` of each provider. By default, we should not upper-bound dependencies for providers, however each provider's maintainer might decide to add additional limits (and justify them with comment) -## Support for providers +## Release process for Providers -Providers released by the community have limitation of a minimum supported version of Airflow. The minimum -version of Airflow is the `MINOR` version (2.1, 2.2 etc.) indicating that the providers might use features -that appeared in this release. The default support timespan for the minimum version of Airflow -(there could be justified exceptions) is that we increase the minimum Airflow version, when 12 months passed -since the first release for the MINOR version of Airflow. +Providers released by the community (with roughly monthly cadence) have +limitation of a minimum supported version of Airflow. The minimum version of +Airflow is the `MINOR` version (2.2, 2.3 etc.) indicating that the providers +might use features that appeared in this release. The default support timespan +for the minimum version of Airflow (there could be justified exceptions) is +that we increase the minimum Airflow version, when 12 months passed since the +first release for the MINOR version of Airflow. For example this means that by default we upgrade the minimum version of Airflow supported by providers -to 2.2.0 in the first Provider's release after 21st of May 2022 (21st of May 2021 is the date when the -first `PATCHLEVEL` of 2.1 (2.1.0) has been released. +to 2.4.0 in the first Provider's release after 30th of April 2023. The 30th of April 2022 is the date when the +first `PATCHLEVEL` of 2.3 (2.3.0) has been released. + +When we increase the minimum Airflow version, this is not a reason to bump `MAJOR` version of the providers +(unless there are other breaking changes in the provider). The reason for that is that people who use +older version of Airflow will not be able to use that provider (so it is not a breaking change for them) +and for people who are using supported version of Airflow this is not a breaking change on its own - they +will be able to use the new version without breaking their workflows. When we upgraded min-version to +2.2+, our approach was different but as of 2.3+ upgrade (November 2022) we only bump `MINOR` version of the +provider when we increase minimum Airflow version. + +Providers are often connected with some stakeholders that are vitally interested in maintaining backwards +compatibilities in their integrations (for example cloud providers, or specific service providers). But, +we are also bound with the [Apache Software Foundation release policy](https://www.apache.org/legal/release-policy.html) +which describes who releases, and how to release the ASF software. The provider's governance model is something we name +"mixed governance" - where we follow the release policies, while the burden of maintaining and testing +the cherry-picked versions is on those who commit to perform the cherry-picks and make PRs to older +branches. + +The "mixed governance" (optional, per-provider) means that: + +* The Airflow Community and release manager decide when to release those providers. + This is fully managed by the community and the usual release-management process following the + [Apache Software Foundation release policy](https://www.apache.org/legal/release-policy.html) +* The contributors (who might or might not be direct stakeholders in the provider) will carry the burden + of cherry-picking and testing the older versions of providers. +* There is no "selection" and acceptance process to determine which version of the provider is released. + It is determined by the actions of contributors raising the PR with cherry-picked changes and it follows + the usual PR review process where maintainer approves (or not) and merges (or not) such PR. Simply + speaking - the completed action of cherry-picking and testing the older version of the provider make + it eligible to be released. Unless there is someone who volunteers and perform the cherry-picking and + testing, the provider is not released. +* Branches to raise PR against are created when a contributor commits to perform the cherry-picking + (as a comment in PR to cherry-pick for example) + +Usually, community effort is focused on the most recent version of each provider. The community approach is +that we should rather aggressively remove deprecations in "major" versions of the providers - whenever +there is an opportunity to increase major version of a provider, we attempt to remove all deprecations. +However, sometimes there is a contributor (who might or might not represent stakeholder), +willing to make their effort on cherry-picking and testing the non-breaking changes to a selected, +previous major branch of the provider. This results in releasing at most two versions of a +provider at a time: + +* potentially breaking "latest" major version +* selected past major version with non-breaking changes applied by the contributor + +Cherry-picking such changes follows the same process for releasing Airflow +patch-level releases for a previous minor Airflow version. Usually such cherry-picking is done when +there is an important bugfix and the latest version contains breaking changes that are not +coupled with the bugfix. Releasing them together in the latest version of the provider effectively couples +them, and therefore they're released separately. The cherry-picked changes have to be merged by the committer following the usual rules of the +community. + +There is no obligation to cherry-pick and release older versions of the providers. +The community continues to release such older versions of the providers for as long as there is an effort +of the contributors to perform the cherry-picks and carry-on testing of the older provider version. + +The availability of stakeholder that can manage "service-oriented" maintenance and agrees to such a +responsibility, will also drive our willingness to accept future, new providers to become community managed. ## Contributing diff --git a/RELEASE_NOTES.rst b/RELEASE_NOTES.rst index 87319b55aaebb..e879df02c328a 100644 --- a/RELEASE_NOTES.rst +++ b/RELEASE_NOTES.rst @@ -21,6 +21,1146 @@ .. towncrier release notes start +Airflow 2.5.1 (2023-01-16) +-------------------------- + +Significant Changes +^^^^^^^^^^^^^^^^^^^ + +Trigger gevent ``monkeypatching`` via environment variable (#28283) +""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + + If you are using gevent for your webserver deployment and used local settings to ``monkeypatch`` gevent, + you might want to replace local settings patching with an ``_AIRFLOW_PATCH_GEVENT`` environment variable + set to 1 in your webserver. This ensures gevent patching is done as early as possible. (#8212) + +Bug Fixes +^^^^^^^^^ +- Fix masking of non-sensitive environment variables (#28802) +- Remove swagger-ui extra from connexion and install ``swagger-ui-dist`` via npm package (#28788) +- Fix ``UIAlert`` should_show when ``AUTH_ROLE_PUBLIC`` set (#28781) +- Only patch single label when adopting pod (#28776) +- Update CSRF token to expire with session (#28730) +- Fix "airflow tasks render" cli command for mapped task instances (#28698) +- Allow XComArgs for ``external_task_ids`` of ExternalTaskSensor (#28692) +- Row-lock TIs to be removed during mapped task expansion (#28689) +- Handle ConnectionReset exception in Executor cleanup (#28685) +- Fix description of output redirection for access_log for gunicorn (#28672) +- Add back join to zombie query that was dropped in #28198 (#28544) +- Fix calendar view for CronTriggerTimeTable dags (#28411) +- After running the DAG the employees table is empty. (#28353) +- Fix ``DetachedInstanceError`` when finding zombies in Dag Parsing process (#28198) +- Nest header blocks in ``divs`` to fix ``dagid`` copy nit on dag.html (#28643) +- Fix UI caret direction (#28624) +- Guard not-yet-expanded ti in trigger rule dep (#28592) +- Move TI ``setNote`` endpoints under TaskInstance in OpenAPI (#28566) +- Consider previous run in ``CronTriggerTimetable`` (#28532) +- Ensure correct log dir in file task handler (#28477) +- Fix bad pods pickled in executor_config (#28454) +- Add ``ensure_ascii=False`` in trigger dag run API (#28451) +- Add setters to MappedOperator on_*_callbacks (#28313) +- Fix ``ti._try_number`` for deferred and up_for_reschedule tasks (#26993) +- separate ``callModal`` from dag.js (#28410) +- A manual run can't look like a scheduled one (#28397) +- Dont show task/run durations when there is no start_date (#28395) +- Maintain manual scroll position in task logs (#28386) +- Correctly select a mapped task's "previous" task (#28379) +- Trigger gevent ``monkeypatching`` via environment variable (#28283) +- Fix db clean warnings (#28243) +- Make arguments 'offset' and 'length' not required (#28234) +- Make live logs reading work for "other" k8s executors (#28213) +- Add custom pickling hooks to ``LazyXComAccess`` (#28191) +- fix next run datasets error (#28165) +- Ensure that warnings from ``@dag`` decorator are reported in dag file (#28153) +- Do not warn when airflow dags tests command is used (#28138) +- Ensure the ``dagbag_size`` metric decreases when files are deleted (#28135) +- Improve run/task grid view actions (#28130) +- Make BaseJob.most_recent_job favor "running" jobs (#28119) +- Don't emit FutureWarning when code not calling old key (#28109) +- Add ``airflow.api.auth.backend.session`` to backend sessions in compose (#28094) +- Resolve false warning about calling conf.get on moved item (#28075) +- Return list of tasks that will be changed (#28066) +- Handle bad zip files nicely when parsing DAGs. (#28011) +- Prevent double loading of providers from local paths (#27988) +- Fix deadlock when chaining multiple empty mapped tasks (#27964) +- fix: current_state method on TaskInstance doesn't filter by map_index (#27898) +- Don't log CLI actions if db not initialized (#27851) +- Make sure we can get out of a faulty scheduler state (#27834) +- dagrun, ``next_dagruns_to_examine``, add MySQL index hint (#27821) +- Handle DAG disappearing mid-flight when dag verification happens (#27720) +- fix: continue checking sla (#26968) +- Allow generation of connection URI to work when no conn type (#26765) + +Misc/Internal +^^^^^^^^^^^^^ +- Add automated version replacement in example dag indexes (#28090) +- Cleanup and do housekeeping with plugin examples (#28537) +- Limit ``SQLAlchemy`` to below ``2.0`` (#28725) +- Bump ``json5`` from ``1.0.1`` to ``1.0.2`` in ``/airflow/www`` (#28715) +- Fix some docs on using sensors with taskflow (#28708) +- Change Architecture and OperatingSystem classes into ``Enums`` (#28627) +- Add doc-strings and small improvement to email util (#28634) +- Fix ``Connection.get_extra`` type (#28594) +- navbar, cap dropdown size, and add scroll bar (#28561) +- Emit warnings for ``conf.get*`` from the right source location (#28543) +- Move MyPY plugins of ours to dev folder (#28498) +- Add retry to ``purge_inactive_dag_warnings`` (#28481) +- Re-enable Plyvel on ARM as it now builds cleanly (#28443) +- Add SIGUSR2 handler for LocalTaskJob and workers to aid debugging (#28309) +- Convert ``test_task_command`` to Pytest and ``unquarantine`` tests in it (#28247) +- Make invalid characters exception more readable (#28181) +- Bump decode-uri-component from ``0.2.0`` to ``0.2.2`` in ``/airflow/www`` (#28080) +- Use asserts instead of exceptions for executor not started (#28019) +- Simplify dataset ``subgraph`` logic (#27987) +- Order TIs by ``map_index`` (#27904) +- Additional info about Segmentation Fault in ``LocalTaskJob`` (#27381) + +Doc Only Changes +^^^^^^^^^^^^^^^^ +- Mention mapped operator in cluster policy doc (#28885) +- Slightly improve description of Dynamic DAG generation preamble (#28650) +- Restructure Docs (#27235) +- Update scheduler docs about low priority tasks (#28831) +- Clarify that versioned constraints are fixed at release time (#28762) +- Clarify about docker compose (#28729) +- Adding an example dag for dynamic task mapping (#28325) +- Use docker compose v2 command (#28605) +- Add AIRFLOW_PROJ_DIR to docker-compose example (#28517) +- Remove outdated Optional Provider Feature outdated documentation (#28506) +- Add documentation for [core] mp_start_method config (#27993) +- Documentation for the LocalTaskJob return code counter (#27972) +- Note which versions of Python are supported (#27798) + + +Airflow 2.5.0 (2022-12-02) +-------------------------- + +Significant Changes +^^^^^^^^^^^^^^^^^^^ + +``airflow dags test`` no longer performs a backfill job (#26400) +"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + + In order to make ``airflow dags test`` more useful as a testing and debugging tool, we no + longer run a backfill job and instead run a "local task runner". Users can still backfill + their DAGs using the ``airflow dags backfill`` command. + +Airflow config section ``kubernetes`` renamed to ``kubernetes_executor`` (#26873) +""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + KubernetesPodOperator no longer considers any core kubernetes config params, so this section now only applies to kubernetes executor. Renaming it reduces potential for confusion. + +``AirflowException`` is now thrown as soon as any dependent tasks of ExternalTaskSensor fails (#27190) +"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" +``ExternalTaskSensor`` no longer hangs indefinitely when ``failed_states`` is set, an ``execute_date_fn`` is used, and some but not all of the dependent tasks fail. + Instead, an ``AirflowException`` is thrown as soon as any of the dependent tasks fail. + Any code handling this failure in addition to timeouts should move to caching the ``AirflowException`` ``BaseClass`` and not only the ``AirflowSensorTimeout`` subclass. + +The Airflow config option ``scheduler.deactivate_stale_dags_interval`` has been renamed to ``scheduler.parsing_cleanup_interval`` (#27828). +"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + The old option will continue to work but will issue deprecation warnings, and will be removed entirely in Airflow 3. + +New Features +^^^^^^^^^^^^ +- ``TaskRunner``: notify of component start and finish (#27855) +- Add DagRun state change to the Listener plugin system(#27113) +- Metric for raw task return codes (#27155) +- Add logic for XComArg to pull specific map indexes (#27771) +- Clear TaskGroup (#26658, #28003) +- Add critical section query duration metric (#27700) +- Add: #23880 :: Audit log for ``AirflowModelViews(Variables/Connection)`` (#24079, #27994, #27923) +- Add postgres 15 support (#27444) +- Expand tasks in mapped group at run time (#27491) +- reset commits, clean submodules (#27560) +- scheduler_job, add metric for scheduler loop timer (#27605) +- Allow datasets to be used in taskflow (#27540) +- Add expanded_ti_count to ti context (#27680) +- Add user comment to task instance and dag run (#26457, #27849, #27867) +- Enable copying DagRun JSON to clipboard (#27639) +- Implement extra controls for SLAs (#27557) +- add dag parsed time in DAG view (#27573) +- Add max_wait for exponential_backoff in BaseSensor (#27597) +- Expand tasks in mapped group at parse time (#27158) +- Add disable retry flag on backfill (#23829) +- Adding sensor decorator (#22562) +- Api endpoint update ti (#26165) +- Filtering datasets by recent update events (#26942) +- Support ``Is /not`` Null filter for value is None on ``webui`` (#26584) +- Add search to datasets list (#26893) +- Split out and handle 'params' in mapped operator (#26100) +- Add authoring API for TaskGroup mapping (#26844) +- Add ``one_done`` trigger rule (#26146) +- Create a more efficient airflow dag test command that also has better local logging (#26400) +- Support add/remove permissions to roles commands (#26338) +- Auto tail file logs in Web UI (#26169) +- Add triggerer info to task instance in API (#26249) +- Flag to deserialize value on custom XCom backend (#26343) + +Improvements +^^^^^^^^^^^^ +- Allow depth-first execution (#27827) +- UI: Update offset height if data changes (#27865) +- Improve TriggerRuleDep typing and readability (#27810) +- Make views requiring session, keyword only args (#27790) +- Optimize ``TI.xcom_pull()`` with explicit task_ids and map_indexes (#27699) +- Allow hyphens in pod id used by k8s executor (#27737) +- optimise task instances filtering (#27102) +- Use context managers to simplify log serve management (#27756) +- Fix formatting leftovers (#27750) +- Improve task deadlock messaging (#27734) +- Improve "sensor timeout" messaging (#27733) +- Replace urlparse with ``urlsplit`` (#27389) +- Align TaskGroup semantics to AbstractOperator (#27723) +- Add new files to parsing queue on every loop of dag processing (#27060) +- Make Kubernetes Executor & Scheduler resilient to error during PMH execution (#27611) +- Separate dataset deps into individual graphs (#27356) +- Use log.exception where more economical than log.error (#27517) +- Move validation ``branch_task_ids`` into ``SkipMixin`` (#27434) +- Coerce LazyXComAccess to list when pushed to XCom (#27251) +- Update cluster-policies.rst docs (#27362) +- Add warning if connection type already registered within the provider (#27520) +- Activate debug logging in commands with --verbose option (#27447) +- Add classic examples for Python Operators (#27403) +- change ``.first()`` to ``.scalar()`` (#27323) +- Improve reset_dag_run description (#26755) +- Add examples and ``howtos`` about sensors (#27333) +- Make grid view widths adjustable (#27273) +- Sorting plugins custom menu links by category before name (#27152) +- Simplify DagRun.verify_integrity (#26894) +- Add mapped task group info to serialization (#27027) +- Correct the JSON style used for Run config in Grid View (#27119) +- No ``extra__conn_type__`` prefix required for UI behaviors (#26995) +- Improve dataset update blurb (#26878) +- Rename kubernetes config section to kubernetes_executor (#26873) +- decode params for dataset searches (#26941) +- Get rid of the DAGRun details page & rely completely on Grid (#26837) +- Fix scheduler ``crashloopbackoff`` when using ``hostname_callable`` (#24999) +- Reduce log verbosity in KubernetesExecutor. (#26582) +- Don't iterate tis list twice for no reason (#26740) +- Clearer code for PodGenerator.deserialize_model_file (#26641) +- Don't import kubernetes unless you have a V1Pod (#26496) +- Add updated_at column to DagRun and Ti tables (#26252) +- Move the deserialization of custom XCom Backend to 2.4.0 (#26392) +- Avoid calculating all elements when one item is needed (#26377) +- Add ``__future__``.annotations automatically by isort (#26383) +- Handle list when serializing expand_kwargs (#26369) +- Apply PEP-563 (Postponed Evaluation of Annotations) to core airflow (#26290) +- Add more weekday operator and sensor examples #26071 (#26098) +- Align TaskGroup semantics to AbstractOperator (#27723) + +Bug Fixes +^^^^^^^^^ +- Gracefully handle whole config sections being renamed (#28008) +- Add allow list for imports during deserialization (#27887) +- Soft delete datasets that are no longer referenced in DAG schedules or task outlets (#27828) +- Redirect to home view when there are no valid tags in the URL (#25715) +- Refresh next run datasets info in dags view (#27839) +- Make MappedTaskGroup depend on its expand inputs (#27876) +- Make DagRun state updates for paused DAGs faster (#27725) +- Don't explicitly set include_examples to False on task run command (#27813) +- Fix menu border color (#27789) +- Fix backfill queued task getting reset to scheduled state. (#23720) +- Fix clearing child dag mapped tasks from parent dag (#27501) +- Handle json encoding of ``V1Pod`` in task callback (#27609) +- Fix ExternalTaskSensor can't check zipped dag (#27056) +- Avoid re-fetching DAG run in TriggerDagRunOperator (#27635) +- Continue on exception when retrieving metadata (#27665) +- External task sensor fail fix (#27190) +- Add the default None when pop actions (#27537) +- Display parameter values from serialized dag in trigger dag view. (#27482, #27944) +- Move TriggerDagRun conf check to execute (#27035) +- Resolve trigger assignment race condition (#27072) +- Update google_analytics.html (#27226) +- Fix some bug in web ui dags list page (auto-refresh & jump search null state) (#27141) +- Fixed broken URL for docker-compose.yaml (#26721) +- Fix xcom arg.py .zip bug (#26636) +- Fix 404 ``taskInstance`` errors and split into two tables (#26575) +- Fix browser warning of improper thread usage (#26551) +- template rendering issue fix (#26390) +- Clear ``autoregistered`` DAGs if there are any import errors (#26398) +- Fix ``from airflow import version`` lazy import (#26239) +- allow scroll in triggered dag runs modal (#27965) + +Misc/Internal +^^^^^^^^^^^^^ +- Remove ``is_mapped`` attribute (#27881) +- Simplify FAB table resetting (#27869) +- Fix old-style typing in Base Sensor (#27871) +- Switch (back) to late imports (#27730) +- Completed D400 for multiple folders (#27748) +- simplify notes accordion test (#27757) +- completed D400 for ``airflow/callbacks/* airflow/cli/*`` (#27721) +- Completed D400 for ``airflow/api_connexion/* directory`` (#27718) +- Completed D400 for ``airflow/listener/* directory`` (#27731) +- Completed D400 for ``airflow/lineage/* directory`` (#27732) +- Update API & Python Client versions (#27642) +- Completed D400 & D401 for ``airflow/api/*`` directory (#27716) +- Completed D400 for multiple folders (#27722) +- Bump ``minimatch`` from ``3.0.4 to 3.0.8`` in ``/airflow/www`` (#27688) +- Bump loader-utils from ``1.4.1 to 1.4.2 ``in ``/airflow/www`` (#27697) +- Disable nested task mapping for now (#27681) +- bump alembic minimum version (#27629) +- remove unused code.html (#27585) +- Enable python string normalization everywhere (#27588) +- Upgrade dependencies in order to avoid backtracking (#27531) +- Strengthen a bit and clarify importance of triaging issues (#27262) +- Deduplicate type hints (#27508) +- Add stub 'yield' to ``BaseTrigger.run`` (#27416) +- Remove upper-bound limit to dask (#27415) +- Limit Dask to under ``2022.10.1`` (#27383) +- Update old style typing (#26872) +- Enable string normalization for docs (#27269) +- Slightly faster up/downgrade tests (#26939) +- Deprecate use of core get_kube_client in PodManager (#26848) +- Add ``memray`` files to ``gitignore / dockerignore`` (#27001) +- Bump sphinx and ``sphinx-autoapi`` (#26743) +- Simplify ``RTIF.delete_old_records()`` (#26667) +- migrate last react files to typescript (#26112) +- Work around ``pyupgrade`` edge cases (#26384) + +Doc only changes +^^^^^^^^^^^^^^^^ +- Document dag_file_processor_timeouts metric as deprecated (#27067) +- Drop support for PostgreSQL 10 (#27594) +- Update index.rst (#27529) +- Add note about pushing the lazy XCom proxy to XCom (#27250) +- Fix BaseOperator link (#27441) +- [docs] best-practices add use variable with template example. (#27316) +- docs for custom view using plugin (#27244) +- Update graph view and grid view on overview page (#26909) +- Documentation fixes (#26819) +- make consistency on markup title string level (#26696) +- Add documentation to dag test function (#26713) +- Fix broken URL for ``docker-compose.yaml`` (#26726) +- Add a note against use of top level code in timetable (#26649) +- Fix example_datasets dag names (#26495) +- Update docs: zip-like effect is now possible in task mapping (#26435) +- changing to task decorator in docs from classic operator use (#25711) + +Airflow 2.4.3 (2022-11-14) +-------------------------- + +Significant Changes +^^^^^^^^^^^^^^^^^^^ + +Make ``RotatingFilehandler`` used in ``DagProcessor`` non-caching (#27223) +"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + + In case you want to decrease cache memory when ``CONFIG_PROCESSOR_MANAGER_LOGGER=True``, and you have your local settings created before, + you can update ``processor_manager_handler`` to use ``airflow.utils.log.non_caching_file_handler.NonCachingRotatingFileHandler`` handler instead of ``logging.RotatingFileHandler``. (#27065) + +Bug Fixes +^^^^^^^^^ +- Fix double logging with some task logging handler (#27591) +- Replace FAB url filtering function with Airflow's (#27576) +- Fix mini scheduler expansion of mapped task (#27506) +- ``SLAMiss`` is nullable and not always given back when pulling task instances (#27423) +- Fix behavior of ``_`` when searching for DAGs (#27448) +- Fix getting the ``dag/task`` ids from BaseExecutor (#27550) +- Fix SQLAlchemy primary key black-out error on DDRQ (#27538) +- Fix IntegrityError during webserver startup (#27297) +- Add case insensitive constraint to username (#27266) +- Fix python external template keys (#27256) +- Reduce extraneous task log requests (#27233) +- Make ``RotatingFilehandler`` used in ``DagProcessor`` non-caching (#27223) +- Listener: Set task on SQLAlchemy TaskInstance object (#27167) +- Fix dags list page auto-refresh & jump search null state (#27141) +- Set ``executor.job_id`` to ``BackfillJob.id`` for backfills (#27020) + +Misc/Internal +^^^^^^^^^^^^^ +- Bump loader-utils from ``1.4.0`` to ``1.4.1`` in ``/airflow/www`` (#27552) +- Reduce log level for k8s ``TCP_KEEPALIVE`` etc warnings (#26981) + +Doc only changes +^^^^^^^^^^^^^^^^ +- Use correct executable in docker compose docs (#27529) +- Fix wording in DAG Runs description (#27470) +- Document that ``KubernetesExecutor`` overwrites container args (#27450) +- Fix ``BaseOperator`` links (#27441) +- Correct timer units to seconds from milliseconds. (#27360) +- Add missed import in the Trigger Rules example (#27309) +- Update SLA wording to reflect it is relative to ``Dag Run`` start. (#27111) +- Add ``kerberos`` environment variables to the docs (#27028) + +Airflow 2.4.2 (2022-10-23) +-------------------------- + +Significant Changes +^^^^^^^^^^^^^^^^^^^ + +Default for ``[webserver] expose_stacktrace`` changed to ``False`` (#27059) +""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + + The default for ``[webserver] expose_stacktrace`` has been set to ``False``, instead of ``True``. This means administrators must opt-in to expose tracebacks to end users. + +Bug Fixes +^^^^^^^^^ +- Make tracebacks opt-in (#27059) +- Add missing AUTOINC/SERIAL for FAB tables (#26885) +- Add separate error handler for 405(Method not allowed) errors (#26880) +- Don't re-patch pods that are already controlled by current worker (#26778) +- Handle mapped tasks in task duration chart (#26722) +- Fix task duration cumulative chart (#26717) +- Avoid 500 on dag redirect (#27064) +- Filter dataset dependency data on webserver (#27046) +- Remove double collection of dags in ``airflow dags reserialize`` (#27030) +- Fix auto refresh for graph view (#26926) +- Don't overwrite connection extra with invalid json (#27142) +- Fix next run dataset modal links (#26897) +- Change dag audit log sort by date from asc to desc (#26895) +- Bump min version of jinja2 (#26866) +- Add missing colors to ``state_color_mapping`` jinja global (#26822) +- Fix running debuggers inside ``airflow tasks test`` (#26806) +- Fix warning when using xcomarg dependencies (#26801) +- demote Removed state in priority for displaying task summaries (#26789) +- Ensure the log messages from operators during parsing go somewhere (#26779) +- Add restarting state to TaskState Enum in REST API (#26776) +- Allow retrieving error message from data.detail (#26762) +- Simplify origin string cleaning (#27143) +- Remove DAG parsing from StandardTaskRunner (#26750) +- Fix non-hidden cumulative chart on duration view (#26716) +- Remove TaskFail duplicates check (#26714) +- Fix airflow tasks run --local when dags_folder differs from that of processor (#26509) +- Fix yarn warning from d3-color (#27139) +- Fix version for a couple configurations (#26491) +- Revert "No grid auto-refresh for backfill dag runs (#25042)" (#26463) +- Retry on Airflow Schedule DAG Run DB Deadlock (#26347) + +Misc/Internal +^^^^^^^^^^^^^ +- Clean-ups around task-mapping code (#26879) +- Move user-facing string to template (#26815) +- add icon legend to datasets graph (#26781) +- Bump ``sphinx`` and ``sphinx-autoapi`` (#26743) +- Simplify ``RTIF.delete_old_records()`` (#26667) +- Bump FAB to ``4.1.4`` (#26393) + +Doc only changes +^^^^^^^^^^^^^^^^ +- Fixed triple quotes in task group example (#26829) +- Documentation fixes (#26819) +- make consistency on markup title string level (#26696) +- Add a note against use of top level code in timetable (#26649) +- Fix broken URL for ``docker-compose.yaml`` (#26726) + + +Airflow 2.4.1 (2022-09-30) +-------------------------- + +Significant Changes +^^^^^^^^^^^^^^^^^^^ + +No significant changes. + +Bug Fixes +^^^^^^^^^ + +- When rendering template, unmap task in context (#26702) +- Fix scroll overflow for ConfirmDialog (#26681) +- Resolve deprecation warning re ``Table.exists()`` (#26616) +- Fix XComArg zip bug (#26636) +- Use COALESCE when ordering runs to handle NULL (#26626) +- Check user is active (#26635) +- No missing user warning for public admin (#26611) +- Allow MapXComArg to resolve after serialization (#26591) +- Resolve warning about DISTINCT ON query on dags view (#26608) +- Log warning when secret backend kwargs is invalid (#26580) +- Fix grid view log try numbers (#26556) +- Template rendering issue in passing ``templates_dict`` to task decorator (#26390) +- Fix Deferrable stuck as ``scheduled`` during backfill (#26205) +- Suppress SQLALCHEMY_TRACK_MODIFICATIONS warning in db init (#26617) +- Correctly set ``json_provider_class`` on Flask app so it uses our encoder (#26554) +- Fix WSGI root app (#26549) +- Fix deadlock when mapped task with removed upstream is rerun (#26518) +- ExecutorConfigType should be ``cacheable`` (#26498) +- Fix proper joining of the path for logs retrieved from celery workers (#26493) +- DAG Deps extends ``base_template`` (#26439) +- Don't update backfill run from the scheduler (#26342) + +Doc only changes +^^^^^^^^^^^^^^^^ + +- Clarify owner links document (#26515) +- Fix invalid RST in dataset concepts doc (#26434) +- Document the ``non-sensitive-only`` option for ``expose_config`` (#26507) +- Fix ``example_datasets`` dag names (#26495) +- Zip-like effect is now possible in task mapping (#26435) +- Use task decorator in docs instead of classic operators (#25711) + +Airflow 2.4.0 (2022-09-19) +-------------------------- + +Significant Changes +^^^^^^^^^^^^^^^^^^^ + +Data-aware Scheduling and ``Dataset`` concept added to Airflow +"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + +New to this release of Airflow is the concept of Datasets to Airflow, and with it a new way of scheduling dags: +data-aware scheduling. + +This allows DAG runs to be automatically created as a result of a task "producing" a dataset. In some ways +this can be thought of as the inverse of ``TriggerDagRunOperator``, where instead of the producing DAG +controlling which DAGs get created, the consuming DAGs can "listen" for changes. + +A dataset is identified by a URI: + +.. code-block:: python + + from airflow import Dataset + + # The URI doesn't have to be absolute + dataset = Dataset(uri="my-dataset") + # Or you can use a scheme to show where it lives. + dataset2 = Dataset(uri="s3://bucket/prefix") + +To create a DAG that runs whenever a Dataset is updated use the new ``schedule`` parameter (see below) and +pass a list of 1 or more Datasets: + +.. code-block:: python + + with DAG(dag_id='dataset-consumer', schedule=[dataset]): + ... + +And to mark a task as producing a dataset pass the dataset(s) to the ``outlets`` attribute: + +.. code-block:: python + + @task(outlets=[dataset]) + def my_task(): + ... + + + # Or for classic operators + BashOperator(task_id="update-ds", bash_command=..., outlets=[dataset]) + +If you have the producer and consumer in different files you do not need to use the same Dataset object, two +``Dataset()``\s created with the same URI are equal. + +Datasets represent the abstract concept of a dataset, and (for now) do not have any direct read or write +capability - in this release we are adding the foundational feature that we will build upon. + +For more info on Datasets please see :doc:`/authoring-and-scheduling/datasets`. + +Expanded dynamic task mapping support +""""""""""""""""""""""""""""""""""""" + +Dynamic task mapping now includes support for ``expand_kwargs``, ``zip`` and ``map``. + +For more info on dynamic task mapping please see :doc:`/authoring-and-scheduling/dynamic-task-mapping`. + +DAGS used in a context manager no longer need to be assigned to a module variable (#23592) +"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + +Previously you had to assign a DAG to a module-level variable in order for Airflow to pick it up. For example this + + +.. code-block:: python + + with DAG(dag_id="example") as dag: + ... + + + @dag + def dag_maker(): + ... + + + dag2 = dag_maker() + + +can become + +.. code-block:: python + + with DAG(dag_id="example"): + ... + + + @dag + def dag_maker(): + ... + + + dag_maker() + +If you want to disable the behaviour for any reason then set ``auto_register=False`` on the dag: + +.. code-block:: python + + # This dag will not be picked up by Airflow as it's not assigned to a variable + with DAG(dag_id="example", auto_register=False): + ... + +Deprecation of ``schedule_interval`` and ``timetable`` arguments (#25410) +""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + +We added new DAG argument ``schedule`` that can accept a cron expression, timedelta object, *timetable* object, or list of dataset objects. Arguments ``schedule_interval`` and ``timetable`` are deprecated. + +If you previously used the ``@daily`` cron preset, your DAG may have looked like this: + +.. code-block:: python + + with DAG( + dag_id="my_example", + start_date=datetime(2021, 1, 1), + schedule_interval="@daily", + ): + ... + +Going forward, you should use the ``schedule`` argument instead: + +.. code-block:: python + + with DAG( + dag_id="my_example", + start_date=datetime(2021, 1, 1), + schedule="@daily", + ): + ... + +The same is true if you used a custom timetable. Previously you would have used the ``timetable`` argument: + +.. code-block:: python + + with DAG( + dag_id="my_example", + start_date=datetime(2021, 1, 1), + timetable=EventsTimetable(event_dates=[pendulum.datetime(2022, 4, 5)]), + ): + ... + +Now you should use the ``schedule`` argument: + +.. code-block:: python + + with DAG( + dag_id="my_example", + start_date=datetime(2021, 1, 1), + schedule=EventsTimetable(event_dates=[pendulum.datetime(2022, 4, 5)]), + ): + ... + +Removal of experimental Smart Sensors (#25507) +"""""""""""""""""""""""""""""""""""""""""""""" + +Smart Sensors were added in 2.0 and deprecated in favor of Deferrable operators in 2.2, and have now been removed. + +``airflow.contrib`` packages and deprecated modules are dynamically generated (#26153, #26179, #26167) +"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + +The ``airflow.contrib`` packages and deprecated modules from Airflow 1.10 in ``airflow.hooks``, ``airflow.operators``, ``airflow.sensors`` packages are now dynamically generated modules and while users can continue using the deprecated contrib classes, they are no longer visible for static code check tools and will be reported as missing. It is recommended for the users to move to the non-deprecated classes. + +``DBApiHook`` and ``SQLSensor`` have moved (#24836) +""""""""""""""""""""""""""""""""""""""""""""""""""" + +``DBApiHook`` and ``SQLSensor`` have been moved to the ``apache-airflow-providers-common-sql`` provider. + +DAG runs sorting logic changed in grid view (#25090) +"""""""""""""""""""""""""""""""""""""""""""""""""""" + +The ordering of DAG runs in the grid view has been changed to be more "natural". +The new logic generally orders by data interval, but a custom ordering can be +applied by setting the DAG to use a custom timetable. + + +New Features +^^^^^^^^^^^^ +- Add Data-aware Scheduling (`AIP-48 `_) +- Add ``@task.short_circuit`` TaskFlow decorator (#25752) +- Make ``execution_date_or_run_id`` optional in ``tasks test`` command (#26114) +- Automatically register DAGs that are used in a context manager (#23592, #26398) +- Add option of sending DAG parser logs to stdout. (#25754) +- Support multiple ``DagProcessors`` parsing files from different locations. (#25935) +- Implement ``ExternalPythonOperator`` (#25780) +- Make execution_date optional for command ``dags test`` (#26111) +- Implement ``expand_kwargs()`` against a literal list (#25925) +- Add trigger rule tooltip (#26043) +- Add conf parameter to CLI for airflow dags test (#25900) +- Include scheduled slots in pools view (#26006) +- Add ``output`` property to ``MappedOperator`` (#25604) +- Add roles delete command to cli (#25854) +- Add Airflow specific warning classes (#25799) +- Add support for ``TaskGroup`` in ``ExternalTaskSensor`` (#24902) +- Add ``@task.kubernetes`` taskflow decorator (#25663) +- Add a way to import Airflow without side-effects (#25832) +- Let timetables control generated run_ids. (#25795) +- Allow per-timetable ordering override in grid view (#25633) +- Grid logs for mapped instances (#25610, #25621, #25611) +- Consolidate to one ``schedule`` param (#25410) +- DAG regex flag in backfill command (#23870) +- Adding support for owner links in the Dags view UI (#25280) +- Ability to clear a specific DAG Run's task instances via REST API (#23516) +- Possibility to document DAG with a separate markdown file (#25509) +- Add parsing context to DAG Parsing (#25161) +- Implement ``CronTriggerTimetable`` (#23662) +- Add option to mask sensitive data in UI configuration page (#25346) +- Create new databases from the ORM (#24156) +- Implement ``XComArg.zip(*xcom_args)`` (#25176) +- Introduce ``sla_miss`` metric (#23402) +- Implement ``map()`` semantic (#25085) +- Add override method to TaskGroupDecorator (#25160) +- Implement ``expand_kwargs()`` (#24989) +- Add parameter to turn off SQL query logging (#24570) +- Add ``DagWarning`` model, and a check for missing pools (#23317) +- Add Task Logs to Grid details panel (#24249) +- Added small health check server and endpoint in scheduler(#23905) +- Add built-in External Link for ``ExternalTaskMarker`` operator (#23964) +- Add default task retry delay config (#23861) +- Add clear DagRun endpoint. (#23451) +- Add support for timezone as string in cron interval timetable (#23279) +- Add auto-refresh to dags home page (#22900, #24770) + +Improvements +^^^^^^^^^^^^ + +- Add more weekday operator and sensor examples #26071 (#26098) +- Add subdir parameter to dags reserialize command (#26170) +- Update zombie message to be more descriptive (#26141) +- Only send an ``SlaCallbackRequest`` if the DAG is scheduled (#26089) +- Promote ``Operator.output`` more (#25617) +- Upgrade API files to typescript (#25098) +- Less ``hacky`` double-rendering prevention in mapped task (#25924) +- Improve Audit log (#25856) +- Remove mapped operator validation code (#25870) +- More ``DAG(schedule=...)`` improvements (#25648) +- Reduce ``operator_name`` dupe in serialized JSON (#25819) +- Make grid view group/mapped summary UI more consistent (#25723) +- Remove useless statement in ``task_group_to_grid`` (#25654) +- Add optional data interval to ``CronTriggerTimetable`` (#25503) +- Remove unused code in ``/grid`` endpoint (#25481) +- Add and document description fields (#25370) +- Improve Airflow logging for operator Jinja template processing (#25452) +- Update core example DAGs to use ``@task.branch`` decorator (#25242) +- Update DAG ``audit_log`` route (#25415) +- Change stdout and stderr access mode to append in commands (#25253) +- Remove ``getTasks`` from Grid view (#25359) +- Improve taskflow type hints with ParamSpec (#25173) +- Use tables in grid details panes (#25258) +- Explicitly list ``@dag`` arguments (#25044) +- More typing in ``SchedulerJob`` and ``TaskInstance`` (#24912) +- Patch ``getfqdn`` with more resilient version (#24981) +- Replace all ``NBSP`` characters by ``whitespaces`` (#24797) +- Re-serialize all DAGs on ``airflow db upgrade`` (#24518) +- Rework contract of try_adopt_task_instances method (#23188) +- Make ``expand()`` error vague so it's not misleading (#24018) +- Add enum validation for ``[webserver]analytics_tool`` (#24032) +- Add ``dttm`` searchable field in audit log (#23794) +- Allow more parameters to be piped through via ``execute_in_subprocess`` (#23286) +- Use ``func.count`` to count rows (#23657) +- Remove stale serialized dags (#22917) +- AIP45 Remove dag parsing in airflow run local (#21877) +- Add support for queued state in DagRun update endpoint. (#23481) +- Add fields to dagrun endpoint (#23440) +- Use ``sql_alchemy_conn`` for celery result backend when ``result_backend`` is not set (#24496) + +Bug Fixes +^^^^^^^^^ + +- Have consistent types between the ORM and the migration files (#24044, #25869) +- Disallow any dag tags longer than 100 char (#25196) +- Add the dag_id to ``AirflowDagCycleException`` message (#26204) +- Properly build URL to retrieve logs independently from system (#26337) +- For worker log servers only bind to IPV6 when dual stack is available (#26222) +- Fix ``TaskInstance.task`` not defined before ``handle_failure`` (#26040) +- Undo secrets backend config caching (#26223) +- Fix faulty executor config serialization logic (#26191) +- Show ``DAGs`` and ``Datasets`` menu links based on role permission (#26183) +- Allow setting ``TaskGroup`` tooltip via function docstring (#26028) +- Fix RecursionError on graph view of a DAG with many tasks (#26175) +- Fix backfill occasional deadlocking (#26161) +- Fix ``DagRun.start_date`` not set during backfill with ``--reset-dagruns`` True (#26135) +- Use label instead of id for dynamic task labels in graph (#26108) +- Don't fail DagRun when leaf ``mapped_task`` is SKIPPED (#25995) +- Add group prefix to decorated mapped task (#26081) +- Fix UI flash when triggering with dup logical date (#26094) +- Fix Make items nullable for ``TaskInstance`` related endpoints to avoid API errors (#26076) +- Fix ``BranchDateTimeOperator`` to be ``timezone-awreness-insensitive`` (#25944) +- Fix legacy timetable schedule interval params (#25999) +- Fix response schema for ``list-mapped-task-instance`` (#25965) +- Properly check the existence of missing mapped TIs (#25788) +- Fix broken auto-refresh on grid view (#25950) +- Use per-timetable ordering in grid UI (#25880) +- Rewrite recursion when parsing DAG into iteration (#25898) +- Find cross-group tasks in ``iter_mapped_dependants`` (#25793) +- Fail task if mapping upstream fails (#25757) +- Support ``/`` in variable get endpoint (#25774) +- Use cfg default_wrap value for grid logs (#25731) +- Add origin request args when triggering a run (#25729) +- Operator name separate from class (#22834) +- Fix incorrect data interval alignment due to assumption on input time alignment (#22658) +- Return None if an ``XComArg`` fails to resolve (#25661) +- Correct ``json`` arg help in ``airflow variables set`` command (#25726) +- Added MySQL index hint to use ``ti_state`` on ``find_zombies`` query (#25725) +- Only excluded actually expanded fields from render (#25599) +- Grid, fix toast for ``axios`` errors (#25703) +- Fix UI redirect (#26409) +- Require dag_id arg for dags list-runs (#26357) +- Check for queued states for dags auto-refresh (#25695) +- Fix upgrade code for the ``dag_owner_attributes`` table (#25579) +- Add map index to task logs api (#25568) +- Ensure that zombie tasks for dags with errors get cleaned up (#25550) +- Make extra link work in UI (#25500) +- Sync up plugin API schema and definition (#25524) +- First/last names can be empty (#25476) +- Refactor DAG pages to be consistent (#25402) +- Check ``expand_kwargs()`` input type before unmapping (#25355) +- Filter XCOM by key when calculating map lengths (#24530) +- Fix ``ExternalTaskSensor`` not working with dynamic task (#25215) +- Added exception catching to send default email if template file raises any exception (#24943) +- Bring ``MappedOperator`` members in sync with ``BaseOperator`` (#24034) + + +Misc/Internal +^^^^^^^^^^^^^ + +- Add automatically generated ``ERD`` schema for the ``MetaData`` DB (#26217) +- Mark serialization functions as internal (#26193) +- Remove remaining deprecated classes and replace them with ``PEP562`` (#26167) +- Move ``dag_edges`` and ``task_group_to_dict`` to corresponding util modules (#26212) +- Lazily import many modules to improve import speed (#24486, #26239) +- FIX Incorrect typing information (#26077) +- Add missing contrib classes to deprecated dictionaries (#26179) +- Re-configure/connect the ``ORM`` after forking to run a DAG processor (#26216) +- Remove cattrs from lineage processing. (#26134) +- Removed deprecated contrib files and replace them with ``PEP-562`` getattr (#26153) +- Make ``BaseSerialization.serialize`` "public" to other classes. (#26142) +- Change the template to use human readable task_instance description (#25960) +- Bump ``moment-timezone`` from ``0.5.34`` to ``0.5.35`` in ``/airflow/www`` (#26080) +- Fix Flask deprecation warning (#25753) +- Add ``CamelCase`` to generated operations types (#25887) +- Fix migration issues and tighten the CI upgrade/downgrade test (#25869) +- Fix type annotations in ``SkipMixin`` (#25864) +- Workaround setuptools editable packages path issue (#25848) +- Bump ``undici`` from ``5.8.0 to 5.9.1`` in /airflow/www (#25801) +- Add custom_operator_name attr to ``_BranchPythonDecoratedOperator`` (#25783) +- Clarify ``filename_template`` deprecation message (#25749) +- Use ``ParamSpec`` to replace ``...`` in Callable (#25658) +- Remove deprecated modules (#25543) +- Documentation on task mapping additions (#24489) +- Remove Smart Sensors (#25507) +- Fix ``elasticsearch`` test config to avoid warning on deprecated template (#25520) +- Bump ``terser`` from ``4.8.0 to 4.8.1`` in /airflow/ui (#25178) +- Generate ``typescript`` types from rest ``API`` docs (#25123) +- Upgrade utils files to ``typescript`` (#25089) +- Upgrade remaining context file to ``typescript``. (#25096) +- Migrate files to ``ts`` (#25267) +- Upgrade grid Table component to ``ts.`` (#25074) +- Skip mapping against mapped ``ti`` if it returns None (#25047) +- Refactor ``js`` file structure (#25003) +- Move mapped kwargs introspection to separate type (#24971) +- Only assert stuff for mypy when type checking (#24937) +- Bump ``moment`` from ``2.29.3 to 2.29.4`` in ``/airflow/www`` (#24885) +- Remove "bad characters" from our codebase (#24841) +- Remove ``xcom_push`` flag from ``BashOperator`` (#24824) +- Move Flask hook registration to end of file (#24776) +- Upgrade more javascript files to ``typescript`` (#24715) +- Clean up task decorator type hints and docstrings (#24667) +- Preserve original order of providers' connection extra fields in UI (#24425) +- Rename ``charts.css`` to ``chart.css`` (#24531) +- Rename ``grid.css`` to ``chart.css`` (#24529) +- Misc: create new process group by ``set_new_process_group`` utility (#24371) +- Airflow UI fix Prototype Pollution (#24201) +- Bump ``moto`` version (#24222) +- Remove unused ``[github_enterprise]`` from ref docs (#24033) +- Clean up ``f-strings`` in logging calls (#23597) +- Add limit for ``JPype1`` (#23847) +- Simply json responses (#25518) +- Add min attrs version (#26408) + +Doc only changes +^^^^^^^^^^^^^^^^ +- Add url prefix setting for ``Celery`` Flower (#25986) +- Updating deprecated configuration in examples (#26037) +- Fix wrong link for taskflow tutorial (#26007) +- Reorganize tutorials into a section (#25890) +- Fix concept doc for dynamic task map (#26002) +- Update code examples from "classic" operators to taskflow (#25845, #25657) +- Add instructions on manually fixing ``MySQL`` Charset problems (#25938) +- Prefer the local Quick Start in docs (#25888) +- Fix broken link to ``Trigger Rules`` (#25840) +- Improve docker documentation (#25735) +- Correctly link to Dag parsing context in docs (#25722) +- Add note on ``task_instance_mutation_hook`` usage (#25607) +- Note that TaskFlow API automatically passes data between tasks (#25577) +- Update DAG run to clarify when a DAG actually runs (#25290) +- Update tutorial docs to include a definition of operators (#25012) +- Rewrite the Airflow documentation home page (#24795) +- Fix ``task-generated mapping`` example (#23424) +- Add note on subtle logical date change in ``2.2.0`` (#24413) +- Add missing import in best-practices code example (#25391) + + + +Airflow 2.3.4 (2022-08-23) +-------------------------- + +Significant Changes +^^^^^^^^^^^^^^^^^^^ + +Added new config ``[logging]log_formatter_class`` to fix timezone display for logs on UI (#24811) +""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + +If you are using a custom Formatter subclass in your ``[logging]logging_config_class``, please inherit from ``airflow.utils.log.timezone_aware.TimezoneAware`` instead of ``logging.Formatter``. +For example, in your ``custom_config.py``: + +.. code-block:: python + + from airflow.utils.log.timezone_aware import TimezoneAware + + # before + class YourCustomFormatter(logging.Formatter): + ... + + + # after + class YourCustomFormatter(TimezoneAware): + ... + + + AIRFLOW_FORMATTER = LOGGING_CONFIG["formatters"]["airflow"] + AIRFLOW_FORMATTER["class"] = "somewhere.your.custom_config.YourCustomFormatter" + # or use TimezoneAware class directly. If you don't have custom Formatter. + AIRFLOW_FORMATTER["class"] = "airflow.utils.log.timezone_aware.TimezoneAware" + +Bug Fixes +^^^^^^^^^ + +- Disable ``attrs`` state management on ``MappedOperator`` (#24772) +- Serialize ``pod_override`` to JSON before pickling ``executor_config`` (#24356) +- Fix ``pid`` check (#24636) +- Rotate session id during login (#25771) +- Fix mapped sensor with reschedule mode (#25594) +- Cache the custom secrets backend so the same instance gets re-used (#25556) +- Add right padding (#25554) +- Fix reducing mapped length of a mapped task at runtime after a clear (#25531) +- Fix ``airflow db reset`` when dangling tables exist (#25441) +- Change ``disable_verify_ssl`` behaviour (#25023) +- Set default task group in dag.add_task method (#25000) +- Removed interfering force of index. (#25404) +- Remove useless logging line (#25347) +- Adding mysql index hint to use index on ``task_instance.state`` in critical section query (#25673) +- Configurable umask to all daemonized processes. (#25664) +- Fix the errors raised when None is passed to template filters (#25593) +- Allow wildcarded CORS origins (#25553) +- Fix "This Session's transaction has been rolled back" (#25532) +- Fix Serialization error in ``TaskCallbackRequest`` (#25471) +- fix - resolve bash by absolute path (#25331) +- Add ``__repr__`` to ParamsDict class (#25305) +- Only load distribution of a name once (#25296) +- convert ``TimeSensorAsync`` ``target_time`` to utc on call time (#25221) +- call ``updateNodeLabels`` after ``expandGroup`` (#25217) +- Stop SLA callbacks gazumping other callbacks and DOS'ing the ``DagProcessorManager`` queue (#25147) +- Fix ``invalidateQueries`` call (#25097) +- ``airflow/www/package.json``: Add name, version fields. (#25065) +- No grid auto-refresh for backfill dag runs (#25042) +- Fix tag link on dag detail page (#24918) +- Fix zombie task handling with multiple schedulers (#24906) +- Bind log server on worker to ``IPv6`` address (#24755) (#24846) +- Add ``%z`` for ``%(asctime)s`` to fix timezone for logs on UI (#24811) +- ``TriggerDagRunOperator.operator_extra_links`` is attr (#24676) +- Send DAG timeout callbacks to processor outside of ``prohibit_commit`` (#24366) +- Don't rely on current ORM structure for db clean command (#23574) +- Clear next method when clearing TIs (#23929) +- Two typing fixes (#25690) + +Doc only changes +^^^^^^^^^^^^^^^^ + +- Update set-up-database.rst (#24983) +- Fix syntax in mysql setup documentation (#24893 (#24939) +- Note how DAG policy works with default_args (#24804) +- Update PythonVirtualenvOperator Howto (#24782) +- Doc: Add hyperlinks to Github PRs for Release Notes (#24532) + +Misc/Internal +^^^^^^^^^^^^^ + +- Remove depreciation warning when use default remote tasks logging handlers (#25764) +- clearer method name in scheduler_job.py (#23702) +- Bump cattrs version (#25689) +- Include missing mention of ``external_executor_id`` in ``sql_engine_collation_for_ids`` docs (#25197) +- Refactor ``DR.task_instance_scheduling_decisions`` (#24774) +- Sort operator extra links (#24992) +- Extends ``resolve_xcom_backend`` function level documentation (#24965) +- Upgrade FAB to 4.1.3 (#24884) +- Limit Flask to <2.3 in the wake of 2.2 breaking our tests (#25511) +- Limit astroid version to < 2.12 (#24982) +- Move javascript compilation to host (#25169) +- Bump typing-extensions and mypy for ParamSpec (#25088) + + +Airflow 2.3.3 (2022-07-09) +-------------------------- + +Significant Changes +^^^^^^^^^^^^^^^^^^^ + +We've upgraded Flask App Builder to a major version 4.* (#24399) +"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + +Flask App Builder is one of the important components of Airflow Webserver, as +it uses a lot of dependencies that are essential to run the webserver and integrate it +in enterprise environments - especially authentication. + +The FAB 4.* upgrades a number of dependencies to major releases, which upgrades them to versions +that have a number of security issues fixed. A lot of tests were performed to bring the dependencies +in a backwards-compatible way, however the dependencies themselves implement breaking changes in their +internals so it might be that some of those changes might impact the users in case they are using the +libraries for their own purposes. + +One important change that you likely will need to apply to Oauth configuration is to add +``server_metadata_url`` or ``jwks_uri`` and you can read about it more +in `this issue `_. + +Here is the list of breaking changes in dependencies that comes together with FAB 4: + + * ``Flask`` from 1.X to 2.X `breaking changes `__ + + * ``flask-jwt-extended`` 3.X to 4.X `breaking changes: `__ + + * ``Jinja2`` 2.X to 3.X `breaking changes: `__ + + * ``Werkzeug`` 1.X to 2.X `breaking changes `__ + + * ``pyJWT`` 1.X to 2.X `breaking changes: `__ + + * ``Click`` 7.X to 8.X `breaking changes: `__ + + * ``itsdangerous`` 1.X to 2.X `breaking changes `__ + +Bug Fixes +^^^^^^^^^ + +- Fix exception in mini task scheduler (#24865) +- Fix cycle bug with attaching label to task group (#24847) +- Fix timestamp defaults for ``sensorinstance`` (#24638) +- Move fallible ``ti.task.dag`` assignment back inside ``try/except`` block (#24533) (#24592) +- Add missing types to ``FSHook`` (#24470) +- Mask secrets in ``stdout`` for ``airflow tasks test`` (#24362) +- ``DebugExecutor`` use ``ti.run()`` instead of ``ti._run_raw_task`` (#24357) +- Fix bugs in ``URI`` constructor for ``MySQL`` connection (#24320) +- Missing ``scheduleinterval`` nullable true added in ``openapi`` (#24253) +- Unify ``return_code`` interface for task runner (#24093) +- Handle occasional deadlocks in trigger with retries (#24071) +- Remove special serde logic for mapped ``op_kwargs`` (#23860) +- ``ExternalTaskSensor`` respects ``soft_fail`` if the external task enters a ``failed_state`` (#23647) +- Fix ``StatD`` timing metric units (#21106) +- Add ``cache_ok`` flag to sqlalchemy TypeDecorators. (#24499) +- Allow for ``LOGGING_LEVEL=DEBUG`` (#23360) +- Fix grid date ticks (#24738) +- Debounce status highlighting in Grid view (#24710) +- Fix Grid vertical scrolling (#24684) +- don't try to render child rows for closed groups (#24637) +- Do not calculate grid root instances (#24528) +- Maintain grid view selection on filtering upstream (#23779) +- Speed up ``grid_data`` endpoint by 10x (#24284) +- Apply per-run log templates to log handlers (#24153) +- Don't crash scheduler if exec config has old k8s objects (#24117) +- ``TI.log_url`` fix for ``map_index`` (#24335) +- Fix migration ``0080_2_0_2`` - Replace null values before setting column not null (#24585) +- Patch ``sql_alchemy_conn`` if old Postgres schemes used (#24569) +- Seed ``log_template`` table (#24511) +- Fix deprecated ``log_id_template`` value (#24506) +- Fix toast messages (#24505) +- Add indexes for CASCADE deletes for ``task_instance`` (#24488) +- Return empty dict if Pod JSON encoding fails (#24478) +- Improve grid rendering performance with a custom tooltip (#24417, #24449) +- Check for ``run_id`` for grid group summaries (#24327) +- Optimize calendar view for cron scheduled DAGs (#24262) +- Use ``get_hostname`` instead of ``socket.getfqdn`` (#24260) +- Check that edge nodes actually exist (#24166) +- Fix ``useTasks`` crash on error (#24152) +- Do not fail re-queued TIs (#23846) +- Reduce grid view API calls (#24083) +- Rename Permissions to Permission Pairs. (#24065) +- Replace ``use_task_execution_date`` with ``use_task_logical_date`` (#23983) +- Grid fix details button truncated and small UI tweaks (#23934) +- Add TaskInstance State ``REMOVED`` to finished states and success states (#23797) +- Fix mapped task immutability after clear (#23667) +- Fix permission issue for dag that has dot in name (#23510) +- Fix closing connection ``dbapi.get_pandas_df`` (#23452) +- Check bag DAG ``schedule_interval`` match timetable (#23113) +- Parse error for task added to multiple groups (#23071) +- Fix flaky order of returned dag runs (#24405) +- Migrate ``jsx`` files that affect run/task selection to ``tsx`` (#24509) +- Fix links to sources for examples (#24386) +- Set proper ``Content-Type`` and ``chartset`` on ``grid_data`` endpoint (#24375) + +Doc only changes +^^^^^^^^^^^^^^^^ + +- Update templates doc to mention ``extras`` and format Airflow ``Vars`` / ``Conns`` (#24735) +- Document built in Timetables (#23099) +- Alphabetizes two tables (#23923) +- Clarify that users should not use Maria DB (#24556) +- Add imports to deferring code samples (#24544) +- Add note about image regeneration in June 2022 (#24524) +- Small cleanup of ``get_current_context()`` chapter (#24482) +- Fix default 2.2.5 ``log_id_template`` (#24455) +- Update description of installing providers separately from core (#24454) +- Mention context variables and logging (#24304) + +Misc/Internal +^^^^^^^^^^^^^ + +- Remove internet explorer support (#24495) +- Removing magic status code numbers from ``api_connexion`` (#24050) +- Upgrade FAB to ``4.1.2`` (#24619) +- Switch Markdown engine to ``markdown-it-py`` (#19702) +- Update ``rich`` to latest version across the board. (#24186) +- Get rid of ``TimedJSONWebSignatureSerializer`` (#24519) +- Update flask-appbuilder ``authlib``/ ``oauth`` dependency (#24516) +- Upgrade to ``webpack`` 5 (#24485) +- Add ``typescript`` (#24337) +- The JWT claims in the request to retrieve logs have been standardized: we use ``nbf`` and ``aud`` claims for + maturity and audience of the requests. Also "filename" payload field is used to keep log name. (#24519) +- Address all ``yarn`` test warnings (#24722) +- Upgrade to react 18 and chakra 2 (#24430) +- Refactor ``DagRun.verify_integrity`` (#24114) +- Upgrade FAB to ``4.1.1`` (#24399) +- We now need at least ``Flask-WTF 0.15`` (#24621) + + +Airflow 2.3.2 (2022-06-04) +-------------------------- + +No significant changes + +Bug Fixes +^^^^^^^^^ + +- Run the ``check_migration`` loop at least once +- Fix grid view for mapped tasks (#24059) +- Icons in grid view for different DAG run types (#23970) +- Faster grid view (#23951) +- Disallow calling expand with no arguments (#23463) +- Add missing ``is_mapped`` field to Task response. (#23319) +- DagFileProcessorManager: Start a new process group only if current process not a session leader (#23872) +- Mask sensitive values for not-yet-running TIs (#23807) +- Add cascade to ``dag_tag`` to ``dag`` foreign key (#23444) +- Use ``--subdir`` argument value for standalone dag processor. (#23864) +- Highlight task states by hovering on legend row (#23678) +- Fix and speed up grid view (#23947) +- Prevent UI from crashing if grid task instances are null (#23939) +- Remove redundant register exit signals in ``dag-processor`` command (#23886) +- Add ``__wrapped__`` property to ``_TaskDecorator`` (#23830) +- Fix UnboundLocalError when ``sql`` is empty list in DbApiHook (#23816) +- Enable clicking on DAG owner in autocomplete dropdown (#23804) +- Simplify flash message for ``_airflow_moved`` tables (#23635) +- Exclude missing tasks from the gantt view (#23627) + +Doc only changes +^^^^^^^^^^^^^^^^ + +- Add column names for DB Migration Reference (#23853) + +Misc/Internal +^^^^^^^^^^^^^ + +- Remove pinning for xmltodict (#23992) + + Airflow 2.3.1 (2022-05-25) -------------------------- @@ -118,7 +1258,7 @@ Continuing the effort to bind TaskInstance to a DagRun, XCom entries are now als Task log templates are now read from the metadata database instead of ``airflow.cfg`` (#20165) """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" -Previously, a task’s log is dynamically rendered from the ``[core] log_filename_template`` and ``[elasticsearch] log_id_template`` config values at runtime. This resulted in unfortunate characteristics, e.g. it is impractical to modify the config value after an Airflow instance is running for a while, since all existing task logs have be saved under the previous format and cannot be found with the new config value. +Previously, a task's log is dynamically rendered from the ``[core] log_filename_template`` and ``[elasticsearch] log_id_template`` config values at runtime. This resulted in unfortunate characteristics, e.g. it is impractical to modify the config value after an Airflow instance is running for a while, since all existing task logs have be saved under the previous format and cannot be found with the new config value. A new ``log_template`` table is introduced to solve this problem. This table is synchronized with the aforementioned config values every time Airflow starts, and a new field ``log_template_id`` is added to every DAG run to point to the format used by tasks (``NULL`` indicates the first ever entry for compatibility). @@ -135,9 +1275,9 @@ No change in behavior is expected. This was necessary in order to take advantag XCom now defined by ``run_id`` instead of ``execution_date`` (#20975) """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" -As a continuation to the TaskInstance-DagRun relation change started in Airflow 2.2, the ``execution_date`` columns on XCom has been removed from the database, and replaced by an `association proxy `_ field at the ORM level. If you access Airflow’s metadata database directly, you should rewrite the implementation to use the ``run_id`` column instead. +As a continuation to the TaskInstance-DagRun relation change started in Airflow 2.2, the ``execution_date`` columns on XCom has been removed from the database, and replaced by an `association proxy `_ field at the ORM level. If you access Airflow's metadata database directly, you should rewrite the implementation to use the ``run_id`` column instead. -Note that Airflow’s metadatabase definition on both the database and ORM levels are considered implementation detail without strict backward compatibility guarantees. +Note that Airflow's metadatabase definition on both the database and ORM levels are considered implementation detail without strict backward compatibility guarantees. Non-JSON-serializable params deprecated (#21135). """"""""""""""""""""""""""""""""""""""""""""""""" @@ -178,14 +1318,14 @@ Details in the `SQLAlchemy Changelog `_ fields at the ORM level. If you access Airflow’s metadatabase directly, you should rewrite the implementation to use the ``run_id`` columns instead. +As a part of the TaskInstance-DagRun relation change, the ``execution_date`` columns on TaskInstance and TaskReschedule have been removed from the database, and replaced by `association proxy `_ fields at the ORM level. If you access Airflow's metadatabase directly, you should rewrite the implementation to use the ``run_id`` columns instead. -Note that Airflow’s metadatabase definition on both the database and ORM levels are considered implementation detail without strict backward compatibility guarantees. +Note that Airflow's metadatabase definition on both the database and ORM levels are considered implementation detail without strict backward compatibility guarantees. DaskExecutor - Dask Worker Resources and queues """"""""""""""""""""""""""""""""""""""""""""""" If dask workers are not started with complementary resources to match the specified queues, it will now result in an ``AirflowException``\ , whereas before it would have just ignored the ``queue`` argument. +Logical date of a DAG run triggered from the web UI now have its sub-second component set to zero +""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + +Due to a change in how the logical date (``execution_date``) is generated for a manual DAG run, a manual DAG run's logical date may not match its time-of-trigger, but have its sub-second part zero-ed out. For example, a DAG run triggered on ``2021-10-11T12:34:56.78901`` would have its logical date set to ``2021-10-11T12:34:56.00000``. + +This may affect some logic that expects on this quirk to detect whether a run is triggered manually or not. Note that ``dag_run.run_type`` is a more authoritative value for this purpose. Also, if you need this distinction between automated and manually-triggered run for "next execution date" calculation, please also consider using the new data interval variables instead, which provide a more consistent behavior between the two run types. + New Features ^^^^^^^^^^^^ @@ -6828,9 +7975,7 @@ New signature: .. code-block:: python - def wait_for_transfer_job( - self, job, expected_statuses=(GcpTransferOperationStatus.SUCCESS,) - ): + def wait_for_transfer_job(self, job, expected_statuses=(GcpTransferOperationStatus.SUCCESS,)): ... The behavior of ``wait_for_transfer_job`` has changed: @@ -7798,7 +8943,6 @@ There are five roles created for Airflow by default: Admin, User, Op, Viewer, an Breaking changes ~~~~~~~~~~~~~~~~ - * AWS Batch Operator renamed property queue to job_queue to prevent conflict with the internal queue from CeleryExecutor - AIRFLOW-2542 * Users created and stored in the old users table will not be migrated automatically. FAB's built-in authentication support must be reconfigured. * Airflow dag home page is now ``/home`` (instead of ``/admin``\ ). @@ -8711,14 +9855,14 @@ A logger is the entry point into the logging system. Each logger is a named buck Each message that is written to the logger is a Log Record. Each log record contains a log level indicating the severity of that specific message. A log record can also contain useful metadata that describes the event that is being logged. This can include details such as a stack trace or an error code. -When a message is given to the logger, the log level of the message is compared to the log level of the logger. If the log level of the message meets or exceeds the log level of the logger itself, the message will undergo further processing. If it doesn’t, the message will be ignored. +When a message is given to the logger, the log level of the message is compared to the log level of the logger. If the log level of the message meets or exceeds the log level of the logger itself, the message will undergo further processing. If it doesn't, the message will be ignored. Once a logger has determined that a message needs to be processed, it is passed to a Handler. This configuration is now more flexible and can be easily be maintained in a single file. Changes in Airflow Logging ~~~~~~~~~~~~~~~~~~~~~~~~~~ -Airflow's logging mechanism has been refactored to use Python’s built-in ``logging`` module to perform logging of the application. By extending classes with the existing ``LoggingMixin``\ , all the logging will go through a central logger. Also the ``BaseHook`` and ``BaseOperator`` already extend this class, so it is easily available to do logging. +Airflow's logging mechanism has been refactored to use Python's built-in ``logging`` module to perform logging of the application. By extending classes with the existing ``LoggingMixin``\ , all the logging will go through a central logger. Also the ``BaseHook`` and ``BaseOperator`` already extend this class, so it is easily available to do logging. The main benefit is easier configuration of the logging by setting a single centralized python file. Disclaimer; there is still some inline configuration, but this will be removed eventually. The new logging class is defined by setting the dotted classpath in your ``~/airflow/airflow.cfg`` file: diff --git a/SELECTIVE_CHECKS.md b/SELECTIVE_CHECKS.md deleted file mode 100644 index 3a92d9c817987..0000000000000 --- a/SELECTIVE_CHECKS.md +++ /dev/null @@ -1,144 +0,0 @@ - - -# Selective CI Checks - -In order to optimise our CI jobs, we've implemented optimisations to only run selected checks for some -kind of changes. The logic implemented reflects the internal architecture of Airflow 2.0 packages -and it helps to keep down both the usage of jobs in GitHub Actions as well as CI feedback time to -contributors in case of simpler changes. - -We have the following test types (separated by packages in which they are): - -* Always - those are tests that should be always executed (always folder) -* Core - for the core Airflow functionality (core folder) -* API - Tests for the Airflow API (api and api_connexion folders) -* CLI - Tests for the Airflow CLI (cli folder) -* WWW - Tests for the Airflow webserver (www folder) -* Providers - Tests for all Providers of Airflow (providers folder) -* Other - all other tests (all other folders that are not part of any of the above) - -We also have several special kinds of tests that are not separated by packages but they are marked with -pytest markers. They can be found in any of those packages and they can be selected by the appropriate -pytest custom command line options. See `TESTING.rst `_ for details but those are: - -* Integration - tests that require external integration images running in docker-compose -* Quarantined - tests that are flaky and need to be fixed -* Postgres - tests that require Postgres database. They are only run when backend is Postgres -* MySQL - tests that require MySQL database. They are only run when backend is MySQL - -Even if the types are separated, In case they share the same backend version/python version, they are -run sequentially in the same job, on the same CI machine. Each of them in a separate `docker run` command -and with additional docker cleaning between the steps to not fall into the trap of exceeding resource -usage in one big test run, but also not to increase the number of jobs per each Pull Request. - -The logic implemented for the changes works as follows: - -1) In case of direct push (so when PR gets merged) or scheduled run, we always run all tests and checks. - This is in order to make sure that the merge did not miss anything important. The remainder of the logic - is executed only in case of Pull Requests. We do not add providers tests in case DEFAULT_BRANCH is - different than main, because providers are only important in main branch and PRs to main branch. - -2) We retrieve which files have changed in the incoming Merge Commit (github.sha is a merge commit - automatically prepared by GitHub in case of Pull Request, so we can retrieve the list of changed - files from that commit directly). - -3) If any of the important, environment files changed (Dockerfile, ci scripts, setup.py, GitHub workflow - files), then we again run all tests and checks. Those are cases where the logic of the checks changed - or the environment for the checks changed so we want to make sure to check everything. We do not add - providers tests in case DEFAULT_BRANCH is different than main, because providers are only - important in main branch and PRs to main branch. - -4) If any of py files changed: we need to have CI image and run full static checks so we enable image building - -5) If any of docs changed: we need to have CI image so we enable image building - -6) If any of chart files changed, we need to run helm tests so we enable helm unit tests - -7) If any of API files changed, we need to run API tests so we enable them - -8) If any of the relevant source files that trigger the tests have changed at all. Those are airflow - sources, chart, tests and kubernetes_tests. If any of those files changed, we enable tests and we - enable image building, because the CI images are needed to run tests. - -9) Then we determine which types of the tests should be run. We count all the changed files in the - relevant airflow sources (airflow, chart, tests, kubernetes_tests) first and then we count how many - files changed in different packages: - - * in any case tests in `Always` folder are run. Those are special tests that should be run any time - modifications to any Python code occurs. Example test of this type is verifying proper structure of - the project including proper naming of all files. - * if any of the Airflow API files changed we enable `API` test type - * if any of the Airflow CLI files changed we enable `CLI` test type and Kubernetes tests (the - K8S tests depend on CLI changes as helm chart uses CLI to run Airflow). - * if this is a main branch and if any of the Provider files changed we enable `Providers` test type - * if any of the WWW files changed we enable `WWW` test type - * if any of the Kubernetes files changed we enable `Kubernetes` test type - * Then we subtract count of all the `specific` above per-type changed files from the count of - all changed files. In case there are any files changed, then we assume that some unknown files - changed (likely from the core of airflow) and in this case we enable all test types above and the - Core test types - simply because we do not want to risk to miss anything. - * In all cases where tests are enabled we also add Integration and - depending on - the backend used = Postgres or MySQL types of tests. - -10) Quarantined tests are always run when tests are run - we need to run them often to observe how - often they fail so that we can decide to move them out of quarantine. Details about the - Quarantined tests are described in `TESTING.rst `_ - -11) There is a special case of static checks. In case the above logic determines that the CI image - needs to be built, we run long and more comprehensive version of static checks - including - Mypy, Flake8. And those tests are run on all files, no matter how many files changed. - In case the image is not built, we run only simpler set of changes - the longer static checks - that require CI image are skipped, and we only run the tests on the files that changed in the incoming - commit - unlike flake8/mypy, those static checks are per-file based and they should not miss any - important change. - -Similarly to selective tests we also run selective security scans. In Pull requests, -the Python scan will only run when there is a python code change and JavaScript scan will only run if -there is a JavaScript or `yarn.lock` file change. For main builds, all scans are always executed. - -The selective check algorithm is shown here: - - -````mermaid -flowchart TD -A(PR arrives)-->B[Selective Check] -B-->C{Direct push merge?} -C-->|Yes| N[Enable images] -N-->D(Run Full Test
+Quarantined
Run full static checks) -C-->|No| E[Retrieve changed files] -E-->F{Environment files changed?} -F-->|Yes| N -F-->|No| G{Docs changed} -G-->|Yes| O[Enable images building] -O-->I{Chart files changed?} -G-->|No| I -I-->|Yes| P[Enable helm tests] -P-->J{API files changed} -I-->|No| J -J-->|Yes| Q[Enable API tests] -Q-->H{Sources changed?} -J-->|No| H -H-->|Yes| R[Enable Pytests] -R-->K[Determine test type] -K-->S{Core files changed} -S-->|Yes| N -S-->|No| M(Run selected test+
Integration, Quarantined
Full static checks) -H-->|No| L[Skip running test
Run subset of static checks] -``` diff --git a/STATIC_CODE_CHECKS.rst b/STATIC_CODE_CHECKS.rst index dccade808e425..b27d170d1c01e 100644 --- a/STATIC_CODE_CHECKS.rst +++ b/STATIC_CODE_CHECKS.rst @@ -59,7 +59,7 @@ After installation, pre-commit hooks are run automatically when you commit the c only run on the files that you change during your commit, so they are usually pretty fast and do not slow down your iteration speed on your changes. There are also ways to disable the ``pre-commits`` temporarily when you commit your code with ``--no-verify`` switch or skip certain checks that you find -to much disturbing your local workflow. See `Available pre-commit checks<#available-pre-commit-checks>`_ +to much disturbing your local workflow. See `Available pre-commit checks <#available-pre-commit-checks>`_ and `Using pre-commit <#using-pre-commit>`_ .. note:: Additional prerequisites might be needed @@ -76,7 +76,7 @@ The current list of prerequisites is limited to ``xmllint``: Some pre-commit hooks also require the Docker Engine to be configured as the static checks are executed in the Docker environment (See table in the -Available pre-commit checks<#available-pre-commit-checks>`_ . You should build the images +`Available pre-commit checks <#available-pre-commit-checks>`_ . You should build the images locally before installing pre-commit checks as described in `BREEZE.rst `__. Sometimes your image is outdated and needs to be rebuilt because some dependencies have been changed. @@ -129,195 +129,210 @@ require Breeze Docker image to be build locally. .. BEGIN AUTO-GENERATED STATIC CHECK LIST -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| ID | Description | Image | -+========================================================+==================================================================+=========+ -| black | Run Black (the uncompromising Python code formatter) | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| blacken-docs | Run black on python code blocks in documentation files | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-airflow-2-1-compatibility | Check that providers are 2.1 compatible. | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-airflow-config-yaml-consistent | Checks for consistency between config.yml and default_config.cfg | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-airflow-providers-have-extras | Checks providers available when declared by extras in setup.py | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-apache-license-rat | Check if licenses are OK for Apache | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-base-operator-partial-arguments | Check BaseOperator and partial() arguments | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-base-operator-usage | * Check BaseOperator[Link] core imports | | -| | * Check BaseOperator[Link] other imports | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-boring-cyborg-configuration | Checks for Boring Cyborg configuration consistency | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-breeze-top-dependencies-limited | Breeze should have small number of top-level dependencies | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-builtin-literals | Require literal syntax when initializing Python builtin types | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-changelog-has-no-duplicates | Check changelogs for duplicate entries | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-daysago-import-from-utils | Make sure days_ago is imported from airflow.utils.dates | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-docstring-param-types | Check that docstrings do not specify param types | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-executables-have-shebangs | Check that executables have shebang | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-extra-packages-references | Checks setup extra packages | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-extras-order | Check order of extras in Dockerfile | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-for-inclusive-language | Check for language that we do not accept as community | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-hooks-apply | Check if all hooks apply to the repository | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-incorrect-use-of-LoggingMixin | Make sure LoggingMixin is not used alone | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-integrations-are-consistent | Check if integration list is consistent in various places | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-merge-conflict | Check that merge conflicts are not being committed | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-newsfragments-are-valid | Check newsfragments are valid | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-no-providers-in-core-examples | No providers imports in core example DAGs | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-no-relative-imports | No relative imports | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-persist-credentials-disabled-in-github-workflows | Check that workflow files have persist-credentials disabled | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-pre-commit-information-consistent | Update information re pre-commit hooks and verify ids and names | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-provide-create-sessions-imports | Check provide_session and create_session imports | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-provider-yaml-valid | Validate providers.yaml files | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-providers-init-file-missing | Provider init file is missing | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-providers-subpackages-init-file-exist | Provider subpackage init files are there | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-pydevd-left-in-code | Check for pydevd debug statements accidentally left | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-revision-heads-map | Check that the REVISION_HEADS_MAP is up-to-date | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-safe-filter-usage-in-html | Don't use safe in templates | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-setup-order | Check order of dependencies in setup.cfg and setup.py | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-start-date-not-used-in-defaults | 'start_date' not to be defined in default_args in example_dags | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-system-tests-present | Check if system tests have required segments of code | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| check-xml | Check XML files with xmllint | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| codespell | Run codespell to check for common misspellings in files | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| debug-statements | Detect accidentally committed debug statements | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| detect-private-key | Detect if private key is added to the repository | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| doctoc | Add TOC for md and rst files | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| end-of-file-fixer | Make sure that there is an empty line at the end | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| fix-encoding-pragma | Remove encoding header from python files | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| flynt | Run flynt string format converter for Python | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| forbid-tabs | Fail if tabs are used in the project | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| identity | Print input to the static check hooks for troubleshooting | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| insert-license | * Add license for all SQL files | | -| | * Add license for all rst files | | -| | * Add license for all CSS/JS/PUML/TS/TSX files | | -| | * Add license for all JINJA template files | | -| | * Add license for all shell files | | -| | * Add license for all Python files | | -| | * Add license for all XML files | | -| | * Add license for all YAML files | | -| | * Add license for all md files | | -| | * Add license for all other files | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| isort | Run isort to sort imports in Python files | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| lint-chart-schema | Lint chart/values.schema.json file | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| lint-css | stylelint | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| lint-dockerfile | Lint dockerfile | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| lint-helm-chart | Lint Helm Chart | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| lint-javascript | * ESLint against airflow/ui | * | -| | * ESLint against current UI JavaScript files | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| lint-json-schema | * Lint JSON Schema files with JSON Schema | | -| | * Lint NodePort Service with JSON Schema | | -| | * Lint Docker compose files with JSON Schema | | -| | * Lint chart/values.schema.json file with JSON Schema | | -| | * Lint chart/values.yaml file with JSON Schema | | -| | * Lint airflow/config_templates/config.yml file with JSON Schema | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| lint-markdown | Run markdownlint | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| lint-openapi | * Lint OpenAPI using spectral | | -| | * Lint OpenAPI using openapi-spec-validator | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| mixed-line-ending | Detect if mixed line ending is used (\r vs. \r\n) | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| pretty-format-json | Format json files | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| pydocstyle | Run pydocstyle | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| python-no-log-warn | Check if there are no deprecate log warn | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| pyupgrade | Upgrade Python code automatically | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| rst-backticks | Check if RST files use double backticks for code | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| run-flake8 | Run flake8 | * | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| run-mypy | * Run mypy for dev | * | -| | * Run mypy for core | | -| | * Run mypy for providers | | -| | * Run mypy for /docs/ folder | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| run-shellcheck | Check Shell scripts syntax correctness | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| static-check-autoflake | Remove all unused code | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| trailing-whitespace | Remove trailing whitespace at end of line | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| update-breeze-file | Update output of breeze commands in BREEZE.rst | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| update-breeze-readme-config-hash | Update Breeze README.md with config files hash | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| update-extras | Update extras in documentation | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| update-in-the-wild-to-be-sorted | Sort INTHEWILD.md alphabetically | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| update-inlined-dockerfile-scripts | Inline Dockerfile and Dockerfile.ci scripts | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| update-local-yml-file | Update mounts in the local yml file | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| update-migration-references | Update migration ref doc | * | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| update-providers-dependencies | Update cross-dependencies for providers packages | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| update-setup-cfg-file | Update setup.cfg file with all licenses | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| update-spelling-wordlist-to-be-sorted | Sort alphabetically and uniquify spelling_wordlist.txt | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| update-supported-versions | Updates supported versions in documentation | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| update-vendored-in-k8s-json-schema | Vendor k8s definitions into values.schema.json | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| update-version | Update version to the latest version in the documentation | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| yamllint | Check YAML files with yamllint | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ -| yesqa | Remove unnecessary noqa statements | | -+--------------------------------------------------------+------------------------------------------------------------------+---------+ ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| ID | Description | Image | ++===========================================================+==================================================================+=========+ +| black | Run black (python formatter) | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| blacken-docs | Run black on python code blocks in documentation files | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-airflow-config-yaml-consistent | Checks for consistency between config.yml and default_config.cfg | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-airflow-provider-compatibility | Check compatibility of Providers with Airflow | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-apache-license-rat | Check if licenses are OK for Apache | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-base-operator-partial-arguments | Check BaseOperator and partial() arguments | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-base-operator-usage | * Check BaseOperator[Link] core imports | | +| | * Check BaseOperator[Link] other imports | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-boring-cyborg-configuration | Checks for Boring Cyborg configuration consistency | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-breeze-top-dependencies-limited | Breeze should have small number of top-level dependencies | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-builtin-literals | Require literal syntax when initializing Python builtin types | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-changelog-has-no-duplicates | Check changelogs for duplicate entries | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-core-deprecation-classes | Verify using of dedicated Airflow deprecation classes in core | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-daysago-import-from-utils | Make sure days_ago is imported from airflow.utils.dates | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-decorated-operator-implements-custom-name | Check @task decorator implements custom_operator_name | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-docstring-param-types | Check that docstrings do not specify param types | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-example-dags-urls | Check that example dags url include provider versions | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-executables-have-shebangs | Check that executables have shebang | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-extra-packages-references | Checks setup extra packages | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-extras-order | Check order of extras in Dockerfile | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-for-inclusive-language | Check for language that we do not accept as community | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-hooks-apply | Check if all hooks apply to the repository | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-incorrect-use-of-LoggingMixin | Make sure LoggingMixin is not used alone | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-init-decorator-arguments | Check model __init__ and decorator arguments are in sync | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-lazy-logging | Check that all logging methods are lazy | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-links-to-example-dags-do-not-use-hardcoded-versions | Check that example dags do not use hard-coded version numbers | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-merge-conflict | Check that merge conflicts are not being committed | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-newsfragments-are-valid | Check newsfragments are valid | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-no-providers-in-core-examples | No providers imports in core example DAGs | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-no-relative-imports | No relative imports | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-persist-credentials-disabled-in-github-workflows | Check that workflow files have persist-credentials disabled | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-pre-commit-information-consistent | Update information re pre-commit hooks and verify ids and names | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-provide-create-sessions-imports | Check provide_session and create_session imports | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-provider-yaml-valid | Validate provider.yaml files | * | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-providers-init-file-missing | Provider init file is missing | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-providers-subpackages-init-file-exist | Provider subpackage init files are there | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-pydevd-left-in-code | Check for pydevd debug statements accidentally left | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-revision-heads-map | Check that the REVISION_HEADS_MAP is up-to-date | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-safe-filter-usage-in-html | Don't use safe in templates | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-setup-order | Check order of dependencies in setup.cfg and setup.py | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-start-date-not-used-in-defaults | 'start_date' not to be defined in default_args in example_dags | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-system-tests-present | Check if system tests have required segments of code | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-system-tests-tocs | Check that system tests is properly added | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| check-xml | Check XML files with xmllint | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| codespell | Run codespell to check for common misspellings in files | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| compile-www-assets | Compile www assets | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| compile-www-assets-dev | Compile www assets in dev mode | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| create-missing-init-py-files-tests | Create missing init.py files in tests | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| debug-statements | Detect accidentally committed debug statements | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| detect-private-key | Detect if private key is added to the repository | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| doctoc | Add TOC for md and rst files | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| end-of-file-fixer | Make sure that there is an empty line at the end | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| fix-encoding-pragma | Remove encoding header from python files | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| flynt | Run flynt string format converter for Python | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| identity | Print input to the static check hooks for troubleshooting | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| insert-license | * Add license for all SQL files | | +| | * Add license for all rst files | | +| | * Add license for all CSS/JS/PUML/TS/TSX files | | +| | * Add license for all JINJA template files | | +| | * Add license for all shell files | | +| | * Add license for all Python files | | +| | * Add license for all XML files | | +| | * Add license for all YAML files | | +| | * Add license for all md files | | +| | * Add license for all other files | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| isort | Run isort to sort imports in Python files | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| lint-chart-schema | Lint chart/values.schema.json file | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| lint-css | stylelint | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| lint-dockerfile | Lint dockerfile | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| lint-helm-chart | Lint Helm Chart | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| lint-json-schema | * Lint JSON Schema files with JSON Schema | | +| | * Lint NodePort Service with JSON Schema | | +| | * Lint Docker compose files with JSON Schema | | +| | * Lint chart/values.schema.json file with JSON Schema | | +| | * Lint chart/values.yaml file with JSON Schema | | +| | * Lint airflow/config_templates/config.yml file with JSON Schema | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| lint-markdown | Run markdownlint | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| lint-openapi | * Lint OpenAPI using spectral | | +| | * Lint OpenAPI using openapi-spec-validator | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| mixed-line-ending | Detect if mixed line ending is used (\r vs. \r\n) | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| pretty-format-json | Format json files | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| pydocstyle | Run pydocstyle | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| python-no-log-warn | Check if there are no deprecate log warn | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| pyupgrade | Upgrade Python code automatically | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| replace-bad-characters | Replace bad characters | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| rst-backticks | Check if RST files use double backticks for code | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| run-flake8 | Run flake8 | * | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| run-mypy | * Run mypy for dev | * | +| | * Run mypy for core | | +| | * Run mypy for providers | | +| | * Run mypy for /docs/ folder | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| run-shellcheck | Check Shell scripts syntax correctness | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| static-check-autoflake | Remove all unused code | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| trailing-whitespace | Remove trailing whitespace at end of line | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| ts-compile-and-lint-javascript | TS types generation and ESLint against current UI files | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| update-breeze-cmd-output | Update output of breeze commands in BREEZE.rst | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| update-breeze-readme-config-hash | Update Breeze README.md with config files hash | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| update-er-diagram | Update ER diagram | * | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| update-extras | Update extras in documentation | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| update-in-the-wild-to-be-sorted | Sort INTHEWILD.md alphabetically | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| update-inlined-dockerfile-scripts | Inline Dockerfile and Dockerfile.ci scripts | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| update-local-yml-file | Update mounts in the local yml file | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| update-migration-references | Update migration ref doc | * | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| update-providers-dependencies | Update cross-dependencies for providers packages | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| update-spelling-wordlist-to-be-sorted | Sort alphabetically and uniquify spelling_wordlist.txt | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| update-supported-versions | Updates supported versions in documentation | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| update-vendored-in-k8s-json-schema | Vendor k8s definitions into values.schema.json | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| update-version | Update version to the latest version in the documentation | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| yamllint | Check YAML files with yamllint | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ +| yesqa | Remove unnecessary noqa statements | | ++-----------------------------------------------------------+------------------------------------------------------------------+---------+ .. END AUTO-GENERATED STATIC CHECK LIST @@ -387,7 +402,7 @@ The static code checks can be launched using the Breeze environment. You run the static code checks via ``breeze static-check`` or commands. You can see the list of available static checks either via ``--help`` flag or by using the autocomplete -option. Note that the ``all`` static check runs all configured static checks. +option. Run the ``mypy`` check for the currently staged changes: @@ -417,19 +432,19 @@ Run all checks for the currently staged files: .. code-block:: bash - breeze static-checks --type all + breeze static-checks Run all checks for all files: .. code-block:: bash - breeze static-checks --type all --all-files + breeze static-checks --all-files Run all checks for last commit : .. code-block:: bash - breeze static-checks --type all --last-commit + breeze static-checks --last-commit Debugging pre-commit check scripts requiring image -------------------------------------------------- diff --git a/TESTING.rst b/TESTING.rst index 12983726e1ebb..7cbb9d49c9057 100644 --- a/TESTING.rst +++ b/TESTING.rst @@ -51,10 +51,48 @@ Follow the guidelines when writing unit tests: * For new tests, use standard "asserts" of Python and ``pytest`` decorators/context managers for testing rather than ``unittest`` ones. See `pytest docs `_ for details. * Use a parameterized framework for tests that have variations in parameters. +* Use with ``pytest.warn`` to capture warnings rather than ``recwarn`` fixture. We are aiming for 0-warning in our + tests, so we run Pytest with ``--disable-warnings`` but instead we have ``pytest-capture-warnings`` plugin that + overrides ``recwarn`` fixture behaviour. **NOTE:** We plan to convert all unit tests to standard "asserts" semi-automatically, but this will be done later in Airflow 2.0 development phase. That will include setUp/tearDown/context managers and decorators. +Airflow test types +------------------ + +Airflow tests in the CI environment are split into several test types: + +* Always - those are tests that should be always executed (always folder) +* Core - for the core Airflow functionality (core folder) +* API - Tests for the Airflow API (api and api_connexion folders) +* CLI - Tests for the Airflow CLI (cli folder) +* WWW - Tests for the Airflow webserver (www folder) +* Providers - Tests for all Providers of Airflow (providers folder) +* Other - all other tests (all other folders that are not part of any of the above) + +This is done for three reasons: + +1. in order to selectively run only subset of the test types for some PRs +2. in order to allow parallel execution of the tests on Self-Hosted runners + +For case 2. We can utilise memory and CPUs available on both CI and local development machines to run +test in parallel. This way we can decrease the time of running all tests in self-hosted runners from +60 minutes to ~15 minutes. + +.. note:: + + We need to split tests manually into separate suites rather than utilise + ``pytest-xdist`` or ``pytest-parallel`` which could be a simpler and much more "native" parallelization + mechanism. Unfortunately, we cannot utilise those tools because our tests are not truly ``unit`` tests that + can run in parallel. A lot of our tests rely on shared databases - and they update/reset/cleanup the + databases while they are executing. They are also exercising features of the Database such as locking which + further increases cross-dependency between tests. Until we make all our tests truly unit tests (and not + touching the database or until we isolate all such tests to a separate test type, we cannot really rely on + frameworks that run tests in parallel. In our solution each of the test types is run in parallel with its + own database (!) so when we have 8 test types running in parallel, there are in fact 8 databases run + behind the scenes to support them and each of the test types executes its own tests sequentially. + Running Unit Tests from PyCharm IDE ----------------------------------- @@ -108,8 +146,9 @@ To run unit tests from the Visual Studio Code: :align: center :alt: Running tests -Running Unit Tests --------------------------------- +Running Unit Tests in local virtualenv +-------------------------------------- + To run unit, integration, and system tests from the Breeze and your virtualenv, you can use the `pytest `_ framework. @@ -158,8 +197,8 @@ for debugging purposes, enter: pytest --log-cli-level=DEBUG tests/core/test_core.py::TestCore -Running Tests for a Specified Target Using Breeze from the Host ---------------------------------------------------------------- +Running Tests using Breeze from the Host +---------------------------------------- If you wish to only run tests and not to drop into the shell, apply the ``tests`` command. You can add extra targets and pytest flags after the ``--`` command. Note that @@ -168,20 +207,37 @@ to breeze. .. code-block:: bash - breeze tests tests/providers/http/hooks/test_http.py tests/core/test_core.py --db-reset --log-cli-level=DEBUG + breeze testing tests tests/providers/http/hooks/test_http.py tests/core/test_core.py --db-reset --log-cli-level=DEBUG You can run the whole test suite without adding the test target: .. code-block:: bash - breeze tests --db-reset + breeze testing tests --db-reset You can also specify individual tests or a group of tests: .. code-block:: bash - breeze tests --db-reset tests/core/test_core.py::TestCore + breeze testing tests --db-reset tests/core/test_core.py::TestCore + +You can also limit the tests to execute to specific group of tests + +.. code-block:: bash + + breeze testing tests --test-type Core + +In case of Providers tests, you can run tests for all providers + +.. code-block:: bash + + breeze testing tests --test-type Providers +You can also limit the set of providers you would like to run tests of + +.. code-block:: bash + + breeze testing tests --test-type "Providers[airbyte,http]" Running Tests of a specified type from the Host ----------------------------------------------- @@ -201,15 +257,15 @@ kinds of test types: .. code-block:: bash - ./breeze-legacy --test-type Core --db-reset tests + breeze testing tests --test-type Core --db-reset tests Runs all provider tests: .. code-block:: bash - ./breeze-legacy --test-type Providers --db-reset tests + breeze testing tests --test-type Providers --db-reset tests -* Special kinds of tests - Integration, Quarantined, Postgres, MySQL, which are marked with pytest +* Special kinds of tests Quarantined, Postgres, MySQL, which are marked with pytest marks and for those you need to select the type using test-type switch. If you want to run such tests using breeze, you need to pass appropriate ``--test-type`` otherwise the test will be skipped. Similarly to the per-directory tests if you do not specify the test or tests to run, @@ -219,74 +275,168 @@ kinds of test types: .. code-block:: bash - ./breeze-legacy --test-type Quarantined tests tests/cli/commands/test_task_command.py --db-reset + breeze testing tests --test-type Quarantined tests tests/cli/commands/test_task_command.py --db-reset Run all Quarantined tests: .. code-block:: bash - ./breeze-legacy --test-type Quarantined tests --db-reset + breeze testing tests --test-type Quarantined tests --db-reset -Helm Unit Tests -=============== -On the Airflow Project, we have decided to stick with pythonic testing for our Helm chart. This makes our chart -easier to test, easier to modify, and able to run with the same testing infrastructure. To add Helm unit tests -add them in ``tests/charts``. +Running full Airflow unit test suite in parallel +------------------------------------------------ + +If you run ``breeze testing tests --run-in-parallel`` tests run in parallel +on your development machine - maxing out the number of parallel runs at the number of cores you +have available in your Docker engine. + +In case you do not have enough memory available to your Docker (8 GB), the ``Integration``. ``Provider`` +and ``Core`` test type are executed sequentially with cleaning the docker setup in-between. This +allows to print + +This allows for massive speedup in full test execution. On 8 CPU machine with 16 cores and 64 GB memory +and fast SSD disk, the whole suite of tests completes in about 5 minutes (!). Same suite of tests takes +more than 30 minutes on the same machine when tests are run sequentially. + +.. note:: + + On MacOS you might have less CPUs and less memory available to run the tests than you have in the host, + simply because your Docker engine runs in a Linux Virtual Machine under-the-hood. If you want to make + use of the parallelism and memory usage for the CI tests you might want to increase the resources available + to your docker engine. See the `Resources `_ chapter + in the ``Docker for Mac`` documentation on how to do it. + +You can also limit the parallelism by specifying the maximum number of parallel jobs via +MAX_PARALLEL_TEST_JOBS variable. If you set it to "1", all the test types will be run sequentially. + +.. code-block:: bash + + MAX_PARALLEL_TEST_JOBS="1" ./scripts/ci/testing/ci_run_airflow_testing.sh + +.. note:: + + In case you would like to cleanup after execution of such tests you might have to cleanup + some of the docker containers running in case you use ctrl-c to stop execution. You can easily do it by + running this command (it will kill all docker containers running so do not use it if you want to keep some + docker containers running): + + .. code-block:: bash + + docker kill $(docker ps -q) + +Running Backend-Specific Tests +------------------------------ + +Tests that are using a specific backend are marked with a custom pytest marker ``pytest.mark.backend``. +The marker has a single parameter - the name of a backend. It corresponds to the ``--backend`` switch of +the Breeze environment (one of ``mysql``, ``sqlite``, or ``postgres``). Backend-specific tests only run when +the Breeze environment is running with the right backend. If you specify more than one backend +in the marker, the test runs for all specified backends. + +Example of the ``postgres`` only test: .. code-block:: python - class TestBaseChartTest(unittest.TestCase): + @pytest.mark.backend("postgres") + def test_copy_expert(self): ... -To render the chart create a YAML string with the nested dictionary of options you wish to test. You can then -use our ``render_chart`` function to render the object of interest into a testable Python dictionary. Once the chart -has been rendered, you can use the ``render_k8s_object`` function to create a k8s model object. It simultaneously -ensures that the object created properly conforms to the expected resource spec and allows you to use object values -instead of nested dictionaries. -Example test here: +Example of the ``postgres,mysql`` test (they are skipped with the ``sqlite`` backend): .. code-block:: python - from tests.charts.helm_template_generator import render_chart, render_k8s_object + @pytest.mark.backend("postgres", "mysql") + def test_celery_executor(self): + ... - git_sync_basic = """ - dags: - gitSync: - enabled: true - """ +You can use the custom ``--backend`` switch in pytest to only run tests specific for that backend. +Here is an example of running only postgres-specific backend tests: + +.. code-block:: bash - class TestGitSyncScheduler(unittest.TestCase): - def test_basic(self): - helm_settings = yaml.safe_load(git_sync_basic) - res = render_chart( - "GIT-SYNC", - helm_settings, - show_only=["templates/scheduler/scheduler-deployment.yaml"], - ) - dep: k8s.V1Deployment = render_k8s_object(res[0], k8s.V1Deployment) - assert "dags" == dep.spec.template.spec.volumes[1].name + pytest --backend postgres + +Running Long-running tests +-------------------------- + +Some of the tests rung for a long time. Such tests are marked with ``@pytest.mark.long_running`` annotation. +Those tests are skipped by default. You can enable them with ``--include-long-running`` flag. You +can also decide to only run tests with ``-m long-running`` flags to run only those tests. -To run tests using breeze run the following command +Running Quarantined tests +------------------------- + +Some of our tests are quarantined. This means that this test will be run in isolation and that it will be +re-run several times. Also when quarantined tests fail, the whole test suite will not fail. The quarantined +tests are usually flaky tests that need some attention and fix. + +Those tests are marked with ``@pytest.mark.quarantined`` annotation. +Those tests are skipped by default. You can enable them with ``--include-quarantined`` flag. You +can also decide to only run tests with ``-m quarantined`` flag to run only those tests. + +Running Tests with provider packages +------------------------------------ + +Airflow 2.0 introduced the concept of splitting the monolithic Airflow package into separate +providers packages. The main "apache-airflow" package contains the bare Airflow implementation, +and additionally we have 70+ providers that we can install additionally to get integrations with +external services. Those providers live in the same monorepo as Airflow, but we build separate +packages for them and the main "apache-airflow" package does not contain the providers. + +Most of the development in Breeze happens by iterating on sources and when you run +your tests during development, you usually do not want to build packages and install them separately. +Therefore by default, when you enter Breeze airflow and all providers are available directly from +sources rather than installed from packages. This is for example to test the "provider discovery" +mechanism available that reads provider information from the package meta-data. + +When Airflow is run from sources, the metadata is read from provider.yaml +files, but when Airflow is installed from packages, it is read via the package entrypoint +``apache_airflow_provider``. + +By default, all packages are prepared in wheel format. To install Airflow from packages you +need to run the following steps: + +1. Prepare provider packages .. code-block:: bash - ./breeze-legacy --test-type Helm tests + breeze release-management prepare-provider-packages [PACKAGE ...] + +If you run this command without packages, you will prepare all packages. However, You can specify +providers that you would like to build if you just want to build few provider packages. +The packages are prepared in ``dist`` folder. Note that this command cleans up the ``dist`` folder +before running, so you should run it before generating ``apache-airflow`` package. + +2. Prepare airflow packages + +.. code-block:: bash + + breeze release-management prepare-airflow-package + +This prepares airflow .whl package in the dist folder. + +3. Enter breeze installing both airflow and providers from the dist packages + +.. code-block:: bash + + breeze --use-airflow-version wheel --use-packages-from-dist --skip-mounting-local-sources + Airflow Integration Tests ========================= Some of the tests in Airflow are integration tests. These tests require ``airflow`` Docker -image and extra images with integrations (such as ``redis``, ``mongodb``, etc.). - +image and extra images with integrations (such as ``celery``, ``mongodb``, etc.). +The integration tests are all stored in the ``tests/integration`` folder. Enabling Integrations --------------------- Airflow integration tests cannot be run in the local virtualenv. They can only run in the Breeze -environment with enabled integrations and in the CI. See `<.github/workflows/ci.yml>`_ for details about Airflow CI. +environment with enabled integrations and in the CI. See `CI `_ for details about Airflow CI. When you are in the Breeze environment, by default, all integrations are disabled. This enables only true unit tests to be executed in Breeze. You can enable the integration by passing the ``--integration `` @@ -295,8 +445,7 @@ or using the ``--integration all`` switch that enables all integrations. NOTE: Every integration requires a separate container with the corresponding integration image. These containers take precious resources on your PC, mainly the memory. The started integrations are not stopped -until you stop the Breeze environment with the ``stop`` command and restart it -via ``restart`` command. +until you stop the Breeze environment with the ``stop`` command and started with the ``start`` command. The following integrations are available: @@ -312,13 +461,9 @@ The following integrations are available: - Integration that provides Kerberos authentication * - mongo - Integration required for MongoDB hooks - * - openldap - - Integration required for OpenLDAP hooks * - pinot - Integration required for Apache Pinot hooks - * - rabbitmq - - Integration required for Celery executor tests - * - redis + * - celery - Integration required for Celery executor tests * - trino - Integration required for Trino hooks @@ -341,10 +486,6 @@ To start all integrations, enter: breeze --integration all -In the CI environment, integrations can be enabled by specifying the ``ENABLED_INTEGRATIONS`` variable -storing a space-separated list of integrations to start. Thanks to that, we can run integration and -integration-less tests separately in different jobs, which is desired from the memory usage point of view. - Note that Kerberos is a special kind of integration. Some tests run differently when Kerberos integration is enabled (they retrieve and use a Kerberos authentication token) and differently when the Kerberos integration is disabled (they neither retrieve nor use the token). Therefore, one of the test jobs @@ -356,11 +497,11 @@ Running Integration Tests All tests using an integration are marked with a custom pytest marker ``pytest.mark.integration``. The marker has a single parameter - the name of integration. -Example of the ``redis`` integration test: +Example of the ``celery`` integration test: .. code-block:: python - @pytest.mark.integration("redis") + @pytest.mark.integration("celery") def test_real_ping(self): hook = RedisHook(redis_conn_id="redis_default") redis = hook.get_conn() @@ -384,267 +525,202 @@ To run only ``mongo`` integration tests: .. code-block:: bash - pytest --integration mongo + pytest --integration mongo tests/integration -To run integration tests for ``mongo`` and ``rabbitmq``: +To run integration tests for ``mongo`` and ``celery``: .. code-block:: bash - pytest --integration mongo --integration rabbitmq - -Note that collecting all tests takes some time. So, if you know where your tests are located, you can -speed up the test collection significantly by providing the folder where the tests are located. - -Here is an example of the collection limited to the ``providers/apache`` directory: - -.. code-block:: bash - - pytest --integration cassandra tests/providers/apache/ - -Running Backend-Specific Tests ------------------------------- - -Tests that are using a specific backend are marked with a custom pytest marker ``pytest.mark.backend``. -The marker has a single parameter - the name of a backend. It corresponds to the ``--backend`` switch of -the Breeze environment (one of ``mysql``, ``sqlite``, or ``postgres``). Backend-specific tests only run when -the Breeze environment is running with the right backend. If you specify more than one backend -in the marker, the test runs for all specified backends. - -Example of the ``postgres`` only test: - -.. code-block:: python - - @pytest.mark.backend("postgres") - def test_copy_expert(self): - ... - + pytest --integration mongo --integration celery tests/integration -Example of the ``postgres,mysql`` test (they are skipped with the ``sqlite`` backend): - -.. code-block:: python - - @pytest.mark.backend("postgres", "mysql") - def test_celery_executor(self): - ... - -You can use the custom ``--backend`` switch in pytest to only run tests specific for that backend. -Here is an example of running only postgres-specific backend tests: +Here is an example of the collection limited to the ``providers/apache`` sub-directory: .. code-block:: bash - pytest --backend postgres - -Running Long-running tests --------------------------- - -Some of the tests rung for a long time. Such tests are marked with ``@pytest.mark.long_running`` annotation. -Those tests are skipped by default. You can enable them with ``--include-long-running`` flag. You -can also decide to only run tests with ``-m long-running`` flags to run only those tests. + pytest --integration cassandra tests/integrations/providers/apache -Quarantined tests ------------------ +Running Integration Tests from the Host +--------------------------------------- -Some of our tests are quarantined. This means that this test will be run in isolation and that it will be -re-run several times. Also when quarantined tests fail, the whole test suite will not fail. The quarantined -tests are usually flaky tests that need some attention and fix. +You can also run integration tests using Breeze from the host. -Those tests are marked with ``@pytest.mark.quarantined`` annotation. -Those tests are skipped by default. You can enable them with ``--include-quarantined`` flag. You -can also decide to only run tests with ``-m quarantined`` flag to run only those tests. +Runs all integration tests: + .. code-block:: bash -Airflow test types -================== + breeze testing integration-tests --db-reset --integration all -Airflow tests in the CI environment are split into several test types: +Runs all mongo DB tests: -* Always - those are tests that should be always executed (always folder) -* Core - for the core Airflow functionality (core folder) -* API - Tests for the Airflow API (api and api_connexion folders) -* CLI - Tests for the Airflow CLI (cli folder) -* WWW - Tests for the Airflow webserver (www folder) -* Providers - Tests for all Providers of Airflow (providers folder) -* Other - all other tests (all other folders that are not part of any of the above) + .. code-block:: bash -This is done for three reasons: + breeze testing integration-tests --db-reset --integration mongo -1. in order to selectively run only subset of the test types for some PRs -2. in order to allow parallel execution of the tests on Self-Hosted runners +Helm Unit Tests +=============== -For case 1. see `Pull Request Workflow `_ for details. +On the Airflow Project, we have decided to stick with pythonic testing for our Helm chart. This makes our chart +easier to test, easier to modify, and able to run with the same testing infrastructure. To add Helm unit tests +add them in ``tests/charts``. -For case 2. We can utilise memory and CPUs available on both CI and local development machines to run -test in parallel. This way we can decrease the time of running all tests in self-hosted runners from -60 minutes to ~15 minutes. +.. code-block:: python -.. note:: + class TestBaseChartTest: + ... - We need to split tests manually into separate suites rather than utilise - ``pytest-xdist`` or ``pytest-parallel`` which could be a simpler and much more "native" parallelization - mechanism. Unfortunately, we cannot utilise those tools because our tests are not truly ``unit`` tests that - can run in parallel. A lot of our tests rely on shared databases - and they update/reset/cleanup the - databases while they are executing. They are also exercising features of the Database such as locking which - further increases cross-dependency between tests. Until we make all our tests truly unit tests (and not - touching the database or until we isolate all such tests to a separate test type, we cannot really rely on - frameworks that run tests in parallel. In our solution each of the test types is run in parallel with its - own database (!) so when we have 8 test types running in parallel, there are in fact 8 databases run - behind the scenes to support them and each of the test types executes its own tests sequentially. +To render the chart create a YAML string with the nested dictionary of options you wish to test. You can then +use our ``render_chart`` function to render the object of interest into a testable Python dictionary. Once the chart +has been rendered, you can use the ``render_k8s_object`` function to create a k8s model object. It simultaneously +ensures that the object created properly conforms to the expected resource spec and allows you to use object values +instead of nested dictionaries. +Example test here: -Running full Airflow test suite in parallel -=========================================== +.. code-block:: python -If you run ``./scripts/ci/testing/ci_run_airflow_testing.sh`` tests run in parallel -on your development machine - maxing out the number of parallel runs at the number of cores you -have available in your Docker engine. + from tests.charts.helm_template_generator import render_chart, render_k8s_object -In case you do not have enough memory available to your Docker (~32 GB), the ``Integration`` test type -is always run sequentially - after all tests are completed (docker cleanup is performed in-between). + git_sync_basic = """ + dags: + gitSync: + enabled: true + """ -This allows for massive speedup in full test execution. On 8 CPU machine with 16 cores and 64 GB memory -and fast SSD disk, the whole suite of tests completes in about 5 minutes (!). Same suite of tests takes -more than 30 minutes on the same machine when tests are run sequentially. -.. note:: + class TestGitSyncScheduler: + def test_basic(self): + helm_settings = yaml.safe_load(git_sync_basic) + res = render_chart( + "GIT-SYNC", + helm_settings, + show_only=["templates/scheduler/scheduler-deployment.yaml"], + ) + dep: k8s.V1Deployment = render_k8s_object(res[0], k8s.V1Deployment) + assert "dags" == dep.spec.template.spec.volumes[1].name - On MacOS you might have less CPUs and less memory available to run the tests than you have in the host, - simply because your Docker engine runs in a Linux Virtual Machine under-the-hood. If you want to make - use of the parallelism and memory usage for the CI tests you might want to increase the resources available - to your docker engine. See the `Resources `_ chapter - in the ``Docker for Mac`` documentation on how to do it. -You can also limit the parallelism by specifying the maximum number of parallel jobs via -MAX_PARALLEL_TEST_JOBS variable. If you set it to "1", all the test types will be run sequentially. +To execute all Helm tests using breeze command and utilize parallel pytest tests, you can run the +following command (but it takes quite a long time even in a multi-processor machine). .. code-block:: bash - MAX_PARALLEL_TEST_JOBS="1" ./scripts/ci/testing/ci_run_airflow_testing.sh - -.. note:: - - In case you would like to cleanup after execution of such tests you might have to cleanup - some of the docker containers running in case you use ctrl-c to stop execution. You can easily do it by - running this command (it will kill all docker containers running so do not use it if you want to keep some - docker containers running): - - .. code-block:: bash - - docker kill $(docker ps -q) - + breeze testing helm-tests -Running Tests with provider packages -==================================== - -Airflow 2.0 introduced the concept of splitting the monolithic Airflow package into separate -providers packages. The main "apache-airflow" package contains the bare Airflow implementation, -and additionally we have 70+ providers that we can install additionally to get integrations with -external services. Those providers live in the same monorepo as Airflow, but we build separate -packages for them and the main "apache-airflow" package does not contain the providers. +You can also run Helm tests individually via the usual ``breeze`` command. Just enter breeze and run the +tests with pytest as you would do with regular unit tests (you can add ``-n auto`` command to run Helm +tests in parallel - unlike most of the regular unit tests of ours that require a database, the Helm tests are +perfectly safe to be run in parallel (and if you have multiple processors, you can gain significant +speedups when using parallel runs): -Most of the development in Breeze happens by iterating on sources and when you run -your tests during development, you usually do not want to build packages and install them separately. -Therefore by default, when you enter Breeze airflow and all providers are available directly from -sources rather than installed from packages. This is for example to test the "provider discovery" -mechanism available that reads provider information from the package meta-data. - -When Airflow is run from sources, the metadata is read from provider.yaml -files, but when Airflow is installed from packages, it is read via the package entrypoint -``apache_airflow_provider``. +.. code-block:: bash -By default, all packages are prepared in wheel format. To install Airflow from packages you -need to run the following steps: + breeze -1. Prepare provider packages +This enters breeze container. .. code-block:: bash - breeze prepare-provider-packages [PACKAGE ...] - -If you run this command without packages, you will prepare all packages. However, You can specify -providers that you would like to build if you just want to build few provider packages. -The packages are prepared in ``dist`` folder. Note that this command cleans up the ``dist`` folder -before running, so you should run it before generating ``apache-airflow`` package. + pytest tests/charts -n auto -2. Prepare airflow packages +This runs all chart tests using all processors you have available. .. code-block:: bash - breeze prepare-airflow-package - -This prepares airflow .whl package in the dist folder. - -3. Enter breeze installing both airflow and providers from the packages + pytest tests/charts/test_airflow_common.py -n auto -This installs airflow and enters +This will run all tests from ``tests_airflow_common.py`` file using all processors you have available. .. code-block:: bash - ./breeze-legacy --use-airflow-version wheel --use-packages-from-dist --skip-mounting-local-sources + pytest tests/charts/test_airflow_common.py +This will run all tests from ``tests_airflow_common.py`` file sequentially. -Running Tests with Kubernetes -============================= +Kubernetes tests +================ Airflow has tests that are run against real Kubernetes cluster. We are using `Kind `_ to create and run the cluster. We integrated the tools to start/stop/ deploy and run the cluster tests in our repository and into Breeze development environment. -Configuration for the cluster is kept in ``./build/.kube/config`` file in your Airflow source repository, and -our scripts set the ``KUBECONFIG`` variable to it. If you want to interact with the Kind cluster created -you can do it from outside of the scripts by exporting this variable and point it to this file. +KinD has a really nice ``kind`` tool that you can use to interact with the cluster. Run ``kind --help`` to +learn more. -Starting Kubernetes Cluster ---------------------------- +K8S test environment +------------------------ -For your testing, you manage Kind cluster with ``kind-cluster`` breeze command: +Before running ``breeze k8s`` cluster commands you need to setup the environment. This is done +by ``breeze k8s setup-env`` command. Breeze in this command makes sure to download tools that +are needed to run k8s tests: Helm, Kind, Kubectl in the right versions and sets up a +Python virtualenv that is needed to run the tests. All those tools and env are setup in +``.build/.k8s-env`` folder. You can activate this environment yourselves as usual by sourcing +``bin/activate`` script, but since we are supporting multiple clusters in the same installation +it is best if you use ``breeze k8s shell`` with the right parameters specifying which cluster +to use. -.. code-block:: bash +Multiple cluster support +------------------------ - ./breeze-legacy kind-cluster [ start | stop | recreate | status | deploy | test | shell | k9s ] +The main feature of ``breeze k8s`` command is that it allows you to manage multiple KinD clusters - one +per each combination of Python and Kubernetes version. This is used during CI where we can run same +tests against those different clusters - even in parallel. -The command allows you to start/stop/recreate/status Kind Kubernetes cluster, deploy Airflow via Helm -chart as well as interact with the cluster (via test and shell commands). +The cluster name follows the pattern ``airflow-python-X.Y-vA.B.C`` where X.Y is a major/minor Python version +and A.B.C is Kubernetes version. Example cluster name: ``airflow-python-3.7-v1.24.0`` -Setting up the Kind Kubernetes cluster takes some time, so once you started it, the cluster continues running -until it is stopped with the ``kind-cluster stop`` command or until ``kind-cluster recreate`` -command is used (it will stop and recreate the cluster image). +Most of the commands can be executed in parallel for multiple images/clusters by adding ``--run-in-parallel`` +to create clusters or deploy airflow. Similarly checking for status, dumping logs and deleting clusters +can be run with ``--all`` flag and they will be executed sequentially for all locally created clusters. -The cluster name follows the pattern ``airflow-python-X.Y-vA.B.C`` where X.Y is a Python version -and A.B.C is a Kubernetes version. This way you can have multiple clusters set up and running at the same -time for different Python versions and different Kubernetes versions. +Per-cluster configuration files +------------------------------- +Once you start the cluster, the configuration for it is stored in a dynamically created folder - separate +folder for each python/kubernetes_version combination. The folder is ``./build/.k8s-clusters/`` -Deploying Airflow to Kubernetes Cluster ---------------------------------------- +There are two files there: -Deploying Airflow to the Kubernetes cluster created is also done via ``kind-cluster deploy`` breeze command: +* kubectl config file stored in .kubeconfig file - our scripts set the ``KUBECONFIG`` variable to it +* KinD cluster configuration in .kindconfig.yml file - our scripts set the ``KINDCONFIG`` variable to it -.. code-block:: bash +The ``KUBECONFIG`` file is automatically used when you enter any of the ``breeze k8s`` commands that use +``kubectl`` or when you run ``kubectl`` in the k8s shell. The ``KINDCONFIG`` file is used when cluster is +started but You and the ``k8s`` command can inspect it to know for example what port is forwarded to the +webserver running in the cluster. + +The files are deleted by ``breeze k8s delete-cluster`` command. - ./breeze-legacy kind-cluster deploy +Managing Kubernetes Cluster +--------------------------- -The deploy command performs those steps: +For your testing, you manage Kind cluster with ``k8s`` breeze command group. Those commands allow to +created: -1. It rebuilds the latest ``apache/airflow:main-pythonX.Y`` production images using the - latest sources using local caching. It also adds example DAGs to the image, so that they do not - have to be mounted inside. -2. Loads the image to the Kind Cluster using the ``kind load`` command. -3. Starts airflow in the cluster using the official helm chart (in ``airflow`` namespace) -4. Forwards Local 8080 port to the webserver running in the cluster -5. Applies the volumes.yaml to get the volumes deployed to ``default`` namespace - this is where - KubernetesExecutor starts its pods. +.. image:: ./images/breeze/output_k8s.svg + :width: 100% + :alt: Breeze k8s -You can also specify a different executor by providing the ``--executor`` optional argument: +The command group allows you to setup environment, start/stop/recreate/status Kind Kubernetes cluster, +configure cluster (via ``create-cluster``, ``configure-cluster`` command). Those commands can be run with +``--run-in-parallel`` flag for all/selected clusters and they can be executed in parallel. -.. code-block:: bash +In order to deploy Airflow, the PROD image of Airflow need to be extended and example dags and POD +template files should be added to the image. This is done via ``build-k8s-image``, ``upload-k8s-image``. +This can also be done for all/selected images/clusters in parallel via ``--run-in-parallel`` flag. - ./breeze-legacy kind-cluster deploy --executor CeleryExecutor +Then Airflow (by using Helm Chart) can be deployed to the cluster via ``deploy-airflow`` command. +This can also be done for all/selected images/clusters in parallel via ``--run-in-parallel`` flag. You can +pass extra options when deploying airflow to configure your depliyment. -Note that when you specify the ``--executor`` option, it becomes the default. Therefore, every other operations -on ``./breeze-legacy kind-cluster`` will default to using this executor. To change that, use the ``--executor`` option on the -subsequent commands too. +You can check the status, dump logs and finally delete cluster via ``status``, ``logs``, ``delete-cluster`` +commands. This can also be done for all created clusters in parallel via ``--all`` flag. + +You can interact with the cluster (via ``shell`` and ``k9s`` commands). + +You can run set of k8s tests via ``tests`` command. You can also run tests in parallel on all/selected +clusters by ``--run-in-parallel`` flag. Running tests with Kubernetes Cluster @@ -653,29 +729,20 @@ Running tests with Kubernetes Cluster You can either run all tests or you can select which tests to run. You can also enter interactive virtualenv to run the tests manually one by one. -Running Kubernetes tests via shell: - -.. code-block:: bash - - export EXECUTOR="KubernetesExecutor" ## can be also CeleryExecutor or CeleryKubernetesExecutor - - ./scripts/ci/kubernetes/ci_run_kubernetes_tests.sh - runs all kubernetes tests - ./scripts/ci/kubernetes/ci_run_kubernetes_tests.sh TEST [TEST ...] - runs selected kubernetes tests (from kubernetes_tests folder) - Running Kubernetes tests via breeze: .. code-block:: bash - ./breeze-legacy kind-cluster test - ./breeze-legacy kind-cluster test -- TEST TEST [TEST ...] + breeze k8s tests + breeze k8s tests TEST TEST [TEST ...] Optionally add ``--executor``: .. code-block:: bash - ./breeze-legacy kind-cluster test --executor CeleryExecutor - ./breeze-legacy kind-cluster test -- TEST TEST [TEST ...] --executor CeleryExecutor + breeze k8s tests --executor CeleryExecutor + breeze k8s tests --executor CeleryExecutor TEST TEST [TEST ...] Entering shell with Kubernetes Cluster -------------------------------------- @@ -683,31 +750,18 @@ Entering shell with Kubernetes Cluster This shell is prepared to run Kubernetes tests interactively. It has ``kubectl`` and ``kind`` cli tools available in the path, it has also activated virtualenv environment that allows you to run tests via pytest. -The binaries are available in ./.build/kubernetes-bin/``KUBERNETES_VERSION`` path. -The virtualenv is available in ./.build/.kubernetes_venv/``KIND_CLUSTER_NAME``_host_python_``HOST_PYTHON_VERSION`` - -Where ``KIND_CLUSTER_NAME`` is the name of the cluster and ``HOST_PYTHON_VERSION`` is the version of python -in the host. - -You can enter the shell via those scripts - -.. code-block:: bash - - export EXECUTOR="KubernetesExecutor" ## can be also CeleryExecutor or CeleryKubernetesExecutor - - ./scripts/ci/kubernetes/ci_run_kubernetes_tests.sh [-i|--interactive] - Activates virtual environment ready to run tests and drops you in - ./scripts/ci/kubernetes/ci_run_kubernetes_tests.sh [--help] - Prints this help message - +The virtualenv is available in ./.build/.k8s-env/ +The binaries are available in ``.build/.k8s-env/bin`` path. .. code-block:: bash - ./breeze-legacy kind-cluster shell + breeze k8s shell Optionally add ``--executor``: .. code-block:: bash - ./breeze-legacy kind-cluster shell --executor CeleryExecutor + breeze k8s shell --executor CeleryExecutor K9s CLI - debug Kubernetes in style! @@ -734,7 +788,7 @@ You can enter the k9s tool via breeze (after you deployed Airflow): .. code-block:: bash - ./breeze-legacy kind-cluster k9s + breeze k8s k9s You can exit k9s by pressing Ctrl-C. @@ -743,73 +797,316 @@ Typical testing pattern for Kubernetes tests The typical session for tests with Kubernetes looks like follows: -1. Start the Kind cluster: + +1. Prepare the environment: .. code-block:: bash - ./breeze-legacy kind-cluster start + breeze k8s setup-env - Starts Kind Kubernetes cluster +The first time you run it, it should result in creating the virtualenv and installing good versions +of kind, kubectl and helm. All of them are installed in ``./build/.k8s-env`` (binaries available in ``bin`` +sub-folder of it). - Use CI image. +.. code-block:: text + + Initializing K8S virtualenv in /Users/jarek/IdeaProjects/airflow/.build/.k8s-env + Reinstalling PIP version in /Users/jarek/IdeaProjects/airflow/.build/.k8s-env + Installing necessary packages in /Users/jarek/IdeaProjects/airflow/.build/.k8s-env + The ``kind`` tool is not downloaded yet. Downloading 0.14.0 version. + Downloading from: https://github.com/kubernetes-sigs/kind/releases/download/v0.14.0/kind-darwin-arm64 + The ``kubectl`` tool is not downloaded yet. Downloading 1.24.3 version. + Downloading from: https://storage.googleapis.com/kubernetes-release/release/v1.24.3/bin/darwin/arm64/kubectl + The ``helm`` tool is not downloaded yet. Downloading 3.9.2 version. + Downloading from: https://get.helm.sh/helm-v3.9.2-darwin-arm64.tar.gz + Extracting the darwin-arm64/helm to /Users/jarek/IdeaProjects/airflow/.build/.k8s-env/bin + Moving the helm to /Users/jarek/IdeaProjects/airflow/.build/.k8s-env/bin/helm - Branch name: main - Docker image: apache/airflow:main-python3.7-ci - Airflow source version: 2.0.0.dev0 - Python version: 3.7 - Backend: postgres 10 +This prepares the virtual environment for tests and downloads the right versions of the tools +to ``./build/.k8s-env`` - No kind clusters found. +2. Create the KinD cluster: + +.. code-block:: bash - Creating cluster + breeze k8s create-cluster + +Should result in KinD creating the K8S cluster. + +.. code-block:: text - Creating cluster "airflow-python-3.7-v1.17.0" ... - ✓ Ensuring node image (kindest/node:v1.17.0) 🖼 + Config created in /Users/jarek/IdeaProjects/airflow/.build/.k8s-clusters/airflow-python-3.7-v1.24.2/.kindconfig.yaml: + + # Licensed to the Apache Software Foundation (ASF) under one + # or more contributor license agreements. See the NOTICE file + # distributed with this work for additional information + # regarding copyright ownership. The ASF licenses this file + # to you under the Apache License, Version 2.0 (the + # "License"); you may not use this file except in compliance + # with the License. You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, + # software distributed under the License is distributed on an + # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + # KIND, either express or implied. See the License for the + # specific language governing permissions and limitations + # under the License. + --- + kind: Cluster + apiVersion: kind.x-k8s.io/v1alpha4 + networking: + ipFamily: ipv4 + apiServerAddress: "127.0.0.1" + apiServerPort: 48366 + nodes: + - role: control-plane + - role: worker + extraPortMappings: + - containerPort: 30007 + hostPort: 18150 + listenAddress: "127.0.0.1" + protocol: TCP + + + + Creating cluster "airflow-python-3.7-v1.24.2" ... + ✓ Ensuring node image (kindest/node:v1.24.2) 🖼 ✓ Preparing nodes 📦 📦 ✓ Writing configuration 📜 ✓ Starting control-plane 🕹️ ✓ Installing CNI 🔌 - Could not read storage manifest, falling back on old k8s.io/host-path default ... ✓ Installing StorageClass 💾 ✓ Joining worker nodes 🚜 - Set kubectl context to "kind-airflow-python-3.7-v1.17.0" + Set kubectl context to "kind-airflow-python-3.7-v1.24.2" You can now use your cluster with: - kubectl cluster-info --context kind-airflow-python-3.7-v1.17.0 + kubectl cluster-info --context kind-airflow-python-3.7-v1.24.2 - Have a question, bug, or feature request? Let us know! https://kind.sigs.k8s.io/#community 🙂 + Not sure what to do next? 😅 Check out https://kind.sigs.k8s.io/docs/user/quick-start/ - Created cluster airflow-python-3.7-v1.17.0 + KinD Cluster API server URL: http://localhost:48366 + Connecting to localhost:18150. Num try: 1 + Error when connecting to localhost:18150 : ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')) + Airflow webserver is not available at port 18150. Run `breeze k8s deploy-airflow --python 3.7 --kubernetes-version v1.24.2` to (re)deploy airflow -2. Check the status of the cluster + KinD cluster airflow-python-3.7-v1.24.2 created! + + NEXT STEP: You might now configure your cluster by: + + breeze k8s configure-cluster + +3. Configure cluster for Airflow - this will recreate namespace and upload test resources for Airflow. .. code-block:: bash - ./breeze-legacy kind-cluster status + breeze k8s configure-cluster + +.. code-block:: text - Checks status of Kind Kubernetes cluster + Configuring airflow-python-3.7-v1.24.2 to be ready for Airflow deployment + Deleting K8S namespaces for kind-airflow-python-3.7-v1.24.2 + Error from server (NotFound): namespaces "airflow" not found + Error from server (NotFound): namespaces "test-namespace" not found + Creating namespaces + namespace/airflow created + namespace/test-namespace created + Created K8S namespaces for cluster kind-airflow-python-3.7-v1.24.2 - Use CI image. + Deploying test resources for cluster kind-airflow-python-3.7-v1.24.2 + persistentvolume/test-volume created + persistentvolumeclaim/test-volume created + service/airflow-webserver-node-port created + Deployed test resources for cluster kind-airflow-python-3.7-v1.24.2 - Branch name: main - Docker image: apache/airflow:main-python3.7-ci - Airflow source version: 2.0.0.dev0 - Python version: 3.7 - Backend: postgres 10 + NEXT STEP: You might now build your k8s image by: - airflow-python-3.7-v1.17.0-control-plane - airflow-python-3.7-v1.17.0-worker + breeze k8s build-k8s-image -3. Deploy Airflow to the cluster +4. Check the status of the cluster .. code-block:: bash - ./breeze-legacy kind-cluster deploy + breeze k8s status + +Should show the status of current KinD cluster. + +.. code-block:: text + + ======================================================================================================================== + Cluster: airflow-python-3.7-v1.24.2 + + * KUBECONFIG=/Users/jarek/IdeaProjects/airflow/.build/.k8s-clusters/airflow-python-3.7-v1.24.2/.kubeconfig + * KINDCONFIG=/Users/jarek/IdeaProjects/airflow/.build/.k8s-clusters/airflow-python-3.7-v1.24.2/.kindconfig.yaml + + Cluster info: airflow-python-3.7-v1.24.2 + + Kubernetes control plane is running at https://127.0.0.1:48366 + CoreDNS is running at https://127.0.0.1:48366/api/v1/namespaces/kube-system/services/kube-dns:dns/proxy + + To further debug and diagnose cluster problems, use 'kubectl cluster-info dump'. + + Storage class for airflow-python-3.7-v1.24.2 + + NAME PROVISIONER RECLAIMPOLICY VOLUMEBINDINGMODE ALLOWVOLUMEEXPANSION AGE + standard (default) rancher.io/local-path Delete WaitForFirstConsumer false 83s + + Running pods for airflow-python-3.7-v1.24.2 + + NAME READY STATUS RESTARTS AGE + coredns-6d4b75cb6d-rwp9d 1/1 Running 0 71s + coredns-6d4b75cb6d-vqnrc 1/1 Running 0 71s + etcd-airflow-python-3.7-v1.24.2-control-plane 1/1 Running 0 84s + kindnet-ckc8l 1/1 Running 0 69s + kindnet-qqt8k 1/1 Running 0 71s + kube-apiserver-airflow-python-3.7-v1.24.2-control-plane 1/1 Running 0 84s + kube-controller-manager-airflow-python-3.7-v1.24.2-control-plane 1/1 Running 0 84s + kube-proxy-6g7hn 1/1 Running 0 69s + kube-proxy-dwfvp 1/1 Running 0 71s + kube-scheduler-airflow-python-3.7-v1.24.2-control-plane 1/1 Running 0 84s + + KinD Cluster API server URL: http://localhost:48366 + Connecting to localhost:18150. Num try: 1 + Error when connecting to localhost:18150 : ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')) + + Airflow webserver is not available at port 18150. Run `breeze k8s deploy-airflow --python 3.7 --kubernetes-version v1.24.2` to (re)deploy airflow + + + Cluster healthy: airflow-python-3.7-v1.24.2 + +5. Build the image base on PROD Airflow image. You need to build the PROD image first (the command will + guide you if you did not - either by running the build separately or passing ``--rebuild-base-image`` flag + +.. code-block:: bash + + breeze k8s build-k8s-image + +.. code-block:: text + + Building the K8S image for Python 3.7 using airflow base image: ghcr.io/apache/airflow/main/prod/python3.7:latest + + [+] Building 0.1s (8/8) FINISHED + => [internal] load build definition from Dockerfile 0.0s + => => transferring dockerfile: 301B 0.0s + => [internal] load .dockerignore 0.0s + => => transferring context: 35B 0.0s + => [internal] load metadata for ghcr.io/apache/airflow/main/prod/python3.7:latest 0.0s + => [1/3] FROM ghcr.io/apache/airflow/main/prod/python3.7:latest 0.0s + => [internal] load build context 0.0s + => => transferring context: 3.00kB 0.0s + => CACHED [2/3] COPY airflow/example_dags/ /opt/airflow/dags/ 0.0s + => CACHED [3/3] COPY airflow/kubernetes_executor_templates/ /opt/airflow/pod_templates/ 0.0s + => exporting to image 0.0s + => => exporting layers 0.0s + => => writing image sha256:c0bdd363c549c3b0731b8e8ce34153d081f239ee2b582355b7b3ffd5394c40bb 0.0s + => => naming to ghcr.io/apache/airflow/main/prod/python3.7-kubernetes:latest -4. Run Kubernetes tests + NEXT STEP: You might now upload your k8s image by: + + breeze k8s upload-k8s-image + + +5. Upload the image to KinD cluster - this uploads your image to make it available for the KinD cluster. + +.. code-block:: bash + + breeze k8s upload-k8s-image + +.. code-block:: text + + K8S Virtualenv is initialized in /Users/jarek/IdeaProjects/airflow/.build/.k8s-env + Good version of kind installed: 0.14.0 in /Users/jarek/IdeaProjects/airflow/.build/.k8s-env/bin + Good version of kubectl installed: 1.25.0 in /Users/jarek/IdeaProjects/airflow/.build/.k8s-env/bin + Good version of helm installed: 3.9.2 in /Users/jarek/IdeaProjects/airflow/.build/.k8s-env/bin + Stable repo is already added + Uploading Airflow image ghcr.io/apache/airflow/main/prod/python3.7-kubernetes to cluster airflow-python-3.7-v1.24.2 + Image: "ghcr.io/apache/airflow/main/prod/python3.7-kubernetes" with ID "sha256:fb6195f7c2c2ad97788a563a3fe9420bf3576c85575378d642cd7985aff97412" not yet present on node "airflow-python-3.7-v1.24.2-worker", loading... + Image: "ghcr.io/apache/airflow/main/prod/python3.7-kubernetes" with ID "sha256:fb6195f7c2c2ad97788a563a3fe9420bf3576c85575378d642cd7985aff97412" not yet present on node "airflow-python-3.7-v1.24.2-control-plane", loading... + + NEXT STEP: You might now deploy airflow by: + + breeze k8s deploy-airflow + + +7. Deploy Airflow to the cluster - this will use Airflow Helm Chart to deploy Airflow to the cluster. + +.. code-block:: bash + + breeze k8s deploy-airflow + +.. code-block:: text + + Deploying Airflow for cluster airflow-python-3.7-v1.24.2 + Deploying kind-airflow-python-3.7-v1.24.2 with airflow Helm Chart. + Copied chart sources to /private/var/folders/v3/gvj4_mw152q556w2rrh7m46w0000gn/T/chart_edu__kir/chart + Deploying Airflow from /private/var/folders/v3/gvj4_mw152q556w2rrh7m46w0000gn/T/chart_edu__kir/chart + NAME: airflow + LAST DEPLOYED: Tue Aug 30 22:57:54 2022 + NAMESPACE: airflow + STATUS: deployed + REVISION: 1 + TEST SUITE: None + NOTES: + Thank you for installing Apache Airflow 2.3.4! + + Your release is named airflow. + You can now access your dashboard(s) by executing the following command(s) and visiting the corresponding port at localhost in your browser: + + Airflow Webserver: kubectl port-forward svc/airflow-webserver 8080:8080 --namespace airflow + Default Webserver (Airflow UI) Login credentials: + username: admin + password: admin + Default Postgres connection credentials: + username: postgres + password: postgres + port: 5432 + + You can get Fernet Key value by running the following: + + echo Fernet Key: $(kubectl get secret --namespace airflow airflow-fernet-key -o jsonpath="{.data.fernet-key}" | base64 --decode) + + WARNING: + Kubernetes workers task logs may not persist unless you configure log persistence or remote logging! + Logging options can be found at: https://airflow.apache.org/docs/helm-chart/stable/manage-logs.html + (This warning can be ignored if logging is configured with environment variables or secrets backend) + + ########################################################### + # WARNING: You should set a static webserver secret key # + ########################################################### + + You are using a dynamically generated webserver secret key, which can lead to + unnecessary restarts of your Airflow components. + + Information on how to set a static webserver secret key can be found here: + https://airflow.apache.org/docs/helm-chart/stable/production-guide.html#webserver-secret-key + Deployed kind-airflow-python-3.7-v1.24.2 with airflow Helm Chart. + + Airflow for Python 3.7 and K8S version v1.24.2 has been successfully deployed. + + The KinD cluster name: airflow-python-3.7-v1.24.2 + The kubectl cluster name: kind-airflow-python-3.7-v1.24.2. + + + KinD Cluster API server URL: http://localhost:48366 + Connecting to localhost:18150. Num try: 1 + Established connection to webserver at http://localhost:18150/health and it is healthy. + Airflow Web server URL: http://localhost:18150 (admin/admin) + + NEXT STEP: You might now run tests or interact with airflow via shell (kubectl, pytest etc.) or k9s commands: + + + breeze k8s tests + + breeze k8s shell + + breeze k8s k9s + + +8. Run Kubernetes tests Note that the tests are executed in production container not in the CI container. There is no need for the tests to run inside the Airflow CI container image as they only @@ -817,69 +1114,77 @@ communicate with the Kubernetes-run Airflow deployed via the production image. Those Kubernetes tests require virtualenv to be created locally with airflow installed. The virtualenv required will be created automatically when the scripts are run. -4a) You can run all the tests +8a) You can run all the tests .. code-block:: bash - ./breeze-legacy kind-cluster test + breeze k8s tests +.. code-block:: text -4b) You can enter an interactive shell to run tests one-by-one + Running tests with kind-airflow-python-3.7-v1.24.2 cluster. + Command to run: pytest kubernetes_tests + ========================================================================================= test session starts ========================================================================================== + platform darwin -- Python 3.9.9, pytest-6.2.5, py-1.11.0, pluggy-1.0.0 -- /Users/jarek/IdeaProjects/airflow/.build/.k8s-env/bin/python + cachedir: .pytest_cache + rootdir: /Users/jarek/IdeaProjects/airflow, configfile: pytest.ini + plugins: anyio-3.6.1, instafail-0.4.2, xdist-2.5.0, forked-1.4.0, timeouts-1.2.1, cov-3.0.0 + setup timeout: 0.0s, execution timeout: 0.0s, teardown timeout: 0.0s + collected 55 items -This prepares and enters the virtualenv in ``.build/.kubernetes_venv_`` folder: + kubernetes_tests/test_kubernetes_executor.py::TestKubernetesExecutor::test_integration_run_dag PASSED [ 1%] + kubernetes_tests/test_kubernetes_executor.py::TestKubernetesExecutor::test_integration_run_dag_with_scheduler_failure PASSED [ 3%] + kubernetes_tests/test_kubernetes_pod_operator.py::TestKubernetesPodOperatorSystem::test_already_checked_on_failure PASSED [ 5%] + kubernetes_tests/test_kubernetes_pod_operator.py::TestKubernetesPodOperatorSystem::test_already_checked_on_success ... -.. code-block:: bash - - ./breeze-legacy kind-cluster shell - -Once you enter the environment, you receive this information: +8b) You can enter an interactive shell to run tests one-by-one +This enters the virtualenv in ``.build/.k8s-env`` folder: .. code-block:: bash - Activating the virtual environment for kubernetes testing + breeze k8s shell - You can run kubernetes testing via 'pytest kubernetes_tests/....' - You can add -s to see the output of your tests on screen +Once you enter the environment, you receive this information: - The webserver is available at http://localhost:8080/ +.. code-block:: text - User/password: admin/admin + Entering interactive k8s shell. - You are entering the virtualenv now. Type exit to exit back to the original shell + (kind-airflow-python-3.7-v1.24.2:KubernetesExecutor)> In a separate terminal you can open the k9s CLI: .. code-block:: bash - ./breeze-legacy kind-cluster k9s + breeze k8s k9s Use it to observe what's going on in your cluster. -6. Debugging in IntelliJ/PyCharm +9. Debugging in IntelliJ/PyCharm It is very easy to running/debug Kubernetes tests with IntelliJ/PyCharm. Unlike the regular tests they are in ``kubernetes_tests`` folder and if you followed the previous steps and entered the shell using -``./breeze-legacy kind-cluster shell`` command, you can setup your IDE very easy to run (and debug) your +``breeze k8s shell`` command, you can setup your IDE very easy to run (and debug) your tests using the standard IntelliJ Run/Debug feature. You just need a few steps: -a) Add the virtualenv as interpreter for the project: +9a) Add the virtualenv as interpreter for the project: .. image:: images/testing/kubernetes-virtualenv.png :align: center :alt: Kubernetes testing virtualenv The virtualenv is created in your "Airflow" source directory in the -``.build/.kubernetes_venv_`` folder and you -have to find ``python`` binary and choose it when selecting interpreter. +``.build/.k8s-env`` folder and you have to find ``python`` binary and choose +it when selecting interpreter. -b) Choose pytest as test runner: +9b) Choose pytest as test runner: .. image:: images/testing/pytest-runner.png :align: center :alt: Pytest runner -c) Run/Debug tests using standard "Run/Debug" feature of IntelliJ +9c) Run/Debug tests using standard "Run/Debug" feature of IntelliJ .. image:: images/testing/run-test.png :align: center @@ -889,7 +1194,7 @@ c) Run/Debug tests using standard "Run/Debug" feature of IntelliJ NOTE! The first time you run it, it will likely fail with ``kubernetes.config.config_exception.ConfigException``: ``Invalid kube-config file. Expected key current-context in kube-config``. You need to add KUBECONFIG -environment variable copying it from the result of "./breeze-legacy kind-cluster test": +environment variable copying it from the result of "breeze k8s tests": .. code-block:: bash @@ -897,7 +1202,6 @@ environment variable copying it from the result of "./breeze-legacy kind-cluster /home/jarek/code/airflow/.build/.kube/config - .. image:: images/testing/kubeconfig-env.png :align: center :alt: Run/Debug tests @@ -914,21 +1218,28 @@ print output generated test logs and print statements to the terminal immediatel pytest kubernetes_tests/test_kubernetes_executor.py::TestKubernetesExecutor::test_integration_run_dag_with_scheduler_failure -s - You can modify the tests or KubernetesPodOperator and re-run them without re-deploying Airflow to KinD cluster. +10. Dumping logs -Sometimes there are side effects from running tests. You can run ``redeploy_airflow.sh`` without -recreating the whole cluster. This will delete the whole namespace, including the database data -and start a new Airflow deployment in the cluster. +Sometimes You want to see the logs of the clister. This can be done with ``breeze k8s logs``. .. code-block:: bash - ./scripts/ci/redeploy_airflow.sh + breeze k8s logs + +11. Redeploying airflow + +Sometimes there are side effects from running tests. You can run ``breeze k8s deploy-airflow --upgrade`` +without recreating the whole cluster. + +.. code-block:: bash -If needed you can also delete the cluster manually: + breeze k8s deploy-airflow --upgrade +If needed you can also delete the cluster manually (within the virtualenv activated by +``breeze k8s shell``: .. code-block:: bash @@ -941,20 +1252,30 @@ Kind has also useful commands to inspect your running cluster: kind --help - -However, when you change Kubernetes executor implementation, you need to redeploy -Airflow to the cluster. +12. Stop KinD cluster when you are done .. code-block:: bash - ./breeze-legacy kind-cluster deploy + breeze k8s delete-cluster + +.. code-block:: text + Deleting KinD cluster airflow-python-3.7-v1.24.2! + Deleting cluster "airflow-python-3.7-v1.24.2" ... + KinD cluster airflow-python-3.7-v1.24.2 deleted! -7. Stop KinD cluster when you are done + +Running complete k8s tests +-------------------------- + +You can also run complete k8s tests with .. code-block:: bash - ./breeze-legacy kind-cluster stop + breeze k8s run-complete-tests + +This will create cluster, build images, deploy airflow run tests and finally delete clusters as single +command. It is the way it is run in our CI, you can also run such complete tests in parallel. Airflow System Tests @@ -1074,8 +1395,7 @@ A simple example of a system test is available in: ``tests/providers/google/cloud/operators/test_compute_system.py``. -It runs two DAGs defined in ``airflow.providers.google.cloud.example_dags.example_compute.py`` and -``airflow.providers.google.cloud.example_dags.example_compute_igm.py``. +It runs two DAGs defined in ``airflow.providers.google.cloud.example_dags.example_compute.py``. Preparing provider packages for System Tests for Airflow 1.10.* series ---------------------------------------------------------------------- @@ -1086,7 +1406,7 @@ example, the below command will build google, postgres and mysql wheel packages: .. code-block:: bash - breeze prepare-provider-packages google postgres mysql + breeze release-management prepare-provider-packages google postgres mysql Those packages will be prepared in ./dist folder. This folder is mapped to /dist folder when you enter Breeze, so it is easy to automate installing those packages for testing. diff --git a/airflow/__init__.py b/airflow/__init__.py index cbbb03dd1be29..19624252a3875 100644 --- a/airflow/__init__.py +++ b/airflow/__init__.py @@ -15,8 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# - """ Authentication is implemented using flask_login and different environments can implement their own login mechanisms by providing an `airflow_login` module @@ -25,28 +23,38 @@ isort:skip_file """ - +from __future__ import annotations # flake8: noqa: F401 +import os import sys -from typing import Callable, Optional +from typing import Callable -from airflow import settings -from airflow import version +if os.environ.get("_AIRFLOW_PATCH_GEVENT"): + # If you are using gevents and start airflow webserver, you might want to run gevent monkeypatching + # as one of the first thing when Airflow is started. This allows gevent to patch networking and other + # system libraries to make them gevent-compatible before anything else patches them (for example boto) + from gevent.monkey import patch_all -__version__ = version.version + patch_all() -__all__ = ['__version__', 'login', 'DAG', 'PY36', 'PY37', 'PY38', 'PY39', 'PY310', 'XComArg'] +from airflow import settings + +__all__ = ["__version__", "login", "DAG", "PY36", "PY37", "PY38", "PY39", "PY310", "XComArg"] # Make `airflow` an namespace package, supporting installing # airflow.providers.* in different locations (i.e. one in site, and one in user # lib.) __path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore -settings.initialize() -login: Optional[Callable] = None +# Perform side-effects unless someone has explicitly opted out before import +# WARNING: DO NOT USE THIS UNLESS YOU REALLY KNOW WHAT YOU'RE DOING. +if not os.environ.get("_AIRFLOW__AS_LIBRARY", None): + settings.initialize() + +login: Callable | None = None PY36 = sys.version_info >= (3, 6) PY37 = sys.version_info >= (3, 7) @@ -54,28 +62,31 @@ PY39 = sys.version_info >= (3, 9) PY310 = sys.version_info >= (3, 10) -# Things to lazy import in form 'name': 'path.to.module' -__lazy_imports = { - 'DAG': 'airflow.models.dag', - 'XComArg': 'airflow.models.xcom_arg', - 'AirflowException': 'airflow.exceptions', +# Things to lazy import in form {local_name: ('target_module', 'target_name')} +__lazy_imports: dict[str, tuple[str, str]] = { + "DAG": (".models.dag", "DAG"), + "Dataset": (".datasets", "Dataset"), + "XComArg": (".models.xcom_arg", "XComArg"), + "AirflowException": (".exceptions", "AirflowException"), + "version": (".version", ""), + "__version__": (".version", "version"), } -def __getattr__(name): +def __getattr__(name: str): # PEP-562: Lazy loaded attributes on python modules - path = __lazy_imports.get(name) - if not path: + module_path, attr_name = __lazy_imports.get(name, ("", "")) + if not module_path: raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - import operator + import importlib - # Strip off the "airflow." prefix because of how `__import__` works (it always returns the top level - # module) - without_prefix = path.split('.', 1)[-1] + mod = importlib.import_module(module_path, __name__) + if attr_name: + val = getattr(mod, attr_name) + else: + val = mod - getter = operator.attrgetter(f'{without_prefix}.{name}') - val = getter(__import__(path)) # Store for next time globals()[name] = val return val @@ -99,7 +110,7 @@ def __getattr__(name): # into knowing the types of these symbols, and what # they contain. STATICA_HACK = True -globals()['kcah_acitats'[::-1].upper()] = False +globals()["kcah_acitats"[::-1].upper()] = False if STATICA_HACK: # pragma: no cover from airflow.models.dag import DAG from airflow.models.xcom_arg import XComArg diff --git a/airflow/__main__.py b/airflow/__main__.py index 334126b2d930b..6114534e1c887 100644 --- a/airflow/__main__.py +++ b/airflow/__main__.py @@ -17,8 +17,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Main executable module""" +from __future__ import annotations + import os import argcomplete @@ -29,14 +30,14 @@ def main(): """Main executable function""" - if conf.get("core", "security") == 'kerberos': - os.environ['KRB5CCNAME'] = conf.get('kerberos', 'ccache') - os.environ['KRB5_KTNAME'] = conf.get('kerberos', 'keytab') + if conf.get("core", "security") == "kerberos": + os.environ["KRB5CCNAME"] = conf.get("kerberos", "ccache") + os.environ["KRB5_KTNAME"] = conf.get("kerberos", "keytab") parser = cli_parser.get_parser() argcomplete.autocomplete(parser) args = parser.parse_args() args.func(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/airflow/api/__init__.py b/airflow/api/__init__.py index da07429869877..656009b0dd69d 100644 --- a/airflow/api/__init__.py +++ b/airflow/api/__init__.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Authentication backend""" +"""Authentication backend.""" +from __future__ import annotations + import logging from importlib import import_module @@ -26,8 +28,8 @@ def load_auth(): - """Loads authentication backends""" - auth_backends = 'airflow.api.auth.backend.default' + """Load authentication backends.""" + auth_backends = "airflow.api.auth.backend.default" try: auth_backends = conf.get("api", "auth_backends") except AirflowConfigException: diff --git a/airflow/api/auth/backend/basic_auth.py b/airflow/api/auth/backend/basic_auth.py index 397a722a98cf2..3f802fde636a5 100644 --- a/airflow/api/auth/backend/basic_auth.py +++ b/airflow/api/auth/backend/basic_auth.py @@ -14,33 +14,36 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Basic authentication backend""" +"""Basic authentication backend.""" +from __future__ import annotations + from functools import wraps -from typing import Any, Callable, Optional, Tuple, TypeVar, Union, cast +from typing import Any, Callable, TypeVar, cast -from flask import Response, current_app, request +from flask import Response, request from flask_appbuilder.const import AUTH_LDAP from flask_login import login_user +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.www.fab_security.sqla.models import User -CLIENT_AUTH: Optional[Union[Tuple[str, str], Any]] = None +CLIENT_AUTH: tuple[str, str] | Any | None = None def init_app(_): - """Initializes authentication backend""" + """Initialize authentication backend.""" T = TypeVar("T", bound=Callable) -def auth_current_user() -> Optional[User]: - """Authenticate and set current user if Authorization header exists""" +def auth_current_user() -> User | None: + """Authenticate and set current user if Authorization header exists.""" auth = request.authorization if auth is None or not auth.username or not auth.password: return None - ab_security_manager = current_app.appbuilder.sm + ab_security_manager = get_airflow_app().appbuilder.sm user = None if ab_security_manager.auth_type == AUTH_LDAP: user = ab_security_manager.auth_user_ldap(auth.username, auth.password) @@ -52,7 +55,7 @@ def auth_current_user() -> Optional[User]: def requires_authentication(function: T): - """Decorator for functions that require authentication""" + """Decorate functions that require authentication.""" @wraps(function) def decorated(*args, **kwargs): diff --git a/airflow/api/auth/backend/default.py b/airflow/api/auth/backend/default.py index 6b0a1a6c67907..b1d8f9bcf2674 100644 --- a/airflow/api/auth/backend/default.py +++ b/airflow/api/auth/backend/default.py @@ -15,22 +15,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Default authentication backend - everything is allowed""" +"""Default authentication backend - everything is allowed.""" +from __future__ import annotations + from functools import wraps -from typing import Any, Callable, Optional, Tuple, TypeVar, Union, cast +from typing import Any, Callable, TypeVar, cast -CLIENT_AUTH: Optional[Union[Tuple[str, str], Any]] = None +CLIENT_AUTH: tuple[str, str] | Any | None = None def init_app(_): - """Initializes authentication backend""" + """Initialize authentication backend.""" T = TypeVar("T", bound=Callable) def requires_authentication(function: T): - """Decorator for functions that require authentication""" + """Decorate functions that require authentication.""" @wraps(function) def decorated(*args, **kwargs): diff --git a/airflow/api/auth/backend/deny_all.py b/airflow/api/auth/backend/deny_all.py index 614e263684ad8..29b23f8d73b5a 100644 --- a/airflow/api/auth/backend/deny_all.py +++ b/airflow/api/auth/backend/deny_all.py @@ -15,24 +15,26 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Authentication backend that denies all requests""" +"""Authentication backend that denies all requests.""" +from __future__ import annotations + from functools import wraps -from typing import Any, Callable, Optional, Tuple, TypeVar, Union, cast +from typing import Any, Callable, TypeVar, cast from flask import Response -CLIENT_AUTH: Optional[Union[Tuple[str, str], Any]] = None +CLIENT_AUTH: tuple[str, str] | Any | None = None def init_app(_): - """Initializes authentication""" + """Initialize authentication.""" T = TypeVar("T", bound=Callable) def requires_authentication(function: T): - """Decorator for functions that require authentication""" + """Decorate functions that require authentication.""" @wraps(function) def decorated(*args, **kwargs): diff --git a/airflow/api/auth/backend/kerberos_auth.py b/airflow/api/auth/backend/kerberos_auth.py index fb76e8a1aa0fa..abb951be968ed 100644 --- a/airflow/api/auth/backend/kerberos_auth.py +++ b/airflow/api/auth/backend/kerberos_auth.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations # # Copyright (c) 2013, Michael Komitee @@ -43,23 +44,23 @@ import logging import os from functools import wraps -from socket import getfqdn -from typing import Any, Callable, Optional, Tuple, TypeVar, Union, cast +from typing import Any, Callable, TypeVar, cast import kerberos from flask import Response, _request_ctx_stack as stack, g, make_response, request # type: ignore from requests_kerberos import HTTPKerberosAuth from airflow.configuration import conf +from airflow.utils.net import getfqdn log = logging.getLogger(__name__) -CLIENT_AUTH: Optional[Union[Tuple[str, str], Any]] = HTTPKerberosAuth(service='airflow') +CLIENT_AUTH: tuple[str, str] | Any | None = HTTPKerberosAuth(service="airflow") class KerberosService: - """Class to keep information about the Kerberos Service initialized""" + """Class to keep information about the Kerberos Service initialized.""" def __init__(self): self.service_name = None @@ -70,18 +71,18 @@ def __init__(self): def init_app(app): - """Initializes application with kerberos""" - hostname = app.config.get('SERVER_NAME') + """Initialize application with kerberos.""" + hostname = app.config.get("SERVER_NAME") if not hostname: hostname = getfqdn() log.info("Kerberos: hostname %s", hostname) - service = 'airflow' + service = "airflow" _KERBEROS_SERVICE.service_name = f"{service}@{hostname}" - if 'KRB5_KTNAME' not in os.environ: - os.environ['KRB5_KTNAME'] = conf.get('kerberos', 'keytab') + if "KRB5_KTNAME" not in os.environ: + os.environ["KRB5_KTNAME"] = conf.get("kerberos", "keytab") try: log.info("Kerberos init: %s %s", service, hostname) @@ -94,7 +95,7 @@ def init_app(app): def _unauthorized(): """ - Indicate that authorization is required + Indicate that authorization is required. :return: """ return Response("Unauthorized", 401, {"WWW-Authenticate": "Negotiate"}) @@ -130,21 +131,21 @@ def _gssapi_authenticate(token): def requires_authentication(function: T): - """Decorator for functions that require authentication with Kerberos""" + """Decorate functions that require authentication with Kerberos.""" @wraps(function) def decorated(*args, **kwargs): header = request.headers.get("Authorization") if header: ctx = stack.top - token = ''.join(header.split()[1:]) + token = "".join(header.split()[1:]) return_code = _gssapi_authenticate(token) if return_code == kerberos.AUTH_GSS_COMPLETE: g.user = ctx.kerberos_user response = function(*args, **kwargs) response = make_response(response) if ctx.kerberos_token is not None: - response.headers['WWW-Authenticate'] = ' '.join(['negotiate', ctx.kerberos_token]) + response.headers["WWW-Authenticate"] = " ".join(["negotiate", ctx.kerberos_token]) return response if return_code != kerberos.AUTH_GSS_CONTINUE: diff --git a/airflow/api/auth/backend/session.py b/airflow/api/auth/backend/session.py index 3f345c8b38240..8cc75ccad024e 100644 --- a/airflow/api/auth/backend/session.py +++ b/airflow/api/auth/backend/session.py @@ -14,24 +14,26 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Session authentication backend""" +"""Session authentication backend.""" +from __future__ import annotations + from functools import wraps -from typing import Any, Callable, Optional, Tuple, TypeVar, Union, cast +from typing import Any, Callable, TypeVar, cast from flask import Response, g -CLIENT_AUTH: Optional[Union[Tuple[str, str], Any]] = None +CLIENT_AUTH: tuple[str, str] | Any | None = None def init_app(_): - """Initializes authentication backend""" + """Initialize authentication backend.""" T = TypeVar("T", bound=Callable) def requires_authentication(function: T): - """Decorator for functions that require authentication""" + """Decorate functions that require authentication.""" @wraps(function) def decorated(*args, **kwargs): diff --git a/airflow/api/client/__init__.py b/airflow/api/client/__init__.py index 49224b5336a32..35608032abcb6 100644 --- a/airflow/api/client/__init__.py +++ b/airflow/api/client/__init__.py @@ -15,9 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""API Client that allows interacting with Airflow API""" +"""API Client that allows interacting with Airflow API.""" +from __future__ import annotations + from importlib import import_module -from typing import Any from airflow import api from airflow.api.client.api_client import Client @@ -25,17 +26,17 @@ def get_current_api_client() -> Client: - """Return current API Client based on current Airflow configuration""" - api_module = import_module(conf.get_mandatory_value('cli', 'api_client')) # type: Any + """Return current API Client based on current Airflow configuration.""" + api_module = import_module(conf.get_mandatory_value("cli", "api_client")) auth_backends = api.load_auth() session = None for backend in auth_backends: - session_factory = getattr(backend, 'create_client_session', None) + session_factory = getattr(backend, "create_client_session", None) if session_factory: session = session_factory() api_client = api_module.Client( - api_base_url=conf.get('cli', 'endpoint_url'), - auth=getattr(backend, 'CLIENT_AUTH', None), + api_base_url=conf.get("cli", "endpoint_url"), + auth=getattr(backend, "CLIENT_AUTH", None), session=session, ) return api_client diff --git a/airflow/api/client/api_client.py b/airflow/api/client/api_client.py index c116d3b75ebb7..1334771b8de65 100644 --- a/airflow/api/client/api_client.py +++ b/airflow/api/client/api_client.py @@ -16,6 +16,8 @@ # specific language governing permissions and limitations # under the License. """Client for all the API clients.""" +from __future__ import annotations + import httpx @@ -75,7 +77,7 @@ def delete_pool(self, name): def get_lineage(self, dag_id: str, execution_date: str): """ - Return the lineage information for the dag on this execution date + Return the lineage information for the dag on this execution date. :param dag_id: :param execution_date: :return: diff --git a/airflow/api/client/json_client.py b/airflow/api/client/json_client.py index d87a3aeef65b3..c6995ceb2a5cf 100644 --- a/airflow/api/client/json_client.py +++ b/airflow/api/client/json_client.py @@ -15,7 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""JSON API Client""" +"""JSON API Client.""" +from __future__ import annotations from urllib.parse import urljoin @@ -25,12 +26,12 @@ class Client(api_client.Client): """Json API client implementation.""" - def _request(self, url, method='GET', json=None): + def _request(self, url, method="GET", json=None): params = { - 'url': url, + "url": url, } if json is not None: - params['json'] = json + params["json"] = json resp = getattr(self._session, method.lower())(**params) if resp.is_error: # It is justified here because there might be many resp types. @@ -38,64 +39,64 @@ def _request(self, url, method='GET', json=None): data = resp.json() except Exception: data = {} - raise OSError(data.get('error', 'Server error')) + raise OSError(data.get("error", "Server error")) return resp.json() def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None): - endpoint = f'/api/experimental/dags/{dag_id}/dag_runs' + endpoint = f"/api/experimental/dags/{dag_id}/dag_runs" url = urljoin(self._api_base_url, endpoint) data = self._request( url, - method='POST', + method="POST", json={ "run_id": run_id, "conf": conf, "execution_date": execution_date, }, ) - return data['message'] + return data["message"] def delete_dag(self, dag_id): - endpoint = f'/api/experimental/dags/{dag_id}/delete_dag' + endpoint = f"/api/experimental/dags/{dag_id}/delete_dag" url = urljoin(self._api_base_url, endpoint) - data = self._request(url, method='DELETE') - return data['message'] + data = self._request(url, method="DELETE") + return data["message"] def get_pool(self, name): - endpoint = f'/api/experimental/pools/{name}' + endpoint = f"/api/experimental/pools/{name}" url = urljoin(self._api_base_url, endpoint) pool = self._request(url) - return pool['pool'], pool['slots'], pool['description'] + return pool["pool"], pool["slots"], pool["description"] def get_pools(self): - endpoint = '/api/experimental/pools' + endpoint = "/api/experimental/pools" url = urljoin(self._api_base_url, endpoint) pools = self._request(url) - return [(p['pool'], p['slots'], p['description']) for p in pools] + return [(p["pool"], p["slots"], p["description"]) for p in pools] def create_pool(self, name, slots, description): - endpoint = '/api/experimental/pools' + endpoint = "/api/experimental/pools" url = urljoin(self._api_base_url, endpoint) pool = self._request( url, - method='POST', + method="POST", json={ - 'name': name, - 'slots': slots, - 'description': description, + "name": name, + "slots": slots, + "description": description, }, ) - return pool['pool'], pool['slots'], pool['description'] + return pool["pool"], pool["slots"], pool["description"] def delete_pool(self, name): - endpoint = f'/api/experimental/pools/{name}' + endpoint = f"/api/experimental/pools/{name}" url = urljoin(self._api_base_url, endpoint) - pool = self._request(url, method='DELETE') - return pool['pool'], pool['slots'], pool['description'] + pool = self._request(url, method="DELETE") + return pool["pool"], pool["slots"], pool["description"] def get_lineage(self, dag_id: str, execution_date: str): endpoint = f"/api/experimental/lineage/{dag_id}/{execution_date}" url = urljoin(self._api_base_url, endpoint) - data = self._request(url, method='GET') - return data['message'] + data = self._request(url, method="GET") + return data["message"] diff --git a/airflow/api/client/local_client.py b/airflow/api/client/local_client.py index c0050672a8e47..3b6e61241eaa1 100644 --- a/airflow/api/client/local_client.py +++ b/airflow/api/client/local_client.py @@ -15,7 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Local client API""" +"""Local client API.""" +from __future__ import annotations from airflow.api.client import api_client from airflow.api.common import delete_dag, trigger_dag diff --git a/airflow/api/common/delete_dag.py b/airflow/api/common/delete_dag.py index 5e0afa81cb5c9..39f1461fccc9e 100644 --- a/airflow/api/common/delete_dag.py +++ b/airflow/api/common/delete_dag.py @@ -16,6 +16,8 @@ # specific language governing permissions and limitations # under the License. """Delete DAGs APIs.""" +from __future__ import annotations + import logging from sqlalchemy import and_, or_ @@ -34,6 +36,8 @@ @provide_session def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> int: """ + Delete a DAG by a dag_id. + :param dag_id: the dag_id of the DAG to delete :param keep_records_in_log: whether keep records of the given dag_id in the Log table in the backend database (for reasons like auditing). @@ -72,12 +76,12 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> i for model in get_sqla_model_classes(): if hasattr(model, "dag_id"): - if keep_records_in_log and model.__name__ == 'Log': + if keep_records_in_log and model.__name__ == "Log": continue count += ( session.query(model) .filter(model.dag_id.in_(dags_to_delete)) - .delete(synchronize_session='fetch') + .delete(synchronize_session="fetch") ) if dag.is_subdag: parent_dag_id, task_id = dag_id.rsplit(".", 1) @@ -89,7 +93,7 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> i # Delete entries in Import Errors table for a deleted DAG # This handles the case when the dag_id is changed in the file session.query(models.ImportError).filter(models.ImportError.filename == dag.fileloc).delete( - synchronize_session='fetch' + synchronize_session="fetch" ) return count diff --git a/airflow/api/common/experimental/__init__.py b/airflow/api/common/experimental/__init__.py index b161e04346358..35a4da3f1735a 100644 --- a/airflow/api/common/experimental/__init__.py +++ b/airflow/api/common/experimental/__init__.py @@ -16,15 +16,16 @@ # specific language governing permissions and limitations # under the License. """Experimental APIs.""" +from __future__ import annotations + from datetime import datetime -from typing import Optional from airflow.exceptions import DagNotFound, DagRunNotFound, TaskNotFound from airflow.models import DagBag, DagModel, DagRun -def check_and_get_dag(dag_id: str, task_id: Optional[str] = None) -> DagModel: - """Checks that DAG exists and in case it is specified that Task exist""" +def check_and_get_dag(dag_id: str, task_id: str | None = None) -> DagModel: + """Check DAG existence and in case it is specified that Task exists.""" dag_model = DagModel.get_current(dag_id) if dag_model is None: raise DagNotFound(f"Dag id {dag_id} not found in DagModel") @@ -35,15 +36,15 @@ def check_and_get_dag(dag_id: str, task_id: Optional[str] = None) -> DagModel: error_message = f"Dag id {dag_id} not found" raise DagNotFound(error_message) if task_id and not dag.has_task(task_id): - error_message = f'Task {task_id} not found in dag {dag_id}' + error_message = f"Task {task_id} not found in dag {dag_id}" raise TaskNotFound(error_message) return dag def check_and_get_dagrun(dag: DagModel, execution_date: datetime) -> DagRun: - """Get DagRun object and check that it exists""" + """Get DagRun object and check that it exists.""" dagrun = dag.get_dagrun(execution_date=execution_date) if not dagrun: - error_message = f'Dag Run for date {execution_date} not found in dag {dag.dag_id}' + error_message = f"Dag Run for date {execution_date} not found in dag {dag.dag_id}" raise DagRunNotFound(error_message) return dagrun diff --git a/airflow/api/common/experimental/delete_dag.py b/airflow/api/common/experimental/delete_dag.py index 36bf7dd8c46a7..821b80aa9af57 100644 --- a/airflow/api/common/experimental/delete_dag.py +++ b/airflow/api/common/experimental/delete_dag.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import warnings from airflow.api.common.delete_dag import * # noqa diff --git a/airflow/api/common/experimental/get_code.py b/airflow/api/common/experimental/get_code.py index d4232b1d0903b..9e2e8c08a4b56 100644 --- a/airflow/api/common/experimental/get_code.py +++ b/airflow/api/common/experimental/get_code.py @@ -16,6 +16,8 @@ # specific language governing permissions and limitations # under the License. """Get code APIs.""" +from __future__ import annotations + from deprecated import deprecated from airflow.api.common.experimental import check_and_get_dag diff --git a/airflow/api/common/experimental/get_dag_run_state.py b/airflow/api/common/experimental/get_dag_run_state.py index 7201186ea9331..cdf044a1569b1 100644 --- a/airflow/api/common/experimental/get_dag_run_state.py +++ b/airflow/api/common/experimental/get_dag_run_state.py @@ -16,8 +16,9 @@ # specific language governing permissions and limitations # under the License. """DAG run APIs.""" +from __future__ import annotations + from datetime import datetime -from typing import Dict from deprecated import deprecated @@ -25,7 +26,7 @@ @deprecated(reason="Use DagRun().get_state() instead", version="2.2.4") -def get_dag_run_state(dag_id: str, execution_date: datetime) -> Dict[str, str]: +def get_dag_run_state(dag_id: str, execution_date: datetime) -> dict[str, str]: """Return the Dag Run state identified by the given dag_id and execution_date. :param dag_id: DAG id @@ -36,4 +37,4 @@ def get_dag_run_state(dag_id: str, execution_date: datetime) -> Dict[str, str]: dagrun = check_and_get_dagrun(dag, execution_date) - return {'state': dagrun.get_state()} + return {"state": dagrun.get_state()} diff --git a/airflow/api/common/experimental/get_dag_runs.py b/airflow/api/common/experimental/get_dag_runs.py index 2064d6eb51267..2761bb45fd2cd 100644 --- a/airflow/api/common/experimental/get_dag_runs.py +++ b/airflow/api/common/experimental/get_dag_runs.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """DAG runs APIs.""" -from typing import Any, Dict, List, Optional +from __future__ import annotations + +from typing import Any from flask import url_for @@ -25,9 +27,9 @@ from airflow.utils.state import DagRunState -def get_dag_runs(dag_id: str, state: Optional[str] = None) -> List[Dict[str, Any]]: +def get_dag_runs(dag_id: str, state: str | None = None) -> list[dict[str, Any]]: """ - Returns a list of Dag Runs for a specific DAG ID. + Return a list of Dag Runs for a specific DAG ID. :param dag_id: String identifier of a DAG :param state: queued|running|success... @@ -41,13 +43,13 @@ def get_dag_runs(dag_id: str, state: Optional[str] = None) -> List[Dict[str, Any for run in DagRun.find(dag_id=dag_id, state=state): dag_runs.append( { - 'id': run.id, - 'run_id': run.run_id, - 'state': run.state, - 'dag_id': run.dag_id, - 'execution_date': run.execution_date.isoformat(), - 'start_date': ((run.start_date or '') and run.start_date.isoformat()), - 'dag_run_url': url_for('Airflow.graph', dag_id=run.dag_id, execution_date=run.execution_date), + "id": run.id, + "run_id": run.run_id, + "state": run.state, + "dag_id": run.dag_id, + "execution_date": run.execution_date.isoformat(), + "start_date": ((run.start_date or "") and run.start_date.isoformat()), + "dag_run_url": url_for("Airflow.graph", dag_id=run.dag_id, execution_date=run.execution_date), } ) diff --git a/airflow/api/common/experimental/get_lineage.py b/airflow/api/common/experimental/get_lineage.py index 461590b70248f..73bc9dd862ef3 100644 --- a/airflow/api/common/experimental/get_lineage.py +++ b/airflow/api/common/experimental/get_lineage.py @@ -15,10 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Lineage apis""" +"""Lineage APIs.""" +from __future__ import annotations + import collections import datetime -from typing import Any, Dict +from typing import Any from sqlalchemy.orm import Session @@ -31,15 +33,15 @@ @provide_session def get_lineage( dag_id: str, execution_date: datetime.datetime, *, session: Session = NEW_SESSION -) -> Dict[str, Dict[str, Any]]: - """Gets the lineage information for dag specified.""" +) -> dict[str, dict[str, Any]]: + """Get lineage information for dag specified.""" dag = check_and_get_dag(dag_id) dagrun = check_and_get_dagrun(dag, execution_date) inlets = XCom.get_many(dag_ids=dag_id, run_id=dagrun.run_id, key=PIPELINE_INLETS, session=session) outlets = XCom.get_many(dag_ids=dag_id, run_id=dagrun.run_id, key=PIPELINE_OUTLETS, session=session) - lineage: Dict[str, Dict[str, Any]] = collections.defaultdict(dict) + lineage: dict[str, dict[str, Any]] = collections.defaultdict(dict) for meta in inlets: lineage[meta.task_id]["inlets"] = meta.value for meta in outlets: diff --git a/airflow/api/common/experimental/get_task.py b/airflow/api/common/experimental/get_task.py index 4589cc6ce4d42..34e0fac37983f 100644 --- a/airflow/api/common/experimental/get_task.py +++ b/airflow/api/common/experimental/get_task.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Task APIs..""" +"""Task APIs.""" +from __future__ import annotations + from deprecated import deprecated from airflow.api.common.experimental import check_and_get_dag diff --git a/airflow/api/common/experimental/get_task_instance.py b/airflow/api/common/experimental/get_task_instance.py index 7361efdc4c796..cc8c734338284 100644 --- a/airflow/api/common/experimental/get_task_instance.py +++ b/airflow/api/common/experimental/get_task_instance.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Task Instance APIs.""" +"""Task instance APIs.""" +from __future__ import annotations + from datetime import datetime from deprecated import deprecated @@ -34,7 +36,7 @@ def get_task_instance(dag_id: str, task_id: str, execution_date: datetime) -> Ta # Get task instance object and check that it exists task_instance = dagrun.get_task_instance(task_id) if not task_instance: - error_message = f'Task {task_id} instance for date {execution_date} not found' + error_message = f"Task {task_id} instance for date {execution_date} not found" raise TaskInstanceNotFound(error_message) return task_instance diff --git a/airflow/api/common/experimental/mark_tasks.py b/airflow/api/common/experimental/mark_tasks.py index 81cff3e30dad8..303c9f98ee59e 100644 --- a/airflow/api/common/experimental/mark_tasks.py +++ b/airflow/api/common/experimental/mark_tasks.py @@ -15,6 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Task Instance APIs.""" +from __future__ import annotations + import warnings from airflow.api.common.mark_tasks import ( # noqa diff --git a/airflow/api/common/experimental/pool.py b/airflow/api/common/experimental/pool.py index 12ebc19dfd41a..a37c3e4086042 100644 --- a/airflow/api/common/experimental/pool.py +++ b/airflow/api/common/experimental/pool.py @@ -16,6 +16,8 @@ # specific language governing permissions and limitations # under the License. """Pool APIs.""" +from __future__ import annotations + from deprecated import deprecated from airflow.exceptions import AirflowBadRequest, PoolNotFound diff --git a/airflow/api/common/experimental/trigger_dag.py b/airflow/api/common/experimental/trigger_dag.py index d52631281f534..123b09cb1c02d 100644 --- a/airflow/api/common/experimental/trigger_dag.py +++ b/airflow/api/common/experimental/trigger_dag.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import warnings diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py index 83bdb2081f0b8..8e25e0f4fd660 100644 --- a/airflow/api/common/mark_tasks.py +++ b/airflow/api/common/mark_tasks.py @@ -16,9 +16,10 @@ # specific language governing permissions and limitations # under the License. """Marks tasks APIs.""" +from __future__ import annotations from datetime import datetime -from typing import TYPE_CHECKING, Collection, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union +from typing import TYPE_CHECKING, Collection, Iterable, Iterator, NamedTuple from sqlalchemy import or_ from sqlalchemy.orm import lazyload @@ -38,7 +39,7 @@ class _DagRunInfo(NamedTuple): logical_date: datetime - data_interval: Tuple[datetime, datetime] + data_interval: tuple[datetime, datetime] def _create_dagruns( @@ -78,9 +79,9 @@ def _create_dagruns( @provide_session def set_state( *, - tasks: Collection[Union[Operator, Tuple[Operator, int]]], - run_id: Optional[str] = None, - execution_date: Optional[datetime] = None, + tasks: Collection[Operator | tuple[Operator, int]], + run_id: str | None = None, + execution_date: datetime | None = None, upstream: bool = False, downstream: bool = False, future: bool = False, @@ -88,13 +89,14 @@ def set_state( state: TaskInstanceState = TaskInstanceState.SUCCESS, commit: bool = False, session: SASession = NEW_SESSION, -) -> List[TaskInstance]: +) -> list[TaskInstance]: """ - Set the state of a task instance and if needed its relatives. Can set state - for future tasks (calculated from run_id) and retroactively + Set the state of a task instance and if needed its relatives. + + Can set state for future tasks (calculated from run_id) and retroactively for past tasks. Will verify integrity of past dag runs in order to create tasks that did not exist. It will not create dag runs that are missing - on the schedule (but it will as for subdag dag runs if needed). + on the schedule (but it will, as for subdag, dag runs if needed). :param tasks: the iterable of tasks or (task, map_index) tuples from which to work. ``task.dag`` needs to be set @@ -152,6 +154,10 @@ def set_state( qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates) tis_altered += qry_sub_dag.with_for_update().all() for task_instance in tis_altered: + # The try_number was decremented when setting to up_for_reschedule and deferred. + # Increment it back when changing the state again + if task_instance.state in [State.DEFERRED, State.UP_FOR_RESCHEDULE]: + task_instance._try_number += 1 task_instance.set_state(state, session=session) session.flush() else: @@ -163,12 +169,12 @@ def set_state( def all_subdag_tasks_query( - sub_dag_run_ids: List[str], + sub_dag_run_ids: list[str], session: SASession, state: TaskInstanceState, confirmed_dates: Iterable[datetime], ): - """Get *all* tasks of the sub dags""" + """Get *all* tasks of the sub dags.""" qry_sub_dag = ( session.query(TaskInstance) .filter(TaskInstance.dag_id.in_(sub_dag_run_ids), TaskInstance.execution_date.in_(confirmed_dates)) @@ -181,10 +187,10 @@ def get_all_dag_task_query( dag: DAG, session: SASession, state: TaskInstanceState, - task_ids: List[Union[str, Tuple[str, int]]], + task_ids: list[str | tuple[str, int]], run_ids: Iterable[str], ): - """Get all tasks of the main dag that will be affected by a state change""" + """Get all tasks of the main dag that will be affected by a state change.""" qry_dag = session.query(TaskInstance).filter( TaskInstance.dag_id == dag.dag_id, TaskInstance.run_id.in_(run_ids), @@ -201,7 +207,7 @@ def _iter_subdag_run_ids( dag: DAG, session: SASession, state: DagRunState, - task_ids: List[str], + task_ids: list[str], commit: bool, confirmed_infos: Iterable[_DagRunInfo], ) -> Iterator[str]: @@ -244,7 +250,7 @@ def verify_dagruns( session: SASession, current_task: Operator, ): - """Verifies integrity of dag_runs. + """Verify integrity of dag_runs. :param dag_runs: dag runs to verify :param commit: whether dag runs state should be updated @@ -261,7 +267,7 @@ def verify_dagruns( session.merge(dag_run) -def _iter_existing_dag_run_infos(dag: DAG, run_ids: List[str], session: SASession) -> Iterator[_DagRunInfo]: +def _iter_existing_dag_run_infos(dag: DAG, run_ids: list[str], session: SASession) -> Iterator[_DagRunInfo]: for dag_run in DagRun.find(dag_id=dag.dag_id, run_id=run_ids, session=session): dag_run.dag = dag dag_run.verify_integrity(session=session) @@ -288,8 +294,8 @@ def find_task_relatives(tasks, downstream, upstream): @provide_session def get_execution_dates( dag: DAG, execution_date: datetime, future: bool, past: bool, *, session: SASession = NEW_SESSION -) -> List[datetime]: - """Returns dates of DAG execution""" +) -> list[datetime]: + """Return DAG execution dates.""" latest_execution_date = dag.get_latest_execution_date(session=session) if latest_execution_date is None: raise ValueError(f"Received non-localized date {execution_date}") @@ -317,7 +323,7 @@ def get_execution_dates( @provide_session def get_run_ids(dag: DAG, run_id: str, future: bool, past: bool, session: SASession = NEW_SESSION): - """Returns run_ids of DAG execution""" + """Return DAG executions' run_ids.""" last_dagrun = dag.get_last_dagrun(include_externally_triggered=True, session=session) current_dagrun = dag.get_dagrun(run_id=run_id, session=session) first_dagrun = ( @@ -328,7 +334,7 @@ def get_run_ids(dag: DAG, run_id: str, future: bool, past: bool, session: SASess ) if last_dagrun is None: - raise ValueError(f'DagRun for {dag.dag_id} not found') + raise ValueError(f"DagRun for {dag.dag_id} not found") # determine run_id range of dag runs and tasks to consider end_date = last_dagrun.logical_date if future else current_dagrun.logical_date @@ -350,7 +356,7 @@ def get_run_ids(dag: DAG, run_id: str, future: bool, past: bool, session: SASess def _set_dag_run_state(dag_id: str, run_id: str, state: DagRunState, session: SASession = NEW_SESSION): """ - Helper method that set dag run state in the DB. + Set dag run state in the DB. :param dag_id: dag_id of target dag run :param run_id: run id of target dag run @@ -371,14 +377,15 @@ def _set_dag_run_state(dag_id: str, run_id: str, state: DagRunState, session: SA def set_dag_run_state_to_success( *, dag: DAG, - execution_date: Optional[datetime] = None, - run_id: Optional[str] = None, + execution_date: datetime | None = None, + run_id: str | None = None, commit: bool = False, session: SASession = NEW_SESSION, -) -> List[TaskInstance]: +) -> list[TaskInstance]: """ - Set the dag run for a specific execution date and its task instances - to success. + Set the dag run's state to success. + + Set for a specific execution date and its task instances to success. :param dag: the DAG of which to alter state :param execution_date: the execution date from which to start looking(deprecated) @@ -400,10 +407,10 @@ def set_dag_run_state_to_success( raise ValueError(f"Received non-localized date {execution_date}") dag_run = dag.get_dagrun(execution_date=execution_date) if not dag_run: - raise ValueError(f'DagRun with execution_date: {execution_date} not found') + raise ValueError(f"DagRun with execution_date: {execution_date} not found") run_id = dag_run.run_id if not run_id: - raise ValueError(f'Invalid dag_run_id: {run_id}') + raise ValueError(f"Invalid dag_run_id: {run_id}") # Mark the dag run to success. if commit: _set_dag_run_state(dag.dag_id, run_id, DagRunState.SUCCESS, session) @@ -418,14 +425,15 @@ def set_dag_run_state_to_success( def set_dag_run_state_to_failed( *, dag: DAG, - execution_date: Optional[datetime] = None, - run_id: Optional[str] = None, + execution_date: datetime | None = None, + run_id: str | None = None, commit: bool = False, session: SASession = NEW_SESSION, -) -> List[TaskInstance]: +) -> list[TaskInstance]: """ - Set the dag run for a specific execution date or run_id and its running task instances - to failed. + Set the dag run's state to failed. + + Set for a specific execution date and its task instances to failed. :param dag: the DAG of which to alter state :param execution_date: the execution date from which to start looking(deprecated) @@ -446,11 +454,11 @@ def set_dag_run_state_to_failed( raise ValueError(f"Received non-localized date {execution_date}") dag_run = dag.get_dagrun(execution_date=execution_date) if not dag_run: - raise ValueError(f'DagRun with execution_date: {execution_date} not found') + raise ValueError(f"DagRun with execution_date: {execution_date} not found") run_id = dag_run.run_id if not run_id: - raise ValueError(f'Invalid dag_run_id: {run_id}') + raise ValueError(f"Invalid dag_run_id: {run_id}") # Mark the dag run to failed. if commit: @@ -462,7 +470,7 @@ def set_dag_run_state_to_failed( TaskInstance.dag_id == dag.dag_id, TaskInstance.run_id == run_id, TaskInstance.task_id.in_(task_ids), - TaskInstance.state.in_(State.running), + TaskInstance.state.in_([State.RUNNING, State.DEFERRED, State.UP_FOR_RESCHEDULE]), ) task_ids_of_running_tis = [task_instance.task_id for task_instance in tis] @@ -478,7 +486,7 @@ def set_dag_run_state_to_failed( TaskInstance.dag_id == dag.dag_id, TaskInstance.run_id == run_id, TaskInstance.state.not_in(State.finished), - TaskInstance.state.not_in(State.running), + TaskInstance.state.not_in([State.RUNNING, State.DEFERRED, State.UP_FOR_RESCHEDULE]), ) tis = [ti for ti in tis] @@ -493,11 +501,11 @@ def __set_dag_run_state_to_running_or_queued( *, new_state: DagRunState, dag: DAG, - execution_date: Optional[datetime] = None, - run_id: Optional[str] = None, + execution_date: datetime | None = None, + run_id: str | None = None, commit: bool = False, session: SASession = NEW_SESSION, -) -> List[TaskInstance]: +) -> list[TaskInstance]: """ Set the dag run for a specific execution date to running. @@ -509,7 +517,7 @@ def __set_dag_run_state_to_running_or_queued( :return: If commit is true, list of tasks that have been updated, otherwise list of tasks that will be updated """ - res: List[TaskInstance] = [] + res: list[TaskInstance] = [] if not (execution_date is None) ^ (run_id is None): return res @@ -523,10 +531,10 @@ def __set_dag_run_state_to_running_or_queued( raise ValueError(f"Received non-localized date {execution_date}") dag_run = dag.get_dagrun(execution_date=execution_date) if not dag_run: - raise ValueError(f'DagRun with execution_date: {execution_date} not found') + raise ValueError(f"DagRun with execution_date: {execution_date} not found") run_id = dag_run.run_id if not run_id: - raise ValueError(f'DagRun with run_id: {run_id} not found') + raise ValueError(f"DagRun with run_id: {run_id} not found") # Mark the dag run to running. if commit: _set_dag_run_state(dag.dag_id, run_id, new_state, session) @@ -539,11 +547,16 @@ def __set_dag_run_state_to_running_or_queued( def set_dag_run_state_to_running( *, dag: DAG, - execution_date: Optional[datetime] = None, - run_id: Optional[str] = None, + execution_date: datetime | None = None, + run_id: str | None = None, commit: bool = False, session: SASession = NEW_SESSION, -) -> List[TaskInstance]: +) -> list[TaskInstance]: + """ + Set the dag run's state to running. + + Set for a specific execution date and its task instances to running. + """ return __set_dag_run_state_to_running_or_queued( new_state=DagRunState.RUNNING, dag=dag, @@ -558,11 +571,16 @@ def set_dag_run_state_to_running( def set_dag_run_state_to_queued( *, dag: DAG, - execution_date: Optional[datetime] = None, - run_id: Optional[str] = None, + execution_date: datetime | None = None, + run_id: str | None = None, commit: bool = False, session: SASession = NEW_SESSION, -) -> List[TaskInstance]: +) -> list[TaskInstance]: + """ + Set the dag run's state to queued. + + Set for a specific execution date and its task instances to queued. + """ return __set_dag_run_state_to_running_or_queued( new_state=DagRunState.QUEUED, dag=dag, diff --git a/airflow/api/common/trigger_dag.py b/airflow/api/common/trigger_dag.py index 225efad683717..01da7745c7bf3 100644 --- a/airflow/api/common/trigger_dag.py +++ b/airflow/api/common/trigger_dag.py @@ -16,11 +16,10 @@ # specific language governing permissions and limitations # under the License. """Triggering DAG runs APIs.""" +from __future__ import annotations + import json from datetime import datetime -from typing import List, Optional, Union - -import pendulum from airflow.exceptions import DagNotFound, DagRunAlreadyExists from airflow.models import DagBag, DagModel, DagRun @@ -32,11 +31,11 @@ def _trigger_dag( dag_id: str, dag_bag: DagBag, - run_id: Optional[str] = None, - conf: Optional[Union[dict, str]] = None, - execution_date: Optional[datetime] = None, + run_id: str | None = None, + conf: dict | str | None = None, + execution_date: datetime | None = None, replace_microseconds: bool = True, -) -> List[Optional[DagRun]]: +) -> list[DagRun | None]: """Triggers DAG run. :param dag_id: DAG ID @@ -60,21 +59,23 @@ def _trigger_dag( if replace_microseconds: execution_date = execution_date.replace(microsecond=0) - if dag.default_args and 'start_date' in dag.default_args: + if dag.default_args and "start_date" in dag.default_args: min_dag_start_date = dag.default_args["start_date"] if min_dag_start_date and execution_date < min_dag_start_date: raise ValueError( f"The execution_date [{execution_date.isoformat()}] should be >= start_date " f"[{min_dag_start_date.isoformat()}] from DAG's default_args" ) + logical_date = timezone.coerce_datetime(execution_date) - run_id = run_id or DagRun.generate_run_id(DagRunType.MANUAL, execution_date) + data_interval = dag.timetable.infer_manual_data_interval(run_after=logical_date) + run_id = run_id or dag.timetable.generate_run_id( + run_type=DagRunType.MANUAL, logical_date=logical_date, data_interval=data_interval + ) dag_run = DagRun.find_duplicate(dag_id=dag_id, execution_date=execution_date, run_id=run_id) if dag_run: - raise DagRunAlreadyExists( - f"A Dag Run already exists for dag id {dag_id} at {execution_date} with run id {run_id}" - ) + raise DagRunAlreadyExists(dag_run=dag_run, execution_date=execution_date, run_id=run_id) run_conf = None if conf: @@ -90,9 +91,7 @@ def _trigger_dag( conf=run_conf, external_trigger=True, dag_hash=dag_bag.dags_hash.get(dag_id), - data_interval=_dag.timetable.infer_manual_data_interval( - run_after=pendulum.instance(execution_date) - ), + data_interval=data_interval, ) dag_runs.append(dag_run) @@ -101,12 +100,12 @@ def _trigger_dag( def trigger_dag( dag_id: str, - run_id: Optional[str] = None, - conf: Optional[Union[dict, str]] = None, - execution_date: Optional[datetime] = None, + run_id: str | None = None, + conf: dict | str | None = None, + execution_date: datetime | None = None, replace_microseconds: bool = True, -) -> Optional[DagRun]: - """Triggers execution of DAG specified by dag_id +) -> DagRun | None: + """Triggers execution of DAG specified by dag_id. :param dag_id: DAG ID :param run_id: ID of the dag_run diff --git a/airflow/api_connexion/endpoints/config_endpoint.py b/airflow/api_connexion/endpoints/config_endpoint.py index bdd2b3a959547..a9353234f0b4e 100644 --- a/airflow/api_connexion/endpoints/config_endpoint.py +++ b/airflow/api_connexion/endpoints/config_endpoint.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from http import HTTPStatus @@ -26,11 +27,11 @@ from airflow.security import permissions from airflow.settings import json -LINE_SEP = '\n' # `\n` cannot appear in f-strings +LINE_SEP = "\n" # `\n` cannot appear in f-strings def _conf_dict_to_config(conf_dict: dict) -> Config: - """Convert config dict to a Config object""" + """Convert config dict to a Config object.""" config = Config( sections=[ ConfigSection( @@ -43,25 +44,25 @@ def _conf_dict_to_config(conf_dict: dict) -> Config: def _option_to_text(config_option: ConfigOption) -> str: - """Convert a single config option to text""" - return f'{config_option.key} = {config_option.value}' + """Convert a single config option to text.""" + return f"{config_option.key} = {config_option.value}" def _section_to_text(config_section: ConfigSection) -> str: - """Convert a single config section to text""" + """Convert a single config section to text.""" return ( - f'[{config_section.name}]{LINE_SEP}' - f'{LINE_SEP.join(_option_to_text(option) for option in config_section.options)}{LINE_SEP}' + f"[{config_section.name}]{LINE_SEP}" + f"{LINE_SEP.join(_option_to_text(option) for option in config_section.options)}{LINE_SEP}" ) def _config_to_text(config: Config) -> str: - """Convert the entire config to text""" + """Convert the entire config to text.""" return LINE_SEP.join(_section_to_text(s) for s in config.sections) def _config_to_json(config: Config) -> str: - """Convert a Config object to a JSON formatted string""" + """Convert a Config object to a JSON formatted string.""" return json.dumps(config_schema.dump(config), indent=4) @@ -69,8 +70,8 @@ def _config_to_json(config: Config) -> str: def get_config() -> Response: """Get current configuration.""" serializer = { - 'text/plain': _config_to_text, - 'application/json': _config_to_json, + "text/plain": _config_to_text, + "application/json": _config_to_json, } return_type = request.accept_mimetypes.best_match(serializer.keys()) if return_type not in serializer: @@ -79,11 +80,11 @@ def get_config() -> Response: conf_dict = conf.as_dict(display_source=False, display_sensitive=True) config = _conf_dict_to_config(conf_dict) config_text = serializer[return_type](config) - return Response(config_text, headers={'Content-Type': return_type}) + return Response(config_text, headers={"Content-Type": return_type}) else: raise PermissionDenied( detail=( - 'Your Airflow administrator chose not to expose the configuration, most likely for security' - ' reasons.' + "Your Airflow administrator chose not to expose the configuration, most likely for security" + " reasons." ) ) diff --git a/airflow/api_connexion/endpoints/connection_endpoint.py b/airflow/api_connexion/endpoints/connection_endpoint.py index b196b3236b911..40cb474bda4ea 100644 --- a/airflow/api_connexion/endpoints/connection_endpoint.py +++ b/airflow/api_connexion/endpoints/connection_endpoint.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import os from http import HTTPStatus @@ -37,18 +38,28 @@ from airflow.models import Connection from airflow.secrets.environment_variables import CONN_ENV_PREFIX from airflow.security import permissions +from airflow.utils.log.action_logger import action_event_from_permission from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.strings import get_random_string +from airflow.www.decorators import action_logging + +RESOURCE_EVENT_PREFIX = "connection" @security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_CONNECTION)]) @provide_session +@action_logging( + event=action_event_from_permission( + prefix=RESOURCE_EVENT_PREFIX, + permission=permissions.ACTION_CAN_DELETE, + ), +) def delete_connection(*, connection_id: str, session: Session = NEW_SESSION) -> APIResponse: - """Delete a connection entry""" + """Delete a connection entry.""" connection = session.query(Connection).filter_by(conn_id=connection_id).one_or_none() if connection is None: raise NotFound( - 'Connection not found', + "Connection not found", detail=f"The Connection with connection_id: `{connection_id}` was not found", ) session.delete(connection) @@ -58,7 +69,7 @@ def delete_connection(*, connection_id: str, session: Session = NEW_SESSION) -> @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION)]) @provide_session def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> APIResponse: - """Get a connection entry""" + """Get a connection entry.""" connection = session.query(Connection).filter(Connection.conn_id == connection_id).one_or_none() if connection is None: raise NotFound( @@ -69,7 +80,7 @@ def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> API @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION)]) -@format_parameters({'limit': check_limit}) +@format_parameters({"limit": check_limit}) @provide_session def get_connections( *, @@ -78,9 +89,9 @@ def get_connections( order_by: str = "id", session: Session = NEW_SESSION, ) -> APIResponse: - """Get all connection entries""" + """Get all connection entries.""" to_replace = {"connection_id": "conn_id"} - allowed_filter_attrs = ['connection_id', 'conn_type', 'description', 'host', 'port', 'id'] + allowed_filter_attrs = ["connection_id", "conn_type", "description", "host", "port", "id"] total_entries = session.query(func.count(Connection.id)).scalar() query = session.query(Connection) @@ -93,26 +104,32 @@ def get_connections( @security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_CONNECTION)]) @provide_session +@action_logging( + event=action_event_from_permission( + prefix=RESOURCE_EVENT_PREFIX, + permission=permissions.ACTION_CAN_EDIT, + ), +) def patch_connection( *, connection_id: str, update_mask: UpdateMask = None, session: Session = NEW_SESSION, ) -> APIResponse: - """Update a connection entry""" + """Update a connection entry.""" try: data = connection_schema.load(request.json, partial=True) except ValidationError as err: # If validation get to here, it is extra field validation. raise BadRequest(detail=str(err.messages)) - non_update_fields = ['connection_id', 'conn_id'] + non_update_fields = ["connection_id", "conn_id"] connection = session.query(Connection).filter_by(conn_id=connection_id).first() if connection is None: raise NotFound( "Connection not found", detail=f"The Connection with connection_id: `{connection_id}` was not found", ) - if data.get('conn_id') and connection.conn_id != data['conn_id']: + if data.get("conn_id") and connection.conn_id != data["conn_id"]: raise BadRequest(detail="The connection_id cannot be updated.") if update_mask: update_mask = [i.strip() for i in update_mask] @@ -132,14 +149,20 @@ def patch_connection( @security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION)]) @provide_session +@action_logging( + event=action_event_from_permission( + prefix=RESOURCE_EVENT_PREFIX, + permission=permissions.ACTION_CAN_CREATE, + ), +) def post_connection(*, session: Session = NEW_SESSION) -> APIResponse: - """Create connection entry""" + """Create connection entry.""" body = request.json try: data = connection_schema.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) - conn_id = data['conn_id'] + conn_id = data["conn_id"] query = session.query(Connection) connection = query.filter_by(conn_id=conn_id).first() if not connection: @@ -153,16 +176,18 @@ def post_connection(*, session: Session = NEW_SESSION) -> APIResponse: @security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION)]) def test_connection() -> APIResponse: """ - To test a connection, this method first creates an in-memory dummy conn_id & exports that to an + Test an API connection. + + This method first creates an in-memory dummy conn_id & exports that to an env var, as some hook classes tries to find out the conn from their __init__ method & errors out if not found. It also deletes the conn id env variable after the test. """ body = request.json dummy_conn_id = get_random_string() - conn_env_var = f'{CONN_ENV_PREFIX}{dummy_conn_id.upper()}' + conn_env_var = f"{CONN_ENV_PREFIX}{dummy_conn_id.upper()}" try: data = connection_schema.load(body) - data['conn_id'] = dummy_conn_id + data["conn_id"] = dummy_conn_id conn = Connection(**data) os.environ[conn_env_var] = conn.get_uri() status, message = conn.test_connection() diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py index 0505f864ee333..be66dc814e95b 100644 --- a/airflow/api_connexion/endpoints/dag_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_endpoint.py @@ -14,12 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from http import HTTPStatus -from typing import Collection, Optional +from typing import Collection from connexion import NoContent -from flask import current_app, g, request +from flask import g, request from marshmallow import ValidationError from sqlalchemy.orm import Session from sqlalchemy.sql.expression import or_ @@ -38,6 +39,7 @@ from airflow.exceptions import AirflowException, DagNotFound from airflow.models.dag import DagModel, DagTag from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.session import NEW_SESSION, provide_session @@ -56,21 +58,21 @@ def get_dag(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)]) def get_dag_details(*, dag_id: str) -> APIResponse: """Get details of DAG.""" - dag: DAG = current_app.dag_bag.get_dag(dag_id) + dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: raise NotFound("DAG not found", detail=f"The DAG with dag_id: {dag_id} was not found") return dag_detail_schema.dump(dag) @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)]) -@format_parameters({'limit': check_limit}) +@format_parameters({"limit": check_limit}) @provide_session def get_dags( *, limit: int, offset: int = 0, - tags: Optional[Collection[str]] = None, - dag_id_pattern: Optional[str] = None, + tags: Collection[str] | None = None, + dag_id_pattern: str | None = None, only_active: bool = True, session: Session = NEW_SESSION, ) -> APIResponse: @@ -81,9 +83,9 @@ def get_dags( dags_query = session.query(DagModel).filter(~DagModel.is_subdag) if dag_id_pattern: - dags_query = dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%')) + dags_query = dags_query.filter(DagModel.dag_id.ilike(f"%{dag_id_pattern}%")) - readable_dags = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) + readable_dags = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) dags_query = dags_query.filter(DagModel.dag_id.in_(readable_dags)) if tags: @@ -100,27 +102,27 @@ def get_dags( @security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)]) @provide_session def patch_dag(*, dag_id: str, update_mask: UpdateMask = None, session: Session = NEW_SESSION) -> APIResponse: - """Update the specific DAG""" + """Update the specific DAG.""" try: patch_body = dag_schema.load(request.json, session=session) except ValidationError as err: raise BadRequest(detail=str(err.messages)) if update_mask: patch_body_ = {} - if update_mask != ['is_paused']: + if update_mask != ["is_paused"]: raise BadRequest(detail="Only `is_paused` field can be updated through the REST API") patch_body_[update_mask[0]] = patch_body[update_mask[0]] patch_body = patch_body_ dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one_or_none() if not dag: raise NotFound(f"Dag with id: '{dag_id}' not found") - dag.is_paused = patch_body['is_paused'] + dag.is_paused = patch_body["is_paused"] session.flush() return dag_schema.dump(dag) @security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)]) -@format_parameters({'limit': check_limit}) +@format_parameters({"limit": check_limit}) @provide_session def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pattern=None, update_mask=None): """Patch multiple DAGs.""" @@ -130,7 +132,7 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat raise BadRequest(detail=str(err.messages)) if update_mask: patch_body_ = {} - if update_mask != ['is_paused']: + if update_mask != ["is_paused"]: raise BadRequest(detail="Only `is_paused` field can be updated through the REST API") update_mask = update_mask[0] patch_body_[update_mask] = patch_body[update_mask] @@ -140,10 +142,10 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat else: dags_query = session.query(DagModel).filter(~DagModel.is_subdag) - if dag_id_pattern == '~': - dag_id_pattern = '%' - dags_query = dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%')) - editable_dags = current_app.appbuilder.sm.get_editable_dag_ids(g.user) + if dag_id_pattern == "~": + dag_id_pattern = "%" + dags_query = dags_query.filter(DagModel.dag_id.ilike(f"%{dag_id_pattern}%")) + editable_dags = get_airflow_app().appbuilder.sm.get_editable_dag_ids(g.user) dags_query = dags_query.filter(DagModel.dag_id.in_(editable_dags)) if tags: @@ -156,7 +158,7 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat dags_to_update = {dag.dag_id for dag in dags} session.query(DagModel).filter(DagModel.dag_id.in_(dags_to_update)).update( - {DagModel.is_paused: patch_body['is_paused']}, synchronize_session='fetch' + {DagModel.is_paused: patch_body["is_paused"]}, synchronize_session="fetch" ) session.flush() diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index e510126534b12..b566ba2df72f6 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -14,12 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from http import HTTPStatus -from typing import List, Optional, Tuple import pendulum from connexion import NoContent -from flask import current_app, g, request +from flask import g from marshmallow import ValidationError from sqlalchemy import or_ from sqlalchemy.orm import Query, Session @@ -30,6 +31,7 @@ set_dag_run_state_to_success, ) from airflow.api_connexion import security +from airflow.api_connexion.endpoints.request_dict import get_json_request_dict from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound from airflow.api_connexion.parameters import apply_sorting, check_limit, format_datetime, format_parameters from airflow.api_connexion.schemas.dag_run_schema import ( @@ -38,8 +40,13 @@ dagrun_collection_schema, dagrun_schema, dagruns_batch_form_schema, + set_dagrun_note_form_schema, set_dagrun_state_form_schema, ) +from airflow.api_connexion.schemas.dataset_schema import ( + DatasetEventCollection, + dataset_event_collection_schema, +) from airflow.api_connexion.schemas.task_instance_schema import ( TaskInstanceReferenceCollection, task_instance_reference_collection_schema, @@ -47,6 +54,7 @@ from airflow.api_connexion.types import APIResponse from airflow.models import DagModel, DagRun from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType @@ -60,7 +68,7 @@ ) @provide_session def delete_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: - """Delete a DAG Run""" + """Delete a DAG Run.""" if session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).delete() == 0: raise NotFound(detail=f"DAGRun with DAG ID: '{dag_id}' and DagRun ID: '{dag_run_id}' not found") return NoContent, HTTPStatus.NO_CONTENT @@ -84,19 +92,50 @@ def get_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) return dagrun_schema.dump(dag_run) +@security.requires_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET), + ], +) +@provide_session +def get_upstream_dataset_events( + *, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION +) -> APIResponse: + """If dag run is dataset-triggered, return the dataset events that triggered it.""" + dag_run: DagRun | None = ( + session.query(DagRun) + .filter( + DagRun.dag_id == dag_id, + DagRun.run_id == dag_run_id, + ) + .one_or_none() + ) + if dag_run is None: + raise NotFound( + "DAGRun not found", + detail=f"DAGRun with DAG ID: '{dag_id}' and DagRun ID: '{dag_run_id}' not found", + ) + events = dag_run.consumed_dataset_events + return dataset_event_collection_schema.dump( + DatasetEventCollection(dataset_events=events, total_entries=len(events)) + ) + + def _fetch_dag_runs( query: Query, *, - end_date_gte: Optional[str], - end_date_lte: Optional[str], - execution_date_gte: Optional[str], - execution_date_lte: Optional[str], - start_date_gte: Optional[str], - start_date_lte: Optional[str], - limit: Optional[int], - offset: Optional[int], + end_date_gte: str | None, + end_date_lte: str | None, + execution_date_gte: str | None, + execution_date_lte: str | None, + start_date_gte: str | None, + start_date_lte: str | None, + limit: int | None, + offset: int | None, order_by: str, -) -> Tuple[List[DagRun], int]: +) -> tuple[list[DagRun], int]: if start_date_gte: query = query.filter(DagRun.start_date >= start_date_gte) if start_date_lte: @@ -137,28 +176,28 @@ def _fetch_dag_runs( ) @format_parameters( { - 'start_date_gte': format_datetime, - 'start_date_lte': format_datetime, - 'execution_date_gte': format_datetime, - 'execution_date_lte': format_datetime, - 'end_date_gte': format_datetime, - 'end_date_lte': format_datetime, - 'limit': check_limit, + "start_date_gte": format_datetime, + "start_date_lte": format_datetime, + "execution_date_gte": format_datetime, + "execution_date_lte": format_datetime, + "end_date_gte": format_datetime, + "end_date_lte": format_datetime, + "limit": check_limit, } ) @provide_session def get_dag_runs( *, dag_id: str, - start_date_gte: Optional[str] = None, - start_date_lte: Optional[str] = None, - execution_date_gte: Optional[str] = None, - execution_date_lte: Optional[str] = None, - end_date_gte: Optional[str] = None, - end_date_lte: Optional[str] = None, - state: Optional[List[str]] = None, - offset: Optional[int] = None, - limit: Optional[int] = None, + start_date_gte: str | None = None, + start_date_lte: str | None = None, + execution_date_gte: str | None = None, + execution_date_lte: str | None = None, + end_date_gte: str | None = None, + end_date_lte: str | None = None, + state: list[str] | None = None, + offset: int | None = None, + limit: int | None = None, order_by: str = "id", session: Session = NEW_SESSION, ): @@ -167,7 +206,7 @@ def get_dag_runs( # This endpoint allows specifying ~ as the dag_id to retrieve DAG Runs for all DAGs. if dag_id == "~": - appbuilder = current_app.appbuilder + appbuilder = get_airflow_app().appbuilder query = query.filter(DagRun.dag_id.in_(appbuilder.sm.get_readable_dag_ids(g.user))) else: query = query.filter(DagRun.dag_id == dag_id) @@ -198,14 +237,14 @@ def get_dag_runs( ) @provide_session def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse: - """Get list of DAG Runs""" - body = request.get_json() + """Get list of DAG Runs.""" + body = get_json_request_dict() try: data = dagruns_batch_form_schema.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) - appbuilder = current_app.appbuilder + appbuilder = get_airflow_app().appbuilder readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user) query = session.query(DagRun) if data.get("dag_ids"): @@ -252,7 +291,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: detail=f"DAG with dag_id: '{dag_id}' has import errors", ) try: - post_body = dagrun_schema.load(request.json, session=session) + post_body = dagrun_schema.load(get_json_request_dict(), session=session) except ValidationError as err: raise BadRequest(detail=str(err)) @@ -268,7 +307,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: ) if not dagrun_instance: try: - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) dag_run = dag.create_dagrun( run_type=DagRunType.MANUAL, run_id=run_id, @@ -277,7 +316,8 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: state=DagRunState.QUEUED, conf=post_body.get("conf"), external_trigger=True, - dag_hash=current_app.dag_bag.dags_hash.get(dag_id), + dag_hash=get_airflow_app().dag_bag.dags_hash.get(dag_id), + session=session, ) return dagrun_schema.dump(dag_run) except ValueError as ve: @@ -303,19 +343,19 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: @provide_session def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: """Set a state of a dag run.""" - dag_run: Optional[DagRun] = ( + dag_run: DagRun | None = ( session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none() ) if dag_run is None: - error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' + error_message = f"Dag Run id {dag_run_id} not found in dag {dag_id}" raise NotFound(error_message) try: - post_body = set_dagrun_state_form_schema.load(request.json) + post_body = set_dagrun_state_form_schema.load(get_json_request_dict()) except ValidationError as err: raise BadRequest(detail=str(err)) - state = post_body['state'] - dag = current_app.dag_bag.get_dag(dag_id) + state = post_body["state"] + dag = get_airflow_app().dag_bag.get_dag(dag_id) if state == DagRunState.SUCCESS: set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, commit=True) elif state == DagRunState.QUEUED: @@ -335,19 +375,19 @@ def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW @provide_session def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: """Clear a dag run.""" - dag_run: Optional[DagRun] = ( + dag_run: DagRun | None = ( session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none() ) if dag_run is None: - error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' + error_message = f"Dag Run id {dag_run_id} not found in dag {dag_id}" raise NotFound(error_message) try: - post_body = clear_dagrun_form_schema.load(request.json) + post_body = clear_dagrun_form_schema.load(get_json_request_dict()) except ValidationError as err: raise BadRequest(detail=str(err)) - dry_run = post_body.get('dry_run', False) - dag = current_app.dag_bag.get_dag(dag_id) + dry_run = post_body.get("dry_run", False) + dag = get_airflow_app().dag_bag.get_dag(dag_id) start_date = dag_run.logical_date end_date = dag_run.logical_date @@ -373,5 +413,38 @@ def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSIO include_parentdag=True, only_failed=False, ) - dag_run.refresh_from_db() + dag_run = session.query(DagRun).filter(DagRun.id == dag_run.id).one() return dagrun_schema.dump(dag_run) + + +@security.requires_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), + ], +) +@provide_session +def set_dag_run_note(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: + """Set the note for a dag run.""" + dag_run: DagRun | None = ( + session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none() + ) + if dag_run is None: + error_message = f"Dag Run id {dag_run_id} not found in dag {dag_id}" + raise NotFound(error_message) + try: + post_body = set_dagrun_note_form_schema.load(get_json_request_dict()) + new_note = post_body["note"] + except ValidationError as err: + raise BadRequest(detail=str(err)) + + from flask_login import current_user + + current_user_id = getattr(current_user, "id", None) + if dag_run.dag_run_note is None: + dag_run.note = (new_note, current_user_id) + else: + dag_run.dag_run_note.content = new_note + dag_run.dag_run_note.user_id = current_user_id + session.commit() + return dagrun_schema.dump(dag_run) diff --git a/airflow/api_connexion/endpoints/dag_source_endpoint.py b/airflow/api_connexion/endpoints/dag_source_endpoint.py index ad6209221e523..42ccd4e5d8671 100644 --- a/airflow/api_connexion/endpoints/dag_source_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_source_endpoint.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from http import HTTPStatus @@ -29,7 +30,7 @@ @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)]) def get_dag_source(*, file_token: str) -> Response: - """Get source code using file token""" + """Get source code using file token.""" secret_key = current_app.config["SECRET_KEY"] auth_s = URLSafeSerializer(secret_key) try: @@ -38,10 +39,10 @@ def get_dag_source(*, file_token: str) -> Response: except (BadSignature, FileNotFoundError): raise NotFound("Dag source not found") - return_type = request.accept_mimetypes.best_match(['text/plain', 'application/json']) - if return_type == 'text/plain': - return Response(dag_source, headers={'Content-Type': return_type}) - if return_type == 'application/json': + return_type = request.accept_mimetypes.best_match(["text/plain", "application/json"]) + if return_type == "text/plain": + return Response(dag_source, headers={"Content-Type": return_type}) + if return_type == "application/json": content = dag_source_schema.dumps(dict(content=dag_source)) - return Response(content, headers={'Content-Type': return_type}) + return Response(content, headers={"Content-Type": return_type}) return Response("Not Allowed Accept Header", status=HTTPStatus.NOT_ACCEPTABLE) diff --git a/airflow/api_connexion/endpoints/dag_warning_endpoint.py b/airflow/api_connexion/endpoints/dag_warning_endpoint.py new file mode 100644 index 0000000000000..5a73afd1a33a8 --- /dev/null +++ b/airflow/api_connexion/endpoints/dag_warning_endpoint.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from sqlalchemy.orm import Session + +from airflow.api_connexion import security +from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters +from airflow.api_connexion.schemas.dag_warning_schema import ( + DagWarningCollection, + dag_warning_collection_schema, +) +from airflow.api_connexion.types import APIResponse +from airflow.models.dagwarning import DagWarning as DagWarningModel +from airflow.security import permissions +from airflow.utils.session import NEW_SESSION, provide_session + + +@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING)]) +@format_parameters({"limit": check_limit}) +@provide_session +def get_dag_warnings( + *, + limit: int, + dag_id: str | None = None, + warning_type: str | None = None, + offset: int | None = None, + order_by: str = "timestamp", + session: Session = NEW_SESSION, +) -> APIResponse: + """Get DAG warnings. + + :param dag_id: the dag_id to optionally filter by + :param warning_type: the warning type to optionally filter by + """ + allowed_filter_attrs = ["dag_id", "warning_type", "message", "timestamp"] + query = session.query(DagWarningModel) + if dag_id: + query = query.filter(DagWarningModel.dag_id == dag_id) + if warning_type: + query = query.filter(DagWarningModel.warning_type == warning_type) + total_entries = query.count() + query = apply_sorting(query=query, order_by=order_by, allowed_attrs=allowed_filter_attrs) + dag_warnings = query.offset(offset).limit(limit).all() + return dag_warning_collection_schema.dump( + DagWarningCollection(dag_warnings=dag_warnings, total_entries=total_entries) + ) diff --git a/airflow/api_connexion/endpoints/dataset_endpoint.py b/airflow/api_connexion/endpoints/dataset_endpoint.py new file mode 100644 index 0000000000000..42e8bb3c36c33 --- /dev/null +++ b/airflow/api_connexion/endpoints/dataset_endpoint.py @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from sqlalchemy import func +from sqlalchemy.orm import Session, joinedload, subqueryload + +from airflow.api_connexion import security +from airflow.api_connexion.exceptions import NotFound +from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters +from airflow.api_connexion.schemas.dataset_schema import ( + DatasetCollection, + DatasetEventCollection, + dataset_collection_schema, + dataset_event_collection_schema, + dataset_schema, +) +from airflow.api_connexion.types import APIResponse +from airflow.models.dataset import DatasetEvent, DatasetModel +from airflow.security import permissions +from airflow.utils.session import NEW_SESSION, provide_session + + +@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)]) +@provide_session +def get_dataset(uri: str, session: Session = NEW_SESSION) -> APIResponse: + """Get a Dataset.""" + dataset = ( + session.query(DatasetModel) + .filter(DatasetModel.uri == uri) + .options(joinedload(DatasetModel.consuming_dags), joinedload(DatasetModel.producing_tasks)) + .one_or_none() + ) + if not dataset: + raise NotFound( + "Dataset not found", + detail=f"The Dataset with uri: `{uri}` was not found", + ) + return dataset_schema.dump(dataset) + + +@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)]) +@format_parameters({"limit": check_limit}) +@provide_session +def get_datasets( + *, + limit: int, + offset: int = 0, + uri_pattern: str | None = None, + order_by: str = "id", + session: Session = NEW_SESSION, +) -> APIResponse: + """Get datasets.""" + allowed_attrs = ["id", "uri", "created_at", "updated_at"] + + total_entries = session.query(func.count(DatasetModel.id)).scalar() + query = session.query(DatasetModel) + if uri_pattern: + query = query.filter(DatasetModel.uri.ilike(f"%{uri_pattern}%")) + query = apply_sorting(query, order_by, {}, allowed_attrs) + datasets = ( + query.options(subqueryload(DatasetModel.consuming_dags), subqueryload(DatasetModel.producing_tasks)) + .offset(offset) + .limit(limit) + .all() + ) + return dataset_collection_schema.dump(DatasetCollection(datasets=datasets, total_entries=total_entries)) + + +@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)]) +@provide_session +@format_parameters({"limit": check_limit}) +def get_dataset_events( + *, + limit: int, + offset: int = 0, + order_by: str = "timestamp", + dataset_id: int | None = None, + source_dag_id: str | None = None, + source_task_id: str | None = None, + source_run_id: str | None = None, + source_map_index: int | None = None, + session: Session = NEW_SESSION, +) -> APIResponse: + """Get dataset events.""" + allowed_attrs = ["source_dag_id", "source_task_id", "source_run_id", "source_map_index", "timestamp"] + + query = session.query(DatasetEvent) + + if dataset_id: + query = query.filter(DatasetEvent.dataset_id == dataset_id) + if source_dag_id: + query = query.filter(DatasetEvent.source_dag_id == source_dag_id) + if source_task_id: + query = query.filter(DatasetEvent.source_task_id == source_task_id) + if source_run_id: + query = query.filter(DatasetEvent.source_run_id == source_run_id) + if source_map_index: + query = query.filter(DatasetEvent.source_map_index == source_map_index) + + query = query.options(subqueryload(DatasetEvent.created_dagruns)) + + total_entries = query.count() + query = apply_sorting(query, order_by, {}, allowed_attrs) + events = query.offset(offset).limit(limit).all() + return dataset_event_collection_schema.dump( + DatasetEventCollection(dataset_events=events, total_entries=total_entries) + ) diff --git a/airflow/api_connexion/endpoints/event_log_endpoint.py b/airflow/api_connexion/endpoints/event_log_endpoint.py index 590190ab98a20..94fa73a431597 100644 --- a/airflow/api_connexion/endpoints/event_log_endpoint.py +++ b/airflow/api_connexion/endpoints/event_log_endpoint.py @@ -14,8 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -from typing import Optional +from __future__ import annotations from sqlalchemy import func from sqlalchemy.orm import Session @@ -37,7 +36,7 @@ @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)]) @provide_session def get_event_log(*, event_log_id: int, session: Session = NEW_SESSION) -> APIResponse: - """Get a log entry""" + """Get a log entry.""" event_log = session.query(Log).get(event_log_id) if event_log is None: raise NotFound("Event Log not found") @@ -50,14 +49,14 @@ def get_event_log(*, event_log_id: int, session: Session = NEW_SESSION) -> APIRe def get_event_logs( *, limit: int, - offset: Optional[int] = None, + offset: int | None = None, order_by: str = "event_log_id", session: Session = NEW_SESSION, ) -> APIResponse: - """Get all log entries from event log""" + """Get all log entries from event log.""" to_replace = {"event_log_id": "id", "when": "dttm"} allowed_filter_attrs = [ - 'event_log_id', + "event_log_id", "when", "dag_id", "task_id", diff --git a/airflow/api_connexion/endpoints/extra_link_endpoint.py b/airflow/api_connexion/endpoints/extra_link_endpoint.py index 3e9535603bda3..2b12667e7cdfa 100644 --- a/airflow/api_connexion/endpoints/extra_link_endpoint.py +++ b/airflow/api_connexion/endpoints/extra_link_endpoint.py @@ -14,8 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from flask import current_app from sqlalchemy.orm.session import Session from airflow import DAG @@ -25,6 +25,7 @@ from airflow.exceptions import TaskNotFound from airflow.models.dagbag import DagBag from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.session import NEW_SESSION, provide_session @@ -43,10 +44,10 @@ def get_extra_links( task_id: str, session: Session = NEW_SESSION, ) -> APIResponse: - """Get extra links for task instance""" + """Get extra links for task instance.""" from airflow.models.taskinstance import TaskInstance - dagbag: DagBag = current_app.dag_bag + dagbag: DagBag = get_airflow_app().dag_bag dag: DAG = dagbag.get_dag(dag_id) if not dag: raise NotFound("DAG not found", detail=f'DAG with ID = "{dag_id}" not found') @@ -73,6 +74,6 @@ def get_extra_links( (link_name, task.get_extra_links(ti, link_name)) for link_name in task.extra_links ) all_extra_links = { - link_name: link_url if link_url else None for link_name, link_url in all_extra_link_pairs + link_name: link_url if link_url else None for link_name, link_url in sorted(all_extra_link_pairs) } return all_extra_links diff --git a/airflow/api_connexion/endpoints/health_endpoint.py b/airflow/api_connexion/endpoints/health_endpoint.py index 380225bf16e4d..f833a5d72815b 100644 --- a/airflow/api_connexion/endpoints/health_endpoint.py +++ b/airflow/api_connexion/endpoints/health_endpoint.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from airflow.api_connexion.schemas.health_schema import health_schema from airflow.api_connexion.types import APIResponse @@ -24,7 +25,7 @@ def get_health() -> APIResponse: - """Return the health of the airflow scheduler and metadatabase""" + """Return the health of the airflow scheduler and metadatabase.""" metadatabase_status = HEALTHY latest_scheduler_heartbeat = None scheduler_status = UNHEALTHY diff --git a/airflow/api_connexion/endpoints/import_error_endpoint.py b/airflow/api_connexion/endpoints/import_error_endpoint.py index 9a46b3aac4392..f5798fe99e8c4 100644 --- a/airflow/api_connexion/endpoints/import_error_endpoint.py +++ b/airflow/api_connexion/endpoints/import_error_endpoint.py @@ -14,8 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -from typing import Optional +from __future__ import annotations from sqlalchemy import func from sqlalchemy.orm import Session @@ -37,7 +36,7 @@ @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)]) @provide_session def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) -> APIResponse: - """Get an import error""" + """Get an import error.""" error = session.query(ImportErrorModel).get(import_error_id) if error is None: @@ -49,18 +48,18 @@ def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) -> @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)]) -@format_parameters({'limit': check_limit}) +@format_parameters({"limit": check_limit}) @provide_session def get_import_errors( *, limit: int, - offset: Optional[int] = None, + offset: int | None = None, order_by: str = "import_error_id", session: Session = NEW_SESSION, ) -> APIResponse: - """Get all import errors""" - to_replace = {"import_error_id": 'id'} - allowed_filter_attrs = ['import_error_id', "timestamp", "filename"] + """Get all import errors.""" + to_replace = {"import_error_id": "id"} + allowed_filter_attrs = ["import_error_id", "timestamp", "filename"] total_entries = session.query(func.count(ImportErrorModel.id)).scalar() query = session.query(ImportErrorModel) query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs) diff --git a/airflow/api_connexion/endpoints/log_endpoint.py b/airflow/api_connexion/endpoints/log_endpoint.py index f1335fe527451..388b164727e90 100644 --- a/airflow/api_connexion/endpoints/log_endpoint.py +++ b/airflow/api_connexion/endpoints/log_endpoint.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import Any, Optional +from typing import Any -from flask import Response, current_app, request +from flask import Response, request from itsdangerous.exc import BadSignature from itsdangerous.url_safe import URLSafeSerializer from sqlalchemy.orm.session import Session @@ -29,6 +30,7 @@ from airflow.exceptions import TaskNotFound from airflow.models import TaskInstance from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.log.log_reader import TaskLogReader from airflow.utils.session import NEW_SESSION, provide_session @@ -48,11 +50,12 @@ def get_log( task_id: str, task_try_number: int, full_content: bool = False, - token: Optional[str] = None, + map_index: int = -1, + token: str | None = None, session: Session = NEW_SESSION, ) -> APIResponse: - """Get logs for specific task instance""" - key = current_app.config["SECRET_KEY"] + """Get logs for specific task instance.""" + key = get_airflow_app().config["SECRET_KEY"] if not token: metadata = {} else: @@ -61,47 +64,48 @@ def get_log( except BadSignature: raise BadRequest("Bad Signature. Please use only the tokens provided by the API.") - if metadata.get('download_logs') and metadata['download_logs']: + if metadata.get("download_logs") and metadata["download_logs"]: full_content = True if full_content: - metadata['download_logs'] = True + metadata["download_logs"] = True else: - metadata['download_logs'] = False + metadata["download_logs"] = False task_log_reader = TaskLogReader() if not task_log_reader.supports_read: raise BadRequest("Task log handler does not support read logs.") - ti = ( session.query(TaskInstance) .filter( TaskInstance.task_id == task_id, TaskInstance.dag_id == dag_id, TaskInstance.run_id == dag_run_id, + TaskInstance.map_index == map_index, ) .join(TaskInstance.dag_run) .one_or_none() ) if ti is None: - metadata['end_of_log'] = True + metadata["end_of_log"] = True raise NotFound(title="TaskInstance not found") - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if dag: try: ti.task = dag.get_task(ti.task_id) except TaskNotFound: pass - return_type = request.accept_mimetypes.best_match(['text/plain', 'application/json']) + return_type = request.accept_mimetypes.best_match(["text/plain", "application/json"]) # return_type would be either the above two or None logs: Any - if return_type == 'application/json' or return_type is None: # default + if return_type == "application/json" or return_type is None: # default logs, metadata = task_log_reader.read_log_chunks(ti, task_try_number, metadata) logs = logs[0] if task_try_number is not None else logs - token = URLSafeSerializer(key).dumps(metadata) + # we must have token here, so we can safely ignore it + token = URLSafeSerializer(key).dumps(metadata) # type: ignore[assignment] return logs_schema.dump(LogResponseObject(continuation_token=token, content=logs)) # text/plain. Stream logs = task_log_reader.read_log_stream(ti, task_try_number, metadata) diff --git a/airflow/api_connexion/endpoints/plugin_endpoint.py b/airflow/api_connexion/endpoints/plugin_endpoint.py index 4b4b17bb96105..a2febda42c9b8 100644 --- a/airflow/api_connexion/endpoints/plugin_endpoint.py +++ b/airflow/api_connexion/endpoints/plugin_endpoint.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from airflow.api_connexion import security from airflow.api_connexion.parameters import check_limit, format_parameters from airflow.api_connexion.schemas.plugin_schema import PluginCollection, plugin_collection_schema @@ -25,9 +27,7 @@ @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_PLUGIN)]) @format_parameters({"limit": check_limit}) def get_plugins(*, limit: int, offset: int = 0) -> APIResponse: - """Get plugins endpoint""" + """Get plugins endpoint.""" plugins_info = get_plugin_info() - total_entries = len(plugins_info) - plugins_info = plugins_info[offset:] - plugins_info = plugins_info[:limit] - return plugin_collection_schema.dump(PluginCollection(plugins=plugins_info, total_entries=total_entries)) + collection = PluginCollection(plugins=plugins_info[offset:][:limit], total_entries=len(plugins_info)) + return plugin_collection_schema.dump(collection) diff --git a/airflow/api_connexion/endpoints/pool_endpoint.py b/airflow/api_connexion/endpoints/pool_endpoint.py index 1d24fea63d756..e6f7903e0ca4d 100644 --- a/airflow/api_connexion/endpoints/pool_endpoint.py +++ b/airflow/api_connexion/endpoints/pool_endpoint.py @@ -14,17 +14,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from http import HTTPStatus -from typing import Optional -from flask import Response, request +from flask import Response from marshmallow import ValidationError from sqlalchemy import func from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from airflow.api_connexion import security +from airflow.api_connexion.endpoints.request_dict import get_json_request_dict from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters from airflow.api_connexion.schemas.pool_schema import PoolCollection, pool_collection_schema, pool_schema @@ -37,7 +38,7 @@ @security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_POOL)]) @provide_session def delete_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse: - """Delete a pool""" + """Delete a pool.""" if pool_name == "default_pool": raise BadRequest(detail="Default Pool can't be deleted") affected_count = session.query(Pool).filter(Pool.pool == pool_name).delete() @@ -49,7 +50,7 @@ def delete_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIRespons @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_POOL)]) @provide_session def get_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse: - """Get a pool""" + """Get a pool.""" obj = session.query(Pool).filter(Pool.pool == pool_name).one_or_none() if obj is None: raise NotFound(detail=f"Pool with name:'{pool_name}' not found") @@ -63,12 +64,12 @@ def get_pools( *, limit: int, order_by: str = "id", - offset: Optional[int] = None, + offset: int | None = None, session: Session = NEW_SESSION, ) -> APIResponse: - """Get all pools""" + """Get all pools.""" to_replace = {"name": "pool"} - allowed_filter_attrs = ['name', 'slots', "id"] + allowed_filter_attrs = ["name", "slots", "id"] total_entries = session.query(func.count(Pool.id)).scalar() query = session.query(Pool) query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs) @@ -84,10 +85,11 @@ def patch_pool( update_mask: UpdateMask = None, session: Session = NEW_SESSION, ) -> APIResponse: - """Update a pool""" + """Update a pool.""" + request_dict = get_json_request_dict() # Only slots can be modified in 'default_pool' try: - if pool_name == Pool.DEFAULT_POOL_NAME and request.json["name"] != Pool.DEFAULT_POOL_NAME: + if pool_name == Pool.DEFAULT_POOL_NAME and request_dict["name"] != Pool.DEFAULT_POOL_NAME: if update_mask and len(update_mask) == 1 and update_mask[0].strip() == "slots": pass else: @@ -100,7 +102,7 @@ def patch_pool( raise NotFound(detail=f"Pool with name:'{pool_name}' not found") try: - patch_body = pool_schema.load(request.json) + patch_body = pool_schema.load(request_dict) except ValidationError as err: raise BadRequest(detail=str(err.messages)) @@ -121,7 +123,7 @@ def patch_pool( else: required_fields = {"name", "slots"} - fields_diff = required_fields - set(request.json.keys()) + fields_diff = required_fields - set(get_json_request_dict().keys()) if fields_diff: raise BadRequest(detail=f"Missing required property(ies): {sorted(fields_diff)}") @@ -134,14 +136,14 @@ def patch_pool( @security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_POOL)]) @provide_session def post_pool(*, session: Session = NEW_SESSION) -> APIResponse: - """Create a pool""" + """Create a pool.""" required_fields = {"name", "slots"} # Pool would require both fields in the post request - fields_diff = required_fields - set(request.json.keys()) + fields_diff = required_fields - set(get_json_request_dict().keys()) if fields_diff: raise BadRequest(detail=f"Missing required property(ies): {sorted(fields_diff)}") try: - post_body = pool_schema.load(request.json, session=session) + post_body = pool_schema.load(get_json_request_dict(), session=session) except ValidationError as err: raise BadRequest(detail=str(err.messages)) diff --git a/airflow/api_connexion/endpoints/provider_endpoint.py b/airflow/api_connexion/endpoints/provider_endpoint.py index 7526e284beb06..c829d9c968d61 100644 --- a/airflow/api_connexion/endpoints/provider_endpoint.py +++ b/airflow/api_connexion/endpoints/provider_endpoint.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import re @@ -42,7 +43,7 @@ def _provider_mapper(provider: ProviderInfo) -> Provider: @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_PROVIDER)]) def get_providers() -> APIResponse: - """Get providers""" + """Get providers.""" providers = [_provider_mapper(d) for d in ProvidersManager().providers.values()] total_entries = len(providers) return provider_collection_schema.dump( diff --git a/airflow/api_connexion/endpoints/request_dict.py b/airflow/api_connexion/endpoints/request_dict.py new file mode 100644 index 0000000000000..b07e06c0b63f8 --- /dev/null +++ b/airflow/api_connexion/endpoints/request_dict.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any, Mapping, cast + + +def get_json_request_dict() -> Mapping[str, Any]: + """Cast request dictionary to JSON.""" + from flask import request + + return cast(Mapping[str, Any], request.get_json()) diff --git a/airflow/api_connexion/endpoints/role_and_permission_endpoint.py b/airflow/api_connexion/endpoints/role_and_permission_endpoint.py index 88a68341c129e..4ed40caae508c 100644 --- a/airflow/api_connexion/endpoints/role_and_permission_endpoint.py +++ b/airflow/api_connexion/endpoints/role_and_permission_endpoint.py @@ -14,12 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from http import HTTPStatus -from typing import List, Optional, Tuple from connexion import NoContent -from flask import current_app, request +from flask import request from marshmallow import ValidationError from sqlalchemy import asc, desc, func @@ -35,13 +35,14 @@ ) from airflow.api_connexion.types import APIResponse, UpdateMask from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.www.fab_security.sqla.models import Action, Role from airflow.www.security import AirflowSecurityManager -def _check_action_and_resource(sm: AirflowSecurityManager, perms: List[Tuple[str, str]]) -> None: +def _check_action_and_resource(sm: AirflowSecurityManager, perms: list[tuple[str, str]]) -> None: """ - Checks if the action or resource exists and raise 400 if not + Checks if the action or resource exists and otherwise raise 400. This function is intended for use in the REST API because it raise 400 """ @@ -54,8 +55,8 @@ def _check_action_and_resource(sm: AirflowSecurityManager, perms: List[Tuple[str @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_ROLE)]) def get_role(*, role_name: str) -> APIResponse: - """Get role""" - ab_security_manager = current_app.appbuilder.sm + """Get role.""" + ab_security_manager = get_airflow_app().appbuilder.sm role = ab_security_manager.find_role(name=role_name) if not role: raise NotFound(title="Role not found", detail=f"Role with name {role_name!r} was not found") @@ -64,9 +65,9 @@ def get_role(*, role_name: str) -> APIResponse: @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_ROLE)]) @format_parameters({"limit": check_limit}) -def get_roles(*, order_by: str = "name", limit: int, offset: Optional[int] = None) -> APIResponse: - """Get roles""" - appbuilder = current_app.appbuilder +def get_roles(*, order_by: str = "name", limit: int, offset: int | None = None) -> APIResponse: + """Get roles.""" + appbuilder = get_airflow_app().appbuilder session = appbuilder.get_session total_entries = session.query(func.count(Role.id)).scalar() direction = desc if order_by.startswith("-") else asc @@ -87,10 +88,10 @@ def get_roles(*, order_by: str = "name", limit: int, offset: Optional[int] = Non @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_ACTION)]) -@format_parameters({'limit': check_limit}) -def get_permissions(*, limit: int, offset: Optional[int] = None) -> APIResponse: - """Get permissions""" - session = current_app.appbuilder.get_session +@format_parameters({"limit": check_limit}) +def get_permissions(*, limit: int, offset: int | None = None) -> APIResponse: + """Get permissions.""" + session = get_airflow_app().appbuilder.get_session total_entries = session.query(func.count(Action.id)).scalar() query = session.query(Action) actions = query.offset(offset).limit(limit).all() @@ -99,8 +100,8 @@ def get_permissions(*, limit: int, offset: Optional[int] = None) -> APIResponse: @security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_ROLE)]) def delete_role(*, role_name: str) -> APIResponse: - """Delete a role""" - ab_security_manager = current_app.appbuilder.sm + """Delete a role.""" + ab_security_manager = get_airflow_app().appbuilder.sm role = ab_security_manager.find_role(name=role_name) if not role: raise NotFound(title="Role not found", detail=f"Role with name {role_name!r} was not found") @@ -110,8 +111,8 @@ def delete_role(*, role_name: str) -> APIResponse: @security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_ROLE)]) def patch_role(*, role_name: str, update_mask: UpdateMask = None) -> APIResponse: - """Update a role""" - appbuilder = current_app.appbuilder + """Update a role.""" + appbuilder = get_airflow_app().appbuilder security_manager = appbuilder.sm body = request.json try: @@ -128,7 +129,7 @@ def patch_role(*, role_name: str, update_mask: UpdateMask = None) -> APIResponse if field in data and not field == "permissions": data_[field] = data[field] elif field == "actions": - data_["permissions"] = data['permissions'] + data_["permissions"] = data["permissions"] else: raise BadRequest(detail=f"'{field}' in update_mask is unknown") data = data_ @@ -144,17 +145,17 @@ def patch_role(*, role_name: str, update_mask: UpdateMask = None) -> APIResponse @security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_ROLE)]) def post_role() -> APIResponse: - """Create a new role""" - appbuilder = current_app.appbuilder + """Create a new role.""" + appbuilder = get_airflow_app().appbuilder security_manager = appbuilder.sm body = request.json try: data = role_schema.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) - role = security_manager.find_role(name=data['name']) + role = security_manager.find_role(name=data["name"]) if not role: - perms = [(item['action']['name'], item['resource']['name']) for item in data['permissions'] if item] + perms = [(item["action"]["name"], item["resource"]["name"]) for item in data["permissions"] if item] _check_action_and_resource(security_manager, perms) security_manager.bulk_sync_roles([{"role": data["name"], "perms": perms}]) return role_schema.dump(role) diff --git a/airflow/api_connexion/endpoints/task_endpoint.py b/airflow/api_connexion/endpoints/task_endpoint.py index 28c39b000c28d..23c2b32487b31 100644 --- a/airflow/api_connexion/endpoints/task_endpoint.py +++ b/airflow/api_connexion/endpoints/task_endpoint.py @@ -14,9 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from operator import attrgetter +from __future__ import annotations -from flask import current_app +from operator import attrgetter from airflow import DAG from airflow.api_connexion import security @@ -25,6 +25,7 @@ from airflow.api_connexion.types import APIResponse from airflow.exceptions import TaskNotFound from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app @security.requires_access( @@ -35,7 +36,7 @@ ) def get_task(*, dag_id: str, task_id: str) -> APIResponse: """Get simplified representation of a task.""" - dag: DAG = current_app.dag_bag.get_dag(dag_id) + dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: raise NotFound("DAG not found") @@ -53,14 +54,14 @@ def get_task(*, dag_id: str, task_id: str) -> APIResponse: ], ) def get_tasks(*, dag_id: str, order_by: str = "task_id") -> APIResponse: - """Get tasks for DAG""" - dag: DAG = current_app.dag_bag.get_dag(dag_id) + """Get tasks for DAG.""" + dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: raise NotFound("DAG not found") tasks = dag.tasks try: - tasks = sorted(tasks, key=attrgetter(order_by.lstrip('-')), reverse=(order_by[0:1] == '-')) + tasks = sorted(tasks, key=attrgetter(order_by.lstrip("-")), reverse=(order_by[0:1] == "-")) except AttributeError as err: raise BadRequest(detail=str(err)) task_collection = TaskCollection(tasks=tasks, total_entries=len(tasks)) diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index c2416ab0d9d44..9d5d54ba5809c 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -14,9 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Iterable, List, Optional, Tuple, TypeVar +from __future__ import annotations + +from typing import Any, Iterable, TypeVar -from flask import current_app, request from marshmallow import ValidationError from sqlalchemy import and_, func, or_ from sqlalchemy.exc import MultipleResultsFound @@ -25,23 +26,29 @@ from sqlalchemy.sql import ClauseElement from airflow.api_connexion import security +from airflow.api_connexion.endpoints.request_dict import get_json_request_dict from airflow.api_connexion.exceptions import BadRequest, NotFound from airflow.api_connexion.parameters import format_datetime, format_parameters from airflow.api_connexion.schemas.task_instance_schema import ( TaskInstanceCollection, TaskInstanceReferenceCollection, clear_task_instance_form, + set_single_task_instance_state_form, + set_task_instance_note_form_schema, set_task_instance_state_form, task_instance_batch_form, task_instance_collection_schema, task_instance_reference_collection_schema, + task_instance_reference_schema, task_instance_schema, ) from airflow.api_connexion.types import APIResponse from airflow.models import SlaMiss from airflow.models.dagrun import DagRun as DR +from airflow.models.operator import needs_expansion from airflow.models.taskinstance import TaskInstance as TI, clear_task_instances from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import DagRunState, State @@ -63,7 +70,7 @@ def get_task_instance( task_id: str, session: Session = NEW_SESSION, ) -> APIResponse: - """Get task instance""" + """Get task instance.""" query = ( session.query(TI) .filter(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == task_id) @@ -112,7 +119,7 @@ def get_mapped_task_instance( map_index: int, session: Session = NEW_SESSION, ) -> APIResponse: - """Get task instance""" + """Get task instance.""" query = ( session.query(TI) .filter( @@ -160,20 +167,20 @@ def get_mapped_task_instances( dag_id: str, dag_run_id: str, task_id: str, - execution_date_gte: Optional[str] = None, - execution_date_lte: Optional[str] = None, - start_date_gte: Optional[str] = None, - start_date_lte: Optional[str] = None, - end_date_gte: Optional[str] = None, - end_date_lte: Optional[str] = None, - duration_gte: Optional[float] = None, - duration_lte: Optional[float] = None, - state: Optional[List[str]] = None, - pool: Optional[List[str]] = None, - queue: Optional[List[str]] = None, - limit: Optional[int] = None, - offset: Optional[int] = None, - order_by: Optional[str] = None, + execution_date_gte: str | None = None, + execution_date_lte: str | None = None, + start_date_gte: str | None = None, + start_date_lte: str | None = None, + end_date_gte: str | None = None, + end_date_lte: str | None = None, + duration_gte: float | None = None, + duration_lte: float | None = None, + state: list[str] | None = None, + pool: list[str] | None = None, + queue: list[str] | None = None, + limit: int | None = None, + offset: int | None = None, + order_by: str | None = None, session: Session = NEW_SESSION, ) -> APIResponse: """Get list of task instances.""" @@ -187,8 +194,8 @@ def get_mapped_task_instances( ) # 0 can mean a mapped TI that expanded to an empty list, so it is not an automatic 404 - if base_query.with_entities(func.count('*')).scalar() == 0: - dag = current_app.dag_bag.get_dag(dag_id) + if base_query.with_entities(func.count("*")).scalar() == 0: + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: error_message = f"DAG {dag_id} not found" raise NotFound(error_message) @@ -196,7 +203,7 @@ def get_mapped_task_instances( if not task: error_message = f"Task id {task_id} not found" raise NotFound(error_message) - if not task.is_mapped: + if not needs_expansion(task): error_message = f"Task id {task_id} is not mapped" raise NotFound(error_message) @@ -214,7 +221,7 @@ def get_mapped_task_instances( query = _apply_array_filter(query, key=TI.queue, values=queue) # Count elements before joining extra columns - total_entries = query.with_entities(func.count('*')).scalar() + total_entries = query.with_entities(func.count("*")).scalar() # Add SLA miss query = ( @@ -232,11 +239,11 @@ def get_mapped_task_instances( ) if order_by: - if order_by == 'state': + if order_by == "state": query = query.order_by(TI.state.asc(), TI.map_index.asc()) - elif order_by == '-state': + elif order_by == "-state": query = query.order_by(TI.state.desc(), TI.map_index.asc()) - elif order_by == '-map_index': + elif order_by == "-map_index": query = query.order_by(TI.map_index.desc()) else: raise BadRequest(detail=f"Ordering with '{order_by}' is not supported") @@ -249,20 +256,20 @@ def get_mapped_task_instances( ) -def _convert_state(states: Optional[Iterable[str]]) -> Optional[List[Optional[str]]]: +def _convert_state(states: Iterable[str] | None) -> list[str | None] | None: if not states: return None return [State.NONE if s == "none" else s for s in states] -def _apply_array_filter(query: Query, key: ClauseElement, values: Optional[Iterable[Any]]) -> Query: +def _apply_array_filter(query: Query, key: ClauseElement, values: Iterable[Any] | None) -> Query: if values is not None: cond = ((key == v) for v in values) query = query.filter(or_(*cond)) return query -def _apply_range_filter(query: Query, key: ClauseElement, value_range: Tuple[T, T]) -> Query: +def _apply_range_filter(query: Query, key: ClauseElement, value_range: tuple[T, T]) -> Query: gte_value, lte_value = value_range if gte_value is not None: query = query.filter(key >= gte_value) @@ -292,20 +299,20 @@ def _apply_range_filter(query: Query, key: ClauseElement, value_range: Tuple[T, def get_task_instances( *, limit: int, - dag_id: Optional[str] = None, - dag_run_id: Optional[str] = None, - execution_date_gte: Optional[str] = None, - execution_date_lte: Optional[str] = None, - start_date_gte: Optional[str] = None, - start_date_lte: Optional[str] = None, - end_date_gte: Optional[str] = None, - end_date_lte: Optional[str] = None, - duration_gte: Optional[float] = None, - duration_lte: Optional[float] = None, - state: Optional[List[str]] = None, - pool: Optional[List[str]] = None, - queue: Optional[List[str]] = None, - offset: Optional[int] = None, + dag_id: str | None = None, + dag_run_id: str | None = None, + execution_date_gte: str | None = None, + execution_date_lte: str | None = None, + start_date_gte: str | None = None, + start_date_lte: str | None = None, + end_date_gte: str | None = None, + end_date_lte: str | None = None, + duration_gte: float | None = None, + duration_lte: float | None = None, + state: list[str] | None = None, + pool: list[str] | None = None, + queue: list[str] | None = None, + offset: int | None = None, session: Session = NEW_SESSION, ) -> APIResponse: """Get list of task instances.""" @@ -333,7 +340,7 @@ def get_task_instances( base_query = _apply_array_filter(base_query, key=TI.queue, values=queue) # Count elements before joining extra columns - total_entries = base_query.with_entities(func.count('*')).scalar() + total_entries = base_query.with_entities(func.count("*")).scalar() # Add join query = ( base_query.join( @@ -364,12 +371,12 @@ def get_task_instances( @provide_session def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: """Get list of task instances.""" - body = request.get_json() + body = get_json_request_dict() try: data = task_instance_batch_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) - states = _convert_state(data['state']) + states = _convert_state(data["state"]) base_query = session.query(TI).join(TI.dag_run) base_query = _apply_array_filter(base_query, key=TI.dag_id, values=data["dag_ids"]) @@ -394,7 +401,7 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: base_query = _apply_array_filter(base_query, key=TI.queue, values=data["queue"]) # Count elements before joining extra columns - total_entries = base_query.with_entities(func.count('*')).scalar() + total_entries = base_query.with_entities(func.count("*")).scalar() # Add join base_query = base_query.join( SlaMiss, @@ -423,20 +430,50 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: @provide_session def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Clear task instances.""" - body = request.get_json() + body = get_json_request_dict() try: data = clear_task_instance_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: error_message = f"Dag id {dag_id} not found" raise NotFound(error_message) - reset_dag_runs = data.pop('reset_dag_runs') - dry_run = data.pop('dry_run') + reset_dag_runs = data.pop("reset_dag_runs") + dry_run = data.pop("dry_run") # We always pass dry_run here, otherwise this would try to confirm on the terminal! - task_instances = dag.clear(dry_run=True, dag_bag=current_app.dag_bag, **data) + dag_run_id = data.pop("dag_run_id", None) + future = data.pop("include_future", False) + past = data.pop("include_past", False) + downstream = data.pop("include_downstream", False) + upstream = data.pop("include_upstream", False) + if dag_run_id is not None: + dag_run: DR | None = ( + session.query(DR).filter(DR.dag_id == dag_id, DR.run_id == dag_run_id).one_or_none() + ) + if dag_run is None: + error_message = f"Dag Run id {dag_run_id} not found in dag {dag_id}" + raise NotFound(error_message) + data["start_date"] = dag_run.logical_date + data["end_date"] = dag_run.logical_date + if past: + data["start_date"] = None + if future: + data["end_date"] = None + task_ids = data.pop("task_ids", None) + if task_ids is not None: + task_id = [task[0] if isinstance(task, tuple) else task for task in task_ids] + dag = dag.partial_subset( + task_ids_or_regex=task_id, + include_downstream=downstream, + include_upstream=upstream, + ) + + if len(dag.task_dict) > 1: + # If we had upstream/downstream etc then also include those! + task_ids.extend(tid for tid in dag.task_dict if tid != task_id) + task_instances = dag.clear(dry_run=True, dag_bag=get_airflow_app().dag_bag, task_ids=task_ids, **data) if not dry_run: clear_task_instances( task_instances.all(), @@ -460,26 +497,26 @@ def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> @provide_session def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Set a state of task instances.""" - body = request.get_json() + body = get_json_request_dict() try: data = set_task_instance_state_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) error_message = f"Dag ID {dag_id} not found" - dag = current_app.dag_bag.get_dag(dag_id) + dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: raise NotFound(error_message) - task_id = data['task_id'] + task_id = data["task_id"] task = dag.task_dict.get(task_id) if not task: error_message = f"Task ID {task_id} not found" raise NotFound(error_message) - execution_date = data.get('execution_date') - run_id = data.get('dag_run_id') + execution_date = data.get("execution_date") + run_id = data.get("dag_run_id") if ( execution_date and ( @@ -494,7 +531,7 @@ def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION ) if run_id and not session.query(TI).get( - {'task_id': task_id, 'dag_id': dag_id, 'run_id': run_id, 'map_index': -1} + {"task_id": task_id, "dag_id": dag_id, "run_id": run_id, "map_index": -1} ): error_message = f"Task instance not found for task {task_id!r} on DAG run with ID {run_id!r}" raise NotFound(detail=error_message) @@ -512,3 +549,135 @@ def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION session=session, ) return task_instance_reference_collection_schema.dump(TaskInstanceReferenceCollection(task_instances=tis)) + + +def set_mapped_task_instance_note( + *, dag_id: str, dag_run_id: str, task_id: str, map_index: int +) -> APIResponse: + """Set the note for a Mapped Task instance.""" + return set_task_instance_note(dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, map_index=map_index) + + +@security.requires_access( + [ + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), + ], +) +@provide_session +def patch_task_instance( + *, dag_id: str, dag_run_id: str, task_id: str, map_index: int = -1, session: Session = NEW_SESSION +) -> APIResponse: + """Update the state of a task instance.""" + body = get_json_request_dict() + try: + data = set_single_task_instance_state_form.load(body) + except ValidationError as err: + raise BadRequest(detail=str(err.messages)) + + dag = get_airflow_app().dag_bag.get_dag(dag_id) + if not dag: + raise NotFound("DAG not found", detail=f"DAG {dag_id!r} not found") + + if not dag.has_task(task_id): + raise NotFound("Task not found", detail=f"Task {task_id!r} not found in DAG {dag_id!r}") + + ti: TI | None = session.query(TI).get( + {"task_id": task_id, "dag_id": dag_id, "run_id": dag_run_id, "map_index": map_index} + ) + + if not ti: + error_message = f"Task instance not found for task {task_id!r} on DAG run with ID {dag_run_id!r}" + raise NotFound(detail=error_message) + + if not data["dry_run"]: + ti = dag.set_task_instance_state( + task_id=task_id, + run_id=dag_run_id, + map_indexes=[map_index], + state=data["new_state"], + commit=True, + session=session, + ) + + return task_instance_reference_schema.dump(ti) + + +@security.requires_access( + [ + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), + ], +) +@provide_session +def patch_mapped_task_instance( + *, dag_id: str, dag_run_id: str, task_id: str, map_index: int, session: Session = NEW_SESSION +) -> APIResponse: + """Update the state of a mapped task instance.""" + return patch_task_instance( + dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, map_index=map_index, session=session + ) + + +@security.requires_access( + [ + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), + ], +) +@provide_session +def set_task_instance_note( + *, dag_id: str, dag_run_id: str, task_id: str, map_index: int = -1, session: Session = NEW_SESSION +) -> APIResponse: + """Set the note for a Task instance. This supports both Mapped and non-Mapped Task instances.""" + try: + post_body = set_task_instance_note_form_schema.load(get_json_request_dict()) + new_note = post_body["note"] + except ValidationError as err: + raise BadRequest(detail=str(err)) + + query = ( + session.query(TI) + .filter(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == task_id) + .join(TI.dag_run) + .outerjoin( + SlaMiss, + and_( + SlaMiss.dag_id == TI.dag_id, + SlaMiss.execution_date == DR.execution_date, + SlaMiss.task_id == TI.task_id, + ), + ) + .add_entity(SlaMiss) + .options(joinedload(TI.rendered_task_instance_fields)) + ) + if map_index == -1: + query = query.filter(or_(TI.map_index == -1, TI.map_index is None)) + else: + query = query.filter(TI.map_index == map_index) + + try: + result = query.one_or_none() + except MultipleResultsFound: + raise NotFound( + "Task instance not found", detail="Task instance is mapped, add the map_index value to the URL" + ) + if result is None: + error_message = f"Task Instance not found for dag_id={dag_id}, run_id={dag_run_id}, task_id={task_id}" + raise NotFound(error_message) + + ti, sla_miss = result + from flask_login import current_user + + current_user_id = getattr(current_user, "id", None) + if ti.task_instance_note is None: + ti.note = (new_note, current_user_id) + else: + ti.task_instance_note.content = new_note + ti.task_instance_note.user_id = current_user_id + session.commit() + return task_instance_schema.dump((ti, sla_miss)) diff --git a/airflow/api_connexion/endpoints/user_endpoint.py b/airflow/api_connexion/endpoints/user_endpoint.py index 6b4e984a69559..506e11e00612c 100644 --- a/airflow/api_connexion/endpoints/user_endpoint.py +++ b/airflow/api_connexion/endpoints/user_endpoint.py @@ -14,11 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from http import HTTPStatus -from typing import List, Optional from connexion import NoContent -from flask import current_app, request +from flask import request from marshmallow import ValidationError from sqlalchemy import asc, desc, func from werkzeug.security import generate_password_hash @@ -34,13 +35,14 @@ ) from airflow.api_connexion.types import APIResponse, UpdateMask from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.www.fab_security.sqla.models import Role, User @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_USER)]) def get_user(*, username: str) -> APIResponse: - """Get a user""" - ab_security_manager = current_app.appbuilder.sm + """Get a user.""" + ab_security_manager = get_airflow_app().appbuilder.sm user = ab_security_manager.find_user(username=username) if not user: raise NotFound(title="User not found", detail=f"The User with username `{username}` was not found") @@ -49,9 +51,9 @@ def get_user(*, username: str) -> APIResponse: @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_USER)]) @format_parameters({"limit": check_limit}) -def get_users(*, limit: int, order_by: str = "id", offset: Optional[str] = None) -> APIResponse: - """Get users""" - appbuilder = current_app.appbuilder +def get_users(*, limit: int, order_by: str = "id", offset: str | None = None) -> APIResponse: + """Get users.""" + appbuilder = get_airflow_app().appbuilder session = appbuilder.get_session total_entries = session.query(func.count(User.id)).scalar() direction = desc if order_by.startswith("-") else asc @@ -59,7 +61,7 @@ def get_users(*, limit: int, order_by: str = "id", offset: Optional[str] = None) order_param = order_by.strip("-") order_param = to_replace.get(order_param, order_param) allowed_filter_attrs = [ - 'id', + "id", "first_name", "last_name", "user_name", @@ -81,13 +83,13 @@ def get_users(*, limit: int, order_by: str = "id", offset: Optional[str] = None) @security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_USER)]) def post_user() -> APIResponse: - """Create a new user""" + """Create a new user.""" try: data = user_schema.load(request.json) except ValidationError as e: raise BadRequest(detail=str(e.messages)) - security_manager = current_app.appbuilder.sm + security_manager = get_airflow_app().appbuilder.sm username = data["username"] email = data["email"] @@ -124,26 +126,26 @@ def post_user() -> APIResponse: @security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_USER)]) def patch_user(*, username: str, update_mask: UpdateMask = None) -> APIResponse: - """Update a user""" + """Update a user.""" try: data = user_schema.load(request.json) except ValidationError as e: raise BadRequest(detail=str(e.messages)) - security_manager = current_app.appbuilder.sm + security_manager = get_airflow_app().appbuilder.sm user = security_manager.find_user(username=username) if user is None: detail = f"The User with username `{username}` was not found" raise NotFound(title="User not found", detail=detail) # Check unique username - new_username = data.get('username') + new_username = data.get("username") if new_username and new_username != username: if security_manager.find_user(username=new_username): raise AlreadyExists(detail=f"The username `{new_username}` already exists") # Check unique email - email = data.get('email') + email = data.get("email") if email and email != user.email: if security_manager.find_user(email=email): raise AlreadyExists(detail=f"The email `{email}` already exists") @@ -163,7 +165,7 @@ def patch_user(*, username: str, update_mask: UpdateMask = None) -> APIResponse: raise BadRequest(detail=detail) data = masked_data - roles_to_update: Optional[List[Role]] + roles_to_update: list[Role] | None if "roles" in data: roles_to_update = [] missing_role_names = [] @@ -193,8 +195,8 @@ def patch_user(*, username: str, update_mask: UpdateMask = None) -> APIResponse: @security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_USER)]) def delete_user(*, username: str) -> APIResponse: - """Delete a user""" - security_manager = current_app.appbuilder.sm + """Delete a user.""" + security_manager = get_airflow_app().appbuilder.sm user = security_manager.find_user(username=username) if user is None: diff --git a/airflow/api_connexion/endpoints/variable_endpoint.py b/airflow/api_connexion/endpoints/variable_endpoint.py index 487d2cc486c83..3111ff18d424d 100644 --- a/airflow/api_connexion/endpoints/variable_endpoint.py +++ b/airflow/api_connexion/endpoints/variable_endpoint.py @@ -14,40 +14,52 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from http import HTTPStatus -from typing import Optional -from flask import Response, request +from flask import Response from marshmallow import ValidationError from sqlalchemy import func from sqlalchemy.orm import Session from airflow.api_connexion import security +from airflow.api_connexion.endpoints.request_dict import get_json_request_dict from airflow.api_connexion.exceptions import BadRequest, NotFound from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters from airflow.api_connexion.schemas.variable_schema import variable_collection_schema, variable_schema from airflow.api_connexion.types import UpdateMask from airflow.models import Variable from airflow.security import permissions +from airflow.utils.log.action_logger import action_event_from_permission from airflow.utils.session import NEW_SESSION, provide_session +from airflow.www.decorators import action_logging + +RESOURCE_EVENT_PREFIX = "variable" @security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE)]) +@action_logging( + event=action_event_from_permission( + prefix=RESOURCE_EVENT_PREFIX, + permission=permissions.ACTION_CAN_DELETE, + ), +) def delete_variable(*, variable_key: str) -> Response: - """Delete variable""" + """Delete variable.""" if Variable.delete(variable_key) == 0: raise NotFound("Variable not found") return Response(status=HTTPStatus.NO_CONTENT) @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE)]) -def get_variable(*, variable_key: str) -> Response: - """Get a variables by key""" - try: - var = Variable.get(variable_key) - except KeyError: +@provide_session +def get_variable(*, variable_key: str, session: Session = NEW_SESSION) -> Response: + """Get a variable by key.""" + var = session.query(Variable).filter(Variable.key == variable_key) + if not var.count(): raise NotFound("Variable not found") - return variable_schema.dump({"key": variable_key, "val": var}) + return variable_schema.dump(var.first()) @security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE)]) @@ -55,15 +67,15 @@ def get_variable(*, variable_key: str) -> Response: @provide_session def get_variables( *, - limit: Optional[int], + limit: int | None, order_by: str = "id", - offset: Optional[int] = None, + offset: int | None = None, session: Session = NEW_SESSION, ) -> Response: - """Get all variable values""" + """Get all variable values.""" total_entries = session.query(func.count(Variable.id)).scalar() to_replace = {"value": "val"} - allowed_filter_attrs = ['value', 'key', 'id'] + allowed_filter_attrs = ["value", "key", "id"] query = session.query(Variable) query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs) variables = query.offset(offset).limit(limit).all() @@ -76,10 +88,16 @@ def get_variables( @security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_VARIABLE)]) +@action_logging( + event=action_event_from_permission( + prefix=RESOURCE_EVENT_PREFIX, + permission=permissions.ACTION_CAN_EDIT, + ), +) def patch_variable(*, variable_key: str, update_mask: UpdateMask = None) -> Response: - """Update a variable by key""" + """Update a variable by key.""" try: - data = variable_schema.load(request.json) + data = variable_schema.load(get_json_request_dict()) except ValidationError as err: raise BadRequest("Invalid Variable schema", detail=str(err.messages)) @@ -97,10 +115,16 @@ def patch_variable(*, variable_key: str, update_mask: UpdateMask = None) -> Resp @security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_VARIABLE)]) +@action_logging( + event=action_event_from_permission( + prefix=RESOURCE_EVENT_PREFIX, + permission=permissions.ACTION_CAN_CREATE, + ), +) def post_variables() -> Response: - """Create a variable""" + """Create a variable.""" try: - data = variable_schema.load(request.json) + data = variable_schema.load(get_json_request_dict()) except ValidationError as err: raise BadRequest("Invalid Variable schema", detail=str(err.messages)) diff --git a/airflow/api_connexion/endpoints/version_endpoint.py b/airflow/api_connexion/endpoints/version_endpoint.py index 077d7f8a1cfe4..79b4d2f1e1719 100644 --- a/airflow/api_connexion/endpoints/version_endpoint.py +++ b/airflow/api_connexion/endpoints/version_endpoint.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import NamedTuple, Optional +from typing import NamedTuple import airflow from airflow.api_connexion.schemas.version_schema import version_info_schema @@ -24,14 +25,14 @@ class VersionInfo(NamedTuple): - """Version information""" + """Version information.""" version: str - git_version: Optional[str] + git_version: str | None def get_version() -> APIResponse: - """Get version information""" + """Get version information.""" airflow_version = airflow.__version__ git_version = get_airflow_git_version() diff --git a/airflow/api_connexion/endpoints/xcom_endpoint.py b/airflow/api_connexion/endpoints/xcom_endpoint.py index 9cc6b6d79a933..2ab5ec26f5a2a 100644 --- a/airflow/api_connexion/endpoints/xcom_endpoint.py +++ b/airflow/api_connexion/endpoints/xcom_endpoint.py @@ -14,9 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Optional +from __future__ import annotations -from flask import current_app, g +import copy + +from flask import g from sqlalchemy import and_ from sqlalchemy.orm import Session @@ -27,6 +29,7 @@ from airflow.api_connexion.types import APIResponse from airflow.models import DagRun as DR, XCom from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.session import NEW_SESSION, provide_session @@ -45,14 +48,14 @@ def get_xcom_entries( dag_id: str, dag_run_id: str, task_id: str, - limit: Optional[int], - offset: Optional[int] = None, + limit: int | None, + offset: int | None = None, session: Session = NEW_SESSION, ) -> APIResponse: - """Get all XCom values""" + """Get all XCom values.""" query = session.query(XCom) - if dag_id == '~': - appbuilder = current_app.appbuilder + if dag_id == "~": + appbuilder = get_airflow_app().appbuilder readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user) query = query.filter(XCom.dag_id.in_(readable_dag_ids)) query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id)) @@ -60,14 +63,14 @@ def get_xcom_entries( query = query.filter(XCom.dag_id == dag_id) query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id)) - if task_id != '~': + if task_id != "~": query = query.filter(XCom.task_id == task_id) - if dag_run_id != '~': + if dag_run_id != "~": query = query.filter(DR.run_id == dag_run_id) query = query.order_by(DR.execution_date, XCom.task_id, XCom.dag_id, XCom.key) total_entries = query.count() query = query.offset(offset).limit(limit) - return xcom_collection_schema.dump(XComCollection(xcom_entries=query.all(), total_entries=total_entries)) + return xcom_collection_schema.dump(XComCollection(xcom_entries=query, total_entries=total_entries)) @security.requires_access( @@ -85,14 +88,28 @@ def get_xcom_entry( task_id: str, dag_run_id: str, xcom_key: str, + deserialize: bool = False, session: Session = NEW_SESSION, ) -> APIResponse: - """Get an XCom entry""" - query = session.query(XCom).filter(XCom.dag_id == dag_id, XCom.task_id == task_id, XCom.key == xcom_key) + """Get an XCom entry.""" + if deserialize: + query = session.query(XCom, XCom.value) + else: + query = session.query(XCom) + + query = query.filter(XCom.dag_id == dag_id, XCom.task_id == task_id, XCom.key == xcom_key) query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id)) query = query.filter(DR.run_id == dag_run_id) - query_object = query.one_or_none() - if not query_object: + item = query.one_or_none() + if item is None: raise NotFound("XCom entry not found") - return xcom_schema.dump(query_object) + + if deserialize: + xcom, value = item + stub = copy.copy(xcom) + stub.value = value + stub.value = XCom.deserialize_value(stub) + item = stub + + return xcom_schema.dump(item) diff --git a/airflow/api_connexion/exceptions.py b/airflow/api_connexion/exceptions.py index 8fb7f2e78883b..11468e1506feb 100644 --- a/airflow/api_connexion/exceptions.py +++ b/airflow/api_connexion/exceptions.py @@ -14,8 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from http import HTTPStatus -from typing import Any, Dict, Optional +from typing import Any import flask import werkzeug @@ -71,13 +73,13 @@ def common_error_handler(exception: BaseException) -> flask.Response: class NotFound(ProblemException): - """Raise when the object cannot be found""" + """Raise when the object cannot be found.""" def __init__( self, - title: str = 'Not Found', - detail: Optional[str] = None, - headers: Optional[Dict] = None, + title: str = "Not Found", + detail: str | None = None, + headers: dict | None = None, **kwargs: Any, ) -> None: super().__init__( @@ -91,13 +93,13 @@ def __init__( class BadRequest(ProblemException): - """Raise when the server processes a bad request""" + """Raise when the server processes a bad request.""" def __init__( self, title: str = "Bad Request", - detail: Optional[str] = None, - headers: Optional[Dict] = None, + detail: str | None = None, + headers: dict | None = None, **kwargs: Any, ) -> None: super().__init__( @@ -111,13 +113,13 @@ def __init__( class Unauthenticated(ProblemException): - """Raise when the user is not authenticated""" + """Raise when the user is not authenticated.""" def __init__( self, title: str = "Unauthorized", - detail: Optional[str] = None, - headers: Optional[Dict] = None, + detail: str | None = None, + headers: dict | None = None, **kwargs: Any, ): super().__init__( @@ -131,13 +133,13 @@ def __init__( class PermissionDenied(ProblemException): - """Raise when the user does not have the required permissions""" + """Raise when the user does not have the required permissions.""" def __init__( self, title: str = "Forbidden", - detail: Optional[str] = None, - headers: Optional[Dict] = None, + detail: str | None = None, + headers: dict | None = None, **kwargs: Any, ) -> None: super().__init__( @@ -151,13 +153,13 @@ def __init__( class AlreadyExists(ProblemException): - """Raise when the object already exists""" + """Raise when the object already exists.""" def __init__( self, title="Conflict", - detail: Optional[str] = None, - headers: Optional[Dict] = None, + detail: str | None = None, + headers: dict | None = None, **kwargs: Any, ): super().__init__( @@ -171,13 +173,13 @@ def __init__( class Unknown(ProblemException): - """Returns a response body and status code for HTTP 500 exception""" + """Returns a response body and status code for HTTP 500 exception.""" def __init__( self, title: str = "Internal Server Error", - detail: Optional[str] = None, - headers: Optional[Dict] = None, + detail: str | None = None, + headers: dict | None = None, **kwargs: Any, ) -> None: super().__init__( diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index fc71fbb7f5e3d..3f25b54640741 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -229,7 +229,7 @@ info: This means that the server encountered an unexpected condition that prevented it from fulfilling the request. - version: '1.0.0' + version: '2.5.1' license: name: Apache 2.0 url: http://www.apache.org/licenses/LICENSE-2.0.html @@ -569,7 +569,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/ClearTaskInstance' + $ref: '#/components/schemas/ClearTaskInstances' responses: '200': @@ -585,6 +585,85 @@ paths: '404': $ref: '#/components/responses/NotFound' + /dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/setNote: + parameters: + - $ref: '#/components/parameters/DAGID' + - $ref: '#/components/parameters/DAGRunID' + - $ref: '#/components/parameters/TaskID' + + patch: + summary: Update the TaskInstance note. + description: | + Update the manual user note of a non-mapped Task Instance. + + *New in version 2.5.0* + x-openapi-router-controller: airflow.api_connexion.endpoints.task_instance_endpoint + operationId: set_task_instance_note + tags: [TaskInstance] + requestBody: + description: Parameters of set Task Instance note. + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/SetTaskInstanceNote' + + responses: + '200': + description: Success. + content: + application/json: + schema: + $ref: '#/components/schemas/TaskInstance' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthenticated' + '403': + $ref: '#/components/responses/PermissionDenied' + '404': + $ref: '#/components/responses/NotFound' + + /dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}/setNote: + parameters: + - $ref: '#/components/parameters/DAGID' + - $ref: '#/components/parameters/DAGRunID' + - $ref: '#/components/parameters/TaskID' + - $ref: '#/components/parameters/MapIndex' + + patch: + summary: Update the TaskInstance note. + description: | + Update the manual user note of a mapped Task Instance. + + *New in version 2.5.0* + x-openapi-router-controller: airflow.api_connexion.endpoints.task_instance_endpoint + operationId: set_mapped_task_instance_note + tags: [TaskInstance] + requestBody: + description: Parameters of set Task Instance note. + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/SetTaskInstanceNote' + + responses: + '200': + description: Success. + content: + application/json: + schema: + $ref: '#/components/schemas/TaskInstance' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthenticated' + '403': + $ref: '#/components/responses/PermissionDenied' + '404': + $ref: '#/components/responses/NotFound' + /dags/{dag_id}/updateTaskInstancesState: parameters: - $ref: '#/components/parameters/DAGID' @@ -818,6 +897,70 @@ paths: '404': $ref: '#/components/responses/NotFound' + /dags/{dag_id}/dagRuns/{dag_run_id}/upstreamDatasetEvents: + parameters: + - $ref: '#/components/parameters/DAGID' + - $ref: '#/components/parameters/DAGRunID' + get: + summary: Get dataset events for a DAG run + description: | + Get datasets for a dag run. + + *New in version 2.4.0* + x-openapi-router-controller: airflow.api_connexion.endpoints.dag_run_endpoint + operationId: get_upstream_dataset_events + tags: [DAGRun, Dataset] + responses: + '200': + description: Success. + content: + application/json: + schema: + $ref: '#/components/schemas/DatasetEventCollection' + '401': + $ref: '#/components/responses/Unauthenticated' + '403': + $ref: '#/components/responses/PermissionDenied' + '404': + $ref: '#/components/responses/NotFound' + + /dags/{dag_id}/dagRuns/{dag_run_id}/setNote: + parameters: + - $ref: '#/components/parameters/DAGID' + - $ref: '#/components/parameters/DAGRunID' + patch: + summary: Update the DagRun note. + description: | + Update the manual user note of a DagRun. + + *New in version 2.5.0* + x-openapi-router-controller: airflow.api_connexion.endpoints.dag_run_endpoint + operationId: set_dag_run_note + tags: [DAGRun] + requestBody: + description: Parameters of set DagRun note. + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/SetDagRunNote' + + responses: + '200': + description: Success. + content: + application/json: + schema: + $ref: '#/components/schemas/DAGRun' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthenticated' + '403': + $ref: '#/components/responses/PermissionDenied' + '404': + $ref: '#/components/responses/NotFound' + /eventLogs: get: summary: List log entries @@ -1116,6 +1259,37 @@ paths: '404': $ref: '#/components/responses/NotFound' + patch: + summary: Updates the state of a task instance + description: > + Updates the state for single task instance. + + *New in version 2.5.0* + x-openapi-router-controller: airflow.api_connexion.endpoints.task_instance_endpoint + operationId: patch_task_instance + tags: [TaskInstance] + requestBody: + description: Parameters of action + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateTaskInstance' + responses: + '200': + description: Success. + content: + application/json: + schema: + $ref: '#/components/schemas/TaskInstanceReference' + '401': + $ref: '#/components/responses/Unauthenticated' + '403': + $ref: '#/components/responses/PermissionDenied' + '404': + $ref: '#/components/responses/NotFound' + + /dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}: parameters: - $ref: '#/components/parameters/DAGID' @@ -1146,6 +1320,36 @@ paths: '404': $ref: '#/components/responses/NotFound' + patch: + summary: Updates the state of a mapped task instance + description: > + Updates the state for single mapped task instance. + + *New in version 2.5.0* + x-openapi-router-controller: airflow.api_connexion.endpoints.task_instance_endpoint + operationId: patch_mapped_task_instance + tags: [TaskInstance] + requestBody: + description: Parameters of action + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateTaskInstance' + responses: + '200': + description: Success. + content: + application/json: + schema: + $ref: '#/components/schemas/TaskInstanceReference' + '401': + $ref: '#/components/responses/Unauthenticated' + '403': + $ref: '#/components/responses/PermissionDenied' + '404': + $ref: '#/components/responses/NotFound' + + /dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/listMapped: parameters: - $ref: '#/components/parameters/DAGID' @@ -1182,7 +1386,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/TaskInstance' + $ref: '#/components/schemas/TaskInstanceCollection' '401': $ref: '#/components/responses/Unauthenticated' '403': @@ -1384,6 +1588,23 @@ paths: x-openapi-router-controller: airflow.api_connexion.endpoints.xcom_endpoint operationId: get_xcom_entry tags: [XCom] + parameters: + - in: query + name: deserialize + schema: + type: boolean + default: false + required: false + description: | + Whether to deserialize an XCom value when using a custom XCom backend. + + The XCom API endpoint calls `orm_deserialize_value` by default since an XCom may contain value + that is potentially expensive to deserialize in the web server. Setting this to true overrides + the consideration, and calls `deserialize_value` instead. + + This parameter is not meaningful when using the default XCom backend. + + *New in version 2.4.0* responses: '200': description: Success. @@ -1433,6 +1654,7 @@ paths: - $ref: '#/components/parameters/TaskID' - $ref: '#/components/parameters/TaskTryNumber' - $ref: '#/components/parameters/FullContent' + - $ref: '#/components/parameters/FilterMapIndex' - $ref: '#/components/parameters/ContinuationToken' get: @@ -1465,6 +1687,7 @@ paths: '404': $ref: '#/components/responses/NotFound' + /dags/{dag_id}/details: parameters: - $ref: '#/components/parameters/DAGID' @@ -1573,6 +1796,123 @@ paths: '406': $ref: '#/components/responses/NotAcceptable' + /dagWarnings: + get: + summary: List dag warnings + x-openapi-router-controller: airflow.api_connexion.endpoints.dag_warning_endpoint + operationId: get_dag_warnings + tags: [DagWarning] + parameters: + - name: dag_id + in: query + schema: + type: string + required: false + description: If set, only return DAG warnings with this dag_id. + - name: warning_type + in: query + schema: + type: string + required: false + description: If set, only return DAG warnings with this type. + - $ref: '#/components/parameters/PageLimit' + - $ref: '#/components/parameters/PageOffset' + - $ref: '#/components/parameters/OrderBy' + + responses: + '200': + description: Success. + content: + application/json: + schema: + $ref: '#/components/schemas/DagWarningCollection' + '401': + $ref: '#/components/responses/Unauthenticated' + '403': + $ref: '#/components/responses/PermissionDenied' + + /datasets: + get: + summary: List datasets + x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint + operationId: get_datasets + tags: [Dataset] + parameters: + - $ref: '#/components/parameters/PageLimit' + - $ref: '#/components/parameters/PageOffset' + - $ref: '#/components/parameters/OrderBy' + - name: uri_pattern + in: query + schema: + type: string + required: false + description: | + If set, only return datasets with uris matching this pattern. + responses: + '200': + description: Success. + content: + application/json: + schema: + $ref: '#/components/schemas/DatasetCollection' + '401': + $ref: '#/components/responses/Unauthenticated' + '403': + $ref: '#/components/responses/PermissionDenied' + + /datasets/{uri}: + parameters: + - $ref: '#/components/parameters/DatasetURI' + get: + summary: Get a dataset + description: Get a dataset by uri. + x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint + operationId: get_dataset + tags: [Dataset] + responses: + '200': + description: Success. + content: + application/json: + schema: + $ref: '#/components/schemas/Dataset' + '401': + $ref: '#/components/responses/Unauthenticated' + '403': + $ref: '#/components/responses/PermissionDenied' + '404': + $ref: '#/components/responses/NotFound' + + /datasets/events: + parameters: + - $ref: '#/components/parameters/PageLimit' + - $ref: '#/components/parameters/PageOffset' + - $ref: '#/components/parameters/OrderBy' + - $ref: '#/components/parameters/FilterDatasetID' + - $ref: '#/components/parameters/FilterSourceDAGID' + - $ref: '#/components/parameters/FilterSourceTaskID' + - $ref: '#/components/parameters/FilterSourceRunID' + - $ref: '#/components/parameters/FilterSourceMapIndex' + get: + summary: Get dataset events + description: Get dataset events + x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint + operationId: get_dataset_events + tags: [Dataset] + responses: + '200': + description: Success. + content: + application/json: + schema: + $ref: '#/components/schemas/DatasetEventCollection' + '401': + $ref: '#/components/responses/Unauthenticated' + '403': + $ref: '#/components/responses/PermissionDenied' + '404': + $ref: '#/components/responses/NotFound' + /config: get: summary: Get current configuration @@ -1988,15 +2328,13 @@ components: description: | The user's first name. - *Changed in version 2.2.0*: A minimum character length requirement ('minLength') is added. - minLength: 1 + *Changed in version 2.4.0*: The requirement for this to be non-empty was removed. last_name: type: string description: | The user's last name. - *Changed in version 2.2.0*: A minimum character length requirement ('minLength') is added. - minLength: 1 + *Changed in version 2.4.0*: The requirement for this to be non-empty was removed. username: type: string description: | @@ -2099,6 +2437,10 @@ components: conn_type: type: string description: The connection type. + description: + type: string + description: The description of the connection. + nullable: true host: type: string nullable: true @@ -2451,6 +2793,7 @@ components: - backfill - manual - scheduled + - dataset_triggered state: $ref: '#/components/schemas/DagState' readOnly: true @@ -2465,6 +2808,13 @@ components: The value of this field can be set only when creating the object. If you try to modify the field of an existing object, the request fails with an BAD_REQUEST error. + note: + type: string + description: | + Contains manually entered notes by the user about the DagRun. + + *New in version 2.5.0* + nullable: true UpdateDagRunState: type: object @@ -2496,20 +2846,62 @@ components: $ref: '#/components/schemas/DAGRun' - $ref: '#/components/schemas/CollectionInfo' - EventLog: + DagWarning: type: object - description: Log of user operations via CLI or Web UI. properties: - event_log_id: - description: The event log ID - type: integer + dag_id: + type: string readOnly: true - when: - description: The time when these events happened. - format: date-time + description: The dag_id. + warning_type: type: string readOnly: true - dag_id: + description: The warning type for the dag warning. + message: + type: string + readOnly: true + description: The message for the dag warning. + timestamp: + type: string + format: datetime + readOnly: true + description: The time when this warning was logged. + + DagWarningCollection: + type: object + description: | + Collection of DAG warnings. + + allOf: + - type: object + properties: + import_errors: + type: array + items: + $ref: '#/components/schemas/DagWarning' + - $ref: '#/components/schemas/CollectionInfo' + + SetDagRunNote: + type: object + properties: + note: + description: Custom notes left by users for this Dag Run. + type: string + + EventLog: + type: object + description: Log of user operations via CLI or Web UI. + properties: + event_log_id: + description: The event log ID + type: integer + readOnly: true + when: + description: The time when these events happened. + format: date-time + type: string + readOnly: true + dag_id: description: The DAG ID type: string readOnly: true @@ -2621,7 +3013,6 @@ components: readOnly: true nullable: true - Pool: description: The pool type: object @@ -2726,6 +3117,59 @@ components: nullable: true notification_sent: type: boolean + nullable: true + + Trigger: + type: object + properties: + id: + type: integer + classpath: + type: string + kwargs: + type: string + created_date: + type: string + format: datetime + triggerer_id: + type: integer + nullable: true + + Job: + type: object + properties: + id: + type: integer + dag_id: + type: string + nullable: true + state: + type: string + nullable: true + job_type: + type: string + nullable: true + start_date: + type: string + format: datetime + nullable: true + end_date: + type: string + format: datetime + nullable: true + latest_heartbeat: + type: string + format: datetime + nullable: true + executor_class: + type: string + nullable: true + hostname: + type: string + nullable: true + unixname: + type: string + nullable: true TaskInstance: type: object @@ -2759,6 +3203,8 @@ components: nullable: true try_number: type: integer + map_index: + type: integer max_tries: type: integer hostname: @@ -2771,8 +3217,10 @@ components: type: integer queue: type: string + nullable: true priority_weight: type: integer + nullable: true operator: type: string nullable: true @@ -2795,6 +3243,19 @@ components: *New in version 2.3.0* type: object + trigger: + $ref: '#/components/schemas/Trigger' + nullable: true + triggerer_job: + $ref: '#/components/schemas/Job' + nullable: true + note: + type: string + description: | + Contains manually entered notes by the user about the TaskInstance. + + *New in version 2.5.0* + nullable: true TaskInstanceCollection: type: object @@ -2850,6 +3311,13 @@ components: properties: key: type: string + description: + type: string + description: | + The description of the variable. + + *New in version 2.4.0* + nullable: true VariableCollection: type: object @@ -3085,6 +3553,7 @@ components: queue: type: string readOnly: true + nullable: true pool: type: string readOnly: true @@ -3139,9 +3608,6 @@ components: *New in version 2.1.0* properties: - number: - type: string - description: The plugin number name: type: string description: The name of the plugin @@ -3302,6 +3768,211 @@ components: $ref: '#/components/schemas/Resource' description: The permission resource + Dataset: + description: | + A dataset item. + + *New in version 2.4.0* + type: object + properties: + id: + type: integer + description: The dataset id + uri: + type: string + description: The dataset uri + nullable: false + extra: + type: object + description: The dataset extra + nullable: true + created_at: + type: string + description: The dataset creation time + nullable: false + updated_at: + type: string + description: The dataset update time + nullable: false + consuming_dags: + type: array + items: + $ref: '#/components/schemas/DagScheduleDatasetReference' + producing_tasks: + type: array + items: + $ref: '#/components/schemas/TaskOutletDatasetReference' + + + TaskOutletDatasetReference: + description: | + A datasets reference to an upstream task. + + *New in version 2.4.0* + type: object + properties: + dag_id: + type: string + description: The DAG ID that updates the dataset. + nullable: true + task_id: + type: string + description: The task ID that updates the dataset. + nullable: true + created_at: + type: string + description: The dataset creation time + nullable: false + updated_at: + type: string + description: The dataset update time + nullable: false + + DagScheduleDatasetReference: + description: | + A datasets reference to a downstream DAG. + + *New in version 2.4.0* + type: object + properties: + dag_id: + type: string + description: The DAG ID that depends on the dataset. + nullable: true + created_at: + type: string + description: The dataset reference creation time + nullable: false + updated_at: + type: string + description: The dataset reference update time + nullable: false + + DatasetCollection: + description: | + A collection of datasets. + + *New in version 2.4.0* + type: object + allOf: + - type: object + properties: + datasets: + type: array + items: + $ref: '#/components/schemas/Dataset' + - $ref: '#/components/schemas/CollectionInfo' + + DatasetEvent: + description: | + A dataset event. + + *New in version 2.4.0* + type: object + properties: + dataset_id: + type: integer + description: The dataset id + dataset_uri: + type: string + description: The URI of the dataset + nullable: false + extra: + type: object + description: The dataset event extra + nullable: true + source_dag_id: + type: string + description: The DAG ID that updated the dataset. + nullable: true + source_task_id: + type: string + description: The task ID that updated the dataset. + nullable: true + source_run_id: + type: string + description: The DAG run ID that updated the dataset. + nullable: true + source_map_index: + type: integer + description: The task map index that updated the dataset. + nullable: true + created_dagruns: + type: array + items: + $ref: '#/components/schemas/BasicDAGRun' + timestamp: + type: string + description: The dataset event creation time + nullable: false + + BasicDAGRun: + type: object + properties: + run_id: + type: string + description: | + Run ID. + dag_id: + type: string + readOnly: true + logical_date: + type: string + description: | + The logical date (previously called execution date). This is the time or interval covered by + this DAG run, according to the DAG definition. + + The value of this field can be set only when creating the object. If you try to modify the + field of an existing object, the request fails with an BAD_REQUEST error. + + This together with DAG_ID are a unique key. + + *New in version 2.2.0* + format: date-time + start_date: + type: string + format: date-time + description: | + The start time. The time when DAG run was actually created. + + *Changed in version 2.1.3*: Field becomes nullable. + readOnly: true + nullable: true + end_date: + type: string + format: date-time + readOnly: true + nullable: true + data_interval_start: + type: string + format: date-time + readOnly: true + nullable: true + data_interval_end: + type: string + format: date-time + readOnly: true + nullable: true + state: + $ref: '#/components/schemas/DagState' + readOnly: true + + DatasetEventCollection: + description: | + A collection of dataset events. + + *New in version 2.4.0* + type: object + allOf: + - type: object + properties: + dataset_events: + type: array + items: + $ref: '#/components/schemas/DatasetEvent' + - $ref: '#/components/schemas/CollectionInfo' + + # Configuration ConfigOption: type: object @@ -3358,7 +4029,7 @@ components: type: boolean default: true - ClearTaskInstance: + ClearTaskInstances: type: object properties: dry_run: @@ -3410,6 +4081,31 @@ components: description: Set state of DAG runs to RUNNING. type: boolean + dag_run_id: + type: string + description: The DagRun ID for this task instance + nullable: true + + include_upstream: + description: If set to true, upstream tasks are also affected. + type: boolean + default: false + + include_downstream: + description: If set to true, downstream tasks are also affected. + type: boolean + default: false + + include_future: + description: If set to True, also tasks from future DAG Runs are affected. + type: boolean + default: false + + include_past: + description: If set to True, also tasks from past DAG Runs are affected. + type: boolean + default: false + UpdateTaskInstancesState: type: object properties: @@ -3459,6 +4155,31 @@ components: - success - failed + UpdateTaskInstance: + type: object + properties: + dry_run: + description: | + If set, don't actually run this operation. The response will contain the task instance + planned to be affected, but won't be modified in any way. + type: boolean + default: false + + new_state: + description: Expected new state. + type: string + enum: + - success + - failed + SetTaskInstanceNote: + type: object + required: + - note + properties: + note: + description: The custom note to set for this Task Instance. + type: string + ListDagRunsForm: type: object properties: @@ -3634,6 +4355,7 @@ components: description: | Schedule interval. Defines how often DAG runs, this object gets added to your latest task instance's execution_date to figure out the next schedule. + nullable: true readOnly: true anyOf: - $ref: '#/components/schemas/TimeDelta' @@ -3779,7 +4501,10 @@ components: *Changed in version 2.0.2*: 'removed' is added as a possible value. - *Changed in version 2.2.0*: 'deferred' and 'sensing' is added as a possible value. + *Changed in version 2.2.0*: 'deferred' is added as a possible value. + + *Changed in version 2.4.0*: 'sensing' state has been removed. + *Changed in version 2.4.2*: 'restarting' is added as a possible value type: string enum: - success @@ -3793,8 +4518,8 @@ components: - none - scheduled - deferred - - sensing - removed + - restarting DagState: description: | @@ -3945,6 +4670,15 @@ components: required: true description: The import error ID. + DatasetURI: + in: path + name: uri + schema: + type: string + format: path + required: true + description: The encoded Dataset URI + PoolName: in: path name: pool_name @@ -3958,6 +4692,7 @@ components: name: variable_key schema: type: string + format: path required: true description: The variable Key. @@ -3989,7 +4724,8 @@ components: type: string required: true description: The XCom key. - # Filter + + # Filters FilterExecutionDateGTE: in: query name: execution_date_gte @@ -4118,6 +4854,48 @@ components: *New in version 2.2.0* + FilterDatasetID: + in: query + name: dataset_id + schema: + type: integer + description: The Dataset ID that updated the dataset. + + FilterSourceDAGID: + in: query + name: source_dag_id + schema: + type: string + description: The DAG ID that updated the dataset. + + FilterSourceTaskID: + in: query + name: source_task_id + schema: + type: string + description: The task ID that updated the dataset. + + FilterSourceRunID: + in: query + name: source_run_id + schema: + type: string + description: The DAG run ID that updated the dataset. + + FilterSourceMapIndex: + in: query + name: source_map_index + schema: + type: integer + description: The map index that updated the dataset. + + FilterMapIndex: + in: query + name: map_index + schema: + type: integer + description: Filter on map index for mapped task. + OrderBy: in: query name: order_by @@ -4143,7 +4921,6 @@ components: *New in version 2.1.1* # Other parameters - FileToken: in: path name: file_token @@ -4274,6 +5051,8 @@ tags: - name: Role - name: Permission - name: User + - name: DagWarning + - name: Dataset externalDocs: url: https://airflow.apache.org/docs/apache-airflow/stable/ diff --git a/airflow/api_connexion/parameters.py b/airflow/api_connexion/parameters.py index 81d1bd9280abb..8064912d921fa 100644 --- a/airflow/api_connexion/parameters.py +++ b/airflow/api_connexion/parameters.py @@ -14,9 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from datetime import datetime from functools import wraps -from typing import Any, Callable, Container, Dict, Optional, TypeVar, cast +from typing import Any, Callable, Container, TypeVar, cast from pendulum.parsing import ParserError from sqlalchemy import text @@ -28,21 +30,23 @@ def validate_istimezone(value: datetime) -> None: - """Validates that a datetime is not naive""" + """Validates that a datetime is not naive.""" if not value.tzinfo: raise BadRequest("Invalid datetime format", detail="Naive datetime is disallowed") def format_datetime(value: str) -> datetime: """ + Format datetime objects. + Datetime format parser for args since connexion doesn't parse datetimes https://github.com/zalando/connexion/issues/476 This should only be used within connection views because it raises 400 """ value = value.strip() - if value[-1] != 'Z': - value = value.replace(" ", '+') + if value[-1] != "Z": + value = value.replace(" ", "+") try: return timezone.parse(value) except (ParserError, TypeError) as err: @@ -51,6 +55,8 @@ def format_datetime(value: str) -> datetime: def check_limit(value: int) -> int: """ + Check the limit does not exceed configured value. + This checks the limit passed to view and raises BadRequest if limit exceed user configured value """ @@ -69,7 +75,7 @@ def check_limit(value: int) -> int: T = TypeVar("T", bound=Callable) -def format_parameters(params_formatters: Dict[str, Callable[[Any], Any]]) -> Callable[[T], T]: +def format_parameters(params_formatters: dict[str, Callable[[Any], Any]]) -> Callable[[T], T]: """ Decorator factory that create decorator that convert parameters using given formatters. @@ -94,11 +100,11 @@ def wrapped_function(*args, **kwargs): def apply_sorting( query: Query, order_by: str, - to_replace: Optional[Dict[str, str]] = None, - allowed_attrs: Optional[Container[str]] = None, + to_replace: dict[str, str] | None = None, + allowed_attrs: Container[str] | None = None, ) -> Query: - """Apply sorting to query""" - lstriped_orderby = order_by.lstrip('-') + """Apply sorting to query.""" + lstriped_orderby = order_by.lstrip("-") if allowed_attrs and lstriped_orderby not in allowed_attrs: raise BadRequest( detail=f"Ordering with '{lstriped_orderby}' is disallowed or " diff --git a/airflow/api_connexion/schemas/common_schema.py b/airflow/api_connexion/schemas/common_schema.py index 502d5b60bdddd..cf510137621b6 100644 --- a/airflow/api_connexion/schemas/common_schema.py +++ b/airflow/api_connexion/schemas/common_schema.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import datetime import inspect @@ -31,13 +32,13 @@ class CronExpression(typing.NamedTuple): - """Cron expression schema""" + """Cron expression schema.""" value: str class TimeDeltaSchema(Schema): - """Time delta schema""" + """Time delta schema.""" objectType = fields.Constant("TimeDelta", data_key="__type") days = fields.Integer() @@ -46,14 +47,14 @@ class TimeDeltaSchema(Schema): @marshmallow.post_load def make_time_delta(self, data, **kwargs): - """Create time delta based on data""" + """Create time delta based on data.""" if "objectType" in data: del data["objectType"] return datetime.timedelta(**data) class RelativeDeltaSchema(Schema): - """Relative delta schema""" + """Relative delta schema.""" objectType = fields.Constant("RelativeDelta", data_key="__type") years = fields.Integer() @@ -74,7 +75,7 @@ class RelativeDeltaSchema(Schema): @marshmallow.post_load def make_relative_delta(self, data, **kwargs): - """Create relative delta based on data""" + """Create relative delta based on data.""" if "objectType" in data: del data["objectType"] @@ -82,14 +83,14 @@ def make_relative_delta(self, data, **kwargs): class CronExpressionSchema(Schema): - """Cron expression schema""" + """Cron expression schema.""" objectType = fields.Constant("CronExpression", data_key="__type") value = fields.String(required=True) @marshmallow.post_load def make_cron_expression(self, data, **kwargs): - """Create cron expression based on data""" + """Create cron expression based on data.""" return CronExpression(data["value"]) @@ -118,7 +119,7 @@ def _dump(self, obj, update_fields=True, **kwargs): return super()._dump(obj, update_fields=update_fields, **kwargs) def get_obj_type(self, obj): - """Select schema based on object type""" + """Select schema based on object type.""" if isinstance(obj, datetime.timedelta): return "TimeDelta" elif isinstance(obj, relativedelta.relativedelta): @@ -130,7 +131,7 @@ def get_obj_type(self, obj): class ColorField(fields.String): - """Schema for color property""" + """Schema for color property.""" def __init__(self, **metadata): super().__init__(**metadata) @@ -138,7 +139,7 @@ def __init__(self, **metadata): class WeightRuleField(fields.String): - """Schema for WeightRule""" + """Schema for WeightRule.""" def __init__(self, **metadata): super().__init__(**metadata) @@ -146,7 +147,7 @@ def __init__(self, **metadata): class TimezoneField(fields.String): - """Schema for timezone""" + """Schema for timezone.""" class ClassReferenceSchema(Schema): diff --git a/airflow/api_connexion/schemas/config_schema.py b/airflow/api_connexion/schemas/config_schema.py index 2eb459ce14263..28095c177a1df 100644 --- a/airflow/api_connexion/schemas/config_schema.py +++ b/airflow/api_connexion/schemas/config_schema.py @@ -14,50 +14,51 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import List, NamedTuple +from typing import NamedTuple from marshmallow import Schema, fields class ConfigOptionSchema(Schema): - """Config Option Schema""" + """Config Option Schema.""" key = fields.String(required=True) value = fields.String(required=True) class ConfigOption(NamedTuple): - """Config option""" + """Config option.""" key: str value: str class ConfigSectionSchema(Schema): - """Config Section Schema""" + """Config Section Schema.""" name = fields.String(required=True) options = fields.List(fields.Nested(ConfigOptionSchema)) class ConfigSection(NamedTuple): - """List of config options within a section""" + """List of config options within a section.""" name: str - options: List[ConfigOption] + options: list[ConfigOption] class ConfigSchema(Schema): - """Config Schema""" + """Config Schema.""" sections = fields.List(fields.Nested(ConfigSectionSchema)) class Config(NamedTuple): - """List of config sections with their options""" + """List of config sections with their options.""" - sections: List[ConfigSection] + sections: list[ConfigSection] config_schema = ConfigSchema() diff --git a/airflow/api_connexion/schemas/connection_schema.py b/airflow/api_connexion/schemas/connection_schema.py index f06da92bacdab..4288ce079c554 100644 --- a/airflow/api_connexion/schemas/connection_schema.py +++ b/airflow/api_connexion/schemas/connection_schema.py @@ -15,8 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import json -from typing import List, NamedTuple +from typing import NamedTuple from marshmallow import Schema, fields from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field @@ -25,14 +27,14 @@ class ConnectionCollectionItemSchema(SQLAlchemySchema): - """Schema for a connection item""" + """Schema for a connection item.""" class Meta: - """Meta""" + """Meta.""" model = Connection - connection_id = auto_field('conn_id', required=True) + connection_id = auto_field("conn_id", required=True) conn_type = auto_field(required=True) description = auto_field() host = auto_field() @@ -42,10 +44,10 @@ class Meta: class ConnectionSchema(ConnectionCollectionItemSchema): - """Connection schema""" + """Connection schema.""" password = auto_field(load_only=True) - extra = fields.Method('serialize_extra', deserialize='deserialize_extra', allow_none=True) + extra = fields.Method("serialize_extra", deserialize="deserialize_extra", allow_none=True) @staticmethod def serialize_extra(obj: Connection): @@ -66,21 +68,21 @@ def deserialize_extra(value): # an explicit deserialize method is required for class ConnectionCollection(NamedTuple): - """List of Connections with meta""" + """List of Connections with meta.""" - connections: List[Connection] + connections: list[Connection] total_entries: int class ConnectionCollectionSchema(Schema): - """Connection Collection Schema""" + """Connection Collection Schema.""" connections = fields.List(fields.Nested(ConnectionCollectionItemSchema)) total_entries = fields.Int() class ConnectionTestSchema(Schema): - """connection Test Schema""" + """connection Test Schema.""" status = fields.Boolean(required=True) message = fields.String(required=True) diff --git a/airflow/api_connexion/schemas/dag_run_schema.py b/airflow/api_connexion/schemas/dag_run_schema.py index 5cd79b20228fb..7ca857951b9fe 100644 --- a/airflow/api_connexion/schemas/dag_run_schema.py +++ b/airflow/api_connexion/schemas/dag_run_schema.py @@ -15,8 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import json -from typing import List, NamedTuple +from typing import NamedTuple from marshmallow import fields, post_dump, pre_load, validate from marshmallow.schema import Schema @@ -34,7 +36,7 @@ class ConfObject(fields.Field): - """The conf field""" + """The conf field.""" def _serialize(self, value, attr, obj, **kwargs): if not value: @@ -51,15 +53,15 @@ def _deserialize(self, value, attr, data, **kwargs): class DAGRunSchema(SQLAlchemySchema): - """Schema for DAGRun""" + """Schema for DAGRun.""" class Meta: - """Meta""" + """Meta.""" model = DagRun dateformat = "iso" - run_id = auto_field(data_key='dag_run_id') + run_id = auto_field(data_key="dag_run_id") dag_id = auto_field(dump_only=True) execution_date = auto_field(data_key="logical_date", validate=validate_istimezone) start_date = auto_field(dump_only=True) @@ -71,6 +73,7 @@ class Meta: data_interval_end = auto_field(dump_only=True) last_scheduling_decision = auto_field(dump_only=True) run_type = auto_field(dump_only=True) + note = auto_field(dump_only=True) @pre_load def autogenerate(self, data, **kwargs): @@ -110,7 +113,7 @@ def autofill(self, data, **kwargs): class SetDagRunStateFormSchema(Schema): - """Schema for handling the request of setting state of DAG run""" + """Schema for handling the request of setting state of DAG run.""" state = DagStateField( validate=validate.OneOf( @@ -120,32 +123,32 @@ class SetDagRunStateFormSchema(Schema): class ClearDagRunStateFormSchema(Schema): - """Schema for handling the request of clearing a DAG run""" + """Schema for handling the request of clearing a DAG run.""" dry_run = fields.Boolean(load_default=True) class DAGRunCollection(NamedTuple): - """List of DAGRuns with metadata""" + """List of DAGRuns with metadata.""" - dag_runs: List[DagRun] + dag_runs: list[DagRun] total_entries: int class DAGRunCollectionSchema(Schema): - """DAGRun Collection schema""" + """DAGRun Collection schema.""" dag_runs = fields.List(fields.Nested(DAGRunSchema)) total_entries = fields.Int() class DagRunsBatchFormSchema(Schema): - """Schema to validate and deserialize the Form(request payload) submitted to DagRun Batch endpoint""" + """Schema to validate and deserialize the Form(request payload) submitted to DagRun Batch endpoint.""" class Meta: - """Meta""" + """Meta.""" - datetimeformat = 'iso' + datetimeformat = "iso" strict = True order_by = fields.String() @@ -161,8 +164,15 @@ class Meta: end_date_lte = fields.DateTime(load_default=None, validate=validate_istimezone) +class SetDagRunNoteFormSchema(Schema): + """Schema for handling the request of clearing a DAG run.""" + + note = fields.String(allow_none=True, validate=validate.Length(max=1000)) + + dagrun_schema = DAGRunSchema() dagrun_collection_schema = DAGRunCollectionSchema() set_dagrun_state_form_schema = SetDagRunStateFormSchema() clear_dagrun_form_schema = ClearDagRunStateFormSchema() dagruns_batch_form_schema = DagRunsBatchFormSchema() +set_dagrun_note_form_schema = SetDagRunNoteFormSchema() diff --git a/airflow/api_connexion/schemas/dag_schema.py b/airflow/api_connexion/schemas/dag_schema.py index 2f369113290d9..182bbf180334f 100644 --- a/airflow/api_connexion/schemas/dag_schema.py +++ b/airflow/api_connexion/schemas/dag_schema.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import List, NamedTuple +from typing import NamedTuple from itsdangerous import URLSafeSerializer from marshmallow import Schema, fields @@ -28,10 +29,10 @@ class DagTagSchema(SQLAlchemySchema): - """Dag Tag schema""" + """Dag Tag schema.""" class Meta: - """Meta""" + """Meta.""" model = DagTag @@ -39,10 +40,10 @@ class Meta: class DAGSchema(SQLAlchemySchema): - """DAG schema""" + """DAG schema.""" class Meta: - """Meta""" + """Meta.""" model = DagModel @@ -75,20 +76,20 @@ class Meta: @staticmethod def get_owners(obj: DagModel): - """Convert owners attribute to DAG representation""" - if not getattr(obj, 'owners', None): + """Convert owners attribute to DAG representation.""" + if not getattr(obj, "owners", None): return [] return obj.owners.split(",") @staticmethod def get_token(obj: DagModel): - """Return file token""" - serializer = URLSafeSerializer(conf.get('webserver', 'secret_key')) + """Return file token.""" + serializer = URLSafeSerializer(conf.get_mandatory_value("webserver", "secret_key")) return serializer.dumps(obj.fileloc) class DAGDetailSchema(DAGSchema): - """DAG details""" + """DAG details.""" owners = fields.Method("get_owners", dump_only=True) timezone = TimezoneField() @@ -100,7 +101,7 @@ class DAGDetailSchema(DAGSchema): dag_run_timeout = fields.Nested(TimeDeltaSchema, attribute="dagrun_timeout") doc_md = fields.String() default_view = fields.String() - params = fields.Method('get_params', dump_only=True) + params = fields.Method("get_params", dump_only=True) tags = fields.Method("get_tags", dump_only=True) # type: ignore is_paused = fields.Method("get_is_paused", dump_only=True) is_active = fields.Method("get_is_active", dump_only=True) @@ -108,7 +109,7 @@ class DAGDetailSchema(DAGSchema): end_date = fields.DateTime(dump_only=True) template_search_path = fields.String(dump_only=True) render_template_as_native_obj = fields.Boolean(dump_only=True) - last_loaded = fields.DateTime(dump_only=True, data_key='last_parsed') + last_loaded = fields.DateTime(dump_only=True, data_key="last_parsed") @staticmethod def get_concurrency(obj: DAG): @@ -116,7 +117,7 @@ def get_concurrency(obj: DAG): @staticmethod def get_tags(obj: DAG): - """Dumps tags as objects""" + """Dumps tags as objects.""" tags = obj.tags if tags: return [DagTagSchema().dump(dict(name=tag)) for tag in tags] @@ -124,37 +125,37 @@ def get_tags(obj: DAG): @staticmethod def get_owners(obj: DAG): - """Convert owners attribute to DAG representation""" - if not getattr(obj, 'owner', None): + """Convert owners attribute to DAG representation.""" + if not getattr(obj, "owner", None): return [] return obj.owner.split(",") @staticmethod def get_is_paused(obj: DAG): - """Checks entry in DAG table to see if this DAG is paused""" + """Checks entry in DAG table to see if this DAG is paused.""" return obj.get_is_paused() @staticmethod def get_is_active(obj: DAG): - """Checks entry in DAG table to see if this DAG is active""" + """Checks entry in DAG table to see if this DAG is active.""" return obj.get_is_active() @staticmethod def get_params(obj: DAG): - """Get the Params defined in a DAG""" + """Get the Params defined in a DAG.""" params = obj.params return {k: v.dump() for k, v in params.items()} class DAGCollection(NamedTuple): - """List of DAGs with metadata""" + """List of DAGs with metadata.""" - dags: List[DagModel] + dags: list[DagModel] total_entries: int class DAGCollectionSchema(Schema): - """DAG Collection schema""" + """DAG Collection schema.""" dags = fields.List(fields.Nested(DAGSchema)) total_entries = fields.Int() diff --git a/airflow/api_connexion/schemas/dag_source_schema.py b/airflow/api_connexion/schemas/dag_source_schema.py index d142454bc1f6d..adb89ce76a569 100644 --- a/airflow/api_connexion/schemas/dag_source_schema.py +++ b/airflow/api_connexion/schemas/dag_source_schema.py @@ -14,12 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from marshmallow import Schema, fields class DagSourceSchema(Schema): - """Dag Source schema""" + """Dag Source schema.""" content = fields.String(dump_only=True) diff --git a/airflow/api_connexion/schemas/dag_warning_schema.py b/airflow/api_connexion/schemas/dag_warning_schema.py new file mode 100644 index 0000000000000..35c9830d273c7 --- /dev/null +++ b/airflow/api_connexion/schemas/dag_warning_schema.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import NamedTuple + +from marshmallow import Schema, fields +from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field + +from airflow.models.dagwarning import DagWarning + + +class DagWarningSchema(SQLAlchemySchema): + """Import error schema.""" + + class Meta: + """Meta.""" + + model = DagWarning + + dag_id = auto_field(data_key="dag_id", dump_only=True) + warning_type = auto_field() + message = auto_field() + timestamp = auto_field(format="iso") + + +class DagWarningCollection(NamedTuple): + """List of dag warnings with metadata.""" + + dag_warnings: list[DagWarning] + total_entries: int + + +class DagWarningCollectionSchema(Schema): + """Import error collection schema.""" + + dag_warnings = fields.List(fields.Nested(DagWarningSchema)) + total_entries = fields.Int() + + +dag_warning_schema = DagWarningSchema() +dag_warning_collection_schema = DagWarningCollectionSchema() diff --git a/airflow/api_connexion/schemas/dataset_schema.py b/airflow/api_connexion/schemas/dataset_schema.py new file mode 100644 index 0000000000000..bfdd0d24231c8 --- /dev/null +++ b/airflow/api_connexion/schemas/dataset_schema.py @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import NamedTuple + +from marshmallow import Schema, fields +from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field + +from airflow.api_connexion.schemas.common_schema import JsonObjectField +from airflow.models.dagrun import DagRun +from airflow.models.dataset import ( + DagScheduleDatasetReference, + DatasetEvent, + DatasetModel, + TaskOutletDatasetReference, +) + + +class TaskOutletDatasetReferenceSchema(SQLAlchemySchema): + """TaskOutletDatasetReference DB schema.""" + + class Meta: + """Meta.""" + + model = TaskOutletDatasetReference + + dag_id = auto_field() + task_id = auto_field() + created_at = auto_field() + updated_at = auto_field() + + +class DagScheduleDatasetReferenceSchema(SQLAlchemySchema): + """DagScheduleDatasetReference DB schema.""" + + class Meta: + """Meta.""" + + model = DagScheduleDatasetReference + + dag_id = auto_field() + created_at = auto_field() + updated_at = auto_field() + + +class DatasetSchema(SQLAlchemySchema): + """Dataset DB schema.""" + + class Meta: + """Meta.""" + + model = DatasetModel + + id = auto_field() + uri = auto_field() + extra = JsonObjectField() + created_at = auto_field() + updated_at = auto_field() + producing_tasks = fields.List(fields.Nested(TaskOutletDatasetReferenceSchema)) + consuming_dags = fields.List(fields.Nested(DagScheduleDatasetReferenceSchema)) + + +class DatasetCollection(NamedTuple): + """List of Datasets with meta.""" + + datasets: list[DatasetModel] + total_entries: int + + +class DatasetCollectionSchema(Schema): + """Dataset Collection Schema.""" + + datasets = fields.List(fields.Nested(DatasetSchema)) + total_entries = fields.Int() + + +dataset_schema = DatasetSchema() +dataset_collection_schema = DatasetCollectionSchema() + + +class BasicDAGRunSchema(SQLAlchemySchema): + """Basic Schema for DAGRun.""" + + class Meta: + """Meta.""" + + model = DagRun + dateformat = "iso" + + run_id = auto_field(data_key="dag_run_id") + dag_id = auto_field(dump_only=True) + execution_date = auto_field(data_key="logical_date", dump_only=True) + start_date = auto_field(dump_only=True) + end_date = auto_field(dump_only=True) + state = auto_field(dump_only=True) + data_interval_start = auto_field(dump_only=True) + data_interval_end = auto_field(dump_only=True) + + +class DatasetEventSchema(SQLAlchemySchema): + """Dataset Event DB schema.""" + + class Meta: + """Meta.""" + + model = DatasetEvent + + id = auto_field() + dataset_id = auto_field() + dataset_uri = fields.String(attribute="dataset.uri", dump_only=True) + extra = JsonObjectField() + source_task_id = auto_field() + source_dag_id = auto_field() + source_run_id = auto_field() + source_map_index = auto_field() + created_dagruns = fields.List(fields.Nested(BasicDAGRunSchema)) + timestamp = auto_field() + + +class DatasetEventCollection(NamedTuple): + """List of Dataset events with meta.""" + + dataset_events: list[DatasetEvent] + total_entries: int + + +class DatasetEventCollectionSchema(Schema): + """Dataset Event Collection Schema.""" + + dataset_events = fields.List(fields.Nested(DatasetEventSchema)) + total_entries = fields.Int() + + +dataset_event_schema = DatasetEventSchema() +dataset_event_collection_schema = DatasetEventCollectionSchema() diff --git a/airflow/api_connexion/schemas/enum_schemas.py b/airflow/api_connexion/schemas/enum_schemas.py index 71faf9fa20436..981a3669b1b58 100644 --- a/airflow/api_connexion/schemas/enum_schemas.py +++ b/airflow/api_connexion/schemas/enum_schemas.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from marshmallow import fields, validate @@ -21,7 +22,7 @@ class DagStateField(fields.String): - """Schema for DagState Enum""" + """Schema for DagState Enum.""" def __init__(self, **metadata): super().__init__(**metadata) @@ -29,7 +30,7 @@ def __init__(self, **metadata): class TaskInstanceStateField(fields.String): - """Schema for TaskInstanceState Enum""" + """Schema for TaskInstanceState Enum.""" def __init__(self, **metadata): super().__init__(**metadata) diff --git a/airflow/api_connexion/schemas/error_schema.py b/airflow/api_connexion/schemas/error_schema.py index c9462b5f967a8..dcd4d37ff7781 100644 --- a/airflow/api_connexion/schemas/error_schema.py +++ b/airflow/api_connexion/schemas/error_schema.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List, NamedTuple +from __future__ import annotations + +from typing import NamedTuple from marshmallow import Schema, fields from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field @@ -23,10 +25,10 @@ class ImportErrorSchema(SQLAlchemySchema): - """Import error schema""" + """Import error schema.""" class Meta: - """Meta""" + """Meta.""" model = ImportError @@ -39,14 +41,14 @@ class Meta: class ImportErrorCollection(NamedTuple): - """List of import errors with metadata""" + """List of import errors with metadata.""" - import_errors: List[ImportError] + import_errors: list[ImportError] total_entries: int class ImportErrorCollectionSchema(Schema): - """Import error collection schema""" + """Import error collection schema.""" import_errors = fields.List(fields.Nested(ImportErrorSchema)) total_entries = fields.Int() diff --git a/airflow/api_connexion/schemas/event_log_schema.py b/airflow/api_connexion/schemas/event_log_schema.py index d97c223bffa23..5bf4ccf00d0b3 100644 --- a/airflow/api_connexion/schemas/event_log_schema.py +++ b/airflow/api_connexion/schemas/event_log_schema.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import List, NamedTuple +from typing import NamedTuple from marshmallow import Schema, fields from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field @@ -24,15 +25,15 @@ class EventLogSchema(SQLAlchemySchema): - """Event log schema""" + """Event log schema.""" class Meta: - """Meta""" + """Meta.""" model = Log - id = auto_field(data_key='event_log_id', dump_only=True) - dttm = auto_field(data_key='when', dump_only=True) + id = auto_field(data_key="event_log_id", dump_only=True) + dttm = auto_field(data_key="when", dump_only=True) dag_id = auto_field(dump_only=True) task_id = auto_field(dump_only=True) event = auto_field(dump_only=True) @@ -42,14 +43,14 @@ class Meta: class EventLogCollection(NamedTuple): - """List of import errors with metadata""" + """List of import errors with metadata.""" - event_logs: List[Log] + event_logs: list[Log] total_entries: int class EventLogCollectionSchema(Schema): - """EventLog Collection Schema""" + """EventLog Collection Schema.""" event_logs = fields.List(fields.Nested(EventLogSchema)) total_entries = fields.Int() diff --git a/airflow/api_connexion/schemas/health_schema.py b/airflow/api_connexion/schemas/health_schema.py index 7089babb6230b..67155406c1a79 100644 --- a/airflow/api_connexion/schemas/health_schema.py +++ b/airflow/api_connexion/schemas/health_schema.py @@ -14,28 +14,29 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from marshmallow import Schema, fields class BaseInfoSchema(Schema): - """Base status field for metadatabase and scheduler""" + """Base status field for metadatabase and scheduler.""" status = fields.String(dump_only=True) class MetaDatabaseInfoSchema(BaseInfoSchema): - """Schema for Metadatabase info""" + """Schema for Metadatabase info.""" class SchedulerInfoSchema(BaseInfoSchema): - """Schema for Metadatabase info""" + """Schema for Metadatabase info.""" latest_scheduler_heartbeat = fields.String(dump_only=True) class HealthInfoSchema(Schema): - """Schema for the Health endpoint""" + """Schema for the Health endpoint.""" metadatabase = fields.Nested(MetaDatabaseInfoSchema) scheduler = fields.Nested(SchedulerInfoSchema) diff --git a/airflow/api_connexion/schemas/job_schema.py b/airflow/api_connexion/schemas/job_schema.py new file mode 100644 index 0000000000000..4d98d39c92030 --- /dev/null +++ b/airflow/api_connexion/schemas/job_schema.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field + +from airflow.jobs.base_job import BaseJob + + +class JobSchema(SQLAlchemySchema): + """Sla Miss Schema.""" + + class Meta: + """Meta.""" + + model = BaseJob + + id = auto_field(dump_only=True) + dag_id = auto_field(dump_only=True) + state = auto_field(dump_only=True) + job_type = auto_field(dump_only=True) + start_date = auto_field(dump_only=True) + end_date = auto_field(dump_only=True) + latest_heartbeat = auto_field(dump_only=True) + executor_class = auto_field(dump_only=True) + hostname = auto_field(dump_only=True) + unixname = auto_field(dump_only=True) diff --git a/airflow/api_connexion/schemas/log_schema.py b/airflow/api_connexion/schemas/log_schema.py index eff97e1723d63..82e291fafc42c 100644 --- a/airflow/api_connexion/schemas/log_schema.py +++ b/airflow/api_connexion/schemas/log_schema.py @@ -14,23 +14,25 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import NamedTuple, Optional +from __future__ import annotations + +from typing import NamedTuple from marshmallow import Schema, fields class LogsSchema(Schema): - """Schema for logs""" + """Schema for logs.""" content = fields.Str() continuation_token = fields.Str() class LogResponseObject(NamedTuple): - """Log Response Object""" + """Log Response Object.""" content: str - continuation_token: Optional[str] + continuation_token: str | None logs_schema = LogsSchema() diff --git a/airflow/api_connexion/schemas/plugin_schema.py b/airflow/api_connexion/schemas/plugin_schema.py index 89704cd191be6..780fef17bf76a 100644 --- a/airflow/api_connexion/schemas/plugin_schema.py +++ b/airflow/api_connexion/schemas/plugin_schema.py @@ -14,37 +14,37 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import List, NamedTuple +from typing import NamedTuple from marshmallow import Schema, fields class PluginSchema(Schema): - """Plugin schema""" + """Plugin schema.""" - number = fields.Int() name = fields.String() hooks = fields.List(fields.String()) executors = fields.List(fields.String()) - macros = fields.List(fields.String()) - flask_blueprints = fields.List(fields.String()) - appbuilder_views = fields.List(fields.String()) + macros = fields.List(fields.Dict()) + flask_blueprints = fields.List(fields.Dict()) + appbuilder_views = fields.List(fields.Dict()) appbuilder_menu_items = fields.List(fields.Dict()) - global_operator_extra_links = fields.List(fields.String()) - operator_extra_links = fields.List(fields.String()) + global_operator_extra_links = fields.List(fields.Dict()) + operator_extra_links = fields.List(fields.Dict()) source = fields.String() class PluginCollection(NamedTuple): - """Plugin List""" + """Plugin List.""" - plugins: List + plugins: list total_entries: int class PluginCollectionSchema(Schema): - """Plugin Collection List""" + """Plugin Collection List.""" plugins = fields.List(fields.Nested(PluginSchema)) total_entries = fields.Int() diff --git a/airflow/api_connexion/schemas/pool_schema.py b/airflow/api_connexion/schemas/pool_schema.py index 1f91b0d5bdd1c..4e25287d1d357 100644 --- a/airflow/api_connexion/schemas/pool_schema.py +++ b/airflow/api_connexion/schemas/pool_schema.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import List, NamedTuple +from typing import NamedTuple from marshmallow import Schema, fields from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field @@ -24,10 +25,10 @@ class PoolSchema(SQLAlchemySchema): - """Pool schema""" + """Pool schema.""" class Meta: - """Meta""" + """Meta.""" model = Pool @@ -36,6 +37,7 @@ class Meta: occupied_slots = fields.Method("get_occupied_slots", dump_only=True) running_slots = fields.Method("get_running_slots", dump_only=True) queued_slots = fields.Method("get_queued_slots", dump_only=True) + scheduled_slots = fields.Method("get_scheduled_slots", dump_only=True) open_slots = fields.Method("get_open_slots", dump_only=True) description = auto_field() @@ -54,6 +56,11 @@ def get_queued_slots(obj: Pool) -> int: """Returns the queued slots of the pool.""" return obj.queued_slots() + @staticmethod + def get_scheduled_slots(obj: Pool) -> int: + """Returns the scheduled slots of the pool.""" + return obj.scheduled_slots() + @staticmethod def get_open_slots(obj: Pool) -> float: """Returns the open slots of the pool.""" @@ -61,14 +68,14 @@ def get_open_slots(obj: Pool) -> float: class PoolCollection(NamedTuple): - """List of Pools with metadata""" + """List of Pools with metadata.""" - pools: List[Pool] + pools: list[Pool] total_entries: int class PoolCollectionSchema(Schema): - """Pool Collection schema""" + """Pool Collection schema.""" pools = fields.List(fields.Nested(PoolSchema)) total_entries = fields.Int() diff --git a/airflow/api_connexion/schemas/provider_schema.py b/airflow/api_connexion/schemas/provider_schema.py index 4c9867380bd99..ad62f4ae26f10 100644 --- a/airflow/api_connexion/schemas/provider_schema.py +++ b/airflow/api_connexion/schemas/provider_schema.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import List, NamedTuple +from typing import NamedTuple from marshmallow import Schema, fields @@ -41,7 +42,7 @@ class Provider(TypedDict): class ProviderCollection(NamedTuple): """List of Providers.""" - providers: List[Provider] + providers: list[Provider] total_entries: int diff --git a/airflow/api_connexion/schemas/role_and_permission_schema.py b/airflow/api_connexion/schemas/role_and_permission_schema.py index 4031750199f83..324336c288668 100644 --- a/airflow/api_connexion/schemas/role_and_permission_schema.py +++ b/airflow/api_connexion/schemas/role_and_permission_schema.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import List, NamedTuple +from typing import NamedTuple from marshmallow import Schema, fields from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field @@ -24,10 +25,10 @@ class ActionSchema(SQLAlchemySchema): - """Action Action Schema""" + """Action Schema.""" class Meta: - """Meta""" + """Meta.""" model = Action @@ -35,10 +36,10 @@ class Meta: class ResourceSchema(SQLAlchemySchema): - """View menu Schema""" + """View menu Schema.""" class Meta: - """Meta""" + """Meta.""" model = Resource @@ -46,24 +47,24 @@ class Meta: class ActionCollection(NamedTuple): - """Action Action Collection""" + """Action Collection.""" - actions: List[Action] + actions: list[Action] total_entries: int class ActionCollectionSchema(Schema): - """Permissions list schema""" + """Permissions list schema.""" actions = fields.List(fields.Nested(ActionSchema)) total_entries = fields.Int() class ActionResourceSchema(SQLAlchemySchema): - """Action View Schema""" + """Action View Schema.""" class Meta: - """Meta""" + """Meta.""" model = Permission @@ -72,26 +73,26 @@ class Meta: class RoleSchema(SQLAlchemySchema): - """Role item schema""" + """Role item schema.""" class Meta: - """Meta""" + """Meta.""" model = Role name = auto_field() - permissions = fields.List(fields.Nested(ActionResourceSchema), data_key='actions') + permissions = fields.List(fields.Nested(ActionResourceSchema), data_key="actions") class RoleCollection(NamedTuple): - """List of roles""" + """List of roles.""" - roles: List[Role] + roles: list[Role] total_entries: int class RoleCollectionSchema(Schema): - """List of roles""" + """List of roles.""" roles = fields.List(fields.Nested(RoleSchema)) total_entries = fields.Int() diff --git a/airflow/api_connexion/schemas/sla_miss_schema.py b/airflow/api_connexion/schemas/sla_miss_schema.py index 9413e37cbde21..97a462e186d59 100644 --- a/airflow/api_connexion/schemas/sla_miss_schema.py +++ b/airflow/api_connexion/schemas/sla_miss_schema.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field @@ -21,10 +22,10 @@ class SlaMissSchema(SQLAlchemySchema): - """Sla Miss Schema""" + """Sla Miss Schema.""" class Meta: - """Meta""" + """Meta.""" model = SlaMiss diff --git a/airflow/api_connexion/schemas/task_instance_schema.py b/airflow/api_connexion/schemas/task_instance_schema.py index 37005256f6cdc..970ef9a0fd4d6 100644 --- a/airflow/api_connexion/schemas/task_instance_schema.py +++ b/airflow/api_connexion/schemas/task_instance_schema.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import List, NamedTuple, Optional, Tuple +from typing import NamedTuple from marshmallow import Schema, ValidationError, fields, validate, validates_schema from marshmallow.utils import get_value @@ -24,17 +25,19 @@ from airflow.api_connexion.parameters import validate_istimezone from airflow.api_connexion.schemas.common_schema import JsonObjectField from airflow.api_connexion.schemas.enum_schemas import TaskInstanceStateField +from airflow.api_connexion.schemas.job_schema import JobSchema from airflow.api_connexion.schemas.sla_miss_schema import SlaMissSchema +from airflow.api_connexion.schemas.trigger_schema import TriggerSchema from airflow.models import SlaMiss, TaskInstance from airflow.utils.helpers import exactly_one from airflow.utils.state import State class TaskInstanceSchema(SQLAlchemySchema): - """Task instance schema""" + """Task instance schema.""" class Meta: - """Meta""" + """Meta.""" model = TaskInstance @@ -59,8 +62,11 @@ class Meta: queued_dttm = auto_field(data_key="queued_when") pid = auto_field() executor_config = auto_field() + note = auto_field() sla_miss = fields.Nested(SlaMissSchema, dump_default=None) - rendered_fields = JsonObjectField(default={}) + rendered_fields = JsonObjectField(dump_default={}) + trigger = fields.Nested(TriggerSchema) + triggerer_job = fields.Nested(JobSchema) def get_attribute(self, obj, attr, default): if attr == "sla_miss": @@ -75,21 +81,21 @@ def get_attribute(self, obj, attr, default): class TaskInstanceCollection(NamedTuple): - """List of task instances with metadata""" + """List of task instances with metadata.""" - task_instances: List[Tuple[TaskInstance, Optional[SlaMiss]]] + task_instances: list[tuple[TaskInstance, SlaMiss | None]] total_entries: int class TaskInstanceCollectionSchema(Schema): - """Task instance collection schema""" + """Task instance collection schema.""" task_instances = fields.List(fields.Nested(TaskInstanceSchema)) total_entries = fields.Int() class TaskInstanceBatchFormSchema(Schema): - """Schema for the request form passed to Task Instance Batch endpoint""" + """Schema for the request form passed to Task Instance Batch endpoint.""" page_offset = fields.Int(load_default=0, validate=validate.Range(min=0)) page_limit = fields.Int(load_default=100, validate=validate.Range(min=1)) @@ -108,7 +114,7 @@ class TaskInstanceBatchFormSchema(Schema): class ClearTaskInstanceFormSchema(Schema): - """Schema for handling the request of clearing task instance of a Dag""" + """Schema for handling the request of clearing task instance of a Dag.""" dry_run = fields.Boolean(load_default=True) start_date = fields.DateTime(load_default=None, validate=validate_istimezone) @@ -119,19 +125,30 @@ class ClearTaskInstanceFormSchema(Schema): include_parentdag = fields.Boolean(load_default=False) reset_dag_runs = fields.Boolean(load_default=False) task_ids = fields.List(fields.String(), validate=validate.Length(min=1)) + dag_run_id = fields.Str(load_default=None) + include_upstream = fields.Boolean(load_default=False) + include_downstream = fields.Boolean(load_default=False) + include_future = fields.Boolean(load_default=False) + include_past = fields.Boolean(load_default=False) @validates_schema def validate_form(self, data, **kwargs): - """Validates clear task instance form""" + """Validates clear task instance form.""" if data["only_failed"] and data["only_running"]: raise ValidationError("only_failed and only_running both are set to True") if data["start_date"] and data["end_date"]: if data["start_date"] > data["end_date"]: raise ValidationError("end_date is sooner than start_date") + if data["start_date"] and data["end_date"] and data["dag_run_id"]: + raise ValidationError("Exactly one of dag_run_id or (start_date and end_date) must be provided") + if data["start_date"] and data["dag_run_id"]: + raise ValidationError("Exactly one of dag_run_id or start_date must be provided") + if data["end_date"] and data["dag_run_id"]: + raise ValidationError("Exactly one of dag_run_id or end_date must be provided") class SetTaskInstanceStateFormSchema(Schema): - """Schema for handling the request of setting state of task instance of a DAG""" + """Schema for handling the request of setting state of task instance of a DAG.""" dry_run = fields.Boolean(dump_default=True) task_id = fields.Str(required=True) @@ -145,13 +162,20 @@ class SetTaskInstanceStateFormSchema(Schema): @validates_schema def validate_form(self, data, **kwargs): - """Validates set task instance state form""" + """Validates set task instance state form.""" if not exactly_one(data.get("execution_date"), data.get("dag_run_id")): raise ValidationError("Exactly one of execution_date or dag_run_id must be provided") +class SetSingleTaskInstanceStateFormSchema(Schema): + """Schema for handling the request of updating state of a single task instance.""" + + dry_run = fields.Boolean(dump_default=True) + new_state = TaskInstanceStateField(required=True, validate=validate.OneOf([State.SUCCESS, State.FAILED])) + + class TaskInstanceReferenceSchema(Schema): - """Schema for the task instance reference schema""" + """Schema for the task instance reference schema.""" task_id = fields.Str() run_id = fields.Str(data_key="dag_run_id") @@ -160,21 +184,31 @@ class TaskInstanceReferenceSchema(Schema): class TaskInstanceReferenceCollection(NamedTuple): - """List of objects with metadata about taskinstance and dag_run_id""" + """List of objects with metadata about taskinstance and dag_run_id.""" - task_instances: List[Tuple[TaskInstance, str]] + task_instances: list[tuple[TaskInstance, str]] class TaskInstanceReferenceCollectionSchema(Schema): - """Collection schema for task reference""" + """Collection schema for task reference.""" task_instances = fields.List(fields.Nested(TaskInstanceReferenceSchema)) +class SetTaskInstanceNoteFormSchema(Schema): + """Schema for settings a note for a TaskInstance.""" + + # Note: We can't add map_index to the url as subpaths can't start with dashes. + map_index = fields.Int(allow_none=False) + note = fields.String(allow_none=True, validate=validate.Length(max=1000)) + + task_instance_schema = TaskInstanceSchema() task_instance_collection_schema = TaskInstanceCollectionSchema() task_instance_batch_form = TaskInstanceBatchFormSchema() clear_task_instance_form = ClearTaskInstanceFormSchema() set_task_instance_state_form = SetTaskInstanceStateFormSchema() +set_single_task_instance_state_form = SetSingleTaskInstanceStateFormSchema() task_instance_reference_schema = TaskInstanceReferenceSchema() task_instance_reference_collection_schema = TaskInstanceReferenceCollectionSchema() +set_task_instance_note_form_schema = SetTaskInstanceNoteFormSchema() diff --git a/airflow/api_connexion/schemas/task_schema.py b/airflow/api_connexion/schemas/task_schema.py index aa9a4703088b0..5715ca2ea0fa3 100644 --- a/airflow/api_connexion/schemas/task_schema.py +++ b/airflow/api_connexion/schemas/task_schema.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import List, NamedTuple +from typing import NamedTuple from marshmallow import Schema, fields @@ -26,13 +27,15 @@ WeightRuleField, ) from airflow.api_connexion.schemas.dag_schema import DAGSchema +from airflow.models.mappedoperator import MappedOperator from airflow.models.operator import Operator class TaskSchema(Schema): - """Task schema""" + """Task schema.""" class_ref = fields.Method("_get_class_reference", dump_only=True) + operator_name = fields.Method("_get_operator_name", dump_only=True) task_id = fields.String(dump_only=True) owner = fields.String(dump_only=True) start_date = fields.DateTime(dump_only=True) @@ -57,29 +60,38 @@ class TaskSchema(Schema): template_fields = fields.List(fields.String(), dump_only=True) sub_dag = fields.Nested(DAGSchema, dump_only=True) downstream_task_ids = fields.List(fields.String(), dump_only=True) - params = fields.Method('get_params', dump_only=True) - is_mapped = fields.Boolean(dump_only=True) + params = fields.Method("_get_params", dump_only=True) + is_mapped = fields.Method("_get_is_mapped", dump_only=True) - def _get_class_reference(self, obj): + @staticmethod + def _get_class_reference(obj): result = ClassReferenceSchema().dump(obj) return result.data if hasattr(result, "data") else result @staticmethod - def get_params(obj): - """Get the Params defined in a Task""" + def _get_operator_name(obj): + return obj.operator_name + + @staticmethod + def _get_params(obj): + """Get the Params defined in a Task.""" params = obj.params return {k: v.dump() for k, v in params.items()} + @staticmethod + def _get_is_mapped(obj): + return isinstance(obj, MappedOperator) + class TaskCollection(NamedTuple): - """List of Tasks with metadata""" + """List of Tasks with metadata.""" - tasks: List[Operator] + tasks: list[Operator] total_entries: int class TaskCollectionSchema(Schema): - """Schema for TaskCollection""" + """Schema for TaskCollection.""" tasks = fields.List(fields.Nested(TaskSchema)) total_entries = fields.Int() diff --git a/airflow/api_connexion/schemas/trigger_schema.py b/airflow/api_connexion/schemas/trigger_schema.py new file mode 100644 index 0000000000000..15d180a5732ff --- /dev/null +++ b/airflow/api_connexion/schemas/trigger_schema.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field + +from airflow.models import Trigger + + +class TriggerSchema(SQLAlchemySchema): + """Sla Miss Schema.""" + + class Meta: + """Meta.""" + + model = Trigger + + id = auto_field(dump_only=True) + classpath = auto_field(dump_only=True) + kwargs = auto_field(dump_only=True) + created_date = auto_field(dump_only=True) + triggerer_id = auto_field(dump_only=True) diff --git a/airflow/api_connexion/schemas/user_schema.py b/airflow/api_connexion/schemas/user_schema.py index 3d36aa91c8031..843ad32f0245c 100644 --- a/airflow/api_connexion/schemas/user_schema.py +++ b/airflow/api_connexion/schemas/user_schema.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List, NamedTuple +from __future__ import annotations + +from typing import NamedTuple from marshmallow import Schema, fields from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field @@ -25,10 +27,10 @@ class UserCollectionItemSchema(SQLAlchemySchema): - """user collection item schema""" + """user collection item schema.""" class Meta: - """Meta""" + """Meta.""" model = User dateformat = "iso" @@ -41,26 +43,26 @@ class Meta: last_login = auto_field(dump_only=True) login_count = auto_field(dump_only=True) fail_login_count = auto_field(dump_only=True) - roles = fields.List(fields.Nested(RoleSchema, only=('name',))) + roles = fields.List(fields.Nested(RoleSchema, only=("name",))) created_on = auto_field(validate=validate_istimezone, dump_only=True) changed_on = auto_field(validate=validate_istimezone, dump_only=True) class UserSchema(UserCollectionItemSchema): - """User schema""" + """User schema.""" password = auto_field(load_only=True) class UserCollection(NamedTuple): - """User collection""" + """User collection.""" - users: List[User] + users: list[User] total_entries: int class UserCollectionSchema(Schema): - """User collection schema""" + """User collection schema.""" users = fields.List(fields.Nested(UserCollectionItemSchema)) total_entries = fields.Int() diff --git a/airflow/api_connexion/schemas/variable_schema.py b/airflow/api_connexion/schemas/variable_schema.py index 6b5d16e4227d6..ffe54f0742907 100644 --- a/airflow/api_connexion/schemas/variable_schema.py +++ b/airflow/api_connexion/schemas/variable_schema.py @@ -14,19 +14,21 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from marshmallow import Schema, fields class VariableSchema(Schema): - """Variable Schema""" + """Variable Schema.""" key = fields.String(required=True) value = fields.String(attribute="val", required=True) + description = fields.String(attribute="description", required=False) class VariableCollectionSchema(Schema): - """Variable Collection Schema""" + """Variable Collection Schema.""" variables = fields.List(fields.Nested(VariableSchema)) total_entries = fields.Int() diff --git a/airflow/api_connexion/schemas/version_schema.py b/airflow/api_connexion/schemas/version_schema.py index 24bd9337c1c36..519f91c55e816 100644 --- a/airflow/api_connexion/schemas/version_schema.py +++ b/airflow/api_connexion/schemas/version_schema.py @@ -14,12 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from marshmallow import Schema, fields class VersionInfoSchema(Schema): - """Version information schema""" + """Version information schema.""" version = fields.String(dump_only=True) git_version = fields.String(dump_only=True) diff --git a/airflow/api_connexion/schemas/xcom_schema.py b/airflow/api_connexion/schemas/xcom_schema.py index b3f3f0dd021a2..09d2505bf7d4d 100644 --- a/airflow/api_connexion/schemas/xcom_schema.py +++ b/airflow/api_connexion/schemas/xcom_schema.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List, NamedTuple +from __future__ import annotations + +from typing import NamedTuple from marshmallow import Schema, fields from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field @@ -23,10 +25,10 @@ class XComCollectionItemSchema(SQLAlchemySchema): - """Schema for a xcom item""" + """Schema for a xcom item.""" class Meta: - """Meta""" + """Meta.""" model = XCom @@ -38,20 +40,20 @@ class Meta: class XComSchema(XComCollectionItemSchema): - """XCom schema""" + """XCom schema.""" value = auto_field() class XComCollection(NamedTuple): - """List of XComs with meta""" + """List of XComs with meta.""" - xcom_entries: List[XCom] + xcom_entries: list[XCom] total_entries: int class XComCollectionSchema(Schema): - """XCom Collection Schema""" + """XCom Collection Schema.""" xcom_entries = fields.List(fields.Nested(XComCollectionItemSchema)) total_entries = fields.Int() diff --git a/airflow/api_connexion/security.py b/airflow/api_connexion/security.py index 3562c98eb4b35..0e3e9b1f3748b 100644 --- a/airflow/api_connexion/security.py +++ b/airflow/api_connexion/security.py @@ -14,20 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from functools import wraps -from typing import Callable, Optional, Sequence, Tuple, TypeVar, cast +from typing import Callable, Sequence, TypeVar, cast -from flask import Response, current_app +from flask import Response from airflow.api_connexion.exceptions import PermissionDenied, Unauthenticated +from airflow.utils.airflow_flask_app import get_airflow_app T = TypeVar("T", bound=Callable) def check_authentication() -> None: """Checks that the request has valid authorization information.""" - for auth in current_app.api_auth: + for auth in get_airflow_app().api_auth: response = auth.requires_authentication(Response)() if response.status_code == 200: return @@ -36,16 +38,16 @@ def check_authentication() -> None: raise Unauthenticated(headers=response.headers) -def requires_access(permissions: Optional[Sequence[Tuple[str, str]]] = None) -> Callable[[T], T]: +def requires_access(permissions: Sequence[tuple[str, str]] | None = None) -> Callable[[T], T]: """Factory for decorator that checks current user's permissions against required permissions.""" - appbuilder = current_app.appbuilder + appbuilder = get_airflow_app().appbuilder appbuilder.sm.sync_resource_permissions(permissions) def requires_access_decorator(func: T): @wraps(func) def decorated(*args, **kwargs): check_authentication() - if appbuilder.sm.check_authorization(permissions, kwargs.get('dag_id')): + if appbuilder.sm.check_authorization(permissions, kwargs.get("dag_id")): return func(*args, **kwargs) raise PermissionDenied() diff --git a/airflow/api_connexion/types.py b/airflow/api_connexion/types.py index f640d14bd7e79..3a6f89d9bb52a 100644 --- a/airflow/api_connexion/types.py +++ b/airflow/api_connexion/types.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from typing import Any, Mapping, Optional, Sequence, Tuple, Union diff --git a/airflow/callbacks/base_callback_sink.py b/airflow/callbacks/base_callback_sink.py index e7cbf23e7b959..c243f0fbd640f 100644 --- a/airflow/callbacks/base_callback_sink.py +++ b/airflow/callbacks/base_callback_sink.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations from airflow.callbacks.callback_requests import CallbackRequest diff --git a/airflow/callbacks/callback_requests.py b/airflow/callbacks/callback_requests.py index 8112589cd0262..8ec0187978db6 100644 --- a/airflow/callbacks/callback_requests.py +++ b/airflow/callbacks/callback_requests.py @@ -14,9 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import json -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING if TYPE_CHECKING: from airflow.models.taskinstance import SimpleTaskInstance @@ -28,10 +29,17 @@ class CallbackRequest: :param full_filepath: File Path to use to run the callback :param msg: Additional Message that can be used for logging + :param processor_subdir: Directory used by Dag Processor when parsed the dag. """ - def __init__(self, full_filepath: str, msg: Optional[str] = None): + def __init__( + self, + full_filepath: str, + processor_subdir: str | None = None, + msg: str | None = None, + ): self.full_filepath = full_filepath + self.processor_subdir = processor_subdir self.msg = msg def __eq__(self, other): @@ -53,6 +61,8 @@ def from_json(cls, json_str: str): class TaskCallbackRequest(CallbackRequest): """ + Task callback status information. + A Class with information about the success/failure TI callback to be executed. Currently, only failure callbacks (when tasks are externally killed) and Zombies are run via DagFileProcessorProcess. @@ -60,31 +70,33 @@ class TaskCallbackRequest(CallbackRequest): :param simple_task_instance: Simplified Task Instance representation :param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback :param msg: Additional Message that can be used for logging to determine failure/zombie + :param processor_subdir: Directory used by Dag Processor when parsed the dag. """ def __init__( self, full_filepath: str, - simple_task_instance: "SimpleTaskInstance", - is_failure_callback: Optional[bool] = True, - msg: Optional[str] = None, + simple_task_instance: SimpleTaskInstance, + is_failure_callback: bool | None = True, + processor_subdir: str | None = None, + msg: str | None = None, ): - super().__init__(full_filepath=full_filepath, msg=msg) + super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg) self.simple_task_instance = simple_task_instance self.is_failure_callback = is_failure_callback def to_json(self) -> str: - dict_obj = self.__dict__.copy() - dict_obj["simple_task_instance"] = dict_obj["simple_task_instance"].__dict__ - return json.dumps(dict_obj) + from airflow.serialization.serialized_objects import BaseSerialization + + val = BaseSerialization.serialize(self.__dict__, strict=True) + return json.dumps(val) @classmethod def from_json(cls, json_str: str): - from airflow.models.taskinstance import SimpleTaskInstance + from airflow.serialization.serialized_objects import BaseSerialization - kwargs = json.loads(json_str) - simple_ti = SimpleTaskInstance.from_dict(obj_dict=kwargs.pop("simple_task_instance")) - return cls(simple_task_instance=simple_ti, **kwargs) + val = json.loads(json_str) + return cls(**BaseSerialization.deserialize(val)) class DagCallbackRequest(CallbackRequest): @@ -94,6 +106,7 @@ class DagCallbackRequest(CallbackRequest): :param full_filepath: File Path to use to run the callback :param dag_id: DAG ID :param run_id: Run ID for the DagRun + :param processor_subdir: Directory used by Dag Processor when parsed the dag. :param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback :param msg: Additional Message that can be used for logging """ @@ -103,10 +116,11 @@ def __init__( full_filepath: str, dag_id: str, run_id: str, - is_failure_callback: Optional[bool] = True, - msg: Optional[str] = None, + processor_subdir: str | None, + is_failure_callback: bool | None = True, + msg: str | None = None, ): - super().__init__(full_filepath=full_filepath, msg=msg) + super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg) self.dag_id = dag_id self.run_id = run_id self.is_failure_callback = is_failure_callback @@ -118,8 +132,15 @@ class SlaCallbackRequest(CallbackRequest): :param full_filepath: File Path to use to run the callback :param dag_id: DAG ID + :param processor_subdir: Directory used by Dag Processor when parsed the dag. """ - def __init__(self, full_filepath: str, dag_id: str, msg: Optional[str] = None): - super().__init__(full_filepath, msg) + def __init__( + self, + full_filepath: str, + dag_id: str, + processor_subdir: str | None, + msg: str | None = None, + ): + super().__init__(full_filepath, processor_subdir=processor_subdir, msg=msg) self.dag_id = dag_id diff --git a/airflow/callbacks/database_callback_sink.py b/airflow/callbacks/database_callback_sink.py index b9b81e6ed745a..24306170dfea2 100644 --- a/airflow/callbacks/database_callback_sink.py +++ b/airflow/callbacks/database_callback_sink.py @@ -15,13 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations from sqlalchemy.orm import Session from airflow.callbacks.base_callback_sink import BaseCallbackSink from airflow.callbacks.callback_requests import CallbackRequest -from airflow.models import DbCallbackRequest +from airflow.models.db_callback_request import DbCallbackRequest from airflow.utils.session import NEW_SESSION, provide_session diff --git a/airflow/callbacks/pipe_callback_sink.py b/airflow/callbacks/pipe_callback_sink.py index 1e11ffd4f5650..d702a781fa57c 100644 --- a/airflow/callbacks/pipe_callback_sink.py +++ b/airflow/callbacks/pipe_callback_sink.py @@ -15,7 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations + from multiprocessing.connection import Connection as MultiprocessingConnection from typing import Callable diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py index 60c77c7a37feb..33e513b586484 100644 --- a/airflow/cli/cli_parser.py +++ b/airflow/cli/cli_parser.py @@ -16,7 +16,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Command-line interface""" +"""Command-line interface.""" +from __future__ import annotations import argparse import json @@ -24,7 +25,7 @@ import textwrap from argparse import Action, ArgumentError, RawTextHelpFormatter from functools import lru_cache -from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Union +from typing import Callable, Iterable, NamedTuple, Union import lazy_object_proxy @@ -43,8 +44,8 @@ def lazy_load_command(import_path: str) -> Callable: - """Create a lazy loader for command""" - _, _, name = import_path.rpartition('.') + """Create a lazy loader for command.""" + _, _, name = import_path.rpartition(".") def command(*args, **kwargs): func = import_string(import_path) @@ -56,12 +57,12 @@ def command(*args, **kwargs): class DefaultHelpParser(argparse.ArgumentParser): - """CustomParser to display help message""" + """CustomParser to display help message.""" def _check_value(self, action, value): - """Override _check_value and check conditionally added command""" - if action.dest == 'subcommand' and value == 'celery': - executor = conf.get('core', 'EXECUTOR') + """Override _check_value and check conditionally added command.""" + if action.dest == "subcommand" and value == "celery": + executor = conf.get("core", "EXECUTOR") if executor not in (CELERY_EXECUTOR, CELERY_KUBERNETES_EXECUTOR): executor_cls, _ = ExecutorLoader.import_executor_cls(executor) classes = () @@ -83,12 +84,12 @@ def _check_value(self, action, value): pass if not issubclass(executor_cls, classes): message = ( - f'celery subcommand works only with CeleryExecutor, CeleryKubernetesExecutor and ' - f'executors derived from them, your current executor: {executor}, subclassed from: ' + f"celery subcommand works only with CeleryExecutor, CeleryKubernetesExecutor and " + f"executors derived from them, your current executor: {executor}, subclassed from: " f'{", ".join([base_cls.__qualname__ for base_cls in executor_cls.__bases__])}' ) raise ArgumentError(action, message) - if action.dest == 'subcommand' and value == 'kubernetes': + if action.dest == "subcommand" and value == "kubernetes": try: import kubernetes.client # noqa: F401 except ImportError: @@ -104,9 +105,9 @@ def _check_value(self, action, value): super()._check_value(action, value) def error(self, message): - """Override error and use print_instead of print_usage""" + """Override error and use print_instead of print_usage.""" self.print_help() - self.exit(2, f'\n{self.prog} command error: {message}, see help above.\n') + self.exit(2, f"\n{self.prog} command error: {message}, see help above.\n") # Used in Arg to enable `None' as a distinct value from "not passed" @@ -114,7 +115,7 @@ def error(self, message): class Arg: - """Class to keep information about command line argument""" + """Class to keep information about command line argument.""" def __init__( self, @@ -140,7 +141,7 @@ def __init__( self.kwargs[k] = v def add_to_parser(self, parser: argparse.ArgumentParser): - """Add this argument to an ArgumentParser""" + """Add this argument to an ArgumentParser.""" parser.add_argument(*self.flags, **self.kwargs) @@ -162,12 +163,12 @@ def _check(value): def string_list_type(val): - """Parses comma-separated list and returns list of string (strips whitespace)""" - return [x.strip() for x in val.split(',')] + """Parses comma-separated list and returns list of string (strips whitespace).""" + return [x.strip() for x in val.split(",")] def string_lower_type(val): - """Lowers arg""" + """Lowers arg.""" if not val: return return val.strip().lower() @@ -177,8 +178,16 @@ def string_lower_type(val): ARG_DAG_ID = Arg(("dag_id",), help="The id of the dag") ARG_TASK_ID = Arg(("task_id",), help="The id of the task") ARG_EXECUTION_DATE = Arg(("execution_date",), help="The execution date of the DAG", type=parsedate) +ARG_EXECUTION_DATE_OPTIONAL = Arg( + ("execution_date",), nargs="?", help="The execution date of the DAG (optional)", type=parsedate +) ARG_EXECUTION_DATE_OR_RUN_ID = Arg( - ('execution_date_or_run_id',), help="The execution_date of the DAG or run_id of the DAGRun" + ("execution_date_or_run_id",), help="The execution_date of the DAG or run_id of the DAGRun" +) +ARG_EXECUTION_DATE_OR_RUN_ID_OPTIONAL = Arg( + ("execution_date_or_run_id",), + nargs="?", + help="The execution_date of the DAG or run_id of the DAGRun (optional)", ) ARG_TASK_REGEX = Arg( ("-t", "--task-regex"), help="The regex to filter specific task_ids to backfill (optional)" @@ -190,7 +199,7 @@ def string_lower_type(val): "Defaults to '[AIRFLOW_HOME]/dags' where [AIRFLOW_HOME] is the " "value you set for 'AIRFLOW_HOME' config you set in 'airflow.cfg' " ), - default='[AIRFLOW_HOME]/dags' if BUILD_DOCS else settings.DAGS_FOLDER, + default="[AIRFLOW_HOME]/dags" if BUILD_DOCS else settings.DAGS_FOLDER, ) ARG_START_DATE = Arg(("-s", "--start-date"), help="Override start_date YYYY-MM-DD", type=parsedate) ARG_END_DATE = Arg(("-e", "--end-date"), help="Override end_date YYYY-MM-DD", type=parsedate) @@ -208,7 +217,7 @@ def string_lower_type(val): help="Perform a dry run for each task. Only renders Template Fields for each task, nothing else", action="store_true", ) -ARG_PID = Arg(("--pid",), help="PID file location", nargs='?') +ARG_PID = Arg(("--pid",), help="PID file location", nargs="?") ARG_DAEMON = Arg( ("-D", "--daemon"), help="Daemonize instead of running in the foreground", action="store_true" ) @@ -232,7 +241,7 @@ def string_lower_type(val): default="table", ) ARG_COLOR = Arg( - ('--color',), + ("--color",), help="Do emit colored output (default: auto)", choices={ColorMode.ON, ColorMode.OFF, ColorMode.AUTO}, default=ColorMode.AUTO, @@ -245,7 +254,7 @@ def string_lower_type(val): default=None, ) ARG_REVISION_RANGE = Arg( - ('--revision-range',), + ("--revision-range",), help=( "Migration revision range(start:end) to use for offline sql generation. " "Example: ``a13f7613ad25:7b2661a43ba3``" @@ -254,13 +263,16 @@ def string_lower_type(val): ) # list_dag_runs -ARG_DAG_ID_OPT = Arg(("-d", "--dag-id"), help="The id of the dag") +ARG_DAG_ID_REQ_FLAG = Arg( + ("-d", "--dag-id"), required=True, help="The id of the dag" +) # TODO: convert this to a positional arg in Airflow 3 ARG_NO_BACKFILL = Arg( ("--no-backfill",), help="filter all the backfill dagruns given the dag id", action="store_true" ) ARG_STATE = Arg(("--state",), help="Only list the dag runs corresponding to the state") # list_jobs +ARG_DAG_ID_OPT = Arg(("-d", "--dag-id"), help="The id of the dag") ARG_LIMIT = Arg(("--limit",), help="Return a limited number of records") # next_execution @@ -339,6 +351,11 @@ def string_lower_type(val): help=("if set, the backfill will keep going even if some of the tasks failed"), action="store_true", ) +ARG_DISABLE_RETRY = Arg( + ("--disable-retry",), + help=("if set, the backfill will set tasks as failed without retrying."), + action="store_true", +) ARG_RUN_BACKWARDS = Arg( ( "-B", @@ -351,6 +368,11 @@ def string_lower_type(val): ), action="store_true", ) +ARG_TREAT_DAG_AS_REGEX = Arg( + ("--treat-dag-as-regex",), + help=("if set, dag_id will be treated as regex instead of an exact string"), + action="store_true", +) # test_dag ARG_SHOW_DAGRUN = Arg( ("--show-dagrun",), @@ -359,7 +381,7 @@ def string_lower_type(val): "\n" "The diagram is in DOT language\n" ), - action='store_true', + action="store_true", ) ARG_IMGCAT_DAGRUN = Arg( ("--imgcat-dagrun",), @@ -367,7 +389,7 @@ def string_lower_type(val): "After completing the dag run, prints a diagram on the screen for the " "current DAG Run using the imgcat tool.\n" ), - action='store_true', + action="store_true", ) ARG_SAVE_DAGRUN = Arg( ("--save-dagrun",), @@ -405,11 +427,11 @@ def string_lower_type(val): # show_dag ARG_SAVE = Arg(("-s", "--save"), help="Saves the result to the indicated file.") -ARG_IMGCAT = Arg(("--imgcat",), help="Displays graph using the imgcat tool.", action='store_true') +ARG_IMGCAT = Arg(("--imgcat",), help="Displays graph using the imgcat tool.", action="store_true") # trigger_dag ARG_RUN_ID = Arg(("-r", "--run-id"), help="Helps to identify this run") -ARG_CONF = Arg(('-c', '--conf'), help="JSON string that gets pickled into the DagRun's conf attribute") +ARG_CONF = Arg(("-c", "--conf"), help="JSON string that gets pickled into the DagRun's conf attribute") ARG_EXEC_DATE = Arg(("-e", "--exec-date"), help="The execution date of the DAG", type=parsedate) # db @@ -434,10 +456,15 @@ def string_lower_type(val): help="Perform a dry run", action="store_true", ) +ARG_DB_SKIP_ARCHIVE = Arg( + ("--skip-archive",), + help="Don't preserve purged records in an archive table.", + action="store_true", +) # pool -ARG_POOL_NAME = Arg(("pool",), metavar='NAME', help="Pool name") +ARG_POOL_NAME = Arg(("pool",), metavar="NAME", help="Pool name") ARG_POOL_SLOTS = Arg(("slots",), type=int, help="Pool slots") ARG_POOL_DESCRIPTION = Arg(("description",), help="Pool description") ARG_POOL_IMPORT = Arg( @@ -446,11 +473,11 @@ def string_lower_type(val): help="Import pools from JSON file. Example format::\n" + textwrap.indent( textwrap.dedent( - ''' + """ { "pool_1": {"slots": 5, "description": ""}, "pool_2": {"slots": 10, "description": "test"} - }''' + }""" ), " " * 4, ), @@ -460,22 +487,23 @@ def string_lower_type(val): # variables ARG_VAR = Arg(("key",), help="Variable key") -ARG_VAR_VALUE = Arg(("value",), metavar='VALUE', help="Variable value") +ARG_VAR_VALUE = Arg(("value",), metavar="VALUE", help="Variable value") ARG_DEFAULT = Arg( ("-d", "--default"), metavar="VAL", default=None, help="Default value returned if variable does not exist" ) -ARG_JSON = Arg(("-j", "--json"), help="Deserialize JSON variable", action="store_true") +ARG_DESERIALIZE_JSON = Arg(("-j", "--json"), help="Deserialize JSON variable", action="store_true") +ARG_SERIALIZE_JSON = Arg(("-j", "--json"), help="Serialize JSON variable", action="store_true") ARG_VAR_IMPORT = Arg(("file",), help="Import variables from JSON file") ARG_VAR_EXPORT = Arg(("file",), help="Export all variables to JSON file") # kerberos -ARG_PRINCIPAL = Arg(("principal",), help="kerberos principal", nargs='?') -ARG_KEYTAB = Arg(("-k", "--keytab"), help="keytab", nargs='?', default=conf.get('kerberos', 'keytab')) +ARG_PRINCIPAL = Arg(("principal",), help="kerberos principal", nargs="?") +ARG_KEYTAB = Arg(("-k", "--keytab"), help="keytab", nargs="?", default=conf.get("kerberos", "keytab")) # run ARG_INTERACTIVE = Arg( - ('-N', '--interactive'), - help='Do not capture standard output and error streams (useful for interactive debugging)', - action='store_true', + ("-N", "--interactive"), + help="Do not capture standard output and error streams (useful for interactive debugging)", + action="store_true", ) # TODO(aoen): "force" is a poor choice of name here since it implies it overrides # all dependencies (not just past success), e.g. the ignore_depends_on_past @@ -514,7 +542,7 @@ def string_lower_type(val): ARG_PICKLE = Arg(("-p", "--pickle"), help="Serialized pickle object of the entire dag (used internally)") ARG_JOB_ID = Arg(("-j", "--job-id"), help=argparse.SUPPRESS) ARG_CFG_PATH = Arg(("--cfg-path",), help="Path to config file to use instead of airflow.cfg") -ARG_MAP_INDEX = Arg(('--map-index',), type=int, default=-1, help="Mapped task index") +ARG_MAP_INDEX = Arg(("--map-index",), type=int, default=-1, help="Mapped task index") # database @@ -524,6 +552,14 @@ def string_lower_type(val): type=int, default=60, ) +ARG_DB_RESERIALIZE_DAGS = Arg( + ("--no-reserialize-dags",), + # Not intended for user, so dont show in help + help=argparse.SUPPRESS, + action="store_false", + default=True, + dest="reserialize_dags", +) ARG_DB_VERSION__UPGRADE = Arg( ("-n", "--to-version"), help=( @@ -554,7 +590,7 @@ def string_lower_type(val): ARG_DB_SQL_ONLY = Arg( ("-s", "--show-sql-only"), help="Don't actually run migrations; just print out sql scripts for offline migration. " - "Required if using either `--from-version` or `--from-version`.", + "Required if using either `--from-revision` or `--from-version`.", action="store_true", default=False, ) @@ -568,41 +604,41 @@ def string_lower_type(val): # webserver ARG_PORT = Arg( ("-p", "--port"), - default=conf.get('webserver', 'WEB_SERVER_PORT'), + default=conf.get("webserver", "WEB_SERVER_PORT"), type=int, help="The port on which to run the server", ) ARG_SSL_CERT = Arg( ("--ssl-cert",), - default=conf.get('webserver', 'WEB_SERVER_SSL_CERT'), + default=conf.get("webserver", "WEB_SERVER_SSL_CERT"), help="Path to the SSL certificate for the webserver", ) ARG_SSL_KEY = Arg( ("--ssl-key",), - default=conf.get('webserver', 'WEB_SERVER_SSL_KEY'), + default=conf.get("webserver", "WEB_SERVER_SSL_KEY"), help="Path to the key to use with the SSL certificate", ) ARG_WORKERS = Arg( ("-w", "--workers"), - default=conf.get('webserver', 'WORKERS'), + default=conf.get("webserver", "WORKERS"), type=int, help="Number of workers to run the webserver on", ) ARG_WORKERCLASS = Arg( ("-k", "--workerclass"), - default=conf.get('webserver', 'WORKER_CLASS'), - choices=['sync', 'eventlet', 'gevent', 'tornado'], + default=conf.get("webserver", "WORKER_CLASS"), + choices=["sync", "eventlet", "gevent", "tornado"], help="The worker class to use for Gunicorn", ) ARG_WORKER_TIMEOUT = Arg( ("-t", "--worker-timeout"), - default=conf.get('webserver', 'WEB_SERVER_WORKER_TIMEOUT'), + default=conf.get("webserver", "WEB_SERVER_WORKER_TIMEOUT"), type=int, help="The timeout for waiting on webserver workers", ) ARG_HOSTNAME = Arg( ("-H", "--hostname"), - default=conf.get('webserver', 'WEB_SERVER_HOST'), + default=conf.get("webserver", "WEB_SERVER_HOST"), help="Set the hostname on which to run the web server", ) ARG_DEBUG = Arg( @@ -610,24 +646,24 @@ def string_lower_type(val): ) ARG_ACCESS_LOGFILE = Arg( ("-A", "--access-logfile"), - default=conf.get('webserver', 'ACCESS_LOGFILE'), - help="The logfile to store the webserver access log. Use '-' to print to stderr", + default=conf.get("webserver", "ACCESS_LOGFILE"), + help="The logfile to store the webserver access log. Use '-' to print to stdout", ) ARG_ERROR_LOGFILE = Arg( ("-E", "--error-logfile"), - default=conf.get('webserver', 'ERROR_LOGFILE'), + default=conf.get("webserver", "ERROR_LOGFILE"), help="The logfile to store the webserver error log. Use '-' to print to stderr", ) ARG_ACCESS_LOGFORMAT = Arg( ("-L", "--access-logformat"), - default=conf.get('webserver', 'ACCESS_LOGFORMAT'), + default=conf.get("webserver", "ACCESS_LOGFORMAT"), help="The access log format for gunicorn logs", ) # scheduler ARG_NUM_RUNS = Arg( ("-n", "--num-runs"), - default=conf.getint('scheduler', 'num_runs'), + default=conf.getint("scheduler", "num_runs"), type=int, help="Set the number of runs to execute before exiting", ) @@ -646,13 +682,13 @@ def string_lower_type(val): ARG_QUEUES = Arg( ("-q", "--queues"), help="Comma delimited list of queues to serve", - default=conf.get('operators', 'DEFAULT_QUEUE'), + default=conf.get("operators", "DEFAULT_QUEUE"), ) ARG_CONCURRENCY = Arg( ("-c", "--concurrency"), type=int, help="The number of worker processes", - default=conf.get('celery', 'worker_concurrency'), + default=conf.get("celery", "worker_concurrency"), ) ARG_CELERY_HOSTNAME = Arg( ("-H", "--celery-hostname"), @@ -661,18 +697,17 @@ def string_lower_type(val): ARG_UMASK = Arg( ("-u", "--umask"), help="Set the umask of celery worker in daemon mode", - default=conf.get('celery', 'worker_umask'), ) ARG_WITHOUT_MINGLE = Arg( ("--without-mingle",), default=False, - help="Don’t synchronize with other workers at start-up", + help="Don't synchronize with other workers at start-up", action="store_true", ) ARG_WITHOUT_GOSSIP = Arg( ("--without-gossip",), default=False, - help="Don’t subscribe to other workers events", + help="Don't subscribe to other workers events", action="store_true", ) @@ -680,22 +715,22 @@ def string_lower_type(val): ARG_BROKER_API = Arg(("-a", "--broker-api"), help="Broker API") ARG_FLOWER_HOSTNAME = Arg( ("-H", "--hostname"), - default=conf.get('celery', 'FLOWER_HOST'), + default=conf.get("celery", "FLOWER_HOST"), help="Set the hostname on which to run the server", ) ARG_FLOWER_PORT = Arg( ("-p", "--port"), - default=conf.get('celery', 'FLOWER_PORT'), + default=conf.get("celery", "FLOWER_PORT"), type=int, help="The port on which to run the server", ) ARG_FLOWER_CONF = Arg(("-c", "--flower-conf"), help="Configuration file for flower") ARG_FLOWER_URL_PREFIX = Arg( - ("-u", "--url-prefix"), default=conf.get('celery', 'FLOWER_URL_PREFIX'), help="URL prefix for Flower" + ("-u", "--url-prefix"), default=conf.get("celery", "FLOWER_URL_PREFIX"), help="URL prefix for Flower" ) ARG_FLOWER_BASIC_AUTH = Arg( ("-A", "--basic-auth"), - default=conf.get('celery', 'FLOWER_BASIC_AUTH'), + default=conf.get("celery", "FLOWER_BASIC_AUTH"), help=( "Securing Flower with Basic Authentication. " "Accepts user:password pairs separated by a comma. " @@ -713,91 +748,91 @@ def string_lower_type(val): ) # connections -ARG_CONN_ID = Arg(('conn_id',), help='Connection id, required to get/add/delete a connection', type=str) +ARG_CONN_ID = Arg(("conn_id",), help="Connection id, required to get/add/delete a connection", type=str) ARG_CONN_ID_FILTER = Arg( - ('--conn-id',), help='If passed, only items with the specified connection ID will be displayed', type=str + ("--conn-id",), help="If passed, only items with the specified connection ID will be displayed", type=str ) ARG_CONN_URI = Arg( - ('--conn-uri',), help='Connection URI, required to add a connection without conn_type', type=str + ("--conn-uri",), help="Connection URI, required to add a connection without conn_type", type=str ) ARG_CONN_JSON = Arg( - ('--conn-json',), help='Connection JSON, required to add a connection using JSON representation', type=str + ("--conn-json",), help="Connection JSON, required to add a connection using JSON representation", type=str ) ARG_CONN_TYPE = Arg( - ('--conn-type',), help='Connection type, required to add a connection without conn_uri', type=str + ("--conn-type",), help="Connection type, required to add a connection without conn_uri", type=str ) ARG_CONN_DESCRIPTION = Arg( - ('--conn-description',), help='Connection description, optional when adding a connection', type=str + ("--conn-description",), help="Connection description, optional when adding a connection", type=str ) -ARG_CONN_HOST = Arg(('--conn-host',), help='Connection host, optional when adding a connection', type=str) -ARG_CONN_LOGIN = Arg(('--conn-login',), help='Connection login, optional when adding a connection', type=str) +ARG_CONN_HOST = Arg(("--conn-host",), help="Connection host, optional when adding a connection", type=str) +ARG_CONN_LOGIN = Arg(("--conn-login",), help="Connection login, optional when adding a connection", type=str) ARG_CONN_PASSWORD = Arg( - ('--conn-password',), help='Connection password, optional when adding a connection', type=str + ("--conn-password",), help="Connection password, optional when adding a connection", type=str ) ARG_CONN_SCHEMA = Arg( - ('--conn-schema',), help='Connection schema, optional when adding a connection', type=str + ("--conn-schema",), help="Connection schema, optional when adding a connection", type=str ) -ARG_CONN_PORT = Arg(('--conn-port',), help='Connection port, optional when adding a connection', type=str) +ARG_CONN_PORT = Arg(("--conn-port",), help="Connection port, optional when adding a connection", type=str) ARG_CONN_EXTRA = Arg( - ('--conn-extra',), help='Connection `Extra` field, optional when adding a connection', type=str + ("--conn-extra",), help="Connection `Extra` field, optional when adding a connection", type=str ) ARG_CONN_EXPORT = Arg( - ('file',), - help='Output file path for exporting the connections', - type=argparse.FileType('w', encoding='UTF-8'), + ("file",), + help="Output file path for exporting the connections", + type=argparse.FileType("w", encoding="UTF-8"), ) ARG_CONN_EXPORT_FORMAT = Arg( - ('--format',), - help='Deprecated -- use `--file-format` instead. File format to use for the export.', + ("--format",), + help="Deprecated -- use `--file-format` instead. File format to use for the export.", type=str, - choices=['json', 'yaml', 'env'], + choices=["json", "yaml", "env"], ) ARG_CONN_EXPORT_FILE_FORMAT = Arg( - ('--file-format',), help='File format for the export', type=str, choices=['json', 'yaml', 'env'] + ("--file-format",), help="File format for the export", type=str, choices=["json", "yaml", "env"] ) ARG_CONN_SERIALIZATION_FORMAT = Arg( - ('--serialization-format',), - help='When exporting as `.env` format, defines how connections should be serialized. Default is `uri`.', + ("--serialization-format",), + help="When exporting as `.env` format, defines how connections should be serialized. Default is `uri`.", type=string_lower_type, - choices=['json', 'uri'], + choices=["json", "uri"], ) ARG_CONN_IMPORT = Arg(("file",), help="Import connections from a file") # providers ARG_PROVIDER_NAME = Arg( - ('provider_name',), help='Provider name, required to get provider information', type=str + ("provider_name",), help="Provider name, required to get provider information", type=str ) ARG_FULL = Arg( - ('-f', '--full'), - help='Full information about the provider, including documentation information.', + ("-f", "--full"), + help="Full information about the provider, including documentation information.", required=False, action="store_true", ) # users -ARG_USERNAME = Arg(('-u', '--username'), help='Username of the user', required=True, type=str) -ARG_USERNAME_OPTIONAL = Arg(('-u', '--username'), help='Username of the user', type=str) -ARG_FIRSTNAME = Arg(('-f', '--firstname'), help='First name of the user', required=True, type=str) -ARG_LASTNAME = Arg(('-l', '--lastname'), help='Last name of the user', required=True, type=str) +ARG_USERNAME = Arg(("-u", "--username"), help="Username of the user", required=True, type=str) +ARG_USERNAME_OPTIONAL = Arg(("-u", "--username"), help="Username of the user", type=str) +ARG_FIRSTNAME = Arg(("-f", "--firstname"), help="First name of the user", required=True, type=str) +ARG_LASTNAME = Arg(("-l", "--lastname"), help="Last name of the user", required=True, type=str) ARG_ROLE = Arg( - ('-r', '--role'), - help='Role of the user. Existing roles include Admin, User, Op, Viewer, and Public', + ("-r", "--role"), + help="Role of the user. Existing roles include Admin, User, Op, Viewer, and Public", required=True, type=str, ) -ARG_EMAIL = Arg(('-e', '--email'), help='Email of the user', required=True, type=str) -ARG_EMAIL_OPTIONAL = Arg(('-e', '--email'), help='Email of the user', type=str) +ARG_EMAIL = Arg(("-e", "--email"), help="Email of the user", required=True, type=str) +ARG_EMAIL_OPTIONAL = Arg(("-e", "--email"), help="Email of the user", type=str) ARG_PASSWORD = Arg( - ('-p', '--password'), - help='Password of the user, required to create a user without --use-random-password', + ("-p", "--password"), + help="Password of the user, required to create a user without --use-random-password", type=str, ) ARG_USE_RANDOM_PASSWORD = Arg( - ('--use-random-password',), - help='Do not prompt for password. Use random string instead.' - ' Required to create a user without --password ', + ("--use-random-password",), + help="Do not prompt for password. Use random string instead." + " Required to create a user without --password ", default=False, - action='store_true', + action="store_true", ) ARG_USER_IMPORT = Arg( ("import",), @@ -805,7 +840,7 @@ def string_lower_type(val): help="Import users from JSON file. Example format::\n" + textwrap.indent( textwrap.dedent( - ''' + """ [ { "email": "foo@bar.org", @@ -814,7 +849,7 @@ def string_lower_type(val): "roles": ["Public"], "username": "jondoe" } - ]''' + ]""" ), " " * 4, ), @@ -822,10 +857,14 @@ def string_lower_type(val): ARG_USER_EXPORT = Arg(("export",), metavar="FILEPATH", help="Export all users to JSON file") # roles -ARG_CREATE_ROLE = Arg(('-c', '--create'), help='Create a new role', action='store_true') -ARG_LIST_ROLES = Arg(('-l', '--list'), help='List roles', action='store_true') -ARG_ROLES = Arg(('role',), help='The name of a role', nargs='*') -ARG_AUTOSCALE = Arg(('-a', '--autoscale'), help="Minimum and Maximum number of worker to autoscale") +ARG_CREATE_ROLE = Arg(("-c", "--create"), help="Create a new role", action="store_true") +ARG_LIST_ROLES = Arg(("-l", "--list"), help="List roles", action="store_true") +ARG_ROLES = Arg(("role",), help="The name of a role", nargs="*") +ARG_PERMISSIONS = Arg(("-p", "--permission"), help="Show role permissions", action="store_true") +ARG_ROLE_RESOURCE = Arg(("-r", "--resource"), help="The name of permissions", nargs="*", required=True) +ARG_ROLE_ACTION = Arg(("-a", "--action"), help="The action of permissions", nargs="*") +ARG_ROLE_ACTION_REQUIRED = Arg(("-a", "--action"), help="The action of permissions", nargs="*", required=True) +ARG_AUTOSCALE = Arg(("-a", "--autoscale"), help="Minimum and Maximum number of worker to autoscale") ARG_SKIP_SERVE_LOGS = Arg( ("-s", "--skip-serve-logs"), default=False, @@ -835,19 +874,19 @@ def string_lower_type(val): ARG_ROLE_IMPORT = Arg(("file",), help="Import roles from JSON file", nargs=None) ARG_ROLE_EXPORT = Arg(("file",), help="Export all roles to JSON file", nargs=None) ARG_ROLE_EXPORT_FMT = Arg( - ('-p', '--pretty'), - help='Format output JSON file by sorting role names and indenting by 4 spaces', - action='store_true', + ("-p", "--pretty"), + help="Format output JSON file by sorting role names and indenting by 4 spaces", + action="store_true", ) # info ARG_ANONYMIZE = Arg( - ('--anonymize',), - help='Minimize any personal identifiable information. Use it when sharing output with others.', - action='store_true', + ("--anonymize",), + help="Minimize any personal identifiable information. Use it when sharing output with others.", + action="store_true", ) ARG_FILE_IO = Arg( - ('--file-io',), help='Send output to file.io service and returns link.', action='store_true' + ("--file-io",), help="Send output to file.io service and returns link.", action="store_true" ) # config @@ -863,7 +902,7 @@ def string_lower_type(val): # kubernetes cleanup-pods ARG_NAMESPACE = Arg( ("--namespace",), - default=conf.get('kubernetes', 'namespace'), + default=conf.get("kubernetes_executor", "namespace"), help="Kubernetes Namespace. Default value is `[kubernetes] namespace` in configuration.", ) @@ -879,10 +918,10 @@ def string_lower_type(val): # jobs check ARG_JOB_TYPE_FILTER = Arg( - ('--job-type',), - choices=('BackfillJob', 'LocalTaskJob', 'SchedulerJob', 'TriggererJob'), - action='store', - help='The type of job(s) that will be checked.', + ("--job-type",), + choices=("BackfillJob", "LocalTaskJob", "SchedulerJob", "TriggererJob"), + action="store", + help="The type of job(s) that will be checked.", ) ARG_JOB_HOSTNAME_FILTER = Arg( @@ -892,6 +931,13 @@ def string_lower_type(val): help="The hostname of job(s) that will be checked.", ) +ARG_JOB_HOSTNAME_CALLABLE_FILTER = Arg( + ("--local",), + action="store_true", + help="If passed, this command will only show jobs from the local host " + "(those with a hostname matching what `hostname_callable` returns).", +) + ARG_JOB_LIMIT = Arg( ("--limit",), default=1, @@ -901,7 +947,7 @@ def string_lower_type(val): ARG_ALLOW_MULTIPLE = Arg( ("--allow-multiple",), - action='store_true', + action="store_true", help="If passed, this command will be successful even if multiple matching alive jobs are found.", ) @@ -936,49 +982,49 @@ def string_lower_type(val): class ActionCommand(NamedTuple): - """Single CLI command""" + """Single CLI command.""" name: str help: str func: Callable args: Iterable[Arg] - description: Optional[str] = None - epilog: Optional[str] = None + description: str | None = None + epilog: str | None = None class GroupCommand(NamedTuple): - """ClI command with subcommands""" + """ClI command with subcommands.""" name: str help: str subcommands: Iterable - description: Optional[str] = None - epilog: Optional[str] = None + description: str | None = None + epilog: str | None = None CLICommand = Union[ActionCommand, GroupCommand] DAGS_COMMANDS = ( ActionCommand( - name='list', + name="list", help="List all the DAGs", - func=lazy_load_command('airflow.cli.commands.dag_command.dag_list_dags'), + func=lazy_load_command("airflow.cli.commands.dag_command.dag_list_dags"), args=(ARG_SUBDIR, ARG_OUTPUT, ARG_VERBOSE), ), ActionCommand( - name='list-import-errors', + name="list-import-errors", help="List all the DAGs that have import errors", - func=lazy_load_command('airflow.cli.commands.dag_command.dag_list_import_errors'), + func=lazy_load_command("airflow.cli.commands.dag_command.dag_list_import_errors"), args=(ARG_SUBDIR, ARG_OUTPUT, ARG_VERBOSE), ), ActionCommand( - name='report', - help='Show DagBag loading report', - func=lazy_load_command('airflow.cli.commands.dag_command.dag_report'), + name="report", + help="Show DagBag loading report", + func=lazy_load_command("airflow.cli.commands.dag_command.dag_report"), args=(ARG_SUBDIR, ARG_OUTPUT, ARG_VERBOSE), ), ActionCommand( - name='list-runs', + name="list-runs", help="List DAG runs given a DAG id", description=( "List DAG runs given a DAG id. If state option is given, it will only search for all the " @@ -987,9 +1033,9 @@ class GroupCommand(NamedTuple): "dagruns that were executed before this date. If end_date is given, it will filter out " "all the dagruns that were executed after this date. " ), - func=lazy_load_command('airflow.cli.commands.dag_command.dag_list_dag_runs'), + func=lazy_load_command("airflow.cli.commands.dag_command.dag_list_dag_runs"), args=( - ARG_DAG_ID_OPT, + ARG_DAG_ID_REQ_FLAG, ARG_NO_BACKFILL, ARG_STATE, ARG_OUTPUT, @@ -999,53 +1045,53 @@ class GroupCommand(NamedTuple): ), ), ActionCommand( - name='list-jobs', + name="list-jobs", help="List the jobs", - func=lazy_load_command('airflow.cli.commands.dag_command.dag_list_jobs'), + func=lazy_load_command("airflow.cli.commands.dag_command.dag_list_jobs"), args=(ARG_DAG_ID_OPT, ARG_STATE, ARG_LIMIT, ARG_OUTPUT, ARG_VERBOSE), ), ActionCommand( - name='state', + name="state", help="Get the status of a dag run", - func=lazy_load_command('airflow.cli.commands.dag_command.dag_state'), - args=(ARG_DAG_ID, ARG_EXECUTION_DATE, ARG_SUBDIR), + func=lazy_load_command("airflow.cli.commands.dag_command.dag_state"), + args=(ARG_DAG_ID, ARG_EXECUTION_DATE, ARG_SUBDIR, ARG_VERBOSE), ), ActionCommand( - name='next-execution', + name="next-execution", help="Get the next execution datetimes of a DAG", description=( "Get the next execution datetimes of a DAG. It returns one execution unless the " "num-executions option is given" ), - func=lazy_load_command('airflow.cli.commands.dag_command.dag_next_execution'), - args=(ARG_DAG_ID, ARG_SUBDIR, ARG_NUM_EXECUTIONS), + func=lazy_load_command("airflow.cli.commands.dag_command.dag_next_execution"), + args=(ARG_DAG_ID, ARG_SUBDIR, ARG_NUM_EXECUTIONS, ARG_VERBOSE), ), ActionCommand( - name='pause', - help='Pause a DAG', - func=lazy_load_command('airflow.cli.commands.dag_command.dag_pause'), - args=(ARG_DAG_ID, ARG_SUBDIR), + name="pause", + help="Pause a DAG", + func=lazy_load_command("airflow.cli.commands.dag_command.dag_pause"), + args=(ARG_DAG_ID, ARG_SUBDIR, ARG_VERBOSE), ), ActionCommand( - name='unpause', - help='Resume a paused DAG', - func=lazy_load_command('airflow.cli.commands.dag_command.dag_unpause'), - args=(ARG_DAG_ID, ARG_SUBDIR), + name="unpause", + help="Resume a paused DAG", + func=lazy_load_command("airflow.cli.commands.dag_command.dag_unpause"), + args=(ARG_DAG_ID, ARG_SUBDIR, ARG_VERBOSE), ), ActionCommand( - name='trigger', - help='Trigger a DAG run', - func=lazy_load_command('airflow.cli.commands.dag_command.dag_trigger'), - args=(ARG_DAG_ID, ARG_SUBDIR, ARG_RUN_ID, ARG_CONF, ARG_EXEC_DATE), + name="trigger", + help="Trigger a DAG run", + func=lazy_load_command("airflow.cli.commands.dag_command.dag_trigger"), + args=(ARG_DAG_ID, ARG_SUBDIR, ARG_RUN_ID, ARG_CONF, ARG_EXEC_DATE, ARG_VERBOSE), ), ActionCommand( - name='delete', + name="delete", help="Delete all DB records related to the specified DAG", - func=lazy_load_command('airflow.cli.commands.dag_command.dag_delete'), - args=(ARG_DAG_ID, ARG_YES), + func=lazy_load_command("airflow.cli.commands.dag_command.dag_delete"), + args=(ARG_DAG_ID, ARG_YES, ARG_VERBOSE), ), ActionCommand( - name='show', + name="show", help="Displays DAG's tasks with their dependencies", description=( "The --imgcat option only works in iTerm.\n" @@ -1064,16 +1110,17 @@ class GroupCommand(NamedTuple): "If you want to create a DOT file then you should execute the following command:\n" "airflow dags show --save output.dot\n" ), - func=lazy_load_command('airflow.cli.commands.dag_command.dag_show'), + func=lazy_load_command("airflow.cli.commands.dag_command.dag_show"), args=( ARG_DAG_ID, ARG_SUBDIR, ARG_SAVE, ARG_IMGCAT, + ARG_VERBOSE, ), ), ActionCommand( - name='show-dependencies', + name="show-dependencies", help="Displays DAGs with their dependencies", description=( "The --imgcat option only works in iTerm.\n" @@ -1092,15 +1139,16 @@ class GroupCommand(NamedTuple): "If you want to create a DOT file then you should execute the following command:\n" "airflow dags show-dependencies --save output.dot\n" ), - func=lazy_load_command('airflow.cli.commands.dag_command.dag_dependencies_show'), + func=lazy_load_command("airflow.cli.commands.dag_command.dag_dependencies_show"), args=( ARG_SUBDIR, ARG_SAVE, ARG_IMGCAT, + ARG_VERBOSE, ), ), ActionCommand( - name='backfill', + name="backfill", help="Run subsections of a DAG for a specified date range", description=( "Run subsections of a DAG for a specified date range. If reset_dag_run option is used, " @@ -1108,7 +1156,7 @@ class GroupCommand(NamedTuple): "task_instances within the backfill date range. If rerun_failed_tasks is used, backfill " "will auto re-run the previous failed task instances within the backfill date range" ), - func=lazy_load_command('airflow.cli.commands.dag_command.dag_backfill'), + func=lazy_load_command("airflow.cli.commands.dag_command.dag_backfill"), args=( ARG_DAG_ID, ARG_TASK_REGEX, @@ -1119,6 +1167,7 @@ class GroupCommand(NamedTuple): ARG_DONOT_PICKLE, ARG_YES, ARG_CONTINUE_ON_FAILURES, + ARG_DISABLE_RETRY, ARG_BF_IGNORE_DEPENDENCIES, ARG_BF_IGNORE_FIRST_DEPENDS_ON_PAST, ARG_SUBDIR, @@ -1130,14 +1179,14 @@ class GroupCommand(NamedTuple): ARG_RESET_DAG_RUN, ARG_RERUN_FAILED_TASKS, ARG_RUN_BACKWARDS, + ARG_TREAT_DAG_AS_REGEX, ), ), ActionCommand( - name='test', + name="test", help="Execute one single DagRun", description=( - "Execute one single DagRun for a given DAG and execution date, " - "using the DebugExecutor.\n" + "Execute one single DagRun for a given DAG and execution date.\n" "\n" "The --imgcat-dagrun option only works in iTerm.\n" "\n" @@ -1155,39 +1204,45 @@ class GroupCommand(NamedTuple): "If you want to create a DOT file then you should execute the following command:\n" "airflow dags test --save-dagrun output.dot\n" ), - func=lazy_load_command('airflow.cli.commands.dag_command.dag_test'), + func=lazy_load_command("airflow.cli.commands.dag_command.dag_test"), args=( ARG_DAG_ID, - ARG_EXECUTION_DATE, + ARG_EXECUTION_DATE_OPTIONAL, + ARG_CONF, ARG_SUBDIR, ARG_SHOW_DAGRUN, ARG_IMGCAT_DAGRUN, ARG_SAVE_DAGRUN, + ARG_VERBOSE, ), ), ActionCommand( - name='reserialize', + name="reserialize", help="Reserialize all DAGs by parsing the DagBag files", description=( "Drop all serialized dags from the metadata DB. This will cause all DAGs to be reserialized " "from the DagBag folder. This can be helpful if your serialized DAGs get out of sync with the " "version of Airflow that you are running." ), - func=lazy_load_command('airflow.cli.commands.dag_command.dag_reserialize'), - args=(ARG_CLEAR_ONLY,), + func=lazy_load_command("airflow.cli.commands.dag_command.dag_reserialize"), + args=( + ARG_CLEAR_ONLY, + ARG_SUBDIR, + ARG_VERBOSE, + ), ), ) TASKS_COMMANDS = ( ActionCommand( - name='list', + name="list", help="List the tasks within a DAG", - func=lazy_load_command('airflow.cli.commands.task_command.task_list'), + func=lazy_load_command("airflow.cli.commands.task_command.task_list"), args=(ARG_DAG_ID, ARG_TREE, ARG_SUBDIR, ARG_VERBOSE), ), ActionCommand( - name='clear', + name="clear", help="Clear a set of task instance, as if they never ran", - func=lazy_load_command('airflow.cli.commands.task_command.task_clear'), + func=lazy_load_command("airflow.cli.commands.task_command.task_clear"), args=( ARG_DAG_ID, ARG_TASK_REGEX, @@ -1202,12 +1257,13 @@ class GroupCommand(NamedTuple): ARG_EXCLUDE_SUBDAGS, ARG_EXCLUDE_PARENTDAG, ARG_DAG_REGEX, + ARG_VERBOSE, ), ), ActionCommand( - name='state', + name="state", help="Get the status of a task instance", - func=lazy_load_command('airflow.cli.commands.task_command.task_state'), + func=lazy_load_command("airflow.cli.commands.task_command.task_state"), args=( ARG_DAG_ID, ARG_TASK_ID, @@ -1218,20 +1274,20 @@ class GroupCommand(NamedTuple): ), ), ActionCommand( - name='failed-deps', + name="failed-deps", help="Returns the unmet dependencies for a task instance", description=( "Returns the unmet dependencies for a task instance from the perspective of the scheduler. " "In other words, why a task instance doesn't get scheduled and then queued by the scheduler, " "and then run by an executor." ), - func=lazy_load_command('airflow.cli.commands.task_command.task_failed_deps'), - args=(ARG_DAG_ID, ARG_TASK_ID, ARG_EXECUTION_DATE_OR_RUN_ID, ARG_SUBDIR, ARG_MAP_INDEX), + func=lazy_load_command("airflow.cli.commands.task_command.task_failed_deps"), + args=(ARG_DAG_ID, ARG_TASK_ID, ARG_EXECUTION_DATE_OR_RUN_ID, ARG_SUBDIR, ARG_MAP_INDEX, ARG_VERBOSE), ), ActionCommand( - name='render', + name="render", help="Render a task instance's template(s)", - func=lazy_load_command('airflow.cli.commands.task_command.task_render'), + func=lazy_load_command("airflow.cli.commands.task_command.task_render"), args=( ARG_DAG_ID, ARG_TASK_ID, @@ -1242,9 +1298,9 @@ class GroupCommand(NamedTuple): ), ), ActionCommand( - name='run', + name="run", help="Run a single task instance", - func=lazy_load_command('airflow.cli.commands.task_command.task_run'), + func=lazy_load_command("airflow.cli.commands.task_command.task_run"), args=( ARG_DAG_ID, ARG_TASK_ID, @@ -1265,133 +1321,135 @@ class GroupCommand(NamedTuple): ARG_INTERACTIVE, ARG_SHUT_DOWN_LOGGING, ARG_MAP_INDEX, + ARG_VERBOSE, ), ), ActionCommand( - name='test', + name="test", help="Test a task instance", description=( "Test a task instance. This will run a task without checking for dependencies or recording " "its state in the database" ), - func=lazy_load_command('airflow.cli.commands.task_command.task_test'), + func=lazy_load_command("airflow.cli.commands.task_command.task_test"), args=( ARG_DAG_ID, ARG_TASK_ID, - ARG_EXECUTION_DATE_OR_RUN_ID, + ARG_EXECUTION_DATE_OR_RUN_ID_OPTIONAL, ARG_SUBDIR, ARG_DRY_RUN, ARG_TASK_PARAMS, ARG_POST_MORTEM, ARG_ENV_VARS, ARG_MAP_INDEX, + ARG_VERBOSE, ), ), ActionCommand( - name='states-for-dag-run', + name="states-for-dag-run", help="Get the status of all task instances in a dag run", - func=lazy_load_command('airflow.cli.commands.task_command.task_states_for_dag_run'), + func=lazy_load_command("airflow.cli.commands.task_command.task_states_for_dag_run"), args=(ARG_DAG_ID, ARG_EXECUTION_DATE_OR_RUN_ID, ARG_OUTPUT, ARG_VERBOSE), ), ) POOLS_COMMANDS = ( ActionCommand( - name='list', - help='List pools', - func=lazy_load_command('airflow.cli.commands.pool_command.pool_list'), + name="list", + help="List pools", + func=lazy_load_command("airflow.cli.commands.pool_command.pool_list"), args=(ARG_OUTPUT, ARG_VERBOSE), ), ActionCommand( - name='get', - help='Get pool size', - func=lazy_load_command('airflow.cli.commands.pool_command.pool_get'), + name="get", + help="Get pool size", + func=lazy_load_command("airflow.cli.commands.pool_command.pool_get"), args=(ARG_POOL_NAME, ARG_OUTPUT, ARG_VERBOSE), ), ActionCommand( - name='set', - help='Configure pool', - func=lazy_load_command('airflow.cli.commands.pool_command.pool_set'), + name="set", + help="Configure pool", + func=lazy_load_command("airflow.cli.commands.pool_command.pool_set"), args=(ARG_POOL_NAME, ARG_POOL_SLOTS, ARG_POOL_DESCRIPTION, ARG_OUTPUT, ARG_VERBOSE), ), ActionCommand( - name='delete', - help='Delete pool', - func=lazy_load_command('airflow.cli.commands.pool_command.pool_delete'), + name="delete", + help="Delete pool", + func=lazy_load_command("airflow.cli.commands.pool_command.pool_delete"), args=(ARG_POOL_NAME, ARG_OUTPUT, ARG_VERBOSE), ), ActionCommand( - name='import', - help='Import pools', - func=lazy_load_command('airflow.cli.commands.pool_command.pool_import'), + name="import", + help="Import pools", + func=lazy_load_command("airflow.cli.commands.pool_command.pool_import"), args=(ARG_POOL_IMPORT, ARG_VERBOSE), ), ActionCommand( - name='export', - help='Export all pools', - func=lazy_load_command('airflow.cli.commands.pool_command.pool_export'), - args=(ARG_POOL_EXPORT,), + name="export", + help="Export all pools", + func=lazy_load_command("airflow.cli.commands.pool_command.pool_export"), + args=(ARG_POOL_EXPORT, ARG_VERBOSE), ), ) VARIABLES_COMMANDS = ( ActionCommand( - name='list', - help='List variables', - func=lazy_load_command('airflow.cli.commands.variable_command.variables_list'), + name="list", + help="List variables", + func=lazy_load_command("airflow.cli.commands.variable_command.variables_list"), args=(ARG_OUTPUT, ARG_VERBOSE), ), ActionCommand( - name='get', - help='Get variable', - func=lazy_load_command('airflow.cli.commands.variable_command.variables_get'), - args=(ARG_VAR, ARG_JSON, ARG_DEFAULT, ARG_VERBOSE), + name="get", + help="Get variable", + func=lazy_load_command("airflow.cli.commands.variable_command.variables_get"), + args=(ARG_VAR, ARG_DESERIALIZE_JSON, ARG_DEFAULT, ARG_VERBOSE), ), ActionCommand( - name='set', - help='Set variable', - func=lazy_load_command('airflow.cli.commands.variable_command.variables_set'), - args=(ARG_VAR, ARG_VAR_VALUE, ARG_JSON), + name="set", + help="Set variable", + func=lazy_load_command("airflow.cli.commands.variable_command.variables_set"), + args=(ARG_VAR, ARG_VAR_VALUE, ARG_SERIALIZE_JSON, ARG_VERBOSE), ), ActionCommand( - name='delete', - help='Delete variable', - func=lazy_load_command('airflow.cli.commands.variable_command.variables_delete'), - args=(ARG_VAR,), + name="delete", + help="Delete variable", + func=lazy_load_command("airflow.cli.commands.variable_command.variables_delete"), + args=(ARG_VAR, ARG_VERBOSE), ), ActionCommand( - name='import', - help='Import variables', - func=lazy_load_command('airflow.cli.commands.variable_command.variables_import'), - args=(ARG_VAR_IMPORT,), + name="import", + help="Import variables", + func=lazy_load_command("airflow.cli.commands.variable_command.variables_import"), + args=(ARG_VAR_IMPORT, ARG_VERBOSE), ), ActionCommand( - name='export', - help='Export all variables', - func=lazy_load_command('airflow.cli.commands.variable_command.variables_export'), - args=(ARG_VAR_EXPORT,), + name="export", + help="Export all variables", + func=lazy_load_command("airflow.cli.commands.variable_command.variables_export"), + args=(ARG_VAR_EXPORT, ARG_VERBOSE), ), ) DB_COMMANDS = ( ActionCommand( - name='init', + name="init", help="Initialize the metadata database", - func=lazy_load_command('airflow.cli.commands.db_command.initdb'), - args=(), + func=lazy_load_command("airflow.cli.commands.db_command.initdb"), + args=(ARG_VERBOSE,), ), ActionCommand( name="check-migrations", help="Check if migration have finished", description="Check if migration have finished (or continually check until timeout)", - func=lazy_load_command('airflow.cli.commands.db_command.check_migrations'), - args=(ARG_MIGRATION_TIMEOUT,), + func=lazy_load_command("airflow.cli.commands.db_command.check_migrations"), + args=(ARG_MIGRATION_TIMEOUT, ARG_VERBOSE), ), ActionCommand( - name='reset', + name="reset", help="Burn down and rebuild the metadata database", - func=lazy_load_command('airflow.cli.commands.db_command.resetdb'), - args=(ARG_YES, ARG_DB_SKIP_INIT), + func=lazy_load_command("airflow.cli.commands.db_command.resetdb"), + args=(ARG_YES, ARG_DB_SKIP_INIT, ARG_VERBOSE), ), ActionCommand( - name='upgrade', + name="upgrade", help="Upgrade the metadata database to latest version", description=( "Upgrade the schema of the metadata database. " @@ -1400,17 +1458,19 @@ class GroupCommand(NamedTuple): "``--show-sql-only``, because if actually *running* migrations, we should only " "migrate from the *current* Alembic revision." ), - func=lazy_load_command('airflow.cli.commands.db_command.upgradedb'), + func=lazy_load_command("airflow.cli.commands.db_command.upgradedb"), args=( ARG_DB_REVISION__UPGRADE, ARG_DB_VERSION__UPGRADE, ARG_DB_SQL_ONLY, ARG_DB_FROM_REVISION, ARG_DB_FROM_VERSION, + ARG_DB_RESERIALIZE_DAGS, + ARG_VERBOSE, ), ), ActionCommand( - name='downgrade', + name="downgrade", help="Downgrade the schema of the metadata database.", description=( "Downgrade the schema of the metadata database. " @@ -1420,7 +1480,7 @@ class GroupCommand(NamedTuple): "because if actually *running* migrations, we should only migrate from the *current* Alembic " "revision." ), - func=lazy_load_command('airflow.cli.commands.db_command.downgrade'), + func=lazy_load_command("airflow.cli.commands.db_command.downgrade"), args=( ARG_DB_REVISION__DOWNGRADE, ARG_DB_VERSION__DOWNGRADE, @@ -1428,61 +1488,63 @@ class GroupCommand(NamedTuple): ARG_YES, ARG_DB_FROM_REVISION, ARG_DB_FROM_VERSION, + ARG_VERBOSE, ), ), ActionCommand( - name='shell', + name="shell", help="Runs a shell to access the database", - func=lazy_load_command('airflow.cli.commands.db_command.shell'), - args=(), + func=lazy_load_command("airflow.cli.commands.db_command.shell"), + args=(ARG_VERBOSE,), ), ActionCommand( - name='check', + name="check", help="Check if the database can be reached", - func=lazy_load_command('airflow.cli.commands.db_command.check'), - args=(), + func=lazy_load_command("airflow.cli.commands.db_command.check"), + args=(ARG_VERBOSE,), ), ActionCommand( - name='clean', + name="clean", help="Purge old records in metastore tables", - func=lazy_load_command('airflow.cli.commands.db_command.cleanup_tables'), + func=lazy_load_command("airflow.cli.commands.db_command.cleanup_tables"), args=( ARG_DB_TABLES, ARG_DB_DRY_RUN, ARG_DB_CLEANUP_TIMESTAMP, ARG_VERBOSE, ARG_YES, + ARG_DB_SKIP_ARCHIVE, ), ), ) CONNECTIONS_COMMANDS = ( ActionCommand( - name='get', - help='Get a connection', - func=lazy_load_command('airflow.cli.commands.connection_command.connections_get'), + name="get", + help="Get a connection", + func=lazy_load_command("airflow.cli.commands.connection_command.connections_get"), args=(ARG_CONN_ID, ARG_COLOR, ARG_OUTPUT, ARG_VERBOSE), ), ActionCommand( - name='list', - help='List connections', - func=lazy_load_command('airflow.cli.commands.connection_command.connections_list'), + name="list", + help="List connections", + func=lazy_load_command("airflow.cli.commands.connection_command.connections_list"), args=(ARG_OUTPUT, ARG_VERBOSE, ARG_CONN_ID_FILTER), ), ActionCommand( - name='add', - help='Add a connection', - func=lazy_load_command('airflow.cli.commands.connection_command.connections_add'), + name="add", + help="Add a connection", + func=lazy_load_command("airflow.cli.commands.connection_command.connections_add"), args=(ARG_CONN_ID, ARG_CONN_URI, ARG_CONN_JSON, ARG_CONN_EXTRA) + tuple(ALTERNATIVE_CONN_SPECS_ARGS), ), ActionCommand( - name='delete', - help='Delete a connection', - func=lazy_load_command('airflow.cli.commands.connection_command.connections_delete'), - args=(ARG_CONN_ID, ARG_COLOR), + name="delete", + help="Delete a connection", + func=lazy_load_command("airflow.cli.commands.connection_command.connections_delete"), + args=(ARG_CONN_ID, ARG_COLOR, ARG_VERBOSE), ), ActionCommand( - name='export', - help='Export all connections', + name="export", + help="Export all connections", description=( "All connections can be exported in STDOUT using the following command:\n" "airflow connections export -\n" @@ -1498,96 +1560,100 @@ class GroupCommand(NamedTuple): "is used to serialize the connection by passing `uri` or `json` with option " "`--serialization-format`.\n" ), - func=lazy_load_command('airflow.cli.commands.connection_command.connections_export'), + func=lazy_load_command("airflow.cli.commands.connection_command.connections_export"), args=( ARG_CONN_EXPORT, ARG_CONN_EXPORT_FORMAT, ARG_CONN_EXPORT_FILE_FORMAT, ARG_CONN_SERIALIZATION_FORMAT, + ARG_VERBOSE, ), ), ActionCommand( - name='import', - help='Import connections from a file', + name="import", + help="Import connections from a file", description=( "Connections can be imported from the output of the export command.\n" "The filetype must by json, yaml or env and will be automatically inferred." ), - func=lazy_load_command('airflow.cli.commands.connection_command.connections_import'), - args=(ARG_CONN_IMPORT,), + func=lazy_load_command("airflow.cli.commands.connection_command.connections_import"), + args=( + ARG_CONN_IMPORT, + ARG_VERBOSE, + ), ), ) PROVIDERS_COMMANDS = ( ActionCommand( - name='list', - help='List installed providers', - func=lazy_load_command('airflow.cli.commands.provider_command.providers_list'), + name="list", + help="List installed providers", + func=lazy_load_command("airflow.cli.commands.provider_command.providers_list"), args=(ARG_OUTPUT, ARG_VERBOSE), ), ActionCommand( - name='get', - help='Get detailed information about a provider', - func=lazy_load_command('airflow.cli.commands.provider_command.provider_get'), + name="get", + help="Get detailed information about a provider", + func=lazy_load_command("airflow.cli.commands.provider_command.provider_get"), args=(ARG_OUTPUT, ARG_VERBOSE, ARG_FULL, ARG_COLOR, ARG_PROVIDER_NAME), ), ActionCommand( - name='links', - help='List extra links registered by the providers', - func=lazy_load_command('airflow.cli.commands.provider_command.extra_links_list'), + name="links", + help="List extra links registered by the providers", + func=lazy_load_command("airflow.cli.commands.provider_command.extra_links_list"), args=(ARG_OUTPUT, ARG_VERBOSE), ), ActionCommand( - name='widgets', - help='Get information about registered connection form widgets', - func=lazy_load_command('airflow.cli.commands.provider_command.connection_form_widget_list'), + name="widgets", + help="Get information about registered connection form widgets", + func=lazy_load_command("airflow.cli.commands.provider_command.connection_form_widget_list"), args=( ARG_OUTPUT, ARG_VERBOSE, ), ), ActionCommand( - name='hooks', - help='List registered provider hooks', - func=lazy_load_command('airflow.cli.commands.provider_command.hooks_list'), + name="hooks", + help="List registered provider hooks", + func=lazy_load_command("airflow.cli.commands.provider_command.hooks_list"), args=(ARG_OUTPUT, ARG_VERBOSE), ), ActionCommand( - name='behaviours', - help='Get information about registered connection types with custom behaviours', - func=lazy_load_command('airflow.cli.commands.provider_command.connection_field_behaviours'), + name="behaviours", + help="Get information about registered connection types with custom behaviours", + func=lazy_load_command("airflow.cli.commands.provider_command.connection_field_behaviours"), args=(ARG_OUTPUT, ARG_VERBOSE), ), ActionCommand( - name='logging', - help='Get information about task logging handlers provided', - func=lazy_load_command('airflow.cli.commands.provider_command.logging_list'), + name="logging", + help="Get information about task logging handlers provided", + func=lazy_load_command("airflow.cli.commands.provider_command.logging_list"), args=(ARG_OUTPUT, ARG_VERBOSE), ), ActionCommand( - name='secrets', - help='Get information about secrets backends provided', - func=lazy_load_command('airflow.cli.commands.provider_command.secrets_backends_list'), + name="secrets", + help="Get information about secrets backends provided", + func=lazy_load_command("airflow.cli.commands.provider_command.secrets_backends_list"), args=(ARG_OUTPUT, ARG_VERBOSE), ), ActionCommand( - name='auth', - help='Get information about API auth backends provided', - func=lazy_load_command('airflow.cli.commands.provider_command.auth_backend_list'), + name="auth", + help="Get information about API auth backends provided", + func=lazy_load_command("airflow.cli.commands.provider_command.auth_backend_list"), args=(ARG_OUTPUT, ARG_VERBOSE), ), ) USERS_COMMANDS = ( ActionCommand( - name='list', - help='List users', - func=lazy_load_command('airflow.cli.commands.user_command.users_list'), + name="list", + help="List users", + func=lazy_load_command("airflow.cli.commands.user_command.users_list"), args=(ARG_OUTPUT, ARG_VERBOSE), ), ActionCommand( - name='create', - help='Create a user', - func=lazy_load_command('airflow.cli.commands.user_command.users_create'), + name="create", + help="Create a user", + func=lazy_load_command("airflow.cli.commands.user_command.users_create"), args=( ARG_ROLE, ARG_USERNAME, @@ -1596,82 +1662,101 @@ class GroupCommand(NamedTuple): ARG_LASTNAME, ARG_PASSWORD, ARG_USE_RANDOM_PASSWORD, + ARG_VERBOSE, ), epilog=( - 'examples:\n' + "examples:\n" 'To create an user with "Admin" role and username equals to "admin", run:\n' - '\n' - ' $ airflow users create \\\n' - ' --username admin \\\n' - ' --firstname FIRST_NAME \\\n' - ' --lastname LAST_NAME \\\n' - ' --role Admin \\\n' - ' --email admin@example.org' + "\n" + " $ airflow users create \\\n" + " --username admin \\\n" + " --firstname FIRST_NAME \\\n" + " --lastname LAST_NAME \\\n" + " --role Admin \\\n" + " --email admin@example.org" ), ), ActionCommand( - name='delete', - help='Delete a user', - func=lazy_load_command('airflow.cli.commands.user_command.users_delete'), - args=(ARG_USERNAME_OPTIONAL, ARG_EMAIL_OPTIONAL), + name="delete", + help="Delete a user", + func=lazy_load_command("airflow.cli.commands.user_command.users_delete"), + args=(ARG_USERNAME_OPTIONAL, ARG_EMAIL_OPTIONAL, ARG_VERBOSE), ), ActionCommand( - name='add-role', - help='Add role to a user', - func=lazy_load_command('airflow.cli.commands.user_command.add_role'), - args=(ARG_USERNAME_OPTIONAL, ARG_EMAIL_OPTIONAL, ARG_ROLE), + name="add-role", + help="Add role to a user", + func=lazy_load_command("airflow.cli.commands.user_command.add_role"), + args=(ARG_USERNAME_OPTIONAL, ARG_EMAIL_OPTIONAL, ARG_ROLE, ARG_VERBOSE), ), ActionCommand( - name='remove-role', - help='Remove role from a user', - func=lazy_load_command('airflow.cli.commands.user_command.remove_role'), - args=(ARG_USERNAME_OPTIONAL, ARG_EMAIL_OPTIONAL, ARG_ROLE), + name="remove-role", + help="Remove role from a user", + func=lazy_load_command("airflow.cli.commands.user_command.remove_role"), + args=(ARG_USERNAME_OPTIONAL, ARG_EMAIL_OPTIONAL, ARG_ROLE, ARG_VERBOSE), ), ActionCommand( - name='import', - help='Import users', - func=lazy_load_command('airflow.cli.commands.user_command.users_import'), - args=(ARG_USER_IMPORT,), + name="import", + help="Import users", + func=lazy_load_command("airflow.cli.commands.user_command.users_import"), + args=(ARG_USER_IMPORT, ARG_VERBOSE), ), ActionCommand( - name='export', - help='Export all users', - func=lazy_load_command('airflow.cli.commands.user_command.users_export'), - args=(ARG_USER_EXPORT,), + name="export", + help="Export all users", + func=lazy_load_command("airflow.cli.commands.user_command.users_export"), + args=(ARG_USER_EXPORT, ARG_VERBOSE), ), ) ROLES_COMMANDS = ( ActionCommand( - name='list', - help='List roles', - func=lazy_load_command('airflow.cli.commands.role_command.roles_list'), - args=(ARG_OUTPUT, ARG_VERBOSE), + name="list", + help="List roles", + func=lazy_load_command("airflow.cli.commands.role_command.roles_list"), + args=(ARG_PERMISSIONS, ARG_OUTPUT, ARG_VERBOSE), ), ActionCommand( - name='create', - help='Create role', - func=lazy_load_command('airflow.cli.commands.role_command.roles_create'), + name="create", + help="Create role", + func=lazy_load_command("airflow.cli.commands.role_command.roles_create"), args=(ARG_ROLES, ARG_VERBOSE), ), ActionCommand( - name='export', - help='Export roles (without permissions) from db to JSON file', - func=lazy_load_command('airflow.cli.commands.role_command.roles_export'), + name="delete", + help="Delete role", + func=lazy_load_command("airflow.cli.commands.role_command.roles_delete"), + args=(ARG_ROLES, ARG_VERBOSE), + ), + ActionCommand( + name="add-perms", + help="Add roles permissions", + func=lazy_load_command("airflow.cli.commands.role_command.roles_add_perms"), + args=(ARG_ROLES, ARG_ROLE_RESOURCE, ARG_ROLE_ACTION_REQUIRED, ARG_VERBOSE), + ), + ActionCommand( + name="del-perms", + help="Delete roles permissions", + func=lazy_load_command("airflow.cli.commands.role_command.roles_del_perms"), + args=(ARG_ROLES, ARG_ROLE_RESOURCE, ARG_ROLE_ACTION, ARG_VERBOSE), + ), + ActionCommand( + name="export", + help="Export roles (without permissions) from db to JSON file", + func=lazy_load_command("airflow.cli.commands.role_command.roles_export"), args=(ARG_ROLE_EXPORT, ARG_ROLE_EXPORT_FMT, ARG_VERBOSE), ), ActionCommand( - name='import', - help='Import roles (without permissions) from JSON file to db', - func=lazy_load_command('airflow.cli.commands.role_command.roles_import'), + name="import", + help="Import roles (without permissions) from JSON file to db", + func=lazy_load_command("airflow.cli.commands.role_command.roles_import"), args=(ARG_ROLE_IMPORT, ARG_VERBOSE), ), ) CELERY_COMMANDS = ( ActionCommand( - name='worker', + name="worker", help="Start a Celery worker node", - func=lazy_load_command('airflow.cli.commands.celery_command.worker'), + func=lazy_load_command("airflow.cli.commands.celery_command.worker"), args=( ARG_QUEUES, ARG_CONCURRENCY, @@ -1686,12 +1771,13 @@ class GroupCommand(NamedTuple): ARG_SKIP_SERVE_LOGS, ARG_WITHOUT_MINGLE, ARG_WITHOUT_GOSSIP, + ARG_VERBOSE, ), ), ActionCommand( - name='flower', + name="flower", help="Start a Celery Flower", - func=lazy_load_command('airflow.cli.commands.celery_command.flower'), + func=lazy_load_command("airflow.cli.commands.celery_command.flower"), args=( ARG_FLOWER_HOSTNAME, ARG_FLOWER_PORT, @@ -1704,117 +1790,135 @@ class GroupCommand(NamedTuple): ARG_STDOUT, ARG_STDERR, ARG_LOG_FILE, + ARG_VERBOSE, ), ), ActionCommand( - name='stop', + name="stop", help="Stop the Celery worker gracefully", - func=lazy_load_command('airflow.cli.commands.celery_command.stop_worker'), - args=(ARG_PID,), + func=lazy_load_command("airflow.cli.commands.celery_command.stop_worker"), + args=(ARG_PID, ARG_VERBOSE), ), ) CONFIG_COMMANDS = ( ActionCommand( - name='get-value', - help='Print the value of the configuration', - func=lazy_load_command('airflow.cli.commands.config_command.get_value'), + name="get-value", + help="Print the value of the configuration", + func=lazy_load_command("airflow.cli.commands.config_command.get_value"), args=( ARG_SECTION, ARG_OPTION, + ARG_VERBOSE, ), ), ActionCommand( - name='list', - help='List options for the configuration', - func=lazy_load_command('airflow.cli.commands.config_command.show_config'), - args=(ARG_COLOR,), + name="list", + help="List options for the configuration", + func=lazy_load_command("airflow.cli.commands.config_command.show_config"), + args=(ARG_COLOR, ARG_VERBOSE), ), ) KUBERNETES_COMMANDS = ( ActionCommand( - name='cleanup-pods', + name="cleanup-pods", help=( "Clean up Kubernetes pods " "(created by KubernetesExecutor/KubernetesPodOperator) " "in evicted/failed/succeeded/pending states" ), - func=lazy_load_command('airflow.cli.commands.kubernetes_command.cleanup_pods'), - args=(ARG_NAMESPACE, ARG_MIN_PENDING_MINUTES), + func=lazy_load_command("airflow.cli.commands.kubernetes_command.cleanup_pods"), + args=(ARG_NAMESPACE, ARG_MIN_PENDING_MINUTES, ARG_VERBOSE), ), ActionCommand( - name='generate-dag-yaml', + name="generate-dag-yaml", help="Generate YAML files for all tasks in DAG. Useful for debugging tasks without " "launching into a cluster", - func=lazy_load_command('airflow.cli.commands.kubernetes_command.generate_pod_yaml'), - args=(ARG_DAG_ID, ARG_EXECUTION_DATE, ARG_SUBDIR, ARG_OUTPUT_PATH), + func=lazy_load_command("airflow.cli.commands.kubernetes_command.generate_pod_yaml"), + args=(ARG_DAG_ID, ARG_EXECUTION_DATE, ARG_SUBDIR, ARG_OUTPUT_PATH, ARG_VERBOSE), ), ) JOBS_COMMANDS = ( ActionCommand( - name='check', + name="check", help="Checks if job(s) are still alive", - func=lazy_load_command('airflow.cli.commands.jobs_command.check'), - args=(ARG_JOB_TYPE_FILTER, ARG_JOB_HOSTNAME_FILTER, ARG_JOB_LIMIT, ARG_ALLOW_MULTIPLE), + func=lazy_load_command("airflow.cli.commands.jobs_command.check"), + args=( + ARG_JOB_TYPE_FILTER, + ARG_JOB_HOSTNAME_FILTER, + ARG_JOB_HOSTNAME_CALLABLE_FILTER, + ARG_JOB_LIMIT, + ARG_ALLOW_MULTIPLE, + ARG_VERBOSE, + ), epilog=( - 'examples:\n' - 'To check if the local scheduler is still working properly, run:\n' - '\n' - ' $ airflow jobs check --job-type SchedulerJob --hostname "$(hostname)"\n' - '\n' - 'To check if any scheduler is running when you are using high availability, run:\n' - '\n' - ' $ airflow jobs check --job-type SchedulerJob --allow-multiple --limit 100' + "examples:\n" + "To check if the local scheduler is still working properly, run:\n" + "\n" + ' $ airflow jobs check --job-type SchedulerJob --local"\n' + "\n" + "To check if any scheduler is running when you are using high availability, run:\n" + "\n" + " $ airflow jobs check --job-type SchedulerJob --allow-multiple --limit 100" ), ), ) -airflow_commands: List[CLICommand] = [ +airflow_commands: list[CLICommand] = [ GroupCommand( - name='dags', - help='Manage DAGs', + name="dags", + help="Manage DAGs", subcommands=DAGS_COMMANDS, ), GroupCommand( - name="kubernetes", help='Tools to help run the KubernetesExecutor', subcommands=KUBERNETES_COMMANDS + name="kubernetes", help="Tools to help run the KubernetesExecutor", subcommands=KUBERNETES_COMMANDS ), GroupCommand( - name='tasks', - help='Manage tasks', + name="tasks", + help="Manage tasks", subcommands=TASKS_COMMANDS, ), GroupCommand( - name='pools', + name="pools", help="Manage pools", subcommands=POOLS_COMMANDS, ), GroupCommand( - name='variables', + name="variables", help="Manage variables", subcommands=VARIABLES_COMMANDS, ), GroupCommand( - name='jobs', + name="jobs", help="Manage jobs", subcommands=JOBS_COMMANDS, ), GroupCommand( - name='db', + name="db", help="Database operations", subcommands=DB_COMMANDS, ), ActionCommand( - name='kerberos', + name="kerberos", help="Start a kerberos ticket renewer", - func=lazy_load_command('airflow.cli.commands.kerberos_command.kerberos'), - args=(ARG_PRINCIPAL, ARG_KEYTAB, ARG_PID, ARG_DAEMON, ARG_STDOUT, ARG_STDERR, ARG_LOG_FILE), + func=lazy_load_command("airflow.cli.commands.kerberos_command.kerberos"), + args=( + ARG_PRINCIPAL, + ARG_KEYTAB, + ARG_PID, + ARG_DAEMON, + ARG_STDOUT, + ARG_STDERR, + ARG_LOG_FILE, + ARG_VERBOSE, + ), ), ActionCommand( - name='webserver', + name="webserver", help="Start a Airflow webserver instance", - func=lazy_load_command('airflow.cli.commands.webserver_command.webserver'), + func=lazy_load_command("airflow.cli.commands.webserver_command.webserver"), args=( ARG_PORT, ARG_WORKERS, @@ -1835,9 +1939,9 @@ class GroupCommand(NamedTuple): ), ), ActionCommand( - name='scheduler', + name="scheduler", help="Start a scheduler instance", - func=lazy_load_command('airflow.cli.commands.scheduler_command.scheduler'), + func=lazy_load_command("airflow.cli.commands.scheduler_command.scheduler"), args=( ARG_SUBDIR, ARG_NUM_RUNS, @@ -1848,20 +1952,21 @@ class GroupCommand(NamedTuple): ARG_STDERR, ARG_LOG_FILE, ARG_SKIP_SERVE_LOGS, + ARG_VERBOSE, ), epilog=( - 'Signals:\n' - '\n' - ' - SIGUSR2: Dump a snapshot of task state being tracked by the executor.\n' - '\n' - ' Example:\n' + "Signals:\n" + "\n" + " - SIGUSR2: Dump a snapshot of task state being tracked by the executor.\n" + "\n" + " Example:\n" ' pkill -f -USR2 "airflow scheduler"' ), ), ActionCommand( - name='triggerer', + name="triggerer", help="Start a triggerer instance", - func=lazy_load_command('airflow.cli.commands.triggerer_command.triggerer'), + func=lazy_load_command("airflow.cli.commands.triggerer_command.triggerer"), args=( ARG_PID, ARG_DAEMON, @@ -1869,12 +1974,13 @@ class GroupCommand(NamedTuple): ARG_STDERR, ARG_LOG_FILE, ARG_CAPACITY, + ARG_VERBOSE, ), ), ActionCommand( - name='dag-processor', + name="dag-processor", help="Start a standalone Dag Processor instance", - func=lazy_load_command('airflow.cli.commands.dag_processor_command.dag_processor'), + func=lazy_load_command("airflow.cli.commands.dag_processor_command.dag_processor"), args=( ARG_PID, ARG_DAEMON, @@ -1884,62 +1990,63 @@ class GroupCommand(NamedTuple): ARG_STDOUT, ARG_STDERR, ARG_LOG_FILE, + ARG_VERBOSE, ), ), ActionCommand( - name='version', + name="version", help="Show the version", - func=lazy_load_command('airflow.cli.commands.version_command.version'), + func=lazy_load_command("airflow.cli.commands.version_command.version"), args=(), ), ActionCommand( - name='cheat-sheet', + name="cheat-sheet", help="Display cheat sheet", - func=lazy_load_command('airflow.cli.commands.cheat_sheet_command.cheat_sheet'), + func=lazy_load_command("airflow.cli.commands.cheat_sheet_command.cheat_sheet"), args=(ARG_VERBOSE,), ), GroupCommand( - name='connections', + name="connections", help="Manage connections", subcommands=CONNECTIONS_COMMANDS, ), GroupCommand( - name='providers', + name="providers", help="Display providers", subcommands=PROVIDERS_COMMANDS, ), GroupCommand( - name='users', + name="users", help="Manage users", subcommands=USERS_COMMANDS, ), GroupCommand( - name='roles', - help='Manage roles', + name="roles", + help="Manage roles", subcommands=ROLES_COMMANDS, ), ActionCommand( - name='sync-perm', + name="sync-perm", help="Update permissions for existing roles and optionally DAGs", - func=lazy_load_command('airflow.cli.commands.sync_perm_command.sync_perm'), - args=(ARG_INCLUDE_DAGS,), + func=lazy_load_command("airflow.cli.commands.sync_perm_command.sync_perm"), + args=(ARG_INCLUDE_DAGS, ARG_VERBOSE), ), ActionCommand( - name='rotate-fernet-key', - func=lazy_load_command('airflow.cli.commands.rotate_fernet_key_command.rotate_fernet_key'), - help='Rotate encrypted connection credentials and variables', + name="rotate-fernet-key", + func=lazy_load_command("airflow.cli.commands.rotate_fernet_key_command.rotate_fernet_key"), + help="Rotate encrypted connection credentials and variables", description=( - 'Rotate all encrypted connection credentials and variables; see ' - 'https://airflow.apache.org/docs/apache-airflow/stable/howto/secure-connections.html' - '#rotating-encryption-keys' + "Rotate all encrypted connection credentials and variables; see " + "https://airflow.apache.org/docs/apache-airflow/stable/howto/secure-connections.html" + "#rotating-encryption-keys" ), args=(), ), - GroupCommand(name="config", help='View configuration', subcommands=CONFIG_COMMANDS), + GroupCommand(name="config", help="View configuration", subcommands=CONFIG_COMMANDS), ActionCommand( - name='info', - help='Show information about current Airflow and environment', - func=lazy_load_command('airflow.cli.commands.info_command.show_info'), + name="info", + help="Show information about current Airflow and environment", + func=lazy_load_command("airflow.cli.commands.info_command.show_info"), args=( ARG_ANONYMIZE, ARG_FILE_IO, @@ -1948,53 +2055,53 @@ class GroupCommand(NamedTuple): ), ), ActionCommand( - name='plugins', - help='Dump information about loaded plugins', - func=lazy_load_command('airflow.cli.commands.plugins_command.dump_plugins'), + name="plugins", + help="Dump information about loaded plugins", + func=lazy_load_command("airflow.cli.commands.plugins_command.dump_plugins"), args=(ARG_OUTPUT, ARG_VERBOSE), ), GroupCommand( name="celery", - help='Celery components', + help="Celery components", description=( - 'Start celery components. Works only when using CeleryExecutor. For more information, see ' - 'https://airflow.apache.org/docs/apache-airflow/stable/executor/celery.html' + "Start celery components. Works only when using CeleryExecutor. For more information, see " + "https://airflow.apache.org/docs/apache-airflow/stable/executor/celery.html" ), subcommands=CELERY_COMMANDS, ), ActionCommand( - name='standalone', - help='Run an all-in-one copy of Airflow', - func=lazy_load_command('airflow.cli.commands.standalone_command.standalone'), + name="standalone", + help="Run an all-in-one copy of Airflow", + func=lazy_load_command("airflow.cli.commands.standalone_command.standalone"), args=tuple(), ), ] -ALL_COMMANDS_DICT: Dict[str, CLICommand] = {sp.name: sp for sp in airflow_commands} +ALL_COMMANDS_DICT: dict[str, CLICommand] = {sp.name: sp for sp in airflow_commands} def _remove_dag_id_opt(command: ActionCommand): cmd = command._asdict() - cmd['args'] = (arg for arg in command.args if arg is not ARG_DAG_ID) + cmd["args"] = (arg for arg in command.args if arg is not ARG_DAG_ID) return ActionCommand(**cmd) -dag_cli_commands: List[CLICommand] = [ +dag_cli_commands: list[CLICommand] = [ GroupCommand( - name='dags', - help='Manage DAGs', + name="dags", + help="Manage DAGs", subcommands=[ _remove_dag_id_opt(sp) for sp in DAGS_COMMANDS - if sp.name in ['backfill', 'list-runs', 'pause', 'unpause'] + if sp.name in ["backfill", "list-runs", "pause", "unpause", "test"] ], ), GroupCommand( - name='tasks', - help='Manage tasks', - subcommands=[_remove_dag_id_opt(sp) for sp in TASKS_COMMANDS if sp.name in ['list', 'test', 'run']], + name="tasks", + help="Manage tasks", + subcommands=[_remove_dag_id_opt(sp) for sp in TASKS_COMMANDS if sp.name in ["list", "test", "run"]], ), ] -DAG_CLI_DICT: Dict[str, CLICommand] = {sp.name: sp for sp in dag_cli_commands} +DAG_CLI_DICT: dict[str, CLICommand] = {sp.name: sp for sp in dag_cli_commands} class AirflowHelpFormatter(argparse.HelpFormatter): @@ -2009,7 +2116,7 @@ def _format_action(self, action: Action): parts = [] action_header = self._format_action_invocation(action) - action_header = '%*s%s\n' % (self._current_indent, '', action_header) + action_header = "%*s%s\n" % (self._current_indent, "", action_header) parts.append(action_header) self._indent() @@ -2018,14 +2125,14 @@ def _format_action(self, action: Action): lambda d: isinstance(ALL_COMMANDS_DICT[d.dest], GroupCommand), subactions ) parts.append("\n") - parts.append('%*s%s:\n' % (self._current_indent, '', "Groups")) + parts.append("%*s%s:\n" % (self._current_indent, "", "Groups")) self._indent() for subaction in group_subcommands: parts.append(self._format_action(subaction)) self._dedent() parts.append("\n") - parts.append('%*s%s:\n' % (self._current_indent, '', "Commands")) + parts.append("%*s%s:\n" % (self._current_indent, "", "Commands")) self._indent() for subaction in action_subcommands: @@ -2041,9 +2148,9 @@ def _format_action(self, action: Action): @lru_cache(maxsize=None) def get_parser(dag_parser: bool = False) -> argparse.ArgumentParser: - """Creates and returns command line argument parser""" + """Creates and returns command line argument parser.""" parser = DefaultHelpParser(prog="airflow", formatter_class=AirflowHelpFormatter) - subparsers = parser.add_subparsers(dest='subcommand', metavar="GROUP_OR_COMMAND") + subparsers = parser.add_subparsers(dest="subcommand", metavar="GROUP_OR_COMMAND") subparsers.required = True command_dict = DAG_CLI_DICT if dag_parser else ALL_COMMANDS_DICT @@ -2056,10 +2163,10 @@ def get_parser(dag_parser: bool = False) -> argparse.ArgumentParser: def _sort_args(args: Iterable[Arg]) -> Iterable[Arg]: - """Sort subcommand optional args, keep positional args""" + """Sort subcommand optional args, keep positional args.""" def get_long_option(arg: Arg): - """Get long option from Arg.flags""" + """Get long option from Arg.flags.""" return arg.flags[0] if len(arg.flags) == 1 else arg.flags[1] positional, optional = partition(lambda x: x.flags[0].startswith("-"), args) diff --git a/airflow/cli/commands/celery_command.py b/airflow/cli/commands/celery_command.py index affe032411c79..0d3e1295dcd4e 100644 --- a/airflow/cli/commands/celery_command.py +++ b/airflow/cli/commands/celery_command.py @@ -15,10 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Celery command""" +"""Celery command.""" +from __future__ import annotations +from contextlib import contextmanager from multiprocessing import Process -from typing import Optional import daemon import psutil @@ -39,10 +40,10 @@ @cli_utils.action_cli def flower(args): - """Starts Flower, Celery monitoring tool""" + """Starts Flower, Celery monitoring tool.""" options = [ "flower", - conf.get('celery', 'BROKER_URL'), + conf.get("celery", "BROKER_URL"), f"--address={args.hostname}", f"--port={args.port}", ] @@ -67,11 +68,15 @@ def flower(args): stderr=args.stderr, log=args.log_file, ) - with open(stdout, "w+") as stdout, open(stderr, "w+") as stderr: + with open(stdout, "a") as stdout, open(stderr, "a") as stderr: + stdout.truncate(0) + stderr.truncate(0) + ctx = daemon.DaemonContext( pidfile=TimeoutPIDLockFile(pidfile, -1), stdout=stdout, stderr=stderr, + umask=int(settings.DAEMON_UMASK, 8), ) with ctx: celery_app.start(options) @@ -79,27 +84,21 @@ def flower(args): celery_app.start(options) -def _serve_logs(skip_serve_logs: bool = False) -> Optional[Process]: - """Starts serve_logs sub-process""" +@contextmanager +def _serve_logs(skip_serve_logs: bool = False): + """Starts serve_logs sub-process.""" + sub_proc = None if skip_serve_logs is False: sub_proc = Process(target=serve_logs) sub_proc.start() - return sub_proc - return None - - -def _run_worker(options, skip_serve_logs): - sub_proc = _serve_logs(skip_serve_logs) - try: - celery_app.worker_main(options) - finally: - if sub_proc: - sub_proc.terminate() + yield + if sub_proc: + sub_proc.terminate() @cli_utils.action_cli def worker(args): - """Starts Airflow Celery worker""" + """Starts Airflow Celery worker.""" # Disable connection pool so that celery worker does not hold an unnecessary db connection settings.reconfigure_orm(disable_connection_pool=True) if not settings.validate_session(): @@ -120,7 +119,7 @@ def worker(args): log=args.log_file, ) - if hasattr(celery_app.backend, 'ResultSession'): + if hasattr(celery_app.backend, "ResultSession"): # Pre-create the database tables now, otherwise SQLA via Celery has a # race condition where one of the subprocesses can die with "Table # already exists" error, because SQLA checks for which tables exist, @@ -137,31 +136,31 @@ def worker(args): pass # backwards-compatible: https://github.com/apache/airflow/pull/21506#pullrequestreview-879893763 - celery_log_level = conf.get('logging', 'CELERY_LOGGING_LEVEL') + celery_log_level = conf.get("logging", "CELERY_LOGGING_LEVEL") if not celery_log_level: - celery_log_level = conf.get('logging', 'LOGGING_LEVEL') + celery_log_level = conf.get("logging", "LOGGING_LEVEL") # Setup Celery worker options = [ - 'worker', - '-O', - 'fair', - '--queues', + "worker", + "-O", + "fair", + "--queues", args.queues, - '--concurrency', + "--concurrency", args.concurrency, - '--hostname', + "--hostname", args.celery_hostname, - '--loglevel', + "--loglevel", celery_log_level, - '--pidfile', + "--pidfile", pid_file_path, ] if autoscale: - options.extend(['--autoscale', autoscale]) + options.extend(["--autoscale", autoscale]) if args.without_mingle: - options.append('--without-mingle') + options.append("--without-mingle") if args.without_gossip: - options.append('--without-gossip') + options.append("--without-gossip") if conf.has_option("celery", "pool"): pool = conf.get("celery", "pool") @@ -171,32 +170,39 @@ def worker(args): # https://eventlet.net/doc/patching.html#monkey-patch # Otherwise task instances hang on the workers and are never # executed. - maybe_patch_concurrency(['-P', pool]) + maybe_patch_concurrency(["-P", pool]) if args.daemon: # Run Celery worker as daemon handle = setup_logging(log_file) - with open(stdout, 'w+') as stdout_handle, open(stderr, 'w+') as stderr_handle: + with open(stdout, "a") as stdout_handle, open(stderr, "a") as stderr_handle: if args.umask: umask = args.umask + else: + umask = conf.get("celery", "worker_umask", fallback=settings.DAEMON_UMASK) - ctx = daemon.DaemonContext( + stdout_handle.truncate(0) + stderr_handle.truncate(0) + + daemon_context = daemon.DaemonContext( files_preserve=[handle], umask=int(umask, 8), stdout=stdout_handle, stderr=stderr_handle, ) - with ctx: - _run_worker(options=options, skip_serve_logs=skip_serve_logs) + with daemon_context, _serve_logs(skip_serve_logs): + celery_app.worker_main(options) + else: # Run Celery worker in the same process - _run_worker(options=options, skip_serve_logs=skip_serve_logs) + with _serve_logs(skip_serve_logs): + celery_app.worker_main(options) @cli_utils.action_cli def stop_worker(args): - """Sends SIGTERM to Celery worker""" + """Sends SIGTERM to Celery worker.""" # Read PID from file if args.pid: pid_file_path = args.pid diff --git a/airflow/cli/commands/cheat_sheet_command.py b/airflow/cli/commands/cheat_sheet_command.py index 001a8721330cb..88d9c5940c7ca 100644 --- a/airflow/cli/commands/cheat_sheet_command.py +++ b/airflow/cli/commands/cheat_sheet_command.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Iterable, List, Optional, Union +from __future__ import annotations + +from typing import Iterable from airflow.cli.cli_parser import ActionCommand, GroupCommand, airflow_commands from airflow.cli.simple_table import AirflowConsole, SimpleTable @@ -31,12 +33,12 @@ def display_commands_index(): """Display list of all commands.""" def display_recursive( - prefix: List[str], - commands: Iterable[Union[GroupCommand, ActionCommand]], - help_msg: Optional[str] = None, + prefix: list[str], + commands: Iterable[GroupCommand | ActionCommand], + help_msg: str | None = None, ): - actions: List[ActionCommand] = [] - groups: List[GroupCommand] = [] + actions: list[ActionCommand] = [] + groups: list[GroupCommand] = [] for command in commands: if isinstance(command, GroupCommand): groups.append(command) diff --git a/airflow/cli/commands/config_command.py b/airflow/cli/commands/config_command.py index 1c2674fc811b2..f3a1eecfee958 100644 --- a/airflow/cli/commands/config_command.py +++ b/airflow/cli/commands/config_command.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Config sub-commands""" +"""Config sub-commands.""" +from __future__ import annotations + import io import pygments @@ -26,7 +28,7 @@ def show_config(args): - """Show current application configuration""" + """Show current application configuration.""" with io.StringIO() as output: conf.write(output) code = output.getvalue() @@ -36,12 +38,12 @@ def show_config(args): def get_value(args): - """Get one value from configuration""" + """Get one value from configuration.""" if not conf.has_section(args.section): - raise SystemExit(f'The section [{args.section}] is not found in config.') + raise SystemExit(f"The section [{args.section}] is not found in config.") if not conf.has_option(args.section, args.option): - raise SystemExit(f'The option [{args.section}/{args.option}] is not found in config.') + raise SystemExit(f"The option [{args.section}/{args.option}] is not found in config.") value = conf.get(args.section, args.option) print(value) diff --git a/airflow/cli/commands/connection_command.py b/airflow/cli/commands/connection_command.py index 8a0c0a3acb80a..7737f4fa2cab3 100644 --- a/airflow/cli/commands/connection_command.py +++ b/airflow/cli/commands/connection_command.py @@ -14,15 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Connection sub-commands""" +"""Connection sub-commands.""" +from __future__ import annotations + import io import json import os import sys import warnings from pathlib import Path -from typing import Any, Dict, List -from urllib.parse import urlparse, urlunparse +from typing import Any +from urllib.parse import urlsplit, urlunsplit from sqlalchemy.orm import exc @@ -38,21 +40,21 @@ from airflow.utils.session import create_session -def _connection_mapper(conn: Connection) -> Dict[str, Any]: +def _connection_mapper(conn: Connection) -> dict[str, Any]: return { - 'id': conn.id, - 'conn_id': conn.conn_id, - 'conn_type': conn.conn_type, - 'description': conn.description, - 'host': conn.host, - 'schema': conn.schema, - 'login': conn.login, - 'password': conn.password, - 'port': conn.port, - 'is_encrypted': conn.is_encrypted, - 'is_extra_encrypted': conn.is_encrypted, - 'extra_dejson': conn.extra_dejson, - 'get_uri': conn.get_uri(), + "id": conn.id, + "conn_id": conn.conn_id, + "conn_type": conn.conn_type, + "description": conn.description, + "host": conn.host, + "schema": conn.schema, + "login": conn.login, + "password": conn.password, + "port": conn.port, + "is_encrypted": conn.is_encrypted, + "is_extra_encrypted": conn.is_encrypted, + "extra_dejson": conn.extra_dejson, + "get_uri": conn.get_uri(), } @@ -72,7 +74,7 @@ def connections_get(args): @suppress_logs_and_warning def connections_list(args): - """Lists all connections at the command line""" + """Lists all connections at the command line.""" with create_session() as session: query = session.query(Connection) if args.conn_id: @@ -99,14 +101,14 @@ def _connection_to_dict(conn: Connection) -> dict: ) -def _format_connections(conns: List[Connection], file_format: str, serialization_format: str) -> str: - if serialization_format == 'json': +def _format_connections(conns: list[Connection], file_format: str, serialization_format: str) -> str: + if serialization_format == "json": serializer_func = lambda x: json.dumps(_connection_to_dict(x)) - elif serialization_format == 'uri': + elif serialization_format == "uri": serializer_func = Connection.get_uri else: raise SystemExit(f"Received unexpected value for `--serialization-format`: {serialization_format!r}") - if file_format == '.env': + if file_format == ".env": connections_env = "" for conn in conns: connections_env += f"{conn.conn_id}={serializer_func(conn)}\n" @@ -116,29 +118,29 @@ def _format_connections(conns: List[Connection], file_format: str, serialization for conn in conns: connections_dict[conn.conn_id] = _connection_to_dict(conn) - if file_format == '.yaml': + if file_format == ".yaml": return yaml.dump(connections_dict) - if file_format == '.json': + if file_format == ".json": return json.dumps(connections_dict, indent=2) return json.dumps(connections_dict) def _is_stdout(fileio: io.TextIOWrapper) -> bool: - return fileio.name == '' + return fileio.name == "" def _valid_uri(uri: str) -> bool: - """Check if a URI is valid, by checking if both scheme and netloc are available""" - uri_parts = urlparse(uri) - return uri_parts.scheme != '' and uri_parts.netloc != '' + """Check if a URI is valid, by checking if both scheme and netloc are available.""" + uri_parts = urlsplit(uri) + return uri_parts.scheme != "" and uri_parts.netloc != "" @cache -def _get_connection_types(): +def _get_connection_types() -> list[str]: """Returns connection types available.""" - _connection_types = ['fs', 'mesos_framework-id', 'email', 'generic'] + _connection_types = ["fs", "mesos_framework-id", "email", "generic"] providers_manager = ProvidersManager() for connection_type, provider_info in providers_manager.hooks.items(): if provider_info: @@ -146,18 +148,14 @@ def _get_connection_types(): return _connection_types -def _valid_conn_type(conn_type: str) -> bool: - return conn_type in _get_connection_types() - - def connections_export(args): - """Exports all connections to a file""" - file_formats = ['.yaml', '.json', '.env'] + """Exports all connections to a file.""" + file_formats = [".yaml", ".json", ".env"] if args.format: warnings.warn("Option `--format` is deprecated. Use `--file-format` instead.", DeprecationWarning) if args.format and args.file_format: - raise SystemExit('Option `--format` is deprecated. Use `--file-format` instead.') - default_format = '.json' + raise SystemExit("Option `--format` is deprecated. Use `--file-format` instead.") + default_format = ".json" provided_file_format = None if args.format or args.file_format: provided_file_format = f".{(args.format or args.file_format).lower()}" @@ -175,7 +173,7 @@ def connections_export(args): f"Unsupported file format. The file must have the extension {', '.join(file_formats)}." ) - if args.serialization_format and not filetype == '.env': + if args.serialization_format and not filetype == ".env": raise SystemExit("Option `--serialization-format` may only be used with file type `env`.") with create_session() as session: @@ -184,7 +182,7 @@ def connections_export(args): msg = _format_connections( conns=connections, file_format=filetype, - serialization_format=args.serialization_format or 'uri', + serialization_format=args.serialization_format or "uri", ) with args.file as f: @@ -196,34 +194,34 @@ def connections_export(args): print(f"Connections successfully exported to {args.file.name}.") -alternative_conn_specs = ['conn_type', 'conn_host', 'conn_login', 'conn_password', 'conn_schema', 'conn_port'] +alternative_conn_specs = ["conn_type", "conn_host", "conn_login", "conn_password", "conn_schema", "conn_port"] @cli_utils.action_cli def connections_add(args): - """Adds new connection""" + """Adds new connection.""" has_uri = bool(args.conn_uri) has_json = bool(args.conn_json) has_type = bool(args.conn_type) if not has_type and not (has_json or has_uri): - raise SystemExit('Must supply either conn-uri or conn-json if not supplying conn-type') + raise SystemExit("Must supply either conn-uri or conn-json if not supplying conn-type") if has_json and has_uri: - raise SystemExit('Cannot supply both conn-uri and conn-json') + raise SystemExit("Cannot supply both conn-uri and conn-json") if has_type and not (args.conn_type in _get_connection_types()): - warnings.warn(f'The type provided to --conn-type is invalid: {args.conn_type}') + warnings.warn(f"The type provided to --conn-type is invalid: {args.conn_type}") warnings.warn( - f'Supported --conn-types are:{_get_connection_types()}.' - 'Hence overriding the conn-type with generic' + f"Supported --conn-types are:{_get_connection_types()}." + "Hence overriding the conn-type with generic" ) - args.conn_type = 'generic' + args.conn_type = "generic" if has_uri or has_json: invalid_args = [] if has_uri and not _valid_uri(args.conn_uri): - raise SystemExit(f'The URI provided to --conn-uri is invalid: {args.conn_uri}') + raise SystemExit(f"The URI provided to --conn-uri is invalid: {args.conn_uri}") for arg in alternative_conn_specs: if getattr(args, arg) is not None: @@ -245,7 +243,7 @@ def connections_add(args): elif args.conn_json: new_conn = Connection.from_json(conn_id=args.conn_id, value=args.conn_json) if not new_conn.conn_type: - raise SystemExit('conn-json is invalid; must supply conn-type') + raise SystemExit("conn-json is invalid; must supply conn-type") else: new_conn = Connection( conn_id=args.conn_id, @@ -263,38 +261,37 @@ def connections_add(args): with create_session() as session: if not session.query(Connection).filter(Connection.conn_id == new_conn.conn_id).first(): session.add(new_conn) - msg = 'Successfully added `conn_id`={conn_id} : {uri}' + msg = "Successfully added `conn_id`={conn_id} : {uri}" msg = msg.format( conn_id=new_conn.conn_id, uri=args.conn_uri - or urlunparse( + or urlunsplit( ( new_conn.conn_type, f"{new_conn.login or ''}:{'******' if new_conn.password else ''}" f"@{new_conn.host or ''}:{new_conn.port or ''}", - new_conn.schema or '', - '', - '', - '', + new_conn.schema or "", + "", + "", ) ), ) print(msg) else: - msg = f'A connection with `conn_id`={new_conn.conn_id} already exists.' + msg = f"A connection with `conn_id`={new_conn.conn_id} already exists." raise SystemExit(msg) @cli_utils.action_cli def connections_delete(args): - """Deletes connection from DB""" + """Deletes connection from DB.""" with create_session() as session: try: to_delete = session.query(Connection).filter(Connection.conn_id == args.conn_id).one() except exc.NoResultFound: - raise SystemExit(f'Did not find a connection with `conn_id`={args.conn_id}') + raise SystemExit(f"Did not find a connection with `conn_id`={args.conn_id}") except exc.MultipleResultsFound: - raise SystemExit(f'Found more than one connection with `conn_id`={args.conn_id}') + raise SystemExit(f"Found more than one connection with `conn_id`={args.conn_id}") else: session.delete(to_delete) print(f"Successfully deleted connection with `conn_id`={to_delete.conn_id}") @@ -302,7 +299,7 @@ def connections_delete(args): @cli_utils.action_cli(check_db=False) def connections_import(args): - """Imports connections from a file""" + """Imports connections from a file.""" if os.path.exists(args.file): _import_helper(args.file) else: @@ -315,9 +312,9 @@ def _import_helper(file_path): with create_session() as session: for conn_id, conn in connections_dict.items(): if session.query(Connection).filter(Connection.conn_id == conn_id).first(): - print(f'Could not import connection {conn_id}: connection already exists.') + print(f"Could not import connection {conn_id}: connection already exists.") continue session.add(conn) session.commit() - print(f'Imported connection {conn_id}') + print(f"Imported connection {conn_id}") diff --git a/airflow/cli/commands/dag_command.py b/airflow/cli/commands/dag_command.py index a06c5a706907a..0465c474ed916 100644 --- a/airflow/cli/commands/dag_command.py +++ b/airflow/cli/commands/dag_command.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Dag sub-commands.""" +from __future__ import annotations -"""Dag sub-commands""" import ast import errno import json @@ -23,31 +24,32 @@ import signal import subprocess import sys -from typing import Optional from graphviz.dot import Dot +from sqlalchemy.orm import Session from sqlalchemy.sql.functions import func from airflow import settings from airflow.api.client import get_current_api_client from airflow.cli.simple_table import AirflowConsole from airflow.configuration import conf -from airflow.exceptions import AirflowException, BackfillUnfinished -from airflow.executors.debug_executor import DebugExecutor +from airflow.exceptions import AirflowException, RemovedInAirflow3Warning from airflow.jobs.base_job import BaseJob from airflow.models import DagBag, DagModel, DagRun, TaskInstance from airflow.models.dag import DAG from airflow.models.serialized_dag import SerializedDagModel -from airflow.utils import cli as cli_utils -from airflow.utils.cli import get_dag, process_subdir, sigint_handler, suppress_logs_and_warning +from airflow.utils import cli as cli_utils, timezone +from airflow.utils.cli import get_dag, get_dags, process_subdir, sigint_handler, suppress_logs_and_warning from airflow.utils.dot_renderer import render_dag, render_dag_dependencies from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.state import DagRunState +log = logging.getLogger(__name__) + @cli_utils.action_cli def dag_backfill(args, dag=None): - """Creates backfill job or dry run for a DAG""" + """Creates backfill job or dry run for a DAG or list of DAGs using regex.""" logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT) signal.signal(signal.SIGTERM, sigint_handler) @@ -55,8 +57,8 @@ def dag_backfill(args, dag=None): import warnings warnings.warn( - '--ignore-first-depends-on-past is deprecated as the value is always set to True', - category=PendingDeprecationWarning, + "--ignore-first-depends-on-past is deprecated as the value is always set to True", + category=RemovedInAirflow3Warning, ) if args.ignore_first_depends_on_past is False: @@ -65,69 +67,79 @@ def dag_backfill(args, dag=None): if not args.start_date and not args.end_date: raise AirflowException("Provide a start_date and/or end_date") - dag = dag or get_dag(args.subdir, args.dag_id) + if not dag: + dags = get_dags(args.subdir, dag_id=args.dag_id, use_regex=args.treat_dag_as_regex) + else: + dags = dag if type(dag) == list else [dag] + + dags.sort(key=lambda d: d.dag_id) # If only one date is passed, using same as start and end args.end_date = args.end_date or args.start_date args.start_date = args.start_date or args.end_date - if args.task_regex: - dag = dag.partial_subset( - task_ids_or_regex=args.task_regex, include_upstream=not args.ignore_dependencies - ) - if not dag.task_dict: - raise AirflowException( - f"There are no tasks that match '{args.task_regex}' regex. Nothing to run, exiting..." - ) - run_conf = None if args.conf: run_conf = json.loads(args.conf) - if args.dry_run: - print(f"Dry run of DAG {args.dag_id} on {args.start_date}") - dr = DagRun(dag.dag_id, execution_date=args.start_date) - for task in dag.tasks: - print(f"Task {task.task_id}") - ti = TaskInstance(task, run_id=None) - ti.dag_run = dr - ti.dry_run() - else: - if args.reset_dagruns: - DAG.clear_dags( - [dag], - start_date=args.start_date, - end_date=args.end_date, - confirm_prompt=not args.yes, - include_subdags=True, - dag_run_state=DagRunState.QUEUED, - ) - - try: - dag.run( - start_date=args.start_date, - end_date=args.end_date, - mark_success=args.mark_success, - local=args.local, - donot_pickle=(args.donot_pickle or conf.getboolean('core', 'donot_pickle')), - ignore_first_depends_on_past=args.ignore_first_depends_on_past, - ignore_task_deps=args.ignore_dependencies, - pool=args.pool, - delay_on_limit_secs=args.delay_on_limit, - verbose=args.verbose, - conf=run_conf, - rerun_failed_tasks=args.rerun_failed_tasks, - run_backwards=args.run_backwards, - continue_on_failures=args.continue_on_failures, + for dag in dags: + if args.task_regex: + dag = dag.partial_subset( + task_ids_or_regex=args.task_regex, include_upstream=not args.ignore_dependencies ) - except ValueError as vr: - print(str(vr)) - sys.exit(1) + if not dag.task_dict: + raise AirflowException( + f"There are no tasks that match '{args.task_regex}' regex. Nothing to run, exiting..." + ) + + if args.dry_run: + print(f"Dry run of DAG {dag.dag_id} on {args.start_date}") + dr = DagRun(dag.dag_id, execution_date=args.start_date) + for task in dag.tasks: + print(f"Task {task.task_id} located in DAG {dag.dag_id}") + ti = TaskInstance(task, run_id=None) + ti.dag_run = dr + ti.dry_run() + else: + if args.reset_dagruns: + DAG.clear_dags( + [dag], + start_date=args.start_date, + end_date=args.end_date, + confirm_prompt=not args.yes, + include_subdags=True, + dag_run_state=DagRunState.QUEUED, + ) + + try: + dag.run( + start_date=args.start_date, + end_date=args.end_date, + mark_success=args.mark_success, + local=args.local, + donot_pickle=(args.donot_pickle or conf.getboolean("core", "donot_pickle")), + ignore_first_depends_on_past=args.ignore_first_depends_on_past, + ignore_task_deps=args.ignore_dependencies, + pool=args.pool, + delay_on_limit_secs=args.delay_on_limit, + verbose=args.verbose, + conf=run_conf, + rerun_failed_tasks=args.rerun_failed_tasks, + run_backwards=args.run_backwards, + continue_on_failures=args.continue_on_failures, + disable_retry=args.disable_retry, + ) + except ValueError as vr: + print(str(vr)) + sys.exit(1) + + if len(dags) > 1: + log.info("All of the backfills are done.") @cli_utils.action_cli def dag_trigger(args): - """Creates a dag run for the specified dag""" + """Creates a dag run for the specified dag.""" api_client = get_current_api_client() try: message = api_client.trigger_dag( @@ -140,7 +152,7 @@ def dag_trigger(args): @cli_utils.action_cli def dag_delete(args): - """Deletes all DB records related to the specified dag""" + """Deletes all DB records related to the specified dag.""" api_client = get_current_api_client() if ( args.yes @@ -158,18 +170,18 @@ def dag_delete(args): @cli_utils.action_cli def dag_pause(args): - """Pauses a DAG""" + """Pauses a DAG.""" set_is_paused(True, args) @cli_utils.action_cli def dag_unpause(args): - """Unpauses a DAG""" + """Unpauses a DAG.""" set_is_paused(False, args) def set_is_paused(is_paused, args): - """Sets is_paused for DAG by a given dag_id""" + """Sets is_paused for DAG by a given dag_id.""" dag = DagModel.get_dagmodel(args.dag_id) if not dag: @@ -181,7 +193,7 @@ def set_is_paused(is_paused, args): def dag_dependencies_show(args): - """Displays DAG dependencies, save to file or show as imgcat image""" + """Displays DAG dependencies, save to file or show as imgcat image.""" dot = render_dag_dependencies(SerializedDagModel.get_dag_dependencies()) filename = args.save imgcat = args.imgcat @@ -200,7 +212,7 @@ def dag_dependencies_show(args): def dag_show(args): - """Displays DAG or saves it's graphic representation to the file""" + """Displays DAG or saves it's graphic representation to the file.""" dag = get_dag(args.subdir, args.dag_id) dot = render_dag(dag) filename = args.save @@ -220,25 +232,23 @@ def dag_show(args): def _display_dot_via_imgcat(dot: Dot): - data = dot.pipe(format='png') + data = dot.pipe(format="png") try: with subprocess.Popen("imgcat", stdout=subprocess.PIPE, stdin=subprocess.PIPE) as proc: out, err = proc.communicate(data) if out: - print(out.decode('utf-8')) + print(out.decode("utf-8")) if err: - print(err.decode('utf-8')) + print(err.decode("utf-8")) except OSError as e: if e.errno == errno.ENOENT: - raise SystemExit( - "Failed to execute. Make sure the imgcat executables are on your systems \'PATH\'" - ) + raise SystemExit("Failed to execute. Make sure the imgcat executables are on your systems 'PATH'") else: raise def _save_dot_to_file(dot: Dot, filename: str): - filename_without_ext, _, ext = filename.rpartition('.') + filename_without_ext, _, ext = filename.rpartition(".") dot.render(filename=filename_without_ext, format=ext, cleanup=True) print(f"File {filename} saved") @@ -253,16 +263,15 @@ def dag_state(args, session=NEW_SESSION): >>> airflow dags state a_dag_with_conf_passed 2015-01-01T00:00:00.000000 failed, {"name": "bob", "age": "42"} """ - dag = DagModel.get_dagmodel(args.dag_id, session=session) if not dag: raise SystemExit(f"DAG: {args.dag_id} does not exist in 'dag' table") dr = session.query(DagRun).filter_by(dag_id=args.dag_id, execution_date=args.execution_date).one_or_none() out = dr.state if dr else None - conf_out = '' + conf_out = "" if out and dr.conf: - conf_out = ', ' + json.dumps(dr.conf) + conf_out = ", " + json.dumps(dr.conf) print(str(out) + conf_out) @@ -284,7 +293,7 @@ def dag_next_execution(args): .filter(DagRun.dag_id == dag.dag_id) .subquery() ) - max_date_run: Optional[DagRun] = ( + max_date_run: DagRun | None = ( session.query(DagRun) .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == max_date_subq.c.max_date) .one_or_none() @@ -314,7 +323,7 @@ def dag_next_execution(args): @cli_utils.action_cli @suppress_logs_and_warning def dag_list_dags(args): - """Displays dags with or without stats at the command line""" + """Displays dags with or without stats at the command line.""" dagbag = DagBag(process_subdir(args.subdir)) if dagbag.import_errors: from rich import print as rich_print @@ -339,7 +348,7 @@ def dag_list_dags(args): @cli_utils.action_cli @suppress_logs_and_warning def dag_list_import_errors(args): - """Displays dags with import errors on the command line""" + """Displays dags with import errors on the command line.""" dagbag = DagBag(process_subdir(args.subdir)) data = [] for filename, errors in dagbag.import_errors.items(): @@ -353,7 +362,7 @@ def dag_list_import_errors(args): @cli_utils.action_cli @suppress_logs_and_warning def dag_report(args): - """Displays dagbag stats at the command line""" + """Displays dagbag stats at the command line.""" dagbag = DagBag(process_subdir(args.subdir)) AirflowConsole().print_as( data=dagbag.dagbag_stats, @@ -372,7 +381,7 @@ def dag_report(args): @suppress_logs_and_warning @provide_session def dag_list_jobs(args, dag=None, session=NEW_SESSION): - """Lists latest n jobs""" + """Lists latest n jobs.""" queries = [] if dag: args.dag_id = dag.dag_id @@ -386,7 +395,7 @@ def dag_list_jobs(args, dag=None, session=NEW_SESSION): if args.state: queries.append(BaseJob.state == args.state) - fields = ['dag_id', 'state', 'job_type', 'start_date', 'end_date'] + fields = ["dag_id", "state", "job_type", "start_date", "end_date"] all_jobs = ( session.query(BaseJob).filter(*queries).order_by(BaseJob.start_date.desc()).limit(args.limit).all() ) @@ -402,7 +411,7 @@ def dag_list_jobs(args, dag=None, session=NEW_SESSION): @suppress_logs_and_warning @provide_session def dag_list_dag_runs(args, dag=None, session=NEW_SESSION): - """Lists dag runs for a given DAG""" + """Lists dag runs for a given DAG.""" if dag: args.dag_id = dag.dag_id else: @@ -430,31 +439,25 @@ def dag_list_dag_runs(args, dag=None, session=NEW_SESSION): "run_id": dr.run_id, "state": dr.state, "execution_date": dr.execution_date.isoformat(), - "start_date": dr.start_date.isoformat() if dr.start_date else '', - "end_date": dr.end_date.isoformat() if dr.end_date else '', + "start_date": dr.start_date.isoformat() if dr.start_date else "", + "end_date": dr.end_date.isoformat() if dr.end_date else "", }, ) @provide_session @cli_utils.action_cli -def dag_test(args, session=None): - """Execute one single DagRun for a given DAG and execution date, using the DebugExecutor.""" - dag = get_dag(subdir=args.subdir, dag_id=args.dag_id) - dag.clear(start_date=args.execution_date, end_date=args.execution_date, dag_run_state=False) - try: - dag.run( - executor=DebugExecutor(), - start_date=args.execution_date, - end_date=args.execution_date, - # Always run the DAG at least once even if no logical runs are - # available. This does not make a lot of sense, but Airflow has - # been doing this prior to 2.2 so we keep compatibility. - run_at_least_once=True, - ) - except BackfillUnfinished as e: - print(str(e)) - +def dag_test(args, dag=None, session=None): + """Execute one single DagRun for a given DAG and execution date.""" + run_conf = None + if args.conf: + try: + run_conf = json.loads(args.conf) + except ValueError as e: + raise SystemExit(f"Configuration {args.conf!r} is not valid JSON. Error: {e}") + execution_date = args.execution_date or timezone.utcnow() + dag = dag or get_dag(subdir=args.subdir, dag_id=args.dag_id) + dag.test(execution_date=execution_date, run_conf=run_conf, session=session) show_dagrun = args.show_dagrun imgcat = args.imgcat_dagrun filename = args.save_dagrun @@ -463,7 +466,7 @@ def dag_test(args, session=None): session.query(TaskInstance) .filter( TaskInstance.dag_id == args.dag_id, - TaskInstance.execution_date == args.execution_date, + TaskInstance.execution_date == execution_date, ) .all() ) @@ -480,10 +483,10 @@ def dag_test(args, session=None): @provide_session @cli_utils.action_cli -def dag_reserialize(args, session=None): +def dag_reserialize(args, session: Session = NEW_SESSION): + """Serialize a DAG instance.""" session.query(SerializedDagModel).delete(synchronize_session=False) if not args.clear_only: - dagbag = DagBag() - dagbag.collect_dags(only_if_updated=False, safe_mode=False) - dagbag.sync_to_db() + dagbag = DagBag(process_subdir(args.subdir)) + dagbag.sync_to_db(session=session) diff --git a/airflow/cli/commands/dag_processor_command.py b/airflow/cli/commands/dag_processor_command.py index d57e26510c974..f8ce65663bd40 100644 --- a/airflow/cli/commands/dag_processor_command.py +++ b/airflow/cli/commands/dag_processor_command.py @@ -14,14 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""DagProcessor command.""" +from __future__ import annotations -"""DagProcessor command""" import logging from datetime import timedelta import daemon from daemon.pidfile import TimeoutPIDLockFile +from airflow import settings from airflow.configuration import conf from airflow.dag_processing.manager import DagFileProcessorManager from airflow.utils import cli as cli_utils @@ -32,7 +34,7 @@ def _create_dag_processor_manager(args) -> DagFileProcessorManager: """Creates DagFileProcessorProcess instance.""" - processor_timeout_seconds: int = conf.getint('core', 'dag_file_processor_timeout') + processor_timeout_seconds: int = conf.getint("core", "dag_file_processor_timeout") processor_timeout = timedelta(seconds=processor_timeout_seconds) return DagFileProcessorManager( dag_directory=args.subdir, @@ -45,13 +47,13 @@ def _create_dag_processor_manager(args) -> DagFileProcessorManager: @cli_utils.action_cli def dag_processor(args): - """Starts Airflow Dag Processor Job""" + """Starts Airflow Dag Processor Job.""" if not conf.getboolean("scheduler", "standalone_dag_processor"): - raise SystemExit('The option [scheduler/standalone_dag_processor] must be True.') + raise SystemExit("The option [scheduler/standalone_dag_processor] must be True.") - sql_conn: str = conf.get('database', 'sql_alchemy_conn').lower() - if sql_conn.startswith('sqlite'): - raise SystemExit('Standalone DagProcessor is not supported when using sqlite.') + sql_conn: str = conf.get("database", "sql_alchemy_conn").lower() + if sql_conn.startswith("sqlite"): + raise SystemExit("Standalone DagProcessor is not supported when using sqlite.") manager = _create_dag_processor_manager(args) @@ -60,12 +62,16 @@ def dag_processor(args): "dag-processor", args.pid, args.stdout, args.stderr, args.log_file ) handle = setup_logging(log_file) - with open(stdout, 'w+') as stdout_handle, open(stderr, 'w+') as stderr_handle: + with open(stdout, "a") as stdout_handle, open(stderr, "a") as stderr_handle: + stdout_handle.truncate(0) + stderr_handle.truncate(0) + ctx = daemon.DaemonContext( pidfile=TimeoutPIDLockFile(pid, -1), files_preserve=[handle], stdout=stdout_handle, stderr=stderr_handle, + umask=int(settings.DAEMON_UMASK, 8), ) with ctx: try: diff --git a/airflow/cli/commands/db_command.py b/airflow/cli/commands/db_command.py index c9201ad59ba80..4064285db8b45 100644 --- a/airflow/cli/commands/db_command.py +++ b/airflow/cli/commands/db_command.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Database sub-commands""" +"""Database sub-commands.""" +from __future__ import annotations + import os import textwrap from tempfile import NamedTemporaryFile @@ -30,14 +32,14 @@ def initdb(args): - """Initializes the metadata database""" + """Initializes the metadata database.""" print("DB: " + repr(settings.engine.url)) db.initdb() print("Initialization done") def resetdb(args): - """Resets the metadata database""" + """Resets the metadata database.""" print("DB: " + repr(settings.engine.url)) if not (args.yes or input("This will drop existing tables if they exist. Proceed? (y/n)").upper() == "Y"): raise SystemExit("Cancelled") @@ -46,7 +48,7 @@ def resetdb(args): @cli_utils.action_cli(check_db=False) def upgradedb(args): - """Upgrades the metadata database""" + """Upgrades the metadata database.""" print("DB: " + repr(settings.engine.url)) if args.to_revision and args.to_version: raise SystemExit("Cannot supply both `--to-revision` and `--to-version`.") @@ -61,7 +63,7 @@ def upgradedb(args): if args.from_revision: from_revision = args.from_revision elif args.from_version: - if parse_version(args.from_version) < parse_version('2.0.0'): + if parse_version(args.from_version) < parse_version("2.0.0"): raise SystemExit("--from-version must be greater or equal to than 2.0.0") from_revision = REVISION_HEADS_MAP.get(args.from_version) if not from_revision: @@ -79,14 +81,19 @@ def upgradedb(args): else: print("Generating sql for upgrade -- upgrade commands will *not* be submitted.") - db.upgradedb(to_revision=to_revision, from_revision=from_revision, show_sql_only=args.show_sql_only) + db.upgradedb( + to_revision=to_revision, + from_revision=from_revision, + show_sql_only=args.show_sql_only, + reserialize_dags=args.reserialize_dags, + ) if not args.show_sql_only: print("Upgrades done") @cli_utils.action_cli(check_db=False) def downgrade(args): - """Downgrades the metadata database""" + """Downgrades the metadata database.""" if args.to_revision and args.to_version: raise SystemExit("Cannot supply both `--to-revision` and `--to-version`.") if args.from_version and args.from_revision: @@ -132,17 +139,17 @@ def downgrade(args): def check_migrations(args): - """Function to wait for all airflow migrations to complete. Used for launching airflow in k8s""" + """Function to wait for all airflow migrations to complete. Used for launching airflow in k8s.""" db.check_migrations(timeout=args.migration_wait_timeout) @cli_utils.action_cli(check_db=False) def shell(args): - """Run a shell that allows to access metadata database""" + """Run a shell that allows to access metadata database.""" url = settings.engine.url print("DB: " + repr(url)) - if url.get_backend_name() == 'mysql': + if url.get_backend_name() == "mysql": with NamedTemporaryFile(suffix="my.cnf") as f: content = textwrap.dedent( f""" @@ -157,23 +164,23 @@ def shell(args): f.write(content.encode()) f.flush() execute_interactive(["mysql", f"--defaults-extra-file={f.name}"]) - elif url.get_backend_name() == 'sqlite': + elif url.get_backend_name() == "sqlite": execute_interactive(["sqlite3", url.database]) - elif url.get_backend_name() == 'postgresql': + elif url.get_backend_name() == "postgresql": env = os.environ.copy() - env['PGHOST'] = url.host or "" - env['PGPORT'] = str(url.port or "5432") - env['PGUSER'] = url.username or "" + env["PGHOST"] = url.host or "" + env["PGPORT"] = str(url.port or "5432") + env["PGUSER"] = url.username or "" # PostgreSQL does not allow the use of PGPASSFILE if the current user is root. env["PGPASSWORD"] = url.password or "" - env['PGDATABASE'] = url.database + env["PGDATABASE"] = url.database execute_interactive(["psql"], env=env) - elif url.get_backend_name() == 'mssql': + elif url.get_backend_name() == "mssql": env = os.environ.copy() - env['MSSQL_CLI_SERVER'] = url.host - env['MSSQL_CLI_DATABASE'] = url.database - env['MSSQL_CLI_USER'] = url.username - env['MSSQL_CLI_PASSWORD'] = url.password + env["MSSQL_CLI_SERVER"] = url.host + env["MSSQL_CLI_DATABASE"] = url.database + env["MSSQL_CLI_USER"] = url.username + env["MSSQL_CLI_PASSWORD"] = url.password execute_interactive(["mssql-cli"], env=env) else: raise AirflowException(f"Unknown driver: {url.drivername}") @@ -191,11 +198,12 @@ def check(_): @cli_utils.action_cli(check_db=False) def cleanup_tables(args): - """Purges old records in metadata database""" + """Purges old records in metadata database.""" run_cleanup( table_names=args.tables, dry_run=args.dry_run, clean_before_timestamp=args.clean_before_timestamp, verbose=args.verbose, confirm=not args.yes, + skip_archive=args.skip_archive, ) diff --git a/airflow/cli/commands/info_command.py b/airflow/cli/commands/info_command.py index fc03615210af3..7261dfc484156 100644 --- a/airflow/cli/commands/info_command.py +++ b/airflow/cli/commands/info_command.py @@ -14,14 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Config sub-commands""" +"""Config sub-commands.""" +from __future__ import annotations + import locale import logging import os import platform import subprocess import sys -from typing import List, Optional +from enum import Enum from urllib.parse import urlsplit, urlunsplit import httpx @@ -42,13 +44,13 @@ class Anonymizer(Protocol): """Anonymizer protocol.""" def process_path(self, value) -> str: - """Remove pii from paths""" + """Remove pii from paths.""" def process_username(self, value) -> str: - """Remove pii from username""" + """Remove pii from username.""" def process_url(self, value) -> str: - """Remove pii from URL""" + """Remove pii from URL.""" class NullAnonymizer(Anonymizer): @@ -123,17 +125,18 @@ def process_url(self, value) -> str: return urlunsplit((url_parts.scheme, netloc, url_parts.path, url_parts.query, url_parts.fragment)) -class OperatingSystem: - """Operating system""" +class OperatingSystem(Enum): + """Operating system.""" WINDOWS = "Windows" LINUX = "Linux" MACOSX = "Mac OS" CYGWIN = "Cygwin" + UNKNOWN = "Unknown" @staticmethod - def get_current() -> Optional[str]: - """Get current operating system""" + def get_current() -> OperatingSystem: + """Get current operating system.""" if os.name == "nt": return OperatingSystem.WINDOWS elif "linux" in sys.platform: @@ -142,24 +145,26 @@ def get_current() -> Optional[str]: return OperatingSystem.MACOSX elif "cygwin" in sys.platform: return OperatingSystem.CYGWIN - return None + return OperatingSystem.UNKNOWN -class Architecture: - """Compute architecture""" +class Architecture(Enum): + """Compute architecture.""" X86_64 = "x86_64" X86 = "x86" PPC = "ppc" ARM = "arm" + UNKNOWN = "unknown" @staticmethod - def get_current(): - """Get architecture""" - return _MACHINE_TO_ARCHITECTURE.get(platform.machine().lower()) + def get_current() -> Architecture: + """Get architecture.""" + current_architecture = _MACHINE_TO_ARCHITECTURE.get(platform.machine().lower()) + return current_architecture if current_architecture else Architecture.UNKNOWN -_MACHINE_TO_ARCHITECTURE = { +_MACHINE_TO_ARCHITECTURE: dict[str, Architecture] = { "amd64": Architecture.X86_64, "x86_64": Architecture.X86_64, "i686-64": Architecture.X86_64, @@ -175,17 +180,18 @@ def get_current(): "arm64": Architecture.ARM, "armv7": Architecture.ARM, "armv7l": Architecture.ARM, + "aarch64": Architecture.ARM, } class AirflowInfo: - """Renders information about Airflow instance""" + """Renders information about Airflow instance.""" def __init__(self, anonymizer): self.anonymizer = anonymizer @staticmethod - def _get_version(cmd: List[str], grep: Optional[bytes] = None): + def _get_version(cmd: list[str], grep: bytes | None = None): """Return tools version.""" try: with subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as proc: @@ -209,10 +215,10 @@ def get_fullname(o): if module is None or module == str.__class__.__module__: return o.__class__.__name__ # Avoid reporting __builtin__ else: - return module + '.' + o.__class__.__name__ + return module + "." + o.__class__.__name__ try: - handler_names = [get_fullname(handler) for handler in logging.getLogger('airflow.task').handlers] + handler_names = [get_fullname(handler) for handler in logging.getLogger("airflow.task").handlers] return ", ".join(handler_names) except Exception: return "NOT AVAILABLE" @@ -257,8 +263,8 @@ def _system_info(self): python_version = sys.version.replace("\n", " ") return [ - ("OS", operating_system or "NOT AVAILABLE"), - ("architecture", arch or "NOT AVAILABLE"), + ("OS", operating_system.value), + ("architecture", arch.value), ("uname", str(uname)), ("locale", str(_locale)), ("python_version", python_version), @@ -304,10 +310,10 @@ def _paths_info(self): @property def _providers_info(self): - return [(p.data['package-name'], p.version) for p in ProvidersManager().providers.values()] + return [(p.data["package-name"], p.version) for p in ProvidersManager().providers.values()] - def show(self, output: str, console: Optional[AirflowConsole] = None) -> None: - """Shows information about Airflow instance""" + def show(self, output: str, console: AirflowConsole | None = None) -> None: + """Shows information about Airflow instance.""" all_info = { "Apache Airflow": self._airflow_info, "System info": self._system_info, @@ -329,7 +335,7 @@ def show(self, output: str, console: Optional[AirflowConsole] = None) -> None: ) def render_text(self, output: str) -> str: - """Exports the info to string""" + """Exports the info to string.""" console = AirflowConsole(record=True) with console.capture(): self.show(output=output, console=console) @@ -337,7 +343,7 @@ def render_text(self, output: str) -> str: class FileIoException(Exception): - """Raises when error happens in FileIo.io integration""" + """Raises when error happens in FileIo.io integration.""" @tenacity.retry( @@ -348,7 +354,7 @@ class FileIoException(Exception): after=tenacity.after_log(log, logging.DEBUG), ) def _upload_text_to_fileio(content): - """Upload text file to File.io service and return lnk""" + """Upload text file to File.io service and return link.""" resp = httpx.post("https://file.io", content=content) if resp.status_code not in [200, 201]: print(resp.json()) diff --git a/airflow/cli/commands/jobs_command.py b/airflow/cli/commands/jobs_command.py index f6d8d55fb6b9a..c030b5ea9b31b 100644 --- a/airflow/cli/commands/jobs_command.py +++ b/airflow/cli/commands/jobs_command.py @@ -14,19 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -from typing import List +from __future__ import annotations from airflow.jobs.base_job import BaseJob +from airflow.utils.net import get_hostname from airflow.utils.session import provide_session from airflow.utils.state import State @provide_session def check(args, session=None): - """Checks if job(s) are still alive""" + """Checks if job(s) are still alive.""" if args.allow_multiple and not args.limit > 1: raise SystemExit("To use option --allow-multiple, you must set the limit to a value greater than 1.") + if args.hostname and args.local: + raise SystemExit("You can't use --hostname and --local at the same time") + query = ( session.query(BaseJob) .filter(BaseJob.state == State.RUNNING) @@ -36,10 +39,12 @@ def check(args, session=None): query = query.filter(BaseJob.job_type == args.job_type) if args.hostname: query = query.filter(BaseJob.hostname == args.hostname) + if args.local: + query = query.filter(BaseJob.hostname == get_hostname()) if args.limit > 0: query = query.limit(args.limit) - jobs: List[BaseJob] = query.all() + jobs: list[BaseJob] = query.all() alive_jobs = [job for job in jobs if job.is_alive()] count_alive_jobs = len(alive_jobs) diff --git a/airflow/cli/commands/kerberos_command.py b/airflow/cli/commands/kerberos_command.py index fea874349923b..4bbe3f6df919d 100644 --- a/airflow/cli/commands/kerberos_command.py +++ b/airflow/cli/commands/kerberos_command.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Kerberos command.""" +from __future__ import annotations -"""Kerberos command""" import daemon from daemon.pidfile import TimeoutPIDLockFile @@ -27,18 +28,22 @@ @cli_utils.action_cli def kerberos(args): - """Start a kerberos ticket renewer""" + """Start a kerberos ticket renewer.""" print(settings.HEADER) if args.daemon: pid, stdout, stderr, _ = setup_locations( "kerberos", args.pid, args.stdout, args.stderr, args.log_file ) - with open(stdout, 'w+') as stdout_handle, open(stderr, 'w+') as stderr_handle: + with open(stdout, "a") as stdout_handle, open(stderr, "a") as stderr_handle: + stdout_handle.truncate(0) + stderr_handle.truncate(0) + ctx = daemon.DaemonContext( pidfile=TimeoutPIDLockFile(pid, -1), stdout=stdout_handle, stderr=stderr_handle, + umask=int(settings.DAEMON_UMASK, 8), ) with ctx: diff --git a/airflow/cli/commands/kubernetes_command.py b/airflow/cli/commands/kubernetes_command.py index 7c26821780575..052bf339c07b0 100644 --- a/airflow/cli/commands/kubernetes_command.py +++ b/airflow/cli/commands/kubernetes_command.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Kubernetes sub-commands""" +"""Kubernetes sub-commands.""" +from __future__ import annotations + import os import sys from datetime import datetime, timedelta @@ -35,7 +37,7 @@ @cli_utils.action_cli def generate_pod_yaml(args): - """Generates yaml files for each task in the DAG. Used for testing output of KubernetesExecutor""" + """Generates yaml files for each task in the DAG. Used for testing output of KubernetesExecutor.""" execution_date = args.execution_date dag = get_dag(subdir=args.subdir, dag_id=args.dag_id) yaml_output_path = args.output_path @@ -70,7 +72,7 @@ def generate_pod_yaml(args): @cli_utils.action_cli def cleanup_pods(args): - """Clean up k8s pods in evicted/failed/succeeded/pending states""" + """Clean up k8s pods in evicted/failed/succeeded/pending states.""" namespace = args.namespace min_pending_minutes = args.min_pending_minutes @@ -80,42 +82,42 @@ def cleanup_pods(args): # https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/ # All Containers in the Pod have terminated in success, and will not be restarted. - pod_succeeded = 'succeeded' + pod_succeeded = "succeeded" # The Pod has been accepted by the Kubernetes cluster, # but one or more of the containers has not been set up and made ready to run. - pod_pending = 'pending' + pod_pending = "pending" # All Containers in the Pod have terminated, and at least one Container has terminated in failure. # That is, the Container either exited with non-zero status or was terminated by the system. - pod_failed = 'failed' + pod_failed = "failed" # https://kubernetes.io/docs/tasks/administer-cluster/out-of-resource/ - pod_reason_evicted = 'evicted' + pod_reason_evicted = "evicted" # If pod is failed and restartPolicy is: # * Always: Restart Container; Pod phase stays Running. # * OnFailure: Restart Container; Pod phase stays Running. # * Never: Pod phase becomes Failed. - pod_restart_policy_never = 'never' + pod_restart_policy_never = "never" - print('Loading Kubernetes configuration') + print("Loading Kubernetes configuration") kube_client = get_kube_client() - print(f'Listing pods in namespace {namespace}') + print(f"Listing pods in namespace {namespace}") airflow_pod_labels = [ - 'dag_id', - 'task_id', - 'try_number', - 'airflow_version', + "dag_id", + "task_id", + "try_number", + "airflow_version", ] - list_kwargs = {"namespace": namespace, "limit": 500, "label_selector": ','.join(airflow_pod_labels)} + list_kwargs = {"namespace": namespace, "limit": 500, "label_selector": ",".join(airflow_pod_labels)} while True: pod_list = kube_client.list_namespaced_pod(**list_kwargs) for pod in pod_list.items: pod_name = pod.metadata.name - print(f'Inspecting pod {pod_name}') + print(f"Inspecting pod {pod_name}") pod_phase = pod.status.phase.lower() - pod_reason = pod.status.reason.lower() if pod.status.reason else '' + pod_reason = pod.status.reason.lower() if pod.status.reason else "" pod_restart_policy = pod.spec.restart_policy.lower() current_time = datetime.now(pod.metadata.creation_timestamp.tzinfo) @@ -138,7 +140,7 @@ def cleanup_pods(args): except ApiException as e: print(f"Can't remove POD: {e}", file=sys.stderr) continue - print(f'No action taken on pod {pod_name}') + print(f"No action taken on pod {pod_name}") continue_token = pod_list.metadata._continue if not continue_token: break @@ -146,7 +148,7 @@ def cleanup_pods(args): def _delete_pod(name, namespace): - """Helper Function for cleanup_pods""" + """Helper Function for cleanup_pods.""" core_v1 = client.CoreV1Api() delete_options = client.V1DeleteOptions() print(f'Deleting POD "{name}" from "{namespace}" namespace') diff --git a/airflow/cli/commands/legacy_commands.py b/airflow/cli/commands/legacy_commands.py index 94f9b690327b2..910c6e442703f 100644 --- a/airflow/cli/commands/legacy_commands.py +++ b/airflow/cli/commands/legacy_commands.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from argparse import ArgumentError @@ -49,7 +50,7 @@ def check_legacy_command(action, value): - """Checks command value and raise error if value is in removed command""" + """Checks command value and raise error if value is in removed command.""" new_command = COMMAND_MAP.get(value) if new_command is not None: msg = f"`airflow {value}` command, has been removed, please use `airflow {new_command}`" diff --git a/airflow/cli/commands/plugins_command.py b/airflow/cli/commands/plugins_command.py index 2d59e901a187a..50ee583099110 100644 --- a/airflow/cli/commands/plugins_command.py +++ b/airflow/cli/commands/plugins_command.py @@ -14,8 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import inspect -from typing import Any, Dict, List, Union +from typing import Any from airflow import plugins_manager from airflow.cli.simple_table import AirflowConsole @@ -31,15 +33,15 @@ def _get_name(class_like_object) -> str: return class_like_object.__class__.__name__ -def _join_plugins_names(value: Union[List[Any], Any]) -> str: +def _join_plugins_names(value: list[Any] | Any) -> str: value = value if isinstance(value, list) else [value] return ",".join(_get_name(v) for v in value) @suppress_logs_and_warning def dump_plugins(args): - """Dump plugins information""" - plugins_info: List[Dict[str, str]] = get_plugin_info() + """Dump plugins information.""" + plugins_info: list[dict[str, str]] = get_plugin_info() if not plugins_manager.plugins: print("No plugins loaded") return diff --git a/airflow/cli/commands/pool_command.py b/airflow/cli/commands/pool_command.py index e435c2a4833bc..aa56ba8fea7de 100644 --- a/airflow/cli/commands/pool_command.py +++ b/airflow/cli/commands/pool_command.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Pools sub-commands""" +"""Pools sub-commands.""" +from __future__ import annotations + import json import os from json import JSONDecodeError @@ -41,7 +43,7 @@ def _show_pools(pools, output): @suppress_logs_and_warning def pool_list(args): - """Displays info of all the pools""" + """Displays info of all the pools.""" api_client = get_current_api_client() pools = api_client.get_pools() _show_pools(pools=pools, output=args.output) @@ -49,7 +51,7 @@ def pool_list(args): @suppress_logs_and_warning def pool_get(args): - """Displays pool info by a given name""" + """Displays pool info by a given name.""" api_client = get_current_api_client() try: pools = [api_client.get_pool(name=args.pool)] @@ -61,7 +63,7 @@ def pool_get(args): @cli_utils.action_cli @suppress_logs_and_warning def pool_set(args): - """Creates new pool with a given name and slots""" + """Creates new pool with a given name and slots.""" api_client = get_current_api_client() api_client.create_pool(name=args.pool, slots=args.slots, description=args.description) print(f"Pool {args.pool} created") @@ -70,7 +72,7 @@ def pool_set(args): @cli_utils.action_cli @suppress_logs_and_warning def pool_delete(args): - """Deletes pool by a given name""" + """Deletes pool by a given name.""" api_client = get_current_api_client() try: api_client.delete_pool(name=args.pool) @@ -82,7 +84,7 @@ def pool_delete(args): @cli_utils.action_cli @suppress_logs_and_warning def pool_import(args): - """Imports pools from the file""" + """Imports pools from the file.""" if not os.path.exists(args.file): raise SystemExit(f"Missing pools file {args.file}") pools, failed = pool_import_helper(args.file) @@ -92,13 +94,13 @@ def pool_import(args): def pool_export(args): - """Exports all of the pools to the file""" + """Exports all the pools to the file.""" pools = pool_export_helper(args.file) print(f"Exported {len(pools)} pools to {args.file}") def pool_import_helper(filepath): - """Helps import pools from the json file""" + """Helps import pools from the json file.""" api_client = get_current_api_client() with open(filepath) as poolfile: @@ -118,12 +120,12 @@ def pool_import_helper(filepath): def pool_export_helper(filepath): - """Helps export all of the pools to the json file""" + """Helps export all the pools to the json file.""" api_client = get_current_api_client() pool_dict = {} pools = api_client.get_pools() for pool in pools: pool_dict[pool[0]] = {"slots": pool[1], "description": pool[2]} - with open(filepath, 'w') as poolfile: + with open(filepath, "w") as poolfile: poolfile.write(json.dumps(pool_dict, sort_keys=True, indent=4)) return pools diff --git a/airflow/cli/commands/provider_command.py b/airflow/cli/commands/provider_command.py index b89f77fb5d2a4..cae605679490f 100644 --- a/airflow/cli/commands/provider_command.py +++ b/airflow/cli/commands/provider_command.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Providers sub-commands""" +"""Providers sub-commands.""" +from __future__ import annotations + import re from airflow.cli.simple_table import AirflowConsole @@ -50,7 +52,7 @@ def provider_get(args): @suppress_logs_and_warning def providers_list(args): - """Lists all providers at the command line""" + """Lists all providers at the command line.""" AirflowConsole().print_as( data=list(ProvidersManager().providers.values()), output=args.output, @@ -64,7 +66,7 @@ def providers_list(args): @suppress_logs_and_warning def hooks_list(args): - """Lists all hooks at the command line""" + """Lists all hooks at the command line.""" AirflowConsole().print_as( data=list(ProvidersManager().hooks.items()), output=args.output, @@ -72,30 +74,30 @@ def hooks_list(args): "connection_type": x[0], "class": x[1].hook_class_name if x[1] else ERROR_IMPORTING_HOOK, "conn_id_attribute_name": x[1].connection_id_attribute_name if x[1] else ERROR_IMPORTING_HOOK, - 'package_name': x[1].package_name if x[1] else ERROR_IMPORTING_HOOK, - 'hook_name': x[1].hook_name if x[1] else ERROR_IMPORTING_HOOK, + "package_name": x[1].package_name if x[1] else ERROR_IMPORTING_HOOK, + "hook_name": x[1].hook_name if x[1] else ERROR_IMPORTING_HOOK, }, ) @suppress_logs_and_warning def connection_form_widget_list(args): - """Lists all custom connection form fields at the command line""" + """Lists all custom connection form fields at the command line.""" AirflowConsole().print_as( - data=list(ProvidersManager().connection_form_widgets.items()), + data=list(sorted(ProvidersManager().connection_form_widgets.items())), output=args.output, mapper=lambda x: { "connection_parameter_name": x[0], "class": x[1].hook_class_name, - 'package_name': x[1].package_name, - 'field_type': x[1].field.field_class.__name__, + "package_name": x[1].package_name, + "field_type": x[1].field.field_class.__name__, }, ) @suppress_logs_and_warning def connection_field_behaviours(args): - """Lists field behaviours""" + """Lists field behaviours.""" AirflowConsole().print_as( data=list(ProvidersManager().field_behaviours.keys()), output=args.output, @@ -107,7 +109,7 @@ def connection_field_behaviours(args): @suppress_logs_and_warning def extra_links_list(args): - """Lists all extra links at the command line""" + """Lists all extra links at the command line.""" AirflowConsole().print_as( data=ProvidersManager().extra_links_class_names, output=args.output, @@ -119,7 +121,7 @@ def extra_links_list(args): @suppress_logs_and_warning def logging_list(args): - """Lists all log task handlers at the command line""" + """Lists all log task handlers at the command line.""" AirflowConsole().print_as( data=list(ProvidersManager().logging_class_names), output=args.output, @@ -131,7 +133,7 @@ def logging_list(args): @suppress_logs_and_warning def secrets_backends_list(args): - """Lists all secrets backends at the command line""" + """Lists all secrets backends at the command line.""" AirflowConsole().print_as( data=list(ProvidersManager().secrets_backend_class_names), output=args.output, @@ -143,7 +145,7 @@ def secrets_backends_list(args): @suppress_logs_and_warning def auth_backend_list(args): - """Lists all API auth backend modules at the command line""" + """Lists all API auth backend modules at the command line.""" AirflowConsole().print_as( data=list(ProvidersManager().auth_backend_module_names), output=args.output, diff --git a/airflow/cli/commands/role_command.py b/airflow/cli/commands/role_command.py index 241a0409c4679..571db3eefe11e 100644 --- a/airflow/cli/commands/role_command.py +++ b/airflow/cli/commands/role_command.py @@ -15,8 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -"""Roles sub-commands""" +"""Roles sub-commands.""" +from __future__ import annotations + +import collections +import itertools import json import os @@ -24,33 +27,128 @@ from airflow.utils import cli as cli_utils from airflow.utils.cli import suppress_logs_and_warning from airflow.www.app import cached_app -from airflow.www.security import EXISTING_ROLES +from airflow.www.fab_security.sqla.models import Action, Permission, Resource, Role +from airflow.www.security import EXISTING_ROLES, AirflowSecurityManager @suppress_logs_and_warning def roles_list(args): - """Lists all existing roles""" + """Lists all existing roles.""" appbuilder = cached_app().appbuilder roles = appbuilder.sm.get_all_roles() + + if not args.permission: + AirflowConsole().print_as( + data=sorted(r.name for r in roles), output=args.output, mapper=lambda x: {"name": x} + ) + return + + permission_map: dict[tuple[str, str], list[str]] = collections.defaultdict(list) + for role in roles: + for permission in role.permissions: + permission_map[(role.name, permission.resource.name)].append(permission.action.name) + AirflowConsole().print_as( - data=sorted(r.name for r in roles), output=args.output, mapper=lambda x: {"name": x} + data=sorted(permission_map), + output=args.output, + mapper=lambda x: {"name": x[0], "resource": x[1], "action": ",".join(sorted(permission_map[x]))}, ) @cli_utils.action_cli @suppress_logs_and_warning def roles_create(args): - """Creates new empty role in DB""" + """Creates new empty role in DB.""" appbuilder = cached_app().appbuilder for role_name in args.role: appbuilder.sm.add_role(role_name) print(f"Added {len(args.role)} role(s)") +@cli_utils.action_cli +@suppress_logs_and_warning +def roles_delete(args): + """Deletes role in DB.""" + appbuilder = cached_app().appbuilder + + for role_name in args.role: + role = appbuilder.sm.find_role(role_name) + if not role: + print(f"Role named '{role_name}' does not exist") + exit(1) + + for role_name in args.role: + appbuilder.sm.delete_role(role_name) + print(f"Deleted {len(args.role)} role(s)") + + +def __roles_add_or_remove_permissions(args): + asm: AirflowSecurityManager = cached_app().appbuilder.sm + is_add: bool = args.subcommand.startswith("add") + + role_map = {} + perm_map: dict[tuple[str, str], set[str]] = collections.defaultdict(set) + for name in args.role: + role: Role | None = asm.find_role(name) + if not role: + print(f"Role named '{name}' does not exist") + exit(1) + + role_map[name] = role + for permission in role.permissions: + perm_map[(name, permission.resource.name)].add(permission.action.name) + + for name in args.resource: + resource: Resource | None = asm.get_resource(name) + if not resource: + print(f"Resource named '{name}' does not exist") + exit(1) + + for name in args.action or []: + action: Action | None = asm.get_action(name) + if not action: + print(f"Action named '{name}' does not exist") + exit(1) + + permission_count = 0 + for (role_name, resource_name, action_name) in list( + itertools.product(args.role, args.resource, args.action or [None]) + ): + res_key = (role_name, resource_name) + if is_add and action_name not in perm_map[res_key]: + perm: Permission | None = asm.create_permission(action_name, resource_name) + asm.add_permission_to_role(role_map[role_name], perm) + print(f"Added {perm} to role {role_name}") + permission_count += 1 + elif not is_add and res_key in perm_map: + for _action_name in perm_map[res_key] if action_name is None else [action_name]: + perm: Permission | None = asm.get_permission(_action_name, resource_name) + asm.remove_permission_from_role(role_map[role_name], perm) + print(f"Deleted {perm} from role {role_name}") + permission_count += 1 + + print(f"{'Added' if is_add else 'Deleted'} {permission_count} permission(s)") + + +@cli_utils.action_cli +@suppress_logs_and_warning +def roles_add_perms(args): + """Adds permissions to role in DB.""" + __roles_add_or_remove_permissions(args) + + +@cli_utils.action_cli +@suppress_logs_and_warning +def roles_del_perms(args): + """Deletes permissions from role in DB.""" + __roles_add_or_remove_permissions(args) + + @suppress_logs_and_warning def roles_export(args): """ - Exports all the rules from the data base to a file. + Exports all the roles from the database to a file. + Note, this function does not export the permissions associated for each role. Strictly, it exports the role names into the passed role json file. """ @@ -58,8 +156,8 @@ def roles_export(args): roles = appbuilder.sm.get_all_roles() exporting_roles = [role.name for role in roles if role.name not in EXISTING_ROLES] filename = os.path.expanduser(args.file) - kwargs = {} if not args.pretty else {'sort_keys': True, 'indent': 4} - with open(filename, 'w', encoding='utf-8') as f: + kwargs = {} if not args.pretty else {"sort_keys": True, "indent": 4} + with open(filename, "w", encoding="utf-8") as f: json.dump(exporting_roles, f, **kwargs) print(f"{len(exporting_roles)} roles successfully exported to {filename}") @@ -69,6 +167,7 @@ def roles_export(args): def roles_import(args): """ Import all the roles into the db from the given json file. + Note, this function does not import the permissions for different roles and import them as well. Strictly, it imports the role names in the role json file passed. """ diff --git a/airflow/cli/commands/rotate_fernet_key_command.py b/airflow/cli/commands/rotate_fernet_key_command.py index 9334344d9610b..f9e18735976c7 100644 --- a/airflow/cli/commands/rotate_fernet_key_command.py +++ b/airflow/cli/commands/rotate_fernet_key_command.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Rotate Fernet key command""" +"""Rotate Fernet key command.""" +from __future__ import annotations + from airflow.models import Connection, Variable from airflow.utils import cli as cli_utils from airflow.utils.session import create_session @@ -22,7 +24,7 @@ @cli_utils.action_cli def rotate_fernet_key(args): - """Rotates all encrypted connection credentials and variables""" + """Rotates all encrypted connection credentials and variables.""" with create_session() as session: for conn in session.query(Connection).filter(Connection.is_encrypted | Connection.is_extra_encrypted): conn.rotate_fernet_key() diff --git a/airflow/cli/commands/scheduler_command.py b/airflow/cli/commands/scheduler_command.py index bc6e983ee5c8e..44544716df6da 100644 --- a/airflow/cli/commands/scheduler_command.py +++ b/airflow/cli/commands/scheduler_command.py @@ -14,44 +14,38 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Scheduler command.""" +from __future__ import annotations -"""Scheduler command""" import signal +from contextlib import contextmanager from multiprocessing import Process -from typing import Optional import daemon from daemon.pidfile import TimeoutPIDLockFile from airflow import settings +from airflow.configuration import conf from airflow.jobs.scheduler_job import SchedulerJob from airflow.utils import cli as cli_utils from airflow.utils.cli import process_subdir, setup_locations, setup_logging, sigint_handler, sigquit_handler +from airflow.utils.scheduler_health import serve_health_check -def _create_scheduler_job(args): +def _run_scheduler_job(args): job = SchedulerJob( subdir=process_subdir(args.subdir), num_runs=args.num_runs, do_pickle=args.do_pickle, ) - return job - - -def _run_scheduler_job(args): - skip_serve_logs = args.skip_serve_logs - job = _create_scheduler_job(args) - sub_proc = _serve_logs(skip_serve_logs) - try: + enable_health_check = conf.getboolean("scheduler", "ENABLE_HEALTH_CHECK") + with _serve_logs(args.skip_serve_logs), _serve_health_check(enable_health_check): job.run() - finally: - if sub_proc: - sub_proc.terminate() @cli_utils.action_cli def scheduler(args): - """Starts Airflow Scheduler""" + """Starts Airflow Scheduler.""" print(settings.HEADER) if args.daemon: @@ -59,12 +53,16 @@ def scheduler(args): "scheduler", args.pid, args.stdout, args.stderr, args.log_file ) handle = setup_logging(log_file) - with open(stdout, 'w+') as stdout_handle, open(stderr, 'w+') as stderr_handle: + with open(stdout, "a") as stdout_handle, open(stderr, "a") as stderr_handle: + stdout_handle.truncate(0) + stderr_handle.truncate(0) + ctx = daemon.DaemonContext( pidfile=TimeoutPIDLockFile(pid, -1), files_preserve=[handle], stdout=stdout_handle, stderr=stderr_handle, + umask=int(settings.DAEMON_UMASK, 8), ) with ctx: _run_scheduler_job(args=args) @@ -75,14 +73,29 @@ def scheduler(args): _run_scheduler_job(args=args) -def _serve_logs(skip_serve_logs: bool = False) -> Optional[Process]: - """Starts serve_logs sub-process""" +@contextmanager +def _serve_logs(skip_serve_logs: bool = False): + """Starts serve_logs sub-process.""" from airflow.configuration import conf from airflow.utils.serve_logs import serve_logs + sub_proc = None if conf.get("core", "executor") in ["LocalExecutor", "SequentialExecutor"]: if skip_serve_logs is False: sub_proc = Process(target=serve_logs) sub_proc.start() - return sub_proc - return None + yield + if sub_proc: + sub_proc.terminate() + + +@contextmanager +def _serve_health_check(enable_health_check: bool = False): + """Starts serve_health_check sub-process.""" + sub_proc = None + if enable_health_check: + sub_proc = Process(target=serve_health_check) + sub_proc.start() + yield + if sub_proc: + sub_proc.terminate() diff --git a/airflow/cli/commands/standalone_command.py b/airflow/cli/commands/standalone_command.py index 3860942adb056..2a5670e83f01e 100644 --- a/airflow/cli/commands/standalone_command.py +++ b/airflow/cli/commands/standalone_command.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import logging import os @@ -23,7 +24,6 @@ import threading import time from collections import deque -from typing import Dict, List from termcolor import colored @@ -55,7 +55,7 @@ def __init__(self): self.ready_delay = 3 def run(self): - """Main run loop""" + """Main run loop.""" self.print_output("standalone", "Starting Airflow Standalone") # Silence built-in logging at INFO logging.getLogger("").setLevel(logging.WARNING) @@ -82,7 +82,7 @@ def run(self): env=env, ) - self.web_server_port = conf.getint('webserver', 'WEB_SERVER_PORT', fallback=8080) + self.web_server_port = conf.getint("webserver", "WEB_SERVER_PORT", fallback=8080) # Run subcommand threads for command in self.subcommands.values(): command.start() @@ -116,7 +116,7 @@ def run(self): self.print_output("standalone", "Complete") def update_output(self): - """Drains the output queue and prints its contents to the screen""" + """Drains the output queue and prints its contents to the screen.""" while self.output_queue: # Extract info name, line = self.output_queue.popleft() @@ -126,8 +126,9 @@ def update_output(self): def print_output(self, name: str, output): """ - Prints an output line with name and colouring. You can pass multiple - lines to output if you wish; it will be split for you. + Prints an output line with name and colouring. + + You can pass multiple lines to output if you wish; it will be split for you. """ color = { "webserver": "green", @@ -141,14 +142,16 @@ def print_output(self, name: str, output): def print_error(self, name: str, output): """ - Prints an error message to the console (this is the same as - print_output but with the text red) + Prints an error message to the console. + + This is the same as print_output but with the text red """ self.print_output(name, colored(output, "red")) def calculate_env(self): """ Works out the environment variables needed to run subprocesses. + We override some settings as part of being standalone. """ env = dict(os.environ) @@ -217,7 +220,8 @@ def is_ready(self): def port_open(self, port): """ Checks if the given port is listening on the local machine. - (used to tell if webserver is alive) + + Used to tell if webserver is alive. """ try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -231,8 +235,9 @@ def port_open(self, port): def job_running(self, job): """ - Checks if the given job name is running and heartbeating correctly - (used to tell if scheduler is alive) + Checks if the given job name is running and heartbeating correctly. + + Used to tell if scheduler is alive. """ recent = job.most_recent_job() if not recent: @@ -241,8 +246,9 @@ def job_running(self, job): def print_ready(self): """ - Prints the banner shown when Airflow is ready to go, with login - details. + Prints the banner shown when Airflow is ready to go. + + Include with login details. """ self.print_output("standalone", "") self.print_output("standalone", "Airflow is ready") @@ -260,12 +266,14 @@ def print_ready(self): class SubCommand(threading.Thread): """ + Execute a subcommand on another thread. + Thread that launches a process and then streams its output back to the main command. We use threads to avoid using select() and raw filehandles, and the complex logic that brings doing line buffering. """ - def __init__(self, parent, name: str, command: List[str], env: Dict[str, str]): + def __init__(self, parent, name: str, command: list[str], env: dict[str, str]): super().__init__() self.parent = parent self.name = name @@ -273,7 +281,7 @@ def __init__(self, parent, name: str, command: List[str], env: Dict[str, str]): self.env = env def run(self): - """Runs the actual process and captures it output to a queue""" + """Runs the actual process and captures it output to a queue.""" self.process = subprocess.Popen( ["airflow"] + self.command, stdout=subprocess.PIPE, @@ -284,7 +292,7 @@ def run(self): self.parent.output_queue.append((self.name, line)) def stop(self): - """Call to stop this process (and thus this thread)""" + """Call to stop this process (and thus this thread).""" self.process.terminate() diff --git a/airflow/cli/commands/sync_perm_command.py b/airflow/cli/commands/sync_perm_command.py index d580631fcc8a8..6a92ef99d5c8c 100644 --- a/airflow/cli/commands/sync_perm_command.py +++ b/airflow/cli/commands/sync_perm_command.py @@ -15,19 +15,21 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Sync permission command""" +"""Sync permission command.""" +from __future__ import annotations + from airflow.utils import cli as cli_utils from airflow.www.app import cached_app @cli_utils.action_cli def sync_perm(args): - """Updates permissions for existing roles and DAGs""" + """Updates permissions for existing roles and DAGs.""" appbuilder = cached_app().appbuilder - print('Updating actions and resources for all existing roles') + print("Updating actions and resources for all existing roles") # Add missing permissions for all the Base Views _before_ syncing/creating roles appbuilder.add_permissions(update_perms=True) appbuilder.sm.sync_roles() if args.include_dags: - print('Updating permission on all DAG views') + print("Updating permission on all DAG views") appbuilder.sm.create_dag_specific_permissions() diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index f8caf08487187..2f37579c351d3 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -15,15 +15,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Task sub-commands""" +"""Task sub-commands.""" +from __future__ import annotations + import datetime import importlib import json import logging import os import textwrap -from contextlib import contextmanager, redirect_stderr, redirect_stdout -from typing import Dict, Generator, List, Optional, Tuple, Union +from contextlib import contextmanager, redirect_stderr, redirect_stdout, suppress +from typing import Generator, Union from pendulum.parsing.exceptions import ParserError from sqlalchemy.orm.exc import NoResultFound @@ -35,17 +37,18 @@ from airflow.exceptions import AirflowException, DagRunNotFound, TaskInstanceNotFound from airflow.executors.executor_loader import ExecutorLoader from airflow.jobs.local_task_job import LocalTaskJob +from airflow.listeners.listener import get_listener_manager from airflow.models import DagPickle, TaskInstance from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG from airflow.models.dagrun import DagRun +from airflow.models.operator import needs_expansion from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS from airflow.typing_compat import Literal from airflow.utils import cli as cli_utils from airflow.utils.cli import ( get_dag, - get_dag_by_deserialization, get_dag_by_file_location, get_dag_by_pickle, get_dags, @@ -53,6 +56,7 @@ ) from airflow.utils.dates import timezone from airflow.utils.log.logging_mixin import StreamLogWriter +from airflow.utils.log.secrets_masker import RedactedIO from airflow.utils.net import get_hostname from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.state import DagRunState @@ -74,10 +78,10 @@ def _generate_temporary_run_id() -> str: def _get_dag_run( *, dag: DAG, - exec_date_or_run_id: str, create_if_necessary: CreateIfNecessary, + exec_date_or_run_id: str | None = None, session: Session, -) -> Tuple[DagRun, bool]: +) -> tuple[DagRun, bool]: """Try to retrieve a DAG run from a string representing either a run ID or logical date. This checks DAG runs like this: @@ -91,33 +95,35 @@ def _get_dag_run( the logical date; otherwise use it as a run ID and set the logical date to the current time. """ - dag_run = dag.get_dagrun(run_id=exec_date_or_run_id, session=session) - if dag_run: - return dag_run, False - - try: - execution_date: Optional[datetime.datetime] = timezone.parse(exec_date_or_run_id) - except (ParserError, TypeError): - execution_date = None - - try: - dag_run = ( - session.query(DagRun) - .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date) - .one() - ) - except NoResultFound: - if not create_if_necessary: - raise DagRunNotFound( - f"DagRun for {dag.dag_id} with run_id or execution_date of {exec_date_or_run_id!r} not found" - ) from None - else: - return dag_run, False + if not exec_date_or_run_id and not create_if_necessary: + raise ValueError("Must provide `exec_date_or_run_id` if not `create_if_necessary`.") + execution_date: datetime.datetime | None = None + if exec_date_or_run_id: + dag_run = dag.get_dagrun(run_id=exec_date_or_run_id, session=session) + if dag_run: + return dag_run, False + with suppress(ParserError, TypeError): + execution_date = timezone.parse(exec_date_or_run_id) + try: + dag_run = ( + session.query(DagRun) + .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date) + .one() + ) + except NoResultFound: + if not create_if_necessary: + raise DagRunNotFound( + f"DagRun for {dag.dag_id} with run_id or execution_date " + f"of {exec_date_or_run_id!r} not found" + ) from None + else: + return dag_run, False if execution_date is not None: dag_run_execution_date = execution_date else: dag_run_execution_date = timezone.utcnow() + if create_if_necessary == "memory": dag_run = DagRun(dag.dag_id, run_id=exec_date_or_run_id, execution_date=dag_run_execution_date) return dag_run, True @@ -135,15 +141,17 @@ def _get_dag_run( @provide_session def _get_ti( task: BaseOperator, - exec_date_or_run_id: str, map_index: int, *, - pool: Optional[str] = None, + exec_date_or_run_id: str | None = None, + pool: str | None = None, create_if_necessary: CreateIfNecessary = False, session: Session = NEW_SESSION, -) -> Tuple[TaskInstance, bool]: - """Get the task instance through DagRun.run_id, if that fails, get the TI the old way""" - if task.is_mapped: +) -> tuple[TaskInstance, bool]: + """Get the task instance through DagRun.run_id, if that fails, get the TI the old way.""" + if not exec_date_or_run_id and not create_if_necessary: + raise ValueError("Must provide `exec_date_or_run_id` if not `create_if_necessary`.") + if needs_expansion(task): if map_index < 0: raise RuntimeError("No map_index passed to mapped task") elif map_index >= 0: @@ -173,7 +181,9 @@ def _get_ti( def _run_task_by_selected_method(args, dag: DAG, ti: TaskInstance) -> None: """ - Runs the task in one of 3 modes + Runs the task based on a mode. + + Any of the 3 modes are available: - using LocalTaskJob - as raw task @@ -189,8 +199,9 @@ def _run_task_by_selected_method(args, dag: DAG, ti: TaskInstance) -> None: def _run_task_by_executor(args, dag, ti): """ - Sends the task to the executor for execution. This can result in the task being started by another host - if the executor implementation does + Sends the task to the executor for execution. + + This can result in the task being started by another host if the executor implementation does. """ pickle_id = None if args.ship_dag: @@ -201,9 +212,9 @@ def _run_task_by_executor(args, dag, ti): session.add(pickle) pickle_id = pickle.id # TODO: This should be written to a log - print(f'Pickled dag {dag} as pickle_id: {pickle_id}') + print(f"Pickled dag {dag} as pickle_id: {pickle_id}") except Exception as e: - print('Could not pickle the DAG') + print("Could not pickle the DAG") print(e) raise e executor = ExecutorLoader.get_default_executor() @@ -225,7 +236,7 @@ def _run_task_by_executor(args, dag, ti): def _run_task_by_local_task_job(args, ti): - """Run LocalTaskJob, which monitors the raw task execution process""" + """Run LocalTaskJob, which monitors the raw task execution process.""" run_job = LocalTaskJob( task_instance=ti, mark_success=args.mark_success, @@ -254,7 +265,7 @@ def _run_task_by_local_task_job(args, ti): def _run_raw_task(args, ti: TaskInstance) -> None: - """Runs the main task handling code""" + """Runs the main task handling code.""" ti._run_raw_task( mark_success=args.mark_success, job_id=args.job_id, @@ -262,7 +273,7 @@ def _run_raw_task(args, ti: TaskInstance) -> None: ) -def _extract_external_executor_id(args) -> Optional[str]: +def _extract_external_executor_id(args) -> str | None: if hasattr(args, "external_executor_id"): return getattr(args, "external_executor_id") return os.environ.get("external_executor_id", None) @@ -270,7 +281,8 @@ def _extract_external_executor_id(args) -> Optional[str]: @contextmanager def _capture_task_logs(ti: TaskInstance) -> Generator[None, None, None]: - """Manage logging context for a task run + """ + Manage logging context for a task run. - Replace the root logger configuration with the airflow.task configuration so we can capture logs from any custom loggers used in the task. @@ -280,9 +292,8 @@ def _capture_task_logs(ti: TaskInstance) -> Generator[None, None, None]: """ modify = not settings.DONOT_MODIFY_HANDLERS - if modify: - root_logger, task_logger = logging.getLogger(), logging.getLogger('airflow.task') + root_logger, task_logger = logging.getLogger(), logging.getLogger("airflow.task") orig_level = root_logger.level root_logger.setLevel(task_logger.level) @@ -303,9 +314,14 @@ def _capture_task_logs(ti: TaskInstance) -> Generator[None, None, None]: root_logger.handlers[:] = orig_handlers +class TaskCommandMarker: + """Marker for listener hooks, to properly detect from which component they are called.""" + + @cli_utils.action_cli(check_db=False) def task_run(args, dag=None): - """Run a single task instance. + """ + Run a single task instance. Note that there must be at least one DagRun for this to start, i.e. it must have been scheduled and/or triggered previously. @@ -324,8 +340,8 @@ def task_run(args, dag=None): unsupported_options = [o for o in RAW_TASK_UNSUPPORTED_OPTION if getattr(args, o)] if unsupported_options: - unsupported_raw_task_flags = ', '.join(f'--{o}' for o in RAW_TASK_UNSUPPORTED_OPTION) - unsupported_flags = ', '.join(f'--{o}' for o in unsupported_options) + unsupported_raw_task_flags = ", ".join(f"--{o}" for o in RAW_TASK_UNSUPPORTED_OPTION) + unsupported_flags = ", ".join(f"--{o}" for o in unsupported_options) raise AirflowException( "Option --raw does not work with some of the other options on this command. " "You can't use --raw option and the following options: " @@ -353,39 +369,42 @@ def task_run(args, dag=None): # processing hundreds of simultaneous tasks. settings.reconfigure_orm(disable_connection_pool=True) + get_listener_manager().hook.on_starting(component=TaskCommandMarker()) + if args.pickle: - print(f'Loading pickle id: {args.pickle}') + print(f"Loading pickle id: {args.pickle}") dag = get_dag_by_pickle(args.pickle) elif not dag: - if args.local: - try: - dag = get_dag_by_deserialization(args.dag_id) - except AirflowException: - print(f'DAG {args.dag_id} does not exist in the database, trying to parse the dag_file') - dag = get_dag(args.subdir, args.dag_id) - else: - dag = get_dag(args.subdir, args.dag_id) + dag = get_dag(args.subdir, args.dag_id) else: # Use DAG from parameter pass task = dag.get_task(task_id=args.task_id) - ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index, pool=args.pool) + ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, pool=args.pool) ti.init_run_context(raw=args.raw) hostname = get_hostname() log.info("Running %s on host %s", ti, hostname) - if args.interactive: - _run_task_by_selected_method(args, dag, ti) - else: - with _capture_task_logs(ti): + try: + if args.interactive: _run_task_by_selected_method(args, dag, ti) + else: + with _capture_task_logs(ti): + _run_task_by_selected_method(args, dag, ti) + finally: + try: + get_listener_manager().hook.before_stopping(component=TaskCommandMarker()) + except Exception: + pass @cli_utils.action_cli(check_db=False) def task_failed_deps(args): """ + Get task instance dependencies that were not met. + Returns the unmet dependencies for a task instance from the perspective of the scheduler (i.e. why a task instance doesn't get scheduled and then queued by the scheduler, and then run by an executor). @@ -397,7 +416,7 @@ def task_failed_deps(args): """ dag = get_dag(args.subdir, args.dag_id) task = dag.get_task(task_id=args.task_id) - ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index) + ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id) dep_context = DepContext(deps=SCHEDULER_QUEUED_DEPS) failed_deps = list(ti.get_failed_dep_statuses(dep_context=dep_context)) @@ -420,14 +439,14 @@ def task_state(args): """ dag = get_dag(args.subdir, args.dag_id) task = dag.get_task(task_id=args.task_id) - ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index) + ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id) print(ti.current_state()) @cli_utils.action_cli(check_db=False) @suppress_logs_and_warning def task_list(args, dag=None): - """Lists the tasks within a DAG at the command line""" + """Lists the tasks within a DAG at the command line.""" dag = dag or get_dag(args.subdir, args.dag_id) if args.tree: dag.tree_view() @@ -436,7 +455,7 @@ def task_list(args, dag=None): print("\n".join(tasks)) -SUPPORTED_DEBUGGER_MODULES: List[str] = [ +SUPPORTED_DEBUGGER_MODULES: list[str] = [ "pudb", "web_pdb", "ipdb", @@ -446,8 +465,9 @@ def task_list(args, dag=None): def _guess_debugger(): """ - Trying to guess the debugger used by the user. When it doesn't find any user-installed debugger, - returns ``pdb``. + Trying to guess the debugger used by the user. + + When it doesn't find any user-installed debugger, returns ``pdb``. List of supported debuggers: @@ -468,7 +488,7 @@ def _guess_debugger(): @suppress_logs_and_warning @provide_session def task_states_for_dag_run(args, session=None): - """Get the status of all task instances in a DagRun""" + """Get the status of all task instances in a DagRun.""" dag_run = ( session.query(DagRun) .filter(DagRun.run_id == args.execution_date_or_run_id, DagRun.dag_id == args.dag_id) @@ -493,7 +513,7 @@ def task_states_for_dag_run(args, session=None): has_mapped_instances = any(ti.map_index >= 0 for ti in dag_run.task_instances) - def format_task_instance(ti: TaskInstance) -> Dict[str, str]: + def format_task_instance(ti: TaskInstance) -> dict[str, str]: data = { "dag_id": ti.dag_id, "execution_date": dag_run.execution_date.isoformat(), @@ -511,23 +531,23 @@ def format_task_instance(ti: TaskInstance) -> Dict[str, str]: @cli_utils.action_cli(check_db=False) def task_test(args, dag=None): - """Tests task for a given dag_id""" + """Tests task for a given dag_id.""" # We want to log output from operators etc to show up here. Normally # airflow.task would redirect to a file, but here we want it to propagate # up to the normal airflow handler. settings.MASK_SECRETS_IN_LOGS = True - handlers = logging.getLogger('airflow.task').handlers + handlers = logging.getLogger("airflow.task").handlers already_has_stream_handler = False for handler in handlers: already_has_stream_handler = isinstance(handler, logging.StreamHandler) if already_has_stream_handler: break if not already_has_stream_handler: - logging.getLogger('airflow.task').propagate = True + logging.getLogger("airflow.task").propagate = True - env_vars = {'AIRFLOW_TEST_MODE': 'True'} + env_vars = {"AIRFLOW_TEST_MODE": "True"} if args.env_vars: env_vars.update(args.env_vars) os.environ.update(env_vars) @@ -543,13 +563,16 @@ def task_test(args, dag=None): if task.params: task.params.validate() - ti, dr_created = _get_ti(task, args.execution_date_or_run_id, args.map_index, create_if_necessary="db") + ti, dr_created = _get_ti( + task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, create_if_necessary="db" + ) try: - if args.dry_run: - ti.dry_run() - else: - ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True) + with redirect_stdout(RedactedIO()): + if args.dry_run: + ti.dry_run() + else: + ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True) except Exception: if args.post_mortem: debugger = _guess_debugger() @@ -560,7 +583,7 @@ def task_test(args, dag=None): if not already_has_stream_handler: # Make sure to reset back to normal. When run for CLI this doesn't # matter, but it does for test suite - logging.getLogger('airflow.task').propagate = False + logging.getLogger("airflow.task").propagate = False if dr_created: with create_session() as session: session.delete(ti.dag_run) @@ -568,19 +591,22 @@ def task_test(args, dag=None): @cli_utils.action_cli(check_db=False) @suppress_logs_and_warning -def task_render(args): - """Renders and displays templated fields for a given task""" - dag = get_dag(args.subdir, args.dag_id) +def task_render(args, dag=None): + """Renders and displays templated fields for a given task.""" + if not dag: + dag = get_dag(args.subdir, args.dag_id) task = dag.get_task(task_id=args.task_id) - ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index, create_if_necessary="memory") + ti, _ = _get_ti( + task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, create_if_necessary="memory" + ) ti.render_templates() - for attr in task.__class__.template_fields: + for attr in task.template_fields: print( textwrap.dedent( f""" # ---------------------------------------------------------- # property: {attr} # ---------------------------------------------------------- - {getattr(task, attr)} + {getattr(ti.task, attr)} """ ) ) @@ -588,7 +614,7 @@ def task_render(args): @cli_utils.action_cli(check_db=False) def task_clear(args): - """Clears all task instances or only those matched by regex for a DAG(s)""" + """Clears all task instances or only those matched by regex for a DAG(s).""" logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT) if args.dag_id and not args.subdir and not args.dag_regex and not args.task_regex: diff --git a/airflow/cli/commands/triggerer_command.py b/airflow/cli/commands/triggerer_command.py index 8bf419268059b..64755f3830291 100644 --- a/airflow/cli/commands/triggerer_command.py +++ b/airflow/cli/commands/triggerer_command.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Triggerer command.""" +from __future__ import annotations -"""Triggerer command""" import signal import daemon @@ -24,12 +25,12 @@ from airflow import settings from airflow.jobs.triggerer_job import TriggererJob from airflow.utils import cli as cli_utils -from airflow.utils.cli import setup_locations, setup_logging, sigquit_handler +from airflow.utils.cli import setup_locations, setup_logging, sigint_handler, sigquit_handler @cli_utils.action_cli def triggerer(args): - """Starts Airflow Triggerer""" + """Starts Airflow Triggerer.""" settings.MASK_SECRETS_IN_LOGS = True print(settings.HEADER) job = TriggererJob(capacity=args.capacity) @@ -39,30 +40,22 @@ def triggerer(args): "triggerer", args.pid, args.stdout, args.stderr, args.log_file ) handle = setup_logging(log_file) - with open(stdout, 'w+') as stdout_handle, open(stderr, 'w+') as stderr_handle: + with open(stdout, "a") as stdout_handle, open(stderr, "a") as stderr_handle: + stdout_handle.truncate(0) + stderr_handle.truncate(0) + ctx = daemon.DaemonContext( pidfile=TimeoutPIDLockFile(pid, -1), files_preserve=[handle], stdout=stdout_handle, stderr=stderr_handle, + umask=int(settings.DAEMON_UMASK, 8), ) with ctx: job.run() else: - # There is a bug in CPython (fixed in March 2022 but not yet released) that - # makes async.io handle SIGTERM improperly by using async unsafe - # functions and hanging the triggerer receive SIGPIPE while handling - # SIGTERN/SIGINT and deadlocking itself. Until the bug is handled - # we should rather rely on standard handling of the signals rather than - # adding our own signal handlers. Seems that even if our signal handler - # just run exit(0) - it caused a race condition that led to the hanging. - # - # More details: - # * https://bugs.python.org/issue39622 - # * https://github.com/python/cpython/issues/83803 - # - # signal.signal(signal.SIGINT, sigint_handler) - # signal.signal(signal.SIGTERM, sigint_handler) + signal.signal(signal.SIGINT, sigint_handler) + signal.signal(signal.SIGTERM, sigint_handler) signal.signal(signal.SIGQUIT, sigquit_handler) job.run() diff --git a/airflow/cli/commands/user_command.py b/airflow/cli/commands/user_command.py index ddbb7cfc82d6c..f1c806e942870 100644 --- a/airflow/cli/commands/user_command.py +++ b/airflow/cli/commands/user_command.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""User sub-commands""" +"""User sub-commands.""" +from __future__ import annotations + import functools import getpass import json @@ -22,7 +24,7 @@ import random import re import string -from typing import Any, Dict, List +from typing import Any from marshmallow import Schema, fields, validate from marshmallow.exceptions import ValidationError @@ -34,7 +36,7 @@ class UserSchema(Schema): - """user collection item schema""" + """user collection item schema.""" id = fields.Int() firstname = fields.Str(required=True) @@ -46,10 +48,10 @@ class UserSchema(Schema): @suppress_logs_and_warning def users_list(args): - """Lists users at the command line""" + """Lists users at the command line.""" appbuilder = cached_app().appbuilder users = appbuilder.sm.get_all_users() - fields = ['id', 'username', 'email', 'first_name', 'last_name', 'roles'] + fields = ["id", "username", "email", "first_name", "last_name", "roles"] AirflowConsole().print_as( data=users, output=args.output, mapper=lambda x: {f: x.__getattribute__(f) for f in fields} @@ -58,39 +60,39 @@ def users_list(args): @cli_utils.action_cli(check_db=True) def users_create(args): - """Creates new user in the DB""" + """Creates new user in the DB.""" appbuilder = cached_app().appbuilder role = appbuilder.sm.find_role(args.role) if not role: valid_roles = appbuilder.sm.get_all_roles() - raise SystemExit(f'{args.role} is not a valid role. Valid roles are: {valid_roles}') + raise SystemExit(f"{args.role} is not a valid role. Valid roles are: {valid_roles}") if args.use_random_password: - password = ''.join(random.choice(string.printable) for _ in range(16)) + password = "".join(random.choice(string.printable) for _ in range(16)) elif args.password: password = args.password else: - password = getpass.getpass('Password:') - password_confirmation = getpass.getpass('Repeat for confirmation:') + password = getpass.getpass("Password:") + password_confirmation = getpass.getpass("Repeat for confirmation:") if password != password_confirmation: - raise SystemExit('Passwords did not match') + raise SystemExit("Passwords did not match") if appbuilder.sm.find_user(args.username): - print(f'{args.username} already exist in the db') + print(f"{args.username} already exist in the db") return user = appbuilder.sm.add_user(args.username, args.firstname, args.lastname, args.email, role, password) if user: print(f'User "{args.username}" created with role "{args.role}"') else: - raise SystemExit('Failed to create user') + raise SystemExit("Failed to create user") def _find_user(args): if not args.username and not args.email: - raise SystemExit('Missing args: must supply one of --username or --email') + raise SystemExit("Missing args: must supply one of --username or --email") if args.username and args.email: - raise SystemExit('Conflicting args: must supply either --username or --email, but not both') + raise SystemExit("Conflicting args: must supply either --username or --email, but not both") appbuilder = cached_app().appbuilder @@ -102,7 +104,7 @@ def _find_user(args): @cli_utils.action_cli def users_delete(args): - """Deletes user from DB""" + """Deletes user from DB.""" user = _find_user(args) appbuilder = cached_app().appbuilder @@ -110,12 +112,12 @@ def users_delete(args): if appbuilder.sm.del_register_user(user): print(f'User "{user.username}" deleted') else: - raise SystemExit('Failed to delete user') + raise SystemExit("Failed to delete user") @cli_utils.action_cli def users_manage_role(args, remove=False): - """Deletes or appends user roles""" + """Deletes or appends user roles.""" user = _find_user(args) appbuilder = cached_app().appbuilder @@ -142,10 +144,10 @@ def users_manage_role(args, remove=False): def users_export(args): - """Exports all users to the json file""" + """Exports all users to the json file.""" appbuilder = cached_app().appbuilder users = appbuilder.sm.get_all_users() - fields = ['id', 'username', 'email', 'first_name', 'last_name', 'roles'] + fields = ["id", "username", "email", "first_name", "last_name", "roles"] # In the User model the first and last name fields have underscores, # but the corresponding parameters in the CLI don't @@ -155,22 +157,22 @@ def remove_underscores(s): users = [ { remove_underscores(field): user.__getattribute__(field) - if field != 'roles' + if field != "roles" else [r.name for r in user.roles] for field in fields } for user in users ] - with open(args.export, 'w') as file: + with open(args.export, "w") as file: file.write(json.dumps(users, sort_keys=True, indent=4)) print(f"{len(users)} users successfully exported to {file.name}") @cli_utils.action_cli def users_import(args): - """Imports users from the json file""" - json_file = getattr(args, 'import') + """Imports users from the json file.""" + json_file = getattr(args, "import") if not os.path.exists(json_file): raise SystemExit(f"File '{json_file}' does not exist") @@ -189,7 +191,7 @@ def users_import(args): print("Updated the following users:\n\t{}".format("\n\t".join(users_updated))) -def _import_users(users_list: List[Dict[str, Any]]): +def _import_users(users_list: list[dict[str, Any]]): appbuilder = cached_app().appbuilder users_created = [] users_updated = [] @@ -199,15 +201,15 @@ def _import_users(users_list: List[Dict[str, Any]]): except ValidationError as e: msg = [] for row_num, failure in e.normalized_messages().items(): - msg.append(f'[Item {row_num}]') + msg.append(f"[Item {row_num}]") for key, value in failure.items(): - msg.append(f'\t{key}: {value}') - raise SystemExit("Error: Input file didn't pass validation. See below:\n{}".format('\n'.join(msg))) + msg.append(f"\t{key}: {value}") + raise SystemExit("Error: Input file didn't pass validation. See below:\n{}".format("\n".join(msg))) for user in users_list: roles = [] - for rolename in user['roles']: + for rolename in user["roles"]: role = appbuilder.sm.find_role(rolename) if not role: valid_roles = appbuilder.sm.get_all_roles() @@ -215,31 +217,31 @@ def _import_users(users_list: List[Dict[str, Any]]): roles.append(role) - existing_user = appbuilder.sm.find_user(email=user['email']) + existing_user = appbuilder.sm.find_user(email=user["email"]) if existing_user: print(f"Found existing user with email '{user['email']}'") - if existing_user.username != user['username']: + if existing_user.username != user["username"]: raise SystemExit( f"Error: Changing the username is not allowed - please delete and recreate the user with" f" email {user['email']!r}" ) existing_user.roles = roles - existing_user.first_name = user['firstname'] - existing_user.last_name = user['lastname'] + existing_user.first_name = user["firstname"] + existing_user.last_name = user["lastname"] appbuilder.sm.update_user(existing_user) - users_updated.append(user['email']) + users_updated.append(user["email"]) else: print(f"Creating new user with email '{user['email']}'") appbuilder.sm.add_user( - username=user['username'], - first_name=user['firstname'], - last_name=user['lastname'], - email=user['email'], + username=user["username"], + first_name=user["firstname"], + last_name=user["lastname"], + email=user["email"], role=roles, ) - users_created.append(user['email']) + users_created.append(user["email"]) return users_created, users_updated diff --git a/airflow/cli/commands/variable_command.py b/airflow/cli/commands/variable_command.py index 40eb193148a80..009b4704aaaee 100644 --- a/airflow/cli/commands/variable_command.py +++ b/airflow/cli/commands/variable_command.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Variable subcommands""" +"""Variable subcommands.""" +from __future__ import annotations + import json import os from json import JSONDecodeError @@ -29,7 +31,7 @@ @suppress_logs_and_warning def variables_list(args): - """Displays all of the variables""" + """Displays all the variables.""" with create_session() as session: variables = session.query(Variable) AirflowConsole().print_as(data=variables, output=args.output, mapper=lambda x: {"key": x.key}) @@ -37,7 +39,7 @@ def variables_list(args): @suppress_logs_and_warning def variables_get(args): - """Displays variable by a given name""" + """Displays variable by a given name.""" try: if args.default is None: var = Variable.get(args.key, deserialize_json=args.json) @@ -51,21 +53,21 @@ def variables_get(args): @cli_utils.action_cli def variables_set(args): - """Creates new variable with a given name and value""" + """Creates new variable with a given name and value.""" Variable.set(args.key, args.value, serialize_json=args.json) print(f"Variable {args.key} created") @cli_utils.action_cli def variables_delete(args): - """Deletes variable by a given name""" + """Deletes variable by a given name.""" Variable.delete(args.key) print(f"Variable {args.key} deleted") @cli_utils.action_cli def variables_import(args): - """Imports variables from a given file""" + """Imports variables from a given file.""" if os.path.exists(args.file): _import_helper(args.file) else: @@ -73,12 +75,12 @@ def variables_import(args): def variables_export(args): - """Exports all of the variables to the file""" + """Exports all the variables to the file.""" _variable_export_helper(args.file) def _import_helper(filepath): - """Helps import variables from the file""" + """Helps import variables from the file.""" with open(filepath) as varfile: data = varfile.read() @@ -92,7 +94,7 @@ def _import_helper(filepath): try: Variable.set(k, v, serialize_json=not isinstance(v, str)) except Exception as e: - print(f'Variable import failed: {repr(e)}') + print(f"Variable import failed: {repr(e)}") fail_count += 1 else: suc_count += 1 @@ -102,7 +104,7 @@ def _import_helper(filepath): def _variable_export_helper(filepath): - """Helps export all of the variables to the file""" + """Helps export all the variables to the file.""" var_dict = {} with create_session() as session: qry = session.query(Variable).all() @@ -115,6 +117,6 @@ def _variable_export_helper(filepath): val = var.val var_dict[var.key] = val - with open(filepath, 'w') as varfile: + with open(filepath, "w") as varfile: varfile.write(json.dumps(var_dict, sort_keys=True, indent=4)) print(f"{len(var_dict)} variables successfully exported to {filepath}") diff --git a/airflow/cli/commands/version_command.py b/airflow/cli/commands/version_command.py index 7e5190185884f..d3b735951c1fd 100644 --- a/airflow/cli/commands/version_command.py +++ b/airflow/cli/commands/version_command.py @@ -14,10 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Version command""" +"""Version command.""" +from __future__ import annotations + import airflow def version(args): - """Displays Airflow version at the command line""" + """Displays Airflow version at the command line.""" print(airflow.__version__) diff --git a/airflow/cli/commands/webserver_command.py b/airflow/cli/commands/webserver_command.py index c74513d23e2b9..a14f6a38e7357 100644 --- a/airflow/cli/commands/webserver_command.py +++ b/airflow/cli/commands/webserver_command.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Webserver command.""" +from __future__ import annotations -"""Webserver command""" import hashlib import logging import os @@ -26,7 +27,7 @@ import time from contextlib import suppress from time import sleep -from typing import Dict, List, NoReturn +from typing import NoReturn import daemon import psutil @@ -40,16 +41,16 @@ from airflow.utils.cli import setup_locations, setup_logging from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.process_utils import check_if_pidfile_process_is_running -from airflow.www.app import create_app log = logging.getLogger(__name__) class GunicornMonitor(LoggingMixin): """ - Runs forever, monitoring the child processes of @gunicorn_master_proc and - restarting workers occasionally or when files in the plug-in directory - has been modified. + Runs forever. + + Monitoring the child processes of @gunicorn_master_proc and restarting + workers occasionally or when files in the plug-in directory has been modified. Each iteration of the loop traverses one edge of this state transition diagram, where each state (node) represents @@ -103,15 +104,17 @@ def __init__( self._last_plugin_state = self._generate_plugin_state() if reload_on_plugin_change else None self._restart_on_next_plugin_check = False - def _generate_plugin_state(self) -> Dict[str, float]: + def _generate_plugin_state(self) -> dict[str, float]: """ + Get plugin states. + Generate dict of filenames and last modification time of all files in settings.PLUGINS_FOLDER directory. """ if not settings.PLUGINS_FOLDER: return {} - all_filenames: List[str] = [] + all_filenames: list[str] = [] for (root, _, filenames) in os.walk(settings.PLUGINS_FOLDER): all_filenames.extend(os.path.join(root, f) for f in filenames) plugin_state = {f: self._get_file_hash(f) for f in sorted(all_filenames)} @@ -119,7 +122,7 @@ def _generate_plugin_state(self) -> Dict[str, float]: @staticmethod def _get_file_hash(fname: str): - """Calculate MD5 hash for file""" + """Calculate MD5 hash for file.""" hash_md5 = hashlib.md5() with open(fname, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): @@ -127,7 +130,7 @@ def _get_file_hash(fname: str): return hash_md5.hexdigest() def _get_num_ready_workers_running(self) -> int: - """Returns number of ready Gunicorn workers by looking for READY_PREFIX in process name""" + """Returns number of ready Gunicorn workers by looking for READY_PREFIX in process name.""" workers = psutil.Process(self.gunicorn_master_proc.pid).children() def ready_prefix_on_cmdline(proc): @@ -143,12 +146,12 @@ def ready_prefix_on_cmdline(proc): return len(ready_workers) def _get_num_workers_running(self) -> int: - """Returns number of running Gunicorn workers processes""" + """Returns number of running Gunicorn workers processes.""" workers = psutil.Process(self.gunicorn_master_proc.pid).children() return len(workers) def _wait_until_true(self, fn, timeout: int = 0) -> None: - """Sleeps until fn is true""" + """Sleeps until fn is true.""" start_time = time.monotonic() while not fn(): if 0 < timeout <= time.monotonic() - start_time: @@ -188,8 +191,10 @@ def _kill_old_workers(self, count: int) -> None: def _reload_gunicorn(self) -> None: """ - Send signal to reload the gunicorn configuration. When gunicorn receive signals, it reload the - configuration, start the new worker processes with a new configuration and gracefully + Send signal to reload the gunicorn configuration. + + When gunicorn receive signals, it reloads the configuration, + start the new worker processes with a new configuration and gracefully shutdown older workers. """ # HUP: Reload the configuration. @@ -229,7 +234,7 @@ def _check_workers(self) -> None: # Whenever some workers are not ready, wait until all workers are ready if num_ready_workers_running < num_workers_running: self.log.debug( - '[%d / %d] Some workers are starting up, waiting...', + "[%d / %d] Some workers are starting up, waiting...", num_ready_workers_running, num_workers_running, ) @@ -241,7 +246,7 @@ def _check_workers(self) -> None: if num_workers_running > self.num_workers_expected: excess = min(num_workers_running - self.num_workers_expected, self.worker_refresh_batch_size) self.log.debug( - '[%d / %d] Killing %s workers', num_ready_workers_running, num_workers_running, excess + "[%d / %d] Killing %s workers", num_ready_workers_running, num_workers_running, excess ) self._kill_old_workers(excess) return @@ -262,7 +267,7 @@ def _check_workers(self) -> None: ) # log at info since we are trying fix an error logged just above self.log.info( - '[%d / %d] Spawning %d workers', + "[%d / %d] Spawning %d workers", num_ready_workers_running, num_workers_running, new_worker_count, @@ -279,7 +284,7 @@ def _check_workers(self) -> None: if self.worker_refresh_interval < last_refresh_diff: num_new_workers = self.worker_refresh_batch_size self.log.debug( - '[%d / %d] Starting doing a refresh. Starting %d workers.', + "[%d / %d] Starting doing a refresh. Starting %d workers.", num_ready_workers_running, num_workers_running, num_new_workers, @@ -295,8 +300,8 @@ def _check_workers(self) -> None: # If changed, wait until its content is fully saved. if new_state != self._last_plugin_state: self.log.debug( - '[%d / %d] Plugins folder changed. The gunicorn will be restarted the next time the ' - 'plugin directory is checked, if there is no change in it.', + "[%d / %d] Plugins folder changed. The gunicorn will be restarted the next time the " + "plugin directory is checked, if there is no change in it.", num_ready_workers_running, num_workers_running, ) @@ -304,7 +309,7 @@ def _check_workers(self) -> None: self._last_plugin_state = new_state elif self._restart_on_next_plugin_check: self.log.debug( - '[%d / %d] Starts reloading the gunicorn configuration.', + "[%d / %d] Starts reloading the gunicorn configuration.", num_ready_workers_running, num_workers_running, ) @@ -315,11 +320,11 @@ def _check_workers(self) -> None: @cli_utils.action_cli def webserver(args): - """Starts Airflow Webserver""" + """Starts Airflow Webserver.""" print(settings.HEADER) # Check for old/insecure config, and fail safe (i.e. don't launch) if the config is wildly insecure. - if conf.get('webserver', 'secret_key') == 'temporary_key': + if conf.get("webserver", "secret_key") == "temporary_key": from rich import print as rich_print rich_print( @@ -331,24 +336,26 @@ def webserver(args): ) sys.exit(1) - access_logfile = args.access_logfile or conf.get('webserver', 'access_logfile') - error_logfile = args.error_logfile or conf.get('webserver', 'error_logfile') - access_logformat = args.access_logformat or conf.get('webserver', 'access_logformat') - num_workers = args.workers or conf.get('webserver', 'workers') - worker_timeout = args.worker_timeout or conf.get('webserver', 'web_server_worker_timeout') - ssl_cert = args.ssl_cert or conf.get('webserver', 'web_server_ssl_cert') - ssl_key = args.ssl_key or conf.get('webserver', 'web_server_ssl_key') + access_logfile = args.access_logfile or conf.get("webserver", "access_logfile") + error_logfile = args.error_logfile or conf.get("webserver", "error_logfile") + access_logformat = args.access_logformat or conf.get("webserver", "access_logformat") + num_workers = args.workers or conf.get("webserver", "workers") + worker_timeout = args.worker_timeout or conf.get("webserver", "web_server_worker_timeout") + ssl_cert = args.ssl_cert or conf.get("webserver", "web_server_ssl_cert") + ssl_key = args.ssl_key or conf.get("webserver", "web_server_ssl_key") if not ssl_cert and ssl_key: - raise AirflowException('An SSL certificate must also be provided for use with ' + ssl_key) + raise AirflowException("An SSL certificate must also be provided for use with " + ssl_key) if ssl_cert and not ssl_key: - raise AirflowException('An SSL key must also be provided for use with ' + ssl_cert) + raise AirflowException("An SSL key must also be provided for use with " + ssl_cert) + + from airflow.www.app import create_app if args.debug: print(f"Starting the web server on port {args.port} and host {args.hostname}.") - app = create_app(testing=conf.getboolean('core', 'unit_test_mode')) + app = create_app(testing=conf.getboolean("core", "unit_test_mode")) app.run( debug=True, - use_reloader=not app.config['TESTING'], + use_reloader=not app.config["TESTING"], port=args.port, host=args.hostname, ssl_context=(ssl_cert, ssl_key) if ssl_cert and ssl_key else None, @@ -364,54 +371,60 @@ def webserver(args): print( textwrap.dedent( - f'''\ + f"""\ Running the Gunicorn Server with: Workers: {num_workers} {args.workerclass} Host: {args.hostname}:{args.port} Timeout: {worker_timeout} Logfiles: {access_logfile} {error_logfile} Access Logformat: {access_logformat} - =================================================================''' + =================================================================""" ) ) run_args = [ sys.executable, - '-m', - 'gunicorn', - '--workers', + "-m", + "gunicorn", + "--workers", str(num_workers), - '--worker-class', + "--worker-class", str(args.workerclass), - '--timeout', + "--timeout", str(worker_timeout), - '--bind', - args.hostname + ':' + str(args.port), - '--name', - 'airflow-webserver', - '--pid', + "--bind", + args.hostname + ":" + str(args.port), + "--name", + "airflow-webserver", + "--pid", pid_file, - '--config', - 'python:airflow.www.gunicorn_config', + "--config", + "python:airflow.www.gunicorn_config", ] if args.access_logfile: - run_args += ['--access-logfile', str(args.access_logfile)] + run_args += ["--access-logfile", str(args.access_logfile)] if args.error_logfile: - run_args += ['--error-logfile', str(args.error_logfile)] + run_args += ["--error-logfile", str(args.error_logfile)] if args.access_logformat and args.access_logformat.strip(): - run_args += ['--access-logformat', str(args.access_logformat)] + run_args += ["--access-logformat", str(args.access_logformat)] if args.daemon: - run_args += ['--daemon'] + run_args += ["--daemon"] if ssl_cert: - run_args += ['--certfile', ssl_cert, '--keyfile', ssl_key] + run_args += ["--certfile", ssl_cert, "--keyfile", ssl_key] run_args += ["airflow.www.app:cached_app()"] + # To prevent different workers creating the web app and + # all writing to the database at the same time, we use the --preload option. + # With the preload option, the app is loaded before the workers are forked, and each worker will + # then have a copy of the app + run_args += ["--preload"] + gunicorn_master_proc = None def kill_proc(signum, _): @@ -432,29 +445,33 @@ def monitor_gunicorn(gunicorn_master_pid: int): GunicornMonitor( gunicorn_master_pid=gunicorn_master_pid, num_workers_expected=num_workers, - master_timeout=conf.getint('webserver', 'web_server_master_timeout'), - worker_refresh_interval=conf.getint('webserver', 'worker_refresh_interval', fallback=30), - worker_refresh_batch_size=conf.getint('webserver', 'worker_refresh_batch_size', fallback=1), + master_timeout=conf.getint("webserver", "web_server_master_timeout"), + worker_refresh_interval=conf.getint("webserver", "worker_refresh_interval", fallback=30), + worker_refresh_batch_size=conf.getint("webserver", "worker_refresh_batch_size", fallback=1), reload_on_plugin_change=conf.getboolean( - 'webserver', 'reload_on_plugin_change', fallback=False + "webserver", "reload_on_plugin_change", fallback=False ), ).start() if args.daemon: # This makes possible errors get reported before daemonization - os.environ['SKIP_DAGS_PARSING'] = 'True' + os.environ["SKIP_DAGS_PARSING"] = "True" app = create_app(None) - os.environ.pop('SKIP_DAGS_PARSING') + os.environ.pop("SKIP_DAGS_PARSING") handle = setup_logging(log_file) base, ext = os.path.splitext(pid_file) - with open(stdout, 'w+') as stdout, open(stderr, 'w+') as stderr: + with open(stdout, "a") as stdout, open(stderr, "a") as stderr: + stdout.truncate(0) + stderr.truncate(0) + ctx = daemon.DaemonContext( pidfile=TimeoutPIDLockFile(f"{base}-monitor{ext}", -1), files_preserve=[handle], stdout=stdout, stderr=stderr, + umask=int(settings.DAEMON_UMASK, 8), ) with ctx: subprocess.Popen(run_args, close_fds=True) diff --git a/airflow/cli/simple_table.py b/airflow/cli/simple_table.py index efd418d21fcbe..87f7ce9f3b853 100644 --- a/airflow/cli/simple_table.py +++ b/airflow/cli/simple_table.py @@ -14,9 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import inspect import json -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable from rich.box import ASCII_DOUBLE_HEAD from rich.console import Console @@ -30,7 +32,7 @@ class AirflowConsole(Console): - """Airflow rich console""" + """Airflow rich console.""" def __init__(self, show_header: bool = True, *args, **kwargs): super().__init__(*args, **kwargs) @@ -40,18 +42,18 @@ def __init__(self, show_header: bool = True, *args, **kwargs): # If show header in tables self.show_header = show_header - def print_as_json(self, data: Dict): - """Renders dict as json text representation""" + def print_as_json(self, data: dict): + """Renders dict as json text representation.""" json_content = json.dumps(data) self.print(Syntax(json_content, "json", theme="ansi_dark"), soft_wrap=True) - def print_as_yaml(self, data: Dict): - """Renders dict as yaml text representation""" + def print_as_yaml(self, data: dict): + """Renders dict as yaml text representation.""" yaml_content = yaml.dump(data) self.print(Syntax(yaml_content, "yaml", theme="ansi_dark"), soft_wrap=True) - def print_as_table(self, data: List[Dict]): - """Renders list of dictionaries as table""" + def print_as_table(self, data: list[dict]): + """Renders list of dictionaries as table.""" if not data: self.print("No data found") return @@ -64,8 +66,8 @@ def print_as_table(self, data: List[Dict]): table.add_row(*(str(d) for d in row.values())) self.print(table) - def print_as_plain_table(self, data: List[Dict]): - """Renders list of dictionaries as a simple table than can be easily piped""" + def print_as_plain_table(self, data: list[dict]): + """Renders list of dictionaries as a simple table than can be easily piped.""" if not data: self.print("No data found") return @@ -73,7 +75,7 @@ def print_as_plain_table(self, data: List[Dict]): output = tabulate(rows, tablefmt="plain", headers=list(data[0].keys())) print(output) - def _normalize_data(self, value: Any, output: str) -> Optional[Union[list, str, dict]]: + def _normalize_data(self, value: Any, output: str) -> list | str | dict | None: if isinstance(value, (tuple, list)): if output == "table": return ",".join(str(self._normalize_data(x, output)) for x in value) @@ -86,9 +88,9 @@ def _normalize_data(self, value: Any, output: str) -> Optional[Union[list, str, return None return str(value) - def print_as(self, data: List[Union[Dict, Any]], output: str, mapper: Optional[Callable] = None): - """Prints provided using format specified by output argument""" - output_to_renderer: Dict[str, Callable[[Any], None]] = { + def print_as(self, data: list[dict | Any], output: str, mapper: Callable | None = None): + """Prints provided using format specified by output argument.""" + output_to_renderer: dict[str, Callable[[Any], None]] = { "json": self.print_as_json, "yaml": self.print_as_yaml, "table": self.print_as_table, @@ -104,7 +106,7 @@ def print_as(self, data: List[Union[Dict, Any]], output: str, mapper: Optional[C raise ValueError("To tabulate non-dictionary data you need to provide `mapper` function") if mapper: - dict_data: List[Dict] = [mapper(d) for d in data] + dict_data: list[dict] = [mapper(d) for d in data] else: dict_data = data dict_data = [{k: self._normalize_data(v, output) for k, v in d.items()} for d in dict_data] @@ -125,6 +127,6 @@ def __init__(self, *args, **kwargs): self.caption = kwargs.get("caption", " ") def add_column(self, *args, **kwargs) -> None: - """Add a column to the table. We use different default""" + """Add a column to the table. We use different default.""" kwargs["overflow"] = kwargs.get("overflow") # to avoid truncating super().add_column(*args, **kwargs) diff --git a/airflow/compat/functools.py b/airflow/compat/functools.py index e3dea0a660bbc..dc0c520b796c0 100644 --- a/airflow/compat/functools.py +++ b/airflow/compat/functools.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import sys if sys.version_info >= (3, 8): diff --git a/airflow/compat/sqlalchemy.py b/airflow/compat/sqlalchemy.py deleted file mode 100644 index 427db90a73d67..0000000000000 --- a/airflow/compat/sqlalchemy.py +++ /dev/null @@ -1,35 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from sqlalchemy import Table -from sqlalchemy.engine import Connection - -try: - from sqlalchemy import inspect -except AttributeError: - from sqlalchemy.engine.reflection import Inspector - - inspect = Inspector.from_engine - -__all__ = ["has_table", "inspect"] - - -def has_table(conn: Connection, table: Table): - try: - return inspect(conn).has_table(table) - except AttributeError: - return table.exists(conn) diff --git a/airflow/config_templates/airflow_local_settings.py b/airflow/config_templates/airflow_local_settings.py index b2752c2be7c25..01edea7520f33 100644 --- a/airflow/config_templates/airflow_local_settings.py +++ b/airflow/config_templates/airflow_local_settings.py @@ -15,12 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Airflow logging settings""" +"""Airflow logging settings.""" +from __future__ import annotations import os from pathlib import Path -from typing import Any, Dict, Optional, Union -from urllib.parse import urlparse +from typing import Any +from urllib.parse import urlsplit from airflow.configuration import conf from airflow.exceptions import AirflowException @@ -29,123 +30,147 @@ # in this file instead of from airflow.cfg. Currently # there are other log format and level configurations in # settings.py and cli.py. Please see AIRFLOW-1455. -LOG_LEVEL: str = conf.get_mandatory_value('logging', 'LOGGING_LEVEL').upper() +LOG_LEVEL: str = conf.get_mandatory_value("logging", "LOGGING_LEVEL").upper() # Flask appbuilder's info level log is very verbose, # so it's set to 'WARN' by default. -FAB_LOG_LEVEL: str = conf.get_mandatory_value('logging', 'FAB_LOGGING_LEVEL').upper() +FAB_LOG_LEVEL: str = conf.get_mandatory_value("logging", "FAB_LOGGING_LEVEL").upper() -LOG_FORMAT: str = conf.get_mandatory_value('logging', 'LOG_FORMAT') +LOG_FORMAT: str = conf.get_mandatory_value("logging", "LOG_FORMAT") +DAG_PROCESSOR_LOG_FORMAT: str = conf.get_mandatory_value("logging", "DAG_PROCESSOR_LOG_FORMAT") -COLORED_LOG_FORMAT: str = conf.get_mandatory_value('logging', 'COLORED_LOG_FORMAT') +LOG_FORMATTER_CLASS: str = conf.get_mandatory_value( + "logging", "LOG_FORMATTER_CLASS", fallback="airflow.utils.log.timezone_aware.TimezoneAware" +) + +COLORED_LOG_FORMAT: str = conf.get_mandatory_value("logging", "COLORED_LOG_FORMAT") + +COLORED_LOG: bool = conf.getboolean("logging", "COLORED_CONSOLE_LOG") -COLORED_LOG: bool = conf.getboolean('logging', 'COLORED_CONSOLE_LOG') +COLORED_FORMATTER_CLASS: str = conf.get_mandatory_value("logging", "COLORED_FORMATTER_CLASS") -COLORED_FORMATTER_CLASS: str = conf.get_mandatory_value('logging', 'COLORED_FORMATTER_CLASS') +DAG_PROCESSOR_LOG_TARGET: str = conf.get_mandatory_value("logging", "DAG_PROCESSOR_LOG_TARGET") -BASE_LOG_FOLDER: str = conf.get_mandatory_value('logging', 'BASE_LOG_FOLDER') +BASE_LOG_FOLDER: str = conf.get_mandatory_value("logging", "BASE_LOG_FOLDER") -PROCESSOR_LOG_FOLDER: str = conf.get_mandatory_value('scheduler', 'CHILD_PROCESS_LOG_DIRECTORY') +PROCESSOR_LOG_FOLDER: str = conf.get_mandatory_value("scheduler", "CHILD_PROCESS_LOG_DIRECTORY") DAG_PROCESSOR_MANAGER_LOG_LOCATION: str = conf.get_mandatory_value( - 'logging', 'DAG_PROCESSOR_MANAGER_LOG_LOCATION' + "logging", "DAG_PROCESSOR_MANAGER_LOG_LOCATION" ) -FILENAME_TEMPLATE: str = conf.get_mandatory_value('logging', 'LOG_FILENAME_TEMPLATE') +# FILENAME_TEMPLATE only uses in Remote Logging Handlers since Airflow 2.3.3 +# All of these handlers inherited from FileTaskHandler and providing any value rather than None +# would raise deprecation warning. +FILENAME_TEMPLATE: str | None = None -PROCESSOR_FILENAME_TEMPLATE: str = conf.get_mandatory_value('logging', 'LOG_PROCESSOR_FILENAME_TEMPLATE') +PROCESSOR_FILENAME_TEMPLATE: str = conf.get_mandatory_value("logging", "LOG_PROCESSOR_FILENAME_TEMPLATE") -DEFAULT_LOGGING_CONFIG: Dict[str, Any] = { - 'version': 1, - 'disable_existing_loggers': False, - 'formatters': { - 'airflow': {'format': LOG_FORMAT}, - 'airflow_coloured': { - 'format': COLORED_LOG_FORMAT if COLORED_LOG else LOG_FORMAT, - 'class': COLORED_FORMATTER_CLASS if COLORED_LOG else 'logging.Formatter', +DEFAULT_LOGGING_CONFIG: dict[str, Any] = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "airflow": { + "format": LOG_FORMAT, + "class": LOG_FORMATTER_CLASS, + }, + "airflow_coloured": { + "format": COLORED_LOG_FORMAT if COLORED_LOG else LOG_FORMAT, + "class": COLORED_FORMATTER_CLASS if COLORED_LOG else LOG_FORMATTER_CLASS, + }, + "source_processor": { + "format": DAG_PROCESSOR_LOG_FORMAT, + "class": LOG_FORMATTER_CLASS, }, }, - 'filters': { - 'mask_secrets': { - '()': 'airflow.utils.log.secrets_masker.SecretsMasker', + "filters": { + "mask_secrets": { + "()": "airflow.utils.log.secrets_masker.SecretsMasker", }, }, - 'handlers': { - 'console': { - 'class': 'airflow.utils.log.logging_mixin.RedirectStdHandler', - 'formatter': 'airflow_coloured', - 'stream': 'sys.stdout', - 'filters': ['mask_secrets'], + "handlers": { + "console": { + "class": "airflow.utils.log.logging_mixin.RedirectStdHandler", + "formatter": "airflow_coloured", + "stream": "sys.stdout", + "filters": ["mask_secrets"], + }, + "task": { + "class": "airflow.utils.log.file_task_handler.FileTaskHandler", + "formatter": "airflow", + "base_log_folder": os.path.expanduser(BASE_LOG_FOLDER), + "filters": ["mask_secrets"], }, - 'task': { - 'class': 'airflow.utils.log.file_task_handler.FileTaskHandler', - 'formatter': 'airflow', - 'base_log_folder': os.path.expanduser(BASE_LOG_FOLDER), - 'filename_template': FILENAME_TEMPLATE, - 'filters': ['mask_secrets'], + "processor": { + "class": "airflow.utils.log.file_processor_handler.FileProcessorHandler", + "formatter": "airflow", + "base_log_folder": os.path.expanduser(PROCESSOR_LOG_FOLDER), + "filename_template": PROCESSOR_FILENAME_TEMPLATE, + "filters": ["mask_secrets"], }, - 'processor': { - 'class': 'airflow.utils.log.file_processor_handler.FileProcessorHandler', - 'formatter': 'airflow', - 'base_log_folder': os.path.expanduser(PROCESSOR_LOG_FOLDER), - 'filename_template': PROCESSOR_FILENAME_TEMPLATE, - 'filters': ['mask_secrets'], + "processor_to_stdout": { + "class": "airflow.utils.log.logging_mixin.RedirectStdHandler", + "formatter": "source_processor", + "stream": "sys.stdout", + "filters": ["mask_secrets"], }, }, - 'loggers': { - 'airflow.processor': { - 'handlers': ['processor'], - 'level': LOG_LEVEL, - 'propagate': False, + "loggers": { + "airflow.processor": { + "handlers": ["processor_to_stdout" if DAG_PROCESSOR_LOG_TARGET == "stdout" else "processor"], + "level": LOG_LEVEL, + # Set to true here (and reset via set_context) so that if no file is configured we still get logs! + "propagate": True, }, - 'airflow.task': { - 'handlers': ['task'], - 'level': LOG_LEVEL, - 'propagate': False, - 'filters': ['mask_secrets'], + "airflow.task": { + "handlers": ["task"], + "level": LOG_LEVEL, + # Set to true here (and reset via set_context) so that if no file is configured we still get logs! + "propagate": True, + "filters": ["mask_secrets"], }, - 'flask_appbuilder': { - 'handlers': ['console'], - 'level': FAB_LOG_LEVEL, - 'propagate': True, + "flask_appbuilder": { + "handlers": ["console"], + "level": FAB_LOG_LEVEL, + "propagate": True, }, }, - 'root': { - 'handlers': ['console'], - 'level': LOG_LEVEL, - 'filters': ['mask_secrets'], + "root": { + "handlers": ["console"], + "level": LOG_LEVEL, + "filters": ["mask_secrets"], }, } -EXTRA_LOGGER_NAMES: Optional[str] = conf.get('logging', 'EXTRA_LOGGER_NAMES', fallback=None) +EXTRA_LOGGER_NAMES: str | None = conf.get("logging", "EXTRA_LOGGER_NAMES", fallback=None) if EXTRA_LOGGER_NAMES: new_loggers = { logger_name.strip(): { - 'handlers': ['console'], - 'level': LOG_LEVEL, - 'propagate': True, + "handlers": ["console"], + "level": LOG_LEVEL, + "propagate": True, } for logger_name in EXTRA_LOGGER_NAMES.split(",") } - DEFAULT_LOGGING_CONFIG['loggers'].update(new_loggers) - -DEFAULT_DAG_PARSING_LOGGING_CONFIG: Dict[str, Dict[str, Dict[str, Any]]] = { - 'handlers': { - 'processor_manager': { - 'class': 'logging.handlers.RotatingFileHandler', - 'formatter': 'airflow', - 'filename': DAG_PROCESSOR_MANAGER_LOG_LOCATION, - 'mode': 'a', - 'maxBytes': 104857600, # 100MB - 'backupCount': 5, + DEFAULT_LOGGING_CONFIG["loggers"].update(new_loggers) + +DEFAULT_DAG_PARSING_LOGGING_CONFIG: dict[str, dict[str, dict[str, Any]]] = { + "handlers": { + "processor_manager": { + "class": "airflow.utils.log.non_caching_file_handler.NonCachingRotatingFileHandler", + "formatter": "airflow", + "filename": DAG_PROCESSOR_MANAGER_LOG_LOCATION, + "mode": "a", + "maxBytes": 104857600, # 100MB + "backupCount": 5, } }, - 'loggers': { - 'airflow.processor_manager': { - 'handlers': ['processor_manager'], - 'level': LOG_LEVEL, - 'propagate': False, + "loggers": { + "airflow.processor_manager": { + "handlers": ["processor_manager"], + "level": LOG_LEVEL, + "propagate": False, } }, } @@ -153,27 +178,27 @@ # Only update the handlers and loggers when CONFIG_PROCESSOR_MANAGER_LOGGER is set. # This is to avoid exceptions when initializing RotatingFileHandler multiple times # in multiple processes. -if os.environ.get('CONFIG_PROCESSOR_MANAGER_LOGGER') == 'True': - DEFAULT_LOGGING_CONFIG['handlers'].update(DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers']) - DEFAULT_LOGGING_CONFIG['loggers'].update(DEFAULT_DAG_PARSING_LOGGING_CONFIG['loggers']) +if os.environ.get("CONFIG_PROCESSOR_MANAGER_LOGGER") == "True": + DEFAULT_LOGGING_CONFIG["handlers"].update(DEFAULT_DAG_PARSING_LOGGING_CONFIG["handlers"]) + DEFAULT_LOGGING_CONFIG["loggers"].update(DEFAULT_DAG_PARSING_LOGGING_CONFIG["loggers"]) # Manually create log directory for processor_manager handler as RotatingFileHandler # will only create file but not the directory. - processor_manager_handler_config: Dict[str, Any] = DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers'][ - 'processor_manager' + processor_manager_handler_config: dict[str, Any] = DEFAULT_DAG_PARSING_LOGGING_CONFIG["handlers"][ + "processor_manager" ] - directory: str = os.path.dirname(processor_manager_handler_config['filename']) + directory: str = os.path.dirname(processor_manager_handler_config["filename"]) Path(directory).mkdir(parents=True, exist_ok=True, mode=0o755) ################## # Remote logging # ################## -REMOTE_LOGGING: bool = conf.getboolean('logging', 'remote_logging') +REMOTE_LOGGING: bool = conf.getboolean("logging", "remote_logging") if REMOTE_LOGGING: - ELASTICSEARCH_HOST: Optional[str] = conf.get('elasticsearch', 'HOST') + ELASTICSEARCH_HOST: str | None = conf.get("elasticsearch", "HOST") # Storage bucket URL for remote logging # S3 buckets should start with "s3://" @@ -181,115 +206,115 @@ # GCS buckets should start with "gs://" # WASB buckets should start with "wasb" # just to help Airflow select correct handler - REMOTE_BASE_LOG_FOLDER: str = conf.get_mandatory_value('logging', 'REMOTE_BASE_LOG_FOLDER') - - if REMOTE_BASE_LOG_FOLDER.startswith('s3://'): - S3_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = { - 'task': { - 'class': 'airflow.providers.amazon.aws.log.s3_task_handler.S3TaskHandler', - 'formatter': 'airflow', - 'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)), - 's3_log_folder': REMOTE_BASE_LOG_FOLDER, - 'filename_template': FILENAME_TEMPLATE, + REMOTE_BASE_LOG_FOLDER: str = conf.get_mandatory_value("logging", "REMOTE_BASE_LOG_FOLDER") + + if REMOTE_BASE_LOG_FOLDER.startswith("s3://"): + S3_REMOTE_HANDLERS: dict[str, dict[str, str | None]] = { + "task": { + "class": "airflow.providers.amazon.aws.log.s3_task_handler.S3TaskHandler", + "formatter": "airflow", + "base_log_folder": str(os.path.expanduser(BASE_LOG_FOLDER)), + "s3_log_folder": REMOTE_BASE_LOG_FOLDER, + "filename_template": FILENAME_TEMPLATE, }, } - DEFAULT_LOGGING_CONFIG['handlers'].update(S3_REMOTE_HANDLERS) - elif REMOTE_BASE_LOG_FOLDER.startswith('cloudwatch://'): - url_parts = urlparse(REMOTE_BASE_LOG_FOLDER) - CLOUDWATCH_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = { - 'task': { - 'class': 'airflow.providers.amazon.aws.log.cloudwatch_task_handler.CloudwatchTaskHandler', - 'formatter': 'airflow', - 'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)), - 'log_group_arn': url_parts.netloc + url_parts.path, - 'filename_template': FILENAME_TEMPLATE, + DEFAULT_LOGGING_CONFIG["handlers"].update(S3_REMOTE_HANDLERS) + elif REMOTE_BASE_LOG_FOLDER.startswith("cloudwatch://"): + url_parts = urlsplit(REMOTE_BASE_LOG_FOLDER) + CLOUDWATCH_REMOTE_HANDLERS: dict[str, dict[str, str | None]] = { + "task": { + "class": "airflow.providers.amazon.aws.log.cloudwatch_task_handler.CloudwatchTaskHandler", + "formatter": "airflow", + "base_log_folder": str(os.path.expanduser(BASE_LOG_FOLDER)), + "log_group_arn": url_parts.netloc + url_parts.path, + "filename_template": FILENAME_TEMPLATE, }, } - DEFAULT_LOGGING_CONFIG['handlers'].update(CLOUDWATCH_REMOTE_HANDLERS) - elif REMOTE_BASE_LOG_FOLDER.startswith('gs://'): - key_path = conf.get_mandatory_value('logging', 'GOOGLE_KEY_PATH', fallback=None) - GCS_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = { - 'task': { - 'class': 'airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler', - 'formatter': 'airflow', - 'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)), - 'gcs_log_folder': REMOTE_BASE_LOG_FOLDER, - 'filename_template': FILENAME_TEMPLATE, - 'gcp_key_path': key_path, + DEFAULT_LOGGING_CONFIG["handlers"].update(CLOUDWATCH_REMOTE_HANDLERS) + elif REMOTE_BASE_LOG_FOLDER.startswith("gs://"): + key_path = conf.get_mandatory_value("logging", "GOOGLE_KEY_PATH", fallback=None) + GCS_REMOTE_HANDLERS: dict[str, dict[str, str | None]] = { + "task": { + "class": "airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler", + "formatter": "airflow", + "base_log_folder": str(os.path.expanduser(BASE_LOG_FOLDER)), + "gcs_log_folder": REMOTE_BASE_LOG_FOLDER, + "filename_template": FILENAME_TEMPLATE, + "gcp_key_path": key_path, }, } - DEFAULT_LOGGING_CONFIG['handlers'].update(GCS_REMOTE_HANDLERS) - elif REMOTE_BASE_LOG_FOLDER.startswith('wasb'): - WASB_REMOTE_HANDLERS: Dict[str, Dict[str, Union[str, bool]]] = { - 'task': { - 'class': 'airflow.providers.microsoft.azure.log.wasb_task_handler.WasbTaskHandler', - 'formatter': 'airflow', - 'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)), - 'wasb_log_folder': REMOTE_BASE_LOG_FOLDER, - 'wasb_container': 'airflow-logs', - 'filename_template': FILENAME_TEMPLATE, - 'delete_local_copy': False, + DEFAULT_LOGGING_CONFIG["handlers"].update(GCS_REMOTE_HANDLERS) + elif REMOTE_BASE_LOG_FOLDER.startswith("wasb"): + WASB_REMOTE_HANDLERS: dict[str, dict[str, str | bool | None]] = { + "task": { + "class": "airflow.providers.microsoft.azure.log.wasb_task_handler.WasbTaskHandler", + "formatter": "airflow", + "base_log_folder": str(os.path.expanduser(BASE_LOG_FOLDER)), + "wasb_log_folder": REMOTE_BASE_LOG_FOLDER, + "wasb_container": "airflow-logs", + "filename_template": FILENAME_TEMPLATE, + "delete_local_copy": False, }, } - DEFAULT_LOGGING_CONFIG['handlers'].update(WASB_REMOTE_HANDLERS) - elif REMOTE_BASE_LOG_FOLDER.startswith('stackdriver://'): - key_path = conf.get_mandatory_value('logging', 'GOOGLE_KEY_PATH', fallback=None) + DEFAULT_LOGGING_CONFIG["handlers"].update(WASB_REMOTE_HANDLERS) + elif REMOTE_BASE_LOG_FOLDER.startswith("stackdriver://"): + key_path = conf.get_mandatory_value("logging", "GOOGLE_KEY_PATH", fallback=None) # stackdriver:///airflow-tasks => airflow-tasks - log_name = urlparse(REMOTE_BASE_LOG_FOLDER).path[1:] + log_name = urlsplit(REMOTE_BASE_LOG_FOLDER).path[1:] STACKDRIVER_REMOTE_HANDLERS = { - 'task': { - 'class': 'airflow.providers.google.cloud.log.stackdriver_task_handler.StackdriverTaskHandler', - 'formatter': 'airflow', - 'name': log_name, - 'gcp_key_path': key_path, + "task": { + "class": "airflow.providers.google.cloud.log.stackdriver_task_handler.StackdriverTaskHandler", + "formatter": "airflow", + "name": log_name, + "gcp_key_path": key_path, } } - DEFAULT_LOGGING_CONFIG['handlers'].update(STACKDRIVER_REMOTE_HANDLERS) - elif REMOTE_BASE_LOG_FOLDER.startswith('oss://'): + DEFAULT_LOGGING_CONFIG["handlers"].update(STACKDRIVER_REMOTE_HANDLERS) + elif REMOTE_BASE_LOG_FOLDER.startswith("oss://"): OSS_REMOTE_HANDLERS = { - 'task': { - 'class': 'airflow.providers.alibaba.cloud.log.oss_task_handler.OSSTaskHandler', - 'formatter': 'airflow', - 'base_log_folder': os.path.expanduser(BASE_LOG_FOLDER), - 'oss_log_folder': REMOTE_BASE_LOG_FOLDER, - 'filename_template': FILENAME_TEMPLATE, + "task": { + "class": "airflow.providers.alibaba.cloud.log.oss_task_handler.OSSTaskHandler", + "formatter": "airflow", + "base_log_folder": os.path.expanduser(BASE_LOG_FOLDER), + "oss_log_folder": REMOTE_BASE_LOG_FOLDER, + "filename_template": FILENAME_TEMPLATE, }, } - DEFAULT_LOGGING_CONFIG['handlers'].update(OSS_REMOTE_HANDLERS) + DEFAULT_LOGGING_CONFIG["handlers"].update(OSS_REMOTE_HANDLERS) elif ELASTICSEARCH_HOST: - ELASTICSEARCH_LOG_ID_TEMPLATE: str = conf.get_mandatory_value('elasticsearch', 'LOG_ID_TEMPLATE') - ELASTICSEARCH_END_OF_LOG_MARK: str = conf.get_mandatory_value('elasticsearch', 'END_OF_LOG_MARK') - ELASTICSEARCH_FRONTEND: str = conf.get_mandatory_value('elasticsearch', 'frontend') - ELASTICSEARCH_WRITE_STDOUT: bool = conf.getboolean('elasticsearch', 'WRITE_STDOUT') - ELASTICSEARCH_JSON_FORMAT: bool = conf.getboolean('elasticsearch', 'JSON_FORMAT') - ELASTICSEARCH_JSON_FIELDS: str = conf.get_mandatory_value('elasticsearch', 'JSON_FIELDS') - ELASTICSEARCH_HOST_FIELD: str = conf.get_mandatory_value('elasticsearch', 'HOST_FIELD') - ELASTICSEARCH_OFFSET_FIELD: str = conf.get_mandatory_value('elasticsearch', 'OFFSET_FIELD') - - ELASTIC_REMOTE_HANDLERS: Dict[str, Dict[str, Union[str, bool]]] = { - 'task': { - 'class': 'airflow.providers.elasticsearch.log.es_task_handler.ElasticsearchTaskHandler', - 'formatter': 'airflow', - 'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)), - 'log_id_template': ELASTICSEARCH_LOG_ID_TEMPLATE, - 'filename_template': FILENAME_TEMPLATE, - 'end_of_log_mark': ELASTICSEARCH_END_OF_LOG_MARK, - 'host': ELASTICSEARCH_HOST, - 'frontend': ELASTICSEARCH_FRONTEND, - 'write_stdout': ELASTICSEARCH_WRITE_STDOUT, - 'json_format': ELASTICSEARCH_JSON_FORMAT, - 'json_fields': ELASTICSEARCH_JSON_FIELDS, - 'host_field': ELASTICSEARCH_HOST_FIELD, - 'offset_field': ELASTICSEARCH_OFFSET_FIELD, + ELASTICSEARCH_LOG_ID_TEMPLATE: str = conf.get_mandatory_value("elasticsearch", "LOG_ID_TEMPLATE") + ELASTICSEARCH_END_OF_LOG_MARK: str = conf.get_mandatory_value("elasticsearch", "END_OF_LOG_MARK") + ELASTICSEARCH_FRONTEND: str = conf.get_mandatory_value("elasticsearch", "frontend") + ELASTICSEARCH_WRITE_STDOUT: bool = conf.getboolean("elasticsearch", "WRITE_STDOUT") + ELASTICSEARCH_JSON_FORMAT: bool = conf.getboolean("elasticsearch", "JSON_FORMAT") + ELASTICSEARCH_JSON_FIELDS: str = conf.get_mandatory_value("elasticsearch", "JSON_FIELDS") + ELASTICSEARCH_HOST_FIELD: str = conf.get_mandatory_value("elasticsearch", "HOST_FIELD") + ELASTICSEARCH_OFFSET_FIELD: str = conf.get_mandatory_value("elasticsearch", "OFFSET_FIELD") + + ELASTIC_REMOTE_HANDLERS: dict[str, dict[str, str | bool | None]] = { + "task": { + "class": "airflow.providers.elasticsearch.log.es_task_handler.ElasticsearchTaskHandler", + "formatter": "airflow", + "base_log_folder": str(os.path.expanduser(BASE_LOG_FOLDER)), + "log_id_template": ELASTICSEARCH_LOG_ID_TEMPLATE, + "filename_template": FILENAME_TEMPLATE, + "end_of_log_mark": ELASTICSEARCH_END_OF_LOG_MARK, + "host": ELASTICSEARCH_HOST, + "frontend": ELASTICSEARCH_FRONTEND, + "write_stdout": ELASTICSEARCH_WRITE_STDOUT, + "json_format": ELASTICSEARCH_JSON_FORMAT, + "json_fields": ELASTICSEARCH_JSON_FIELDS, + "host_field": ELASTICSEARCH_HOST_FIELD, + "offset_field": ELASTICSEARCH_OFFSET_FIELD, }, } - DEFAULT_LOGGING_CONFIG['handlers'].update(ELASTIC_REMOTE_HANDLERS) + DEFAULT_LOGGING_CONFIG["handlers"].update(ELASTIC_REMOTE_HANDLERS) else: raise AirflowException( "Incorrect remote log configuration. Please check the configuration of option 'host' in " diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index f884959a88840..2c0721aa89fcd 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -33,15 +33,15 @@ Hostname by providing a path to a callable, which will resolve the hostname. The format is "package.function". - For example, default value "socket.getfqdn" means that result from getfqdn() of "socket" - package will be used as hostname. + For example, default value "airflow.utils.net.getfqdn" means that result from patched + version of socket.getfqdn() - see https://github.com/python/cpython/issues/49254. No argument should be required in the function specified. If using IP address as hostname is preferred, use value ``airflow.utils.net.get_host_ip_address`` version_added: ~ type: string example: ~ - default: "socket.getfqdn" + default: "airflow.utils.net.getfqdn" - name: default_timezone description: | Default timezone in case supplied date times are naive @@ -99,6 +99,17 @@ type: string example: ~ default: "16" + - name: mp_start_method + description: | + The name of the method used in order to start Python processes via the multiprocessing module. + This corresponds directly with the options available in the Python docs: + https://docs.python.org/3/library/multiprocessing.html#multiprocessing.set_start_method. + Must be one of the values returned by: + https://docs.python.org/3/library/multiprocessing.html#multiprocessing.get_all_start_methods. + version_added: "2.0.0" + type: string + default: ~ + example: "fork" - name: load_examples description: | Whether to load the DAG examples that ship with Airflow. It's good to @@ -210,6 +221,15 @@ example: ~ default: "False" see_also: "https://docs.python.org/3/library/pickle.html#comparison-with-json" + - name: allowed_deserialization_classes + description: | + What classes can be imported during deserialization. This is a multi line value. + The individual items will be parsed as regexp. Python built-in classes (like dict) + are always allowed + version_added: 2.5.0 + type: string + default: 'airflow\..*' + example: ~ - name: killed_task_cleanup_time description: | When a task is killed forcefully, this is the amount of time in seconds that @@ -253,7 +273,7 @@ description: | The number of seconds each task is going to wait by default between retries. Can be overridden at dag or task level. - version_added: 2.3.2 + version_added: 2.4.0 type: integer example: ~ default: "300" @@ -374,6 +394,31 @@ example: ~ default: "1024" + - name: daemon_umask + description: | + The default umask to use for process when run in daemon mode (scheduler, worker, etc.) + + This controls the file-creation mode mask which determines the initial value of file permission bits + for newly created files. + + This value is treated as an octal-integer. + version_added: 2.3.4 + type: string + default: "0o077" + example: ~ + - name: dataset_manager_class + description: Class to use as dataset manager. + version_added: 2.4.0 + type: string + default: ~ + example: 'airflow.datasets.manager.DatasetManager' + - name: dataset_manager_kwargs + description: Kwargs to supply to dataset manager. + version_added: 2.4.0 + type: string + default: ~ + example: '{"some_param": "some_value"}' + - name: database description: ~ options: @@ -405,7 +450,8 @@ default: "utf-8" - name: sql_engine_collation_for_ids description: | - Collation for ``dag_id``, ``task_id``, ``key`` columns in case they have different encoding. + Collation for ``dag_id``, ``task_id``, ``key``, ``external_executor_id`` columns + in case they have different encoding. By default this collation is the same as the database collation, however for ``mysql`` and ``mariadb`` the default is ``utf8mb3_bin`` so that the index sizes of our index keys will not exceed the maximum size of allowed index when collation is set to ``utf8mb4`` variant @@ -459,7 +505,7 @@ Check connection at the start of each connection pool checkout. Typically, this is a simple statement like "SELECT 1". More information here: - https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic + https://docs.sqlalchemy.org/en/14/core/pooling.html#disconnect-handling-pessimistic version_added: 2.3.0 type: string example: ~ @@ -477,7 +523,7 @@ Import path for connect args in SqlAlchemy. Defaults to an empty dict. This is useful when you want to configure db engine args that SqlAlchemy won't parse in connection string. - See https://docs.sqlalchemy.org/en/13/core/engines.html#sqlalchemy.create_engine.params.connect_args + See https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.connect_args version_added: 2.3.0 type: string example: ~ @@ -633,6 +679,27 @@ type: string example: ~ default: "%%(asctime)s %%(levelname)s - %%(message)s" + - name: dag_processor_log_target + description: Where to send dag parser logs. If "file", + logs are sent to log files defined by child_process_log_directory. + version_added: 2.4.0 + type: string + example: ~ + default: "file" + - name: dag_processor_log_format + description: | + Format of Dag Processor Log line + version_added: 2.4.0 + type: string + example: ~ + default: "[%%(asctime)s] [SOURCE:DAG_PROCESSOR] + {{%%(filename)s:%%(lineno)d}} %%(levelname)s - %%(message)s" + - name: log_formatter_class + description: ~ + version_added: 2.3.4 + type: string + example: ~ + default: "airflow.utils.log.timezone_aware.TimezoneAware" - name: task_log_prefix_template description: | Specify prefix pattern like mentioned below with stream handler TaskHandlerWithCustomFormatter @@ -1135,7 +1202,9 @@ - name: worker_class description: | The worker class gunicorn should use. Choices include - sync (default), eventlet, gevent + sync (default), eventlet, gevent. Note when using gevent you might also want to set the + "_AIRFLOW_PATCH_GEVENT" environment variable to "1" to make sure gevent patching is done as + early as possible. version_added: ~ type: string example: ~ @@ -1165,7 +1234,9 @@ default: "" - name: expose_config description: | - Expose the configuration file in the web server + Expose the configuration file in the web server. Set to "non-sensitive-only" to show all values + except those that have security implications. "True" shows all values. "False" hides the + configuration completely. version_added: ~ type: string example: ~ @@ -1183,7 +1254,7 @@ version_added: 1.10.8 type: string example: ~ - default: "True" + default: "False" - name: dag_default_view description: | Default DAG view. Valid values are: ``grid``, ``graph``, ``duration``, ``gantt``, ``landing_times`` @@ -1645,15 +1716,6 @@ type: boolean example: ~ default: "true" - - name: worker_umask - description: | - Umask that will be used when starting workers with the ``airflow celery worker`` - in daemon mode. This control the file-creation mode mask which determines the initial - value of file permission bits for newly created files. - version_added: 2.0.0 - type: string - example: ~ - default: "0o077" - name: broker_url description: | The Celery broker URL. Celery supports RabbitMQ, Redis and experimentally @@ -1670,12 +1732,13 @@ or insert it into a database (depending of the backend) This status is used by the scheduler to update the state of the task The use of a database is highly recommended + When not specified, sql_alchemy_conn with a db+ scheme prefix will be used http://docs.celeryproject.org/en/latest/userguide/configuration.html#task-result-backend-settings version_added: ~ type: string sensitive: true - example: ~ - default: "db+postgresql://postgres:airflow@postgres/airflow" + example: "db+postgresql://postgres:airflow@postgres/airflow" + default: ~ - name: flower_host description: | Celery Flower is a sweet UI for Celery. Airflow has a shortcut to start @@ -1905,11 +1968,12 @@ type: string example: ~ default: "30" - - name: deactivate_stale_dags_interval + - name: parsing_cleanup_interval description: | How often (in seconds) to check for stale DAGs (DAGs which are no longer present in - the expected files) which should be deactivated. - version_added: 2.2.5 + the expected files) which should be deactivated, as well as datasets that are no longer + referenced and should be marked as orphaned. + version_added: 2.5.0 type: integer example: ~ default: "60" @@ -1943,6 +2007,22 @@ type: string example: ~ default: "30" + - name: enable_health_check + description: | + When you start a scheduler, airflow starts a tiny web server + subprocess to serve a health check if this is set to True + version_added: 2.4.0 + type: boolean + example: ~ + default: "False" + - name: scheduler_health_check_server_port + description: | + When you start a scheduler, airflow starts a tiny web server + subprocess to serve a health check on this port + version_added: 2.4.0 + type: string + example: ~ + default: "8974" - name: orphaned_tasks_check_interval description: | How often (in seconds) should the scheduler check for orphaned tasks and SchedulerJobs @@ -2081,6 +2161,14 @@ type: integer example: ~ default: "20" + - name: dag_stale_not_seen_duration + description: | + Only applicable if `[scheduler]standalone_dag_processor` is true. + Time in seconds after which dags, which were not updated by Dag Processor are deactivated. + version_added: 2.4.0 + type: integer + example: ~ + default: "600" - name: use_job_schedule description: | Turn off scheduler use of cron intervals by setting this to False. @@ -2097,12 +2185,6 @@ type: string example: ~ default: "False" - - name: dependency_detector - description: DAG dependency detector class to use - version_added: 2.1.0 - type: string - example: ~ - default: "airflow.serialization.serialized_objects.DependencyDetector" - name: trigger_timeout_check_interval description: | How often to check for expired trigger requests that have not run yet. @@ -2252,7 +2334,7 @@ type: string example: ~ default: "True" -- name: kubernetes +- name: kubernetes_executor description: ~ options: - name: pod_template_file @@ -2445,36 +2527,3 @@ type: float example: ~ default: "604800" -- name: smart_sensor - description: ~ - options: - - name: use_smart_sensor - description: | - When `use_smart_sensor` is True, Airflow redirects multiple qualified sensor tasks to - smart sensor task. - version_added: 2.0.0 - type: boolean - example: ~ - default: "False" - - name: shard_code_upper_limit - description: | - `shard_code_upper_limit` is the upper limit of `shard_code` value. The `shard_code` is generated - by `hashcode % shard_code_upper_limit`. - version_added: 2.0.0 - type: integer - example: ~ - default: "10000" - - name: shards - description: | - The number of running smart sensor processes for each service. - version_added: 2.0.0 - type: integer - example: ~ - default: "5" - - name: sensors_enabled - description: | - comma separated sensor classes support in smart_sensor. - version_added: 2.0.0 - type: string - example: ~ - default: "NamedHivePartitionSensor" diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index b5c1d4290bd60..80915d6643a56 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -16,7 +16,6 @@ # specific language governing permissions and limitations # under the License. - # This is the template for Airflow's default configuration. When Airflow is # imported, it looks for a configuration file at $AIRFLOW_HOME/airflow.cfg. If # it doesn't exist, Airflow uses this template to generate it by replacing @@ -36,12 +35,12 @@ dags_folder = {AIRFLOW_HOME}/dags # Hostname by providing a path to a callable, which will resolve the hostname. # The format is "package.function". # -# For example, default value "socket.getfqdn" means that result from getfqdn() of "socket" -# package will be used as hostname. +# For example, default value "airflow.utils.net.getfqdn" means that result from patched +# version of socket.getfqdn() - see https://github.com/python/cpython/issues/49254. # # No argument should be required in the function specified. # If using IP address as hostname is preferred, use value ``airflow.utils.net.get_host_ip_address`` -hostname_callable = socket.getfqdn +hostname_callable = airflow.utils.net.getfqdn # Default timezone in case supplied date times are naive # can be utc (default), system, or any IANA timezone string (e.g. Europe/Amsterdam) @@ -76,6 +75,14 @@ dags_are_paused_at_creation = True # which is defaulted as ``max_active_runs_per_dag``. max_active_runs_per_dag = 16 +# The name of the method used in order to start Python processes via the multiprocessing module. +# This corresponds directly with the options available in the Python docs: +# https://docs.python.org/3/library/multiprocessing.html#multiprocessing.set_start_method. +# Must be one of the values returned by: +# https://docs.python.org/3/library/multiprocessing.html#multiprocessing.get_all_start_methods. +# Example: mp_start_method = fork +# mp_start_method = + # Whether to load the DAG examples that ship with Airflow. It's good to # get started, but you probably want to set this to ``False`` in a production # environment @@ -128,6 +135,11 @@ unit_test_mode = False # RCE exploits). enable_xcom_pickling = False +# What classes can be imported during deserialization. This is a multi line value. +# The individual items will be parsed as regexp. Python built-in classes (like dict) +# are always allowed +allowed_deserialization_classes = airflow\..* + # When a task is killed forcefully, this is the amount of time in seconds that # it has to cleanup after it is sent a SIGTERM, before it is SIGKILLED killed_task_cleanup_time = 60 @@ -212,6 +224,22 @@ default_pool_task_slot_count = 128 # mapped tasks from clogging the scheduler. max_map_length = 1024 +# The default umask to use for process when run in daemon mode (scheduler, worker, etc.) +# +# This controls the file-creation mode mask which determines the initial value of file permission bits +# for newly created files. +# +# This value is treated as an octal-integer. +daemon_umask = 0o077 + +# Class to use as dataset manager. +# Example: dataset_manager_class = airflow.datasets.manager.DatasetManager +# dataset_manager_class = + +# Kwargs to supply to dataset manager. +# Example: dataset_manager_kwargs = {{"some_param": "some_value"}} +# dataset_manager_kwargs = + [database] # The SqlAlchemy connection string to the metadata database. # SqlAlchemy supports many different database engines. @@ -226,7 +254,8 @@ sql_alchemy_conn = sqlite:///{AIRFLOW_HOME}/airflow.db # The encoding for the databases sql_engine_encoding = utf-8 -# Collation for ``dag_id``, ``task_id``, ``key`` columns in case they have different encoding. +# Collation for ``dag_id``, ``task_id``, ``key``, ``external_executor_id`` columns +# in case they have different encoding. # By default this collation is the same as the database collation, however for ``mysql`` and ``mariadb`` # the default is ``utf8mb3_bin`` so that the index sizes of our index keys will not exceed # the maximum size of allowed index when collation is set to ``utf8mb4`` variant @@ -260,7 +289,7 @@ sql_alchemy_pool_recycle = 1800 # Check connection at the start of each connection pool checkout. # Typically, this is a simple statement like "SELECT 1". # More information here: -# https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic +# https://docs.sqlalchemy.org/en/14/core/pooling.html#disconnect-handling-pessimistic sql_alchemy_pool_pre_ping = True # The schema to use for the metadata database. @@ -270,7 +299,7 @@ sql_alchemy_schema = # Import path for connect args in SqlAlchemy. Defaults to an empty dict. # This is useful when you want to configure db engine args that SqlAlchemy won't parse # in connection string. -# See https://docs.sqlalchemy.org/en/13/core/engines.html#sqlalchemy.create_engine.params.connect_args +# See https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.connect_args # sql_alchemy_connect_args = # Whether to load the default connections that ship with Airflow. It's good to @@ -350,6 +379,13 @@ colored_formatter_class = airflow.utils.log.colored_log.CustomTTYColoredFormatte log_format = [%%(asctime)s] {{%%(filename)s:%%(lineno)d}} %%(levelname)s - %%(message)s simple_log_format = %%(asctime)s %%(levelname)s - %%(message)s +# Where to send dag parser logs. If "file", logs are sent to log files defined by child_process_log_directory. +dag_processor_log_target = file + +# Format of Dag Processor Log line +dag_processor_log_format = [%%(asctime)s] [SOURCE:DAG_PROCESSOR] {{%%(filename)s:%%(lineno)d}} %%(levelname)s - %%(message)s +log_formatter_class = airflow.utils.log.timezone_aware.TimezoneAware + # Specify prefix pattern like mentioned below with stream handler TaskHandlerWithCustomFormatter # Example: task_log_prefix_template = {{ti.dag_id}}-{{ti.task_id}}-{{execution_date}}-{{try_number}} task_log_prefix_template = @@ -585,7 +621,9 @@ secret_key = {SECRET_KEY} workers = 4 # The worker class gunicorn should use. Choices include -# sync (default), eventlet, gevent +# sync (default), eventlet, gevent. Note when using gevent you might also want to set the +# "_AIRFLOW_PATCH_GEVENT" environment variable to "1" to make sure gevent patching is done as +# early as possible. worker_class = sync # Log files for the gunicorn webserver. '-' means log to stderr. @@ -599,14 +637,16 @@ error_logfile = - # documentation - https://docs.gunicorn.org/en/stable/settings.html#access-log-format access_logformat = -# Expose the configuration file in the web server +# Expose the configuration file in the web server. Set to "non-sensitive-only" to show all values +# except those that have security implications. "True" shows all values. "False" hides the +# configuration completely. expose_config = False # Expose hostname in the web server expose_hostname = True # Expose stacktrace in the web server -expose_stacktrace = True +expose_stacktrace = False # Default DAG view. Valid values are: ``grid``, ``graph``, ``duration``, ``gantt``, ``landing_times`` dag_default_view = grid @@ -832,11 +872,6 @@ worker_prefetch_multiplier = 1 # prevent this by setting this to false. However, with this disabled Flower won't work. worker_enable_remote_control = true -# Umask that will be used when starting workers with the ``airflow celery worker`` -# in daemon mode. This control the file-creation mode mask which determines the initial -# value of file permission bits for newly created files. -worker_umask = 0o077 - # The Celery broker URL. Celery supports RabbitMQ, Redis and experimentally # a sqlalchemy database. Refer to the Celery documentation for more information. broker_url = redis://redis:6379/0 @@ -846,8 +881,10 @@ broker_url = redis://redis:6379/0 # or insert it into a database (depending of the backend) # This status is used by the scheduler to update the state of the task # The use of a database is highly recommended +# When not specified, sql_alchemy_conn with a db+ scheme prefix will be used # http://docs.celeryproject.org/en/latest/userguide/configuration.html#task-result-backend-settings -result_backend = db+postgresql://postgres:airflow@postgres/airflow +# Example: result_backend = db+postgresql://postgres:airflow@postgres/airflow +# result_backend = # Celery Flower is a sweet UI for Celery. Airflow has a shortcut to start # it ``airflow celery flower``. This defines the IP that Celery Flower runs on @@ -963,8 +1000,9 @@ scheduler_idle_sleep_time = 1 min_file_process_interval = 30 # How often (in seconds) to check for stale DAGs (DAGs which are no longer present in -# the expected files) which should be deactivated. -deactivate_stale_dags_interval = 60 +# the expected files) which should be deactivated, as well as datasets that are no longer +# referenced and should be marked as orphaned. +parsing_cleanup_interval = 60 # How often (in seconds) to scan the DAGs directory for new files. Default to 5 minutes. dag_dir_list_interval = 300 @@ -980,6 +1018,14 @@ pool_metrics_interval = 5.0 # This is used by the health check in the "/health" endpoint scheduler_health_check_threshold = 30 +# When you start a scheduler, airflow starts a tiny web server +# subprocess to serve a health check if this is set to True +enable_health_check = False + +# When you start a scheduler, airflow starts a tiny web server +# subprocess to serve a health check on this port +scheduler_health_check_server_port = 8974 + # How often (in seconds) should the scheduler check for orphaned tasks and SchedulerJobs orphaned_tasks_check_interval = 300.0 child_process_log_directory = {AIRFLOW_HOME}/logs/scheduler @@ -1054,6 +1100,10 @@ standalone_dag_processor = False # in database. Contains maximum number of callbacks that are fetched during a single loop. max_callbacks_per_loop = 20 +# Only applicable if `[scheduler]standalone_dag_processor` is true. +# Time in seconds after which dags, which were not updated by Dag Processor are deactivated. +dag_stale_not_seen_duration = 600 + # Turn off scheduler use of cron intervals by setting this to False. # DAGs submitted manually in the web UI or with trigger_dag will still run. use_job_schedule = True @@ -1062,9 +1112,6 @@ use_job_schedule = True # Only has effect if schedule_interval is set to None in DAG allow_trigger_in_future = False -# DAG dependency detector class to use -dependency_detector = airflow.serialization.serialized_objects.DependencyDetector - # How often to check for expired trigger requests that have not run yet. trigger_timeout_check_interval = 15 @@ -1122,7 +1169,7 @@ offset_field = offset use_ssl = False verify_certs = True -[kubernetes] +[kubernetes_executor] # Path to the YAML pod file that forms the basis for KubernetesExecutor workers. pod_template_file = @@ -1218,18 +1265,3 @@ worker_pods_pending_timeout_batch_size = 100 [sensors] # Sensor default timeout, 7 days by default (7 * 24 * 60 * 60). default_timeout = 604800 - -[smart_sensor] -# When `use_smart_sensor` is True, Airflow redirects multiple qualified sensor tasks to -# smart sensor task. -use_smart_sensor = False - -# `shard_code_upper_limit` is the upper limit of `shard_code` value. The `shard_code` is generated -# by `hashcode % shard_code_upper_limit`. -shard_code_upper_limit = 10000 - -# The number of running smart sensor processes for each service. -shards = 5 - -# comma separated sensor classes support in smart_sensor. -sensors_enabled = NamedHivePartitionSensor diff --git a/airflow/config_templates/default_celery.py b/airflow/config_templates/default_celery.py index 9d81c6353fba2..d3d5a4adf11e2 100644 --- a/airflow/config_templates/default_celery.py +++ b/airflow/config_templates/default_celery.py @@ -16,6 +16,8 @@ # specific language governing permissions and limitations # under the License. """Default celery configuration.""" +from __future__ import annotations + import logging import ssl @@ -29,71 +31,76 @@ def _broker_supports_visibility_timeout(url): log = logging.getLogger(__name__) -broker_url = conf.get('celery', 'BROKER_URL') +broker_url = conf.get("celery", "BROKER_URL") -broker_transport_options = conf.getsection('celery_broker_transport_options') or {} -if 'visibility_timeout' not in broker_transport_options: +broker_transport_options = conf.getsection("celery_broker_transport_options") or {} +if "visibility_timeout" not in broker_transport_options: if _broker_supports_visibility_timeout(broker_url): - broker_transport_options['visibility_timeout'] = 21600 + broker_transport_options["visibility_timeout"] = 21600 + +if conf.has_option("celery", "RESULT_BACKEND"): + result_backend = conf.get_mandatory_value("celery", "RESULT_BACKEND") +else: + log.debug("Value for celery result_backend not found. Using sql_alchemy_conn with db+ prefix.") + result_backend = f'db+{conf.get("database", "SQL_ALCHEMY_CONN")}' DEFAULT_CELERY_CONFIG = { - 'accept_content': ['json'], - 'event_serializer': 'json', - 'worker_prefetch_multiplier': conf.getint('celery', 'worker_prefetch_multiplier'), - 'task_acks_late': True, - 'task_default_queue': conf.get('operators', 'DEFAULT_QUEUE'), - 'task_default_exchange': conf.get('operators', 'DEFAULT_QUEUE'), - 'task_track_started': conf.getboolean('celery', 'task_track_started'), - 'broker_url': broker_url, - 'broker_transport_options': broker_transport_options, - 'result_backend': conf.get('celery', 'RESULT_BACKEND'), - 'worker_concurrency': conf.getint('celery', 'WORKER_CONCURRENCY'), - 'worker_enable_remote_control': conf.getboolean('celery', 'worker_enable_remote_control'), + "accept_content": ["json"], + "event_serializer": "json", + "worker_prefetch_multiplier": conf.getint("celery", "worker_prefetch_multiplier"), + "task_acks_late": True, + "task_default_queue": conf.get("operators", "DEFAULT_QUEUE"), + "task_default_exchange": conf.get("operators", "DEFAULT_QUEUE"), + "task_track_started": conf.getboolean("celery", "task_track_started"), + "broker_url": broker_url, + "broker_transport_options": broker_transport_options, + "result_backend": result_backend, + "worker_concurrency": conf.getint("celery", "WORKER_CONCURRENCY"), + "worker_enable_remote_control": conf.getboolean("celery", "worker_enable_remote_control"), } celery_ssl_active = False try: - celery_ssl_active = conf.getboolean('celery', 'SSL_ACTIVE') + celery_ssl_active = conf.getboolean("celery", "SSL_ACTIVE") except AirflowConfigException: log.warning("Celery Executor will run without SSL") try: if celery_ssl_active: - if broker_url and 'amqp://' in broker_url: + if broker_url and "amqp://" in broker_url: broker_use_ssl = { - 'keyfile': conf.get('celery', 'SSL_KEY'), - 'certfile': conf.get('celery', 'SSL_CERT'), - 'ca_certs': conf.get('celery', 'SSL_CACERT'), - 'cert_reqs': ssl.CERT_REQUIRED, + "keyfile": conf.get("celery", "SSL_KEY"), + "certfile": conf.get("celery", "SSL_CERT"), + "ca_certs": conf.get("celery", "SSL_CACERT"), + "cert_reqs": ssl.CERT_REQUIRED, } - elif broker_url and 'redis://' in broker_url: + elif broker_url and "redis://" in broker_url: broker_use_ssl = { - 'ssl_keyfile': conf.get('celery', 'SSL_KEY'), - 'ssl_certfile': conf.get('celery', 'SSL_CERT'), - 'ssl_ca_certs': conf.get('celery', 'SSL_CACERT'), - 'ssl_cert_reqs': ssl.CERT_REQUIRED, + "ssl_keyfile": conf.get("celery", "SSL_KEY"), + "ssl_certfile": conf.get("celery", "SSL_CERT"), + "ssl_ca_certs": conf.get("celery", "SSL_CACERT"), + "ssl_cert_reqs": ssl.CERT_REQUIRED, } else: raise AirflowException( - 'The broker you configured does not support SSL_ACTIVE to be True. ' - 'Please use RabbitMQ or Redis if you would like to use SSL for broker.' + "The broker you configured does not support SSL_ACTIVE to be True. " + "Please use RabbitMQ or Redis if you would like to use SSL for broker." ) - DEFAULT_CELERY_CONFIG['broker_use_ssl'] = broker_use_ssl + DEFAULT_CELERY_CONFIG["broker_use_ssl"] = broker_use_ssl except AirflowConfigException: raise AirflowException( - 'AirflowConfigException: SSL_ACTIVE is True, ' - 'please ensure SSL_KEY, ' - 'SSL_CERT and SSL_CACERT are set' + "AirflowConfigException: SSL_ACTIVE is True, " + "please ensure SSL_KEY, " + "SSL_CERT and SSL_CACERT are set" ) except Exception as e: raise AirflowException( - f'Exception: There was an unknown Celery SSL Error. Please ensure you want to use SSL and/or have ' - f'all necessary certs and key ({e}).' + f"Exception: There was an unknown Celery SSL Error. Please ensure you want to use SSL and/or have " + f"all necessary certs and key ({e})." ) -result_backend = str(DEFAULT_CELERY_CONFIG['result_backend']) -if 'amqp://' in result_backend or 'redis://' in result_backend or 'rpc://' in result_backend: +if "amqp://" in result_backend or "redis://" in result_backend or "rpc://" in result_backend: log.warning( "You have configured a result_backend of %s, it is highly recommended " "to use an alternative result_backend (i.e. a database).", diff --git a/airflow/config_templates/default_test.cfg b/airflow/config_templates/default_test.cfg index 2f9b6fa264b13..523f52cb69a04 100644 --- a/airflow/config_templates/default_test.cfg +++ b/airflow/config_templates/default_test.cfg @@ -32,100 +32,37 @@ unit_test_mode = True dags_folder = {TEST_DAGS_FOLDER} plugins_folder = {TEST_PLUGINS_FOLDER} -executor = SequentialExecutor -load_examples = True -donot_pickle = True -max_active_tasks_per_dag = 16 dags_are_paused_at_creation = False fernet_key = {FERNET_KEY} -enable_xcom_pickling = False killed_task_cleanup_time = 5 -hostname_callable = socket.getfqdn -default_task_retries = 0 -# This is a hack, too many tests assume DAGs are already in the DB. We need to fix those tests instead -store_serialized_dags = False +allowed_deserialization_classes = airflow\..* + tests\..* [database] sql_alchemy_conn = sqlite:///{AIRFLOW_HOME}/unittests.db -load_default_connections = True [logging] -base_log_folder = {AIRFLOW_HOME}/logs -logging_level = INFO celery_logging_level = WARN -fab_logging_level = WARN -log_filename_template = {{{{ ti.dag_id }}}}/{{{{ ti.task_id }}}}/{{{{ ts }}}}/{{{{ try_number }}}}.log -log_processor_filename_template = {{{{ filename }}}}.log -dag_processor_manager_log_location = {AIRFLOW_HOME}/logs/dag_processor_manager/dag_processor_manager.log -worker_log_server_port = 8793 - -[cli] -api_client = airflow.api.client.local_client -endpoint_url = http://localhost:8080 [api] auth_backends = airflow.api.auth.backend.default -[operators] -default_owner = airflow - - [hive] default_hive_mapred_queue = airflow -[webserver] -base_url = http://localhost:8080 -web_server_host = 0.0.0.0 -web_server_port = 8080 -dag_orientation = LR -dag_default_view = tree -log_fetch_timeout_sec = 5 -hide_paused_dags_by_default = False -page_size = 100 - -[email] -email_backend = airflow.utils.email.send_email_smtp -email_conn_id = smtp_default - [smtp] -smtp_host = localhost smtp_user = airflow -smtp_port = 25 smtp_password = airflow -smtp_mail_from = airflow@example.com -smtp_retry_limit = 5 -smtp_timeout = 30 [celery] -celery_app_name = airflow.executors.celery_executor -worker_concurrency = 16 broker_url = sqla+mysql://airflow:airflow@localhost:3306/airflow result_backend = db+mysql://airflow:airflow@localhost:3306/airflow -flower_host = 0.0.0.0 -flower_port = 5555 -sync_parallelism = 0 -worker_precheck = False [scheduler] job_heartbeat_sec = 1 schedule_after_task_execution = False -scheduler_heartbeat_sec = 5 -scheduler_health_check_threshold = 30 -parsing_processes = 2 -catchup_by_default = True -scheduler_zombie_task_threshold = 300 +scheduler_health_check_server_port = 8794 dag_dir_list_interval = 0 -max_tis_per_query = 512 [elasticsearch] -host = -log_id_template = {{dag_id}}-{{task_id}}-{{execution_date}}-{{try_number}} -end_of_log_mark = end_of_log - -[elasticsearch_configs] - -use_ssl = False -verify_certs = True - -[kubernetes] -dags_volume_claim = default +log_id_template = {{dag_id}}-{{task_id}}-{{run_id}}-{{map_index}}-{{try_number}} diff --git a/airflow/config_templates/default_webserver_config.py b/airflow/config_templates/default_webserver_config.py index 5c7eaa1b16a39..aa22b125fa98c 100644 --- a/airflow/config_templates/default_webserver_config.py +++ b/airflow/config_templates/default_webserver_config.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Default configuration for the Airflow webserver""" +"""Default configuration for the Airflow webserver.""" +from __future__ import annotations + import os from airflow.www.fab_security.manager import AUTH_DB @@ -30,6 +32,7 @@ # Flask-WTF flag for CSRF WTF_CSRF_ENABLED = True +WTF_CSRF_TIME_LIMIT = None # ---------------------------------------------------- # AUTHENTICATION CONFIG diff --git a/airflow/configuration.py b/airflow/configuration.py index 93612160c8290..ce55aa45c6075 100644 --- a/airflow/configuration.py +++ b/airflow/configuration.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import datetime import functools import json @@ -31,14 +33,15 @@ # Ignored Mypy on configparser because it thinks the configparser module has no _UNSET attribute from configparser import _UNSET, ConfigParser, NoOptionError, NoSectionError # type: ignore -from contextlib import suppress +from contextlib import contextmanager, suppress from json.decoder import JSONDecodeError from re import Pattern -from typing import IO, Any, Dict, Iterable, List, Optional, Set, Tuple, Union -from urllib.parse import urlparse +from typing import IO, Any, Dict, Iterable, Tuple, Union +from urllib.parse import urlsplit from typing_extensions import overload +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowConfigException from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH, BaseSecretsBackend from airflow.utils import yaml @@ -49,8 +52,8 @@ # show Airflow's deprecation warnings if not sys.warnoptions: - warnings.filterwarnings(action='default', category=DeprecationWarning, module='airflow') - warnings.filterwarnings(action='default', category=PendingDeprecationWarning, module='airflow') + warnings.filterwarnings(action="default", category=DeprecationWarning, module="airflow") + warnings.filterwarnings(action="default", category=PendingDeprecationWarning, module="airflow") _SQLITE3_VERSION_PATTERN = re.compile(r"(?P^\d+(?:\.\d+)*)\D?.*$") @@ -59,10 +62,10 @@ ConfigSectionSourcesType = Dict[str, Union[str, Tuple[str, str]]] ConfigSourcesType = Dict[str, ConfigSectionSourcesType] -ENV_VAR_PREFIX = 'AIRFLOW__' +ENV_VAR_PREFIX = "AIRFLOW__" -def _parse_sqlite_version(s: str) -> Tuple[int, ...]: +def _parse_sqlite_version(s: str) -> tuple[int, ...]: match = _SQLITE3_VERSION_PATTERN.match(s) if match is None: return () @@ -79,11 +82,12 @@ def expand_env_var(env_var: str) -> str: ... -def expand_env_var(env_var: Union[str, None]) -> Optional[Union[str, None]]: +def expand_env_var(env_var: str | None) -> str | None: """ - Expands (potentially nested) env vars by repeatedly applying - `expandvars` and `expanduser` until interpolation stops having - any effect. + Expands (potentially nested) env vars. + + Repeat and apply `expandvars` and `expanduser` until + interpolation stops having any effect. """ if not env_var: return env_var @@ -96,11 +100,11 @@ def expand_env_var(env_var: Union[str, None]) -> Optional[Union[str, None]]: def run_command(command: str) -> str: - """Runs command and returns stdout""" + """Runs command and returns stdout.""" process = subprocess.Popen( shlex.split(command), stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True ) - output, stderr = (stream.decode(sys.getdefaultencoding(), 'ignore') for stream in process.communicate()) + output, stderr = (stream.decode(sys.getdefaultencoding(), "ignore") for stream in process.communicate()) if process.returncode != 0: raise AirflowConfigException( @@ -111,8 +115,8 @@ def run_command(command: str) -> str: return output -def _get_config_value_from_secret_backend(config_key: str) -> Optional[str]: - """Get Config option values from Secret Backend""" +def _get_config_value_from_secret_backend(config_key: str) -> str | None: + """Get Config option values from Secret Backend.""" try: secrets_client = get_custom_secret_backend() if not secrets_client: @@ -120,163 +124,183 @@ def _get_config_value_from_secret_backend(config_key: str) -> Optional[str]: return secrets_client.get_config(config_key) except Exception as e: raise AirflowConfigException( - 'Cannot retrieve config from alternative secrets backend. ' - 'Make sure it is configured properly and that the Backend ' - 'is accessible.\n' - f'{e}' + "Cannot retrieve config from alternative secrets backend. " + "Make sure it is configured properly and that the Backend " + "is accessible.\n" + f"{e}" ) def _default_config_file_path(file_name: str) -> str: - templates_dir = os.path.join(os.path.dirname(__file__), 'config_templates') + templates_dir = os.path.join(os.path.dirname(__file__), "config_templates") return os.path.join(templates_dir, file_name) -def default_config_yaml() -> List[Dict[str, Any]]: +def default_config_yaml() -> list[dict[str, Any]]: """ - Read Airflow configs from YAML file + Read Airflow configs from YAML file. :return: Python dictionary containing configs & their info """ - with open(_default_config_file_path('config.yml')) as config_file: + with open(_default_config_file_path("config.yml")) as config_file: return yaml.safe_load(config_file) +SENSITIVE_CONFIG_VALUES = { + ("database", "sql_alchemy_conn"), + ("core", "fernet_key"), + ("celery", "broker_url"), + ("celery", "flower_basic_auth"), + ("celery", "result_backend"), + ("atlas", "password"), + ("smtp", "smtp_password"), + ("webserver", "secret_key"), + # The following options are deprecated + ("core", "sql_alchemy_conn"), +} + + class AirflowConfigParser(ConfigParser): - """Custom Airflow Configparser supporting defaults and deprecated options""" + """Custom Airflow Configparser supporting defaults and deprecated options.""" # These configuration elements can be fetched as the stdout of commands # following the "{section}__{name}_cmd" pattern, the idea behind this # is to not store password on boxes in text files. # These configs can also be fetched from Secrets backend # following the "{section}__{name}__secret" pattern - sensitive_config_values: Set[Tuple[str, str]] = { - ('database', 'sql_alchemy_conn'), - ('core', 'fernet_key'), - ('celery', 'broker_url'), - ('celery', 'flower_basic_auth'), - ('celery', 'result_backend'), - ('atlas', 'password'), - ('smtp', 'smtp_password'), - ('webserver', 'secret_key'), - # The following options are deprecated - ('core', 'sql_alchemy_conn'), - } + + sensitive_config_values: set[tuple[str, str]] = SENSITIVE_CONFIG_VALUES # A mapping of (new section, new option) -> (old section, old option, since_version). # When reading new option, the old option will be checked to see if it exists. If it does a # DeprecationWarning will be issued and the old option will be used instead - deprecated_options: Dict[Tuple[str, str], Tuple[str, str, str]] = { - ('celery', 'worker_precheck'): ('core', 'worker_precheck', '2.0.0'), - ('logging', 'base_log_folder'): ('core', 'base_log_folder', '2.0.0'), - ('logging', 'remote_logging'): ('core', 'remote_logging', '2.0.0'), - ('logging', 'remote_log_conn_id'): ('core', 'remote_log_conn_id', '2.0.0'), - ('logging', 'remote_base_log_folder'): ('core', 'remote_base_log_folder', '2.0.0'), - ('logging', 'encrypt_s3_logs'): ('core', 'encrypt_s3_logs', '2.0.0'), - ('logging', 'logging_level'): ('core', 'logging_level', '2.0.0'), - ('logging', 'fab_logging_level'): ('core', 'fab_logging_level', '2.0.0'), - ('logging', 'logging_config_class'): ('core', 'logging_config_class', '2.0.0'), - ('logging', 'colored_console_log'): ('core', 'colored_console_log', '2.0.0'), - ('logging', 'colored_log_format'): ('core', 'colored_log_format', '2.0.0'), - ('logging', 'colored_formatter_class'): ('core', 'colored_formatter_class', '2.0.0'), - ('logging', 'log_format'): ('core', 'log_format', '2.0.0'), - ('logging', 'simple_log_format'): ('core', 'simple_log_format', '2.0.0'), - ('logging', 'task_log_prefix_template'): ('core', 'task_log_prefix_template', '2.0.0'), - ('logging', 'log_filename_template'): ('core', 'log_filename_template', '2.0.0'), - ('logging', 'log_processor_filename_template'): ('core', 'log_processor_filename_template', '2.0.0'), - ('logging', 'dag_processor_manager_log_location'): ( - 'core', - 'dag_processor_manager_log_location', - '2.0.0', + deprecated_options: dict[tuple[str, str], tuple[str, str, str]] = { + ("celery", "worker_precheck"): ("core", "worker_precheck", "2.0.0"), + ("logging", "base_log_folder"): ("core", "base_log_folder", "2.0.0"), + ("logging", "remote_logging"): ("core", "remote_logging", "2.0.0"), + ("logging", "remote_log_conn_id"): ("core", "remote_log_conn_id", "2.0.0"), + ("logging", "remote_base_log_folder"): ("core", "remote_base_log_folder", "2.0.0"), + ("logging", "encrypt_s3_logs"): ("core", "encrypt_s3_logs", "2.0.0"), + ("logging", "logging_level"): ("core", "logging_level", "2.0.0"), + ("logging", "fab_logging_level"): ("core", "fab_logging_level", "2.0.0"), + ("logging", "logging_config_class"): ("core", "logging_config_class", "2.0.0"), + ("logging", "colored_console_log"): ("core", "colored_console_log", "2.0.0"), + ("logging", "colored_log_format"): ("core", "colored_log_format", "2.0.0"), + ("logging", "colored_formatter_class"): ("core", "colored_formatter_class", "2.0.0"), + ("logging", "log_format"): ("core", "log_format", "2.0.0"), + ("logging", "simple_log_format"): ("core", "simple_log_format", "2.0.0"), + ("logging", "task_log_prefix_template"): ("core", "task_log_prefix_template", "2.0.0"), + ("logging", "log_filename_template"): ("core", "log_filename_template", "2.0.0"), + ("logging", "log_processor_filename_template"): ("core", "log_processor_filename_template", "2.0.0"), + ("logging", "dag_processor_manager_log_location"): ( + "core", + "dag_processor_manager_log_location", + "2.0.0", ), - ('logging', 'task_log_reader'): ('core', 'task_log_reader', '2.0.0'), - ('metrics', 'statsd_on'): ('scheduler', 'statsd_on', '2.0.0'), - ('metrics', 'statsd_host'): ('scheduler', 'statsd_host', '2.0.0'), - ('metrics', 'statsd_port'): ('scheduler', 'statsd_port', '2.0.0'), - ('metrics', 'statsd_prefix'): ('scheduler', 'statsd_prefix', '2.0.0'), - ('metrics', 'statsd_allow_list'): ('scheduler', 'statsd_allow_list', '2.0.0'), - ('metrics', 'stat_name_handler'): ('scheduler', 'stat_name_handler', '2.0.0'), - ('metrics', 'statsd_datadog_enabled'): ('scheduler', 'statsd_datadog_enabled', '2.0.0'), - ('metrics', 'statsd_datadog_tags'): ('scheduler', 'statsd_datadog_tags', '2.0.0'), - ('metrics', 'statsd_custom_client_path'): ('scheduler', 'statsd_custom_client_path', '2.0.0'), - ('scheduler', 'parsing_processes'): ('scheduler', 'max_threads', '1.10.14'), - ('scheduler', 'scheduler_idle_sleep_time'): ('scheduler', 'processor_poll_interval', '2.2.0'), - ('operators', 'default_queue'): ('celery', 'default_queue', '2.1.0'), - ('core', 'hide_sensitive_var_conn_fields'): ('admin', 'hide_sensitive_variable_fields', '2.1.0'), - ('core', 'sensitive_var_conn_names'): ('admin', 'sensitive_variable_fields', '2.1.0'), - ('core', 'default_pool_task_slot_count'): ('core', 'non_pooled_task_slot_count', '1.10.4'), - ('core', 'max_active_tasks_per_dag'): ('core', 'dag_concurrency', '2.2.0'), - ('logging', 'worker_log_server_port'): ('celery', 'worker_log_server_port', '2.2.0'), - ('api', 'access_control_allow_origins'): ('api', 'access_control_allow_origin', '2.2.0'), - ('api', 'auth_backends'): ('api', 'auth_backend', '2.3.0'), - ('database', 'sql_alchemy_conn'): ('core', 'sql_alchemy_conn', '2.3.0'), - ('database', 'sql_engine_encoding'): ('core', 'sql_engine_encoding', '2.3.0'), - ('database', 'sql_engine_collation_for_ids'): ('core', 'sql_engine_collation_for_ids', '2.3.0'), - ('database', 'sql_alchemy_pool_enabled'): ('core', 'sql_alchemy_pool_enabled', '2.3.0'), - ('database', 'sql_alchemy_pool_size'): ('core', 'sql_alchemy_pool_size', '2.3.0'), - ('database', 'sql_alchemy_max_overflow'): ('core', 'sql_alchemy_max_overflow', '2.3.0'), - ('database', 'sql_alchemy_pool_recycle'): ('core', 'sql_alchemy_pool_recycle', '2.3.0'), - ('database', 'sql_alchemy_pool_pre_ping'): ('core', 'sql_alchemy_pool_pre_ping', '2.3.0'), - ('database', 'sql_alchemy_schema'): ('core', 'sql_alchemy_schema', '2.3.0'), - ('database', 'sql_alchemy_connect_args'): ('core', 'sql_alchemy_connect_args', '2.3.0'), - ('database', 'load_default_connections'): ('core', 'load_default_connections', '2.3.0'), - ('database', 'max_db_retries'): ('core', 'max_db_retries', '2.3.0'), + ("logging", "task_log_reader"): ("core", "task_log_reader", "2.0.0"), + ("metrics", "statsd_on"): ("scheduler", "statsd_on", "2.0.0"), + ("metrics", "statsd_host"): ("scheduler", "statsd_host", "2.0.0"), + ("metrics", "statsd_port"): ("scheduler", "statsd_port", "2.0.0"), + ("metrics", "statsd_prefix"): ("scheduler", "statsd_prefix", "2.0.0"), + ("metrics", "statsd_allow_list"): ("scheduler", "statsd_allow_list", "2.0.0"), + ("metrics", "stat_name_handler"): ("scheduler", "stat_name_handler", "2.0.0"), + ("metrics", "statsd_datadog_enabled"): ("scheduler", "statsd_datadog_enabled", "2.0.0"), + ("metrics", "statsd_datadog_tags"): ("scheduler", "statsd_datadog_tags", "2.0.0"), + ("metrics", "statsd_custom_client_path"): ("scheduler", "statsd_custom_client_path", "2.0.0"), + ("scheduler", "parsing_processes"): ("scheduler", "max_threads", "1.10.14"), + ("scheduler", "scheduler_idle_sleep_time"): ("scheduler", "processor_poll_interval", "2.2.0"), + ("operators", "default_queue"): ("celery", "default_queue", "2.1.0"), + ("core", "hide_sensitive_var_conn_fields"): ("admin", "hide_sensitive_variable_fields", "2.1.0"), + ("core", "sensitive_var_conn_names"): ("admin", "sensitive_variable_fields", "2.1.0"), + ("core", "default_pool_task_slot_count"): ("core", "non_pooled_task_slot_count", "1.10.4"), + ("core", "max_active_tasks_per_dag"): ("core", "dag_concurrency", "2.2.0"), + ("logging", "worker_log_server_port"): ("celery", "worker_log_server_port", "2.2.0"), + ("api", "access_control_allow_origins"): ("api", "access_control_allow_origin", "2.2.0"), + ("api", "auth_backends"): ("api", "auth_backend", "2.3.0"), + ("database", "sql_alchemy_conn"): ("core", "sql_alchemy_conn", "2.3.0"), + ("database", "sql_engine_encoding"): ("core", "sql_engine_encoding", "2.3.0"), + ("database", "sql_engine_collation_for_ids"): ("core", "sql_engine_collation_for_ids", "2.3.0"), + ("database", "sql_alchemy_pool_enabled"): ("core", "sql_alchemy_pool_enabled", "2.3.0"), + ("database", "sql_alchemy_pool_size"): ("core", "sql_alchemy_pool_size", "2.3.0"), + ("database", "sql_alchemy_max_overflow"): ("core", "sql_alchemy_max_overflow", "2.3.0"), + ("database", "sql_alchemy_pool_recycle"): ("core", "sql_alchemy_pool_recycle", "2.3.0"), + ("database", "sql_alchemy_pool_pre_ping"): ("core", "sql_alchemy_pool_pre_ping", "2.3.0"), + ("database", "sql_alchemy_schema"): ("core", "sql_alchemy_schema", "2.3.0"), + ("database", "sql_alchemy_connect_args"): ("core", "sql_alchemy_connect_args", "2.3.0"), + ("database", "load_default_connections"): ("core", "load_default_connections", "2.3.0"), + ("database", "max_db_retries"): ("core", "max_db_retries", "2.3.0"), + ("scheduler", "parsing_cleanup_interval"): ("scheduler", "deactivate_stale_dags_interval", "2.5.0"), } + # A mapping of new section -> (old section, since_version). + deprecated_sections: dict[str, tuple[str, str]] = {"kubernetes_executor": ("kubernetes", "2.5.0")} + + # Now build the inverse so we can go from old_section/old_key to new_section/new_key + # if someone tries to retrieve it based on old_section/old_key + @cached_property + def inversed_deprecated_options(self): + return {(sec, name): key for key, (sec, name, ver) in self.deprecated_options.items()} + + @cached_property + def inversed_deprecated_sections(self): + return { + old_section: new_section for new_section, (old_section, ver) in self.deprecated_sections.items() + } + # A mapping of old default values that we want to change and warn the user # about. Mapping of section -> setting -> { old, replace, by_version } - deprecated_values: Dict[str, Dict[str, Tuple[Pattern, str, str]]] = { - 'core': { - 'hostname_callable': (re.compile(r':'), r'.', '2.1'), + deprecated_values: dict[str, dict[str, tuple[Pattern, str, str]]] = { + "core": { + "hostname_callable": (re.compile(r":"), r".", "2.1"), }, - 'webserver': { - 'navbar_color': (re.compile(r'\A#007A87\Z', re.IGNORECASE), '#fff', '2.1'), - 'dag_default_view': (re.compile(r'^tree$'), 'grid', '3.0'), + "webserver": { + "navbar_color": (re.compile(r"\A#007A87\Z", re.IGNORECASE), "#fff", "2.1"), + "dag_default_view": (re.compile(r"^tree$"), "grid", "3.0"), }, - 'email': { - 'email_backend': ( - re.compile(r'^airflow\.contrib\.utils\.sendgrid\.send_email$'), - r'airflow.providers.sendgrid.utils.emailer.send_email', - '2.1', + "email": { + "email_backend": ( + re.compile(r"^airflow\.contrib\.utils\.sendgrid\.send_email$"), + r"airflow.providers.sendgrid.utils.emailer.send_email", + "2.1", ), }, - 'logging': { - 'log_filename_template': ( + "logging": { + "log_filename_template": ( re.compile(re.escape("{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log")), "XX-set-after-default-config-loaded-XX", - '3.0', + "3.0", ), }, - 'api': { - 'auth_backends': ( - re.compile(r'^airflow\.api\.auth\.backend\.deny_all$|^$'), - 'airflow.api.auth.backend.session', - '3.0', + "api": { + "auth_backends": ( + re.compile(r"^airflow\.api\.auth\.backend\.deny_all$|^$"), + "airflow.api.auth.backend.session", + "3.0", ), }, - 'elasticsearch': { - 'log_id_template': ( - re.compile('^' + re.escape('{dag_id}-{task_id}-{run_id}-{try_number}') + '$'), - '{dag_id}-{task_id}-{run_id}-{map_index}-{try_number}', - '3.0', + "elasticsearch": { + "log_id_template": ( + re.compile("^" + re.escape("{dag_id}-{task_id}-{execution_date}-{try_number}") + "$"), + "{dag_id}-{task_id}-{run_id}-{map_index}-{try_number}", + "3.0", ) }, } - _available_logging_levels = ['CRITICAL', 'FATAL', 'ERROR', 'WARN', 'WARNING', 'INFO', 'DEBUG'] + _available_logging_levels = ["CRITICAL", "FATAL", "ERROR", "WARN", "WARNING", "INFO", "DEBUG"] enums_options = { ("core", "default_task_weight_rule"): sorted(WeightRule.all_weight_rules()), ("core", "dag_ignore_file_syntax"): ["regexp", "glob"], - ('core', 'mp_start_method'): multiprocessing.get_all_start_methods(), + ("core", "mp_start_method"): multiprocessing.get_all_start_methods(), ("scheduler", "file_parsing_sort_mode"): ["modified_time", "random_seeded_by_host", "alphabetical"], ("logging", "logging_level"): _available_logging_levels, ("logging", "fab_logging_level"): _available_logging_levels, # celery_logging_level can be empty, which uses logging_level as fallback - ("logging", "celery_logging_level"): _available_logging_levels + [''], - ("webserver", "analytical_tool"): ['google_analytics', 'metarouter', 'segment', ''], + ("logging", "celery_logging_level"): _available_logging_levels + [""], + ("webserver", "analytical_tool"): ["google_analytics", "metarouter", "segment", ""], } - upgraded_values: Dict[Tuple[str, str], str] + upgraded_values: dict[tuple[str, str], str] """Mapping of (section,option) to the old value that was upgraded""" # This method transforms option names on every read, get, or set operation. @@ -285,7 +309,7 @@ class AirflowConfigParser(ConfigParser): def optionxform(self, optionstr: str) -> str: return optionstr - def __init__(self, default_config: Optional[str] = None, *args, **kwargs): + def __init__(self, default_config: str | None = None, *args, **kwargs): super().__init__(*args, **kwargs) self.upgraded_values = {} @@ -293,10 +317,10 @@ def __init__(self, default_config: Optional[str] = None, *args, **kwargs): if default_config is not None: self.airflow_defaults.read_string(default_config) # Set the upgrade value based on the current loaded default - default = self.airflow_defaults.get('logging', 'log_filename_template', fallback=None) + default = self.airflow_defaults.get("logging", "log_filename_template", fallback=None) if default: - replacement = self.deprecated_values['logging']['log_filename_template'] - self.deprecated_values['logging']['log_filename_template'] = ( + replacement = self.deprecated_values["logging"]["log_filename_template"] + self.deprecated_values["logging"]["log_filename_template"] = ( replacement[0], default, replacement[2], @@ -304,12 +328,13 @@ def __init__(self, default_config: Optional[str] = None, *args, **kwargs): else: # In case of tests it might not exist with suppress(KeyError): - del self.deprecated_values['logging']['log_filename_template'] + del self.deprecated_values["logging"]["log_filename_template"] else: with suppress(KeyError): - del self.deprecated_values['logging']['log_filename_template'] + del self.deprecated_values["logging"]["log_filename_template"] self.is_validated = False + self._suppress_future_warnings = False def validate(self): self._validate_config_dependencies() @@ -337,14 +362,15 @@ def validate(self): def _upgrade_auth_backends(self): """ - Ensure a custom auth_backends setting contains session, - which is needed by the UI for ajax queries. + Ensure a custom auth_backends setting contains session. + + This is required by the UI for ajax queries. """ old_value = self.get("api", "auth_backends", fallback="") - if old_value in ('airflow.api.auth.backend.default', ''): + if old_value in ("airflow.api.auth.backend.default", ""): # handled by deprecated_values pass - elif old_value.find('airflow.api.auth.backend.session') == -1: + elif old_value.find("airflow.api.auth.backend.session") == -1: new_value = old_value + ",airflow.api.auth.backend.session" self._update_env_var(section="api", name="auth_backends", new_value=new_value) self.upgraded_values[("api", "auth_backends")] = old_value @@ -355,28 +381,33 @@ def _upgrade_auth_backends(self): os.environ.pop(old_env_var, None) warnings.warn( - 'The auth_backends setting in [api] has had airflow.api.auth.backend.session added ' - 'in the running config, which is needed by the UI. Please update your config before ' - 'Apache Airflow 3.0.', + "The auth_backends setting in [api] has had airflow.api.auth.backend.session added " + "in the running config, which is needed by the UI. Please update your config before " + "Apache Airflow 3.0.", FutureWarning, ) def _upgrade_postgres_metastore_conn(self): - """As of sqlalchemy 1.4, scheme `postgres+psycopg2` must be replaced with `postgresql`""" - section, key = 'database', 'sql_alchemy_conn' - old_value = self.get(section, key) - bad_scheme = 'postgres+psycopg2' - good_scheme = 'postgresql' - parsed = urlparse(old_value) - if parsed.scheme == bad_scheme: + """ + Upgrade SQL schemas. + + As of SQLAlchemy 1.4, schemes `postgres+psycopg2` and `postgres` + must be replaced with `postgresql`. + """ + section, key = "database", "sql_alchemy_conn" + old_value = self.get(section, key, _extra_stacklevel=1) + bad_schemes = ["postgres+psycopg2", "postgres"] + good_scheme = "postgresql" + parsed = urlsplit(old_value) + if parsed.scheme in bad_schemes: warnings.warn( - f"Bad scheme in Airflow configuration core > sql_alchemy_conn: `{bad_scheme}`. " - "As of SqlAlchemy 1.4 (adopted in Airflow 2.3) this is no longer supported. You must " + f"Bad scheme in Airflow configuration core > sql_alchemy_conn: `{parsed.scheme}`. " + "As of SQLAlchemy 1.4 (adopted in Airflow 2.3) this is no longer supported. You must " f"change to `{good_scheme}` before the next Airflow release.", FutureWarning, ) self.upgraded_values[(section, key)] = old_value - new_value = re.sub('^' + re.escape(f"{bad_scheme}://"), f"{good_scheme}://", old_value) + new_value = re.sub("^" + re.escape(f"{parsed.scheme}://"), f"{good_scheme}://", old_value) self._update_env_var(section=section, name=key, new_value=new_value) # if the old value is set via env var, we need to wipe it @@ -385,7 +416,7 @@ def _upgrade_postgres_metastore_conn(self): os.environ.pop(old_env_var, None) def _validate_enums(self): - """Validate that enum type config has an accepted value""" + """Validate that enum type config has an accepted value.""" for (section_key, option_key), enum_options in self.enums_options.items(): if self.has_option(section_key, option_key): value = self.get(section_key, option_key) @@ -397,14 +428,16 @@ def _validate_enums(self): def _validate_config_dependencies(self): """ - Validate that config values aren't invalid given other config values + Validate that config based on condition. + + Values are considered invalid when they conflict with other config values or system-level limitations and requirements. """ is_executor_without_sqlite_support = self.get("core", "executor") not in ( - 'DebugExecutor', - 'SequentialExecutor', + "DebugExecutor", + "SequentialExecutor", ) - is_sqlite = "sqlite" in self.get('database', 'sql_alchemy_conn') + is_sqlite = "sqlite" in self.get("database", "sql_alchemy_conn") if is_sqlite and is_executor_without_sqlite_support: raise AirflowConfigException(f"error: cannot use sqlite with the {self.get('core', 'executor')}") if is_sqlite: @@ -424,7 +457,7 @@ def _validate_config_dependencies(self): def _using_old_value(self, old: Pattern, current_value: str) -> bool: return old.search(current_value) is not None - def _update_env_var(self, section: str, name: str, new_value: Union[str]): + def _update_env_var(self, section: str, name: str, new_value: str): env_var = self._env_var_name(section, name) # Set it as an env var so that any subprocesses keep the same override! os.environ[env_var] = new_value @@ -432,14 +465,14 @@ def _update_env_var(self, section: str, name: str, new_value: Union[str]): @staticmethod def _create_future_warning(name: str, section: str, current_value: Any, new_value: Any, version: str): warnings.warn( - f'The {name!r} setting in [{section}] has the old default value of {current_value!r}. ' - f'This value has been changed to {new_value!r} in the running config, but ' - f'please update your config before Apache Airflow {version}.', + f"The {name!r} setting in [{section}] has the old default value of {current_value!r}. " + f"This value has been changed to {new_value!r} in the running config, but " + f"please update your config before Apache Airflow {version}.", FutureWarning, ) def _env_var_name(self, section: str, key: str) -> str: - return f'{ENV_VAR_PREFIX}{section.upper()}__{key.upper()}' + return f"{ENV_VAR_PREFIX}{section.upper()}__{key.upper()}" def _get_env_var_option(self, section: str, key: str): # must have format AIRFLOW__{SECTION}__{KEY} (note double underscore) @@ -447,13 +480,13 @@ def _get_env_var_option(self, section: str, key: str): if env_var in os.environ: return expand_env_var(os.environ[env_var]) # alternatively AIRFLOW__{SECTION}__{KEY}_CMD (for a command) - env_var_cmd = env_var + '_CMD' + env_var_cmd = env_var + "_CMD" if env_var_cmd in os.environ: # if this is a valid command key... if (section, key) in self.sensitive_config_values: return run_command(os.environ[env_var_cmd]) # alternatively AIRFLOW__{SECTION}__{KEY}_SECRET (to get from Secrets Backend) - env_var_secret_path = env_var + '_SECRET' + env_var_secret_path = env_var + "_SECRET" if env_var_secret_path in os.environ: # if this is a valid secret path... if (section, key) in self.sensitive_config_values: @@ -461,7 +494,7 @@ def _get_env_var_option(self, section: str, key: str): return None def _get_cmd_option(self, section: str, key: str): - fallback_key = key + '_cmd' + fallback_key = key + "_cmd" if (section, key) in self.sensitive_config_values: if super().has_option(section, fallback_key): command = super().get(section, fallback_key) @@ -470,8 +503,8 @@ def _get_cmd_option(self, section: str, key: str): def _get_cmd_option_from_config_sources( self, config_sources: ConfigSourcesType, section: str, key: str - ) -> Optional[str]: - fallback_key = key + '_cmd' + ) -> str | None: + fallback_key = key + "_cmd" if (section, key) in self.sensitive_config_values: section_dict = config_sources.get(section) if section_dict is not None: @@ -484,9 +517,9 @@ def _get_cmd_option_from_config_sources( return run_command(command) return None - def _get_secret_option(self, section: str, key: str) -> Optional[str]: - """Get Config option values from Secret Backend""" - fallback_key = key + '_secret' + def _get_secret_option(self, section: str, key: str) -> str | None: + """Get Config option values from Secret Backend.""" + fallback_key = key + "_secret" if (section, key) in self.sensitive_config_values: if super().has_option(section, fallback_key): secrets_path = super().get(section, fallback_key) @@ -495,8 +528,8 @@ def _get_secret_option(self, section: str, key: str) -> Optional[str]: def _get_secret_option_from_config_sources( self, config_sources: ConfigSourcesType, section: str, key: str - ) -> Optional[str]: - fallback_key = key + '_secret' + ) -> str | None: + fallback_key = key + "_secret" if (section, key) in self.sensitive_config_values: section_dict = config_sources.get(section) if section_dict is not None: @@ -510,115 +543,219 @@ def _get_secret_option_from_config_sources( return None def get_mandatory_value(self, section: str, key: str, **kwargs) -> str: - value = self.get(section, key, **kwargs) + value = self.get(section, key, _extra_stacklevel=1, **kwargs) if value is None: raise ValueError(f"The value {section}/{key} should be set!") return value - def get(self, section: str, key: str, **kwargs) -> Optional[str]: # type: ignore[override] + @overload # type: ignore[override] + def get(self, section: str, key: str, fallback: str = ..., **kwargs) -> str: # type: ignore[override] + + ... + + @overload # type: ignore[override] + def get(self, section: str, key: str, **kwargs) -> str | None: # type: ignore[override] + + ... + + def get( # type: ignore[override, misc] + self, + section: str, + key: str, + _extra_stacklevel: int = 0, + **kwargs, + ) -> str | None: section = str(section).lower() key = str(key).lower() + warning_emitted = False + deprecated_section: str | None + deprecated_key: str | None + + # For when we rename whole sections + if section in self.inversed_deprecated_sections: + deprecated_section, deprecated_key = (section, key) + section = self.inversed_deprecated_sections[section] + if not self._suppress_future_warnings: + warnings.warn( + f"The config section [{deprecated_section}] has been renamed to " + f"[{section}]. Please update your `conf.get*` call to use the new name", + FutureWarning, + stacklevel=2 + _extra_stacklevel, + ) + # Don't warn about individual rename if the whole section is renamed + warning_emitted = True + elif (section, key) in self.inversed_deprecated_options: + # Handle using deprecated section/key instead of the new section/key + new_section, new_key = self.inversed_deprecated_options[(section, key)] + if not self._suppress_future_warnings and not warning_emitted: + warnings.warn( + f"section/key [{section}/{key}] has been deprecated, you should use" + f"[{new_section}/{new_key}] instead. Please update your `conf.get*` call to use the " + "new name", + FutureWarning, + stacklevel=2 + _extra_stacklevel, + ) + warning_emitted = True + deprecated_section, deprecated_key = section, key + section, key = (new_section, new_key) + elif section in self.deprecated_sections: + # When accessing the new section name, make sure we check under the old config name + deprecated_key = key + deprecated_section = self.deprecated_sections[section][0] + else: + deprecated_section, deprecated_key, _ = self.deprecated_options.get( + (section, key), (None, None, None) + ) - deprecated_section, deprecated_key, _ = self.deprecated_options.get( - (section, key), (None, None, None) + # first check environment variables + option = self._get_environment_variables( + deprecated_key, + deprecated_section, + key, + section, + issue_warning=not warning_emitted, + extra_stacklevel=_extra_stacklevel, ) - - option = self._get_environment_variables(deprecated_key, deprecated_section, key, section) if option is not None: return option - option = self._get_option_from_config_file(deprecated_key, deprecated_section, key, kwargs, section) + # ...then the config file + option = self._get_option_from_config_file( + deprecated_key, + deprecated_section, + key, + kwargs, + section, + issue_warning=not warning_emitted, + extra_stacklevel=_extra_stacklevel, + ) if option is not None: return option - option = self._get_option_from_commands(deprecated_key, deprecated_section, key, section) + # ...then commands + option = self._get_option_from_commands( + deprecated_key, + deprecated_section, + key, + section, + issue_warning=not warning_emitted, + extra_stacklevel=_extra_stacklevel, + ) if option is not None: return option - option = self._get_option_from_secrets(deprecated_key, deprecated_section, key, section) + # ...then from secret backends + option = self._get_option_from_secrets( + deprecated_key, + deprecated_section, + key, + section, + issue_warning=not warning_emitted, + extra_stacklevel=_extra_stacklevel, + ) if option is not None: return option - return self._get_option_from_default_config(section, key, **kwargs) - - def _get_option_from_default_config(self, section: str, key: str, **kwargs) -> Optional[str]: # ...then the default config - if self.airflow_defaults.has_option(section, key) or 'fallback' in kwargs: + if self.airflow_defaults.has_option(section, key) or "fallback" in kwargs: return expand_env_var(self.airflow_defaults.get(section, key, **kwargs)) - else: - log.warning("section/key [%s/%s] not found in config", section, key) + log.warning("section/key [%s/%s] not found in config", section, key) - raise AirflowConfigException(f"section/key [{section}/{key}] not found in config") + raise AirflowConfigException(f"section/key [{section}/{key}] not found in config") def _get_option_from_secrets( - self, deprecated_key: Optional[str], deprecated_section: Optional[str], key: str, section: str - ) -> Optional[str]: - # ...then from secret backends + self, + deprecated_key: str | None, + deprecated_section: str | None, + key: str, + section: str, + issue_warning: bool = True, + extra_stacklevel: int = 0, + ) -> str | None: option = self._get_secret_option(section, key) if option: return option if deprecated_section and deprecated_key: - option = self._get_secret_option(deprecated_section, deprecated_key) + with self.suppress_future_warnings(): + option = self._get_secret_option(deprecated_section, deprecated_key) if option: - self._warn_deprecate(section, key, deprecated_section, deprecated_key) + if issue_warning: + self._warn_deprecate(section, key, deprecated_section, deprecated_key, extra_stacklevel) return option return None def _get_option_from_commands( - self, deprecated_key: Optional[str], deprecated_section: Optional[str], key: str, section: str - ) -> Optional[str]: - # ...then commands + self, + deprecated_key: str | None, + deprecated_section: str | None, + key: str, + section: str, + issue_warning: bool = True, + extra_stacklevel: int = 0, + ) -> str | None: option = self._get_cmd_option(section, key) if option: return option if deprecated_section and deprecated_key: - option = self._get_cmd_option(deprecated_section, deprecated_key) + with self.suppress_future_warnings(): + option = self._get_cmd_option(deprecated_section, deprecated_key) if option: - self._warn_deprecate(section, key, deprecated_section, deprecated_key) + if issue_warning: + self._warn_deprecate(section, key, deprecated_section, deprecated_key, extra_stacklevel) return option return None def _get_option_from_config_file( self, - deprecated_key: Optional[str], - deprecated_section: Optional[str], + deprecated_key: str | None, + deprecated_section: str | None, key: str, - kwargs: Dict[str, Any], + kwargs: dict[str, Any], section: str, - ) -> Optional[str]: - # ...then the config file + issue_warning: bool = True, + extra_stacklevel: int = 0, + ) -> str | None: if super().has_option(section, key): # Use the parent's methods to get the actual config here to be able to # separate the config from default config. return expand_env_var(super().get(section, key, **kwargs)) if deprecated_section and deprecated_key: if super().has_option(deprecated_section, deprecated_key): - self._warn_deprecate(section, key, deprecated_section, deprecated_key) - return expand_env_var(super().get(deprecated_section, deprecated_key, **kwargs)) + if issue_warning: + self._warn_deprecate(section, key, deprecated_section, deprecated_key, extra_stacklevel) + with self.suppress_future_warnings(): + return expand_env_var(super().get(deprecated_section, deprecated_key, **kwargs)) return None def _get_environment_variables( - self, deprecated_key: Optional[str], deprecated_section: Optional[str], key: str, section: str - ) -> Optional[str]: - # first check environment variables + self, + deprecated_key: str | None, + deprecated_section: str | None, + key: str, + section: str, + issue_warning: bool = True, + extra_stacklevel: int = 0, + ) -> str | None: option = self._get_env_var_option(section, key) if option is not None: return option if deprecated_section and deprecated_key: - option = self._get_env_var_option(deprecated_section, deprecated_key) + with self.suppress_future_warnings(): + option = self._get_env_var_option(deprecated_section, deprecated_key) if option is not None: - self._warn_deprecate(section, key, deprecated_section, deprecated_key) + if issue_warning: + self._warn_deprecate(section, key, deprecated_section, deprecated_key, extra_stacklevel) return option return None def getboolean(self, section: str, key: str, **kwargs) -> bool: # type: ignore[override] - val = str(self.get(section, key, **kwargs)).lower().strip() - if '#' in val: - val = val.split('#')[0].strip() - if val in ('t', 'true', '1'): + val = str(self.get(section, key, _extra_stacklevel=1, **kwargs)).lower().strip() + if "#" in val: + val = val.split("#")[0].strip() + if val in ("t", "true", "1"): return True - elif val in ('f', 'false', '0'): + elif val in ("f", "false", "0"): return False else: raise AirflowConfigException( @@ -627,10 +764,10 @@ def getboolean(self, section: str, key: str, **kwargs) -> bool: # type: ignore[ ) def getint(self, section: str, key: str, **kwargs) -> int: # type: ignore[override] - val = self.get(section, key, **kwargs) + val = self.get(section, key, _extra_stacklevel=1, **kwargs) if val is None: raise AirflowConfigException( - f'Failed to convert value None to int. ' + f"Failed to convert value None to int. " f'Please check "{key}" key in "{section}" section is set.' ) try: @@ -642,10 +779,10 @@ def getint(self, section: str, key: str, **kwargs) -> int: # type: ignore[overr ) def getfloat(self, section: str, key: str, **kwargs) -> float: # type: ignore[override] - val = self.get(section, key, **kwargs) + val = self.get(section, key, _extra_stacklevel=1, **kwargs) if val is None: raise AirflowConfigException( - f'Failed to convert value None to float. ' + f"Failed to convert value None to float. " f'Please check "{key}" key in "{section}" section is set.' ) try: @@ -679,7 +816,7 @@ def getimport(self, section: str, key: str, **kwargs) -> Any: def getjson( self, section: str, key: str, fallback=_UNSET, **kwargs - ) -> Union[dict, list, str, int, float, None]: + ) -> dict | list | str | int | float | None: """ Return a config value parsed from a JSON string. @@ -693,7 +830,7 @@ def getjson( fallback = _UNSET try: - data = self.get(section=section, key=key, fallback=fallback, **kwargs) + data = self.get(section=section, key=key, fallback=fallback, _extra_stacklevel=1, **kwargs) except (NoSectionError, NoOptionError): return default @@ -703,13 +840,14 @@ def getjson( try: return json.loads(data) except JSONDecodeError as e: - raise AirflowConfigException(f'Unable to parse [{section}] {key!r} as valid json') from e + raise AirflowConfigException(f"Unable to parse [{section}] {key!r} as valid json") from e def gettimedelta( self, section: str, key: str, fallback: Any = None, **kwargs - ) -> Optional[datetime.timedelta]: + ) -> datetime.timedelta | None: """ Gets the config value for the given section and key, and converts it into datetime.timedelta object. + If the key is missing, then it is considered as `None`. :param section: the section from the config @@ -718,7 +856,7 @@ def gettimedelta( :raises AirflowConfigException: raised because ValueError or OverflowError :return: datetime.timedelta(seconds=) or None """ - val = self.get(section, key, fallback=fallback, **kwargs) + val = self.get(section, key, fallback=fallback, _extra_stacklevel=1, **kwargs) if val: # the given value must be convertible to integer @@ -734,8 +872,8 @@ def gettimedelta( return datetime.timedelta(seconds=int_val) except OverflowError as err: raise AirflowConfigException( - f'Failed to convert value to timedelta in `seconds`. ' - f'{err}. ' + f"Failed to convert value to timedelta in `seconds`. " + f"{err}. " f'Please check "{key}" key in "{section}" section. Current value: "{val}".' ) @@ -743,12 +881,7 @@ def gettimedelta( def read( self, - filenames: Union[ - str, - bytes, - os.PathLike, - Iterable[Union[str, bytes, os.PathLike]], - ], + filenames: (str | bytes | os.PathLike | Iterable[str | bytes | os.PathLike]), encoding=None, ): super().read(filenames=filenames, encoding=encoding) @@ -756,7 +889,7 @@ def read( # The RawConfigParser defines "Mapping" from abc.collections is not subscriptable - so we have # to use Dict here. def read_dict( # type: ignore[override] - self, dictionary: Dict[str, Dict[str, Any]], source: str = '' + self, dictionary: dict[str, dict[str, Any]], source: str = "" ): super().read_dict(dictionary=dictionary, source=source) @@ -765,16 +898,17 @@ def has_option(self, section: str, option: str) -> bool: # Using self.get() to avoid reimplementing the priority order # of config variables (env, config, cmd, defaults) # UNSET to avoid logging a warning about missing values - self.get(section, option, fallback=_UNSET) + self.get(section, option, fallback=_UNSET, _extra_stacklevel=1) return True except (NoOptionError, NoSectionError): return False def remove_option(self, section: str, option: str, remove_default: bool = True): """ - Remove an option if it exists in config from a file or - default config. If both of config have the same option, this removes - the option in both configs unless remove_default=False. + Remove an option if it exists in config from a file or default config. + + If both of config have the same option, this removes the option + in both configs unless remove_default=False. """ if super().has_option(section, option): super().remove_option(section, option) @@ -782,13 +916,13 @@ def remove_option(self, section: str, option: str, remove_default: bool = True): if self.airflow_defaults.has_option(section, option) and remove_default: self.airflow_defaults.remove_option(section, option) - def getsection(self, section: str) -> Optional[ConfigOptionsDictType]: + def getsection(self, section: str) -> ConfigOptionsDictType | None: """ - Returns the section as a dict. Values are converted to int, float, bool - as required. + Returns the section as a dict. + + Values are converted to int, float, bool as required. :param section: section from the config - :rtype: dict """ if not self.has_section(section) and not self.airflow_defaults.has_section(section): return None @@ -800,10 +934,10 @@ def getsection(self, section: str) -> Optional[ConfigOptionsDictType]: if self.has_section(section): _section.update(OrderedDict(self.items(section))) - section_prefix = self._env_var_name(section, '') + section_prefix = self._env_var_name(section, "") for env_var in sorted(os.environ.keys()): if env_var.startswith(section_prefix): - key = env_var.replace(section_prefix, '') + key = env_var.replace(section_prefix, "") if key.endswith("_CMD"): key = key[:-4] key = key.lower() @@ -812,7 +946,7 @@ def getsection(self, section: str) -> Optional[ConfigOptionsDictType]: for key, val in _section.items(): if val is None: raise AirflowConfigException( - f'Failed to convert value automatically. ' + f"Failed to convert value automatically. " f'Please check "{key}" key in "{section}" section is set.' ) try: @@ -821,9 +955,9 @@ def getsection(self, section: str) -> Optional[ConfigOptionsDictType]: try: _section[key] = float(val) except ValueError: - if isinstance(val, str) and val.lower() in ('t', 'true'): + if isinstance(val, str) and val.lower() in ("t", "true"): _section[key] = True - elif isinstance(val, str) and val.lower() in ('f', 'false'): + elif isinstance(val, str) and val.lower() in ("f", "false"): _section[key] = False return _section @@ -880,14 +1014,22 @@ def as_dict( :param include_secret: Should the result of calling any *_secret config be set (True, default), or should the _secret options be left as the path to get the secret from (False) - :rtype: Dict[str, Dict[str, str]] :return: Dictionary, where the key is the name of the section and the content is the dictionary with the name of the parameter and its value. """ + if not display_sensitive: + # We want to hide the sensitive values at the appropriate methods + # since envs from cmds, secrets can be read at _include_envs method + if not all([include_env, include_cmds, include_secret]): + raise ValueError( + "If display_sensitive is false, then include_env, " + "include_cmds, include_secret must all be set as True" + ) + config_sources: ConfigSourcesType = {} configs = [ - ('default', self.airflow_defaults), - ('airflow.cfg', self), + ("default", self.airflow_defaults), + ("airflow.cfg", self), ] self._replace_config_with_display_sources( @@ -919,6 +1061,20 @@ def as_dict( else: self._filter_by_source(config_sources, display_source, self._get_secret_option) + if not display_sensitive: + # This ensures the ones from config file is hidden too + # if they are not provided through env, cmd and secret + hidden = "< hidden >" + for (section, key) in self.sensitive_config_values: + if not config_sources.get(section): + continue + if config_sources[section].get(key, None): + if display_source: + source = config_sources[section][key][1] + config_sources[section][key] = (hidden, source) + else: + config_sources[section][key] = hidden + return config_sources def _include_secrets( @@ -929,18 +1085,18 @@ def _include_secrets( raw: bool, ): for (section, key) in self.sensitive_config_values: - value: Optional[str] = self._get_secret_option_from_config_sources(config_sources, section, key) + value: str | None = self._get_secret_option_from_config_sources(config_sources, section, key) if value: if not display_sensitive: - value = '< hidden >' + value = "< hidden >" if display_source: - opt: Union[str, Tuple[str, str]] = (value, 'secret') + opt: str | tuple[str, str] = (value, "secret") elif raw: - opt = value.replace('%', '%%') + opt = value.replace("%", "%%") else: opt = value config_sources.setdefault(section, OrderedDict()).update({key: opt}) - del config_sources[section][key + '_secret'] + del config_sources[section][key + "_secret"] def _include_commands( self, @@ -953,17 +1109,17 @@ def _include_commands( opt = self._get_cmd_option_from_config_sources(config_sources, section, key) if not opt: continue - opt_to_set: Union[str, Tuple[str, str], None] = opt + opt_to_set: str | tuple[str, str] | None = opt if not display_sensitive: - opt_to_set = '< hidden >' + opt_to_set = "< hidden >" if display_source: - opt_to_set = (str(opt_to_set), 'cmd') + opt_to_set = (str(opt_to_set), "cmd") elif raw: - opt_to_set = str(opt_to_set).replace('%', '%%') + opt_to_set = str(opt_to_set).replace("%", "%%") if opt_to_set is not None: - dict_to_update: Dict[str, Union[str, Tuple[str, str]]] = {key: opt_to_set} + dict_to_update: dict[str, str | tuple[str, str]] = {key: opt_to_set} config_sources.setdefault(section, OrderedDict()).update(dict_to_update) - del config_sources[section][key + '_cmd'] + del config_sources[section][key + "_cmd"] def _include_envs( self, @@ -976,26 +1132,29 @@ def _include_envs( os_environment for os_environment in os.environ if os_environment.startswith(ENV_VAR_PREFIX) ]: try: - _, section, key = env_var.split('__', 2) + _, section, key = env_var.split("__", 2) opt = self._get_env_var_option(section, key) except ValueError: continue if opt is None: log.warning("Ignoring unknown env var '%s'", env_var) continue - if not display_sensitive and env_var != self._env_var_name('core', 'unit_test_mode'): - opt = '< hidden >' + if not display_sensitive and env_var != self._env_var_name("core", "unit_test_mode"): + # Don't hide cmd/secret values here + if not env_var.lower().endswith("cmd") and not env_var.lower().endswith("secret"): + if (section, key) in self.sensitive_config_values: + opt = "< hidden >" elif raw: - opt = opt.replace('%', '%%') + opt = opt.replace("%", "%%") if display_source: - opt = (opt, 'env var') + opt = (opt, "env var") section = section.lower() # if we lower key for kubernetes_environment_variables section, # then we won't be able to set any Airflow environment # variables. Airflow only parse environment variables starts # with AIRFLOW_. Therefore, we need to make it a special case. - if section != 'kubernetes_environment_variables': + if section != "kubernetes_environment_variables": key = key.lower() config_sources.setdefault(section, OrderedDict()).update({key: opt}) @@ -1006,8 +1165,9 @@ def _filter_by_source( getter_func, ): """ - Deletes default configs from current configuration (an OrderedDict of - OrderedDicts) if it would conflict with special sensitive_config_values. + Deletes default configs from current configuration. + + An OrderedDict of OrderedDicts, if it would conflict with special sensitive_config_values. This is necessary because bare configs take precedence over the command or secret key equivalents so if the current running config is @@ -1020,7 +1180,6 @@ def _filter_by_source( Source is either 'airflow.cfg', 'default', 'env var', or 'cmd'. :param getter_func: A callback function that gets the user configured override value for a particular sensitive_config_values config. - :rtype: None :return: None, the given config_sources is filtered if necessary, otherwise untouched. """ @@ -1050,10 +1209,10 @@ def _filter_by_source( @staticmethod def _replace_config_with_display_sources( config_sources: ConfigSourcesType, - configs: Iterable[Tuple[str, ConfigParser]], + configs: Iterable[tuple[str, ConfigParser]], display_source: bool, raw: bool, - deprecated_options: Dict[Tuple[str, str], Tuple[str, str, str]], + deprecated_options: dict[tuple[str, str], tuple[str, str, str]], include_env: bool, include_cmds: bool, include_secret: bool, @@ -1078,10 +1237,10 @@ def _replace_config_with_display_sources( def _deprecated_value_is_set_in_config( deprecated_section: str, deprecated_key: str, - configs: Iterable[Tuple[str, ConfigParser]], + configs: Iterable[tuple[str, ConfigParser]], ) -> bool: for config_type, config in configs: - if config_type == 'default': + if config_type == "default": continue try: deprecated_section_array = config.items(section=deprecated_section, raw=True) @@ -1095,13 +1254,13 @@ def _deprecated_value_is_set_in_config( @staticmethod def _deprecated_variable_is_set(deprecated_section: str, deprecated_key: str) -> bool: return ( - os.environ.get(f'{ENV_VAR_PREFIX}{deprecated_section.upper()}__{deprecated_key.upper()}') + os.environ.get(f"{ENV_VAR_PREFIX}{deprecated_section.upper()}__{deprecated_key.upper()}") is not None ) @staticmethod def _deprecated_command_is_set_in_config( - deprecated_section: str, deprecated_key: str, configs: Iterable[Tuple[str, ConfigParser]] + deprecated_section: str, deprecated_key: str, configs: Iterable[tuple[str, ConfigParser]] ) -> bool: return AirflowConfigParser._deprecated_value_is_set_in_config( deprecated_section=deprecated_section, deprecated_key=deprecated_key + "_cmd", configs=configs @@ -1110,13 +1269,13 @@ def _deprecated_command_is_set_in_config( @staticmethod def _deprecated_variable_command_is_set(deprecated_section: str, deprecated_key: str) -> bool: return ( - os.environ.get(f'{ENV_VAR_PREFIX}{deprecated_section.upper()}__{deprecated_key.upper()}_CMD') + os.environ.get(f"{ENV_VAR_PREFIX}{deprecated_section.upper()}__{deprecated_key.upper()}_CMD") is not None ) @staticmethod def _deprecated_secret_is_set_in_config( - deprecated_section: str, deprecated_key: str, configs: Iterable[Tuple[str, ConfigParser]] + deprecated_section: str, deprecated_key: str, configs: Iterable[tuple[str, ConfigParser]] ) -> bool: return AirflowConfigParser._deprecated_value_is_set_in_config( deprecated_section=deprecated_section, deprecated_key=deprecated_key + "_secret", configs=configs @@ -1125,10 +1284,17 @@ def _deprecated_secret_is_set_in_config( @staticmethod def _deprecated_variable_secret_is_set(deprecated_section: str, deprecated_key: str) -> bool: return ( - os.environ.get(f'{ENV_VAR_PREFIX}{deprecated_section.upper()}__{deprecated_key.upper()}_SECRET') + os.environ.get(f"{ENV_VAR_PREFIX}{deprecated_section.upper()}__{deprecated_key.upper()}_SECRET") is not None ) + @contextmanager + def suppress_future_warnings(self): + suppress_future_warnings = self._suppress_future_warnings + self._suppress_future_warnings = True + yield self + self._suppress_future_warnings = suppress_future_warnings + @staticmethod def _replace_section_config_with_display_sources( config: ConfigParser, @@ -1137,17 +1303,22 @@ def _replace_section_config_with_display_sources( raw: bool, section: str, source_name: str, - deprecated_options: Dict[Tuple[str, str], Tuple[str, str, str]], - configs: Iterable[Tuple[str, ConfigParser]], + deprecated_options: dict[tuple[str, str], tuple[str, str, str]], + configs: Iterable[tuple[str, ConfigParser]], include_env: bool, include_cmds: bool, include_secret: bool, ): sect = config_sources.setdefault(section, OrderedDict()) - for (k, val) in config.items(section=section, raw=raw): + if isinstance(config, AirflowConfigParser): + with config.suppress_future_warnings(): + items = config.items(section=section, raw=raw) + else: + items = config.items(section=section, raw=raw) + for k, val in items: deprecated_section, deprecated_key, _ = deprecated_options.get((section, k), (None, None, None)) if deprecated_section and deprecated_key: - if source_name == 'default': + if source_name == "default": # If deprecated entry has some non-default value set for any of the sources requested, # We should NOT set default for the new entry (because it will override anything # coming from the deprecated ones) @@ -1194,62 +1365,64 @@ def load_test_config(self): # then read test config - path = _default_config_file_path('default_test.cfg') + path = _default_config_file_path("default_test.cfg") log.info("Reading default test configuration from %s", path) - self.read_string(_parameterized_config_from_template('default_test.cfg')) + self.read_string(_parameterized_config_from_template("default_test.cfg")) # then read any "custom" test settings log.info("Reading test configuration from %s", TEST_CONFIG_FILE) self.read(TEST_CONFIG_FILE) @staticmethod - def _warn_deprecate(section: str, key: str, deprecated_section: str, deprecated_name: str): + def _warn_deprecate( + section: str, key: str, deprecated_section: str, deprecated_name: str, extra_stacklevel: int + ): if section == deprecated_section: warnings.warn( - f'The {deprecated_name} option in [{section}] has been renamed to {key} - ' - f'the old setting has been used, but please update your config.', + f"The {deprecated_name} option in [{section}] has been renamed to {key} - " + f"the old setting has been used, but please update your config.", DeprecationWarning, - stacklevel=3, + stacklevel=4 + extra_stacklevel, ) else: warnings.warn( - f'The {deprecated_name} option in [{deprecated_section}] has been moved to the {key} option ' - f'in [{section}] - the old setting has been used, but please update your config.', + f"The {deprecated_name} option in [{deprecated_section}] has been moved to the {key} option " + f"in [{section}] - the old setting has been used, but please update your config.", DeprecationWarning, - stacklevel=3, + stacklevel=4 + extra_stacklevel, ) def __getstate__(self): return { name: getattr(self, name) for name in [ - '_sections', - 'is_validated', - 'airflow_defaults', + "_sections", + "is_validated", + "airflow_defaults", ] } def __setstate__(self, state): self.__init__() - config = state.pop('_sections') + config = state.pop("_sections") self.read_dict(config) self.__dict__.update(state) def get_airflow_home() -> str: - """Get path to Airflow Home""" - return expand_env_var(os.environ.get('AIRFLOW_HOME', '~/airflow')) + """Get path to Airflow Home.""" + return expand_env_var(os.environ.get("AIRFLOW_HOME", "~/airflow")) def get_airflow_config(airflow_home) -> str: - """Get Path to airflow.cfg path""" - airflow_config_var = os.environ.get('AIRFLOW_CONFIG') + """Get Path to airflow.cfg path.""" + airflow_config_var = os.environ.get("AIRFLOW_CONFIG") if airflow_config_var is None: - return os.path.join(airflow_home, 'airflow.cfg') + return os.path.join(airflow_home, "airflow.cfg") return expand_env_var(airflow_config_var) def _parameterized_config_from_template(filename) -> str: - TEMPLATE_START = '# ----------------------- TEMPLATE BEGINS HERE -----------------------\n' + TEMPLATE_START = "# ----------------------- TEMPLATE BEGINS HERE -----------------------\n" path = _default_config_file_path(filename) with open(path) as fh: @@ -1262,8 +1435,7 @@ def _parameterized_config_from_template(filename) -> str: def parameterized_config(template) -> str: """ - Generates a configuration from the provided template + variables defined in - current scope + Generates configuration from provided template & variables defined in current scope. :param template: a config content templated with {{variables}} """ @@ -1272,11 +1444,11 @@ def parameterized_config(template) -> str: def get_airflow_test_config(airflow_home) -> str: - """Get path to unittests.cfg""" - if 'AIRFLOW_TEST_CONFIG' not in os.environ: - return os.path.join(airflow_home, 'unittests.cfg') + """Get path to unittests.cfg.""" + if "AIRFLOW_TEST_CONFIG" not in os.environ: + return os.path.join(airflow_home, "unittests.cfg") # It will never return None - return expand_env_var(os.environ['AIRFLOW_TEST_CONFIG']) # type: ignore[return-value] + return expand_env_var(os.environ["AIRFLOW_TEST_CONFIG"]) # type: ignore[return-value] def _generate_fernet_key() -> str: @@ -1293,22 +1465,22 @@ def initialize_config() -> AirflowConfigParser: """ global FERNET_KEY, AIRFLOW_HOME - default_config = _parameterized_config_from_template('default_airflow.cfg') + default_config = _parameterized_config_from_template("default_airflow.cfg") local_conf = AirflowConfigParser(default_config=default_config) - if local_conf.getboolean('core', 'unit_test_mode'): + if local_conf.getboolean("core", "unit_test_mode"): # Load test config only if not os.path.isfile(TEST_CONFIG_FILE): from cryptography.fernet import Fernet - log.info('Creating new Airflow config file for unit tests in: %s', TEST_CONFIG_FILE) + log.info("Creating new Airflow config file for unit tests in: %s", TEST_CONFIG_FILE) pathlib.Path(AIRFLOW_HOME).mkdir(parents=True, exist_ok=True) FERNET_KEY = Fernet.generate_key().decode() - with open(TEST_CONFIG_FILE, 'w') as file: - cfg = _parameterized_config_from_template('default_test.cfg') + with open(TEST_CONFIG_FILE, "w") as file: + cfg = _parameterized_config_from_template("default_test.cfg") file.write(cfg) local_conf.load_test_config() @@ -1317,58 +1489,58 @@ def initialize_config() -> AirflowConfigParser: if not os.path.isfile(AIRFLOW_CONFIG): from cryptography.fernet import Fernet - log.info('Creating new Airflow config file in: %s', AIRFLOW_CONFIG) + log.info("Creating new Airflow config file in: %s", AIRFLOW_CONFIG) pathlib.Path(AIRFLOW_HOME).mkdir(parents=True, exist_ok=True) FERNET_KEY = Fernet.generate_key().decode() - with open(AIRFLOW_CONFIG, 'w') as file: + with open(AIRFLOW_CONFIG, "w") as file: file.write(default_config) log.info("Reading the config from %s", AIRFLOW_CONFIG) local_conf.read(AIRFLOW_CONFIG) - if local_conf.has_option('core', 'AIRFLOW_HOME'): + if local_conf.has_option("core", "AIRFLOW_HOME"): msg = ( - 'Specifying both AIRFLOW_HOME environment variable and airflow_home ' - 'in the config file is deprecated. Please use only the AIRFLOW_HOME ' - 'environment variable and remove the config file entry.' + "Specifying both AIRFLOW_HOME environment variable and airflow_home " + "in the config file is deprecated. Please use only the AIRFLOW_HOME " + "environment variable and remove the config file entry." ) - if 'AIRFLOW_HOME' in os.environ: + if "AIRFLOW_HOME" in os.environ: warnings.warn(msg, category=DeprecationWarning) - elif local_conf.get('core', 'airflow_home') == AIRFLOW_HOME: + elif local_conf.get("core", "airflow_home") == AIRFLOW_HOME: warnings.warn( - 'Specifying airflow_home in the config file is deprecated. As you ' - 'have left it at the default value you should remove the setting ' - 'from your airflow.cfg and suffer no change in behaviour.', + "Specifying airflow_home in the config file is deprecated. As you " + "have left it at the default value you should remove the setting " + "from your airflow.cfg and suffer no change in behaviour.", category=DeprecationWarning, ) else: # there - AIRFLOW_HOME = local_conf.get('core', 'airflow_home') # type: ignore[assignment] + AIRFLOW_HOME = local_conf.get("core", "airflow_home") # type: ignore[assignment] warnings.warn(msg, category=DeprecationWarning) # They _might_ have set unit_test_mode in the airflow.cfg, we still # want to respect that and then load the unittests.cfg - if local_conf.getboolean('core', 'unit_test_mode'): + if local_conf.getboolean("core", "unit_test_mode"): local_conf.load_test_config() # Make it no longer a proxy variable, just set it to an actual string global WEBSERVER_CONFIG - WEBSERVER_CONFIG = AIRFLOW_HOME + '/webserver_config.py' + WEBSERVER_CONFIG = AIRFLOW_HOME + "/webserver_config.py" if not os.path.isfile(WEBSERVER_CONFIG): import shutil - log.info('Creating new FAB webserver config file in: %s', WEBSERVER_CONFIG) - shutil.copy(_default_config_file_path('default_webserver_config.py'), WEBSERVER_CONFIG) + log.info("Creating new FAB webserver config file in: %s", WEBSERVER_CONFIG) + shutil.copy(_default_config_file_path("default_webserver_config.py"), WEBSERVER_CONFIG) return local_conf # Historical convenience functions to access config entries def load_test_config(): - """Historical load_test_config""" + """Historical load_test_config.""" warnings.warn( "Accessing configuration method 'load_test_config' directly from the configuration module is " "deprecated. Please access the configuration from the 'configuration.conf' object via " @@ -1379,8 +1551,8 @@ def load_test_config(): conf.load_test_config() -def get(*args, **kwargs) -> Optional[ConfigType]: - """Historical get""" +def get(*args, **kwargs) -> ConfigType | None: + """Historical get.""" warnings.warn( "Accessing configuration method 'get' directly from the configuration module is " "deprecated. Please access the configuration from the 'configuration.conf' object via " @@ -1392,7 +1564,7 @@ def get(*args, **kwargs) -> Optional[ConfigType]: def getboolean(*args, **kwargs) -> bool: - """Historical getboolean""" + """Historical getboolean.""" warnings.warn( "Accessing configuration method 'getboolean' directly from the configuration module is " "deprecated. Please access the configuration from the 'configuration.conf' object via " @@ -1404,7 +1576,7 @@ def getboolean(*args, **kwargs) -> bool: def getfloat(*args, **kwargs) -> float: - """Historical getfloat""" + """Historical getfloat.""" warnings.warn( "Accessing configuration method 'getfloat' directly from the configuration module is " "deprecated. Please access the configuration from the 'configuration.conf' object via " @@ -1416,7 +1588,7 @@ def getfloat(*args, **kwargs) -> float: def getint(*args, **kwargs) -> int: - """Historical getint""" + """Historical getint.""" warnings.warn( "Accessing configuration method 'getint' directly from the configuration module is " "deprecated. Please access the configuration from the 'configuration.conf' object via " @@ -1427,8 +1599,8 @@ def getint(*args, **kwargs) -> int: return conf.getint(*args, **kwargs) -def getsection(*args, **kwargs) -> Optional[ConfigOptionsDictType]: - """Historical getsection""" +def getsection(*args, **kwargs) -> ConfigOptionsDictType | None: + """Historical getsection.""" warnings.warn( "Accessing configuration method 'getsection' directly from the configuration module is " "deprecated. Please access the configuration from the 'configuration.conf' object via " @@ -1440,7 +1612,7 @@ def getsection(*args, **kwargs) -> Optional[ConfigOptionsDictType]: def has_option(*args, **kwargs) -> bool: - """Historical has_option""" + """Historical has_option.""" warnings.warn( "Accessing configuration method 'has_option' directly from the configuration module is " "deprecated. Please access the configuration from the 'configuration.conf' object via " @@ -1452,7 +1624,7 @@ def has_option(*args, **kwargs) -> bool: def remove_option(*args, **kwargs) -> bool: - """Historical remove_option""" + """Historical remove_option.""" warnings.warn( "Accessing configuration method 'remove_option' directly from the configuration module is " "deprecated. Please access the configuration from the 'configuration.conf' object via " @@ -1464,7 +1636,7 @@ def remove_option(*args, **kwargs) -> bool: def as_dict(*args, **kwargs) -> ConfigSourcesType: - """Historical as_dict""" + """Historical as_dict.""" warnings.warn( "Accessing configuration method 'as_dict' directly from the configuration module is " "deprecated. Please access the configuration from the 'configuration.conf' object via " @@ -1476,7 +1648,7 @@ def as_dict(*args, **kwargs) -> ConfigSourcesType: def set(*args, **kwargs) -> None: - """Historical set""" + """Historical set.""" warnings.warn( "Accessing configuration method 'set' directly from the configuration module is " "deprecated. Please access the configuration from the 'configuration.conf' object via " @@ -1487,7 +1659,7 @@ def set(*args, **kwargs) -> None: conf.set(*args, **kwargs) -def ensure_secrets_loaded() -> List[BaseSecretsBackend]: +def ensure_secrets_loaded() -> list[BaseSecretsBackend]: """ Ensure that all secrets backends are loaded. If the secrets_backend_list contains only 2 default backends, reload it. @@ -1498,23 +1670,33 @@ def ensure_secrets_loaded() -> List[BaseSecretsBackend]: return secrets_backend_list -def get_custom_secret_backend() -> Optional[BaseSecretsBackend]: - """Get Secret Backend if defined in airflow.cfg""" - secrets_backend_cls = conf.getimport(section='secrets', key='backend') - - if secrets_backend_cls: - try: - backends: Any = conf.get(section='secrets', key='backend_kwargs', fallback='{}') - alternative_secrets_config_dict = json.loads(backends) - except JSONDecodeError: - alternative_secrets_config_dict = {} - - return secrets_backend_cls(**alternative_secrets_config_dict) - return None +def get_custom_secret_backend() -> BaseSecretsBackend | None: + """Get Secret Backend if defined in airflow.cfg.""" + secrets_backend_cls = conf.getimport(section="secrets", key="backend") + if not secrets_backend_cls: + return None -def initialize_secrets_backends() -> List[BaseSecretsBackend]: + try: + backend_kwargs = conf.getjson(section="secrets", key="backend_kwargs") + if not backend_kwargs: + backend_kwargs = {} + elif not isinstance(backend_kwargs, dict): + raise ValueError("not a dict") + except AirflowConfigException: + log.warning("Failed to parse [secrets] backend_kwargs as JSON, defaulting to no kwargs.") + backend_kwargs = {} + except ValueError: + log.warning("Failed to parse [secrets] backend_kwargs into a dict, defaulting to no kwargs.") + backend_kwargs = {} + + return secrets_backend_cls(**backend_kwargs) + + +def initialize_secrets_backends() -> list[BaseSecretsBackend]: """ + Initialize secrets backend. + * import secrets backend classes * instantiate them and return them in a list """ @@ -1534,23 +1716,23 @@ def initialize_secrets_backends() -> List[BaseSecretsBackend]: @functools.lru_cache(maxsize=None) def _DEFAULT_CONFIG() -> str: - path = _default_config_file_path('default_airflow.cfg') + path = _default_config_file_path("default_airflow.cfg") with open(path) as fh: return fh.read() @functools.lru_cache(maxsize=None) def _TEST_CONFIG() -> str: - path = _default_config_file_path('default_test.cfg') + path = _default_config_file_path("default_test.cfg") with open(path) as fh: return fh.read() _deprecated = { - 'DEFAULT_CONFIG': _DEFAULT_CONFIG, - 'TEST_CONFIG': _TEST_CONFIG, - 'TEST_CONFIG_FILE_PATH': functools.partial(_default_config_file_path, 'default_test.cfg'), - 'DEFAULT_CONFIG_FILE_PATH': functools.partial(_default_config_file_path, 'default_airflow.cfg'), + "DEFAULT_CONFIG": _DEFAULT_CONFIG, + "TEST_CONFIG": _TEST_CONFIG, + "TEST_CONFIG_FILE_PATH": functools.partial(_default_config_file_path, "default_test.cfg"), + "DEFAULT_CONFIG_FILE_PATH": functools.partial(_default_config_file_path, "default_airflow.cfg"), } @@ -1574,28 +1756,28 @@ def __getattr__(name): # Set up dags folder for unit tests # this directory won't exist if users install via pip _TEST_DAGS_FOLDER = os.path.join( - os.path.dirname(os.path.dirname(os.path.realpath(__file__))), 'tests', 'dags' + os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "tests", "dags" ) if os.path.exists(_TEST_DAGS_FOLDER): TEST_DAGS_FOLDER = _TEST_DAGS_FOLDER else: - TEST_DAGS_FOLDER = os.path.join(AIRFLOW_HOME, 'dags') + TEST_DAGS_FOLDER = os.path.join(AIRFLOW_HOME, "dags") # Set up plugins folder for unit tests _TEST_PLUGINS_FOLDER = os.path.join( - os.path.dirname(os.path.dirname(os.path.realpath(__file__))), 'tests', 'plugins' + os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "tests", "plugins" ) if os.path.exists(_TEST_PLUGINS_FOLDER): TEST_PLUGINS_FOLDER = _TEST_PLUGINS_FOLDER else: - TEST_PLUGINS_FOLDER = os.path.join(AIRFLOW_HOME, 'plugins') + TEST_PLUGINS_FOLDER = os.path.join(AIRFLOW_HOME, "plugins") TEST_CONFIG_FILE = get_airflow_test_config(AIRFLOW_HOME) -SECRET_KEY = b64encode(os.urandom(16)).decode('utf-8') -FERNET_KEY = '' # Set only if needed when generating a new file -WEBSERVER_CONFIG = '' # Set by initialize_config +SECRET_KEY = b64encode(os.urandom(16)).decode("utf-8") +FERNET_KEY = "" # Set only if needed when generating a new file +WEBSERVER_CONFIG = "" # Set by initialize_config conf = initialize_config() secrets_backend_list = initialize_secrets_backends() diff --git a/airflow/contrib/hooks/__init__.py b/airflow/contrib/hooks/__init__.py index c861d649e4a49..cffc937097306 100644 --- a/airflow/contrib/hooks/__init__.py +++ b/airflow/contrib/hooks/__init__.py @@ -15,13 +15,323 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This package is deprecated. Please use `airflow.hooks` or `airflow.providers.*.hooks`.""" +from __future__ import annotations import warnings +from airflow.exceptions import RemovedInAirflow3Warning +from airflow.utils.deprecation_tools import add_deprecated_classes + warnings.warn( "This package is deprecated. Please use `airflow.hooks` or `airflow.providers.*.hooks`.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) + +__deprecated_classes = { + 'aws_athena_hook': { + 'AWSAthenaHook': 'airflow.providers.amazon.aws.hooks.athena.AthenaHook', + }, + 'aws_datasync_hook': { + 'AWSDataSyncHook': 'airflow.providers.amazon.aws.hooks.datasync.DataSyncHook', + }, + 'aws_dynamodb_hook': { + 'AwsDynamoDBHook': 'airflow.providers.amazon.aws.hooks.dynamodb.DynamoDBHook', + }, + 'aws_firehose_hook': { + 'FirehoseHook': 'airflow.providers.amazon.aws.hooks.kinesis.FirehoseHook', + }, + 'aws_glue_catalog_hook': { + 'AwsGlueCatalogHook': 'airflow.providers.amazon.aws.hooks.glue_catalog.GlueCatalogHook', + }, + 'aws_hook': { + 'AwsBaseHook': 'airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook', + '_parse_s3_config': 'airflow.providers.amazon.aws.hooks.base_aws._parse_s3_config', + 'boto3': 'airflow.providers.amazon.aws.hooks.base_aws.boto3', + 'AwsHook': 'airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook', + }, + 'aws_lambda_hook': { + 'AwsLambdaHook': 'airflow.providers.amazon.aws.hooks.lambda_function.LambdaHook', + }, + 'aws_logs_hook': { + 'AwsLogsHook': 'airflow.providers.amazon.aws.hooks.logs.AwsLogsHook', + }, + 'aws_sns_hook': { + 'AwsSnsHook': 'airflow.providers.amazon.aws.hooks.sns.SnsHook', + }, + 'aws_sqs_hook': { + 'SqsHook': 'airflow.providers.amazon.aws.hooks.sqs.SqsHook', + 'SQSHook': 'airflow.providers.amazon.aws.hooks.sqs.SqsHook', + }, + 'azure_container_instance_hook': { + 'AzureContainerInstanceHook': ( + 'airflow.providers.microsoft.azure.hooks.container_instance.AzureContainerInstanceHook' + ), + }, + 'azure_container_registry_hook': { + 'AzureContainerRegistryHook': ( + 'airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook' + ), + }, + 'azure_container_volume_hook': { + 'AzureContainerVolumeHook': ( + 'airflow.providers.microsoft.azure.hooks.container_volume.AzureContainerVolumeHook' + ), + }, + 'azure_cosmos_hook': { + 'AzureCosmosDBHook': 'airflow.providers.microsoft.azure.hooks.cosmos.AzureCosmosDBHook', + }, + 'azure_data_lake_hook': { + 'AzureDataLakeHook': 'airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeHook', + }, + 'azure_fileshare_hook': { + 'AzureFileShareHook': 'airflow.providers.microsoft.azure.hooks.fileshare.AzureFileShareHook', + }, + 'bigquery_hook': { + 'BigQueryBaseCursor': 'airflow.providers.google.cloud.hooks.bigquery.BigQueryBaseCursor', + 'BigQueryConnection': 'airflow.providers.google.cloud.hooks.bigquery.BigQueryConnection', + 'BigQueryCursor': 'airflow.providers.google.cloud.hooks.bigquery.BigQueryCursor', + 'BigQueryHook': 'airflow.providers.google.cloud.hooks.bigquery.BigQueryHook', + 'GbqConnector': 'airflow.providers.google.cloud.hooks.bigquery.GbqConnector', + }, + 'cassandra_hook': { + 'CassandraHook': 'airflow.providers.apache.cassandra.hooks.cassandra.CassandraHook', + }, + 'cloudant_hook': { + 'CloudantHook': 'airflow.providers.cloudant.hooks.cloudant.CloudantHook', + }, + 'databricks_hook': { + 'CANCEL_RUN_ENDPOINT': 'airflow.providers.databricks.hooks.databricks.CANCEL_RUN_ENDPOINT', + 'GET_RUN_ENDPOINT': 'airflow.providers.databricks.hooks.databricks.GET_RUN_ENDPOINT', + 'RESTART_CLUSTER_ENDPOINT': 'airflow.providers.databricks.hooks.databricks.RESTART_CLUSTER_ENDPOINT', + 'RUN_LIFE_CYCLE_STATES': 'airflow.providers.databricks.hooks.databricks.RUN_LIFE_CYCLE_STATES', + 'RUN_NOW_ENDPOINT': 'airflow.providers.databricks.hooks.databricks.RUN_NOW_ENDPOINT', + 'START_CLUSTER_ENDPOINT': 'airflow.providers.databricks.hooks.databricks.START_CLUSTER_ENDPOINT', + 'SUBMIT_RUN_ENDPOINT': 'airflow.providers.databricks.hooks.databricks.SUBMIT_RUN_ENDPOINT', + 'TERMINATE_CLUSTER_ENDPOINT': ( + 'airflow.providers.databricks.hooks.databricks.TERMINATE_CLUSTER_ENDPOINT' + ), + 'DatabricksHook': 'airflow.providers.databricks.hooks.databricks.DatabricksHook', + 'RunState': 'airflow.providers.databricks.hooks.databricks.RunState', + }, + 'datadog_hook': { + 'DatadogHook': 'airflow.providers.datadog.hooks.datadog.DatadogHook', + }, + 'datastore_hook': { + 'DatastoreHook': 'airflow.providers.google.cloud.hooks.datastore.DatastoreHook', + }, + 'dingding_hook': { + 'DingdingHook': 'airflow.providers.dingding.hooks.dingding.DingdingHook', + 'requests': 'airflow.providers.dingding.hooks.dingding.requests', + }, + 'discord_webhook_hook': { + 'DiscordWebhookHook': 'airflow.providers.discord.hooks.discord_webhook.DiscordWebhookHook', + }, + 'emr_hook': { + 'EmrHook': 'airflow.providers.amazon.aws.hooks.emr.EmrHook', + }, + 'fs_hook': { + 'FSHook': 'airflow.hooks.filesystem.FSHook', + }, + 'ftp_hook': { + 'FTPHook': 'airflow.providers.ftp.hooks.ftp.FTPHook', + 'FTPSHook': 'airflow.providers.ftp.hooks.ftp.FTPSHook', + }, + 'gcp_api_base_hook': { + 'GoogleBaseHook': 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook', + 'GoogleCloudBaseHook': 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook', + }, + 'gcp_bigtable_hook': { + 'BigtableHook': 'airflow.providers.google.cloud.hooks.bigtable.BigtableHook', + }, + 'gcp_cloud_build_hook': { + 'CloudBuildHook': 'airflow.providers.google.cloud.hooks.cloud_build.CloudBuildHook', + }, + 'gcp_compute_hook': { + 'ComputeEngineHook': 'airflow.providers.google.cloud.hooks.compute.ComputeEngineHook', + 'GceHook': 'airflow.providers.google.cloud.hooks.compute.ComputeEngineHook', + }, + 'gcp_container_hook': { + 'GKEHook': 'airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook', + 'GKEClusterHook': 'airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook', + }, + 'gcp_dataflow_hook': { + 'DataflowHook': 'airflow.providers.google.cloud.hooks.dataflow.DataflowHook', + 'DataFlowHook': 'airflow.providers.google.cloud.hooks.dataflow.DataflowHook', + }, + 'gcp_dataproc_hook': { + 'DataprocHook': 'airflow.providers.google.cloud.hooks.dataproc.DataprocHook', + 'DataProcHook': 'airflow.providers.google.cloud.hooks.dataproc.DataprocHook', + }, + 'gcp_dlp_hook': { + 'CloudDLPHook': 'airflow.providers.google.cloud.hooks.dlp.CloudDLPHook', + 'DlpJob': 'airflow.providers.google.cloud.hooks.dlp.DlpJob', + }, + 'gcp_function_hook': { + 'CloudFunctionsHook': 'airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook', + 'GcfHook': 'airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook', + }, + 'gcp_kms_hook': { + 'CloudKMSHook': 'airflow.providers.google.cloud.hooks.kms.CloudKMSHook', + 'GoogleCloudKMSHook': 'airflow.providers.google.cloud.hooks.kms.CloudKMSHook', + }, + 'gcp_mlengine_hook': { + 'MLEngineHook': 'airflow.providers.google.cloud.hooks.mlengine.MLEngineHook', + }, + 'gcp_natural_language_hook': { + 'CloudNaturalLanguageHook': ( + 'airflow.providers.google.cloud.hooks.natural_language.CloudNaturalLanguageHook' + ), + }, + 'gcp_pubsub_hook': { + 'PubSubException': 'airflow.providers.google.cloud.hooks.pubsub.PubSubException', + 'PubSubHook': 'airflow.providers.google.cloud.hooks.pubsub.PubSubHook', + }, + 'gcp_spanner_hook': { + 'SpannerHook': 'airflow.providers.google.cloud.hooks.spanner.SpannerHook', + 'CloudSpannerHook': 'airflow.providers.google.cloud.hooks.spanner.SpannerHook', + }, + 'gcp_speech_to_text_hook': { + 'CloudSpeechToTextHook': 'airflow.providers.google.cloud.hooks.speech_to_text.CloudSpeechToTextHook', + 'GCPSpeechToTextHook': 'airflow.providers.google.cloud.hooks.speech_to_text.CloudSpeechToTextHook', + }, + 'gcp_sql_hook': { + 'CloudSQLDatabaseHook': 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook', + 'CloudSQLHook': 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook', + 'CloudSqlDatabaseHook': 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook', + 'CloudSqlHook': 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook', + }, + 'gcp_tasks_hook': { + 'CloudTasksHook': 'airflow.providers.google.cloud.hooks.tasks.CloudTasksHook', + }, + 'gcp_text_to_speech_hook': { + 'CloudTextToSpeechHook': 'airflow.providers.google.cloud.hooks.text_to_speech.CloudTextToSpeechHook', + 'GCPTextToSpeechHook': 'airflow.providers.google.cloud.hooks.text_to_speech.CloudTextToSpeechHook', + }, + 'gcp_transfer_hook': { + 'CloudDataTransferServiceHook': ( + 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service.CloudDataTransferServiceHook' + ), + 'GCPTransferServiceHook': ( + 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service.CloudDataTransferServiceHook' + ), + }, + 'gcp_translate_hook': { + 'CloudTranslateHook': 'airflow.providers.google.cloud.hooks.translate.CloudTranslateHook', + }, + 'gcp_video_intelligence_hook': { + 'CloudVideoIntelligenceHook': ( + 'airflow.providers.google.cloud.hooks.video_intelligence.CloudVideoIntelligenceHook' + ), + }, + 'gcp_vision_hook': { + 'CloudVisionHook': 'airflow.providers.google.cloud.hooks.vision.CloudVisionHook', + }, + 'gcs_hook': { + 'GCSHook': 'airflow.providers.google.cloud.hooks.gcs.GCSHook', + 'GoogleCloudStorageHook': 'airflow.providers.google.cloud.hooks.gcs.GCSHook', + }, + 'gdrive_hook': { + 'GoogleDriveHook': 'airflow.providers.google.suite.hooks.drive.GoogleDriveHook', + }, + 'grpc_hook': { + 'GrpcHook': 'airflow.providers.grpc.hooks.grpc.GrpcHook', + }, + 'imap_hook': { + 'ImapHook': 'airflow.providers.imap.hooks.imap.ImapHook', + 'Mail': 'airflow.providers.imap.hooks.imap.Mail', + 'MailPart': 'airflow.providers.imap.hooks.imap.MailPart', + }, + 'jenkins_hook': { + 'JenkinsHook': 'airflow.providers.jenkins.hooks.jenkins.JenkinsHook', + }, + 'jira_hook': { + 'JiraHook': 'airflow.providers.atlassian.jira.hooks.jira.JiraHook', + }, + 'mongo_hook': { + 'MongoHook': 'airflow.providers.mongo.hooks.mongo.MongoHook', + }, + 'openfaas_hook': { + 'OK_STATUS_CODE': 'airflow.providers.openfaas.hooks.openfaas.OK_STATUS_CODE', + 'OpenFaasHook': 'airflow.providers.openfaas.hooks.openfaas.OpenFaasHook', + 'requests': 'airflow.providers.openfaas.hooks.openfaas.requests', + }, + 'opsgenie_alert_hook': { + 'OpsgenieAlertHook': 'airflow.providers.opsgenie.hooks.opsgenie.OpsgenieAlertHook', + }, + 'pagerduty_hook': { + 'PagerdutyHook': 'airflow.providers.pagerduty.hooks.pagerduty.PagerdutyHook', + }, + 'pinot_hook': { + 'PinotAdminHook': 'airflow.providers.apache.pinot.hooks.pinot.PinotAdminHook', + 'PinotDbApiHook': 'airflow.providers.apache.pinot.hooks.pinot.PinotDbApiHook', + }, + 'qubole_check_hook': { + 'QuboleCheckHook': 'airflow.providers.qubole.hooks.qubole_check.QuboleCheckHook', + }, + 'qubole_hook': { + 'QuboleHook': 'airflow.providers.qubole.hooks.qubole.QuboleHook', + }, + 'redis_hook': { + 'RedisHook': 'airflow.providers.redis.hooks.redis.RedisHook', + }, + 'redshift_hook': { + 'RedshiftHook': 'airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook', + }, + 'sagemaker_hook': { + 'LogState': 'airflow.providers.amazon.aws.hooks.sagemaker.LogState', + 'Position': 'airflow.providers.amazon.aws.hooks.sagemaker.Position', + 'SageMakerHook': 'airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook', + 'argmin': 'airflow.providers.amazon.aws.hooks.sagemaker.argmin', + 'secondary_training_status_changed': ( + 'airflow.providers.amazon.aws.hooks.sagemaker.secondary_training_status_changed' + ), + 'secondary_training_status_message': ( + 'airflow.providers.amazon.aws.hooks.sagemaker.secondary_training_status_message' + ), + }, + 'salesforce_hook': { + 'SalesforceHook': 'airflow.providers.salesforce.hooks.salesforce.SalesforceHook', + 'pd': 'airflow.providers.salesforce.hooks.salesforce.pd', + }, + 'segment_hook': { + 'SegmentHook': 'airflow.providers.segment.hooks.segment.SegmentHook', + 'analytics': 'airflow.providers.segment.hooks.segment.analytics', + }, + 'sftp_hook': { + 'SFTPHook': 'airflow.providers.sftp.hooks.sftp.SFTPHook', + }, + 'slack_webhook_hook': { + 'SlackWebhookHook': 'airflow.providers.slack.hooks.slack_webhook.SlackWebhookHook', + }, + 'snowflake_hook': { + 'SnowflakeHook': 'airflow.providers.snowflake.hooks.snowflake.SnowflakeHook', + }, + 'spark_jdbc_hook': { + 'SparkJDBCHook': 'airflow.providers.apache.spark.hooks.spark_jdbc.SparkJDBCHook', + }, + 'spark_sql_hook': { + 'SparkSqlHook': 'airflow.providers.apache.spark.hooks.spark_sql.SparkSqlHook', + }, + 'spark_submit_hook': { + 'SparkSubmitHook': 'airflow.providers.apache.spark.hooks.spark_submit.SparkSubmitHook', + }, + 'sqoop_hook': { + 'SqoopHook': 'airflow.providers.apache.sqoop.hooks.sqoop.SqoopHook', + }, + 'ssh_hook': { + 'SSHHook': 'airflow.providers.ssh.hooks.ssh.SSHHook', + }, + 'vertica_hook': { + 'VerticaHook': 'airflow.providers.vertica.hooks.vertica.VerticaHook', + }, + 'wasb_hook': { + 'WasbHook': 'airflow.providers.microsoft.azure.hooks.wasb.WasbHook', + }, + 'winrm_hook': { + 'WinRMHook': 'airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook', + }, +} + +add_deprecated_classes(__deprecated_classes, __name__) diff --git a/airflow/contrib/hooks/aws_athena_hook.py b/airflow/contrib/hooks/aws_athena_hook.py deleted file mode 100644 index db1ecdfdbf3c5..0000000000000 --- a/airflow/contrib/hooks/aws_athena_hook.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.athena`.""" - -import warnings - -from airflow.providers.amazon.aws.hooks.athena import AWSAthenaHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.athena`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/aws_datasync_hook.py b/airflow/contrib/hooks/aws_datasync_hook.py deleted file mode 100644 index 0d485475b0310..0000000000000 --- a/airflow/contrib/hooks/aws_datasync_hook.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.datasync`.""" - -import warnings - -from airflow.providers.amazon.aws.hooks.datasync import AWSDataSyncHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.datasync`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/aws_dynamodb_hook.py b/airflow/contrib/hooks/aws_dynamodb_hook.py deleted file mode 100644 index dedb80073e3e5..0000000000000 --- a/airflow/contrib/hooks/aws_dynamodb_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.dynamodb`.""" - -import warnings - -from airflow.providers.amazon.aws.hooks.dynamodb import AwsDynamoDBHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.dynamodb`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/aws_firehose_hook.py b/airflow/contrib/hooks/aws_firehose_hook.py deleted file mode 100644 index c6d39cd795b79..0000000000000 --- a/airflow/contrib/hooks/aws_firehose_hook.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.kinesis`.""" - -import warnings - -from airflow.providers.amazon.aws.hooks.kinesis import AwsFirehoseHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.kinesis`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/aws_glue_catalog_hook.py b/airflow/contrib/hooks/aws_glue_catalog_hook.py deleted file mode 100644 index 703ba47b81bf3..0000000000000 --- a/airflow/contrib/hooks/aws_glue_catalog_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.glue_catalog`.""" - -import warnings - -from airflow.providers.amazon.aws.hooks.glue_catalog import AwsGlueCatalogHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.glue_catalog`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/aws_hook.py b/airflow/contrib/hooks/aws_hook.py deleted file mode 100644 index c40e32c0305f1..0000000000000 --- a/airflow/contrib/hooks/aws_hook.py +++ /dev/null @@ -1,43 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.base_aws`.""" - -import warnings - -from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook, _parse_s3_config, boto3 # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.base_aws`.", - DeprecationWarning, - stacklevel=2, -) - - -class AwsHook(AwsBaseHook): - """ - This class is deprecated. - Please use :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This class is deprecated. Please use `airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/aws_lambda_hook.py b/airflow/contrib/hooks/aws_lambda_hook.py deleted file mode 100644 index 379aaf5486655..0000000000000 --- a/airflow/contrib/hooks/aws_lambda_hook.py +++ /dev/null @@ -1,32 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -This module is deprecated. -Please use :mod:`airflow.providers.amazon.aws.hooks.lambda_function`. -""" - -import warnings - -from airflow.providers.amazon.aws.hooks.lambda_function import AwsLambdaHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.lambda_function`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/aws_logs_hook.py b/airflow/contrib/hooks/aws_logs_hook.py deleted file mode 100644 index 9b9c449f8415e..0000000000000 --- a/airflow/contrib/hooks/aws_logs_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.logs`.""" - -import warnings - -from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.logs`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/aws_sns_hook.py b/airflow/contrib/hooks/aws_sns_hook.py deleted file mode 100644 index b1318f52add2d..0000000000000 --- a/airflow/contrib/hooks/aws_sns_hook.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.sns`.""" - -import warnings - -from airflow.providers.amazon.aws.hooks.sns import AwsSnsHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.sns`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/aws_sqs_hook.py b/airflow/contrib/hooks/aws_sqs_hook.py deleted file mode 100644 index aafd8379ab5dd..0000000000000 --- a/airflow/contrib/hooks/aws_sqs_hook.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.sqs`.""" - -import warnings - -from airflow.providers.amazon.aws.hooks.sqs import SqsHook - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.sqs`.", - DeprecationWarning, - stacklevel=2, -) - - -class SQSHook(SqsHook): - """ - This class is deprecated. - Please use :class:`airflow.providers.amazon.aws.hooks.sqs.SqsHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This class is deprecated. Please use `airflow.providers.amazon.aws.hooks.sqs.SqsHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/azure_container_instance_hook.py b/airflow/contrib/hooks/azure_container_instance_hook.py deleted file mode 100644 index 9fefa5c679d38..0000000000000 --- a/airflow/contrib/hooks/azure_container_instance_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.container_instance`.""" - -import warnings - -from airflow.providers.microsoft.azure.hooks.container_instance import AzureContainerInstanceHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.container_instance`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/azure_container_registry_hook.py b/airflow/contrib/hooks/azure_container_registry_hook.py deleted file mode 100644 index 14e55ef820737..0000000000000 --- a/airflow/contrib/hooks/azure_container_registry_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.container_registry`.""" - -import warnings - -from airflow.providers.microsoft.azure.hooks.container_registry import AzureContainerRegistryHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.container_registry`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/azure_container_volume_hook.py b/airflow/contrib/hooks/azure_container_volume_hook.py deleted file mode 100644 index facfdaca4cc8a..0000000000000 --- a/airflow/contrib/hooks/azure_container_volume_hook.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.microsoft.azure.hooks.container_volume`. -""" - -import warnings - -from airflow.providers.microsoft.azure.hooks.container_volume import AzureContainerVolumeHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.container_volume`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/azure_cosmos_hook.py b/airflow/contrib/hooks/azure_cosmos_hook.py deleted file mode 100644 index 4152f15e9d7b1..0000000000000 --- a/airflow/contrib/hooks/azure_cosmos_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.cosmos`.""" - -import warnings - -from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.cosmos`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/azure_data_lake_hook.py b/airflow/contrib/hooks/azure_data_lake_hook.py deleted file mode 100644 index 3442d1345078c..0000000000000 --- a/airflow/contrib/hooks/azure_data_lake_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.data_lake`.""" - -import warnings - -from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.data_lake`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/azure_fileshare_hook.py b/airflow/contrib/hooks/azure_fileshare_hook.py deleted file mode 100644 index f0a5b2ec4c9bb..0000000000000 --- a/airflow/contrib/hooks/azure_fileshare_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.fileshare`.""" - -import warnings - -from airflow.providers.microsoft.azure.hooks.fileshare import AzureFileShareHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.fileshare`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/bigquery_hook.py b/airflow/contrib/hooks/bigquery_hook.py deleted file mode 100644 index e9e10402943a5..0000000000000 --- a/airflow/contrib/hooks/bigquery_hook.py +++ /dev/null @@ -1,34 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.bigquery`.""" - -import warnings - -from airflow.providers.google.cloud.hooks.bigquery import ( # noqa - BigQueryBaseCursor, - BigQueryConnection, - BigQueryCursor, - BigQueryHook, - GbqConnector, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.bigquery`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/cassandra_hook.py b/airflow/contrib/hooks/cassandra_hook.py deleted file mode 100644 index ea4c748c2bfc5..0000000000000 --- a/airflow/contrib/hooks/cassandra_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.cassandra.hooks.cassandra`.""" - -import warnings - -from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.cassandra.hooks.cassandra`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/cloudant_hook.py b/airflow/contrib/hooks/cloudant_hook.py deleted file mode 100644 index ab7a1fa398fa7..0000000000000 --- a/airflow/contrib/hooks/cloudant_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.cloudant.hooks.cloudant`.""" - -import warnings - -from airflow.providers.cloudant.hooks.cloudant import CloudantHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.cloudant.hooks.cloudant`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py deleted file mode 100644 index 84746d39da8ae..0000000000000 --- a/airflow/contrib/hooks/databricks_hook.py +++ /dev/null @@ -1,39 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.databricks.hooks.databricks`.""" - -import warnings - -from airflow.providers.databricks.hooks.databricks import ( # noqa - CANCEL_RUN_ENDPOINT, - GET_RUN_ENDPOINT, - RESTART_CLUSTER_ENDPOINT, - RUN_LIFE_CYCLE_STATES, - RUN_NOW_ENDPOINT, - START_CLUSTER_ENDPOINT, - SUBMIT_RUN_ENDPOINT, - TERMINATE_CLUSTER_ENDPOINT, - DatabricksHook, - RunState, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.databricks.hooks.databricks`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/datadog_hook.py b/airflow/contrib/hooks/datadog_hook.py deleted file mode 100644 index be275e9adf81d..0000000000000 --- a/airflow/contrib/hooks/datadog_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.datadog.hooks.datadog`.""" - -import warnings - -from airflow.providers.datadog.hooks.datadog import DatadogHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.datadog.hooks.datadog`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/datastore_hook.py b/airflow/contrib/hooks/datastore_hook.py deleted file mode 100644 index 9898e2fb16814..0000000000000 --- a/airflow/contrib/hooks/datastore_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.datastore`.""" - -import warnings - -from airflow.providers.google.cloud.hooks.datastore import DatastoreHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.datastore`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/dingding_hook.py b/airflow/contrib/hooks/dingding_hook.py deleted file mode 100644 index deff0414baf94..0000000000000 --- a/airflow/contrib/hooks/dingding_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.dingding.hooks.dingding`.""" - -import warnings - -from airflow.providers.dingding.hooks.dingding import DingdingHook, requests # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.dingding.hooks.dingding`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/discord_webhook_hook.py b/airflow/contrib/hooks/discord_webhook_hook.py deleted file mode 100644 index a907d2115a724..0000000000000 --- a/airflow/contrib/hooks/discord_webhook_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.discord.hooks.discord_webhook`.""" - -import warnings - -from airflow.providers.discord.hooks.discord_webhook import DiscordWebhookHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.discord.hooks.discord_webhook`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/emr_hook.py b/airflow/contrib/hooks/emr_hook.py deleted file mode 100644 index 1a15ee3edcd17..0000000000000 --- a/airflow/contrib/hooks/emr_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.emr`.""" - -import warnings - -from airflow.providers.amazon.aws.hooks.emr import EmrHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.emr`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/fs_hook.py b/airflow/contrib/hooks/fs_hook.py deleted file mode 100644 index bc247c1948074..0000000000000 --- a/airflow/contrib/hooks/fs_hook.py +++ /dev/null @@ -1,26 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.hooks.filesystem`.""" - -import warnings - -from airflow.hooks.filesystem import FSHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.hooks.filesystem`.", DeprecationWarning, stacklevel=2 -) diff --git a/airflow/contrib/hooks/ftp_hook.py b/airflow/contrib/hooks/ftp_hook.py deleted file mode 100644 index 8d2e9cb06c843..0000000000000 --- a/airflow/contrib/hooks/ftp_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.ftp.hooks.ftp`.""" - -import warnings - -from airflow.providers.ftp.hooks.ftp import FTPHook, FTPSHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.ftp.hooks.ftp`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/gcp_api_base_hook.py b/airflow/contrib/hooks/gcp_api_base_hook.py deleted file mode 100644 index 3226a8683d7a7..0000000000000 --- a/airflow/contrib/hooks/gcp_api_base_hook.py +++ /dev/null @@ -1,43 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.common.hooks.base_google`.""" -import warnings - -from airflow.providers.google.common.hooks.base_google import GoogleBaseHook - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.common.hooks.base_google`.", - DeprecationWarning, - stacklevel=2, -) - - -class GoogleCloudBaseHook(GoogleBaseHook): - """ - This class is deprecated. Please use - `airflow.providers.google.common.hooks.base_google.GoogleBaseHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This class is deprecated. Please use " - "`airflow.providers.google.common.hooks.base_google.GoogleBaseHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_bigtable_hook.py b/airflow/contrib/hooks/gcp_bigtable_hook.py deleted file mode 100644 index 47ccd2414a839..0000000000000 --- a/airflow/contrib/hooks/gcp_bigtable_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.bigtable`.""" - -import warnings - -from airflow.providers.google.cloud.hooks.bigtable import BigtableHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.bigtable`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/gcp_cloud_build_hook.py b/airflow/contrib/hooks/gcp_cloud_build_hook.py deleted file mode 100644 index 691ae728a46d1..0000000000000 --- a/airflow/contrib/hooks/gcp_cloud_build_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.cloud_build`.""" - -import warnings - -from airflow.providers.google.cloud.hooks.cloud_build import CloudBuildHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.cloud_build`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/gcp_compute_hook.py b/airflow/contrib/hooks/gcp_compute_hook.py deleted file mode 100644 index 5ac9df1378561..0000000000000 --- a/airflow/contrib/hooks/gcp_compute_hook.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.compute`.""" - -import warnings - -from airflow.providers.google.cloud.hooks.compute import ComputeEngineHook - -warnings.warn( - "This module is deprecated. Please use airflow.providers.google.cloud.hooks.compute`", - DeprecationWarning, - stacklevel=2, -) - - -class GceHook(ComputeEngineHook): - """ - This class is deprecated. - Please use :class:`airflow.providers.google.cloud.hooks.compute.ComputeEngineHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This class is deprecated. Please use `airflow.providers.google.cloud.hooks.compute`.", - DeprecationWarning, - stacklevel=2, - ) - - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_container_hook.py b/airflow/contrib/hooks/gcp_container_hook.py deleted file mode 100644 index e825dbe5492c2..0000000000000 --- a/airflow/contrib/hooks/gcp_container_hook.py +++ /dev/null @@ -1,43 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.google.cloud.hooks.kubernetes_engine`. -""" - -import warnings - -from airflow.providers.google.cloud.hooks.kubernetes_engine import GKEHook - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.kubernetes_engine`", - DeprecationWarning, - stacklevel=2, -) - - -class GKEClusterHook(GKEHook): - """This class is deprecated. Please use `airflow.providers.google.cloud.hooks.container.GKEHook`.""" - - def __init__(self, *args, **kwargs): - warnings.warn( - "This class is deprecated. Please use `airflow.providers.google.cloud.hooks.container.GKEHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_dataflow_hook.py b/airflow/contrib/hooks/gcp_dataflow_hook.py deleted file mode 100644 index f489ffad82006..0000000000000 --- a/airflow/contrib/hooks/gcp_dataflow_hook.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.dataflow`.""" - -import warnings - -from airflow.providers.google.cloud.hooks.dataflow import DataflowHook - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.dataflow`.", - DeprecationWarning, - stacklevel=2, -) - - -class DataFlowHook(DataflowHook): - """ - This class is deprecated. - Please use :class:`airflow.providers.google.cloud.hooks.dataflow.DataflowHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This class is deprecated. " - "Please use `airflow.providers.google.cloud.hooks.dataflow.DataflowHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_dataproc_hook.py b/airflow/contrib/hooks/gcp_dataproc_hook.py deleted file mode 100644 index 1f02c6de42f7e..0000000000000 --- a/airflow/contrib/hooks/gcp_dataproc_hook.py +++ /dev/null @@ -1,45 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.dataproc`.""" - -import warnings - -from airflow.providers.google.cloud.hooks.dataproc import DataprocHook - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.dataproc`.", - DeprecationWarning, - stacklevel=2, -) - - -class DataProcHook(DataprocHook): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.hooks.dataproc.DataprocHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.hooks.dataproc.DataprocHook`.""", - DeprecationWarning, - stacklevel=2, - ) - - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_dlp_hook.py b/airflow/contrib/hooks/gcp_dlp_hook.py deleted file mode 100644 index 77a9da6b1a8a8..0000000000000 --- a/airflow/contrib/hooks/gcp_dlp_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.dlp`.""" - -import warnings - -from airflow.providers.google.cloud.hooks.dlp import CloudDLPHook, DlpJob # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.dlp`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/gcp_function_hook.py b/airflow/contrib/hooks/gcp_function_hook.py deleted file mode 100644 index dc6464fc07ce9..0000000000000 --- a/airflow/contrib/hooks/gcp_function_hook.py +++ /dev/null @@ -1,45 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.functions`.""" - -import warnings - -from airflow.providers.google.cloud.hooks.functions import CloudFunctionsHook - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.functions`.", - DeprecationWarning, - stacklevel=2, -) - - -class GcfHook(CloudFunctionsHook): - """ - This class is deprecated. Please use - `airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This class is deprecated. " - "Please use `airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook`.", - DeprecationWarning, - stacklevel=2, - ) - - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_kms_hook.py b/airflow/contrib/hooks/gcp_kms_hook.py deleted file mode 100644 index 1b409be99f121..0000000000000 --- a/airflow/contrib/hooks/gcp_kms_hook.py +++ /dev/null @@ -1,40 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.kms`.""" - -import warnings - -from airflow.providers.google.cloud.hooks.kms import CloudKMSHook - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.kms`.", - DeprecationWarning, - stacklevel=2, -) - - -class GoogleCloudKMSHook(CloudKMSHook): - """This class is deprecated. Please use `airflow.providers.google.cloud.hooks.kms.CloudKMSHook`.""" - - def __init__(self, *args, **kwargs): - warnings.warn( - "This class is deprecated. Please use `airflow.providers.google.cloud.hooks.kms.CloudKMSHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_mlengine_hook.py b/airflow/contrib/hooks/gcp_mlengine_hook.py deleted file mode 100644 index 57978e008b375..0000000000000 --- a/airflow/contrib/hooks/gcp_mlengine_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.mlengine`.""" - -import warnings - -from airflow.providers.google.cloud.hooks.mlengine import MLEngineHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.mlengine`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/gcp_natural_language_hook.py b/airflow/contrib/hooks/gcp_natural_language_hook.py deleted file mode 100644 index 86ee9f8675d57..0000000000000 --- a/airflow/contrib/hooks/gcp_natural_language_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.natural_language`.""" - -import warnings - -from airflow.providers.google.cloud.hooks.natural_language import CloudNaturalLanguageHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.natural_language`", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/gcp_pubsub_hook.py b/airflow/contrib/hooks/gcp_pubsub_hook.py deleted file mode 100644 index 677a0f03fa506..0000000000000 --- a/airflow/contrib/hooks/gcp_pubsub_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.pubsub`.""" - -import warnings - -from airflow.providers.google.cloud.hooks.pubsub import PubSubException, PubSubHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.pubsub`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/gcp_spanner_hook.py b/airflow/contrib/hooks/gcp_spanner_hook.py deleted file mode 100644 index d3608ce5d0994..0000000000000 --- a/airflow/contrib/hooks/gcp_spanner_hook.py +++ /dev/null @@ -1,36 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.spanner`.""" - -import warnings - -from airflow.providers.google.cloud.hooks.spanner import SpannerHook - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.spanner`.", - DeprecationWarning, - stacklevel=2, -) - - -class CloudSpannerHook(SpannerHook): - """This class is deprecated. Please use `airflow.providers.google.cloud.hooks.spanner.SpannerHook`.""" - - def __init__(self, *args, **kwargs): - warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_speech_to_text_hook.py b/airflow/contrib/hooks/gcp_speech_to_text_hook.py deleted file mode 100644 index 1c4f4e16a79c0..0000000000000 --- a/airflow/contrib/hooks/gcp_speech_to_text_hook.py +++ /dev/null @@ -1,45 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.speech_to_text`.""" - -import warnings - -from airflow.providers.google.cloud.hooks.speech_to_text import CloudSpeechToTextHook - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.speech_to_text`", - DeprecationWarning, - stacklevel=2, -) - - -class GCPSpeechToTextHook(CloudSpeechToTextHook): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.hooks.speech_to_text.CloudSpeechToTextHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This class is deprecated. " - "Please use `airflow.providers.google.cloud.hooks.speech_to_text.CloudSpeechToTextHook`.", - DeprecationWarning, - stacklevel=2, - ) - - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_sql_hook.py b/airflow/contrib/hooks/gcp_sql_hook.py deleted file mode 100644 index d3e85a7c3a57c..0000000000000 --- a/airflow/contrib/hooks/gcp_sql_hook.py +++ /dev/null @@ -1,47 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.cloud_sql`.""" - -import warnings - -from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLDatabaseHook, CloudSQLHook - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.cloud_sql`", - DeprecationWarning, - stacklevel=2, -) - - -class CloudSqlDatabaseHook(CloudSQLDatabaseHook): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2) - super().__init__(*args, **kwargs) - - -class CloudSqlHook(CloudSQLHook): - """This class is deprecated. Please use `airflow.providers.google.cloud.hooks.sql.CloudSQLHook`.""" - - def __init__(self, *args, **kwargs): - warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_tasks_hook.py b/airflow/contrib/hooks/gcp_tasks_hook.py deleted file mode 100644 index 1753b2a1842d4..0000000000000 --- a/airflow/contrib/hooks/gcp_tasks_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.tasks`.""" - -import warnings - -from airflow.providers.google.cloud.hooks.tasks import CloudTasksHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.tasks`", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/gcp_text_to_speech_hook.py b/airflow/contrib/hooks/gcp_text_to_speech_hook.py deleted file mode 100644 index 9e53ec16f5f0e..0000000000000 --- a/airflow/contrib/hooks/gcp_text_to_speech_hook.py +++ /dev/null @@ -1,45 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.text_to_speech`.""" - -import warnings - -from airflow.providers.google.cloud.hooks.text_to_speech import CloudTextToSpeechHook - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.text_to_speech`", - DeprecationWarning, - stacklevel=2, -) - - -class GCPTextToSpeechHook(CloudTextToSpeechHook): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.hooks.text_to_speech.CloudTextToSpeechHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This class is deprecated. " - "Please use `airflow.providers.google.cloud.hooks.text_to_speech.CloudTextToSpeechHook`.", - DeprecationWarning, - stacklevel=2, - ) - - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_transfer_hook.py b/airflow/contrib/hooks/gcp_transfer_hook.py deleted file mode 100644 index 57c098d52b7f9..0000000000000 --- a/airflow/contrib/hooks/gcp_transfer_hook.py +++ /dev/null @@ -1,50 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. Please use -`airflow.providers.google.cloud.hooks.cloud_storage_transfer_service`. -""" - -import warnings - -from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import CloudDataTransferServiceHook - -warnings.warn( - "This module is deprecated. " - "Please use `airflow.providers.google.cloud.hooks.cloud_storage_transfer_service`", - DeprecationWarning, - stacklevel=2, -) - - -class GCPTransferServiceHook(CloudDataTransferServiceHook): - """ - This class is deprecated. Please use - `airflow.providers.google.cloud.hooks.cloud_storage_transfer_service.CloudDataTransferServiceHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.hooks.cloud_storage_transfer_service - .CloudDataTransferServiceHook`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gcp_translate_hook.py b/airflow/contrib/hooks/gcp_translate_hook.py deleted file mode 100644 index 1b0cec8b5ec9c..0000000000000 --- a/airflow/contrib/hooks/gcp_translate_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.translate`.""" - -import warnings - -from airflow.providers.google.cloud.hooks.translate import CloudTranslateHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.translate`", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/gcp_video_intelligence_hook.py b/airflow/contrib/hooks/gcp_video_intelligence_hook.py deleted file mode 100644 index a71ef4649e909..0000000000000 --- a/airflow/contrib/hooks/gcp_video_intelligence_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.video_intelligence`.""" - -import warnings - -from airflow.providers.google.cloud.hooks.video_intelligence import CloudVideoIntelligenceHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.video_intelligence`", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/gcp_vision_hook.py b/airflow/contrib/hooks/gcp_vision_hook.py deleted file mode 100644 index 52f47f42bfdf5..0000000000000 --- a/airflow/contrib/hooks/gcp_vision_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.vision`.""" - -import warnings - -from airflow.providers.google.cloud.hooks.vision import CloudVisionHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.vision`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/gcs_hook.py b/airflow/contrib/hooks/gcs_hook.py deleted file mode 100644 index 910206acf9286..0000000000000 --- a/airflow/contrib/hooks/gcs_hook.py +++ /dev/null @@ -1,39 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.gcs`.""" -import warnings - -from airflow.providers.google.cloud.hooks.gcs import GCSHook - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.gcs`.", - DeprecationWarning, - stacklevel=2, -) - - -class GoogleCloudStorageHook(GCSHook): - """This class is deprecated. Please use `airflow.providers.google.cloud.hooks.gcs.GCSHook`.""" - - def __init__(self, *args, **kwargs): - warnings.warn( - "This class is deprecated. Please use `airflow.providers.google.cloud.hooks.gcs.GCSHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/hooks/gdrive_hook.py b/airflow/contrib/hooks/gdrive_hook.py deleted file mode 100644 index dad8459c58394..0000000000000 --- a/airflow/contrib/hooks/gdrive_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.suite.hooks.drive`.""" - -import warnings - -from airflow.providers.google.suite.hooks.drive import GoogleDriveHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.suite.hooks.drive`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/grpc_hook.py b/airflow/contrib/hooks/grpc_hook.py deleted file mode 100644 index f7aa6e2216fa9..0000000000000 --- a/airflow/contrib/hooks/grpc_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.grpc.hooks.grpc`.""" - -import warnings - -from airflow.providers.grpc.hooks.grpc import GrpcHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.grpc.hooks.grpc`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/imap_hook.py b/airflow/contrib/hooks/imap_hook.py deleted file mode 100644 index 57703966ab3c9..0000000000000 --- a/airflow/contrib/hooks/imap_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.imap.hooks.imap`.""" - -import warnings - -from airflow.providers.imap.hooks.imap import ImapHook, Mail, MailPart # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.imap.hooks.imap`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/jenkins_hook.py b/airflow/contrib/hooks/jenkins_hook.py deleted file mode 100644 index 178474ea77991..0000000000000 --- a/airflow/contrib/hooks/jenkins_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.jenkins.hooks.jenkins`.""" - -import warnings - -from airflow.providers.jenkins.hooks.jenkins import JenkinsHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.jenkins.hooks.jenkins`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/jira_hook.py b/airflow/contrib/hooks/jira_hook.py deleted file mode 100644 index 8f9d4670154a2..0000000000000 --- a/airflow/contrib/hooks/jira_hook.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.jira.hooks.jira`.""" - -import warnings - -from airflow.providers.jira.hooks.jira import JiraHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.jira.hooks.jira`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/mongo_hook.py b/airflow/contrib/hooks/mongo_hook.py deleted file mode 100644 index 63f6eea36ba51..0000000000000 --- a/airflow/contrib/hooks/mongo_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.mongo.hooks.mongo`.""" - -import warnings - -from airflow.providers.mongo.hooks.mongo import MongoHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.mongo.hooks.mongo`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/openfaas_hook.py b/airflow/contrib/hooks/openfaas_hook.py deleted file mode 100644 index a0e71ffe5e564..0000000000000 --- a/airflow/contrib/hooks/openfaas_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.openfaas.hooks.openfaas`.""" - -import warnings - -from airflow.providers.openfaas.hooks.openfaas import OK_STATUS_CODE, OpenFaasHook, requests # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.openfaas.hooks.openfaas`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/opsgenie_alert_hook.py b/airflow/contrib/hooks/opsgenie_alert_hook.py deleted file mode 100644 index abd9c89e09308..0000000000000 --- a/airflow/contrib/hooks/opsgenie_alert_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.opsgenie.hooks.opsgenie`.""" - -import warnings - -from airflow.providers.opsgenie.hooks.opsgenie import OpsgenieAlertHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.opsgenie.hooks.opsgenie`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/pagerduty_hook.py b/airflow/contrib/hooks/pagerduty_hook.py deleted file mode 100644 index 33797b0e60ea4..0000000000000 --- a/airflow/contrib/hooks/pagerduty_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.pagerduty.hooks.pagerduty`.""" - -import warnings - -from airflow.providers.pagerduty.hooks.pagerduty import PagerdutyHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.pagerduty.hooks.pagerduty`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/pinot_hook.py b/airflow/contrib/hooks/pinot_hook.py deleted file mode 100644 index 159677fdec52a..0000000000000 --- a/airflow/contrib/hooks/pinot_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.pinot.hooks.pinot`.""" - -import warnings - -from airflow.providers.apache.pinot.hooks.pinot import PinotAdminHook, PinotDbApiHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.pinot.hooks.pinot`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/qubole_check_hook.py b/airflow/contrib/hooks/qubole_check_hook.py deleted file mode 100644 index 0a674d7c763ad..0000000000000 --- a/airflow/contrib/hooks/qubole_check_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.qubole.hooks.qubole_check`.""" - -import warnings - -from airflow.providers.qubole.hooks.qubole_check import QuboleCheckHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.qubole.hooks.qubole_check`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/qubole_hook.py b/airflow/contrib/hooks/qubole_hook.py deleted file mode 100644 index 6a695bca76462..0000000000000 --- a/airflow/contrib/hooks/qubole_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.qubole.hooks.qubole`.""" - -import warnings - -from airflow.providers.qubole.hooks.qubole import QuboleHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.qubole.hooks.qubole`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/redis_hook.py b/airflow/contrib/hooks/redis_hook.py deleted file mode 100644 index 57bdab5aead08..0000000000000 --- a/airflow/contrib/hooks/redis_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.redis.hooks.redis`.""" - -import warnings - -from airflow.providers.redis.hooks.redis import RedisHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.redis.hooks.redis`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/redshift_hook.py b/airflow/contrib/hooks/redshift_hook.py deleted file mode 100644 index f33515f6cd954..0000000000000 --- a/airflow/contrib/hooks/redshift_hook.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.redshift_cluster`.""" - -import warnings - -from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.redshift_cluster`.", - DeprecationWarning, - stacklevel=2, -) - -__all__ = ["RedshiftHook"] diff --git a/airflow/contrib/hooks/sagemaker_hook.py b/airflow/contrib/hooks/sagemaker_hook.py deleted file mode 100644 index 321f25b022fa9..0000000000000 --- a/airflow/contrib/hooks/sagemaker_hook.py +++ /dev/null @@ -1,35 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.sagemaker`.""" - -import warnings - -from airflow.providers.amazon.aws.hooks.sagemaker import ( # noqa - LogState, - Position, - SageMakerHook, - argmin, - secondary_training_status_changed, - secondary_training_status_message, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.sagemaker`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/salesforce_hook.py b/airflow/contrib/hooks/salesforce_hook.py deleted file mode 100644 index a707a527b8449..0000000000000 --- a/airflow/contrib/hooks/salesforce_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.salesforce.hooks.salesforce`.""" - -import warnings - -from airflow.providers.salesforce.hooks.salesforce import SalesforceHook, pd # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.salesforce.hooks.salesforce`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/segment_hook.py b/airflow/contrib/hooks/segment_hook.py deleted file mode 100644 index 6da62578cee8b..0000000000000 --- a/airflow/contrib/hooks/segment_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.segment.hooks.segment`.""" - -import warnings - -from airflow.providers.segment.hooks.segment import SegmentHook, analytics # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.segment.hooks.segment`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/sftp_hook.py b/airflow/contrib/hooks/sftp_hook.py deleted file mode 100644 index 0153e8e54d062..0000000000000 --- a/airflow/contrib/hooks/sftp_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.sftp.hooks.sftp`.""" - -import warnings - -from airflow.providers.sftp.hooks.sftp import SFTPHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.sftp.hooks.sftp`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/slack_webhook_hook.py b/airflow/contrib/hooks/slack_webhook_hook.py deleted file mode 100644 index f438d11575b43..0000000000000 --- a/airflow/contrib/hooks/slack_webhook_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.slack.hooks.slack_webhook`.""" - -import warnings - -from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.slack.hooks.slack_webhook`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/snowflake_hook.py b/airflow/contrib/hooks/snowflake_hook.py deleted file mode 100644 index 804baccf66cdf..0000000000000 --- a/airflow/contrib/hooks/snowflake_hook.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.snowflake.hooks.snowflake`.""" - -import warnings - -from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.snowflake.hooks.snowflake`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/spark_jdbc_hook.py b/airflow/contrib/hooks/spark_jdbc_hook.py deleted file mode 100644 index 1b48d094acbb6..0000000000000 --- a/airflow/contrib/hooks/spark_jdbc_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.spark.hooks.spark_jdbc`.""" - -import warnings - -from airflow.providers.apache.spark.hooks.spark_jdbc import SparkJDBCHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.spark.hooks.spark_jdbc`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/spark_sql_hook.py b/airflow/contrib/hooks/spark_sql_hook.py deleted file mode 100644 index 6b262ed3efce5..0000000000000 --- a/airflow/contrib/hooks/spark_sql_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.spark.hooks.spark_sql`.""" - -import warnings - -from airflow.providers.apache.spark.hooks.spark_sql import SparkSqlHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.spark.hooks.spark_sql`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/spark_submit_hook.py b/airflow/contrib/hooks/spark_submit_hook.py deleted file mode 100644 index fbdbf4f9f0d53..0000000000000 --- a/airflow/contrib/hooks/spark_submit_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.spark.hooks.spark_submit`.""" - -import warnings - -from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.spark.hooks.spark_submit`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/sqoop_hook.py b/airflow/contrib/hooks/sqoop_hook.py deleted file mode 100644 index f231c0f26a8dc..0000000000000 --- a/airflow/contrib/hooks/sqoop_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.sqoop.hooks.sqoop`.""" - -import warnings - -from airflow.providers.apache.sqoop.hooks.sqoop import SqoopHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.sqoop.hooks.sqoop`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/ssh_hook.py b/airflow/contrib/hooks/ssh_hook.py deleted file mode 100644 index ef3000d888c50..0000000000000 --- a/airflow/contrib/hooks/ssh_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.ssh.hooks.ssh`.""" - -import warnings - -from airflow.providers.ssh.hooks.ssh import SSHHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.ssh.hooks.ssh`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/vertica_hook.py b/airflow/contrib/hooks/vertica_hook.py deleted file mode 100644 index fc84b222d0a5a..0000000000000 --- a/airflow/contrib/hooks/vertica_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.vertica.hooks.vertica`.""" - -import warnings - -from airflow.providers.vertica.hooks.vertica import VerticaHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.vertica.hooks.vertica`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/wasb_hook.py b/airflow/contrib/hooks/wasb_hook.py deleted file mode 100644 index 3b5eb650934af..0000000000000 --- a/airflow/contrib/hooks/wasb_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.wasb`.""" - -import warnings - -from airflow.providers.microsoft.azure.hooks.wasb import WasbHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.wasb`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/hooks/winrm_hook.py b/airflow/contrib/hooks/winrm_hook.py deleted file mode 100644 index 35e7db2bc7294..0000000000000 --- a/airflow/contrib/hooks/winrm_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.winrm.hooks.winrm`.""" - -import warnings - -from airflow.providers.microsoft.winrm.hooks.winrm import WinRMHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.winrm.hooks.winrm`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/__init__.py b/airflow/contrib/operators/__init__.py index 2041adb9ea24a..8e02c871b84d2 100644 --- a/airflow/contrib/operators/__init__.py +++ b/airflow/contrib/operators/__init__.py @@ -15,5 +15,1128 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This package is deprecated. Please use `airflow.operators` or `airflow.providers.*.operators`.""" +from __future__ import annotations + +import warnings + +from airflow.exceptions import RemovedInAirflow3Warning +from airflow.utils.deprecation_tools import add_deprecated_classes + +warnings.warn( + "This package is deprecated. Please use `airflow.operators` or `airflow.providers.*.operators`.", + RemovedInAirflow3Warning, + stacklevel=2, +) + +__deprecated_classes = { + 'adls_list_operator': { + 'ADLSListOperator': 'airflow.providers.microsoft.azure.operators.adls.ADLSListOperator', + 'AzureDataLakeStorageListOperator': ( + 'airflow.providers.microsoft.azure.operators.adls.ADLSListOperator' + ), + }, + 'adls_to_gcs': { + 'ADLSToGCSOperator': 'airflow.providers.google.cloud.transfers.adls_to_gcs.ADLSToGCSOperator', + 'AdlsToGoogleCloudStorageOperator': ( + 'airflow.providers.google.cloud.transfers.adls_to_gcs.ADLSToGCSOperator' + ), + }, + 'aws_athena_operator': { + 'AWSAthenaOperator': 'airflow.providers.amazon.aws.operators.athena.AthenaOperator', + }, + 'aws_sqs_publish_operator': { + 'SqsPublishOperator': 'airflow.providers.amazon.aws.operators.sqs.SqsPublishOperator', + 'SQSPublishOperator': 'airflow.providers.amazon.aws.operators.sqs.SqsPublishOperator', + }, + 'awsbatch_operator': { + 'BatchProtocol': 'airflow.providers.amazon.aws.hooks.batch_client.BatchProtocol', + 'BatchOperator': 'airflow.providers.amazon.aws.operators.batch.BatchOperator', + 'AWSBatchOperator': 'airflow.providers.amazon.aws.operators.batch.BatchOperator', + }, + 'azure_container_instances_operator': { + 'AzureContainerInstancesOperator': ( + 'airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstancesOperator' + ), + }, + 'azure_cosmos_operator': { + 'AzureCosmosInsertDocumentOperator': ( + 'airflow.providers.microsoft.azure.operators.cosmos.AzureCosmosInsertDocumentOperator' + ), + }, + 'bigquery_check_operator': { + 'BigQueryCheckOperator': 'airflow.providers.google.cloud.operators.bigquery.BigQueryCheckOperator', + 'BigQueryIntervalCheckOperator': ( + 'airflow.providers.google.cloud.operators.bigquery.BigQueryIntervalCheckOperator' + ), + 'BigQueryValueCheckOperator': ( + 'airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator' + ), + }, + 'bigquery_get_data': { + 'BigQueryGetDataOperator': ( + 'airflow.providers.google.cloud.operators.bigquery.BigQueryGetDataOperator' + ), + }, + 'bigquery_operator': { + 'BigQueryCreateEmptyDatasetOperator': ( + 'airflow.providers.google.cloud.operators.bigquery.BigQueryCreateEmptyDatasetOperator' + ), + 'BigQueryCreateEmptyTableOperator': ( + 'airflow.providers.google.cloud.operators.bigquery.BigQueryCreateEmptyTableOperator' + ), + 'BigQueryCreateExternalTableOperator': ( + 'airflow.providers.google.cloud.operators.bigquery.BigQueryCreateExternalTableOperator' + ), + 'BigQueryDeleteDatasetOperator': ( + 'airflow.providers.google.cloud.operators.bigquery.BigQueryDeleteDatasetOperator' + ), + 'BigQueryExecuteQueryOperator': ( + 'airflow.providers.google.cloud.operators.bigquery.BigQueryExecuteQueryOperator' + ), + 'BigQueryGetDatasetOperator': ( + 'airflow.providers.google.cloud.operators.bigquery.BigQueryGetDatasetOperator' + ), + 'BigQueryGetDatasetTablesOperator': ( + 'airflow.providers.google.cloud.operators.bigquery.BigQueryGetDatasetTablesOperator' + ), + 'BigQueryPatchDatasetOperator': ( + 'airflow.providers.google.cloud.operators.bigquery.BigQueryPatchDatasetOperator' + ), + 'BigQueryUpdateDatasetOperator': ( + 'airflow.providers.google.cloud.operators.bigquery.BigQueryUpdateDatasetOperator' + ), + 'BigQueryUpsertTableOperator': ( + 'airflow.providers.google.cloud.operators.bigquery.BigQueryUpsertTableOperator' + ), + 'BigQueryOperator': 'airflow.providers.google.cloud.operators.bigquery.BigQueryExecuteQueryOperator', + }, + 'bigquery_table_delete_operator': { + 'BigQueryDeleteTableOperator': ( + 'airflow.providers.google.cloud.operators.bigquery.BigQueryDeleteTableOperator' + ), + 'BigQueryTableDeleteOperator': ( + 'airflow.providers.google.cloud.operators.bigquery.BigQueryDeleteTableOperator' + ), + }, + 'bigquery_to_bigquery': { + 'BigQueryToBigQueryOperator': ( + 'airflow.providers.google.cloud.transfers.bigquery_to_bigquery.BigQueryToBigQueryOperator' + ), + }, + 'bigquery_to_gcs': { + 'BigQueryToGCSOperator': ( + 'airflow.providers.google.cloud.transfers.bigquery_to_gcs.BigQueryToGCSOperator' + ), + 'BigQueryToCloudStorageOperator': ( + 'airflow.providers.google.cloud.transfers.bigquery_to_gcs.BigQueryToGCSOperator' + ), + }, + 'bigquery_to_mysql_operator': { + 'BigQueryToMySqlOperator': ( + 'airflow.providers.google.cloud.transfers.bigquery_to_mysql.BigQueryToMySqlOperator' + ), + }, + 'cassandra_to_gcs': { + 'CassandraToGCSOperator': ( + 'airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraToGCSOperator' + ), + 'CassandraToGoogleCloudStorageOperator': ( + 'airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraToGCSOperator' + ), + }, + 'databricks_operator': { + 'DatabricksRunNowOperator': ( + 'airflow.providers.databricks.operators.databricks.DatabricksRunNowOperator' + ), + 'DatabricksSubmitRunOperator': ( + 'airflow.providers.databricks.operators.databricks.DatabricksSubmitRunOperator' + ), + }, + 'dataflow_operator': { + 'DataflowCreateJavaJobOperator': ( + 'airflow.providers.google.cloud.operators.dataflow.DataflowCreateJavaJobOperator' + ), + 'DataflowCreatePythonJobOperator': ( + 'airflow.providers.google.cloud.operators.dataflow.DataflowCreatePythonJobOperator' + ), + 'DataflowTemplatedJobStartOperator': ( + 'airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator' + ), + 'DataFlowJavaOperator': ( + 'airflow.providers.google.cloud.operators.dataflow.DataflowCreateJavaJobOperator' + ), + 'DataFlowPythonOperator': ( + 'airflow.providers.google.cloud.operators.dataflow.DataflowCreatePythonJobOperator' + ), + 'DataflowTemplateOperator': ( + 'airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator' + ), + }, + 'dataproc_operator': { + 'DataprocCreateClusterOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocCreateClusterOperator' + ), + 'DataprocDeleteClusterOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocDeleteClusterOperator' + ), + 'DataprocInstantiateInlineWorkflowTemplateOperator': + 'airflow.providers.google.cloud.operators.dataproc.' + 'DataprocInstantiateInlineWorkflowTemplateOperator', + 'DataprocInstantiateWorkflowTemplateOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocInstantiateWorkflowTemplateOperator' + ), + 'DataprocJobBaseOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator' + ), + 'DataprocScaleClusterOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocScaleClusterOperator' + ), + 'DataprocSubmitHadoopJobOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHadoopJobOperator' + ), + 'DataprocSubmitHiveJobOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHiveJobOperator' + ), + 'DataprocSubmitPigJobOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPigJobOperator' + ), + 'DataprocSubmitPySparkJobOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPySparkJobOperator' + ), + 'DataprocSubmitSparkJobOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkJobOperator' + ), + 'DataprocSubmitSparkSqlJobOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkSqlJobOperator' + ), + 'DataprocClusterCreateOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocCreateClusterOperator' + ), + 'DataprocClusterDeleteOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocDeleteClusterOperator' + ), + 'DataprocClusterScaleOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocScaleClusterOperator' + ), + 'DataProcHadoopOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHadoopJobOperator' + ), + 'DataProcHiveOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHiveJobOperator' + ), + 'DataProcJobBaseOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator' + ), + 'DataProcPigOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPigJobOperator' + ), + 'DataProcPySparkOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPySparkJobOperator' + ), + 'DataProcSparkOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkJobOperator' + ), + 'DataProcSparkSqlOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkSqlJobOperator' + ), + 'DataprocWorkflowTemplateInstantiateInlineOperator': + 'airflow.providers.google.cloud.operators.dataproc.' + 'DataprocInstantiateInlineWorkflowTemplateOperator', + 'DataprocWorkflowTemplateInstantiateOperator': ( + 'airflow.providers.google.cloud.operators.dataproc.DataprocInstantiateWorkflowTemplateOperator' + ), + }, + 'datastore_export_operator': { + 'CloudDatastoreExportEntitiesOperator': ( + 'airflow.providers.google.cloud.operators.datastore.CloudDatastoreExportEntitiesOperator' + ), + 'DatastoreExportOperator': ( + 'airflow.providers.google.cloud.operators.datastore.CloudDatastoreExportEntitiesOperator' + ), + }, + 'datastore_import_operator': { + 'CloudDatastoreImportEntitiesOperator': ( + 'airflow.providers.google.cloud.operators.datastore.CloudDatastoreImportEntitiesOperator' + ), + 'DatastoreImportOperator': ( + 'airflow.providers.google.cloud.operators.datastore.CloudDatastoreImportEntitiesOperator' + ), + }, + 'dingding_operator': { + 'DingdingOperator': 'airflow.providers.dingding.operators.dingding.DingdingOperator', + }, + 'discord_webhook_operator': { + 'DiscordWebhookOperator': ( + 'airflow.providers.discord.operators.discord_webhook.DiscordWebhookOperator' + ), + }, + 'docker_swarm_operator': { + 'DockerSwarmOperator': 'airflow.providers.docker.operators.docker_swarm.DockerSwarmOperator', + }, + 'druid_operator': { + 'DruidOperator': 'airflow.providers.apache.druid.operators.druid.DruidOperator', + }, + 'dynamodb_to_s3': { + 'DynamoDBToS3Operator': 'airflow.providers.amazon.aws.transfers.dynamodb_to_s3.DynamoDBToS3Operator', + }, + 'ecs_operator': { + 'EcsProtocol': 'airflow.providers.amazon.aws.hooks.ecs.EcsProtocol', + 'EcsRunTaskOperator': 'airflow.providers.amazon.aws.operators.ecs.EcsRunTaskOperator', + 'EcsOperator': 'airflow.providers.amazon.aws.operators.ecs.EcsRunTaskOperator', + }, + 'file_to_gcs': { + 'LocalFilesystemToGCSOperator': ( + 'airflow.providers.google.cloud.transfers.local_to_gcs.LocalFilesystemToGCSOperator' + ), + 'FileToGoogleCloudStorageOperator': ( + 'airflow.providers.google.cloud.transfers.local_to_gcs.LocalFilesystemToGCSOperator' + ), + }, + 'file_to_wasb': { + 'LocalFilesystemToWasbOperator': ( + 'airflow.providers.microsoft.azure.transfers.local_to_wasb.LocalFilesystemToWasbOperator' + ), + 'FileToWasbOperator': ( + 'airflow.providers.microsoft.azure.transfers.local_to_wasb.LocalFilesystemToWasbOperator' + ), + }, + 'gcp_bigtable_operator': { + 'BigtableCreateInstanceOperator': ( + 'airflow.providers.google.cloud.operators.bigtable.BigtableCreateInstanceOperator' + ), + 'BigtableCreateTableOperator': ( + 'airflow.providers.google.cloud.operators.bigtable.BigtableCreateTableOperator' + ), + 'BigtableDeleteInstanceOperator': ( + 'airflow.providers.google.cloud.operators.bigtable.BigtableDeleteInstanceOperator' + ), + 'BigtableDeleteTableOperator': ( + 'airflow.providers.google.cloud.operators.bigtable.BigtableDeleteTableOperator' + ), + 'BigtableUpdateClusterOperator': ( + 'airflow.providers.google.cloud.operators.bigtable.BigtableUpdateClusterOperator' + ), + 'BigtableTableReplicationCompletedSensor': ( + 'airflow.providers.google.cloud.sensors.bigtable.BigtableTableReplicationCompletedSensor' + ), + 'BigtableClusterUpdateOperator': ( + 'airflow.providers.google.cloud.operators.bigtable.BigtableUpdateClusterOperator' + ), + 'BigtableInstanceCreateOperator': ( + 'airflow.providers.google.cloud.operators.bigtable.BigtableCreateInstanceOperator' + ), + 'BigtableInstanceDeleteOperator': ( + 'airflow.providers.google.cloud.operators.bigtable.BigtableDeleteInstanceOperator' + ), + 'BigtableTableCreateOperator': ( + 'airflow.providers.google.cloud.operators.bigtable.BigtableCreateTableOperator' + ), + 'BigtableTableDeleteOperator': ( + 'airflow.providers.google.cloud.operators.bigtable.BigtableDeleteTableOperator' + ), + 'BigtableTableWaitForReplicationSensor': ( + 'airflow.providers.google.cloud.sensors.bigtable.BigtableTableReplicationCompletedSensor' + ), + }, + 'gcp_cloud_build_operator': { + 'CloudBuildCreateBuildOperator': ( + 'airflow.providers.google.cloud.operators.cloud_build.CloudBuildCreateBuildOperator' + ), + }, + 'gcp_compute_operator': { + 'ComputeEngineBaseOperator': ( + 'airflow.providers.google.cloud.operators.compute.ComputeEngineBaseOperator' + ), + 'ComputeEngineCopyInstanceTemplateOperator': ( + 'airflow.providers.google.cloud.operators.compute.ComputeEngineCopyInstanceTemplateOperator' + ), + 'ComputeEngineInstanceGroupUpdateManagerTemplateOperator': + 'airflow.providers.google.cloud.operators.compute.' + 'ComputeEngineInstanceGroupUpdateManagerTemplateOperator', + 'ComputeEngineSetMachineTypeOperator': ( + 'airflow.providers.google.cloud.operators.compute.ComputeEngineSetMachineTypeOperator' + ), + 'ComputeEngineStartInstanceOperator': ( + 'airflow.providers.google.cloud.operators.compute.ComputeEngineStartInstanceOperator' + ), + 'ComputeEngineStopInstanceOperator': ( + 'airflow.providers.google.cloud.operators.compute.ComputeEngineStopInstanceOperator' + ), + 'GceBaseOperator': 'airflow.providers.google.cloud.operators.compute.ComputeEngineBaseOperator', + 'GceInstanceGroupManagerUpdateTemplateOperator': + 'airflow.providers.google.cloud.operators.compute.' + 'ComputeEngineInstanceGroupUpdateManagerTemplateOperator', + 'GceInstanceStartOperator': ( + 'airflow.providers.google.cloud.operators.compute.ComputeEngineStartInstanceOperator' + ), + 'GceInstanceStopOperator': ( + 'airflow.providers.google.cloud.operators.compute.ComputeEngineStopInstanceOperator' + ), + 'GceInstanceTemplateCopyOperator': ( + 'airflow.providers.google.cloud.operators.compute.ComputeEngineCopyInstanceTemplateOperator' + ), + 'GceSetMachineTypeOperator': ( + 'airflow.providers.google.cloud.operators.compute.ComputeEngineSetMachineTypeOperator' + ), + }, + 'gcp_container_operator': { + 'GKECreateClusterOperator': ( + 'airflow.providers.google.cloud.operators.kubernetes_engine.GKECreateClusterOperator' + ), + 'GKEDeleteClusterOperator': ( + 'airflow.providers.google.cloud.operators.kubernetes_engine.GKEDeleteClusterOperator' + ), + 'GKEStartPodOperator': ( + 'airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator' + ), + 'GKEClusterCreateOperator': ( + 'airflow.providers.google.cloud.operators.kubernetes_engine.GKECreateClusterOperator' + ), + 'GKEClusterDeleteOperator': ( + 'airflow.providers.google.cloud.operators.kubernetes_engine.GKEDeleteClusterOperator' + ), + 'GKEPodOperator': 'airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator', + }, + 'gcp_dlp_operator': { + 'CloudDLPCancelDLPJobOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPCancelDLPJobOperator' + ), + 'CloudDLPCreateDeidentifyTemplateOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPCreateDeidentifyTemplateOperator' + ), + 'CloudDLPCreateDLPJobOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPCreateDLPJobOperator' + ), + 'CloudDLPCreateInspectTemplateOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPCreateInspectTemplateOperator' + ), + 'CloudDLPCreateJobTriggerOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPCreateJobTriggerOperator' + ), + 'CloudDLPCreateStoredInfoTypeOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPCreateStoredInfoTypeOperator' + ), + 'CloudDLPDeidentifyContentOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPDeidentifyContentOperator' + ), + 'CloudDLPDeleteDeidentifyTemplateOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPDeleteDeidentifyTemplateOperator' + ), + 'CloudDLPDeleteDLPJobOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPDeleteDLPJobOperator' + ), + 'CloudDLPDeleteInspectTemplateOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPDeleteInspectTemplateOperator' + ), + 'CloudDLPDeleteJobTriggerOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPDeleteJobTriggerOperator' + ), + 'CloudDLPDeleteStoredInfoTypeOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPDeleteStoredInfoTypeOperator' + ), + 'CloudDLPGetDeidentifyTemplateOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPGetDeidentifyTemplateOperator' + ), + 'CloudDLPGetDLPJobOperator': 'airflow.providers.google.cloud.operators.dlp.CloudDLPGetDLPJobOperator', + 'CloudDLPGetDLPJobTriggerOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPGetDLPJobTriggerOperator' + ), + 'CloudDLPGetInspectTemplateOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPGetInspectTemplateOperator' + ), + 'CloudDLPGetStoredInfoTypeOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPGetStoredInfoTypeOperator' + ), + 'CloudDLPInspectContentOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPInspectContentOperator' + ), + 'CloudDLPListDeidentifyTemplatesOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPListDeidentifyTemplatesOperator' + ), + 'CloudDLPListDLPJobsOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPListDLPJobsOperator' + ), + 'CloudDLPListInfoTypesOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPListInfoTypesOperator' + ), + 'CloudDLPListInspectTemplatesOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPListInspectTemplatesOperator' + ), + 'CloudDLPListJobTriggersOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPListJobTriggersOperator' + ), + 'CloudDLPListStoredInfoTypesOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPListStoredInfoTypesOperator' + ), + 'CloudDLPRedactImageOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPRedactImageOperator' + ), + 'CloudDLPReidentifyContentOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPReidentifyContentOperator' + ), + 'CloudDLPUpdateDeidentifyTemplateOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPUpdateDeidentifyTemplateOperator' + ), + 'CloudDLPUpdateInspectTemplateOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPUpdateInspectTemplateOperator' + ), + 'CloudDLPUpdateJobTriggerOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPUpdateJobTriggerOperator' + ), + 'CloudDLPUpdateStoredInfoTypeOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPUpdateStoredInfoTypeOperator' + ), + 'CloudDLPDeleteDlpJobOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPDeleteDLPJobOperator' + ), + 'CloudDLPGetDlpJobOperator': 'airflow.providers.google.cloud.operators.dlp.CloudDLPGetDLPJobOperator', + 'CloudDLPGetJobTripperOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPGetDLPJobTriggerOperator' + ), + 'CloudDLPListDlpJobsOperator': ( + 'airflow.providers.google.cloud.operators.dlp.CloudDLPListDLPJobsOperator' + ), + }, + 'gcp_function_operator': { + 'CloudFunctionDeleteFunctionOperator': ( + 'airflow.providers.google.cloud.operators.functions.CloudFunctionDeleteFunctionOperator' + ), + 'CloudFunctionDeployFunctionOperator': ( + 'airflow.providers.google.cloud.operators.functions.CloudFunctionDeployFunctionOperator' + ), + 'GcfFunctionDeleteOperator': ( + 'airflow.providers.google.cloud.operators.functions.CloudFunctionDeleteFunctionOperator' + ), + 'GcfFunctionDeployOperator': ( + 'airflow.providers.google.cloud.operators.functions.CloudFunctionDeployFunctionOperator' + ), + }, + 'gcp_natural_language_operator': { + 'CloudNaturalLanguageAnalyzeEntitiesOperator': + 'airflow.providers.google.cloud.operators.natural_language.' + 'CloudNaturalLanguageAnalyzeEntitiesOperator', + 'CloudNaturalLanguageAnalyzeEntitySentimentOperator': + 'airflow.providers.google.cloud.operators.natural_language.' + 'CloudNaturalLanguageAnalyzeEntitySentimentOperator', + 'CloudNaturalLanguageAnalyzeSentimentOperator': + 'airflow.providers.google.cloud.operators.natural_language.' + 'CloudNaturalLanguageAnalyzeSentimentOperator', + 'CloudNaturalLanguageClassifyTextOperator': + 'airflow.providers.google.cloud.operators.natural_language.' + 'CloudNaturalLanguageClassifyTextOperator', + 'CloudLanguageAnalyzeEntitiesOperator': + 'airflow.providers.google.cloud.operators.natural_language.' + 'CloudNaturalLanguageAnalyzeEntitiesOperator', + 'CloudLanguageAnalyzeEntitySentimentOperator': + 'airflow.providers.google.cloud.operators.natural_language.' + 'CloudNaturalLanguageAnalyzeEntitySentimentOperator', + 'CloudLanguageAnalyzeSentimentOperator': + 'airflow.providers.google.cloud.operators.natural_language.' + 'CloudNaturalLanguageAnalyzeSentimentOperator', + 'CloudLanguageClassifyTextOperator': + 'airflow.providers.google.cloud.operators.natural_language.' + 'CloudNaturalLanguageClassifyTextOperator', + }, + 'gcp_spanner_operator': { + 'SpannerDeleteDatabaseInstanceOperator': ( + 'airflow.providers.google.cloud.operators.spanner.SpannerDeleteDatabaseInstanceOperator' + ), + 'SpannerDeleteInstanceOperator': ( + 'airflow.providers.google.cloud.operators.spanner.SpannerDeleteInstanceOperator' + ), + 'SpannerDeployDatabaseInstanceOperator': ( + 'airflow.providers.google.cloud.operators.spanner.SpannerDeployDatabaseInstanceOperator' + ), + 'SpannerDeployInstanceOperator': ( + 'airflow.providers.google.cloud.operators.spanner.SpannerDeployInstanceOperator' + ), + 'SpannerQueryDatabaseInstanceOperator': ( + 'airflow.providers.google.cloud.operators.spanner.SpannerQueryDatabaseInstanceOperator' + ), + 'SpannerUpdateDatabaseInstanceOperator': ( + 'airflow.providers.google.cloud.operators.spanner.SpannerUpdateDatabaseInstanceOperator' + ), + 'CloudSpannerInstanceDatabaseDeleteOperator': ( + 'airflow.providers.google.cloud.operators.spanner.SpannerDeleteDatabaseInstanceOperator' + ), + 'CloudSpannerInstanceDatabaseDeployOperator': ( + 'airflow.providers.google.cloud.operators.spanner.SpannerDeployDatabaseInstanceOperator' + ), + 'CloudSpannerInstanceDatabaseQueryOperator': ( + 'airflow.providers.google.cloud.operators.spanner.SpannerQueryDatabaseInstanceOperator' + ), + 'CloudSpannerInstanceDatabaseUpdateOperator': ( + 'airflow.providers.google.cloud.operators.spanner.SpannerUpdateDatabaseInstanceOperator' + ), + 'CloudSpannerInstanceDeleteOperator': ( + 'airflow.providers.google.cloud.operators.spanner.SpannerDeleteInstanceOperator' + ), + 'CloudSpannerInstanceDeployOperator': ( + 'airflow.providers.google.cloud.operators.spanner.SpannerDeployInstanceOperator' + ), + }, + 'gcp_speech_to_text_operator': { + 'CloudSpeechToTextRecognizeSpeechOperator': ( + 'airflow.providers.google.cloud.operators.speech_to_text.CloudSpeechToTextRecognizeSpeechOperator' + ), + 'GcpSpeechToTextRecognizeSpeechOperator': ( + 'airflow.providers.google.cloud.operators.speech_to_text.CloudSpeechToTextRecognizeSpeechOperator' + ), + }, + 'gcp_sql_operator': { + 'CloudSQLBaseOperator': 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLBaseOperator', + 'CloudSQLCreateInstanceDatabaseOperator': ( + 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLCreateInstanceDatabaseOperator' + ), + 'CloudSQLCreateInstanceOperator': ( + 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLCreateInstanceOperator' + ), + 'CloudSQLDeleteInstanceDatabaseOperator': ( + 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLDeleteInstanceDatabaseOperator' + ), + 'CloudSQLDeleteInstanceOperator': ( + 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLDeleteInstanceOperator' + ), + 'CloudSQLExecuteQueryOperator': ( + 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLExecuteQueryOperator' + ), + 'CloudSQLExportInstanceOperator': ( + 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLExportInstanceOperator' + ), + 'CloudSQLImportInstanceOperator': ( + 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLImportInstanceOperator' + ), + 'CloudSQLInstancePatchOperator': ( + 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLInstancePatchOperator' + ), + 'CloudSQLPatchInstanceDatabaseOperator': ( + 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLPatchInstanceDatabaseOperator' + ), + 'CloudSqlBaseOperator': 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLBaseOperator', + 'CloudSqlInstanceCreateOperator': ( + 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLCreateInstanceOperator' + ), + 'CloudSqlInstanceDatabaseCreateOperator': ( + 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLCreateInstanceDatabaseOperator' + ), + 'CloudSqlInstanceDatabaseDeleteOperator': ( + 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLDeleteInstanceDatabaseOperator' + ), + 'CloudSqlInstanceDatabasePatchOperator': ( + 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLPatchInstanceDatabaseOperator' + ), + 'CloudSqlInstanceDeleteOperator': ( + 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLDeleteInstanceOperator' + ), + 'CloudSqlInstanceExportOperator': ( + 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLExportInstanceOperator' + ), + 'CloudSqlInstanceImportOperator': ( + 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLImportInstanceOperator' + ), + 'CloudSqlInstancePatchOperator': ( + 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLInstancePatchOperator' + ), + 'CloudSqlQueryOperator': ( + 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLExecuteQueryOperator' + ), + }, + 'gcp_tasks_operator': { + 'CloudTasksQueueCreateOperator': ( + 'airflow.providers.google.cloud.operators.tasks.CloudTasksQueueCreateOperator' + ), + 'CloudTasksQueueDeleteOperator': ( + 'airflow.providers.google.cloud.operators.tasks.CloudTasksQueueDeleteOperator' + ), + 'CloudTasksQueueGetOperator': ( + 'airflow.providers.google.cloud.operators.tasks.CloudTasksQueueGetOperator' + ), + 'CloudTasksQueuePauseOperator': ( + 'airflow.providers.google.cloud.operators.tasks.CloudTasksQueuePauseOperator' + ), + 'CloudTasksQueuePurgeOperator': ( + 'airflow.providers.google.cloud.operators.tasks.CloudTasksQueuePurgeOperator' + ), + 'CloudTasksQueueResumeOperator': ( + 'airflow.providers.google.cloud.operators.tasks.CloudTasksQueueResumeOperator' + ), + 'CloudTasksQueuesListOperator': ( + 'airflow.providers.google.cloud.operators.tasks.CloudTasksQueuesListOperator' + ), + 'CloudTasksQueueUpdateOperator': ( + 'airflow.providers.google.cloud.operators.tasks.CloudTasksQueueUpdateOperator' + ), + 'CloudTasksTaskCreateOperator': ( + 'airflow.providers.google.cloud.operators.tasks.CloudTasksTaskCreateOperator' + ), + 'CloudTasksTaskDeleteOperator': ( + 'airflow.providers.google.cloud.operators.tasks.CloudTasksTaskDeleteOperator' + ), + 'CloudTasksTaskGetOperator': ( + 'airflow.providers.google.cloud.operators.tasks.CloudTasksTaskGetOperator' + ), + 'CloudTasksTaskRunOperator': ( + 'airflow.providers.google.cloud.operators.tasks.CloudTasksTaskRunOperator' + ), + 'CloudTasksTasksListOperator': ( + 'airflow.providers.google.cloud.operators.tasks.CloudTasksTasksListOperator' + ), + }, + 'gcp_text_to_speech_operator': { + 'CloudTextToSpeechSynthesizeOperator': ( + 'airflow.providers.google.cloud.operators.text_to_speech.CloudTextToSpeechSynthesizeOperator' + ), + 'GcpTextToSpeechSynthesizeOperator': ( + 'airflow.providers.google.cloud.operators.text_to_speech.CloudTextToSpeechSynthesizeOperator' + ), + }, + 'gcp_transfer_operator': { + 'CloudDataTransferServiceCancelOperationOperator': + 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.' + 'CloudDataTransferServiceCancelOperationOperator', + 'CloudDataTransferServiceCreateJobOperator': + 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.' + 'CloudDataTransferServiceCreateJobOperator', + 'CloudDataTransferServiceDeleteJobOperator': + 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.' + 'CloudDataTransferServiceDeleteJobOperator', + 'CloudDataTransferServiceGCSToGCSOperator': + 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.' + 'CloudDataTransferServiceGCSToGCSOperator', + 'CloudDataTransferServiceGetOperationOperator': + 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.' + 'CloudDataTransferServiceGetOperationOperator', + 'CloudDataTransferServiceListOperationsOperator': + 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.' + 'CloudDataTransferServiceListOperationsOperator', + 'CloudDataTransferServicePauseOperationOperator': + 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.' + 'CloudDataTransferServicePauseOperationOperator', + 'CloudDataTransferServiceResumeOperationOperator': + 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.' + 'CloudDataTransferServiceResumeOperationOperator', + 'CloudDataTransferServiceS3ToGCSOperator': + 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.' + 'CloudDataTransferServiceS3ToGCSOperator', + 'CloudDataTransferServiceUpdateJobOperator': + 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.' + 'CloudDataTransferServiceUpdateJobOperator', + 'GcpTransferServiceJobCreateOperator': + 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.' + 'CloudDataTransferServiceCreateJobOperator', + 'GcpTransferServiceJobDeleteOperator': + 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.' + 'CloudDataTransferServiceDeleteJobOperator', + 'GcpTransferServiceJobUpdateOperator': + 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.' + 'CloudDataTransferServiceUpdateJobOperator', + 'GcpTransferServiceOperationCancelOperator': + 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.' + 'CloudDataTransferServiceCancelOperationOperator', + 'GcpTransferServiceOperationGetOperator': + 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.' + 'CloudDataTransferServiceGetOperationOperator', + 'GcpTransferServiceOperationPauseOperator': + 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.' + 'CloudDataTransferServicePauseOperationOperator', + 'GcpTransferServiceOperationResumeOperator': + 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.' + 'CloudDataTransferServiceResumeOperationOperator', + 'GcpTransferServiceOperationsListOperator': + 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.' + 'CloudDataTransferServiceListOperationsOperator', + 'GoogleCloudStorageToGoogleCloudStorageTransferOperator': + 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.' + 'CloudDataTransferServiceGCSToGCSOperator', + 'S3ToGoogleCloudStorageTransferOperator': + 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.' + 'CloudDataTransferServiceS3ToGCSOperator', + }, + 'gcp_translate_operator': { + 'CloudTranslateTextOperator': ( + 'airflow.providers.google.cloud.operators.translate.CloudTranslateTextOperator' + ), + }, + 'gcp_translate_speech_operator': { + 'CloudTranslateSpeechOperator': ( + 'airflow.providers.google.cloud.operators.translate_speech.CloudTranslateSpeechOperator' + ), + 'GcpTranslateSpeechOperator': ( + 'airflow.providers.google.cloud.operators.translate_speech.CloudTranslateSpeechOperator' + ), + }, + 'gcp_video_intelligence_operator': { + 'CloudVideoIntelligenceDetectVideoExplicitContentOperator': + 'airflow.providers.google.cloud.operators.video_intelligence.' + 'CloudVideoIntelligenceDetectVideoExplicitContentOperator', + 'CloudVideoIntelligenceDetectVideoLabelsOperator': + 'airflow.providers.google.cloud.operators.video_intelligence.' + 'CloudVideoIntelligenceDetectVideoLabelsOperator', + 'CloudVideoIntelligenceDetectVideoShotsOperator': + 'airflow.providers.google.cloud.operators.video_intelligence.' + 'CloudVideoIntelligenceDetectVideoShotsOperator', + }, + 'gcp_vision_operator': { + 'CloudVisionAddProductToProductSetOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionAddProductToProductSetOperator' + ), + 'CloudVisionCreateProductOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductOperator' + ), + 'CloudVisionCreateProductSetOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductSetOperator' + ), + 'CloudVisionCreateReferenceImageOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionCreateReferenceImageOperator' + ), + 'CloudVisionDeleteProductOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductOperator' + ), + 'CloudVisionDeleteProductSetOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductSetOperator' + ), + 'CloudVisionDetectImageLabelsOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionDetectImageLabelsOperator' + ), + 'CloudVisionDetectImageSafeSearchOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionDetectImageSafeSearchOperator' + ), + 'CloudVisionDetectTextOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionDetectTextOperator' + ), + 'CloudVisionGetProductOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionGetProductOperator' + ), + 'CloudVisionGetProductSetOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionGetProductSetOperator' + ), + 'CloudVisionImageAnnotateOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionImageAnnotateOperator' + ), + 'CloudVisionRemoveProductFromProductSetOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionRemoveProductFromProductSetOperator' + ), + 'CloudVisionTextDetectOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionTextDetectOperator' + ), + 'CloudVisionUpdateProductOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductOperator' + ), + 'CloudVisionUpdateProductSetOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductSetOperator' + ), + 'CloudVisionAnnotateImageOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionImageAnnotateOperator' + ), + 'CloudVisionDetectDocumentTextOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionTextDetectOperator' + ), + 'CloudVisionProductCreateOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductOperator' + ), + 'CloudVisionProductDeleteOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductOperator' + ), + 'CloudVisionProductGetOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionGetProductOperator' + ), + 'CloudVisionProductSetCreateOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductSetOperator' + ), + 'CloudVisionProductSetDeleteOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductSetOperator' + ), + 'CloudVisionProductSetGetOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionGetProductSetOperator' + ), + 'CloudVisionProductSetUpdateOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductSetOperator' + ), + 'CloudVisionProductUpdateOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductOperator' + ), + 'CloudVisionReferenceImageCreateOperator': ( + 'airflow.providers.google.cloud.operators.vision.CloudVisionCreateReferenceImageOperator' + ), + }, + 'gcs_acl_operator': { + 'GCSBucketCreateAclEntryOperator': ( + 'airflow.providers.google.cloud.operators.gcs.GCSBucketCreateAclEntryOperator' + ), + 'GCSObjectCreateAclEntryOperator': ( + 'airflow.providers.google.cloud.operators.gcs.GCSObjectCreateAclEntryOperator' + ), + 'GoogleCloudStorageBucketCreateAclEntryOperator': ( + 'airflow.providers.google.cloud.operators.gcs.GCSBucketCreateAclEntryOperator' + ), + 'GoogleCloudStorageObjectCreateAclEntryOperator': ( + 'airflow.providers.google.cloud.operators.gcs.GCSObjectCreateAclEntryOperator' + ), + }, + 'gcs_delete_operator': { + 'GCSDeleteObjectsOperator': 'airflow.providers.google.cloud.operators.gcs.GCSDeleteObjectsOperator', + 'GoogleCloudStorageDeleteOperator': ( + 'airflow.providers.google.cloud.operators.gcs.GCSDeleteObjectsOperator' + ), + }, + 'gcs_download_operator': { + 'GCSToLocalFilesystemOperator': ( + 'airflow.providers.google.cloud.transfers.gcs_to_local.GCSToLocalFilesystemOperator' + ), + 'GoogleCloudStorageDownloadOperator': ( + 'airflow.providers.google.cloud.transfers.gcs_to_local.GCSToLocalFilesystemOperator' + ), + }, + 'gcs_list_operator': { + 'GCSListObjectsOperator': 'airflow.providers.google.cloud.operators.gcs.GCSListObjectsOperator', + 'GoogleCloudStorageListOperator': ( + 'airflow.providers.google.cloud.operators.gcs.GCSListObjectsOperator' + ), + }, + 'gcs_operator': { + 'GCSCreateBucketOperator': 'airflow.providers.google.cloud.operators.gcs.GCSCreateBucketOperator', + 'GoogleCloudStorageCreateBucketOperator': ( + 'airflow.providers.google.cloud.operators.gcs.GCSCreateBucketOperator' + ), + }, + 'gcs_to_bq': { + 'GCSToBigQueryOperator': ( + 'airflow.providers.google.cloud.transfers.gcs_to_bigquery.GCSToBigQueryOperator' + ), + 'GoogleCloudStorageToBigQueryOperator': ( + 'airflow.providers.google.cloud.transfers.gcs_to_bigquery.GCSToBigQueryOperator' + ), + }, + 'gcs_to_gcs': { + 'GCSToGCSOperator': 'airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSToGCSOperator', + 'GoogleCloudStorageToGoogleCloudStorageOperator': ( + 'airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSToGCSOperator' + ), + }, + 'gcs_to_gdrive_operator': { + 'GCSToGoogleDriveOperator': ( + 'airflow.providers.google.suite.transfers.gcs_to_gdrive.GCSToGoogleDriveOperator' + ), + }, + 'gcs_to_s3': { + 'GCSToS3Operator': 'airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSToS3Operator', + 'GoogleCloudStorageToS3Operator': 'airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSToS3Operator', + }, + 'grpc_operator': { + 'GrpcOperator': 'airflow.providers.grpc.operators.grpc.GrpcOperator', + }, + 'hive_to_dynamodb': { + 'HiveToDynamoDBOperator': ( + 'airflow.providers.amazon.aws.transfers.hive_to_dynamodb.HiveToDynamoDBOperator' + ), + }, + 'imap_attachment_to_s3_operator': { + 'ImapAttachmentToS3Operator': ( + 'airflow.providers.amazon.aws.transfers.imap_attachment_to_s3.ImapAttachmentToS3Operator' + ), + }, + 'jenkins_job_trigger_operator': { + 'JenkinsJobTriggerOperator': ( + 'airflow.providers.jenkins.operators.jenkins_job_trigger.JenkinsJobTriggerOperator' + ), + }, + 'jira_operator': { + 'JiraOperator': 'airflow.providers.atlassian.jira.operators.jira.JiraOperator', + }, + 'kubernetes_pod_operator': { + 'KubernetesPodOperator': ( + 'airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator' + ), + }, + 'mlengine_operator': { + 'MLEngineManageModelOperator': ( + 'airflow.providers.google.cloud.operators.mlengine.MLEngineManageModelOperator' + ), + 'MLEngineManageVersionOperator': ( + 'airflow.providers.google.cloud.operators.mlengine.MLEngineManageVersionOperator' + ), + 'MLEngineStartBatchPredictionJobOperator': ( + 'airflow.providers.google.cloud.operators.mlengine.MLEngineStartBatchPredictionJobOperator' + ), + 'MLEngineStartTrainingJobOperator': ( + 'airflow.providers.google.cloud.operators.mlengine.MLEngineStartTrainingJobOperator' + ), + 'MLEngineBatchPredictionOperator': ( + 'airflow.providers.google.cloud.operators.mlengine.MLEngineStartBatchPredictionJobOperator' + ), + 'MLEngineModelOperator': ( + 'airflow.providers.google.cloud.operators.mlengine.MLEngineManageModelOperator' + ), + 'MLEngineTrainingOperator': ( + 'airflow.providers.google.cloud.operators.mlengine.MLEngineStartTrainingJobOperator' + ), + 'MLEngineVersionOperator': ( + 'airflow.providers.google.cloud.operators.mlengine.MLEngineManageVersionOperator' + ), + }, + 'mongo_to_s3': { + 'MongoToS3Operator': 'airflow.providers.amazon.aws.transfers.mongo_to_s3.MongoToS3Operator', + }, + 'mssql_to_gcs': { + 'MSSQLToGCSOperator': 'airflow.providers.google.cloud.transfers.mssql_to_gcs.MSSQLToGCSOperator', + 'MsSqlToGoogleCloudStorageOperator': ( + 'airflow.providers.google.cloud.transfers.mssql_to_gcs.MSSQLToGCSOperator' + ), + }, + 'mysql_to_gcs': { + 'MySQLToGCSOperator': 'airflow.providers.google.cloud.transfers.mysql_to_gcs.MySQLToGCSOperator', + 'MySqlToGoogleCloudStorageOperator': ( + 'airflow.providers.google.cloud.transfers.mysql_to_gcs.MySQLToGCSOperator' + ), + }, + 'opsgenie_alert_operator': { + 'OpsgenieCreateAlertOperator': ( + 'airflow.providers.opsgenie.operators.opsgenie.OpsgenieCreateAlertOperator' + ), + 'OpsgenieAlertOperator': 'airflow.providers.opsgenie.operators.opsgenie.OpsgenieCreateAlertOperator', + }, + 'oracle_to_azure_data_lake_transfer': { + 'OracleToAzureDataLakeOperator': + 'airflow.providers.microsoft.azure.transfers.' + 'oracle_to_azure_data_lake.OracleToAzureDataLakeOperator', + }, + 'oracle_to_oracle_transfer': { + 'OracleToOracleOperator': ( + 'airflow.providers.oracle.transfers.oracle_to_oracle.OracleToOracleOperator' + ), + 'OracleToOracleTransfer': ( + 'airflow.providers.oracle.transfers.oracle_to_oracle.OracleToOracleOperator' + ), + }, + 'postgres_to_gcs_operator': { + 'PostgresToGCSOperator': ( + 'airflow.providers.google.cloud.transfers.postgres_to_gcs.PostgresToGCSOperator' + ), + 'PostgresToGoogleCloudStorageOperator': ( + 'airflow.providers.google.cloud.transfers.postgres_to_gcs.PostgresToGCSOperator' + ), + }, + 'pubsub_operator': { + 'PubSubCreateSubscriptionOperator': ( + 'airflow.providers.google.cloud.operators.pubsub.PubSubCreateSubscriptionOperator' + ), + 'PubSubCreateTopicOperator': ( + 'airflow.providers.google.cloud.operators.pubsub.PubSubCreateTopicOperator' + ), + 'PubSubDeleteSubscriptionOperator': ( + 'airflow.providers.google.cloud.operators.pubsub.PubSubDeleteSubscriptionOperator' + ), + 'PubSubDeleteTopicOperator': ( + 'airflow.providers.google.cloud.operators.pubsub.PubSubDeleteTopicOperator' + ), + 'PubSubPublishMessageOperator': ( + 'airflow.providers.google.cloud.operators.pubsub.PubSubPublishMessageOperator' + ), + 'PubSubPublishOperator': ( + 'airflow.providers.google.cloud.operators.pubsub.PubSubPublishMessageOperator' + ), + 'PubSubSubscriptionCreateOperator': ( + 'airflow.providers.google.cloud.operators.pubsub.PubSubCreateSubscriptionOperator' + ), + 'PubSubSubscriptionDeleteOperator': ( + 'airflow.providers.google.cloud.operators.pubsub.PubSubDeleteSubscriptionOperator' + ), + 'PubSubTopicCreateOperator': ( + 'airflow.providers.google.cloud.operators.pubsub.PubSubCreateTopicOperator' + ), + 'PubSubTopicDeleteOperator': ( + 'airflow.providers.google.cloud.operators.pubsub.PubSubDeleteTopicOperator' + ), + }, + 'qubole_check_operator': { + 'QuboleCheckOperator': 'airflow.providers.qubole.operators.qubole_check.QuboleCheckOperator', + 'QuboleValueCheckOperator': ( + 'airflow.providers.qubole.operators.qubole_check.QuboleValueCheckOperator' + ), + }, + 'qubole_operator': { + 'QuboleOperator': 'airflow.providers.qubole.operators.qubole.QuboleOperator', + }, + 'redis_publish_operator': { + 'RedisPublishOperator': 'airflow.providers.redis.operators.redis_publish.RedisPublishOperator', + }, + 's3_to_gcs_operator': { + 'S3ToGCSOperator': 'airflow.providers.google.cloud.transfers.s3_to_gcs.S3ToGCSOperator', + }, + 's3_to_gcs_transfer_operator': { + 'CloudDataTransferServiceS3ToGCSOperator': + 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.' + 'CloudDataTransferServiceS3ToGCSOperator', + }, + 's3_to_sftp_operator': { + 'S3ToSFTPOperator': 'airflow.providers.amazon.aws.transfers.s3_to_sftp.S3ToSFTPOperator', + }, + 'segment_track_event_operator': { + 'SegmentTrackEventOperator': ( + 'airflow.providers.segment.operators.segment_track_event.SegmentTrackEventOperator' + ), + }, + 'sftp_operator': { + 'SFTPOperator': 'airflow.providers.sftp.operators.sftp.SFTPOperator', + }, + 'sftp_to_s3_operator': { + 'SFTPToS3Operator': 'airflow.providers.amazon.aws.transfers.sftp_to_s3.SFTPToS3Operator', + }, + 'slack_webhook_operator': { + 'SlackWebhookOperator': 'airflow.providers.slack.operators.slack_webhook.SlackWebhookOperator', + }, + 'snowflake_operator': { + 'SnowflakeOperator': 'airflow.providers.snowflake.operators.snowflake.SnowflakeOperator', + }, + 'sns_publish_operator': { + 'SnsPublishOperator': 'airflow.providers.amazon.aws.operators.sns.SnsPublishOperator', + }, + 'spark_jdbc_operator': { + 'SparkJDBCOperator': 'airflow.providers.apache.spark.operators.spark_jdbc.SparkJDBCOperator', + 'SparkSubmitOperator': 'airflow.providers.apache.spark.operators.spark_jdbc.SparkSubmitOperator', + }, + 'spark_sql_operator': { + 'SparkSqlOperator': 'airflow.providers.apache.spark.operators.spark_sql.SparkSqlOperator', + }, + 'spark_submit_operator': { + 'SparkSubmitOperator': 'airflow.providers.apache.spark.operators.spark_submit.SparkSubmitOperator', + }, + 'sql_to_gcs': { + 'BaseSQLToGCSOperator': 'airflow.providers.google.cloud.transfers.sql_to_gcs.BaseSQLToGCSOperator', + 'BaseSQLToGoogleCloudStorageOperator': ( + 'airflow.providers.google.cloud.transfers.sql_to_gcs.BaseSQLToGCSOperator' + ), + }, + 'sqoop_operator': { + 'SqoopOperator': 'airflow.providers.apache.sqoop.operators.sqoop.SqoopOperator', + }, + 'ssh_operator': { + 'SSHOperator': 'airflow.providers.ssh.operators.ssh.SSHOperator', + }, + 'vertica_operator': { + 'VerticaOperator': 'airflow.providers.vertica.operators.vertica.VerticaOperator', + }, + 'vertica_to_hive': { + 'VerticaToHiveOperator': ( + 'airflow.providers.apache.hive.transfers.vertica_to_hive.VerticaToHiveOperator' + ), + 'VerticaToHiveTransfer': ( + 'airflow.providers.apache.hive.transfers.vertica_to_hive.VerticaToHiveOperator' + ), + }, + 'vertica_to_mysql': { + 'VerticaToMySqlOperator': 'airflow.providers.mysql.transfers.vertica_to_mysql.VerticaToMySqlOperator', + 'VerticaToMySqlTransfer': 'airflow.providers.mysql.transfers.vertica_to_mysql.VerticaToMySqlOperator', + }, + 'wasb_delete_blob_operator': { + 'WasbDeleteBlobOperator': ( + 'airflow.providers.microsoft.azure.operators.wasb_delete_blob.WasbDeleteBlobOperator' + ), + }, + 'winrm_operator': { + 'WinRMOperator': 'airflow.providers.microsoft.winrm.operators.winrm.WinRMOperator', + }, +} + +add_deprecated_classes(__deprecated_classes, __name__) diff --git a/airflow/contrib/operators/adls_list_operator.py b/airflow/contrib/operators/adls_list_operator.py deleted file mode 100644 index d4d1394f89fe0..0000000000000 --- a/airflow/contrib/operators/adls_list_operator.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.operators.adls`.""" - -import warnings - -from airflow.providers.microsoft.azure.operators.adls import ADLSListOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.operators.adls`.", - DeprecationWarning, - stacklevel=2, -) - - -class AzureDataLakeStorageListOperator(ADLSListOperator): - """ - This class is deprecated. - Please use Please use :mod:`airflow.providers.microsoft.azure.operators.adls.ADLSListOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use Please use :mod:`airflow.providers.microsoft.azure.operators.adls.ADLSListOperator`""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/adls_to_gcs.py b/airflow/contrib/operators/adls_to_gcs.py deleted file mode 100644 index 0497d4c259023..0000000000000 --- a/airflow/contrib/operators/adls_to_gcs.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.transfers.adls_to_gcs`.""" - -import warnings - -from airflow.providers.google.cloud.transfers.adls_to_gcs import ADLSToGCSOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.adls_to_gcs`.", - DeprecationWarning, - stacklevel=2, -) - - -class AdlsToGoogleCloudStorageOperator(ADLSToGCSOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.transfers.adls_to_gcs.ADLSToGCSOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.transfers.adls_to_gcs.ADLSToGCSOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/aws_athena_operator.py b/airflow/contrib/operators/aws_athena_operator.py deleted file mode 100644 index e799c74635ed2..0000000000000 --- a/airflow/contrib/operators/aws_athena_operator.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.athena`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.athena import AWSAthenaOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.athena`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/aws_sqs_publish_operator.py b/airflow/contrib/operators/aws_sqs_publish_operator.py deleted file mode 100644 index 0ecc8a64f359d..0000000000000 --- a/airflow/contrib/operators/aws_sqs_publish_operator.py +++ /dev/null @@ -1,44 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sqs`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.sqs import SqsPublishOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sqs`.", - DeprecationWarning, - stacklevel=2, -) - - -class SQSPublishOperator(SqsPublishOperator): - """ - This class is deprecated. - Please use :class:`airflow.providers.amazon.aws.operators.sqs.SqsPublishOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This class is deprecated. " - "Please use `airflow.providers.amazon.aws.operators.sqs.SqsPublishOperator`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/awsbatch_operator.py b/airflow/contrib/operators/awsbatch_operator.py deleted file mode 100644 index a6be224cb93eb..0000000000000 --- a/airflow/contrib/operators/awsbatch_operator.py +++ /dev/null @@ -1,75 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -""" -This module is deprecated. Please use: - -- :mod:`airflow.providers.amazon.aws.operators.batch` -- :mod:`airflow.providers.amazon.aws.hooks.batch_client` -- :mod:`airflow.providers.amazon.aws.hooks.batch_waiters`` -""" - -import warnings - -from airflow.providers.amazon.aws.hooks.batch_client import AwsBatchProtocol -from airflow.providers.amazon.aws.operators.batch import AwsBatchOperator -from airflow.typing_compat import Protocol, runtime_checkable - -warnings.warn( - "This module is deprecated. " - "Please use `airflow.providers.amazon.aws.operators.batch`, " - "`airflow.providers.amazon.aws.hooks.batch_client`, and " - "`airflow.providers.amazon.aws.hooks.batch_waiters`", - DeprecationWarning, - stacklevel=2, -) - - -class AWSBatchOperator(AwsBatchOperator): - """ - This class is deprecated. Please use - `airflow.providers.amazon.aws.operators.batch.AwsBatchOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.amazon.aws.operators.batch.AwsBatchOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -@runtime_checkable -class BatchProtocol(AwsBatchProtocol, Protocol): - """ - This class is deprecated. Please use - `airflow.providers.amazon.aws.hooks.batch_client.AwsBatchProtocol`. - """ - - # A Protocol cannot be instantiated - - def __new__(cls, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.amazon.aws.hooks.batch_client.AwsBatchProtocol`.""", - DeprecationWarning, - stacklevel=2, - ) diff --git a/airflow/contrib/operators/azure_container_instances_operator.py b/airflow/contrib/operators/azure_container_instances_operator.py deleted file mode 100644 index d084748ca06df..0000000000000 --- a/airflow/contrib/operators/azure_container_instances_operator.py +++ /dev/null @@ -1,33 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. Please use -`airflow.providers.microsoft.azure.operators.container_instances`. -""" -import warnings - -from airflow.providers.microsoft.azure.operators.container_instances import ( # noqa - AzureContainerInstancesOperator, -) - -warnings.warn( - "This module is deprecated. " - "Please use `airflow.providers.microsoft.azure.operators.container_instances`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/azure_cosmos_operator.py b/airflow/contrib/operators/azure_cosmos_operator.py deleted file mode 100644 index 269c8357c02d3..0000000000000 --- a/airflow/contrib/operators/azure_cosmos_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.operators.cosmos`.""" - -import warnings - -from airflow.providers.microsoft.azure.operators.cosmos import AzureCosmosInsertDocumentOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.operators.cosmos`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/bigquery_check_operator.py b/airflow/contrib/operators/bigquery_check_operator.py deleted file mode 100644 index 39b658cacab8d..0000000000000 --- a/airflow/contrib/operators/bigquery_check_operator.py +++ /dev/null @@ -1,32 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.bigquery`.""" - -import warnings - -from airflow.providers.google.cloud.operators.bigquery import ( # noqa - BigQueryCheckOperator, - BigQueryIntervalCheckOperator, - BigQueryValueCheckOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.bigquery`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/bigquery_get_data.py b/airflow/contrib/operators/bigquery_get_data.py deleted file mode 100644 index 00c8575df3d70..0000000000000 --- a/airflow/contrib/operators/bigquery_get_data.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.bigquery`.""" - -import warnings - -from airflow.providers.google.cloud.operators.bigquery import BigQueryGetDataOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.bigquery`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/bigquery_operator.py b/airflow/contrib/operators/bigquery_operator.py deleted file mode 100644 index 6fe8f0816263f..0000000000000 --- a/airflow/contrib/operators/bigquery_operator.py +++ /dev/null @@ -1,55 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.bigquery`.""" - -import warnings - -from airflow.providers.google.cloud.operators.bigquery import ( # noqa; noqa; noqa; noqa; noqa - BigQueryCreateEmptyDatasetOperator, - BigQueryCreateEmptyTableOperator, - BigQueryCreateExternalTableOperator, - BigQueryDeleteDatasetOperator, - BigQueryExecuteQueryOperator, - BigQueryGetDatasetOperator, - BigQueryGetDatasetTablesOperator, - BigQueryPatchDatasetOperator, - BigQueryUpdateDatasetOperator, - BigQueryUpsertTableOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.bigquery`.", - DeprecationWarning, - stacklevel=2, -) - - -class BigQueryOperator(BigQueryExecuteQueryOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.bigquery.BigQueryExecuteQueryOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.bigquery.BigQueryExecuteQueryOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/bigquery_table_delete_operator.py b/airflow/contrib/operators/bigquery_table_delete_operator.py deleted file mode 100644 index 13822a1844409..0000000000000 --- a/airflow/contrib/operators/bigquery_table_delete_operator.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.bigquery`.""" - -import warnings - -from airflow.providers.google.cloud.operators.bigquery import BigQueryDeleteTableOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.bigquery`.", - DeprecationWarning, - stacklevel=2, -) - - -class BigQueryTableDeleteOperator(BigQueryDeleteTableOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.bigquery.BigQueryDeleteTableOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.bigquery.BigQueryDeleteTableOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/bigquery_to_bigquery.py b/airflow/contrib/operators/bigquery_to_bigquery.py deleted file mode 100644 index 84c26fb718c91..0000000000000 --- a/airflow/contrib/operators/bigquery_to_bigquery.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.google.cloud.transfers.bigquery_to_bigquery`. -""" - -import warnings - -from airflow.providers.google.cloud.transfers.bigquery_to_bigquery import BigQueryToBigQueryOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.bigquery_to_bigquery`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/bigquery_to_gcs.py b/airflow/contrib/operators/bigquery_to_gcs.py deleted file mode 100644 index 702171187e885..0000000000000 --- a/airflow/contrib/operators/bigquery_to_gcs.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.transfers.bigquery_to_gcs`.""" - -import warnings - -from airflow.providers.google.cloud.transfers.bigquery_to_gcs import BigQueryToGCSOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.bigquery_to_gcs`.", - DeprecationWarning, - stacklevel=2, -) - - -class BigQueryToCloudStorageOperator(BigQueryToGCSOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.transfers.bigquery_to_gcs.BigQueryToGCSOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.transfers.bigquery_to_gcs.BigQueryToGCSOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/bigquery_to_mysql_operator.py b/airflow/contrib/operators/bigquery_to_mysql_operator.py deleted file mode 100644 index 401921cff502b..0000000000000 --- a/airflow/contrib/operators/bigquery_to_mysql_operator.py +++ /dev/null @@ -1,30 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. -Please use :mod:`airflow.providers.google.cloud.transfers.bigquery_to_mysql`. -""" - -import warnings - -from airflow.providers.google.cloud.transfers.bigquery_to_mysql import BigQueryToMySqlOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.bigquery_to_mysql`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/cassandra_to_gcs.py b/airflow/contrib/operators/cassandra_to_gcs.py deleted file mode 100644 index bb4b244f0d4cf..0000000000000 --- a/airflow/contrib/operators/cassandra_to_gcs.py +++ /dev/null @@ -1,47 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.google.cloud.transfers.cassandra_to_gcs`. -""" - -import warnings - -from airflow.providers.google.cloud.transfers.cassandra_to_gcs import CassandraToGCSOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.cassandra_to_gcs`.", - DeprecationWarning, - stacklevel=2, -) - - -class CassandraToGoogleCloudStorageOperator(CassandraToGCSOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraToGCSOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraToGCSOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/databricks_operator.py b/airflow/contrib/operators/databricks_operator.py deleted file mode 100644 index b591dd63b8274..0000000000000 --- a/airflow/contrib/operators/databricks_operator.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.databricks.operators.databricks`.""" - -import warnings - -from airflow.providers.databricks.operators.databricks import ( # noqa - DatabricksRunNowOperator, - DatabricksSubmitRunOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.databricks.operators.databricks`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/dataflow_operator.py b/airflow/contrib/operators/dataflow_operator.py deleted file mode 100644 index d2e445add02f1..0000000000000 --- a/airflow/contrib/operators/dataflow_operator.py +++ /dev/null @@ -1,82 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.dataflow`.""" - -import warnings - -from airflow.providers.google.cloud.operators.dataflow import ( - DataflowCreateJavaJobOperator, - DataflowCreatePythonJobOperator, - DataflowTemplatedJobStartOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.dataflow`.", - DeprecationWarning, - stacklevel=2, -) - - -class DataFlowJavaOperator(DataflowCreateJavaJobOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dataflow.DataflowCreateJavaJobOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dataflow.DataflowCreateJavaJobOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class DataFlowPythonOperator(DataflowCreatePythonJobOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dataflow.DataflowCreatePythonJobOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.dataflow.DataflowCreatePythonJobOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class DataflowTemplateOperator(DataflowTemplatedJobStartOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/dataproc_operator.py b/airflow/contrib/operators/dataproc_operator.py deleted file mode 100644 index b655ce630142a..0000000000000 --- a/airflow/contrib/operators/dataproc_operator.py +++ /dev/null @@ -1,244 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.dataproc`.""" - -import warnings - -from airflow.providers.google.cloud.operators.dataproc import ( - DataprocCreateClusterOperator, - DataprocDeleteClusterOperator, - DataprocInstantiateInlineWorkflowTemplateOperator, - DataprocInstantiateWorkflowTemplateOperator, - DataprocJobBaseOperator, - DataprocScaleClusterOperator, - DataprocSubmitHadoopJobOperator, - DataprocSubmitHiveJobOperator, - DataprocSubmitPigJobOperator, - DataprocSubmitPySparkJobOperator, - DataprocSubmitSparkJobOperator, - DataprocSubmitSparkSqlJobOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.dataproc`.", - DeprecationWarning, - stacklevel=2, -) - - -class DataprocClusterCreateOperator(DataprocCreateClusterOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dataproc.DataprocCreateClusterOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dataproc.DataprocCreateClusterOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class DataprocClusterDeleteOperator(DataprocDeleteClusterOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dataproc.DataprocDeleteClusterOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dataproc.DataprocDeleteClusterOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class DataprocClusterScaleOperator(DataprocScaleClusterOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dataproc.DataprocScaleClusterOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dataproc.DataprocScaleClusterOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class DataProcHadoopOperator(DataprocSubmitHadoopJobOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHadoopJobOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHadoopJobOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class DataProcHiveOperator(DataprocSubmitHiveJobOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHiveJobOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHiveJobOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class DataProcJobBaseOperator(DataprocJobBaseOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class DataProcPigOperator(DataprocSubmitPigJobOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPigJobOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPigJobOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class DataProcPySparkOperator(DataprocSubmitPySparkJobOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPySparkJobOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPySparkJobOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class DataProcSparkOperator(DataprocSubmitSparkJobOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkJobOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkJobOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class DataProcSparkSqlOperator(DataprocSubmitSparkSqlJobOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkSqlJobOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkSqlJobOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class DataprocWorkflowTemplateInstantiateInlineOperator(DataprocInstantiateInlineWorkflowTemplateOperator): - """ - This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.dataproc.DataprocInstantiateInlineWorkflowTemplateOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.dataproc - .DataprocInstantiateInlineWorkflowTemplateOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class DataprocWorkflowTemplateInstantiateOperator(DataprocInstantiateWorkflowTemplateOperator): - """ - This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.dataproc.DataprocInstantiateWorkflowTemplateOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.dataproc - .DataprocInstantiateWorkflowTemplateOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/datastore_export_operator.py b/airflow/contrib/operators/datastore_export_operator.py deleted file mode 100644 index 085ee132607da..0000000000000 --- a/airflow/contrib/operators/datastore_export_operator.py +++ /dev/null @@ -1,45 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.datastore`.""" - -import warnings - -from airflow.providers.google.cloud.operators.datastore import CloudDatastoreExportEntitiesOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.datastore`.", - DeprecationWarning, - stacklevel=2, -) - - -class DatastoreExportOperator(CloudDatastoreExportEntitiesOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.datastore.CloudDatastoreExportEntitiesOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated.l - Please use - `airflow.providers.google.cloud.operators.datastore.CloudDatastoreExportEntitiesOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/datastore_import_operator.py b/airflow/contrib/operators/datastore_import_operator.py deleted file mode 100644 index 5b15cd23c9516..0000000000000 --- a/airflow/contrib/operators/datastore_import_operator.py +++ /dev/null @@ -1,45 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.datastore`.""" - -import warnings - -from airflow.providers.google.cloud.operators.datastore import CloudDatastoreImportEntitiesOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.datastore`.", - DeprecationWarning, - stacklevel=2, -) - - -class DatastoreImportOperator(CloudDatastoreImportEntitiesOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.datastore.CloudDatastoreImportEntitiesOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.datastore.CloudDatastoreImportEntitiesOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/dingding_operator.py b/airflow/contrib/operators/dingding_operator.py deleted file mode 100644 index bfe91e8a72491..0000000000000 --- a/airflow/contrib/operators/dingding_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.dingding.operators.dingding`.""" - -import warnings - -from airflow.providers.dingding.operators.dingding import DingdingOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.dingding.operators.dingding`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/discord_webhook_operator.py b/airflow/contrib/operators/discord_webhook_operator.py deleted file mode 100644 index be5809afbfcbd..0000000000000 --- a/airflow/contrib/operators/discord_webhook_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.discord.operators.discord_webhook`.""" - -import warnings - -from airflow.providers.discord.operators.discord_webhook import DiscordWebhookOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.discord.operators.discord_webhook`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/docker_swarm_operator.py b/airflow/contrib/operators/docker_swarm_operator.py deleted file mode 100644 index b023da796a4c8..0000000000000 --- a/airflow/contrib/operators/docker_swarm_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.docker.operators.docker_swarm`.""" - -import warnings - -from airflow.providers.docker.operators.docker_swarm import DockerSwarmOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.docker.operators.docker_swarm`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/druid_operator.py b/airflow/contrib/operators/druid_operator.py deleted file mode 100644 index 20dff77192313..0000000000000 --- a/airflow/contrib/operators/druid_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.druid.operators.druid`.""" - -import warnings - -from airflow.providers.apache.druid.operators.druid import DruidOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.druid.operators.druid`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/dynamodb_to_s3.py b/airflow/contrib/operators/dynamodb_to_s3.py deleted file mode 100644 index a2054007c0678..0000000000000 --- a/airflow/contrib/operators/dynamodb_to_s3.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.transfers.dynamodb_to_s3`.""" - -import warnings - -from airflow.providers.amazon.aws.transfers.dynamodb_to_s3 import DynamoDBToS3Operator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.dynamodb_to_s3`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/ecs_operator.py b/airflow/contrib/operators/ecs_operator.py deleted file mode 100644 index 569df0c284774..0000000000000 --- a/airflow/contrib/operators/ecs_operator.py +++ /dev/null @@ -1,30 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.ecs`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.ecs import ECSOperator, ECSProtocol - -__all__ = ["ECSOperator", "ECSProtocol"] - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.ecs`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/emr_add_steps_operator.py b/airflow/contrib/operators/emr_add_steps_operator.py deleted file mode 100644 index e53f284e447c6..0000000000000 --- a/airflow/contrib/operators/emr_add_steps_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.emr_add_steps`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.emr_add_steps import EmrAddStepsOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_add_steps`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/emr_create_job_flow_operator.py b/airflow/contrib/operators/emr_create_job_flow_operator.py deleted file mode 100644 index 16f1ce7f61f21..0000000000000 --- a/airflow/contrib/operators/emr_create_job_flow_operator.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.emr_create_job_flow`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.emr_create_job_flow import EmrCreateJobFlowOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_create_job_flow`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/emr_terminate_job_flow_operator.py b/airflow/contrib/operators/emr_terminate_job_flow_operator.py deleted file mode 100644 index 7c73bc32dc8db..0000000000000 --- a/airflow/contrib/operators/emr_terminate_job_flow_operator.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.amazon.aws.operators.emr_terminate_job_flow`. -""" - -import warnings - -from airflow.providers.amazon.aws.operators.emr_terminate_job_flow import EmrTerminateJobFlowOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_terminate_job_flow`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/file_to_gcs.py b/airflow/contrib/operators/file_to_gcs.py deleted file mode 100644 index 227cd74979c2f..0000000000000 --- a/airflow/contrib/operators/file_to_gcs.py +++ /dev/null @@ -1,45 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.transfers.local_to_gcs`.""" - -import warnings - -from airflow.providers.google.cloud.transfers.local_to_gcs import LocalFilesystemToGCSOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.local_to_gcs`,", - DeprecationWarning, - stacklevel=2, -) - - -class FileToGoogleCloudStorageOperator(LocalFilesystemToGCSOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.transfers.local_to_gcs.LocalFilesystemToGCSOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.transfers.local_to_gcs.LocalFilesystemToGCSOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/file_to_wasb.py b/airflow/contrib/operators/file_to_wasb.py deleted file mode 100644 index eb6275d5bbd1f..0000000000000 --- a/airflow/contrib/operators/file_to_wasb.py +++ /dev/null @@ -1,45 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.transfers.local_to_wasb`.""" - -import warnings - -from airflow.providers.microsoft.azure.transfers.local_to_wasb import LocalFilesystemToWasbOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.transfers.local_to_wasb`.", - DeprecationWarning, - stacklevel=2, -) - - -class FileToWasbOperator(LocalFilesystemToWasbOperator): - """ - This class is deprecated. - Please use `airflow.providers.microsoft.azure.transfers.local_to_wasb.LocalFilesystemToWasbOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.microsoft.azure.transfers.local_to_wasb.LocalFilesystemToWasbOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_bigtable_operator.py b/airflow/contrib/operators/gcp_bigtable_operator.py deleted file mode 100644 index f45fde7d5a388..0000000000000 --- a/airflow/contrib/operators/gcp_bigtable_operator.py +++ /dev/null @@ -1,136 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. Please use `airflow.providers.google.cloud.operators.bigtable` -or `airflow.providers.google.cloud.sensors.bigtable`. -""" - -import warnings - -from airflow.providers.google.cloud.operators.bigtable import ( - BigtableCreateInstanceOperator, - BigtableCreateTableOperator, - BigtableDeleteInstanceOperator, - BigtableDeleteTableOperator, - BigtableUpdateClusterOperator, -) -from airflow.providers.google.cloud.sensors.bigtable import BigtableTableReplicationCompletedSensor - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.bigtable`" - " or `airflow.providers.google.cloud.sensors.bigtable`.", - DeprecationWarning, - stacklevel=2, -) - - -class BigtableClusterUpdateOperator(BigtableUpdateClusterOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.bigtable.BigtableUpdateClusterOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.bigtable.BigtableUpdateClusterOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class BigtableInstanceCreateOperator(BigtableCreateInstanceOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.bigtable.BigtableCreateInstanceOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.bigtable.BigtableCreateInstanceOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class BigtableInstanceDeleteOperator(BigtableDeleteInstanceOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.bigtable.BigtableDeleteInstanceOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.bigtable.BigtableDeleteInstanceOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class BigtableTableCreateOperator(BigtableCreateTableOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.bigtable.BigtableCreateTableOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.bigtable.BigtableCreateTableOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class BigtableTableDeleteOperator(BigtableDeleteTableOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.bigtable.BigtableDeleteTableOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.bigtable.BigtableDeleteTableOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class BigtableTableWaitForReplicationSensor(BigtableTableReplicationCompletedSensor): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.sensors.bigtable.BigtableTableReplicationCompletedSensor`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.sensors.bigtable.BigtableTableReplicationCompletedSensor`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_cloud_build_operator.py b/airflow/contrib/operators/gcp_cloud_build_operator.py deleted file mode 100644 index 443fdbfdab73f..0000000000000 --- a/airflow/contrib/operators/gcp_cloud_build_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.cloud_build`.""" - -import warnings - -from airflow.providers.google.cloud.operators.cloud_build import CloudBuildCreateBuildOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.cloud_build`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/gcp_compute_operator.py b/airflow/contrib/operators/gcp_compute_operator.py deleted file mode 100644 index d943fdf5587dd..0000000000000 --- a/airflow/contrib/operators/gcp_compute_operator.py +++ /dev/null @@ -1,138 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.compute`.""" - -import warnings - -from airflow.providers.google.cloud.operators.compute import ( - ComputeEngineBaseOperator, - ComputeEngineCopyInstanceTemplateOperator, - ComputeEngineInstanceGroupUpdateManagerTemplateOperator, - ComputeEngineSetMachineTypeOperator, - ComputeEngineStartInstanceOperator, - ComputeEngineStopInstanceOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.compute`.", - DeprecationWarning, - stacklevel=2, -) - - -class GceBaseOperator(ComputeEngineBaseOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.compute.ComputeEngineBaseOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.compute.ComputeEngineBaseOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class GceInstanceGroupManagerUpdateTemplateOperator(ComputeEngineInstanceGroupUpdateManagerTemplateOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.compute - .ComputeEngineInstanceGroupUpdateManagerTemplateOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. Please use - `airflow.providers.google.cloud.operators.compute - .ComputeEngineInstanceGroupUpdateManagerTemplateOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class GceInstanceStartOperator(ComputeEngineStartInstanceOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators - .compute.ComputeEngineStartInstanceOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.compute - .ComputeEngineStartInstanceOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class GceInstanceStopOperator(ComputeEngineStopInstanceOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.compute.ComputeEngineStopInstanceOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.compute - .ComputeEngineStopInstanceOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class GceInstanceTemplateCopyOperator(ComputeEngineCopyInstanceTemplateOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.compute.ComputeEngineCopyInstanceTemplateOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """"This class is deprecated. - Please use `airflow.providers.google.cloud.operators.compute - .ComputeEngineCopyInstanceTemplateOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class GceSetMachineTypeOperator(ComputeEngineSetMachineTypeOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.compute.ComputeEngineSetMachineTypeOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.compute - .ComputeEngineSetMachineTypeOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_container_operator.py b/airflow/contrib/operators/gcp_container_operator.py deleted file mode 100644 index 8a26c40a8fe5d..0000000000000 --- a/airflow/contrib/operators/gcp_container_operator.py +++ /dev/null @@ -1,83 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.google.cloud.operators.kubernetes_engine`. -""" - -import warnings - -from airflow.providers.google.cloud.operators.kubernetes_engine import ( - GKECreateClusterOperator, - GKEDeleteClusterOperator, - GKEStartPodOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.kubernetes_engine`", - DeprecationWarning, - stacklevel=2, -) - - -class GKEClusterCreateOperator(GKECreateClusterOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.container.GKECreateClusterOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.container.GKECreateClusterOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class GKEClusterDeleteOperator(GKEDeleteClusterOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.container.GKEDeleteClusterOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.container.GKEDeleteClusterOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class GKEPodOperator(GKEStartPodOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.container.GKEStartPodOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.container.GKEStartPodOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_dlp_operator.py b/airflow/contrib/operators/gcp_dlp_operator.py deleted file mode 100644 index f5b4c07d4cc8d..0000000000000 --- a/airflow/contrib/operators/gcp_dlp_operator.py +++ /dev/null @@ -1,123 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.dlp`.""" - -import warnings - -from airflow.providers.google.cloud.operators.dlp import ( # noqa - CloudDLPCancelDLPJobOperator, - CloudDLPCreateDeidentifyTemplateOperator, - CloudDLPCreateDLPJobOperator, - CloudDLPCreateInspectTemplateOperator, - CloudDLPCreateJobTriggerOperator, - CloudDLPCreateStoredInfoTypeOperator, - CloudDLPDeidentifyContentOperator, - CloudDLPDeleteDeidentifyTemplateOperator, - CloudDLPDeleteDLPJobOperator, - CloudDLPDeleteInspectTemplateOperator, - CloudDLPDeleteJobTriggerOperator, - CloudDLPDeleteStoredInfoTypeOperator, - CloudDLPGetDeidentifyTemplateOperator, - CloudDLPGetDLPJobOperator, - CloudDLPGetDLPJobTriggerOperator, - CloudDLPGetInspectTemplateOperator, - CloudDLPGetStoredInfoTypeOperator, - CloudDLPInspectContentOperator, - CloudDLPListDeidentifyTemplatesOperator, - CloudDLPListDLPJobsOperator, - CloudDLPListInfoTypesOperator, - CloudDLPListInspectTemplatesOperator, - CloudDLPListJobTriggersOperator, - CloudDLPListStoredInfoTypesOperator, - CloudDLPRedactImageOperator, - CloudDLPReidentifyContentOperator, - CloudDLPUpdateDeidentifyTemplateOperator, - CloudDLPUpdateInspectTemplateOperator, - CloudDLPUpdateJobTriggerOperator, - CloudDLPUpdateStoredInfoTypeOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.dlp`.", - DeprecationWarning, - stacklevel=2, -) - - -class CloudDLPDeleteDlpJobOperator(CloudDLPDeleteDLPJobOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPDeleteDLPJobOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPDeleteDLPJobOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class CloudDLPGetDlpJobOperator(CloudDLPGetDLPJobOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPGetDLPJobOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPGetDLPJobOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class CloudDLPGetJobTripperOperator(CloudDLPGetDLPJobTriggerOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPGetDLPJobTriggerOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPGetDLPJobTriggerOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class CloudDLPListDlpJobsOperator(CloudDLPListDLPJobsOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPListDLPJobsOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPListDLPJobsOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_function_operator.py b/airflow/contrib/operators/gcp_function_operator.py deleted file mode 100644 index 4ca96c5c8a648..0000000000000 --- a/airflow/contrib/operators/gcp_function_operator.py +++ /dev/null @@ -1,65 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.functions`.""" - -import warnings - -from airflow.providers.google.cloud.operators.functions import ( - CloudFunctionDeleteFunctionOperator, - CloudFunctionDeployFunctionOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.functions`.", - DeprecationWarning, - stacklevel=2, -) - - -class GcfFunctionDeleteOperator(CloudFunctionDeleteFunctionOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.function.CloudFunctionDeleteFunctionOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.function.CloudFunctionDeleteFunctionOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class GcfFunctionDeployOperator(CloudFunctionDeployFunctionOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.function.CloudFunctionDeployFunctionOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.function.CloudFunctionDeployFunctionOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_natural_language_operator.py b/airflow/contrib/operators/gcp_natural_language_operator.py deleted file mode 100644 index 8a98ccf6903f0..0000000000000 --- a/airflow/contrib/operators/gcp_natural_language_operator.py +++ /dev/null @@ -1,114 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.natural_language`.""" - -import warnings - -from airflow.providers.google.cloud.operators.natural_language import ( - CloudNaturalLanguageAnalyzeEntitiesOperator, - CloudNaturalLanguageAnalyzeEntitySentimentOperator, - CloudNaturalLanguageAnalyzeSentimentOperator, - CloudNaturalLanguageClassifyTextOperator, -) - -warnings.warn( - """This module is deprecated. - Please use `airflow.providers.google.cloud.operators.natural_language` - """, - DeprecationWarning, - stacklevel=2, -) - - -class CloudLanguageAnalyzeEntitiesOperator(CloudNaturalLanguageAnalyzeEntitiesOperator): - """ - This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.natural_language.CloudNaturalLanguageAnalyzeEntitiesOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.natural_language - .CloudNaturalLanguageAnalyzeEntitiesOperator`. - """, - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class CloudLanguageAnalyzeEntitySentimentOperator(CloudNaturalLanguageAnalyzeEntitySentimentOperator): - """ - This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.natural_language - .CloudNaturalLanguageAnalyzeEntitySentimentOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.natural_language - .CloudNaturalLanguageAnalyzeEntitySentimentOperator`. - """, - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class CloudLanguageAnalyzeSentimentOperator(CloudNaturalLanguageAnalyzeSentimentOperator): - """ - This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.natural_language.CloudNaturalLanguageAnalyzeSentimentOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.natural_language - .CloudNaturalLanguageAnalyzeSentimentOperator`. - """, - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class CloudLanguageClassifyTextOperator(CloudNaturalLanguageClassifyTextOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.natural_language - .CloudNaturalLanguageClassifyTextOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.natural_language - .CloudNaturalLanguageClassifyTextOperator`. - """, - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_spanner_operator.py b/airflow/contrib/operators/gcp_spanner_operator.py deleted file mode 100644 index b2e50c3a80f6a..0000000000000 --- a/airflow/contrib/operators/gcp_spanner_operator.py +++ /dev/null @@ -1,125 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.spanner`.""" - -import warnings - -from airflow.providers.google.cloud.operators.spanner import ( - SpannerDeleteDatabaseInstanceOperator, - SpannerDeleteInstanceOperator, - SpannerDeployDatabaseInstanceOperator, - SpannerDeployInstanceOperator, - SpannerQueryDatabaseInstanceOperator, - SpannerUpdateDatabaseInstanceOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.spanner`.", - DeprecationWarning, - stacklevel=2, -) - - -class CloudSpannerInstanceDatabaseDeleteOperator(SpannerDeleteDatabaseInstanceOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.spanner.SpannerDeleteDatabaseInstanceOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - self.__doc__, - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class CloudSpannerInstanceDatabaseDeployOperator(SpannerDeployDatabaseInstanceOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.spanner.SpannerDeployDatabaseInstanceOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - self.__doc__, - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class CloudSpannerInstanceDatabaseQueryOperator(SpannerQueryDatabaseInstanceOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.spanner.SpannerQueryDatabaseInstanceOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - self.__doc__, - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class CloudSpannerInstanceDatabaseUpdateOperator(SpannerUpdateDatabaseInstanceOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.spanner.SpannerUpdateDatabaseInstanceOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - self.__doc__, - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class CloudSpannerInstanceDeleteOperator(SpannerDeleteInstanceOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.spanner.SpannerDeleteInstanceOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - self.__doc__, - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class CloudSpannerInstanceDeployOperator(SpannerDeployInstanceOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.spanner.SpannerDeployInstanceOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - self.__doc__, - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_speech_to_text_operator.py b/airflow/contrib/operators/gcp_speech_to_text_operator.py deleted file mode 100644 index 499c8ab7f6e01..0000000000000 --- a/airflow/contrib/operators/gcp_speech_to_text_operator.py +++ /dev/null @@ -1,47 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.speech_to_text`.""" - -import warnings - -from airflow.providers.google.cloud.operators.speech_to_text import CloudSpeechToTextRecognizeSpeechOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.speech_to_text`", - DeprecationWarning, - stacklevel=2, -) - - -class GcpSpeechToTextRecognizeSpeechOperator(CloudSpeechToTextRecognizeSpeechOperator): - """ - This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.speech_to_text.CloudSpeechToTextRecognizeSpeechOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.speech_to_text - .CloudSpeechToTextRecognizeSpeechOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_sql_operator.py b/airflow/contrib/operators/gcp_sql_operator.py deleted file mode 100644 index 5cbbf8a80a294..0000000000000 --- a/airflow/contrib/operators/gcp_sql_operator.py +++ /dev/null @@ -1,149 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.cloud_sql`.""" - -import warnings - -from airflow.providers.google.cloud.operators.cloud_sql import ( - CloudSQLBaseOperator, - CloudSQLCreateInstanceDatabaseOperator, - CloudSQLCreateInstanceOperator, - CloudSQLDeleteInstanceDatabaseOperator, - CloudSQLDeleteInstanceOperator, - CloudSQLExecuteQueryOperator, - CloudSQLExportInstanceOperator, - CloudSQLImportInstanceOperator, - CloudSQLInstancePatchOperator, - CloudSQLPatchInstanceDatabaseOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.cloud_sql`", - DeprecationWarning, - stacklevel=2, -) - - -class CloudSqlBaseOperator(CloudSQLBaseOperator): - """ - This class is deprecated. Please use - `airflow.providers.google.cloud.operators.sql.CloudSQLBaseOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2) - super().__init__(*args, **kwargs) - - -class CloudSqlInstanceCreateOperator(CloudSQLCreateInstanceOperator): - """ - This class is deprecated. Please use `airflow.providers.google.cloud.operators.sql - .CloudSQLCreateInstanceOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2) - super().__init__(*args, **kwargs) - - -class CloudSqlInstanceDatabaseCreateOperator(CloudSQLCreateInstanceDatabaseOperator): - """ - This class is deprecated. Please use `airflow.providers.google.cloud.operators.sql - .CloudSQLCreateInstanceDatabaseOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2) - super().__init__(*args, **kwargs) - - -class CloudSqlInstanceDatabaseDeleteOperator(CloudSQLDeleteInstanceDatabaseOperator): - """ - This class is deprecated. Please use `airflow.providers.google.cloud.operators.sql - .CloudSQLDeleteInstanceDatabaseOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2) - super().__init__(*args, **kwargs) - - -class CloudSqlInstanceDatabasePatchOperator(CloudSQLPatchInstanceDatabaseOperator): - """ - This class is deprecated. Please use `airflow.providers.google.cloud.operators.sql - .CloudSQLPatchInstanceDatabaseOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2) - super().__init__(*args, **kwargs) - - -class CloudSqlInstanceDeleteOperator(CloudSQLDeleteInstanceOperator): - """ - This class is deprecated. Please use `airflow.providers.google.cloud.operators.sql - .CloudSQLDeleteInstanceOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2) - super().__init__(*args, **kwargs) - - -class CloudSqlInstanceExportOperator(CloudSQLExportInstanceOperator): - """ - This class is deprecated. Please use `airflow.providers.google.cloud.operators.sql - .CloudSQLExportInstanceOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2) - super().__init__(*args, **kwargs) - - -class CloudSqlInstanceImportOperator(CloudSQLImportInstanceOperator): - """ - This class is deprecated. Please use `airflow.providers.google.cloud.operators.sql - .CloudSQLImportInstanceOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2) - super().__init__(*args, **kwargs) - - -class CloudSqlInstancePatchOperator(CloudSQLInstancePatchOperator): - """ - This class is deprecated. Please use `airflow.providers.google.cloud.operators - .sql.CloudSQLInstancePatchOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2) - super().__init__(*args, **kwargs) - - -class CloudSqlQueryOperator(CloudSQLExecuteQueryOperator): - """ - This class is deprecated. Please use `airflow.providers.google.cloud.operators - .sql.CloudSQLExecuteQueryOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_tasks_operator.py b/airflow/contrib/operators/gcp_tasks_operator.py deleted file mode 100644 index 319ddb4fe2ad5..0000000000000 --- a/airflow/contrib/operators/gcp_tasks_operator.py +++ /dev/null @@ -1,42 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.tasks`.""" - -import warnings - -from airflow.providers.google.cloud.operators.tasks import ( # noqa - CloudTasksQueueCreateOperator, - CloudTasksQueueDeleteOperator, - CloudTasksQueueGetOperator, - CloudTasksQueuePauseOperator, - CloudTasksQueuePurgeOperator, - CloudTasksQueueResumeOperator, - CloudTasksQueuesListOperator, - CloudTasksQueueUpdateOperator, - CloudTasksTaskCreateOperator, - CloudTasksTaskDeleteOperator, - CloudTasksTaskGetOperator, - CloudTasksTaskRunOperator, - CloudTasksTasksListOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.tasks`", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/gcp_text_to_speech_operator.py b/airflow/contrib/operators/gcp_text_to_speech_operator.py deleted file mode 100644 index af5ff2f39d122..0000000000000 --- a/airflow/contrib/operators/gcp_text_to_speech_operator.py +++ /dev/null @@ -1,45 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.text_to_speech`.""" - -import warnings - -from airflow.providers.google.cloud.operators.text_to_speech import CloudTextToSpeechSynthesizeOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.text_to_speech`", - DeprecationWarning, - stacklevel=2, -) - - -class GcpTextToSpeechSynthesizeOperator(CloudTextToSpeechSynthesizeOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.text_to_speech.CloudTextToSpeechSynthesizeOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.text_to_speech.CloudTextToSpeechSynthesizeOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_transfer_operator.py b/airflow/contrib/operators/gcp_transfer_operator.py deleted file mode 100644 index bd7c7bc2cb752..0000000000000 --- a/airflow/contrib/operators/gcp_transfer_operator.py +++ /dev/null @@ -1,229 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use `airflow.providers.google.cloud.operators.cloud_storage_transfer_service`. -""" - -import warnings - -from airflow.providers.google.cloud.operators.cloud_storage_transfer_service import ( - CloudDataTransferServiceCancelOperationOperator, - CloudDataTransferServiceCreateJobOperator, - CloudDataTransferServiceDeleteJobOperator, - CloudDataTransferServiceGCSToGCSOperator, - CloudDataTransferServiceGetOperationOperator, - CloudDataTransferServiceListOperationsOperator, - CloudDataTransferServicePauseOperationOperator, - CloudDataTransferServiceResumeOperationOperator, - CloudDataTransferServiceS3ToGCSOperator, - CloudDataTransferServiceUpdateJobOperator, -) - -warnings.warn( - "This module is deprecated. " - "Please use `airflow.providers.google.cloud.operators.cloud_storage_transfer_service`", - DeprecationWarning, - stacklevel=2, -) - - -class GcpTransferServiceJobCreateOperator(CloudDataTransferServiceCreateJobOperator): - """ - This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.data_transfer.CloudDataTransferServiceCreateJobOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.data_transfer - .CloudDataTransferServiceCreateJobOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class GcpTransferServiceJobDeleteOperator(CloudDataTransferServiceDeleteJobOperator): - """ - This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.data_transfer.CloudDataTransferServiceDeleteJobOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.data_transfer - .CloudDataTransferServiceDeleteJobOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class GcpTransferServiceJobUpdateOperator(CloudDataTransferServiceUpdateJobOperator): - """ - This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.data_transfer.CloudDataTransferServiceUpdateJobOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.data_transfer - .CloudDataTransferServiceUpdateJobOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class GcpTransferServiceOperationCancelOperator(CloudDataTransferServiceCancelOperationOperator): - """ - This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.data_transfer.CloudDataTransferServiceCancelOperationOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.data_transfer - .CloudDataTransferServiceCancelOperationOperator`. - """, - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class GcpTransferServiceOperationGetOperator(CloudDataTransferServiceGetOperationOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.data_transfer - .CloudDataTransferServiceGetOperationOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.data_transfer - .CloudDataTransferServiceGetOperationOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class GcpTransferServiceOperationPauseOperator(CloudDataTransferServicePauseOperationOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.data_transfer - .CloudDataTransferServicePauseOperationOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.data_transfer - .CloudDataTransferServicePauseOperationOperator`. - """, - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class GcpTransferServiceOperationResumeOperator(CloudDataTransferServiceResumeOperationOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.data_transfer - .CloudDataTransferServiceResumeOperationOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.data_transfer - .CloudDataTransferServiceResumeOperationOperator`. - """, - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class GcpTransferServiceOperationsListOperator(CloudDataTransferServiceListOperationsOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.data_transfer - .CloudDataTransferServiceListOperationsOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.data_transfer - .CloudDataTransferServiceListOperationsOperator`. - """, - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class GoogleCloudStorageToGoogleCloudStorageTransferOperator(CloudDataTransferServiceGCSToGCSOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.data_transfer - .CloudDataTransferServiceGCSToGCSOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.data_transfer - .CloudDataTransferServiceGCSToGCSOperator`. - """, - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class S3ToGoogleCloudStorageTransferOperator(CloudDataTransferServiceS3ToGCSOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.data_transfer - .CloudDataTransferServiceS3ToGCSOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """"This class is deprecated. - Please use `airflow.providers.google.cloud.operators.data_transfer - .CloudDataTransferServiceS3ToGCSOperator`. - """, - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_translate_operator.py b/airflow/contrib/operators/gcp_translate_operator.py deleted file mode 100644 index c61cc8492c250..0000000000000 --- a/airflow/contrib/operators/gcp_translate_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.translate`.""" - -import warnings - -from airflow.providers.google.cloud.operators.translate import CloudTranslateTextOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.translate`", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/gcp_translate_speech_operator.py b/airflow/contrib/operators/gcp_translate_speech_operator.py deleted file mode 100644 index 2e0bb70787b61..0000000000000 --- a/airflow/contrib/operators/gcp_translate_speech_operator.py +++ /dev/null @@ -1,49 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.google.cloud.operators.translate_speech`. -""" - -import warnings - -from airflow.providers.google.cloud.operators.translate_speech import CloudTranslateSpeechOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.translate_speech`.", - DeprecationWarning, - stacklevel=2, -) - - -class GcpTranslateSpeechOperator(CloudTranslateSpeechOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.translate_speech.CloudTranslateSpeechOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.translate_speech.CloudTranslateSpeechOperator`. - """, - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcp_video_intelligence_operator.py b/airflow/contrib/operators/gcp_video_intelligence_operator.py deleted file mode 100644 index a82fc9aeb8411..0000000000000 --- a/airflow/contrib/operators/gcp_video_intelligence_operator.py +++ /dev/null @@ -1,35 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.google.cloud.operators.video_intelligence`. -""" - -import warnings - -from airflow.providers.google.cloud.operators.video_intelligence import ( # noqa - CloudVideoIntelligenceDetectVideoExplicitContentOperator, - CloudVideoIntelligenceDetectVideoLabelsOperator, - CloudVideoIntelligenceDetectVideoShotsOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.video_intelligence`", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/gcp_vision_operator.py b/airflow/contrib/operators/gcp_vision_operator.py deleted file mode 100644 index 09a5b1e817e28..0000000000000 --- a/airflow/contrib/operators/gcp_vision_operator.py +++ /dev/null @@ -1,227 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.vision`.""" - -import warnings - -from airflow.providers.google.cloud.operators.vision import ( # noqa - CloudVisionAddProductToProductSetOperator, - CloudVisionCreateProductOperator, - CloudVisionCreateProductSetOperator, - CloudVisionCreateReferenceImageOperator, - CloudVisionDeleteProductOperator, - CloudVisionDeleteProductSetOperator, - CloudVisionDetectImageLabelsOperator, - CloudVisionDetectImageSafeSearchOperator, - CloudVisionDetectTextOperator, - CloudVisionGetProductOperator, - CloudVisionGetProductSetOperator, - CloudVisionImageAnnotateOperator, - CloudVisionRemoveProductFromProductSetOperator, - CloudVisionTextDetectOperator, - CloudVisionUpdateProductOperator, - CloudVisionUpdateProductSetOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.vision`.", - DeprecationWarning, - stacklevel=2, -) - - -class CloudVisionAnnotateImageOperator(CloudVisionImageAnnotateOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.vision.CloudVisionImageAnnotateOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.vision.CloudVisionImageAnnotateOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class CloudVisionDetectDocumentTextOperator(CloudVisionTextDetectOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.vision.CloudVisionTextDetectOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.vision.CloudVisionTextDetectOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class CloudVisionProductCreateOperator(CloudVisionCreateProductOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class CloudVisionProductDeleteOperator(CloudVisionDeleteProductOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class CloudVisionProductGetOperator(CloudVisionGetProductOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.vision.CloudVisionGetProductOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.vision.CloudVisionGetProductOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class CloudVisionProductSetCreateOperator(CloudVisionCreateProductSetOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductSetOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductSetOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class CloudVisionProductSetDeleteOperator(CloudVisionDeleteProductSetOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductSetOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductSetOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class CloudVisionProductSetGetOperator(CloudVisionGetProductSetOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.vision.CloudVisionGetProductSetOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.vision.CloudVisionGetProductSetOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class CloudVisionProductSetUpdateOperator(CloudVisionUpdateProductSetOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductSetOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductSetOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class CloudVisionProductUpdateOperator(CloudVisionUpdateProductOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class CloudVisionReferenceImageCreateOperator(CloudVisionCreateReferenceImageOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.vision.CloudVisionCreateReferenceImageOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.vision.CloudVisionCreateReferenceImageOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcs_acl_operator.py b/airflow/contrib/operators/gcs_acl_operator.py deleted file mode 100644 index f3e6afa303602..0000000000000 --- a/airflow/contrib/operators/gcs_acl_operator.py +++ /dev/null @@ -1,63 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.gcs`.""" - -import warnings - -from airflow.providers.google.cloud.operators.gcs import ( - GCSBucketCreateAclEntryOperator, - GCSObjectCreateAclEntryOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.gcs`.", - DeprecationWarning, - stacklevel=2, -) - - -class GoogleCloudStorageBucketCreateAclEntryOperator(GCSBucketCreateAclEntryOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.gcs.GCSBucketCreateAclEntryOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.gcs.GCSBucketCreateAclEntryOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class GoogleCloudStorageObjectCreateAclEntryOperator(GCSObjectCreateAclEntryOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.gcs.GCSObjectCreateAclEntryOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.gcs.GCSObjectCreateAclEntryOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcs_delete_operator.py b/airflow/contrib/operators/gcs_delete_operator.py deleted file mode 100644 index 42fa2a89057d1..0000000000000 --- a/airflow/contrib/operators/gcs_delete_operator.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.gcs`.""" - -import warnings - -from airflow.providers.google.cloud.operators.gcs import GCSDeleteObjectsOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.gcs`.", - DeprecationWarning, - stacklevel=2, -) - - -class GoogleCloudStorageDeleteOperator(GCSDeleteObjectsOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.gcs.GCSDeleteObjectsOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.gcs.GCSDeleteObjectsOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcs_download_operator.py b/airflow/contrib/operators/gcs_download_operator.py deleted file mode 100644 index a6904320bd2c2..0000000000000 --- a/airflow/contrib/operators/gcs_download_operator.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.gcs`.""" - -import warnings - -from airflow.providers.google.cloud.transfers.gcs_to_local import GCSToLocalFilesystemOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.gcs`.", - DeprecationWarning, - stacklevel=2, -) - - -class GoogleCloudStorageDownloadOperator(GCSToLocalFilesystemOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.gcs.GCSToLocalFilesystemOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.gcs.GCSToLocalFilesystemOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcs_list_operator.py b/airflow/contrib/operators/gcs_list_operator.py deleted file mode 100644 index a18ec272a5907..0000000000000 --- a/airflow/contrib/operators/gcs_list_operator.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.gcs`.""" - -import warnings - -from airflow.providers.google.cloud.operators.gcs import GCSListObjectsOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.gcs`.", - DeprecationWarning, - stacklevel=2, -) - - -class GoogleCloudStorageListOperator(GCSListObjectsOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.gcs.GCSListObjectsOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.gcs.GCSListObjectsOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcs_operator.py b/airflow/contrib/operators/gcs_operator.py deleted file mode 100644 index 4f0aecbf9c167..0000000000000 --- a/airflow/contrib/operators/gcs_operator.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.gcs`.""" - -import warnings - -from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.gcs`.", - DeprecationWarning, - stacklevel=2, -) - - -class GoogleCloudStorageCreateBucketOperator(GCSCreateBucketOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.gcs.GCSCreateBucketOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.gcs.GCSCreateBucketOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcs_to_bq.py b/airflow/contrib/operators/gcs_to_bq.py deleted file mode 100644 index ab71bb8da7989..0000000000000 --- a/airflow/contrib/operators/gcs_to_bq.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.transfers.gcs_to_bigquery`.""" - -import warnings - -from airflow.providers.google.cloud.transfers.gcs_to_bigquery import GCSToBigQueryOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.gcs_to_bigquery`.", - DeprecationWarning, - stacklevel=2, -) - - -class GoogleCloudStorageToBigQueryOperator(GCSToBigQueryOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.transfers.gcs_to_bigquery.GCSToBigQueryOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.transfers.gcs_to_bigquery.GCSToBigQueryOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcs_to_gcs.py b/airflow/contrib/operators/gcs_to_gcs.py deleted file mode 100644 index ab6f2d91c1ba9..0000000000000 --- a/airflow/contrib/operators/gcs_to_gcs.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.transfers.gcs_to_gcs`.""" - -import warnings - -from airflow.providers.google.cloud.transfers.gcs_to_gcs import GCSToGCSOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.gcs_to_gcs`.", - DeprecationWarning, - stacklevel=2, -) - - -class GoogleCloudStorageToGoogleCloudStorageOperator(GCSToGCSOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSToGCSOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSToGCSOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/gcs_to_gcs_transfer_operator.py b/airflow/contrib/operators/gcs_to_gcs_transfer_operator.py deleted file mode 100644 index 75a672f2c18b2..0000000000000 --- a/airflow/contrib/operators/gcs_to_gcs_transfer_operator.py +++ /dev/null @@ -1,30 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use `airflow.providers.google.cloud.operators.cloud_storage_transfer_service`. -""" - -import warnings - -warnings.warn( - "This module is deprecated. " - "Please use `airflow.providers.google.cloud.operators.cloud_storage_transfer_service`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/gcs_to_gdrive_operator.py b/airflow/contrib/operators/gcs_to_gdrive_operator.py deleted file mode 100644 index 1fb55d1fc3dfe..0000000000000 --- a/airflow/contrib/operators/gcs_to_gdrive_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.suite.transfers.gcs_to_gdrive`.""" - -import warnings - -from airflow.providers.google.suite.transfers.gcs_to_gdrive import GCSToGoogleDriveOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.suite.transfers.gcs_to_gdrive.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/gcs_to_s3.py b/airflow/contrib/operators/gcs_to_s3.py deleted file mode 100644 index 13aa005ab20a3..0000000000000 --- a/airflow/contrib/operators/gcs_to_s3.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.transfers.gcs_to_s3`.""" - -import warnings - -from airflow.providers.amazon.aws.transfers.gcs_to_s3 import GCSToS3Operator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.gcs_to_s3`.", - DeprecationWarning, - stacklevel=2, -) - - -class GoogleCloudStorageToS3Operator(GCSToS3Operator): - """ - This class is deprecated. Please use - `airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSToS3Operator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This class is deprecated. " - "Please use `airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSToS3Operator`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/grpc_operator.py b/airflow/contrib/operators/grpc_operator.py deleted file mode 100644 index bd8cfbd6003ee..0000000000000 --- a/airflow/contrib/operators/grpc_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.grpc.operators.grpc`.""" - -import warnings - -from airflow.providers.grpc.operators.grpc import GrpcOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.grpc.operators.grpc`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/hive_to_dynamodb.py b/airflow/contrib/operators/hive_to_dynamodb.py deleted file mode 100644 index ba4f8b967cc98..0000000000000 --- a/airflow/contrib/operators/hive_to_dynamodb.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.transfers.hive_to_dynamodb`.""" - -import warnings - -from airflow.providers.amazon.aws.transfers.hive_to_dynamodb import HiveToDynamoDBOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.hive_to_dynamodb`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/imap_attachment_to_s3_operator.py b/airflow/contrib/operators/imap_attachment_to_s3_operator.py deleted file mode 100644 index e82a8bc05aeae..0000000000000 --- a/airflow/contrib/operators/imap_attachment_to_s3_operator.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.amazon.aws.transfers.imap_attachment_to_s3`. -""" - -import warnings - -from airflow.providers.amazon.aws.transfers.imap_attachment_to_s3 import ImapAttachmentToS3Operator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.imap_attachment_to_s3`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/jenkins_job_trigger_operator.py b/airflow/contrib/operators/jenkins_job_trigger_operator.py deleted file mode 100644 index 0b401d2b430a7..0000000000000 --- a/airflow/contrib/operators/jenkins_job_trigger_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.jenkins.operators.jenkins_job_trigger`.""" - -import warnings - -from airflow.providers.jenkins.operators.jenkins_job_trigger import JenkinsJobTriggerOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.jenkins.operators.jenkins_job_trigger`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/jira_operator.py b/airflow/contrib/operators/jira_operator.py deleted file mode 100644 index b6e3b3e124a85..0000000000000 --- a/airflow/contrib/operators/jira_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.jira.operators.jira`.""" - -import warnings - -from airflow.providers.jira.operators.jira import JiraOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.jira.operators.jira`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/kubernetes_pod_operator.py b/airflow/contrib/operators/kubernetes_pod_operator.py deleted file mode 100644 index 962fa22859013..0000000000000 --- a/airflow/contrib/operators/kubernetes_pod_operator.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.cncf.kubernetes.operators.kubernetes_pod`. -""" - -import warnings - -from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.cncf.kubernetes.operators.kubernetes_pod`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/mlengine_operator.py b/airflow/contrib/operators/mlengine_operator.py deleted file mode 100644 index d5d6b2fb1d24a..0000000000000 --- a/airflow/contrib/operators/mlengine_operator.py +++ /dev/null @@ -1,100 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.mlengine`.""" - -import warnings - -from airflow.providers.google.cloud.operators.mlengine import ( - MLEngineManageModelOperator, - MLEngineManageVersionOperator, - MLEngineStartBatchPredictionJobOperator, - MLEngineStartTrainingJobOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.operators.mlengine`.", - DeprecationWarning, - stacklevel=2, -) - - -class MLEngineBatchPredictionOperator(MLEngineStartBatchPredictionJobOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.mlengine.MLEngineStartBatchPredictionJobOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.mlengine.MLEngineStartBatchPredictionJobOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class MLEngineModelOperator(MLEngineManageModelOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.mlengine.MLEngineManageModelOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.mlengine.MLEngineManageModelOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class MLEngineTrainingOperator(MLEngineStartTrainingJobOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.mlengine.MLEngineStartTrainingJobOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.mlengine.MLEngineStartTrainingJobOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class MLEngineVersionOperator(MLEngineManageVersionOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.operators.mlengine.MLEngineManageVersionOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.operators.mlengine.MLEngineManageVersionOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/mongo_to_s3.py b/airflow/contrib/operators/mongo_to_s3.py deleted file mode 100644 index 17b0676952e91..0000000000000 --- a/airflow/contrib/operators/mongo_to_s3.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.transfers.mongo_to_s3`.""" - -import warnings - -from airflow.providers.amazon.aws.transfers.mongo_to_s3 import MongoToS3Operator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.mongo_to_s3`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/mssql_to_gcs.py b/airflow/contrib/operators/mssql_to_gcs.py deleted file mode 100644 index 140457bb709f4..0000000000000 --- a/airflow/contrib/operators/mssql_to_gcs.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.transfers.mssql_to_gcs`.""" - -import warnings - -from airflow.providers.google.cloud.transfers.mssql_to_gcs import MSSQLToGCSOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.mssql_to_gcs`.", - DeprecationWarning, - stacklevel=2, -) - - -class MsSqlToGoogleCloudStorageOperator(MSSQLToGCSOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.transfers.mssql_to_gcs.MSSQLToGCSOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.transfers.mssql_to_gcs.MSSQLToGCSOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/mysql_to_gcs.py b/airflow/contrib/operators/mysql_to_gcs.py deleted file mode 100644 index 50363b3a39bc6..0000000000000 --- a/airflow/contrib/operators/mysql_to_gcs.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.transfers.mysql_to_gcs`.""" - -import warnings - -from airflow.providers.google.cloud.transfers.mysql_to_gcs import MySQLToGCSOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.mysql_to_gcs`.", - DeprecationWarning, - stacklevel=2, -) - - -class MySqlToGoogleCloudStorageOperator(MySQLToGCSOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.transfers.mysql_to_gcs.MySQLToGCSOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.transfers.mysql_to_gcs.MySQLToGCSOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/opsgenie_alert_operator.py b/airflow/contrib/operators/opsgenie_alert_operator.py deleted file mode 100644 index 008214b97e020..0000000000000 --- a/airflow/contrib/operators/opsgenie_alert_operator.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.opsgenie.operators.opsgenie`.""" - -import warnings - -from airflow.providers.opsgenie.operators.opsgenie import OpsgenieCreateAlertOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.opsgenie.operators.opsgenie`.", - DeprecationWarning, - stacklevel=2, -) - - -class OpsgenieAlertOperator(OpsgenieCreateAlertOperator): - """ - This class is deprecated. - Please use :class:`airflow.providers.opsgenie.operators.opsgenie.OpsgenieCreateAlertOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This class is deprecated. " - "Please use :class:`airflow.providers.opsgenie.operators.opsgenie.OpsgenieCreateAlertOperator`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/oracle_to_azure_data_lake_transfer.py b/airflow/contrib/operators/oracle_to_azure_data_lake_transfer.py deleted file mode 100644 index 3907b6f8e8b6b..0000000000000 --- a/airflow/contrib/operators/oracle_to_azure_data_lake_transfer.py +++ /dev/null @@ -1,34 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use `airflow.providers.microsoft.azure.transfers.oracle_to_azure_data_lake`. -""" - -import warnings - -from airflow.providers.microsoft.azure.transfers.oracle_to_azure_data_lake import ( # noqa - OracleToAzureDataLakeOperator, -) - -warnings.warn( - "This module is deprecated. " - "Please use `airflow.providers.microsoft.azure.transfers.oracle_to_azure_data_lake`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/oracle_to_oracle_transfer.py b/airflow/contrib/operators/oracle_to_oracle_transfer.py deleted file mode 100644 index 2efbf4d0a89b4..0000000000000 --- a/airflow/contrib/operators/oracle_to_oracle_transfer.py +++ /dev/null @@ -1,49 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.oracle.transfers.oracle_to_oracle`. -""" - -import warnings - -from airflow.providers.oracle.transfers.oracle_to_oracle import OracleToOracleOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.oracle.transfers.oracle_to_oracle`.", - DeprecationWarning, - stacklevel=2, -) - - -class OracleToOracleTransfer(OracleToOracleOperator): - """This class is deprecated. - - Please use: - `airflow.providers.oracle.transfers.oracle_to_oracle.OracleToOracleOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.oracle.transfers.oracle_to_oracle.OracleToOracleOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/postgres_to_gcs_operator.py b/airflow/contrib/operators/postgres_to_gcs_operator.py deleted file mode 100644 index 1ce5f42304ce5..0000000000000 --- a/airflow/contrib/operators/postgres_to_gcs_operator.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.transfers.postgres_to_gcs`.""" - -import warnings - -from airflow.providers.google.cloud.transfers.postgres_to_gcs import PostgresToGCSOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.postgres_to_gcs`.", - DeprecationWarning, - stacklevel=2, -) - - -class PostgresToGoogleCloudStorageOperator(PostgresToGCSOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.transfers.postgres_to_gcs.PostgresToGCSOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.transfers.postgres_to_gcs.PostgresToGCSOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/pubsub_operator.py b/airflow/contrib/operators/pubsub_operator.py deleted file mode 100644 index 9d85d32fefed1..0000000000000 --- a/airflow/contrib/operators/pubsub_operator.py +++ /dev/null @@ -1,118 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.google.cloud.operators.pubsub`. -""" - -import warnings - -from airflow.providers.google.cloud.operators.pubsub import ( - PubSubCreateSubscriptionOperator, - PubSubCreateTopicOperator, - PubSubDeleteSubscriptionOperator, - PubSubDeleteTopicOperator, - PubSubPublishMessageOperator, -) - -warnings.warn( - """This module is deprecated. - "Please use `airflow.providers.google.cloud.operators.pubsub`.""", - DeprecationWarning, - stacklevel=2, -) - - -class PubSubPublishOperator(PubSubPublishMessageOperator): - """This class is deprecated. - - Please use `airflow.providers.google.cloud.operators.pubsub.PubSubPublishMessageOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.pubsub.PubSubPublishMessageOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class PubSubSubscriptionCreateOperator(PubSubCreateSubscriptionOperator): - """This class is deprecated. - - Please use `airflow.providers.google.cloud.operators.pubsub.PubSubCreateSubscriptionOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.pubsub.PubSubCreateSubscriptionOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class PubSubSubscriptionDeleteOperator(PubSubDeleteSubscriptionOperator): - """This class is deprecated. - - Please use `airflow.providers.google.cloud.operators.pubsub.PubSubDeleteSubscriptionOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.pubsub.PubSubDeleteSubscriptionOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class PubSubTopicCreateOperator(PubSubCreateTopicOperator): - """This class is deprecated. - - Please use `airflow.providers.google.cloud.operators.pubsub.PubSubCreateTopicOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.pubsub.PubSubCreateTopicOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class PubSubTopicDeleteOperator(PubSubDeleteTopicOperator): - """This class is deprecated. - - Please use `airflow.providers.google.cloud.operators.pubsub.PubSubDeleteTopicOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.operators.pubsub.PubSubDeleteTopicOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/qubole_check_operator.py b/airflow/contrib/operators/qubole_check_operator.py deleted file mode 100644 index e42a9e0c8f13a..0000000000000 --- a/airflow/contrib/operators/qubole_check_operator.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.qubole.operators.qubole_check`.""" - -import warnings - -from airflow.providers.qubole.operators.qubole_check import ( # noqa - QuboleCheckOperator, - QuboleValueCheckOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.qubole.operators.qubole_check`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/qubole_operator.py b/airflow/contrib/operators/qubole_operator.py deleted file mode 100644 index e4a30748a03c6..0000000000000 --- a/airflow/contrib/operators/qubole_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.qubole.operators.qubole`.""" - -import warnings - -from airflow.providers.qubole.operators.qubole import QuboleOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.qubole.operators.qubole`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/redis_publish_operator.py b/airflow/contrib/operators/redis_publish_operator.py deleted file mode 100644 index 994d9323ad170..0000000000000 --- a/airflow/contrib/operators/redis_publish_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.redis.operators.redis_publish`.""" - -import warnings - -from airflow.providers.redis.operators.redis_publish import RedisPublishOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.redis.operators.redis_publish`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/s3_copy_object_operator.py b/airflow/contrib/operators/s3_copy_object_operator.py deleted file mode 100644 index cbe9c63440629..0000000000000 --- a/airflow/contrib/operators/s3_copy_object_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.s3_copy_object`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.s3_copy_object import S3CopyObjectOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_copy_object`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/s3_delete_objects_operator.py b/airflow/contrib/operators/s3_delete_objects_operator.py deleted file mode 100644 index a0ab210324b55..0000000000000 --- a/airflow/contrib/operators/s3_delete_objects_operator.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.amazon.aws.operators.s3_delete_objects`. -""" - -import warnings - -from airflow.providers.amazon.aws.operators.s3_delete_objects import S3DeleteObjectsOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_delete_objects`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/s3_list_operator.py b/airflow/contrib/operators/s3_list_operator.py deleted file mode 100644 index 172b94cb116c2..0000000000000 --- a/airflow/contrib/operators/s3_list_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.s3_list`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.s3_list import S3ListOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_list`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/s3_to_gcs_operator.py b/airflow/contrib/operators/s3_to_gcs_operator.py deleted file mode 100644 index d0ea8e09bdbb1..0000000000000 --- a/airflow/contrib/operators/s3_to_gcs_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.transfers.s3_to_gcs`.""" - -import warnings - -from airflow.providers.google.cloud.transfers.s3_to_gcs import S3ToGCSOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.s3_to_gcs`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/s3_to_gcs_transfer_operator.py b/airflow/contrib/operators/s3_to_gcs_transfer_operator.py deleted file mode 100644 index 71df06229334b..0000000000000 --- a/airflow/contrib/operators/s3_to_gcs_transfer_operator.py +++ /dev/null @@ -1,33 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use `airflow.providers.google.cloud.operators.cloud_storage_transfer_service`. -""" -import warnings - -from airflow.providers.google.cloud.operators.cloud_storage_transfer_service import ( # noqa isort:skip - CloudDataTransferServiceS3ToGCSOperator, -) - -warnings.warn( - "This module is deprecated. " - "Please use `airflow.providers.google.cloud.operators.cloud_storage_transfer_service`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/s3_to_sftp_operator.py b/airflow/contrib/operators/s3_to_sftp_operator.py deleted file mode 100644 index e129af1471782..0000000000000 --- a/airflow/contrib/operators/s3_to_sftp_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.transfers.s3_to_sftp`.""" - -import warnings - -from airflow.providers.amazon.aws.transfers.s3_to_sftp import S3ToSFTPOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.s3_to_sftp`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/sagemaker_base_operator.py b/airflow/contrib/operators/sagemaker_base_operator.py deleted file mode 100644 index 4c2c8f6baf131..0000000000000 --- a/airflow/contrib/operators/sagemaker_base_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker_base`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_base`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/sagemaker_endpoint_config_operator.py b/airflow/contrib/operators/sagemaker_endpoint_config_operator.py deleted file mode 100644 index 43945b222734d..0000000000000 --- a/airflow/contrib/operators/sagemaker_endpoint_config_operator.py +++ /dev/null @@ -1,34 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use `airflow.providers.amazon.aws.operators.sagemaker_endpoint_config`. -""" - -import warnings - -from airflow.providers.amazon.aws.operators.sagemaker_endpoint_config import ( # noqa - SageMakerEndpointConfigOperator, -) - -warnings.warn( - "This module is deprecated. " - "Please use `airflow.providers.amazon.aws.operators.sagemaker_endpoint_config`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/sagemaker_endpoint_operator.py b/airflow/contrib/operators/sagemaker_endpoint_operator.py deleted file mode 100644 index fe175a67b29cf..0000000000000 --- a/airflow/contrib/operators/sagemaker_endpoint_operator.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker_endpoint`. -""" - -import warnings - -from airflow.providers.amazon.aws.operators.sagemaker_endpoint import SageMakerEndpointOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_endpoint`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/sagemaker_model_operator.py b/airflow/contrib/operators/sagemaker_model_operator.py deleted file mode 100644 index 9a003485606ab..0000000000000 --- a/airflow/contrib/operators/sagemaker_model_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker_model`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.sagemaker_model import SageMakerModelOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_model`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/sagemaker_training_operator.py b/airflow/contrib/operators/sagemaker_training_operator.py deleted file mode 100644 index d3749c68573d2..0000000000000 --- a/airflow/contrib/operators/sagemaker_training_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker_training`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.sagemaker_training import SageMakerTrainingOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_training`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/sagemaker_transform_operator.py b/airflow/contrib/operators/sagemaker_transform_operator.py deleted file mode 100644 index 93cf7070db220..0000000000000 --- a/airflow/contrib/operators/sagemaker_transform_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker_transform`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.sagemaker_transform import SageMakerTransformOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_transform`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/sagemaker_tuning_operator.py b/airflow/contrib/operators/sagemaker_tuning_operator.py deleted file mode 100644 index 05760a74569cf..0000000000000 --- a/airflow/contrib/operators/sagemaker_tuning_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker_tuning`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.sagemaker_tuning import SageMakerTuningOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_tuning`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/segment_track_event_operator.py b/airflow/contrib/operators/segment_track_event_operator.py deleted file mode 100644 index 92419a1977405..0000000000000 --- a/airflow/contrib/operators/segment_track_event_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.segment.operators.segment_track_event`.""" - -import warnings - -from airflow.providers.segment.operators.segment_track_event import SegmentTrackEventOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.segment.operators.segment_track_event`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/sftp_operator.py b/airflow/contrib/operators/sftp_operator.py deleted file mode 100644 index e73a84743c83b..0000000000000 --- a/airflow/contrib/operators/sftp_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.sftp.operators.sftp`.""" - -import warnings - -from airflow.providers.sftp.operators.sftp import SFTPOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.sftp.operators.sftp`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/sftp_to_s3_operator.py b/airflow/contrib/operators/sftp_to_s3_operator.py deleted file mode 100644 index 7c13b1817d63e..0000000000000 --- a/airflow/contrib/operators/sftp_to_s3_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.transfers.sftp_to_s3`.""" - -import warnings - -from airflow.providers.amazon.aws.transfers.sftp_to_s3 import SFTPToS3Operator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.sftp_to_s3`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/slack_webhook_operator.py b/airflow/contrib/operators/slack_webhook_operator.py deleted file mode 100644 index f271102e14550..0000000000000 --- a/airflow/contrib/operators/slack_webhook_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.slack.operators.slack_webhook`.""" - -import warnings - -from airflow.providers.slack.operators.slack_webhook import SlackWebhookOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.slack.operators.slack_webhook`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/snowflake_operator.py b/airflow/contrib/operators/snowflake_operator.py deleted file mode 100644 index f01cc72d1a531..0000000000000 --- a/airflow/contrib/operators/snowflake_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.snowflake.operators.snowflake`.""" - -import warnings - -from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.snowflake.operators.snowflake`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/sns_publish_operator.py b/airflow/contrib/operators/sns_publish_operator.py deleted file mode 100644 index 104e240836f12..0000000000000 --- a/airflow/contrib/operators/sns_publish_operator.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sns`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.sns import SnsPublishOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sns`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/spark_jdbc_operator.py b/airflow/contrib/operators/spark_jdbc_operator.py deleted file mode 100644 index fc3cdc02704c6..0000000000000 --- a/airflow/contrib/operators/spark_jdbc_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.spark.operators.spark_jdbc`.""" - -import warnings - -from airflow.providers.apache.spark.operators.spark_jdbc import SparkJDBCOperator, SparkSubmitOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.spark.operators.spark_jdbc`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/spark_sql_operator.py b/airflow/contrib/operators/spark_sql_operator.py deleted file mode 100644 index 19e20d215dbf5..0000000000000 --- a/airflow/contrib/operators/spark_sql_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.spark.operators.spark_sql`.""" - -import warnings - -from airflow.providers.apache.spark.operators.spark_sql import SparkSqlOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.spark.operators.spark_sql`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/spark_submit_operator.py b/airflow/contrib/operators/spark_submit_operator.py deleted file mode 100644 index 103187e445fa4..0000000000000 --- a/airflow/contrib/operators/spark_submit_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.spark.operators.spark_submit`.""" - -import warnings - -from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.spark.operators.spark_submit`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/sql_to_gcs.py b/airflow/contrib/operators/sql_to_gcs.py deleted file mode 100644 index 13aa869c5d98d..0000000000000 --- a/airflow/contrib/operators/sql_to_gcs.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.transfers.sql_to_gcs`.""" - -import warnings - -from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.sql_to_gcs`.", - DeprecationWarning, - stacklevel=2, -) - - -class BaseSQLToGoogleCloudStorageOperator(BaseSQLToGCSOperator): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.transfers.sql_to_gcs.BaseSQLToGCSOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.transfers.sql_to_gcs.BaseSQLToGCSOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/sqoop_operator.py b/airflow/contrib/operators/sqoop_operator.py deleted file mode 100644 index 2757847abd65e..0000000000000 --- a/airflow/contrib/operators/sqoop_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.sqoop.operators.sqoop`.""" - -import warnings - -from airflow.providers.apache.sqoop.operators.sqoop import SqoopOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.sqoop.operators.sqoop`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/ssh_operator.py b/airflow/contrib/operators/ssh_operator.py deleted file mode 100644 index 56f94b9b26b63..0000000000000 --- a/airflow/contrib/operators/ssh_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.ssh.operators.ssh`.""" - -import warnings - -from airflow.providers.ssh.operators.ssh import SSHOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.ssh.operators.ssh`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/vertica_operator.py b/airflow/contrib/operators/vertica_operator.py deleted file mode 100644 index e652512ad4056..0000000000000 --- a/airflow/contrib/operators/vertica_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.vertica.operators.vertica`.""" - -import warnings - -from airflow.providers.vertica.operators.vertica import VerticaOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.vertica.operators.vertica`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/vertica_to_hive.py b/airflow/contrib/operators/vertica_to_hive.py deleted file mode 100644 index 49ddea172f374..0000000000000 --- a/airflow/contrib/operators/vertica_to_hive.py +++ /dev/null @@ -1,49 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.apache.hive.transfers.vertica_to_hive`. -""" - -import warnings - -from airflow.providers.apache.hive.transfers.vertica_to_hive import VerticaToHiveOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.hive.transfers.vertica_to_hive`.", - DeprecationWarning, - stacklevel=2, -) - - -class VerticaToHiveTransfer(VerticaToHiveOperator): - """This class is deprecated. - - Please use: - `airflow.providers.apache.hive.transfers.vertica_to_hive.VerticaToHiveOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.apache.hive.transfers.vertica_to_hive.VerticaToHiveOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/vertica_to_mysql.py b/airflow/contrib/operators/vertica_to_mysql.py deleted file mode 100644 index c85738f9091b2..0000000000000 --- a/airflow/contrib/operators/vertica_to_mysql.py +++ /dev/null @@ -1,49 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.mysql.transfers.vertica_to_mysql`. -""" - -import warnings - -from airflow.providers.mysql.transfers.vertica_to_mysql import VerticaToMySqlOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.mysql.transfers.vertica_to_mysql`.", - DeprecationWarning, - stacklevel=2, -) - - -class VerticaToMySqlTransfer(VerticaToMySqlOperator): - """This class is deprecated. - - Please use: - `airflow.providers.mysql.transfers.vertica_to_mysql.VerticaToMySqlOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.mysql.transfers.vertica_to_mysql.VerticaToMySqlOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/operators/wasb_delete_blob_operator.py b/airflow/contrib/operators/wasb_delete_blob_operator.py deleted file mode 100644 index cbf11b38fbf86..0000000000000 --- a/airflow/contrib/operators/wasb_delete_blob_operator.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.microsoft.azure.operators.wasb_delete_blob`. -""" - -import warnings - -from airflow.providers.microsoft.azure.operators.wasb_delete_blob import WasbDeleteBlobOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.operators.wasb_delete_blob`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/operators/winrm_operator.py b/airflow/contrib/operators/winrm_operator.py deleted file mode 100644 index fcc6213e71d6e..0000000000000 --- a/airflow/contrib/operators/winrm_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.winrm.operators.winrm`.""" - -import warnings - -from airflow.providers.microsoft.winrm.operators.winrm import WinRMOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.winrm.operators.winrm`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/secrets/__init__.py b/airflow/contrib/secrets/__init__.py index 31cf12a8e38db..9e2b4c74abdd7 100644 --- a/airflow/contrib/secrets/__init__.py +++ b/airflow/contrib/secrets/__init__.py @@ -16,3 +16,41 @@ # specific language governing permissions and limitations # under the License. """This package is deprecated. Please use `airflow.secrets` or `airflow.providers.*.secrets`.""" +from __future__ import annotations + +import warnings + +from airflow.exceptions import RemovedInAirflow3Warning +from airflow.utils.deprecation_tools import add_deprecated_classes + +warnings.warn( + "This module is deprecated. Please use airflow.providers.*.secrets.", + RemovedInAirflow3Warning, + stacklevel=2 +) + +__deprecated_classes = { + 'aws_secrets_manager': { + 'SecretsManagerBackend': 'airflow.providers.amazon.aws.secrets.secrets_manager.SecretsManagerBackend', + }, + 'aws_systems_manager': { + 'SystemsManagerParameterStoreBackend': ( + 'airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend' + ), + }, + 'azure_key_vault': { + 'AzureKeyVaultBackend': 'airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend', + }, + 'gcp_secrets_manager': { + 'CloudSecretManagerBackend': ( + 'airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend' + ), + 'CloudSecretsManagerBackend': ( + 'airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend' + ), + }, + 'hashicorp_vault': { + 'VaultBackend': 'airflow.providers.hashicorp.secrets.vault.VaultBackend', + }, +} +add_deprecated_classes(__deprecated_classes, __name__) diff --git a/airflow/contrib/secrets/aws_secrets_manager.py b/airflow/contrib/secrets/aws_secrets_manager.py deleted file mode 100644 index 833b03a59b06c..0000000000000 --- a/airflow/contrib/secrets/aws_secrets_manager.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.secrets.secrets_manager`.""" - -import warnings - -from airflow.providers.amazon.aws.secrets.secrets_manager import SecretsManagerBackend # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.secrets.secrets_manager`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/secrets/aws_systems_manager.py b/airflow/contrib/secrets/aws_systems_manager.py deleted file mode 100644 index 4c7a30cf05ab7..0000000000000 --- a/airflow/contrib/secrets/aws_systems_manager.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.secrets.systems_manager`.""" - -import warnings - -from airflow.providers.amazon.aws.secrets.systems_manager import SystemsManagerParameterStoreBackend # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.secrets.systems_manager`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/secrets/azure_key_vault.py b/airflow/contrib/secrets/azure_key_vault.py deleted file mode 100644 index 000ae92b3ac28..0000000000000 --- a/airflow/contrib/secrets/azure_key_vault.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.secrets.key_vault`.""" - -import warnings - -from airflow.providers.microsoft.azure.secrets.key_vault import AzureKeyVaultBackend # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.secrets.key_vault`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/secrets/gcp_secrets_manager.py b/airflow/contrib/secrets/gcp_secrets_manager.py deleted file mode 100644 index 7caa7ea2e88a7..0000000000000 --- a/airflow/contrib/secrets/gcp_secrets_manager.py +++ /dev/null @@ -1,45 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.secrets.secret_manager`.""" - -import warnings - -from airflow.providers.google.cloud.secrets.secret_manager import CloudSecretManagerBackend - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.secrets.secret_manager`.", - DeprecationWarning, - stacklevel=2, -) - - -class CloudSecretsManagerBackend(CloudSecretManagerBackend): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/secrets/hashicorp_vault.py b/airflow/contrib/secrets/hashicorp_vault.py deleted file mode 100644 index a3158d5d03abf..0000000000000 --- a/airflow/contrib/secrets/hashicorp_vault.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.hashicorp.secrets.vault`.""" - -import warnings - -from airflow.providers.hashicorp.secrets.vault import VaultBackend # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.hashicorp.secrets.vault`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/__init__.py b/airflow/contrib/sensors/__init__.py index de90d6c210596..65739f1ed0eb5 100644 --- a/airflow/contrib/sensors/__init__.py +++ b/airflow/contrib/sensors/__init__.py @@ -16,11 +16,151 @@ # specific language governing permissions and limitations # under the License. """This package is deprecated. Please use `airflow.sensors` or `airflow.providers.*.sensors`.""" +from __future__ import annotations import warnings +from airflow.exceptions import RemovedInAirflow3Warning +from airflow.utils.deprecation_tools import add_deprecated_classes + warnings.warn( "This package is deprecated. Please use `airflow.sensors` or `airflow.providers.*.sensors`.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) + +__deprecated_classes = { + 'aws_athena_sensor': { + 'AthenaSensor': 'airflow.providers.amazon.aws.sensors.athena.AthenaSensor', + }, + 'aws_glue_catalog_partition_sensor': { + 'AwsGlueCatalogPartitionSensor': ( + 'airflow.providers.amazon.aws.sensors.glue_catalog_partition.GlueCatalogPartitionSensor' + ), + }, + 'aws_redshift_cluster_sensor': { + 'AwsRedshiftClusterSensor': ( + 'airflow.providers.amazon.aws.sensors.redshift_cluster.RedshiftClusterSensor' + ), + }, + 'aws_sqs_sensor': { + 'SqsSensor': 'airflow.providers.amazon.aws.sensors.sqs.SqsSensor', + 'SQSSensor': 'airflow.providers.amazon.aws.sensors.sqs.SqsSensor', + }, + 'azure_cosmos_sensor': { + 'AzureCosmosDocumentSensor': ( + 'airflow.providers.microsoft.azure.sensors.cosmos.AzureCosmosDocumentSensor' + ), + }, + 'bash_sensor': { + 'STDOUT': 'airflow.sensors.bash.STDOUT', + 'BashSensor': 'airflow.sensors.bash.BashSensor', + 'Popen': 'airflow.sensors.bash.Popen', + 'TemporaryDirectory': 'airflow.sensors.bash.TemporaryDirectory', + 'gettempdir': 'airflow.sensors.bash.gettempdir', + }, + 'bigquery_sensor': { + 'BigQueryTableExistenceSensor': ( + 'airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceSensor' + ), + 'BigQueryTableSensor': 'airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceSensor', + }, + 'cassandra_record_sensor': { + 'CassandraRecordSensor': 'airflow.providers.apache.cassandra.sensors.record.CassandraRecordSensor', + }, + 'cassandra_table_sensor': { + 'CassandraTableSensor': 'airflow.providers.apache.cassandra.sensors.table.CassandraTableSensor', + }, + 'celery_queue_sensor': { + 'CeleryQueueSensor': 'airflow.providers.celery.sensors.celery_queue.CeleryQueueSensor', + }, + 'datadog_sensor': { + 'DatadogSensor': 'airflow.providers.datadog.sensors.datadog.DatadogSensor', + }, + 'file_sensor': { + 'FileSensor': 'airflow.sensors.filesystem.FileSensor', + }, + 'ftp_sensor': { + 'FTPSensor': 'airflow.providers.ftp.sensors.ftp.FTPSensor', + 'FTPSSensor': 'airflow.providers.ftp.sensors.ftp.FTPSSensor', + }, + 'gcp_transfer_sensor': { + 'CloudDataTransferServiceJobStatusSensor': + 'airflow.providers.google.cloud.sensors.cloud_storage_transfer_service.' + 'CloudDataTransferServiceJobStatusSensor', + 'GCPTransferServiceWaitForJobStatusSensor': + 'airflow.providers.google.cloud.sensors.cloud_storage_transfer_service.' + 'CloudDataTransferServiceJobStatusSensor', + }, + 'gcs_sensor': { + 'GCSObjectExistenceSensor': 'airflow.providers.google.cloud.sensors.gcs.GCSObjectExistenceSensor', + 'GCSObjectsWithPrefixExistenceSensor': ( + 'airflow.providers.google.cloud.sensors.gcs.GCSObjectsWithPrefixExistenceSensor' + ), + 'GCSObjectUpdateSensor': 'airflow.providers.google.cloud.sensors.gcs.GCSObjectUpdateSensor', + 'GCSUploadSessionCompleteSensor': ( + 'airflow.providers.google.cloud.sensors.gcs.GCSUploadSessionCompleteSensor' + ), + 'GoogleCloudStorageObjectSensor': ( + 'airflow.providers.google.cloud.sensors.gcs.GCSObjectExistenceSensor' + ), + 'GoogleCloudStorageObjectUpdatedSensor': ( + 'airflow.providers.google.cloud.sensors.gcs.GCSObjectUpdateSensor' + ), + 'GoogleCloudStoragePrefixSensor': ( + 'airflow.providers.google.cloud.sensors.gcs.GCSObjectsWithPrefixExistenceSensor' + ), + 'GoogleCloudStorageUploadSessionCompleteSensor': ( + 'airflow.providers.google.cloud.sensors.gcs.GCSUploadSessionCompleteSensor' + ), + }, + 'hdfs_sensor': { + 'HdfsFolderSensor': 'airflow.providers.apache.hdfs.sensors.hdfs.HdfsFolderSensor', + 'HdfsRegexSensor': 'airflow.providers.apache.hdfs.sensors.hdfs.HdfsRegexSensor', + 'HdfsSensorFolder': 'airflow.providers.apache.hdfs.sensors.hdfs.HdfsFolderSensor', + 'HdfsSensorRegex': 'airflow.providers.apache.hdfs.sensors.hdfs.HdfsRegexSensor', + }, + 'imap_attachment_sensor': { + 'ImapAttachmentSensor': 'airflow.providers.imap.sensors.imap_attachment.ImapAttachmentSensor', + }, + 'jira_sensor': { + 'JiraSensor': 'airflow.providers.atlassian.jira.sensors.jira.JiraSensor', + 'JiraTicketSensor': 'airflow.providers.atlassian.jira.sensors.jira.JiraTicketSensor', + }, + 'mongo_sensor': { + 'MongoSensor': 'airflow.providers.mongo.sensors.mongo.MongoSensor', + }, + 'pubsub_sensor': { + 'PubSubPullSensor': 'airflow.providers.google.cloud.sensors.pubsub.PubSubPullSensor', + }, + 'python_sensor': { + 'PythonSensor': 'airflow.sensors.python.PythonSensor', + }, + 'qubole_sensor': { + 'QuboleFileSensor': 'airflow.providers.qubole.sensors.qubole.QuboleFileSensor', + 'QubolePartitionSensor': 'airflow.providers.qubole.sensors.qubole.QubolePartitionSensor', + 'QuboleSensor': 'airflow.providers.qubole.sensors.qubole.QuboleSensor', + }, + 'redis_key_sensor': { + 'RedisKeySensor': 'airflow.providers.redis.sensors.redis_key.RedisKeySensor', + }, + 'redis_pub_sub_sensor': { + 'RedisPubSubSensor': 'airflow.providers.redis.sensors.redis_pub_sub.RedisPubSubSensor', + }, + 'sagemaker_training_sensor': { + 'SageMakerHook': 'airflow.providers.amazon.aws.sensors.sagemaker.SageMakerHook', + 'SageMakerTrainingSensor': 'airflow.providers.amazon.aws.sensors.sagemaker.SageMakerTrainingSensor', + }, + 'sftp_sensor': { + 'SFTPSensor': 'airflow.providers.sftp.sensors.sftp.SFTPSensor', + }, + 'wasb_sensor': { + 'WasbBlobSensor': 'airflow.providers.microsoft.azure.sensors.wasb.WasbBlobSensor', + 'WasbPrefixSensor': 'airflow.providers.microsoft.azure.sensors.wasb.WasbPrefixSensor', + }, + 'weekday_sensor': { + 'DayOfWeekSensor': 'airflow.sensors.weekday.DayOfWeekSensor', + }, +} + +add_deprecated_classes(__deprecated_classes, __name__) diff --git a/airflow/contrib/sensors/aws_athena_sensor.py b/airflow/contrib/sensors/aws_athena_sensor.py deleted file mode 100644 index ddffc38bb63e8..0000000000000 --- a/airflow/contrib/sensors/aws_athena_sensor.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.athena`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.athena import AthenaSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.athena`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/aws_glue_catalog_partition_sensor.py b/airflow/contrib/sensors/aws_glue_catalog_partition_sensor.py deleted file mode 100644 index 66975a86f3e2f..0000000000000 --- a/airflow/contrib/sensors/aws_glue_catalog_partition_sensor.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.amazon.aws.sensors.glue_catalog_partition`. -""" - -import warnings - -from airflow.providers.amazon.aws.sensors.glue_catalog_partition import AwsGlueCatalogPartitionSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.glue_catalog_partition`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/aws_redshift_cluster_sensor.py b/airflow/contrib/sensors/aws_redshift_cluster_sensor.py deleted file mode 100644 index dc0da1372b400..0000000000000 --- a/airflow/contrib/sensors/aws_redshift_cluster_sensor.py +++ /dev/null @@ -1,33 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.redshift_cluster`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.redshift_cluster import ( - RedshiftClusterSensor as AwsRedshiftClusterSensor, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.redshift_cluster`.", - DeprecationWarning, - stacklevel=2, -) - -__all__ = ["AwsRedshiftClusterSensor"] diff --git a/airflow/contrib/sensors/aws_sqs_sensor.py b/airflow/contrib/sensors/aws_sqs_sensor.py deleted file mode 100644 index ed251f1ee9c22..0000000000000 --- a/airflow/contrib/sensors/aws_sqs_sensor.py +++ /dev/null @@ -1,45 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sqs`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.sqs import SqsSensor - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sqs`.", - DeprecationWarning, - stacklevel=2, -) - - -class SQSSensor(SqsSensor): - """ - This sensor is deprecated. - Please use :class:`airflow.providers.amazon.aws.sensors.sqs.SqsSensor`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This class is deprecated. " - "Please use :class:`airflow.providers.amazon.aws.sensors.sqs.SqsSensor`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/sensors/azure_cosmos_sensor.py b/airflow/contrib/sensors/azure_cosmos_sensor.py deleted file mode 100644 index fc3df4f26615e..0000000000000 --- a/airflow/contrib/sensors/azure_cosmos_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.sensors.cosmos`.""" - -import warnings - -from airflow.providers.microsoft.azure.sensors.cosmos import AzureCosmosDocumentSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.sensors.cosmos`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/bash_sensor.py b/airflow/contrib/sensors/bash_sensor.py deleted file mode 100644 index c3d9c814696c7..0000000000000 --- a/airflow/contrib/sensors/bash_sensor.py +++ /dev/null @@ -1,26 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.sensors.bash`.""" - -import warnings - -from airflow.sensors.bash import STDOUT, BashSensor, Popen, TemporaryDirectory, gettempdir # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.sensors.bash`.", DeprecationWarning, stacklevel=2 -) diff --git a/airflow/contrib/sensors/bigquery_sensor.py b/airflow/contrib/sensors/bigquery_sensor.py deleted file mode 100644 index d58445e3d174b..0000000000000 --- a/airflow/contrib/sensors/bigquery_sensor.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.sensors.bigquery`.""" - -import warnings - -from airflow.providers.google.cloud.sensors.bigquery import BigQueryTableExistenceSensor - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.sensors.bigquery`.", - DeprecationWarning, - stacklevel=2, -) - - -class BigQueryTableSensor(BigQueryTableExistenceSensor): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceSensor`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceSensor`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/sensors/cassandra_record_sensor.py b/airflow/contrib/sensors/cassandra_record_sensor.py deleted file mode 100644 index cfc3b30107ead..0000000000000 --- a/airflow/contrib/sensors/cassandra_record_sensor.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.apache.cassandra.sensors.record`.""" - -import warnings - -from airflow.providers.apache.cassandra.sensors.record import CassandraRecordSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.cassandra.sensors.record`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/cassandra_table_sensor.py b/airflow/contrib/sensors/cassandra_table_sensor.py deleted file mode 100644 index 0b7c7aa6eb73f..0000000000000 --- a/airflow/contrib/sensors/cassandra_table_sensor.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.apache.cassandra.sensors.table`.""" - -import warnings - -from airflow.providers.apache.cassandra.sensors.table import CassandraTableSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.cassandra.sensors.table`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/celery_queue_sensor.py b/airflow/contrib/sensors/celery_queue_sensor.py deleted file mode 100644 index 6ed2be1c93ac5..0000000000000 --- a/airflow/contrib/sensors/celery_queue_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.celery.sensors.celery_queue`.""" - -import warnings - -from airflow.providers.celery.sensors.celery_queue import CeleryQueueSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.celery.sensors.celery_queue`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/datadog_sensor.py b/airflow/contrib/sensors/datadog_sensor.py deleted file mode 100644 index d0377d11a8889..0000000000000 --- a/airflow/contrib/sensors/datadog_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.datadog.sensors.datadog`.""" - -import warnings - -from airflow.providers.datadog.sensors.datadog import DatadogSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.datadog.sensors.datadog`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/emr_base_sensor.py b/airflow/contrib/sensors/emr_base_sensor.py deleted file mode 100644 index 08d0efed81475..0000000000000 --- a/airflow/contrib/sensors/emr_base_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.emr_base`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.emr_base import EmrBaseSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_base`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/emr_job_flow_sensor.py b/airflow/contrib/sensors/emr_job_flow_sensor.py deleted file mode 100644 index 429052a4ec969..0000000000000 --- a/airflow/contrib/sensors/emr_job_flow_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.emr_job_flow`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.emr_job_flow import EmrJobFlowSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_job_flow`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/emr_step_sensor.py b/airflow/contrib/sensors/emr_step_sensor.py deleted file mode 100644 index 9d4ac9b166ed6..0000000000000 --- a/airflow/contrib/sensors/emr_step_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.emr_step`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.emr_step import EmrStepSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_step`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/file_sensor.py b/airflow/contrib/sensors/file_sensor.py deleted file mode 100644 index 6d75b657e6cef..0000000000000 --- a/airflow/contrib/sensors/file_sensor.py +++ /dev/null @@ -1,26 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.sensors.filesystem`.""" - -import warnings - -from airflow.sensors.filesystem import FileSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.sensors.filesystem`.", DeprecationWarning, stacklevel=2 -) diff --git a/airflow/contrib/sensors/ftp_sensor.py b/airflow/contrib/sensors/ftp_sensor.py deleted file mode 100644 index 76c47c4609854..0000000000000 --- a/airflow/contrib/sensors/ftp_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.ftp.sensors.ftp`.""" - -import warnings - -from airflow.providers.ftp.sensors.ftp import FTPSensor, FTPSSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.ftp.sensors.ftp`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/gcp_transfer_sensor.py b/airflow/contrib/sensors/gcp_transfer_sensor.py deleted file mode 100644 index 429ddb1a44b05..0000000000000 --- a/airflow/contrib/sensors/gcp_transfer_sensor.py +++ /dev/null @@ -1,51 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.google.cloud.sensors.cloud_storage_transfer_service`. -""" - -import warnings - -from airflow.providers.google.cloud.sensors.cloud_storage_transfer_service import ( - CloudDataTransferServiceJobStatusSensor, -) - -warnings.warn( - "This module is deprecated. " - "Please use `airflow.providers.google.cloud.sensors.cloud_storage_transfer_service`.", - DeprecationWarning, - stacklevel=2, -) - - -class GCPTransferServiceWaitForJobStatusSensor(CloudDataTransferServiceJobStatusSensor): - """This class is deprecated. - - Please use `airflow.providers.google.cloud.sensors.transfer.CloudDataTransferServiceJobStatusSensor`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.google.cloud.sensors.transfer.CloudDataTransferServiceJobStatusSensor`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/sensors/gcs_sensor.py b/airflow/contrib/sensors/gcs_sensor.py deleted file mode 100644 index 3df15f676c265..0000000000000 --- a/airflow/contrib/sensors/gcs_sensor.py +++ /dev/null @@ -1,97 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.sensors.gcs`.""" - -import warnings - -from airflow.providers.google.cloud.sensors.gcs import ( - GCSObjectExistenceSensor, - GCSObjectsWithPrefixExistenceSensor, - GCSObjectUpdateSensor, - GCSUploadSessionCompleteSensor, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.sensors.gcs`.", - DeprecationWarning, - stacklevel=2, -) - - -class GoogleCloudStorageObjectSensor(GCSObjectExistenceSensor): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.sensors.gcs.GCSObjectExistenceSensor`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.sensors.gcs.GCSObjectExistenceSensor`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class GoogleCloudStorageObjectUpdatedSensor(GCSObjectUpdateSensor): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.sensors.gcs.GCSObjectUpdateSensor`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.sensors.gcs.GCSObjectUpdateSensor`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class GoogleCloudStoragePrefixSensor(GCSObjectsWithPrefixExistenceSensor): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.sensors.gcs.GCSObjectsWithPrefixExistenceSensor`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.sensors.gcs.GCSObjectsWithPrefixExistenceSensor`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class GoogleCloudStorageUploadSessionCompleteSensor(GCSUploadSessionCompleteSensor): - """ - This class is deprecated. - Please use `airflow.providers.google.cloud.sensors.gcs.GCSUploadSessionCompleteSensor`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.google.cloud.sensors.gcs.GCSUploadSessionCompleteSensor`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/sensors/hdfs_sensor.py b/airflow/contrib/sensors/hdfs_sensor.py deleted file mode 100644 index d71ec8fc2f454..0000000000000 --- a/airflow/contrib/sensors/hdfs_sensor.py +++ /dev/null @@ -1,67 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.apache.hdfs.sensors.hdfs`. -""" - -import warnings - -from airflow.providers.apache.hdfs.sensors.hdfs import HdfsFolderSensor, HdfsRegexSensor - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.hdfs.sensors.hdfs`.", - DeprecationWarning, - stacklevel=2, -) - - -class HdfsSensorFolder(HdfsFolderSensor): - """This class is deprecated. - - Please use: - `airflow.providers.apache.hdfs.sensors.hdfs.HdfsFolderSensor`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.apache.hdfs.sensors.hdfs.HdfsFolderSensor`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class HdfsSensorRegex(HdfsRegexSensor): - """This class is deprecated. - - Please use: - `airflow.providers.apache.hdfs.sensors.hdfs.HdfsRegexSensor`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.apache.hdfs.sensors.hdfs.HdfsRegexSensor`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/contrib/sensors/imap_attachment_sensor.py b/airflow/contrib/sensors/imap_attachment_sensor.py deleted file mode 100644 index 34d2d7f1402e8..0000000000000 --- a/airflow/contrib/sensors/imap_attachment_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.imap.sensors.imap_attachment`.""" - -import warnings - -from airflow.providers.imap.sensors.imap_attachment import ImapAttachmentSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.imap.sensors.imap_attachment`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/jira_sensor.py b/airflow/contrib/sensors/jira_sensor.py deleted file mode 100644 index e7c3785209616..0000000000000 --- a/airflow/contrib/sensors/jira_sensor.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.jira.sensors.jira`.""" - -import warnings - -from airflow.providers.jira.sensors.jira import JiraSensor, JiraTicketSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.jira.sensors.jira`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/mongo_sensor.py b/airflow/contrib/sensors/mongo_sensor.py deleted file mode 100644 index 13a5f0b65af4c..0000000000000 --- a/airflow/contrib/sensors/mongo_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.mongo.sensors.mongo`.""" - -import warnings - -from airflow.providers.mongo.sensors.mongo import MongoSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.mongo.sensors.mongo`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/pubsub_sensor.py b/airflow/contrib/sensors/pubsub_sensor.py deleted file mode 100644 index eea404e216995..0000000000000 --- a/airflow/contrib/sensors/pubsub_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.sensors.pubsub`.""" - -import warnings - -from airflow.providers.google.cloud.sensors.pubsub import PubSubPullSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.sensors.pubsub`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/python_sensor.py b/airflow/contrib/sensors/python_sensor.py deleted file mode 100644 index bc7543c2fd372..0000000000000 --- a/airflow/contrib/sensors/python_sensor.py +++ /dev/null @@ -1,26 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.sensors.python`.""" - -import warnings - -from airflow.sensors.python import PythonSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.sensors.python`.", DeprecationWarning, stacklevel=2 -) diff --git a/airflow/contrib/sensors/qubole_sensor.py b/airflow/contrib/sensors/qubole_sensor.py deleted file mode 100644 index 6b656249003e9..0000000000000 --- a/airflow/contrib/sensors/qubole_sensor.py +++ /dev/null @@ -1,32 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.qubole.sensors.qubole`.""" - -import warnings - -from airflow.providers.qubole.sensors.qubole import ( # noqa - QuboleFileSensor, - QubolePartitionSensor, - QuboleSensor, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.qubole.sensors.qubole`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/redis_key_sensor.py b/airflow/contrib/sensors/redis_key_sensor.py deleted file mode 100644 index f500c86dac5f3..0000000000000 --- a/airflow/contrib/sensors/redis_key_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.redis.sensors.redis_key`.""" - -import warnings - -from airflow.providers.redis.sensors.redis_key import RedisKeySensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.redis.sensors.redis_key`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/redis_pub_sub_sensor.py b/airflow/contrib/sensors/redis_pub_sub_sensor.py deleted file mode 100644 index 16946ac40d198..0000000000000 --- a/airflow/contrib/sensors/redis_pub_sub_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.redis.sensors.redis_pub_sub`.""" - -import warnings - -from airflow.providers.redis.sensors.redis_pub_sub import RedisPubSubSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.redis.sensors.redis_pub_sub`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/sagemaker_base_sensor.py b/airflow/contrib/sensors/sagemaker_base_sensor.py deleted file mode 100644 index 86e32330278e1..0000000000000 --- a/airflow/contrib/sensors/sagemaker_base_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker_base`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_base`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/sagemaker_endpoint_sensor.py b/airflow/contrib/sensors/sagemaker_endpoint_sensor.py deleted file mode 100644 index 5107d6f542fd0..0000000000000 --- a/airflow/contrib/sensors/sagemaker_endpoint_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker_endpoint`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.sagemaker_endpoint import SageMakerEndpointSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_endpoint`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/sagemaker_training_sensor.py b/airflow/contrib/sensors/sagemaker_training_sensor.py deleted file mode 100644 index 6a56516042a6f..0000000000000 --- a/airflow/contrib/sensors/sagemaker_training_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerHook, SageMakerTrainingSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/sagemaker_transform_sensor.py b/airflow/contrib/sensors/sagemaker_transform_sensor.py deleted file mode 100644 index 29fd18f8baaae..0000000000000 --- a/airflow/contrib/sensors/sagemaker_transform_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker_transform`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.sagemaker_transform import SageMakerTransformSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_transform`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/sagemaker_tuning_sensor.py b/airflow/contrib/sensors/sagemaker_tuning_sensor.py deleted file mode 100644 index 7079e4ccb774f..0000000000000 --- a/airflow/contrib/sensors/sagemaker_tuning_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker_tuning`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.sagemaker_tuning import SageMakerTuningSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_tuning`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/sftp_sensor.py b/airflow/contrib/sensors/sftp_sensor.py deleted file mode 100644 index d2700e814295a..0000000000000 --- a/airflow/contrib/sensors/sftp_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.sftp.sensors.sftp`.""" - -import warnings - -from airflow.providers.sftp.sensors.sftp import SFTPSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.sftp.sensors.sftp`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/wasb_sensor.py b/airflow/contrib/sensors/wasb_sensor.py deleted file mode 100644 index d8e0748907afe..0000000000000 --- a/airflow/contrib/sensors/wasb_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.sensors.wasb`.""" - -import warnings - -from airflow.providers.microsoft.azure.sensors.wasb import WasbBlobSensor, WasbPrefixSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.sensors.wasb`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/sensors/weekday_sensor.py b/airflow/contrib/sensors/weekday_sensor.py deleted file mode 100644 index 1f836e1dda936..0000000000000 --- a/airflow/contrib/sensors/weekday_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.sensors.weekday`.""" - -import warnings - -from airflow.sensors.weekday import DayOfWeekSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.sensors.weekday`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/task_runner/__init__.py b/airflow/contrib/task_runner/__init__.py index 77842e3dc5b4c..82bd6bf04148b 100644 --- a/airflow/contrib/task_runner/__init__.py +++ b/airflow/contrib/task_runner/__init__.py @@ -16,3 +16,21 @@ # specific language governing permissions and limitations # under the License. """This package is deprecated. Please use `airflow.task.task_runner`.""" +from __future__ import annotations + +import warnings + +from airflow.exceptions import RemovedInAirflow3Warning +from airflow.utils.deprecation_tools import add_deprecated_classes + +warnings.warn( + "This module is deprecated. Please use airflow.task.task_runner.", RemovedInAirflow3Warning, stacklevel=2 +) + +__deprecated_classes = { + 'cgroup_task_runner': { + 'CgroupTaskRunner': 'airflow.task.task_runner.cgroup_task_runner.CgroupTaskRunner', + }, +} + +add_deprecated_classes(__deprecated_classes, __name__) diff --git a/airflow/contrib/task_runner/cgroup_task_runner.py b/airflow/contrib/task_runner/cgroup_task_runner.py deleted file mode 100644 index f923126fe4475..0000000000000 --- a/airflow/contrib/task_runner/cgroup_task_runner.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.task.task_runner.cgroup_task_runner`.""" - -import warnings - -from airflow.task.task_runner.cgroup_task_runner import CgroupTaskRunner # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.task.task_runner.cgroup_task_runner`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/utils/__init__.py b/airflow/contrib/utils/__init__.py index 4e6cbb7d51e23..4057f36868d74 100644 --- a/airflow/contrib/utils/__init__.py +++ b/airflow/contrib/utils/__init__.py @@ -16,7 +16,50 @@ # specific language governing permissions and limitations # under the License. """This package is deprecated. Please use `airflow.utils`.""" +from __future__ import annotations import warnings -warnings.warn("This module is deprecated. Please use `airflow.utils`.", DeprecationWarning, stacklevel=2) +from airflow.exceptions import RemovedInAirflow3Warning +from airflow.utils.deprecation_tools import add_deprecated_classes + +warnings.warn( + "This module is deprecated. Please use `airflow.utils`.", + RemovedInAirflow3Warning, + stacklevel=2 +) + +__deprecated_classes = { + 'gcp_field_sanitizer': { + 'GcpBodyFieldSanitizer': 'airflow.providers.google.cloud.utils.field_sanitizer.GcpBodyFieldSanitizer', + 'GcpFieldSanitizerException': ( + 'airflow.providers.google.cloud.utils.field_sanitizer.GcpFieldSanitizerException' + ), + }, + 'gcp_field_validator': { + 'GcpBodyFieldValidator': 'airflow.providers.google.cloud.utils.field_validator.GcpBodyFieldValidator', + 'GcpFieldValidationException': ( + 'airflow.providers.google.cloud.utils.field_validator.GcpFieldValidationException' + ), + 'GcpValidationSpecificationException': ( + 'airflow.providers.google.cloud.utils.field_validator.GcpValidationSpecificationException' + ), + }, + 'mlengine_operator_utils': { + 'create_evaluate_ops': ( + 'airflow.providers.google.cloud.utils.mlengine_operator_utils.create_evaluate_ops' + ), + }, + 'mlengine_prediction_summary': { + 'JsonCoder': 'airflow.providers.google.cloud.utils.mlengine_prediction_summary.JsonCoder', + 'MakeSummary': 'airflow.providers.google.cloud.utils.mlengine_prediction_summary.MakeSummary', + }, + 'sendgrid': { + 'import_string': 'airflow.utils.module_loading.import_string', + }, + 'weekday': { + 'WeekDay': 'airflow.utils.weekday.WeekDay', + }, +} + +add_deprecated_classes(__deprecated_classes, __name__) diff --git a/airflow/contrib/utils/gcp_field_sanitizer.py b/airflow/contrib/utils/gcp_field_sanitizer.py deleted file mode 100644 index 37c0aff2b76d6..0000000000000 --- a/airflow/contrib/utils/gcp_field_sanitizer.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.utils.field_sanitizer`""" - -import warnings - -from airflow.providers.google.cloud.utils.field_sanitizer import ( # noqa - GcpBodyFieldSanitizer, - GcpFieldSanitizerException, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.utils.field_sanitizer`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/utils/gcp_field_validator.py b/airflow/contrib/utils/gcp_field_validator.py deleted file mode 100644 index fc42dca94be00..0000000000000 --- a/airflow/contrib/utils/gcp_field_validator.py +++ /dev/null @@ -1,32 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.utils.field_validator`.""" - -import warnings - -from airflow.providers.google.cloud.utils.field_validator import ( # noqa - GcpBodyFieldValidator, - GcpFieldValidationException, - GcpValidationSpecificationException, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.utils.field_validator`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/utils/log/__init__.py b/airflow/contrib/utils/log/__init__.py index ecc47d3d85122..baf3a481858c8 100644 --- a/airflow/contrib/utils/log/__init__.py +++ b/airflow/contrib/utils/log/__init__.py @@ -15,7 +15,20 @@ # specific language governing permissions and limitations # under the License. """This package is deprecated. Please use `airflow.utils.log`.""" +from __future__ import annotations import warnings +from airflow.utils.deprecation_tools import add_deprecated_classes + warnings.warn("This module is deprecated. Please use `airflow.utils.log`.", DeprecationWarning, stacklevel=2) + +__deprecated_classes = { + 'task_handler_with_custom_formatter': { + 'TaskHandlerWithCustomFormatter': ( + 'airflow.utils.log.task_handler_with_custom_formatter.TaskHandlerWithCustomFormatter' + ), + }, +} + +add_deprecated_classes(__deprecated_classes, __name__) diff --git a/airflow/contrib/utils/log/task_handler_with_custom_formatter.py b/airflow/contrib/utils/log/task_handler_with_custom_formatter.py deleted file mode 100644 index 9bbdee3b2c8c3..0000000000000 --- a/airflow/contrib/utils/log/task_handler_with_custom_formatter.py +++ /dev/null @@ -1,27 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.utils.log.task_handler_with_custom_formatter`.""" - -import warnings - -from airflow.utils.log.task_handler_with_custom_formatter import TaskHandlerWithCustomFormatter # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.utils.log.task_handler_with_custom_formatter`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/utils/mlengine_operator_utils.py b/airflow/contrib/utils/mlengine_operator_utils.py deleted file mode 100644 index ebd630c1912a2..0000000000000 --- a/airflow/contrib/utils/mlengine_operator_utils.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.google.cloud.utils.mlengine_operator_utils`. -""" - -import warnings - -from airflow.providers.google.cloud.utils.mlengine_operator_utils import create_evaluate_ops # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.google.cloud.utils.mlengine_operator_utils`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/utils/mlengine_prediction_summary.py b/airflow/contrib/utils/mlengine_prediction_summary.py deleted file mode 100644 index ea390525a359d..0000000000000 --- a/airflow/contrib/utils/mlengine_prediction_summary.py +++ /dev/null @@ -1,32 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.google.cloud.utils.mlengine_prediction_summary`. -""" - -import warnings - -from airflow.providers.google.cloud.utils.mlengine_prediction_summary import JsonCoder, MakeSummary # noqa - -warnings.warn( - "This module is deprecated. " - "Please use `airflow.providers.google.cloud.utils.mlengine_prediction_summary`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/contrib/utils/sendgrid.py b/airflow/contrib/utils/sendgrid.py deleted file mode 100644 index 16408f92d3a5a..0000000000000 --- a/airflow/contrib/utils/sendgrid.py +++ /dev/null @@ -1,36 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -This module is deprecated. -Please use `airflow.providers.sendgrid.utils.emailer`. -""" - -import warnings - -from airflow.utils.module_loading import import_string - - -def send_email(*args, **kwargs): - """This function is deprecated. Please use `airflow.providers.sendgrid.utils.emailer.send_email`.""" - warnings.warn( - "This function is deprecated. Please use `airflow.providers.sendgrid.utils.emailer.send_email`.", - DeprecationWarning, - stacklevel=2, - ) - return import_string('airflow.providers.sendgrid.utils.emailer.send_email')(*args, **kwargs) diff --git a/airflow/contrib/utils/weekday.py b/airflow/contrib/utils/weekday.py deleted file mode 100644 index 2f2448c8896da..0000000000000 --- a/airflow/contrib/utils/weekday.py +++ /dev/null @@ -1,24 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.utils.weekday`.""" -import warnings - -from airflow.utils.weekday import WeekDay # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.utils.weekday`.", DeprecationWarning, stacklevel=2 -) diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index 38f5ffdbbf73a..1d8a082869efc 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -16,6 +16,9 @@ # specific language governing permissions and limitations # under the License. """Processes DAGs.""" +from __future__ import annotations + +import collections import enum import importlib import inspect @@ -31,17 +34,21 @@ from datetime import datetime, timedelta from importlib import import_module from multiprocessing.connection import Connection as MultiprocessingConnection -from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union, cast +from pathlib import Path +from typing import Any, NamedTuple, cast from setproctitle import setproctitle from sqlalchemy.orm import Session from tabulate import tabulate import airflow.models -from airflow.callbacks.callback_requests import CallbackRequest +from airflow.callbacks.callback_requests import CallbackRequest, SlaCallbackRequest from airflow.configuration import conf from airflow.dag_processing.processor import DagFileProcessorProcess -from airflow.models import DagModel, DbCallbackRequest, errors +from airflow.models import errors +from airflow.models.dag import DagModel +from airflow.models.dagwarning import DagWarning +from airflow.models.db_callback_request import DbCallbackRequest from airflow.models.serialized_dag import SerializedDagModel from airflow.stats import Stats from airflow.utils import timezone @@ -57,9 +64,6 @@ from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import prohibit_commit, skip_locked, with_row_locks -if TYPE_CHECKING: - import pathlib - class DagParsingStat(NamedTuple): """Information on processing progress""" @@ -73,17 +77,17 @@ class DagFileStat(NamedTuple): num_dags: int import_errors: int - last_finish_time: Optional[datetime] - last_duration: Optional[timedelta] + last_finish_time: datetime | None + last_duration: timedelta | None run_count: int class DagParsingSignal(enum.Enum): """All signals sent to parser.""" - AGENT_RUN_ONCE = 'agent_run_once' - TERMINATE_MANAGER = 'terminate_manager' - END_MANAGER = 'end_manager' + AGENT_RUN_ONCE = "agent_run_once" + TERMINATE_MANAGER = "terminate_manager" + END_MANAGER = "end_manager" class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin): @@ -107,30 +111,29 @@ class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin): def __init__( self, - dag_directory: str, + dag_directory: os.PathLike, max_runs: int, processor_timeout: timedelta, - dag_ids: Optional[List[str]], + dag_ids: list[str] | None, pickle_dags: bool, async_mode: bool, ): super().__init__() - self._file_path_queue: List[str] = [] - self._dag_directory: str = dag_directory + self._dag_directory: os.PathLike = dag_directory self._max_runs = max_runs self._processor_timeout = processor_timeout self._dag_ids = dag_ids self._pickle_dags = pickle_dags self._async_mode = async_mode # Map from file path to the processor - self._processors: Dict[str, DagFileProcessorProcess] = {} + self._processors: dict[str, DagFileProcessorProcess] = {} # Pipe for communicating signals - self._process: Optional[multiprocessing.process.BaseProcess] = None + self._process: multiprocessing.process.BaseProcess | None = None self._done: bool = False # Initialized as true so we do not deactivate w/o any actual DAG parsing. self._all_files_processed = True - self._parent_signal_conn: Optional[MultiprocessingConnection] = None + self._parent_signal_conn: MultiprocessingConnection | None = None self._last_parsing_stat_received_at: float = time.monotonic() @@ -205,11 +208,11 @@ def wait_until_finished(self) -> None: @staticmethod def _run_processor_manager( - dag_directory: str, + dag_directory: os.PathLike, max_runs: int, processor_timeout: timedelta, signal_conn: MultiprocessingConnection, - dag_ids: Optional[List[str]], + dag_ids: list[str] | None, pickle_dags: bool, async_mode: bool, ) -> None: @@ -217,15 +220,15 @@ def _run_processor_manager( # Make this process start as a new process group - that makes it easy # to kill all sub-process of this at the OS-level, rather than having # to iterate the child processes - os.setpgid(0, 0) + set_new_process_group() setproctitle("airflow scheduler -- DagFileProcessorManager") # Reload configurations and settings to avoid collision with parent process. # Because this process may need custom configurations that cannot be shared, # e.g. RotatingFileHandler. And it can cause connection corruption if we # do not recreate the SQLA connection pool. - os.environ['CONFIG_PROCESSOR_MANAGER_LOGGER'] = 'True' - os.environ['AIRFLOW__LOGGING__COLORED_CONSOLE_LOG'] = 'False' + os.environ["CONFIG_PROCESSOR_MANAGER_LOGGER"] = "True" + os.environ["AIRFLOW__LOGGING__COLORED_CONSOLE_LOG"] = "False" # Replicating the behavior of how logging module was loaded # in logging_config.py @@ -238,10 +241,10 @@ def _run_processor_manager( # The issue that describes the problem and possible remediation is # at https://github.com/apache/airflow/issues/19934 - importlib.reload(import_module(airflow.settings.LOGGING_CLASS_PATH.rsplit('.', 1)[0])) # type: ignore + importlib.reload(import_module(airflow.settings.LOGGING_CLASS_PATH.rsplit(".", 1)[0])) # type: ignore importlib.reload(airflow.settings) airflow.settings.initialize() - del os.environ['CONFIG_PROCESSOR_MANAGER_LOGGER'] + del os.environ["CONFIG_PROCESSOR_MANAGER_LOGGER"] processor_manager = DagFileProcessorManager( dag_directory=dag_directory, max_runs=max_runs, @@ -251,7 +254,6 @@ def _run_processor_manager( signal_conn=signal_conn, async_mode=async_mode, ) - processor_manager.start() def heartbeat(self) -> None: @@ -295,7 +297,7 @@ def _heartbeat_manager(self): parsing_stat_age = time.monotonic() - self._last_parsing_stat_received_at if parsing_stat_age > self._processor_timeout.total_seconds(): - Stats.incr('dag_processing.manager_stalls') + Stats.incr("dag_processing.manager_stalls") self.log.error( "DagFileProcessorManager (PID=%d) last sent a heartbeat %.2f seconds ago! Restarting it", self._process.pid, @@ -338,7 +340,7 @@ def end(self): :return: """ if not self._process: - self.log.warning('Ending without manager process.') + self.log.warning("Ending without manager process.") return # Give the Manager some time to cleanly shut down, but not too long, as # it's better to finish sooner than wait for (non-critical) work to @@ -369,25 +371,25 @@ class DagFileProcessorManager(LoggingMixin): def __init__( self, - dag_directory: Union[str, "pathlib.Path"], + dag_directory: os.PathLike, max_runs: int, processor_timeout: timedelta, - dag_ids: Optional[List[str]], + dag_ids: list[str] | None, pickle_dags: bool, - signal_conn: Optional[MultiprocessingConnection] = None, + signal_conn: MultiprocessingConnection | None = None, async_mode: bool = True, ): super().__init__() - self._file_paths: List[str] = [] - self._file_path_queue: List[str] = [] - self._dag_directory = dag_directory + self._file_paths: list[str] = [] + self._file_path_queue: collections.deque[str] = collections.deque() self._max_runs = max_runs # signal_conn is None for dag_processor_standalone mode. self._direct_scheduler_conn = signal_conn self._pickle_dags = pickle_dags self._dag_ids = dag_ids self._async_mode = async_mode - self._parsing_start_time: Optional[int] = None + self._parsing_start_time: int | None = None + self._dag_directory = dag_directory # Set the signal conn in to non-blocking mode, so that attempting to # send when the buffer is full errors, rather than hangs for-ever @@ -398,9 +400,10 @@ def __init__( if self._async_mode and self._direct_scheduler_conn is not None: os.set_blocking(self._direct_scheduler_conn.fileno(), False) - self._parallelism = conf.getint('scheduler', 'parsing_processes') + self.standalone_dag_processor = conf.getboolean("scheduler", "standalone_dag_processor") + self._parallelism = conf.getint("scheduler", "parsing_processes") if ( - conf.get_mandatory_value('database', 'sql_alchemy_conn').startswith('sqlite') + conf.get_mandatory_value("database", "sql_alchemy_conn").startswith("sqlite") and self._parallelism > 1 ): self.log.warning( @@ -411,18 +414,18 @@ def __init__( self._parallelism = 1 # Parse and schedule each file no faster than this interval. - self._file_process_interval = conf.getint('scheduler', 'min_file_process_interval') + self._file_process_interval = conf.getint("scheduler", "min_file_process_interval") # How often to print out DAG file processing stats to the log. Default to # 30 seconds. - self.print_stats_interval = conf.getint('scheduler', 'print_stats_interval') + self.print_stats_interval = conf.getint("scheduler", "print_stats_interval") # Map from file path to the processor - self._processors: Dict[str, DagFileProcessorProcess] = {} + self._processors: dict[str, DagFileProcessorProcess] = {} self._num_run = 0 # Map from file path to stats about the file - self._file_stats: Dict[str, DagFileStat] = {} + self._file_stats: dict[str, DagFileStat] = {} # Last time that the DAG dir was traversed to look for files self.last_dag_dir_refresh_time = timezone.make_aware(datetime.fromtimestamp(0)) @@ -431,18 +434,18 @@ def __init__( # Last time we cleaned up DAGs which are no longer in files self.last_deactivate_stale_dags_time = timezone.make_aware(datetime.fromtimestamp(0)) # How often to check for DAGs which are no longer in files - self.deactivate_stale_dags_interval = conf.getint('scheduler', 'deactivate_stale_dags_interval') + self.parsing_cleanup_interval = conf.getint("scheduler", "parsing_cleanup_interval") # How long to wait before timing out a process to parse a DAG file self._processor_timeout = processor_timeout # How often to scan the DAGs directory for new files. Default to 5 minutes. - self.dag_dir_list_interval = conf.getint('scheduler', 'dag_dir_list_interval') + self.dag_dir_list_interval = conf.getint("scheduler", "dag_dir_list_interval") # Mapping file name and callbacks requests - self._callback_to_execute: Dict[str, List[CallbackRequest]] = defaultdict(list) + self._callback_to_execute: dict[str, list[CallbackRequest]] = defaultdict(list) - self._log = logging.getLogger('airflow.processor_manager') + self._log = logging.getLogger("airflow.processor_manager") - self.waitables: Dict[Any, Union[MultiprocessingConnection, DagFileProcessorProcess]] = ( + self.waitables: dict[Any, MultiprocessingConnection | DagFileProcessorProcess] = ( { self._direct_scheduler_conn: self._direct_scheduler_conn, } @@ -460,7 +463,7 @@ def register_exit_signals(self): def _exit_gracefully(self, signum, frame): """Helper method to clean up DAG file processors to avoid leaving orphan processes.""" self.log.info("Exiting gracefully upon receiving signal %s", signum) - self.log.debug("Current Stacktrace is: %s", '\n'.join(map(str, inspect.stack()))) + self.log.debug("Current Stacktrace is: %s", "\n".join(map(str, inspect.stack()))) self.terminate() self.end() self.log.debug("Finished terminating DAG processors.") @@ -494,16 +497,18 @@ def _deactivate_stale_dags(self, session=None): """ now = timezone.utcnow() elapsed_time_since_refresh = (now - self.last_deactivate_stale_dags_time).total_seconds() - if elapsed_time_since_refresh > self.deactivate_stale_dags_interval: + if elapsed_time_since_refresh > self.parsing_cleanup_interval: last_parsed = { fp: self.get_last_finish_time(fp) for fp in self.file_paths if self.get_last_finish_time(fp) } to_deactivate = set() - dags_parsed = ( - session.query(DagModel.dag_id, DagModel.fileloc, DagModel.last_parsed_time) - .filter(DagModel.is_active) - .all() + query = session.query(DagModel.dag_id, DagModel.fileloc, DagModel.last_parsed_time).filter( + DagModel.is_active ) + if self.standalone_dag_processor: + query = query.filter(DagModel.processor_subdir == self.get_dag_directory()) + dags_parsed = query.all() + for dag in dags_parsed: # The largest valid difference between a DagFileStat's last_finished_time and a DAG's # last_parsed_time is _processor_timeout. Longer than that indicates that the DAG is @@ -541,7 +546,7 @@ def _run_parsing_loop(self): self._refresh_dag_dir() self.prepare_file_path_queue() max_callbacks_per_loop = conf.getint("scheduler", "max_callbacks_per_loop") - standalone_dag_processor = conf.getboolean("scheduler", "standalone_dag_processor") + if self._async_mode: # If we're in async mode, we can start up straight away. If we're # in sync mode we need to be told to start a "loop" @@ -592,10 +597,11 @@ def _run_parsing_loop(self): self.waitables.pop(sentinel) self._processors.pop(processor.file_path) - if standalone_dag_processor: + if self.standalone_dag_processor: self._fetch_callbacks(max_callbacks_per_loop) self._deactivate_stale_dags() - self._refresh_dag_dir() + DagWarning.purge_inactive_dag_warnings() + refreshed_dag_dir = self._refresh_dag_dir() self._kill_timed_out_processors() @@ -604,6 +610,8 @@ def _run_parsing_loop(self): if not self._file_path_queue: self.emit_metrics() self.prepare_file_path_queue() + elif refreshed_dag_dir: + self.add_new_file_path_to_queue() self.start_new_processes() @@ -661,11 +669,12 @@ def _fetch_callbacks(self, max_callbacks: int, session: Session = NEW_SESSION): """Fetches callbacks from database and add them to the internal queue for execution.""" self.log.debug("Fetching callbacks from the database.") with prohibit_commit(session) as guard: - query = ( - session.query(DbCallbackRequest) - .order_by(DbCallbackRequest.priority_weight.asc()) - .limit(max_callbacks) - ) + query = session.query(DbCallbackRequest) + if self.standalone_dag_processor: + query = query.filter( + DbCallbackRequest.processor_subdir == self.get_dag_directory(), + ) + query = query.order_by(DbCallbackRequest.priority_weight.asc()).limit(max_callbacks) callbacks = with_row_locks( query, of=DbCallbackRequest, session=session, **skip_locked(session=session) ).all() @@ -678,16 +687,35 @@ def _fetch_callbacks(self, max_callbacks: int, session: Session = NEW_SESSION): guard.commit() def _add_callback_to_queue(self, request: CallbackRequest): - self._callback_to_execute[request.full_filepath].append(request) - # Callback has a higher priority over DAG Run scheduling - if request.full_filepath in self._file_path_queue: - # Remove file paths matching request.full_filepath from self._file_path_queue - # Since we are already going to use that filepath to run callback, - # there is no need to have same file path again in the queue - self._file_path_queue = [ - file_path for file_path in self._file_path_queue if file_path != request.full_filepath - ] - self._file_path_queue.insert(0, request.full_filepath) + + # requests are sent by dag processors. SLAs exist per-dag, but can be generated once per SLA-enabled + # task in the dag. If treated like other callbacks, SLAs can cause feedback where a SLA arrives, + # goes to the front of the queue, gets processed, triggers more SLAs from the same DAG, which go to + # the front of the queue, and we never get round to picking stuff off the back of the queue + if isinstance(request, SlaCallbackRequest): + if request in self._callback_to_execute[request.full_filepath]: + self.log.debug("Skipping already queued SlaCallbackRequest") + return + + # not already queued, queue the file _at the back_, and add the request to the file's callbacks + self.log.debug("Queuing SlaCallbackRequest for %s", request.dag_id) + self._callback_to_execute[request.full_filepath].append(request) + if request.full_filepath not in self._file_path_queue: + self._file_path_queue.append(request.full_filepath) + + # Other callbacks have a higher priority over DAG Run scheduling, so those callbacks gazump, even if + # already in the queue + else: + self.log.debug("Queuing %s CallbackRequest: %s", type(request).__name__, request) + self._callback_to_execute[request.full_filepath].append(request) + if request.full_filepath in self._file_path_queue: + # Remove file paths matching request.full_filepath from self._file_path_queue + # Since we are already going to use that filepath to run callback, + # there is no need to have same file path again in the queue + self._file_path_queue = collections.deque( + file_path for file_path in self._file_path_queue if file_path != request.full_filepath + ) + self._file_path_queue.appendleft(request.full_filepath) def _refresh_dag_dir(self): """Refresh file paths from dag dir if we haven't done it for too long.""" @@ -713,24 +741,33 @@ def _refresh_dag_dir(self): dag_filelocs = [] for fileloc in self._file_paths: if not fileloc.endswith(".py") and zipfile.is_zipfile(fileloc): - with zipfile.ZipFile(fileloc) as z: - dag_filelocs.extend( - [ - os.path.join(fileloc, info.filename) - for info in z.infolist() - if might_contain_dag(info.filename, True, z) - ] - ) + try: + with zipfile.ZipFile(fileloc) as z: + dag_filelocs.extend( + [ + os.path.join(fileloc, info.filename) + for info in z.infolist() + if might_contain_dag(info.filename, True, z) + ] + ) + except zipfile.BadZipFile as err: + self.log.error("There was an err accessing %s, %s", fileloc, err) else: dag_filelocs.append(fileloc) - SerializedDagModel.remove_deleted_dags(dag_filelocs) + SerializedDagModel.remove_deleted_dags( + alive_dag_filelocs=dag_filelocs, + processor_subdir=self.get_dag_directory(), + ) DagModel.deactivate_deleted_dags(self._file_paths) from airflow.models.dagcode import DagCode DagCode.remove_deleted_code(dag_filelocs) + return True + return False + def _print_stat(self): """Occasionally print out stats about how fast the files are getting processed""" if 0 < self.print_stats_interval < time.monotonic() - self.last_stat_print_time: @@ -748,7 +785,7 @@ def clear_nonexistent_import_errors(self, session): query = session.query(errors.ImportError) if self._file_paths: query = query.filter(~errors.ImportError.filename.in_(self._file_paths)) - query.delete(synchronize_session='fetch') + query.delete(synchronize_session="fetch") session.commit() def _log_file_processing_stats(self, known_file_paths): @@ -776,7 +813,7 @@ def _log_file_processing_stats(self, known_file_paths): num_dags = self.get_last_dag_count(file_path) num_errors = self.get_last_error_count(file_path) file_name = os.path.basename(file_path) - file_name = os.path.splitext(file_name)[0].replace(os.sep, '.') + file_name = os.path.splitext(file_name)[0].replace(os.sep, ".") processor_pid = self.get_pid(file_path) processor_start_time = self.get_start_time(file_path) @@ -784,7 +821,7 @@ def _log_file_processing_stats(self, known_file_paths): last_run = self.get_last_finish_time(file_path) if last_run: seconds_ago = (now - last_run).total_seconds() - Stats.gauge(f'dag_processing.last_run.seconds_ago.{file_name}', seconds_ago) + Stats.gauge(f"dag_processing.last_run.seconds_ago.{file_name}", seconds_ago) rows.append((file_path, processor_pid, runtime, num_dags, num_errors, last_runtime, last_run)) @@ -816,84 +853,85 @@ def _log_file_processing_stats(self, known_file_paths): self.log.info(log_str) - def get_pid(self, file_path): + def get_pid(self, file_path) -> int | None: """ :param file_path: the path to the file that's being processed :return: the PID of the process processing the given file or None if the specified file is not being processed - :rtype: int """ if file_path in self._processors: return self._processors[file_path].pid return None - def get_all_pids(self): + def get_all_pids(self) -> list[int]: """ + Get all pids. + :return: a list of the PIDs for the processors that are running - :rtype: List[int] """ return [x.pid for x in self._processors.values()] - def get_last_runtime(self, file_path): + def get_last_runtime(self, file_path) -> float | None: """ :param file_path: the path to the file that was processed :return: the runtime (in seconds) of the process of the last run, or None if the file was never processed. - :rtype: float """ stat = self._file_stats.get(file_path) return stat.last_duration.total_seconds() if stat and stat.last_duration else None - def get_last_dag_count(self, file_path): + def get_last_dag_count(self, file_path) -> int | None: """ :param file_path: the path to the file that was processed :return: the number of dags loaded from that file, or None if the file was never processed. - :rtype: int """ stat = self._file_stats.get(file_path) return stat.num_dags if stat else None - def get_last_error_count(self, file_path): + def get_last_error_count(self, file_path) -> int | None: """ :param file_path: the path to the file that was processed :return: the number of import errors from processing, or None if the file was never processed. - :rtype: int """ stat = self._file_stats.get(file_path) return stat.import_errors if stat else None - def get_last_finish_time(self, file_path): + def get_last_finish_time(self, file_path) -> datetime | None: """ :param file_path: the path to the file that was processed :return: the finish time of the process of the last run, or None if the file was never processed. - :rtype: datetime """ stat = self._file_stats.get(file_path) return stat.last_finish_time if stat else None - def get_start_time(self, file_path): + def get_start_time(self, file_path) -> datetime | None: """ :param file_path: the path to the file that's being processed :return: the start time of the process that's processing the specified file or None if the file is not currently being processed - :rtype: datetime """ if file_path in self._processors: return self._processors[file_path].start_time return None - def get_run_count(self, file_path): + def get_run_count(self, file_path) -> int: """ :param file_path: the path to the file that's being processed :return: the number of times the given file has been parsed - :rtype: int """ stat = self._file_stats.get(file_path) return stat.run_count if stat else 0 + def get_dag_directory(self) -> str: + """Returns the dag_director as a string.""" + if isinstance(self._dag_directory, Path): + return str(self._dag_directory.resolve()) + else: + return str(self._dag_directory) + def set_file_paths(self, new_file_paths): """ Update this with a new set of paths to DAG definition files. @@ -902,7 +940,7 @@ def set_file_paths(self, new_file_paths): :return: None """ self._file_paths = new_file_paths - self._file_path_queue = [x for x in self._file_path_queue if x in new_file_paths] + self._file_path_queue = collections.deque(x for x in self._file_path_queue if x in new_file_paths) # Stop processors that are working on deleted files filtered_processors = {} for file_path, processor in self._processors.items(): @@ -910,9 +948,15 @@ def set_file_paths(self, new_file_paths): filtered_processors[file_path] = processor else: self.log.warning("Stopping processor for %s", file_path) - Stats.decr('dag_processing.processes') + Stats.decr("dag_processing.processes") processor.terminate() self._file_stats.pop(file_path) + + to_remove = set(self._file_stats.keys()) - set(self._file_paths) + for key in to_remove: + # Remove the stats for any dag files that don't exist anymore + del self._file_stats[key] + self._processors = filtered_processors def wait_until_finished(self): @@ -923,7 +967,7 @@ def wait_until_finished(self): def _collect_results_from_processor(self, processor) -> None: self.log.debug("Processor for %s finished", processor.file_path) - Stats.decr('dag_processing.processes') + Stats.decr("dag_processing.processes") last_finish_time = timezone.utcnow() if processor.result is not None: @@ -945,8 +989,8 @@ def _collect_results_from_processor(self, processor) -> None: ) self._file_stats[processor.file_path] = stat - file_name = os.path.splitext(os.path.basename(processor.file_path))[0].replace(os.sep, '.') - Stats.timing(f'dag_processing.last_duration.{file_name}', last_duration) + file_name = os.path.splitext(os.path.basename(processor.file_path))[0].replace(os.sep, ".") + Stats.timing(f"dag_processing.last_duration.{file_name}", last_duration) def collect_results(self) -> None: """Collect the result from any finished DAG processors""" @@ -967,33 +1011,51 @@ def collect_results(self) -> None: self.log.debug("%s file paths queued for processing", len(self._file_path_queue)) @staticmethod - def _create_process(file_path, pickle_dags, dag_ids, callback_requests): + def _create_process(file_path, pickle_dags, dag_ids, dag_directory, callback_requests): """Creates DagFileProcessorProcess instance.""" return DagFileProcessorProcess( - file_path=file_path, pickle_dags=pickle_dags, dag_ids=dag_ids, callback_requests=callback_requests + file_path=file_path, + pickle_dags=pickle_dags, + dag_ids=dag_ids, + dag_directory=dag_directory, + callback_requests=callback_requests, ) def start_new_processes(self): """Start more processors if we have enough slots and files to process""" while self._parallelism - len(self._processors) > 0 and self._file_path_queue: - file_path = self._file_path_queue.pop(0) + file_path = self._file_path_queue.popleft() # Stop creating duplicate processor i.e. processor with the same filepath if file_path in self._processors.keys(): continue callback_to_execute_for_file = self._callback_to_execute[file_path] processor = self._create_process( - file_path, self._pickle_dags, self._dag_ids, callback_to_execute_for_file + file_path, + self._pickle_dags, + self._dag_ids, + self.get_dag_directory(), + callback_to_execute_for_file, ) del self._callback_to_execute[file_path] - Stats.incr('dag_processing.processes') + Stats.incr("dag_processing.processes") processor.start() self.log.debug("Started a process (PID: %s) to generate tasks for %s", processor.pid, file_path) self._processors[file_path] = processor self.waitables[processor.waitable_handle] = processor + def add_new_file_path_to_queue(self): + for file_path in self.file_paths: + if file_path not in self._file_stats: + # We found new file after refreshing dir. add to parsing queue at start + self.log.info("Adding new file %s to parsing queue", file_path) + self._file_stats[file_path] = DagFileStat( + num_dags=0, import_errors=0, last_finish_time=None, last_duration=None, run_count=0 + ) + self._file_path_queue.appendleft(file_path) + def prepare_file_path_queue(self): """Generate more file paths to process. Result are saved in _file_path_queue.""" self._parsing_start_time = time.perf_counter() @@ -1010,6 +1072,7 @@ def prepare_file_path_queue(self): is_mtime_mode = list_mode == "modified_time" file_paths_recently_processed = [] + file_paths_to_stop_watching = set() for file_path in self._file_paths: if is_mtime_mode: @@ -1017,6 +1080,8 @@ def prepare_file_path_queue(self): files_with_mtime[file_path] = os.path.getmtime(file_path) except FileNotFoundError: self.log.warning("Skipping processing of missing file: %s", file_path) + self._file_stats.pop(file_path, None) + file_paths_to_stop_watching.add(file_path) continue file_modified_time = timezone.make_aware(datetime.fromtimestamp(files_with_mtime[file_path])) else: @@ -1045,12 +1110,18 @@ def prepare_file_path_queue(self): # set of files. Since we set the seed, the sort order will remain same per host random.Random(get_hostname()).shuffle(file_paths) + if file_paths_to_stop_watching: + self.set_file_paths( + [path for path in self._file_paths if path not in file_paths_to_stop_watching] + ) + files_paths_at_run_limit = [ file_path for file_path, stat in self._file_stats.items() if stat.run_count == self._max_runs ] file_paths_to_exclude = set(file_paths_in_progress).union( - file_paths_recently_processed, files_paths_at_run_limit + file_paths_recently_processed, + files_paths_at_run_limit, ) # Do not convert the following list to set as set does not preserve the order @@ -1068,12 +1139,11 @@ def prepare_file_path_queue(self): self.log.debug("Queuing the following files for processing:\n\t%s", "\n\t".join(files_paths_to_queue)) + default = DagFileStat( + num_dags=0, import_errors=0, last_finish_time=None, last_duration=None, run_count=0 + ) for file_path in files_paths_to_queue: - if file_path not in self._file_stats: - self._file_stats[file_path] = DagFileStat( - num_dags=0, import_errors=0, last_finish_time=None, last_duration=None, run_count=0 - ) - + self._file_stats.setdefault(file_path, default) self._file_path_queue.extend(files_paths_to_queue) def _kill_timed_out_processors(self): @@ -1089,16 +1159,25 @@ def _kill_timed_out_processors(self): processor.pid, processor.start_time.isoformat(), ) - Stats.decr('dag_processing.processes') - Stats.incr('dag_processing.processor_timeouts') - # TODO: Remove after Airflow 2.0 - Stats.incr('dag_file_processor_timeouts') + Stats.decr("dag_processing.processes") + Stats.incr("dag_processing.processor_timeouts") + # Deprecated; may be removed in a future Airflow release. + Stats.incr("dag_file_processor_timeouts") processor.kill() # Clean up processor references self.waitables.pop(processor.waitable_handle) processors_to_remove.append(file_path) + stat = DagFileStat( + num_dags=0, + import_errors=1, + last_finish_time=now, + last_duration=duration, + run_count=self.get_run_count(file_path) + 1, + ) + self._file_stats[processor.file_path] = stat + # Clean up `self._processors` after iterating over it for proc in processors_to_remove: self._processors.pop(proc) @@ -1120,7 +1199,7 @@ def terminate(self): :return: None """ for processor in self._processors.values(): - Stats.decr('dag_processing.processes') + Stats.decr("dag_processing.processes") processor.terminate() def end(self): @@ -1140,10 +1219,10 @@ def emit_metrics(self): all files have been parsed. """ parse_time = time.perf_counter() - self._parsing_start_time - Stats.gauge('dag_processing.total_parse_time', parse_time) - Stats.gauge('dagbag_size', sum(stat.num_dags for stat in self._file_stats.values())) + Stats.gauge("dag_processing.total_parse_time", parse_time) + Stats.gauge("dagbag_size", sum(stat.num_dags for stat in self._file_stats.values())) Stats.gauge( - 'dag_processing.import_errors', sum(stat.import_errors for stat in self._file_stats.values()) + "dag_processing.import_errors", sum(stat.import_errors for stat in self._file_stats.values()) ) @property diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index 34821ffb92010..02bb5eeebf399 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import datetime import logging @@ -25,13 +26,13 @@ from contextlib import redirect_stderr, redirect_stdout, suppress from datetime import timedelta from multiprocessing.connection import Connection as MultiprocessingConnection -from typing import Iterator, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Iterator from setproctitle import setproctitle -from sqlalchemy import func, or_ +from sqlalchemy import exc, func, or_ from sqlalchemy.orm.session import Session -from airflow import models, settings +from airflow import settings from airflow.callbacks.callback_requests import ( CallbackRequest, DagCallbackRequest, @@ -43,6 +44,9 @@ from airflow.models import SlaMiss, errors from airflow.models.dag import DAG, DagModel from airflow.models.dagbag import DagBag +from airflow.models.dagrun import DagRun as DR +from airflow.models.dagwarning import DagWarning, DagWarningType +from airflow.models.taskinstance import TaskInstance as TI from airflow.stats import Stats from airflow.utils import timezone from airflow.utils.email import get_email_address_list, send_email @@ -51,8 +55,8 @@ from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import State -DR = models.DagRun -TI = models.TaskInstance +if TYPE_CHECKING: + from airflow.models.operator import Operator class DagFileProcessorProcess(LoggingMixin, MultiprocessingStartMethodMixin): @@ -71,28 +75,30 @@ def __init__( self, file_path: str, pickle_dags: bool, - dag_ids: Optional[List[str]], - callback_requests: List[CallbackRequest], + dag_ids: list[str] | None, + dag_directory: str, + callback_requests: list[CallbackRequest], ): super().__init__() self._file_path = file_path self._pickle_dags = pickle_dags self._dag_ids = dag_ids + self._dag_directory = dag_directory self._callback_requests = callback_requests # The process that was launched to process the given . - self._process: Optional[multiprocessing.process.BaseProcess] = None + self._process: multiprocessing.process.BaseProcess | None = None # The result of DagFileProcessor.process_file(file_path). - self._result: Optional[Tuple[int, int]] = None + self._result: tuple[int, int] | None = None # Whether the process is done running. self._done = False # When the process started. - self._start_time: Optional[datetime.datetime] = None + self._start_time: datetime.datetime | None = None # This ID is use to uniquely name the process / thread that's launched # by this processor instance self._instance_id = DagFileProcessorProcess.class_creation_counter - self._parent_channel: Optional[MultiprocessingConnection] = None + self._parent_channel: MultiprocessingConnection | None = None DagFileProcessorProcess.class_creation_counter += 1 @property @@ -105,9 +111,10 @@ def _run_file_processor( parent_channel: MultiprocessingConnection, file_path: str, pickle_dags: bool, - dag_ids: Optional[List[str]], + dag_ids: list[str] | None, thread_name: str, - callback_requests: List[CallbackRequest], + dag_directory: str, + callback_requests: list[CallbackRequest], ) -> None: """ Process the given file. @@ -122,42 +129,49 @@ def _run_file_processor( :param thread_name: the name to use for the process that is launched :param callback_requests: failure callback to execute :return: the process that was launched - :rtype: multiprocessing.Process """ # This helper runs in the newly created process log: logging.Logger = logging.getLogger("airflow.processor") # Since we share all open FDs from the parent, we need to close the parent side of the pipe here in # the child, else it won't get closed properly until we exit. - log.info("Closing parent pipe") - parent_channel.close() del parent_channel set_context(log, file_path) setproctitle(f"airflow scheduler - DagFileProcessor {file_path}") + def _handle_dag_file_processing(): + # Re-configure the ORM engine as there are issues with multiple processes + settings.configure_orm() + + # Change the thread name to differentiate log lines. This is + # really a separate process, but changing the name of the + # process doesn't work, so changing the thread name instead. + threading.current_thread().name = thread_name + + log.info("Started process (PID=%s) to work on %s", os.getpid(), file_path) + dag_file_processor = DagFileProcessor(dag_ids=dag_ids, dag_directory=dag_directory, log=log) + result: tuple[int, int] = dag_file_processor.process_file( + file_path=file_path, + pickle_dags=pickle_dags, + callback_requests=callback_requests, + ) + result_channel.send(result) + try: - # redirect stdout/stderr to log - with redirect_stdout(StreamLogWriter(log, logging.INFO)), redirect_stderr( - StreamLogWriter(log, logging.WARN) - ), Stats.timer() as timer: - # Re-configure the ORM engine as there are issues with multiple processes - settings.configure_orm() - - # Change the thread name to differentiate log lines. This is - # really a separate process, but changing the name of the - # process doesn't work, so changing the thread name instead. - threading.current_thread().name = thread_name - - log.info("Started process (PID=%s) to work on %s", os.getpid(), file_path) - dag_file_processor = DagFileProcessor(dag_ids=dag_ids, log=log) - result: Tuple[int, int] = dag_file_processor.process_file( - file_path=file_path, - pickle_dags=pickle_dags, - callback_requests=callback_requests, - ) - result_channel.send(result) + DAG_PROCESSOR_LOG_TARGET = conf.get_mandatory_value("logging", "DAG_PROCESSOR_LOG_TARGET") + if DAG_PROCESSOR_LOG_TARGET == "stdout": + with Stats.timer() as timer: + _handle_dag_file_processing() + else: + # The following line ensures that stdout goes to the same destination as the logs. If stdout + # gets sent to logs and logs are sent to stdout, this leads to an infinite loop. This + # necessitates this conditional based on the value of DAG_PROCESSOR_LOG_TARGET. + with redirect_stdout(StreamLogWriter(log, logging.INFO)), redirect_stderr( + StreamLogWriter(log, logging.WARN) + ), Stats.timer() as timer: + _handle_dag_file_processing() log.info("Processing %s took %.3f seconds", file_path, timer.duration) except Exception: # Log exceptions through the logging framework. @@ -185,6 +199,7 @@ def start(self) -> None: self._pickle_dags, self._dag_ids, f"DagFileProcessor{self._instance_id}", + self._dag_directory, self._callback_requests, ), name=f"DagFileProcessor{self._instance_id}-Process", @@ -243,21 +258,17 @@ def _kill_process(self) -> None: @property def pid(self) -> int: - """ - :return: the PID of the process launched to process the given file - :rtype: int - """ + """PID of the process launched to process the given file.""" if self._process is None or self._process.pid is None: raise AirflowException("Tried to get PID before starting!") return self._process.pid @property - def exit_code(self) -> Optional[int]: + def exit_code(self) -> int | None: """ After the process is finished, this can be called to get the return code :return: the exit code of the process - :rtype: int """ if self._process is None: raise AirflowException("Tried to get exit code before starting!") @@ -271,7 +282,6 @@ def done(self) -> bool: Check if the process launched to process this file is done. :return: whether the process is finished running - :rtype: bool """ if self._process is None or self._parent_channel is None: raise AirflowException("Tried to see if it's done before starting!") @@ -309,21 +319,15 @@ def done(self) -> bool: return False @property - def result(self) -> Optional[Tuple[int, int]]: - """ - :return: result of running DagFileProcessor.process_file() - :rtype: tuple[int, int] or None - """ + def result(self) -> tuple[int, int] | None: + """Result of running ``DagFileProcessor.process_file()``.""" if not self.done: raise AirflowException("Tried to get the result before it's done!") return self._result @property def start_time(self) -> datetime.datetime: - """ - :return: when this started to process the file - :rtype: datetime - """ + """Time when this started to process the file.""" if self._start_time is None: raise AirflowException("Tried to get start time before it started!") return self._start_time @@ -351,12 +355,14 @@ class DagFileProcessor(LoggingMixin): :param log: Logger to save the processing process """ - UNIT_TEST_MODE: bool = conf.getboolean('core', 'UNIT_TEST_MODE') + UNIT_TEST_MODE: bool = conf.getboolean("core", "UNIT_TEST_MODE") - def __init__(self, dag_ids: Optional[List[str]], log: logging.Logger): + def __init__(self, dag_ids: list[str] | None, dag_directory: str, log: logging.Logger): super().__init__() self.dag_ids = dag_ids self._log = log + self._dag_directory = dag_directory + self.dag_warnings: set[tuple[str, str]] = set() @provide_session def manage_slas(self, dag: DAG, session: Session = None) -> None: @@ -373,14 +379,13 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None: return qry = ( - session.query(TI.task_id, func.max(DR.execution_date).label('max_ti')) + session.query(TI.task_id, func.max(DR.execution_date).label("max_ti")) .join(TI.dag_run) - .with_hint(TI, 'USE INDEX (PRIMARY)', dialect_name='mysql') .filter(TI.dag_id == dag.dag_id) .filter(or_(TI.state == State.SUCCESS, TI.state == State.SKIPPED)) .filter(TI.task_id.in_(dag.task_ids)) .group_by(TI.task_id) - .subquery('sq') + .subquery("sq") ) # get recorded SlaMiss recorded_slas_query = set( @@ -414,42 +419,40 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None: sla_misses = [] next_info = dag.next_dagrun_info(dag.get_run_data_interval(ti.dag_run), restricted=False) - if next_info is None: - self.log.info("Skipping SLA check for %s because task does not have scheduled date", ti) - else: - while next_info.logical_date < ts: - next_info = dag.next_dagrun_info(next_info.data_interval, restricted=False) - - if next_info is None: - break - if (ti.dag_id, ti.task_id, next_info.logical_date) in recorded_slas_query: - break - if next_info.logical_date + task.sla < ts: - - sla_miss = SlaMiss( - task_id=ti.task_id, - dag_id=ti.dag_id, - execution_date=next_info.logical_date, - timestamp=ts, - ) - sla_misses.append(sla_miss) + while next_info and next_info.logical_date < ts: + next_info = dag.next_dagrun_info(next_info.data_interval, restricted=False) + + if next_info is None: + break + if (ti.dag_id, ti.task_id, next_info.logical_date) in recorded_slas_query: + continue + if next_info.logical_date + task.sla < ts: + + sla_miss = SlaMiss( + task_id=ti.task_id, + dag_id=ti.dag_id, + execution_date=next_info.logical_date, + timestamp=ts, + ) + sla_misses.append(sla_miss) + Stats.incr("sla_missed") if sla_misses: session.add_all(sla_misses) session.commit() - slas: List[SlaMiss] = ( + slas: list[SlaMiss] = ( session.query(SlaMiss) .filter(SlaMiss.notification_sent == False, SlaMiss.dag_id == dag.dag_id) # noqa .all() ) if slas: - sla_dates: List[datetime.datetime] = [sla.execution_date for sla in slas] - fetched_tis: List[TI] = ( + sla_dates: list[datetime.datetime] = [sla.execution_date for sla in slas] + fetched_tis: list[TI] = ( session.query(TI) .filter(TI.state != State.SUCCESS, TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id) .all() ) - blocking_tis: List[TI] = [] + blocking_tis: list[TI] = [] for ti in fetched_tis: if ti.task_id in dag.task_ids: ti.task = dag.get_task(ti.task_id) @@ -458,9 +461,9 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None: session.delete(ti) session.commit() - task_list = "\n".join(sla.task_id + ' on ' + sla.execution_date.isoformat() for sla in slas) + task_list = "\n".join(sla.task_id + " on " + sla.execution_date.isoformat() for sla in slas) blocking_task_list = "\n".join( - ti.task_id + ' on ' + ti.execution_date.isoformat() for ti in blocking_tis + ti.task_id + " on " + ti.execution_date.isoformat() for ti in blocking_tis ) # Track whether email or any alert notification sent # We consider email or the alert callback as notifications @@ -468,12 +471,12 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None: notification_sent = False if dag.sla_miss_callback: # Execute the alert callback - self.log.info('Calling SLA miss callback') + self.log.info("Calling SLA miss callback") try: dag.sla_miss_callback(dag, task_list, blocking_task_list, slas, blocking_tis) notification_sent = True except Exception: - Stats.incr('sla_callback_notification_failure') + Stats.incr("sla_callback_notification_failure") self.log.exception("Could not call sla_miss_callback for DAG %s", dag.dag_id) email_content = f"""\ Here's a list of tasks that missed their SLAs: @@ -495,7 +498,7 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None: continue tasks_missed_sla.append(task) - emails: Set[str] = set() + emails: set[str] = set() for task in tasks_missed_sla: if task.email: if isinstance(task.email, str): @@ -508,7 +511,7 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None: email_sent = True notification_sent = True except Exception: - Stats.incr('sla_email_notification_failure') + Stats.incr("sla_email_notification_failure") self.log.exception("Could not send SLA Miss email notification for DAG %s", dag.dag_id) # If we sent any notification, update the sla_miss table if notification_sent: @@ -545,7 +548,7 @@ def update_import_errors(session: Session, dagbag: DagBag) -> None: if filename in existing_import_error_files: session.query(errors.ImportError).filter(errors.ImportError.filename == filename).update( dict(filename=filename, timestamp=timezone.utcnow(), stacktrace=stacktrace), - synchronize_session='fetch', + synchronize_session="fetch", ) else: session.add( @@ -554,14 +557,63 @@ def update_import_errors(session: Session, dagbag: DagBag) -> None: ( session.query(DagModel) .filter(DagModel.fileloc == filename) - .update({'has_import_errors': True}, synchronize_session='fetch') + .update({"has_import_errors": True}, synchronize_session="fetch") ) session.commit() + @provide_session + def _validate_task_pools(self, *, dagbag: DagBag, session: Session = NEW_SESSION): + """ + Validates and raise exception if any task in a dag is using a non-existent pool + :meta private: + """ + from airflow.models.pool import Pool + + def check_pools(dag): + task_pools = {task.pool for task in dag.tasks} + nonexistent_pools = task_pools - pools + if nonexistent_pools: + return ( + f"Dag '{dag.dag_id}' references non-existent pools: {list(sorted(nonexistent_pools))!r}" + ) + + pools = {p.pool for p in Pool.get_pools(session)} + for dag in dagbag.dags.values(): + message = check_pools(dag) + if message: + self.dag_warnings.add(DagWarning(dag.dag_id, DagWarningType.NONEXISTENT_POOL, message)) + for subdag in dag.subdags: + message = check_pools(subdag) + if message: + self.dag_warnings.add(DagWarning(subdag.dag_id, DagWarningType.NONEXISTENT_POOL, message)) + + def update_dag_warnings(self, *, session: Session, dagbag: DagBag) -> None: + """ + For the DAGs in the given DagBag, record any associated configuration warnings and clear + warnings for files that no longer have them. These are usually displayed through the + Airflow UI so that users know that there are issues parsing DAGs. + + :param session: session for ORM operations + :param dagbag: DagBag containing DAGs with configuration warnings + """ + self._validate_task_pools(dagbag=dagbag) + + stored_warnings = set( + session.query(DagWarning).filter(DagWarning.dag_id.in_(dagbag.dags.keys())).all() + ) + + for warning_to_delete in stored_warnings - self.dag_warnings: + session.delete(warning_to_delete) + + for warning_to_add in self.dag_warnings: + session.merge(warning_to_add) + + session.commit() + @provide_session def execute_callbacks( - self, dagbag: DagBag, callback_requests: List[CallbackRequest], session: Session = NEW_SESSION + self, dagbag: DagBag, callback_requests: list[CallbackRequest], session: Session = NEW_SESSION ) -> None: """ Execute on failure callbacks. These objects can come from SchedulerJob or from @@ -575,7 +627,7 @@ def execute_callbacks( self.log.debug("Processing Callback Request: %s", request) try: if isinstance(request, TaskCallbackRequest): - self._execute_task_callbacks(dagbag, request) + self._execute_task_callbacks(dagbag, request, session=session) elif isinstance(request, SlaCallbackRequest): self.manage_slas(dagbag.get_dag(request.dag_id), session=session) elif isinstance(request, DagCallbackRequest): @@ -587,7 +639,27 @@ def execute_callbacks( request.full_filepath, ) - session.commit() + session.flush() + + def execute_callbacks_without_dag( + self, callback_requests: list[CallbackRequest], session: Session + ) -> None: + """ + Execute what callbacks we can as "best effort" when the dag cannot be found/had parse errors. + + This is so important so that tasks that failed when there is a parse + error don't get stuck in queued state. + """ + for request in callback_requests: + self.log.debug("Processing Callback Request: %s", request) + if isinstance(request, TaskCallbackRequest): + self._execute_task_callbacks(None, request, session) + else: + self.log.info( + "Not executing %s callback for file %s as there was a dag parse error", + request.__class__.__name__, + request.full_filepath, + ) @provide_session def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, session: Session): @@ -597,27 +669,58 @@ def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, se dagrun=dag_run, success=not request.is_failure_callback, reason=request.msg, session=session ) - def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest): + def _execute_task_callbacks(self, dagbag: DagBag | None, request: TaskCallbackRequest, session: Session): + if not request.is_failure_callback: + return + simple_ti = request.simple_task_instance - if simple_ti.dag_id in dagbag.dags: + ti: TI | None = ( + session.query(TI) + .filter_by( + dag_id=simple_ti.dag_id, + run_id=simple_ti.run_id, + task_id=simple_ti.task_id, + map_index=simple_ti.map_index, + ) + .one_or_none() + ) + if not ti: + return + + task: Operator | None = None + + if dagbag and simple_ti.dag_id in dagbag.dags: dag = dagbag.dags[simple_ti.dag_id] if simple_ti.task_id in dag.task_ids: task = dag.get_task(simple_ti.task_id) - if request.is_failure_callback: - ti = TI(task, run_id=simple_ti.run_id, map_index=simple_ti.map_index) - # TODO: Use simple_ti to improve performance here in the future - ti.refresh_from_db() - ti.handle_failure(error=request.msg, test_mode=self.UNIT_TEST_MODE) - self.log.info('Executed failure callback for %s in state %s', ti, ti.state) + else: + # We don't have the _real_ dag here (perhaps it had a parse error?) but we still want to run + # `handle_failure` so that the state of the TI gets progressed. + # + # Since handle_failure _really_ wants a task, we do our best effort to give it one + from airflow.models.serialized_dag import SerializedDagModel + + try: + model = session.query(SerializedDagModel).get(simple_ti.dag_id) + if model: + task = model.dag.get_task(simple_ti.task_id) + except (exc.NoResultFound, TaskNotFound): + pass + if task: + ti.refresh_from_task(task) + + ti.handle_failure(error=request.msg, test_mode=self.UNIT_TEST_MODE, session=session) + self.log.info("Executed failure callback for %s in state %s", ti, ti.state) + session.flush() @provide_session def process_file( self, file_path: str, - callback_requests: List[CallbackRequest], + callback_requests: list[CallbackRequest], pickle_dags: bool = False, - session: Session = None, - ) -> Tuple[int, int]: + session: Session = NEW_SESSION, + ) -> tuple[int, int]: """ Process a Python file containing Airflow DAGs. @@ -636,15 +739,14 @@ def process_file( save them to the db :param session: Sqlalchemy ORM Session :return: number of dags found, count of import errors - :rtype: Tuple[int, int] """ self.log.info("Processing file %s for tasks to queue", file_path) try: - dagbag = DagBag(file_path, include_examples=False, include_smart_sensor=False) + dagbag = DagBag(file_path, include_examples=False) except Exception: self.log.exception("Failed at reloading the DAG file %s", file_path) - Stats.incr('dag_file_refresh_error', 1, 1) + Stats.incr("dag_file_refresh_error", 1, 1) return 0, 0 if len(dagbag.dags) > 0: @@ -652,17 +754,24 @@ def process_file( else: self.log.warning("No viable dags retrieved from %s", file_path) self.update_import_errors(session, dagbag) + if callback_requests: + # If there were callback requests for this file but there was a + # parse error we still need to progress the state of TIs, + # otherwise they might be stuck in queued/running for ever! + self.execute_callbacks_without_dag(callback_requests, session) return 0, len(dagbag.import_errors) - self.execute_callbacks(dagbag, callback_requests) + self.execute_callbacks(dagbag, callback_requests, session) + session.commit() # Save individual DAGs in the ORM - dagbag.sync_to_db() + dagbag.sync_to_db(processor_subdir=self._dag_directory, session=session) + session.commit() if pickle_dags: paused_dag_ids = DagModel.get_paused_dag_ids(dag_ids=dagbag.dag_ids) - unpaused_dags: List[DAG] = [ + unpaused_dags: list[DAG] = [ dag for dag_id, dag in dagbag.dags.items() if dag_id not in paused_dag_ids ] @@ -675,4 +784,10 @@ def process_file( except Exception: self.log.exception("Error logging import errors!") + # Record DAG warnings in the metadatabase. + try: + self.update_dag_warnings(session=session, dagbag=dagbag) + except Exception: + self.log.exception("Error logging DAG warnings.") + return len(dagbag.dags), len(dagbag.import_errors) diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py new file mode 100644 index 0000000000000..aa50fd16ecd9e --- /dev/null +++ b/airflow/datasets/__init__.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any, ClassVar +from urllib.parse import urlsplit + +import attr + + +@attr.define() +class Dataset: + """A Dataset is used for marking data dependencies between workflows.""" + + uri: str = attr.field(validator=[attr.validators.min_len(1), attr.validators.max_len(3000)]) + extra: dict[str, Any] | None = None + + version: ClassVar[int] = 1 + + @uri.validator + def _check_uri(self, attr, uri: str): + if uri.isspace(): + raise ValueError(f"{attr.name} cannot be just whitespace") + try: + uri.encode("ascii") + except UnicodeEncodeError: + raise ValueError(f"{attr.name!r} must be ascii") + parsed = urlsplit(uri) + if parsed.scheme and parsed.scheme.lower() == "airflow": + raise ValueError(f"{attr.name!r} scheme `airflow` is reserved") diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py new file mode 100644 index 0000000000000..1e6723ea64ecb --- /dev/null +++ b/airflow/datasets/manager.py @@ -0,0 +1,126 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from sqlalchemy import exc +from sqlalchemy.orm.session import Session + +from airflow.configuration import conf +from airflow.datasets import Dataset +from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, DatasetModel +from airflow.utils.log.logging_mixin import LoggingMixin + +if TYPE_CHECKING: + from airflow.models.taskinstance import TaskInstance + + +class DatasetManager(LoggingMixin): + """ + A pluggable class that manages operations for datasets. + + The intent is to have one place to handle all Dataset-related operations, so different + Airflow deployments can use plugins that broadcast dataset events to each other. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def register_dataset_change( + self, *, task_instance: TaskInstance, dataset: Dataset, extra=None, session: Session, **kwargs + ) -> None: + """ + Register dataset related changes. + + For local datasets, look them up, record the dataset event, queue dagruns, and broadcast + the dataset event + """ + dataset_model = session.query(DatasetModel).filter(DatasetModel.uri == dataset.uri).one_or_none() + if not dataset_model: + self.log.warning("DatasetModel %s not found", dataset) + return + session.add( + DatasetEvent( + dataset_id=dataset_model.id, + source_task_id=task_instance.task_id, + source_dag_id=task_instance.dag_id, + source_run_id=task_instance.run_id, + source_map_index=task_instance.map_index, + extra=extra, + ) + ) + session.flush() + if dataset_model.consuming_dags: + self._queue_dagruns(dataset_model, session) + session.flush() + + def _queue_dagruns(self, dataset: DatasetModel, session: Session) -> None: + # Possible race condition: if multiple dags or multiple (usually + # mapped) tasks update the same dataset, this can fail with a unique + # constraint violation. + # + # If we support it, use ON CONFLICT to do nothing, otherwise + # "fallback" to running this in a nested transaction. This is needed + # so that the adding of these rows happens in the same transaction + # where `ti.state` is changed. + + if session.bind.dialect.name == "postgresql": + return self._postgres_queue_dagruns(dataset, session) + return self._slow_path_queue_dagruns(dataset, session) + + def _slow_path_queue_dagruns(self, dataset: DatasetModel, session: Session) -> None: + consuming_dag_ids = [x.dag_id for x in dataset.consuming_dags] + self.log.debug("consuming dag ids %s", consuming_dag_ids) + + # Don't error whole transaction when a single RunQueue item conflicts. + # https://docs.sqlalchemy.org/en/14/orm/session_transaction.html#using-savepoint + for dag_id in consuming_dag_ids: + item = DatasetDagRunQueue(target_dag_id=dag_id, dataset_id=dataset.id) + try: + with session.begin_nested(): + session.merge(item) + except exc.IntegrityError: + self.log.debug("Skipping record %s", item, exc_info=True) + + def _postgres_queue_dagruns(self, dataset: DatasetModel, session: Session) -> None: + from sqlalchemy.dialects.postgresql import insert + + stmt = insert(DatasetDagRunQueue).values(dataset_id=dataset.id).on_conflict_do_nothing() + session.execute( + stmt, + [{"target_dag_id": target_dag.dag_id} for target_dag in dataset.consuming_dags], + ) + + +def resolve_dataset_manager() -> DatasetManager: + """Retrieve the dataset manager.""" + _dataset_manager_class = conf.getimport( + section="core", + key="dataset_manager_class", + fallback="airflow.datasets.manager.DatasetManager", + ) + _dataset_manager_kwargs = conf.getjson( + section="core", + key="dataset_manager_kwargs", + fallback={}, + ) + return _dataset_manager_class(**_dataset_manager_kwargs) + + +dataset_manager = resolve_dataset_manager() diff --git a/airflow/decorators/__init__.py b/airflow/decorators/__init__.py index 40d9921ec0169..af478314e18e4 100644 --- a/airflow/decorators/__init__.py +++ b/airflow/decorators/__init__.py @@ -14,13 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from typing import Any from airflow.decorators.base import TaskDecorator from airflow.decorators.branch_python import branch_task +from airflow.decorators.external_python import external_python_task from airflow.decorators.python import python_task from airflow.decorators.python_virtualenv import virtualenv_task +from airflow.decorators.sensor import sensor_task +from airflow.decorators.short_circuit import short_circuit_task from airflow.decorators.task_group import task_group from airflow.models.dag import dag from airflow.providers_manager import ProvidersManager @@ -34,18 +38,24 @@ "task_group", "python_task", "virtualenv_task", + "external_python_task", "branch_task", + "short_circuit_task", + "sensor_task", ] class TaskDecoratorCollection: """Implementation to provide the ``@task`` syntax.""" - python: Any = staticmethod(python_task) + python = staticmethod(python_task) virtualenv = staticmethod(virtualenv_task) + external_python = staticmethod(external_python_task) branch = staticmethod(branch_task) + short_circuit = staticmethod(short_circuit_task) + sensor = staticmethod(sensor_task) - __call__ = python # Alias '@task' to '@task.python'. + __call__: Any = python # Alias '@task' to '@task.python'. def __getattr__(self, name: str) -> TaskDecorator: """Dynamically get provider-registered task decorators, e.g. ``@task.docker``.""" diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi index e970f61657c64..0a6d534b247fa 100644 --- a/airflow/decorators/__init__.pyi +++ b/airflow/decorators/__init__.pyi @@ -14,19 +14,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - # This file provides better type hinting and editor autocompletion support for # dynamically generated task decorators. Functions declared in this stub do not # necessarily exist at run time. See "Creating Custom @task Decorators" # documentation for more details. -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union, overload +from datetime import timedelta +from typing import Any, Callable, Iterable, Mapping, Union, overload + +from kubernetes.client import models as k8s -from airflow.decorators.base import Function, Task, TaskDecorator +from airflow.decorators.base import FParams, FReturn, Task, TaskDecorator from airflow.decorators.branch_python import branch_task +from airflow.decorators.external_python import external_python_task from airflow.decorators.python import python_task from airflow.decorators.python_virtualenv import virtualenv_task +from airflow.decorators.sensor import sensor_task from airflow.decorators.task_group import task_group +from airflow.kubernetes.secret import Secret from airflow.models.dag import dag # Please keep this in sync with __init__.py's __all__. @@ -38,7 +43,10 @@ __all__ = [ "task_group", "python_task", "virtualenv_task", + "external_python_task", "branch_task", + "short_circuit_task", + "sensor_task", ] class TaskDecoratorCollection: @@ -46,10 +54,10 @@ class TaskDecoratorCollection: def python( self, *, - multiple_outputs: Optional[bool] = None, + multiple_outputs: bool | None = None, # 'python_callable', 'op_args' and 'op_kwargs' since they are filled by # _PythonDecoratedOperator. - templates_dict: Optional[Mapping[str, Any]] = None, + templates_dict: Mapping[str, Any] | None = None, show_return_value_in_logs: bool = True, **kwargs, ) -> TaskDecorator: @@ -60,7 +68,7 @@ class TaskDecoratorCollection: :param templates_dict: a dictionary where the values are templates that will get templated by the Airflow engine sometime between ``__init__`` and ``execute`` takes place and are made available - in your callable's context after the template has been applied + in your callable's context after the template has been applied. :param show_return_value_in_logs: a bool value whether to show return_value logs. Defaults to True, which allows return value log output. It can be set to False to prevent log output of return value when you return huge data @@ -68,33 +76,33 @@ class TaskDecoratorCollection: """ # [START mixin_for_typing] @overload - def python(self, python_callable: Function) -> Task[Function]: ... + def python(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ... # [END mixin_for_typing] @overload def __call__( self, *, - multiple_outputs: Optional[bool] = None, - templates_dict: Optional[Mapping[str, Any]] = None, + multiple_outputs: bool | None = None, + templates_dict: Mapping[str, Any] | None = None, show_return_value_in_logs: bool = True, **kwargs, ) -> TaskDecorator: """Aliasing ``python``; signature should match exactly.""" @overload - def __call__(self, python_callable: Function) -> Task[Function]: + def __call__(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: """Aliasing ``python``; signature should match exactly.""" @overload def virtualenv( self, *, - multiple_outputs: Optional[bool] = None, + multiple_outputs: bool | None = None, # 'python_callable', 'op_args' and 'op_kwargs' since they are filled by # _PythonVirtualenvDecoratedOperator. requirements: Union[None, Iterable[str], str] = None, python_version: Union[None, str, int, float] = None, use_dill: bool = False, system_site_packages: bool = True, - templates_dict: Optional[Mapping[str, Any]] = None, + templates_dict: Mapping[str, Any] | None = None, show_return_value_in_logs: bool = True, **kwargs, ) -> TaskDecorator: @@ -115,71 +123,115 @@ class TaskDecoratorCollection: :param templates_dict: a dictionary where the values are templates that will get templated by the Airflow engine sometime between ``__init__`` and ``execute`` takes place and are made available - in your callable's context after the template has been applied + in your callable's context after the template has been applied. :param show_return_value_in_logs: a bool value whether to show return_value logs. Defaults to True, which allows return value log output. It can be set to False to prevent log output of return value when you return huge data such as transmission a large amount of XCom to TaskAPI. """ @overload - def virtualenv(self, python_callable: Function) -> Task[Function]: ... + def virtualenv(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ... + def external_python( + self, + *, + python: str, + multiple_outputs: bool | None = None, + # 'python_callable', 'op_args' and 'op_kwargs' since they are filled by + # _PythonVirtualenvDecoratedOperator. + use_dill: bool = False, + templates_dict: Mapping[str, Any] | None = None, + show_return_value_in_logs: bool = True, + **kwargs, + ) -> TaskDecorator: + """Create a decorator to convert the decorated callable to a virtual environment task. + + :param python: Full path string (file-system specific) that points to a Python binary inside + a virtualenv that should be used (in ``VENV/bin`` folder). Should be absolute path + (so usually start with "/" or "X:/" depending on the filesystem/os used). + :param multiple_outputs: If set, function return value will be unrolled to multiple XCom values. + Dict will unroll to XCom values with keys as XCom keys. Defaults to False. + :param use_dill: Whether to use dill to serialize + the args and result (pickle is default). This allow more complex types + but requires you to include dill in your requirements. + :param templates_dict: a dictionary where the values are templates that + will get templated by the Airflow engine sometime between + ``__init__`` and ``execute`` takes place and are made available + in your callable's context after the template has been applied. + :param show_return_value_in_logs: a bool value whether to show return_value + logs. Defaults to True, which allows return value log output. + It can be set to False to prevent log output of return value when you return huge data + such as transmission a large amount of XCom to TaskAPI. + """ @overload - def branch( - self, python_callable: Optional[Callable] = None, multiple_outputs: Optional[bool] = None, **kwargs + def branch(self, *, multiple_outputs: bool | None = None, **kwargs) -> TaskDecorator: + """Create a decorator to wrap the decorated callable into a BranchPythonOperator. + + For more information on how to use this decorator, see :ref:`howto/operator:BranchPythonOperator`. + Accepts arbitrary for operator kwarg. Can be reused in a single DAG. + + :param multiple_outputs: If set, function return value will be unrolled to multiple XCom values. + Dict will unroll to XCom values with keys as XCom keys. Defaults to False. + """ + @overload + def branch(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ... + @overload + def short_circuit( + self, + *, + multiple_outputs: bool | None = None, + ignore_downstream_trigger_rules: bool = True, + **kwargs, ) -> TaskDecorator: - """Wraps a python function into a BranchPythonOperator + """Create a decorator to wrap the decorated callable into a ShortCircuitOperator. - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:BranchPythonOperator` - Accepts kwargs for operator kwarg. Can be reused in a single DAG. - :param python_callable: Function to decorate - :type python_callable: Optional[Callable] - :param multiple_outputs: if set, function return value will be - unrolled to multiple XCom values. Dict will unroll to xcom values with keys as XCom keys. - Defaults to False. - :type multiple_outputs: bool + :param multiple_outputs: If set, function return value will be unrolled to multiple XCom values. + Dict will unroll to XCom values with keys as XCom keys. Defaults to False. + :param ignore_downstream_trigger_rules: If set to True, all downstream tasks from this operator task + will be skipped. This is the default behavior. If set to False, the direct, downstream task(s) + will be skipped but the ``trigger_rule`` defined for a other downstream tasks will be respected. + Defaults to True. """ @overload - def branch(self, python_callable: Function) -> Task[Function]: ... + def short_circuit(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ... # [START decorator_signature] def docker( self, *, - multiple_outputs: Optional[bool] = None, + multiple_outputs: bool | None = None, use_dill: bool = False, # Added by _DockerDecoratedOperator. python_command: str = "python3", # 'command', 'retrieve_output', and 'retrieve_output_path' are filled by # _DockerDecoratedOperator. image: str, - api_version: Optional[str] = None, - container_name: Optional[str] = None, + api_version: str | None = None, + container_name: str | None = None, cpus: float = 1.0, docker_url: str = "unix://var/run/docker.sock", - environment: Optional[Dict[str, str]] = None, - private_environment: Optional[Dict[str, str]] = None, + environment: dict[str, str] | None = None, + private_environment: dict[str, str] | None = None, force_pull: bool = False, - mem_limit: Optional[Union[float, str]] = None, - host_tmp_dir: Optional[str] = None, - network_mode: Optional[str] = None, - tls_ca_cert: Optional[str] = None, - tls_client_cert: Optional[str] = None, - tls_client_key: Optional[str] = None, - tls_hostname: Optional[Union[str, bool]] = None, - tls_ssl_version: Optional[str] = None, + mem_limit: float | str | None = None, + host_tmp_dir: str | None = None, + network_mode: str | None = None, + tls_ca_cert: str | None = None, + tls_client_cert: str | None = None, + tls_client_key: str | None = None, + tls_hostname: str | bool | None = None, + tls_ssl_version: str | None = None, tmp_dir: str = "/tmp/airflow", - user: Optional[Union[str, int]] = None, - mounts: Optional[List[str]] = None, - working_dir: Optional[str] = None, + user: str | int | None = None, + mounts: list[str] | None = None, + working_dir: str | None = None, xcom_all: bool = False, - docker_conn_id: Optional[str] = None, - dns: Optional[List[str]] = None, - dns_search: Optional[List[str]] = None, + docker_conn_id: str | None = None, + dns: list[str] | None = None, + dns_search: list[str] | None = None, auto_remove: bool = False, - shm_size: Optional[int] = None, + shm_size: int | None = None, tty: bool = False, privileged: bool = False, - cap_add: Optional[Iterable[str]] = None, - extra_hosts: Optional[Dict[str, str]] = None, + cap_add: str | None = None, + extra_hosts: dict[str, str] | None = None, **kwargs, ) -> TaskDecorator: """Create a decorator to convert the decorated callable to a Docker task. @@ -245,5 +297,157 @@ class TaskDecoratorCollection: :param cap_add: Include container capabilities """ # [END decorator_signature] + def kubernetes( + self, + *, + image: str, + kubernetes_conn_id: str = ..., + namespace: str = "default", + name: str = ..., + random_name_suffix: bool = True, + ports: list[k8s.V1ContainerPort] | None = None, + volume_mounts: list[k8s.V1VolumeMount] | None = None, + volumes: list[k8s.V1Volume] | None = None, + env_vars: list[k8s.V1EnvVar] | None = None, + env_from: list[k8s.V1EnvFromSource] | None = None, + secrets: list[Secret] | None = None, + in_cluster: bool | None = None, + cluster_context: str | None = None, + labels: dict | None = None, + reattach_on_restart: bool = True, + startup_timeout_seconds: int = 120, + get_logs: bool = True, + image_pull_policy: str | None = None, + annotations: dict | None = None, + container_resources: k8s.V1ResourceRequirements | None = None, + affinity: k8s.V1Affinity | None = None, + config_file: str = ..., + node_selector: dict | None = None, + image_pull_secrets: list[k8s.V1LocalObjectReference] | None = None, + service_account_name: str | None = None, + is_delete_operator_pod: bool = True, + hostnetwork: bool = False, + tolerations: list[k8s.V1Toleration] | None = None, + security_context: dict | None = None, + dnspolicy: str | None = None, + schedulername: str | None = None, + init_containers: list[k8s.V1Container] | None = None, + log_events_on_failure: bool = False, + do_xcom_push: bool = False, + pod_template_file: str | None = None, + priority_class_name: str | None = None, + pod_runtime_info_envs: list[k8s.V1EnvVar] | None = None, + termination_grace_period: int | None = None, + configmaps: list[str] | None = None, + **kwargs, + ) -> TaskDecorator: + """Create a decorator to convert a callable to a Kubernetes Pod task. + + :param kubernetes_conn_id: The Kubernetes cluster's + :ref:`connection ID `. + :param namespace: Namespace to run within Kubernetes. Defaults to *default*. + :param image: Docker image to launch. Defaults to *hub.docker.com*, but + a fully qualified URL will point to a custom repository. (templated) + :param name: Name of the pod to run. This will be used (plus a random + suffix if *random_name_suffix* is *True*) to generate a pod ID + (DNS-1123 subdomain, containing only ``[a-z0-9.-]``). Defaults to + ``k8s_airflow_pod_{RANDOM_UUID}``. + :param random_name_suffix: If *True*, will generate a random suffix. + :param ports: Ports for the launched pod. + :param volume_mounts: *volumeMounts* for the launched pod. + :param volumes: Volumes for the launched pod. Includes *ConfigMaps* and + *PersistentVolumes*. + :param env_vars: Environment variables initialized in the container. + (templated) + :param env_from: List of sources to populate environment variables in + the container. + :param secrets: Kubernetes secrets to inject in the container. They can + be exposed as environment variables or files in a volume. + :param in_cluster: Run kubernetes client with *in_cluster* configuration. + :param cluster_context: Context that points to the Kubernetes cluster. + Ignored when *in_cluster* is *True*. If *None*, current-context will + be used. + :param reattach_on_restart: If the worker dies while the pod is running, + reattach and monitor during the next try. If *False*, always create + a new pod for each try. + :param labels: Labels to apply to the pod. (templated) + :param startup_timeout_seconds: Timeout in seconds to startup the pod. + :param get_logs: Get the stdout of the container as logs of the tasks. + :param image_pull_policy: Specify a policy to cache or always pull an + image. + :param annotations: Non-identifying metadata you can attach to the pod. + Can be a large range of data, and can include characters that are + not permitted by labels. + :param container_resources: Resources for the launched pod. + :param affinity: Affinity scheduling rules for the launched pod. + :param config_file: The path to the Kubernetes config file. If not + specified, default value is ``~/.kube/config``. (templated) + :param node_selector: A dict containing a group of scheduling rules. + :param image_pull_secrets: Any image pull secrets to be given to the + pod. If more than one secret is required, provide a comma separated + list, e.g. ``secret_a,secret_b``. + :param service_account_name: Name of the service account. + :param is_delete_operator_pod: What to do when the pod reaches its final + state, or the execution is interrupted. If *True* (default), delete + the pod; otherwise leave the pod. + :param hostnetwork: If *True*, enable host networking on the pod. + :param tolerations: A list of Kubernetes tolerations. + :param security_context: Security options the pod should run with + (PodSecurityContext). + :param dnspolicy: DNS policy for the pod. + :param schedulername: Specify a scheduler name for the pod + :param init_containers: Init containers for the launched pod. + :param log_events_on_failure: Log the pod's events if a failure occurs. + :param do_xcom_push: If *True*, the content of + ``/airflow/xcom/return.json`` in the container will also be pushed + to an XCom when the container completes. + :param pod_template_file: Path to pod template file (templated) + :param priority_class_name: Priority class name for the launched pod. + :param pod_runtime_info_envs: A list of environment variables + to be set in the container. + :param termination_grace_period: Termination grace period if task killed + in UI, defaults to kubernetes default + :param configmaps: A list of names of config maps from which it collects + ConfigMaps to populate the environment variables with. The contents + of the target ConfigMap's Data field will represent the key-value + pairs as environment variables. Extends env_from. + """ + @overload + def sensor( + self, + *, + poke_interval: float = ..., + timeout: float = ..., + soft_fail: bool = False, + mode: str = ..., + exponential_backoff: bool = False, + max_wait: timedelta | float | None = None, + **kwargs, + ) -> TaskDecorator: + """ + Wraps a Python function into a sensor operator. + + :param poke_interval: Time in seconds that the job should wait in + between each try + :param timeout: Time, in seconds before the task times out and fails. + :param soft_fail: Set to true to mark the task as SKIPPED on failure + :param mode: How the sensor operates. + Options are: ``{ poke | reschedule }``, default is ``poke``. + When set to ``poke`` the sensor is taking up a worker slot for its + whole execution time and sleeps between pokes. Use this mode if the + expected runtime of the sensor is short or if a short poke interval + is required. Note that the sensor will hold onto a worker slot and + a pool slot for the duration of the sensor's runtime in this mode. + When set to ``reschedule`` the sensor task frees the worker slot when + the criteria is not yet met and it's rescheduled at a later time. Use + this mode if the time before the criteria is met is expected to be + quite long. The poke interval should be more than one minute to + prevent too much load on the scheduler. + :param exponential_backoff: allow progressive longer waits between + pokes by using exponential backoff algorithm + :param max_wait: maximum wait interval between pokes, can be ``timedelta`` or ``float`` seconds + """ + @overload + def sensor(self, python_callable: Optional[FParams, FReturn] = None) -> Task[FParams, FReturn]: ... task: TaskDecoratorCollection diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 1b5b5b760bf77..7da74e5514e42 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -14,23 +14,21 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -import functools import inspect import re +from itertools import chain from typing import ( - TYPE_CHECKING, Any, Callable, + ClassVar, Collection, Dict, Generic, Iterator, Mapping, - Optional, Sequence, - Set, - Type, TypeVar, cast, overload, @@ -38,8 +36,10 @@ import attr import typing_extensions +from sqlalchemy.orm import Session -from airflow.compat.functools import cache, cached_property +from airflow import Dataset +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.models.abstractoperator import DEFAULT_RETRIES, DEFAULT_RETRY_DELAY from airflow.models.baseoperator import ( @@ -50,50 +50,74 @@ parse_retries, ) from airflow.models.dag import DAG, DagContext -from airflow.models.mappedoperator import ( - MappedOperator, - ValidationSource, - ensure_xcomarg_return_value, - get_mappable_types, - prevent_duplicates, +from airflow.models.expandinput import ( + EXPAND_INPUT_EMPTY, + DictOfListsExpandInput, + ExpandInput, + ListOfDictsExpandInput, + OperatorExpandArgument, + OperatorExpandKwargsArgument, + is_mappable, ) +from airflow.models.mappedoperator import MappedOperator, ValidationSource, ensure_xcomarg_return_value from airflow.models.pool import Pool from airflow.models.xcom_arg import XComArg -from airflow.typing_compat import Protocol +from airflow.typing_compat import ParamSpec, Protocol from airflow.utils import timezone from airflow.utils.context import KNOWN_CONTEXT_KEYS, Context +from airflow.utils.helpers import prevent_duplicates from airflow.utils.task_group import TaskGroup, TaskGroupContext from airflow.utils.types import NOTSET -if TYPE_CHECKING: - import jinja2 # Slow import. - from sqlalchemy.orm import Session - from airflow.models.mappedoperator import Mappable +class ExpandableFactory(Protocol): + """Protocol providing inspection against wrapped function. + This is used in ``validate_expand_kwargs`` and implemented by function + decorators like ``@task`` and ``@task_group``. -def validate_python_callable(python_callable: Any) -> None: + :meta private: """ - Validate that python callable can be wrapped by operator. - Raises exception if invalid. - :param python_callable: Python object to be validated - :raises: TypeError, AirflowException - """ - if not callable(python_callable): - raise TypeError('`python_callable` param must be callable') - if 'self' in inspect.signature(python_callable).parameters.keys(): - raise AirflowException('@task does not support methods') + function: Callable + + @cached_property + def function_signature(self) -> inspect.Signature: + return inspect.signature(self.function) + + @cached_property + def _mappable_function_argument_names(self) -> set[str]: + """Arguments that can be mapped against.""" + return set(self.function_signature.parameters) + + def _validate_arg_names(self, func: ValidationSource, kwargs: dict[str, Any]) -> None: + """Ensure that all arguments passed to operator-mapping functions are accounted for.""" + parameters = self.function_signature.parameters + if any(v.kind == inspect.Parameter.VAR_KEYWORD for v in parameters.values()): + return + kwargs_left = kwargs.copy() + for arg_name in self._mappable_function_argument_names: + value = kwargs_left.pop(arg_name, NOTSET) + if func != "expand" or value is NOTSET or is_mappable(value): + continue + tname = type(value).__name__ + raise ValueError(f"expand() got an unexpected type {tname!r} for keyword argument {arg_name!r}") + if len(kwargs_left) == 1: + raise TypeError(f"{func}() got an unexpected keyword argument {next(iter(kwargs_left))!r}") + elif kwargs_left: + names = ", ".join(repr(n) for n in kwargs_left) + raise TypeError(f"{func}() got unexpected keyword arguments {names}") def get_unique_task_id( task_id: str, - dag: Optional[DAG] = None, - task_group: Optional[TaskGroup] = None, + dag: DAG | None = None, + task_group: TaskGroup | None = None, ) -> str: """ - Generate unique task id given a DAG (or if run in a DAG context) - Ids are generated by appending a unique number to the end of + Generate unique task id given a DAG (or if run in a DAG context). + + IDs are generated by appending a unique number to the end of the original task id. Example: @@ -144,38 +168,52 @@ class DecoratedOperator(BaseOperator): PythonOperator). This gives a user the option to upstream kwargs as needed. """ - template_fields: Sequence[str] = ('op_args', 'op_kwargs') + template_fields: Sequence[str] = ("op_args", "op_kwargs") template_fields_renderers = {"op_args": "py", "op_kwargs": "py"} # since we won't mutate the arguments, we should just do the shallow copy # there are some cases we can't deepcopy the objects (e.g protobuf). - shallow_copy_attrs: Sequence[str] = ('python_callable',) + shallow_copy_attrs: Sequence[str] = ("python_callable",) def __init__( self, *, python_callable: Callable, task_id: str, - op_args: Optional[Collection[Any]] = None, - op_kwargs: Optional[Mapping[str, Any]] = None, + op_args: Collection[Any] | None = None, + op_kwargs: Mapping[str, Any] | None = None, multiple_outputs: bool = False, - kwargs_to_upstream: Optional[Dict[str, Any]] = None, + kwargs_to_upstream: dict[str, Any] | None = None, **kwargs, ) -> None: - task_id = get_unique_task_id(task_id, kwargs.get('dag'), kwargs.get('task_group')) + task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group")) self.python_callable = python_callable kwargs_to_upstream = kwargs_to_upstream or {} op_args = op_args or [] op_kwargs = op_kwargs or {} - # Check that arguments can be binded - inspect.signature(python_callable).bind(*op_args, **op_kwargs) + # Check that arguments can be binded. There's a slight difference when + # we do validation for task-mapping: Since there's no guarantee we can + # receive enough arguments at parse time, we use bind_partial to simply + # check all the arguments we know are valid. Whether these are enough + # can only be known at execution time, when unmapping happens, and this + # is called without the _airflow_mapped_validation_only flag. + if kwargs.get("_airflow_mapped_validation_only"): + inspect.signature(python_callable).bind_partial(*op_args, **op_kwargs) + else: + inspect.signature(python_callable).bind(*op_args, **op_kwargs) + self.multiple_outputs = multiple_outputs self.op_args = op_args self.op_kwargs = op_kwargs super().__init__(task_id=task_id, **kwargs_to_upstream, **kwargs) def execute(self, context: Context): + # todo make this more generic (move to prepare_lineage) so it deals with non taskflow operators + # as well + for arg in chain(self.op_args, self.op_kwargs.values()): + if isinstance(arg, Dataset): + self.inlets.append(arg) return_value = super().execute(context) return self._handle_output(return_value=return_value, context=context, xcom_push=self.xcom_push) @@ -183,49 +221,61 @@ def _handle_output(self, return_value: Any, context: Context, xcom_push: Callabl """ Handles logic for whether a decorator needs to push a single return value or multiple return values. + It sets outlets if any datasets are found in the returned value(s) + :param return_value: :param context: :param xcom_push: """ + if isinstance(return_value, Dataset): + self.outlets.append(return_value) + if isinstance(return_value, list): + for item in return_value: + if isinstance(item, Dataset): + self.outlets.append(item) if not self.multiple_outputs: return return_value if isinstance(return_value, dict): for key in return_value.keys(): if not isinstance(key, str): raise AirflowException( - 'Returned dictionary keys must be strings when using ' - f'multiple_outputs, found {key} ({type(key)}) instead' + "Returned dictionary keys must be strings when using " + f"multiple_outputs, found {key} ({type(key)}) instead" ) for key, value in return_value.items(): + if isinstance(value, Dataset): + self.outlets.append(value) xcom_push(context, key, value) else: raise AirflowException( - f'Returned output was type {type(return_value)} expected dictionary for multiple_outputs' + f"Returned output was type {type(return_value)} expected dictionary for multiple_outputs" ) return return_value def _hook_apply_defaults(self, *args, **kwargs): - if 'python_callable' not in kwargs: + if "python_callable" not in kwargs: return args, kwargs - python_callable = kwargs['python_callable'] - default_args = kwargs.get('default_args') or {} - op_kwargs = kwargs.get('op_kwargs') or {} + python_callable = kwargs["python_callable"] + default_args = kwargs.get("default_args") or {} + op_kwargs = kwargs.get("op_kwargs") or {} f_sig = inspect.signature(python_callable) for arg in f_sig.parameters: if arg not in op_kwargs and arg in default_args: op_kwargs[arg] = default_args[arg] - kwargs['op_kwargs'] = op_kwargs + kwargs["op_kwargs"] = op_kwargs return args, kwargs -Function = TypeVar("Function", bound=Callable) +FParams = ParamSpec("FParams") + +FReturn = TypeVar("FReturn") OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator") @attr.define(slots=False) -class _TaskDecorator(Generic[Function, OperatorSubclass]): +class _TaskDecorator(ExpandableFactory, Generic[FParams, FReturn, OperatorSubclass]): """ Helper class for providing dynamic task mapping to decorated functions. @@ -234,18 +284,20 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]): :meta private: """ - function: Function = attr.ib() - operator_class: Type[OperatorSubclass] + function: Callable[FParams, FReturn] = attr.ib(validator=attr.validators.is_callable()) + operator_class: type[OperatorSubclass] multiple_outputs: bool = attr.ib() - kwargs: Dict[str, Any] = attr.ib(factory=dict) + kwargs: dict[str, Any] = attr.ib(factory=dict) decorator_name: str = attr.ib(repr=False, default="task") + _airflow_is_task_decorator: ClassVar[bool] = True + @multiple_outputs.default def _infer_multiple_outputs(self): try: return_type = typing_extensions.get_type_hints(self.function).get("return", Any) - except Exception: # Can't evaluate retrurn type. + except TypeError: # Can't evaluate return type. return False ttype = getattr(return_type, "__origin__", return_type) return ttype == dict or ttype == Dict @@ -253,9 +305,9 @@ def _infer_multiple_outputs(self): def __attrs_post_init__(self): if "self" in self.function_signature.parameters: raise TypeError(f"@{self.decorator_name} does not support methods") - self.kwargs.setdefault('task_id', self.function.__name__) + self.kwargs.setdefault("task_id", self.function.__name__) - def __call__(self, *args, **kwargs) -> XComArg: + def __call__(self, *args: FParams.args, **kwargs: FParams.kwargs) -> XComArg: op = self.operator_class( python_callable=self.function, op_args=args, @@ -268,24 +320,10 @@ def __call__(self, *args, **kwargs) -> XComArg: return XComArg(op) @property - def __wrapped__(self) -> Function: + def __wrapped__(self) -> Callable[FParams, FReturn]: return self.function - @cached_property - def function_signature(self): - return inspect.signature(self.function) - - @cached_property - def _function_is_vararg(self): - parameters = self.function_signature.parameters - return any(v.kind == inspect.Parameter.VAR_KEYWORD for v in parameters.values()) - - @cached_property - def _mappable_function_argument_names(self) -> Set[str]: - """Arguments that can be mapped against.""" - return set(self.function_signature.parameters) - - def _validate_arg_names(self, func: ValidationSource, kwargs: Dict[str, Any]): + def _validate_arg_names(self, func: ValidationSource, kwargs: dict[str, Any]): # Ensure that context variables are not shadowed. context_keys_being_mapped = KNOWN_CONTEXT_KEYS.intersection(kwargs) if len(context_keys_being_mapped) == 1: @@ -295,35 +333,34 @@ def _validate_arg_names(self, func: ValidationSource, kwargs: Dict[str, Any]): names = ", ".join(repr(n) for n in context_keys_being_mapped) raise ValueError(f"cannot call {func}() on task context variables {names}") - # Ensure that all arguments passed in are accounted for. - if self._function_is_vararg: - return - kwargs_left = kwargs.copy() - for arg_name in self._mappable_function_argument_names: - value = kwargs_left.pop(arg_name, NOTSET) - if func != "expand" or value is NOTSET or isinstance(value, get_mappable_types()): - continue - tname = type(value).__name__ - raise ValueError(f"expand() got an unexpected type {tname!r} for keyword argument {arg_name!r}") - if len(kwargs_left) == 1: - raise TypeError(f"{func}() got an unexpected keyword argument {next(iter(kwargs_left))!r}") - elif kwargs_left: - names = ", ".join(repr(n) for n in kwargs_left) - raise TypeError(f"{func}() got unexpected keyword arguments {names}") + super()._validate_arg_names(func, kwargs) - def expand(self, **map_kwargs: "Mappable") -> XComArg: + def expand(self, **map_kwargs: OperatorExpandArgument) -> XComArg: if not map_kwargs: raise TypeError("no arguments to expand against") - self._validate_arg_names("expand", map_kwargs) prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping already partial") - ensure_xcomarg_return_value(map_kwargs) + # Since the input is already checked at parse time, we can set strict + # to False to skip the checks on execution. + return self._expand(DictOfListsExpandInput(map_kwargs), strict=False) + + def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg: + if isinstance(kwargs, Sequence): + for item in kwargs: + if not isinstance(item, (XComArg, Mapping)): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") + elif not isinstance(kwargs, XComArg): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") + return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) + + def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: + ensure_xcomarg_return_value(expand_input.value) task_kwargs = self.kwargs.copy() dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag() task_group = task_kwargs.pop("task_group", None) or TaskGroupContext.get_current_task_group(dag) - partial_kwargs, default_params = get_merged_defaults( + partial_kwargs, partial_params = get_merged_defaults( dag=dag, task_group=task_group, task_params=task_kwargs.pop("params", None), @@ -332,7 +369,8 @@ def expand(self, **map_kwargs: "Mappable") -> XComArg: partial_kwargs.update(task_kwargs) task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group) - params = partial_kwargs.pop("params", None) or default_params + if task_group: + task_id = task_group.child_id(task_id) # Logic here should be kept in sync with BaseOperatorMeta.partial(). if "task_concurrency" in partial_kwargs: @@ -361,12 +399,18 @@ def expand(self, **map_kwargs: "Mappable") -> XComArg: # Mypy does not work well with a subclassed attrs class :( _MappedOperator = cast(Any, DecoratedMappedOperator) + + try: + operator_name = self.operator_class.custom_operator_name # type: ignore + except AttributeError: + operator_name = self.operator_class.__name__ + operator = _MappedOperator( operator_class=self.operator_class, - mapped_kwargs={}, + expand_input=EXPAND_INPUT_EMPTY, # Don't use this; mapped values go to op_kwargs_expand_input. partial_kwargs=partial_kwargs, task_id=task_id, - params=params, + params=partial_params, deps=MappedOperator.deps_for(self.operator_class), operator_extra_links=self.operator_class.operator_extra_links, template_ext=self.operator_class.template_ext, @@ -377,42 +421,33 @@ def expand(self, **map_kwargs: "Mappable") -> XComArg: is_empty=False, task_module=self.operator_class.__module__, task_type=self.operator_class.__name__, + operator_name=operator_name, dag=dag, task_group=task_group, start_date=start_date, end_date=end_date, multiple_outputs=self.multiple_outputs, python_callable=self.function, - mapped_op_kwargs=map_kwargs, + op_kwargs_expand_input=expand_input, + disallow_kwargs_override=strict, # Different from classic operators, kwargs passed to a taskflow # task's expand() contribute to the op_kwargs operator argument, not # the operator arguments themselves, and should expand against it. - expansion_kwargs_attr="mapped_op_kwargs", + expand_input_attr="op_kwargs_expand_input", ) return XComArg(operator=operator) - def partial(self, **kwargs) -> "_TaskDecorator[Function, OperatorSubclass]": + def partial(self, **kwargs: Any) -> _TaskDecorator[FParams, FReturn, OperatorSubclass]: self._validate_arg_names("partial", kwargs) + old_kwargs = self.kwargs.get("op_kwargs", {}) + prevent_duplicates(old_kwargs, kwargs, fail_reason="duplicate partial") + kwargs.update(old_kwargs) + return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs": kwargs}) - op_kwargs = self.kwargs.get("op_kwargs", {}) - op_kwargs = _merge_kwargs(op_kwargs, kwargs, fail_reason="duplicate partial") - - return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs": op_kwargs}) - - def override(self, **kwargs) -> "_TaskDecorator[Function, OperatorSubclass]": + def override(self, **kwargs: Any) -> _TaskDecorator[FParams, FReturn, OperatorSubclass]: return attr.evolve(self, kwargs={**self.kwargs, **kwargs}) -def _merge_kwargs(kwargs1: Dict[str, Any], kwargs2: Dict[str, Any], *, fail_reason: str) -> Dict[str, Any]: - duplicated_keys = set(kwargs1).intersection(kwargs2) - if len(duplicated_keys) == 1: - raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}") - elif duplicated_keys: - duplicated_keys_display = ", ".join(sorted(duplicated_keys)) - raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}") - return {**kwargs1, **kwargs2} - - @attr.define(kw_only=True, repr=False) class DecoratedMappedOperator(MappedOperator): """MappedOperator implementation for @task-decorated task function.""" @@ -420,81 +455,41 @@ class DecoratedMappedOperator(MappedOperator): multiple_outputs: bool python_callable: Callable - # We can't save these in mapped_kwargs because op_kwargs need to be present + # We can't save these in expand_input because op_kwargs need to be present # in partial_kwargs, and MappedOperator prevents duplication. - mapped_op_kwargs: Dict[str, "Mappable"] + op_kwargs_expand_input: ExpandInput def __hash__(self): return id(self) - @classmethod - @cache - def get_serialized_fields(cls): - # The magic super() doesn't work here, so we use the explicit form. - # Not using super(..., cls) to work around pyupgrade bug. - sup = super(DecoratedMappedOperator, DecoratedMappedOperator) - return sup.get_serialized_fields() | {"mapped_op_kwargs"} - def __attrs_post_init__(self): # The magic super() doesn't work here, so we use the explicit form. # Not using super(..., self) to work around pyupgrade bug. super(DecoratedMappedOperator, DecoratedMappedOperator).__attrs_post_init__(self) - XComArg.apply_upstream_relationship(self, self.mapped_op_kwargs) - - def _get_unmap_kwargs(self) -> Dict[str, Any]: - partial_kwargs = self.partial_kwargs.copy() - op_kwargs = _merge_kwargs( - partial_kwargs.pop("op_kwargs"), - self.mapped_op_kwargs, - fail_reason="mapping already partial", - ) - self._combined_op_kwargs = op_kwargs - return { - "dag": self.dag, - "task_group": self.task_group, - "task_id": self.task_id, - "op_kwargs": op_kwargs, + XComArg.apply_upstream_relationship(self, self.op_kwargs_expand_input.value) + + def _expand_mapped_kwargs(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]: + # We only use op_kwargs_expand_input so this must always be empty. + assert self.expand_input is EXPAND_INPUT_EMPTY + op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context, session) + return {"op_kwargs": op_kwargs}, resolved_oids + + def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]: + partial_op_kwargs = self.partial_kwargs["op_kwargs"] + mapped_op_kwargs = mapped_kwargs["op_kwargs"] + + if strict: + prevent_duplicates(partial_op_kwargs, mapped_op_kwargs, fail_reason="mapping already partial") + + kwargs = { "multiple_outputs": self.multiple_outputs, "python_callable": self.python_callable, - **partial_kwargs, - **self.mapped_kwargs, + "op_kwargs": {**partial_op_kwargs, **mapped_op_kwargs}, } + return super()._get_unmap_kwargs(kwargs, strict=False) - def _resolve_expansion_kwargs( - self, kwargs: Dict[str, Any], template_fields: Set[str], context: Context, session: "Session" - ) -> None: - expansion_kwargs = self._get_expansion_kwargs() - - self._already_resolved_op_kwargs = set() - for k, v in expansion_kwargs.items(): - if isinstance(v, XComArg): - self._already_resolved_op_kwargs.add(k) - v = v.resolve(context, session=session) - v = self._expand_mapped_field(k, v, context, session=session) - kwargs['op_kwargs'][k] = v - template_fields.discard(k) - - def render_template( - self, - value: Any, - context: Context, - jinja_env: Optional["jinja2.Environment"] = None, - seen_oids: Optional[Set] = None, - ) -> Any: - if hasattr(self, '_combined_op_kwargs') and value is self._combined_op_kwargs: - # Avoid rendering values that came out of resolved XComArgs - return { - k: v - if k in self._already_resolved_op_kwargs - else super(DecoratedMappedOperator, DecoratedMappedOperator).render_template( - self, v, context, jinja_env=jinja_env, seen_oids=seen_oids - ) - for k, v in value.items() - } - return super().render_template(value, context, jinja_env=jinja_env, seen_oids=seen_oids) - - -class Task(Generic[Function]): + +class Task(Generic[FParams, FReturn]): """Declaration of a @task-decorated callable for type-checking. An instance of this type inherits the call signature of the decorated @@ -505,18 +500,24 @@ class Task(Generic[Function]): This type is implemented by ``_TaskDecorator`` at runtime. """ - __call__: Function + __call__: Callable[FParams, XComArg] - function: Function + function: Callable[FParams, FReturn] @property - def __wrapped__(self) -> Function: + def __wrapped__(self) -> Callable[FParams, FReturn]: ... - def expand(self, **kwargs: "Mappable") -> XComArg: + def partial(self, **kwargs: Any) -> Task[FParams, FReturn]: ... - def partial(self, **kwargs: Any) -> "Task[Function]": + def expand(self, **kwargs: OperatorExpandArgument) -> XComArg: + ... + + def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg: + ... + + def override(self, **kwargs: Any) -> Task[FParams, FReturn]: ... @@ -524,36 +525,46 @@ class TaskDecorator(Protocol): """Type declaration for ``task_decorator_factory`` return type.""" @overload - def __call__(self, python_callable: Function) -> Task[Function]: + def __call__( # type: ignore[misc] + self, + python_callable: Callable[FParams, FReturn], + ) -> Task[FParams, FReturn]: """For the "bare decorator" ``@task`` case.""" @overload def __call__( self, *, - multiple_outputs: Optional[bool] = None, + multiple_outputs: bool | None = None, **kwargs: Any, - ) -> Callable[[Function], Task[Function]]: + ) -> Callable[[Callable[FParams, FReturn]], Task[FParams, FReturn]]: """For the decorator factory ``@task()`` case.""" + def override(self, **kwargs: Any) -> Task[FParams, FReturn]: + ... + def task_decorator_factory( - python_callable: Optional[Callable] = None, + python_callable: Callable | None = None, *, - multiple_outputs: Optional[bool] = None, - decorated_operator_class: Type[BaseOperator], + multiple_outputs: bool | None = None, + decorated_operator_class: type[BaseOperator], **kwargs, ) -> TaskDecorator: - """ - A factory that generates a wrapper that wraps a function into an Airflow operator. - Accepts kwargs for operator kwarg. Can be reused in a single DAG. + """Generate a wrapper that wraps a function into an Airflow operator. - :param python_callable: Function to decorate - :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to - multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False. - :param decorated_operator_class: The operator that executes the logic needed to run the python function in - the correct environment + Can be reused in a single DAG. + + :param python_callable: Function to decorate. + :param multiple_outputs: If set to True, the decorated function's return + value will be unrolled to multiple XCom values. Dict will unroll to XCom + values with its keys as XCom keys. If set to False (default), only at + most one XCom value is pushed. + :param decorated_operator_class: The operator that executes the logic needed + to run the python function in the correct environment. + Other kwargs are directly forwarded to the underlying operator class when + it's instantiated. """ if multiple_outputs is None: multiple_outputs = cast(bool, attr.NOTHING) @@ -566,11 +577,14 @@ def task_decorator_factory( ) return cast(TaskDecorator, decorator) elif python_callable is not None: - raise TypeError('No args allowed while using @task, use kwargs instead') - decorator_factory = functools.partial( - _TaskDecorator, - multiple_outputs=multiple_outputs, - operator_class=decorated_operator_class, - kwargs=kwargs, - ) + raise TypeError("No args allowed while using @task, use kwargs instead") + + def decorator_factory(python_callable): + return _TaskDecorator( + function=python_callable, + multiple_outputs=multiple_outputs, + operator_class=decorated_operator_class, + kwargs=kwargs, + ) + return cast(TaskDecorator, decorator_factory) diff --git a/airflow/decorators/branch_python.py b/airflow/decorators/branch_python.py index ac83132574402..b7e3a94826679 100644 --- a/airflow/decorators/branch_python.py +++ b/airflow/decorators/branch_python.py @@ -14,14 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import inspect from textwrap import dedent -from typing import Callable, Optional, Sequence, TypeVar +from typing import Callable, Sequence from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory from airflow.operators.python import BranchPythonOperator -from airflow.utils.python_virtualenv import remove_task_decorator +from airflow.utils.decorators import remove_task_decorator class _BranchPythonDecoratedOperator(DecoratedOperator, BranchPythonOperator): @@ -38,12 +39,14 @@ class _BranchPythonDecoratedOperator(DecoratedOperator, BranchPythonOperator): Defaults to False. """ - template_fields: Sequence[str] = ('op_args', 'op_kwargs') + template_fields: Sequence[str] = ("op_args", "op_kwargs") template_fields_renderers = {"op_args": "py", "op_kwargs": "py"} # since we won't mutate the arguments, we should just do the shallow copy # there are some cases we can't deepcopy the objects (e.g protobuf). - shallow_copy_attrs: Sequence[str] = ('python_callable',) + shallow_copy_attrs: Sequence[str] = ("python_callable",) + + custom_operator_name: str = "@task.branch" def __init__( self, @@ -63,14 +66,12 @@ def get_python_source(self): return res -T = TypeVar("T", bound=Callable) - - def branch_task( - python_callable: Optional[Callable] = None, multiple_outputs: Optional[bool] = None, **kwargs + python_callable: Callable | None = None, multiple_outputs: bool | None = None, **kwargs ) -> TaskDecorator: """ - Wraps a python function into a BranchPythonOperator + Wraps a python function into a BranchPythonOperator. + For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:BranchPythonOperator` diff --git a/airflow/decorators/external_python.py b/airflow/decorators/external_python.py new file mode 100644 index 0000000000000..273bf95321d83 --- /dev/null +++ b/airflow/decorators/external_python.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import inspect +from textwrap import dedent +from typing import Callable, Sequence + +from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory +from airflow.operators.python import ExternalPythonOperator +from airflow.utils.decorators import remove_task_decorator + + +class _PythonExternalDecoratedOperator(DecoratedOperator, ExternalPythonOperator): + """ + Wraps a Python callable and captures args/kwargs when called for execution. + + :param python: Full path string (file-system specific) that points to a Python binary inside + a virtualenv that should be used (in ``VENV/bin`` folder). Should be absolute path + (so usually start with "/" or "X:/" depending on the filesystem/os used). + :param python_callable: A reference to an object that is callable + :param op_kwargs: a dictionary of keyword arguments that will get unpacked + in your function (templated) + :param op_args: a list of positional arguments that will get unpacked when + calling your callable (templated) + :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to + multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False. + """ + + template_fields: Sequence[str] = ("op_args", "op_kwargs") + template_fields_renderers = {"op_args": "py", "op_kwargs": "py"} + + # since we won't mutate the arguments, we should just do the shallow copy + # there are some cases we can't deepcopy the objects (e.g protobuf). + shallow_copy_attrs: Sequence[str] = ("python_callable",) + + custom_operator_name: str = "@task.external_python" + + def __init__(self, *, python_callable, op_args, op_kwargs, **kwargs) -> None: + kwargs_to_upstream = { + "python_callable": python_callable, + "op_args": op_args, + "op_kwargs": op_kwargs, + } + super().__init__( + kwargs_to_upstream=kwargs_to_upstream, + python_callable=python_callable, + op_args=op_args, + op_kwargs=op_kwargs, + **kwargs, + ) + + def get_python_source(self): + raw_source = inspect.getsource(self.python_callable) + res = dedent(raw_source) + res = remove_task_decorator(res, "@task.external_python") + return res + + +def external_python_task( + python: str | None = None, + python_callable: Callable | None = None, + multiple_outputs: bool | None = None, + **kwargs, +) -> TaskDecorator: + """Wraps a callable into an Airflow operator to run via a Python virtual environment. + + Accepts kwargs for operator kwarg. Can be reused in a single DAG. + + This function is only used during type checking or auto-completion. + + :meta private: + + :param python: Full path string (file-system specific) that points to a Python binary inside + a virtualenv that should be used (in ``VENV/bin`` folder). Should be absolute path + (so usually start with "/" or "X:/" depending on the filesystem/os used). + :param python_callable: Function to decorate + :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to + multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. + Defaults to False. + """ + return task_decorator_factory( + python=python, + python_callable=python_callable, + multiple_outputs=multiple_outputs, + decorated_operator_class=_PythonExternalDecoratedOperator, + **kwargs, + ) diff --git a/airflow/decorators/python.py b/airflow/decorators/python.py index c7c1a629a4dfb..af88e534bdae5 100644 --- a/airflow/decorators/python.py +++ b/airflow/decorators/python.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import Callable, Optional, Sequence, TypeVar +from typing import Callable, Sequence from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory from airflow.operators.python import PythonOperator @@ -34,12 +35,14 @@ class _PythonDecoratedOperator(DecoratedOperator, PythonOperator): multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False. """ - template_fields: Sequence[str] = ('op_args', 'op_kwargs') - template_fields_renderers = {"op_args": "py", "op_kwargs": "py"} + template_fields: Sequence[str] = ("templates_dict", "op_args", "op_kwargs") + template_fields_renderers = {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"} # since we won't mutate the arguments, we should just do the shallow copy # there are some cases we can't deepcopy the objects (e.g protobuf). - shallow_copy_attrs: Sequence[str] = ('python_callable',) + shallow_copy_attrs: Sequence[str] = ("python_callable",) + + custom_operator_name: str = "@task" def __init__(self, *, python_callable, op_args, op_kwargs, **kwargs) -> None: kwargs_to_upstream = { @@ -56,12 +59,9 @@ def __init__(self, *, python_callable, op_args, op_kwargs, **kwargs) -> None: ) -T = TypeVar("T", bound=Callable) - - def python_task( - python_callable: Optional[Callable] = None, - multiple_outputs: Optional[bool] = None, + python_callable: Callable | None = None, + multiple_outputs: bool | None = None, **kwargs, ) -> TaskDecorator: """Wraps a function into an Airflow operator. diff --git a/airflow/decorators/python_virtualenv.py b/airflow/decorators/python_virtualenv.py index e8fd681b8f3af..53ca04b8e8f78 100644 --- a/airflow/decorators/python_virtualenv.py +++ b/airflow/decorators/python_virtualenv.py @@ -14,14 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import inspect from textwrap import dedent -from typing import Callable, Optional, Sequence, TypeVar +from typing import Callable, Sequence from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory from airflow.operators.python import PythonVirtualenvOperator -from airflow.utils.python_virtualenv import remove_task_decorator +from airflow.utils.decorators import remove_task_decorator class _PythonVirtualenvDecoratedOperator(DecoratedOperator, PythonVirtualenvOperator): @@ -37,12 +38,14 @@ class _PythonVirtualenvDecoratedOperator(DecoratedOperator, PythonVirtualenvOper multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False. """ - template_fields: Sequence[str] = ('op_args', 'op_kwargs') + template_fields: Sequence[str] = ("op_args", "op_kwargs") template_fields_renderers = {"op_args": "py", "op_kwargs": "py"} # since we won't mutate the arguments, we should just do the shallow copy # there are some cases we can't deepcopy the objects (e.g protobuf). - shallow_copy_attrs: Sequence[str] = ('python_callable',) + shallow_copy_attrs: Sequence[str] = ("python_callable",) + + custom_operator_name: str = "@task.virtualenv" def __init__(self, *, python_callable, op_args, op_kwargs, **kwargs) -> None: kwargs_to_upstream = { @@ -65,12 +68,9 @@ def get_python_source(self): return res -T = TypeVar("T", bound=Callable) - - def virtualenv_task( - python_callable: Optional[Callable] = None, - multiple_outputs: Optional[bool] = None, + python_callable: Callable | None = None, + multiple_outputs: bool | None = None, **kwargs, ) -> TaskDecorator: """Wraps a callable into an Airflow operator to run via a Python virtual environment. diff --git a/airflow/decorators/sensor.py b/airflow/decorators/sensor.py new file mode 100644 index 0000000000000..20339686201a9 --- /dev/null +++ b/airflow/decorators/sensor.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Callable, Sequence + +from airflow.decorators.base import TaskDecorator, get_unique_task_id, task_decorator_factory +from airflow.models.taskinstance import Context +from airflow.sensors.base import PokeReturnValue +from airflow.sensors.python import PythonSensor + + +class DecoratedSensorOperator(PythonSensor): + """ + Wraps a Python callable and captures args/kwargs when called for execution. + :param python_callable: A reference to an object that is callable + :param task_id: task Id + :param op_args: a list of positional arguments that will get unpacked when + calling your callable (templated) + :param op_kwargs: a dictionary of keyword arguments that will get unpacked + in your function (templated) + :param kwargs_to_upstream: For certain operators, we might need to upstream certain arguments + that would otherwise be absorbed by the DecoratedOperator (for example python_callable for the + PythonOperator). This gives a user the option to upstream kwargs as needed. + """ + + template_fields: Sequence[str] = ("op_args", "op_kwargs") + template_fields_renderers: dict[str, str] = {"op_args": "py", "op_kwargs": "py"} + + # since we won't mutate the arguments, we should just do the shallow copy + # there are some cases we can't deepcopy the objects (e.g protobuf). + shallow_copy_attrs: Sequence[str] = ("python_callable",) + + def __init__( + self, + *, + task_id: str, + **kwargs, + ) -> None: + kwargs.pop("multiple_outputs") + kwargs["task_id"] = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group")) + super().__init__(**kwargs) + + def poke(self, context: Context) -> PokeReturnValue | bool: + return self.python_callable(*self.op_args, **self.op_kwargs) + + +def sensor_task(python_callable: Callable | None = None, **kwargs) -> TaskDecorator: + """ + Wraps a function into an Airflow operator. + Accepts kwargs for operator kwarg. Can be reused in a single DAG. + :param python_callable: Function to decorate + """ + return task_decorator_factory( + python_callable=python_callable, + multiple_outputs=False, + decorated_operator_class=DecoratedSensorOperator, + **kwargs, + ) diff --git a/airflow/decorators/short_circuit.py b/airflow/decorators/short_circuit.py new file mode 100644 index 0000000000000..8422c4aa0844a --- /dev/null +++ b/airflow/decorators/short_circuit.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Callable, Sequence + +from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory +from airflow.operators.python import ShortCircuitOperator + + +class _ShortCircuitDecoratedOperator(DecoratedOperator, ShortCircuitOperator): + """ + Wraps a Python callable and captures args/kwargs when called for execution. + + :param python_callable: A reference to an object that is callable + :param op_kwargs: a dictionary of keyword arguments that will get unpacked + in your function (templated) + :param op_args: a list of positional arguments that will get unpacked when + calling your callable (templated) + :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to + multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False. + """ + + template_fields: Sequence[str] = ("op_args", "op_kwargs") + template_fields_renderers = {"op_args": "py", "op_kwargs": "py"} + + # since we won't mutate the arguments, we should just do the shallow copy + # there are some cases we can't deepcopy the objects (e.g protobuf). + shallow_copy_attrs: Sequence[str] = ("python_callable",) + + custom_operator_name: str = "@task.short_circuit" + + def __init__(self, *, python_callable, op_args, op_kwargs, **kwargs) -> None: + kwargs_to_upstream = { + "python_callable": python_callable, + "op_args": op_args, + "op_kwargs": op_kwargs, + } + super().__init__( + kwargs_to_upstream=kwargs_to_upstream, + python_callable=python_callable, + op_args=op_args, + op_kwargs=op_kwargs, + **kwargs, + ) + + +def short_circuit_task( + python_callable: Callable | None = None, + multiple_outputs: bool | None = None, + **kwargs, +) -> TaskDecorator: + """Wraps a function into an ShortCircuitOperator. + + Accepts kwargs for operator kwarg. Can be reused in a single DAG. + + This function is only used only used during type checking or auto-completion. + + :param python_callable: Function to decorate + :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to + multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False. + + :meta private: + """ + return task_decorator_factory( + python_callable=python_callable, + multiple_outputs=multiple_outputs, + decorated_operator_class=_ShortCircuitDecoratedOperator, + **kwargs, + ) diff --git a/airflow/decorators/task_group.py b/airflow/decorators/task_group.py index 674474947036e..2aa714be3b6c5 100644 --- a/airflow/decorators/task_group.py +++ b/airflow/decorators/task_group.py @@ -15,56 +15,89 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" -A TaskGroup is a collection of closely related tasks on the same DAG that should be grouped +"""Implements the ``@task_group`` function decorator. + +When the decorated function is called, a task group will be created to represent +a collection of closely related tasks on the same DAG that should be grouped together when the DAG is displayed graphically. """ + +from __future__ import annotations + import functools -from inspect import signature -from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar, Union, cast, overload +import inspect +import warnings +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Mapping, Sequence, TypeVar, overload import attr -from airflow.utils.task_group import TaskGroup +from airflow.decorators.base import ExpandableFactory +from airflow.models.expandinput import ( + DictOfListsExpandInput, + ListOfDictsExpandInput, + MappedArgument, + OperatorExpandArgument, + OperatorExpandKwargsArgument, +) +from airflow.models.taskmixin import DAGNode +from airflow.models.xcom_arg import XComArg +from airflow.typing_compat import ParamSpec +from airflow.utils.helpers import prevent_duplicates +from airflow.utils.task_group import MappedTaskGroup, TaskGroup if TYPE_CHECKING: from airflow.models.dag import DAG - from airflow.models.mappedoperator import Mappable -F = TypeVar("F", bound=Callable) -R = TypeVar("R") +FParams = ParamSpec("FParams") +FReturn = TypeVar("FReturn", None, DAGNode) -task_group_sig = signature(TaskGroup.__init__) +task_group_sig = inspect.signature(TaskGroup.__init__) -@attr.define -class TaskGroupDecorator(Generic[R]): - """:meta private:""" +@attr.define() +class _TaskGroupFactory(ExpandableFactory, Generic[FParams, FReturn]): + function: Callable[FParams, FReturn] = attr.ib(validator=attr.validators.is_callable()) + tg_kwargs: dict[str, Any] = attr.ib(factory=dict) # Parameters forwarded to TaskGroup. + partial_kwargs: dict[str, Any] = attr.ib(factory=dict) # Parameters forwarded to 'function'. - function: Callable[..., Optional[R]] = attr.ib(validator=attr.validators.is_callable()) - kwargs: Dict[str, Any] = attr.ib(factory=dict) - """kwargs for the TaskGroup""" + _task_group_created: bool = attr.ib(False, init=False) - @function.validator - def _validate_function(self, _, f): - if 'self' in signature(f).parameters: - raise TypeError('@task_group does not support methods') + tg_class: ClassVar[type[TaskGroup]] = TaskGroup - @kwargs.validator + @tg_kwargs.validator def _validate(self, _, kwargs): task_group_sig.bind_partial(**kwargs) def __attrs_post_init__(self): - self.kwargs.setdefault('group_id', self.function.__name__) + self.tg_kwargs.setdefault("group_id", self.function.__name__) + + def __del__(self): + if self.partial_kwargs and not self._task_group_created: + try: + group_id = repr(self.tg_kwargs["group_id"]) + except KeyError: + group_id = f"at {hex(id(self))}" + warnings.warn(f"Partial task group {group_id} was never mapped!") + + def __call__(self, *args: FParams.args, **kwargs: FParams.kwargs) -> DAGNode: + """Instantiate the task group. + + This uses the wrapped function to create a task group. Depending on the + return type of the wrapped function, this either returns the last task + in the group, or the group itself, to support task chaining. + """ + return self._create_task_group(TaskGroup, *args, **kwargs) + + def _create_task_group(self, tg_factory: Callable[..., TaskGroup], *args: Any, **kwargs: Any) -> DAGNode: + with tg_factory(add_suffix_on_collision=True, **self.tg_kwargs) as task_group: + if self.function.__doc__ and not task_group.tooltip: + task_group.tooltip = self.function.__doc__ - def _make_task_group(self, **kwargs) -> TaskGroup: - return TaskGroup(**kwargs) - - def __call__(self, *args, **kwargs) -> Union[R, TaskGroup]: - with self._make_task_group(add_suffix_on_collision=True, **self.kwargs) as task_group: # Invoke function to run Tasks inside the TaskGroup retval = self.function(*args, **kwargs) + self._task_group_created = True + # If the task-creating function returns a task, forward the return value # so dependencies bind to it. This is equivalent to # with TaskGroup(...) as tg: @@ -80,28 +113,55 @@ def __call__(self, *args, **kwargs) -> Union[R, TaskGroup]: # start >> tg >> end return task_group - -class Group(Generic[F]): - """Declaration of a @task_group-decorated callable for type-checking. - - An instance of this type inherits the call signature of the decorated - function wrapped in it (not *exactly* since it actually turns the function - into an XComArg-compatible, but there's no way to express that right now), - and provides two additional methods for task-mapping. - - This type is implemented by ``TaskGroupDecorator`` at runtime. - """ - - __call__: F - - function: F - - # Return value should match F's return type, but that's impossible to declare. - def expand(self, **kwargs: "Mappable") -> Any: - ... - - def partial(self, **kwargs: Any) -> "Group[F]": - ... + def override(self, **kwargs: Any) -> _TaskGroupFactory[FParams, FReturn]: + return attr.evolve(self, tg_kwargs={**self.tg_kwargs, **kwargs}) + + def partial(self, **kwargs: Any) -> _TaskGroupFactory[FParams, FReturn]: + self._validate_arg_names("partial", kwargs) + prevent_duplicates(self.partial_kwargs, kwargs, fail_reason="duplicate partial") + kwargs.update(self.partial_kwargs) + return attr.evolve(self, partial_kwargs=kwargs) + + def expand(self, **kwargs: OperatorExpandArgument) -> DAGNode: + if not kwargs: + raise TypeError("no arguments to expand against") + self._validate_arg_names("expand", kwargs) + prevent_duplicates(self.partial_kwargs, kwargs, fail_reason="mapping already partial") + expand_input = DictOfListsExpandInput(kwargs) + return self._create_task_group( + functools.partial(MappedTaskGroup, expand_input=expand_input), + **self.partial_kwargs, + **{k: MappedArgument(input=expand_input, key=k) for k in kwargs}, + ) + + def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument) -> DAGNode: + if isinstance(kwargs, Sequence): + for item in kwargs: + if not isinstance(item, (XComArg, Mapping)): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") + elif not isinstance(kwargs, XComArg): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") + + # It's impossible to build a dict of stubs as keyword arguments if the + # function uses * or ** wildcard arguments. + function_has_vararg = any( + v.kind == inspect.Parameter.VAR_POSITIONAL or v.kind == inspect.Parameter.VAR_KEYWORD + for v in self.function_signature.parameters.values() + ) + if function_has_vararg: + raise TypeError("calling expand_kwargs() on task group function with * or ** is not supported") + + # We can't be sure how each argument is used in the function (well + # technically we can with AST but let's not), so we have to create stubs + # for every argument, including those with default values. + map_kwargs = (k for k in self.function_signature.parameters if k not in self.partial_kwargs) + + expand_input = ListOfDictsExpandInput(kwargs) + return self._create_task_group( + functools.partial(MappedTaskGroup, expand_input=expand_input), + **self.partial_kwargs, + **{k: MappedArgument(input=expand_input, key=k) for k in map_kwargs}, + ) # This covers the @task_group() case. Annotations are copied from the TaskGroup @@ -113,28 +173,27 @@ def partial(self, **kwargs: Any) -> "Group[F]": # disastrous if they go out of sync with TaskGroup. @overload def task_group( - group_id: Optional[str] = None, + group_id: str | None = None, prefix_group_id: bool = True, - parent_group: Optional[TaskGroup] = None, - dag: Optional["DAG"] = None, - default_args: Optional[Dict[str, Any]] = None, + parent_group: TaskGroup | None = None, + dag: DAG | None = None, + default_args: dict[str, Any] | None = None, tooltip: str = "", ui_color: str = "CornflowerBlue", ui_fgcolor: str = "#000", add_suffix_on_collision: bool = False, -) -> Callable[[F], Group[F]]: +) -> Callable[[Callable[FParams, FReturn]], _TaskGroupFactory[FParams, FReturn]]: ... # This covers the @task_group case (no parentheses). @overload -def task_group(python_callable: F) -> Group[F]: +def task_group(python_callable: Callable[FParams, FReturn]) -> _TaskGroupFactory[FParams, FReturn]: ... def task_group(python_callable=None, **tg_kwargs): - """ - Python TaskGroup decorator. + """Python TaskGroup decorator. This wraps a function into an Airflow TaskGroup. When used as the ``@task_group()`` form, all arguments are forwarded to the underlying @@ -143,6 +202,6 @@ def task_group(python_callable=None, **tg_kwargs): :param python_callable: Function to decorate. :param tg_kwargs: Keyword arguments for the TaskGroup object. """ - if callable(python_callable): - return TaskGroupDecorator(function=python_callable, kwargs=tg_kwargs) - return cast(Callable[[F], F], functools.partial(TaskGroupDecorator, kwargs=tg_kwargs)) + if callable(python_callable) and not tg_kwargs: + return _TaskGroupFactory(function=python_callable, tg_kwargs=tg_kwargs) + return functools.partial(_TaskGroupFactory, tg_kwargs=tg_kwargs) diff --git a/airflow/example_dags/example_bash_operator.py b/airflow/example_dags/example_bash_operator.py index 335f3ad961d3b..2f0343e1584d4 100644 --- a/airflow/example_dags/example_bash_operator.py +++ b/airflow/example_dags/example_bash_operator.py @@ -15,8 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Example DAG demonstrating the usage of the BashOperator.""" +from __future__ import annotations import datetime @@ -27,22 +27,22 @@ from airflow.operators.empty import EmptyOperator with DAG( - dag_id='example_bash_operator', - schedule_interval='0 0 * * *', + dag_id="example_bash_operator", + schedule="0 0 * * *", start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, dagrun_timeout=datetime.timedelta(minutes=60), - tags=['example', 'example2'], + tags=["example", "example2"], params={"example_key": "example_value"}, ) as dag: run_this_last = EmptyOperator( - task_id='run_this_last', + task_id="run_this_last", ) # [START howto_operator_bash] run_this = BashOperator( - task_id='run_after_loop', - bash_command='echo 1', + task_id="run_after_loop", + bash_command="echo 1", ) # [END howto_operator_bash] @@ -50,22 +50,22 @@ for i in range(3): task = BashOperator( - task_id='runme_' + str(i), + task_id="runme_" + str(i), bash_command='echo "{{ task_instance_key_str }}" && sleep 1', ) task >> run_this # [START howto_operator_bash_template] also_run_this = BashOperator( - task_id='also_run_this', - bash_command='echo "run_id={{ run_id }} | dag_run={{ dag_run }}"', + task_id="also_run_this", + bash_command='echo "ti_key={{ task_instance_key_str }}"', ) # [END howto_operator_bash_template] also_run_this >> run_this_last # [START howto_operator_bash_skip] this_will_skip = BashOperator( - task_id='this_will_skip', + task_id="this_will_skip", bash_command='echo "hello world"; exit 99;', dag=dag, ) @@ -73,4 +73,4 @@ this_will_skip >> run_this_last if __name__ == "__main__": - dag.cli() + dag.test() diff --git a/airflow/example_dags/example_branch_datetime_operator.py b/airflow/example_dags/example_branch_datetime_operator.py index 3c86e40402aef..8a4d579b4c03a 100644 --- a/airflow/example_dags/example_branch_datetime_operator.py +++ b/airflow/example_dags/example_branch_datetime_operator.py @@ -15,64 +15,90 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example DAG demonstrating the usage of DateTimeBranchOperator with datetime as well as time objects as targets. """ +from __future__ import annotations + import pendulum from airflow import DAG from airflow.operators.datetime import BranchDateTimeOperator from airflow.operators.empty import EmptyOperator -dag = DAG( +dag1 = DAG( dag_id="example_branch_datetime_operator", start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, tags=["example"], - schedule_interval="@daily", + schedule="@daily", ) # [START howto_branch_datetime_operator] -empty_task_1 = EmptyOperator(task_id='date_in_range', dag=dag) -empty_task_2 = EmptyOperator(task_id='date_outside_range', dag=dag) +empty_task_11 = EmptyOperator(task_id="date_in_range", dag=dag1) +empty_task_21 = EmptyOperator(task_id="date_outside_range", dag=dag1) cond1 = BranchDateTimeOperator( - task_id='datetime_branch', - follow_task_ids_if_true=['date_in_range'], - follow_task_ids_if_false=['date_outside_range'], + task_id="datetime_branch", + follow_task_ids_if_true=["date_in_range"], + follow_task_ids_if_false=["date_outside_range"], target_upper=pendulum.datetime(2020, 10, 10, 15, 0, 0), target_lower=pendulum.datetime(2020, 10, 10, 14, 0, 0), - dag=dag, + dag=dag1, ) -# Run empty_task_1 if cond1 executes between 2020-10-10 14:00:00 and 2020-10-10 15:00:00 -cond1 >> [empty_task_1, empty_task_2] +# Run empty_task_11 if cond1 executes between 2020-10-10 14:00:00 and 2020-10-10 15:00:00 +cond1 >> [empty_task_11, empty_task_21] # [END howto_branch_datetime_operator] -dag = DAG( +dag2 = DAG( dag_id="example_branch_datetime_operator_2", start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, tags=["example"], - schedule_interval="@daily", + schedule="@daily", ) # [START howto_branch_datetime_operator_next_day] -empty_task_1 = EmptyOperator(task_id='date_in_range', dag=dag) -empty_task_2 = EmptyOperator(task_id='date_outside_range', dag=dag) +empty_task_12 = EmptyOperator(task_id="date_in_range", dag=dag2) +empty_task_22 = EmptyOperator(task_id="date_outside_range", dag=dag2) cond2 = BranchDateTimeOperator( - task_id='datetime_branch', - follow_task_ids_if_true=['date_in_range'], - follow_task_ids_if_false=['date_outside_range'], + task_id="datetime_branch", + follow_task_ids_if_true=["date_in_range"], + follow_task_ids_if_false=["date_outside_range"], target_upper=pendulum.time(0, 0, 0), target_lower=pendulum.time(15, 0, 0), - dag=dag, + dag=dag2, ) # Since target_lower happens after target_upper, target_upper will be moved to the following day -# Run empty_task_1 if cond2 executes between 15:00:00, and 00:00:00 of the following day -cond2 >> [empty_task_1, empty_task_2] +# Run empty_task_12 if cond2 executes between 15:00:00, and 00:00:00 of the following day +cond2 >> [empty_task_12, empty_task_22] # [END howto_branch_datetime_operator_next_day] + +dag3 = DAG( + dag_id="example_branch_datetime_operator_3", + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + catchup=False, + tags=["example"], + schedule="@daily", +) +# [START howto_branch_datetime_operator_logical_date] +empty_task_13 = EmptyOperator(task_id="date_in_range", dag=dag3) +empty_task_23 = EmptyOperator(task_id="date_outside_range", dag=dag3) + +cond3 = BranchDateTimeOperator( + task_id="datetime_branch", + use_task_logical_date=True, + follow_task_ids_if_true=["date_in_range"], + follow_task_ids_if_false=["date_outside_range"], + target_upper=pendulum.datetime(2020, 10, 10, 15, 0, 0), + target_lower=pendulum.datetime(2020, 10, 10, 14, 0, 0), + dag=dag3, +) + +# Run empty_task_13 if cond3 executes between 2020-10-10 14:00:00 and 2020-10-10 15:00:00 +cond3 >> [empty_task_13, empty_task_23] +# [END howto_branch_datetime_operator_logical_date] diff --git a/airflow/example_dags/example_branch_day_of_week_operator.py b/airflow/example_dags/example_branch_day_of_week_operator.py index 879824ab1c876..fb7dea56b07fe 100644 --- a/airflow/example_dags/example_branch_day_of_week_operator.py +++ b/airflow/example_dags/example_branch_day_of_week_operator.py @@ -15,26 +15,30 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example DAG demonstrating the usage of BranchDayOfWeekOperator. """ +from __future__ import annotations + import pendulum from airflow import DAG from airflow.operators.empty import EmptyOperator from airflow.operators.weekday import BranchDayOfWeekOperator +from airflow.utils.weekday import WeekDay with DAG( dag_id="example_weekday_branch_operator", start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, tags=["example"], - schedule_interval="@daily", + schedule="@daily", ) as dag: # [START howto_operator_day_of_week_branch] - empty_task_1 = EmptyOperator(task_id='branch_true') - empty_task_2 = EmptyOperator(task_id='branch_false') + empty_task_1 = EmptyOperator(task_id="branch_true") + empty_task_2 = EmptyOperator(task_id="branch_false") + empty_task_3 = EmptyOperator(task_id="branch_weekend") + empty_task_4 = EmptyOperator(task_id="branch_mid_week") branch = BranchDayOfWeekOperator( task_id="make_choice", @@ -42,7 +46,15 @@ follow_task_ids_if_false="branch_false", week_day="Monday", ) + branch_weekend = BranchDayOfWeekOperator( + task_id="make_weekend_choice", + follow_task_ids_if_true="branch_weekend", + follow_task_ids_if_false="branch_mid_week", + week_day={WeekDay.SATURDAY, WeekDay.SUNDAY}, + ) - # Run empty_task_1 if branch executes on Monday + # Run empty_task_1 if branch executes on Monday, empty_task_2 otherwise branch >> [empty_task_1, empty_task_2] + # Run empty_task_3 if it's a weekend, empty_task_4 otherwise + empty_task_2 >> branch_weekend >> [empty_task_3, empty_task_4] # [END howto_operator_day_of_week_branch] diff --git a/airflow/example_dags/example_branch_labels.py b/airflow/example_dags/example_branch_labels.py index 72337eb45c873..220bb445cf153 100644 --- a/airflow/example_dags/example_branch_labels.py +++ b/airflow/example_dags/example_branch_labels.py @@ -15,10 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example DAG demonstrating the usage of labels with different branches. """ +from __future__ import annotations + import pendulum from airflow import DAG @@ -27,7 +28,7 @@ with DAG( "example_branch_labels", - schedule_interval="@daily", + schedule="@daily", start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, ) as dag: diff --git a/airflow/example_dags/example_branch_operator.py b/airflow/example_dags/example_branch_operator.py index 8721c78bcbc8e..43bb34eeec68f 100644 --- a/airflow/example_dags/example_branch_operator.py +++ b/airflow/example_dags/example_branch_operator.py @@ -15,8 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Example DAG demonstrating the usage of the BranchPythonOperator.""" +from __future__ import annotations import random @@ -29,26 +29,26 @@ from airflow.utils.trigger_rule import TriggerRule with DAG( - dag_id='example_branch_operator', + dag_id="example_branch_operator", start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, - schedule_interval="@daily", - tags=['example', 'example2'], + schedule="@daily", + tags=["example", "example2"], ) as dag: run_this_first = EmptyOperator( - task_id='run_this_first', + task_id="run_this_first", ) - options = ['branch_a', 'branch_b', 'branch_c', 'branch_d'] + options = ["branch_a", "branch_b", "branch_c", "branch_d"] branching = BranchPythonOperator( - task_id='branching', + task_id="branching", python_callable=lambda: random.choice(options), ) run_this_first >> branching join = EmptyOperator( - task_id='join', + task_id="join", trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS, ) @@ -58,7 +58,7 @@ ) empty_follow = EmptyOperator( - task_id='follow_' + option, + task_id="follow_" + option, ) # Label is optional here, but it can help identify more complex branches diff --git a/airflow/example_dags/example_branch_operator_decorator.py b/airflow/example_dags/example_branch_operator_decorator.py index 0ab4f76cafa2a..91101dc775aa8 100644 --- a/airflow/example_dags/example_branch_operator_decorator.py +++ b/airflow/example_dags/example_branch_operator_decorator.py @@ -15,11 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -"""Example DAG demonstrating the usage of the BranchPythonOperator.""" +"""Example DAG demonstrating the usage of the ``@task.branch`` TaskFlow API decorator.""" +from __future__ import annotations import random -from datetime import datetime + +import pendulum from airflow import DAG from airflow.decorators import task @@ -28,39 +29,30 @@ from airflow.utils.trigger_rule import TriggerRule with DAG( - dag_id='example_branch_python_operator_decorator', - start_date=datetime(2021, 1, 1), + dag_id="example_branch_python_operator_decorator", + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, - schedule_interval="@daily", - tags=['example', 'example2'], + schedule="@daily", + tags=["example", "example2"], ) as dag: - run_this_first = EmptyOperator( - task_id='run_this_first', - ) + run_this_first = EmptyOperator(task_id="run_this_first") - options = ['branch_a', 'branch_b', 'branch_c', 'branch_d'] + options = ["branch_a", "branch_b", "branch_c", "branch_d"] @task.branch(task_id="branching") - def random_choice(): - return random.choice(options) + def random_choice(choices: list[str]) -> str: + return random.choice(choices) - random_choice_instance = random_choice() + random_choice_instance = random_choice(choices=options) run_this_first >> random_choice_instance - join = EmptyOperator( - task_id='join', - trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS, - ) + join = EmptyOperator(task_id="join", trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS) for option in options: - t = EmptyOperator( - task_id=option, - ) + t = EmptyOperator(task_id=option) - empty_follow = EmptyOperator( - task_id='follow_' + option, - ) + empty_follow = EmptyOperator(task_id="follow_" + option) # Label is optional here, but it can help identify more complex branches random_choice_instance >> Label(option) >> t >> empty_follow >> join diff --git a/airflow/example_dags/example_branch_python_dop_operator_3.py b/airflow/example_dags/example_branch_python_dop_operator_3.py index a8e0ce2c1c8af..f42abb46042be 100644 --- a/airflow/example_dags/example_branch_python_dop_operator_3.py +++ b/airflow/example_dags/example_branch_python_dop_operator_3.py @@ -15,48 +15,46 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ -Example DAG demonstrating the usage of BranchPythonOperator with depends_on_past=True, where tasks may be run -or skipped on alternating runs. +Example DAG demonstrating the usage of ``@task.branch`` TaskFlow API decorator with depends_on_past=True, +where tasks may be run or skipped on alternating runs. """ +from __future__ import annotations + import pendulum from airflow import DAG +from airflow.decorators import task from airflow.operators.empty import EmptyOperator -from airflow.operators.python import BranchPythonOperator -def should_run(**kwargs): +@task.branch() +def should_run(**kwargs) -> str: """ Determine which empty_task should be run based on if the execution date minute is even or odd. :param dict kwargs: Context :return: Id of the task to run - :rtype: str """ print( f"------------- exec dttm = {kwargs['execution_date']} and minute = {kwargs['execution_date'].minute}" ) - if kwargs['execution_date'].minute % 2 == 0: + if kwargs["execution_date"].minute % 2 == 0: return "empty_task_1" else: return "empty_task_2" with DAG( - dag_id='example_branch_dop_operator_v3', - schedule_interval='*/1 * * * *', + dag_id="example_branch_dop_operator_v3", + schedule="*/1 * * * *", start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, - default_args={'depends_on_past': True}, - tags=['example'], + default_args={"depends_on_past": True}, + tags=["example"], ) as dag: - cond = BranchPythonOperator( - task_id='condition', - python_callable=should_run, - ) + cond = should_run() - empty_task_1 = EmptyOperator(task_id='empty_task_1') - empty_task_2 = EmptyOperator(task_id='empty_task_2') + empty_task_1 = EmptyOperator(task_id="empty_task_1") + empty_task_2 = EmptyOperator(task_id="empty_task_2") cond >> [empty_task_1, empty_task_2] diff --git a/airflow/example_dags/example_complex.py b/airflow/example_dags/example_complex.py index 22e1906c042dd..f7c7e25aff62a 100644 --- a/airflow/example_dags/example_complex.py +++ b/airflow/example_dags/example_complex.py @@ -15,10 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example Airflow DAG that shows the complex DAG structure. """ +from __future__ import annotations + import pendulum from airflow import models @@ -27,10 +28,10 @@ with models.DAG( dag_id="example_complex", - schedule_interval=None, + schedule=None, start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, - tags=['example', 'example2', 'example3'], + tags=["example", "example2", "example3"], ) as dag: # Create create_entry_group = BashOperator(task_id="create_entry_group", bash_command="echo create_entry_group") diff --git a/airflow/example_dags/example_dag_decorator.py b/airflow/example_dags/example_dag_decorator.py index 88e0282016dd2..e8ee8a72997a2 100644 --- a/airflow/example_dags/example_dag_decorator.py +++ b/airflow/example_dags/example_dag_decorator.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +from __future__ import annotations + +from typing import Any import httpx import pendulum @@ -39,33 +41,33 @@ def execute(self, context: Context): # [START dag_decorator_usage] @dag( - schedule_interval=None, + schedule=None, start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, - tags=['example'], + tags=["example"], ) -def example_dag_decorator(email: str = 'example@example.com'): +def example_dag_decorator(email: str = "example@example.com"): """ DAG to send server IP to email. :param email: Email to send IP to. Defaults to example@example.com. """ - get_ip = GetRequestOperator(task_id='get_ip', url="http://httpbin.org/get") + get_ip = GetRequestOperator(task_id="get_ip", url="http://httpbin.org/get") @task(multiple_outputs=True) - def prepare_email(raw_json: Dict[str, Any]) -> Dict[str, str]: - external_ip = raw_json['origin'] + def prepare_email(raw_json: dict[str, Any]) -> dict[str, str]: + external_ip = raw_json["origin"] return { - 'subject': f'Server connected from {external_ip}', - 'body': f'Seems like today your server executing Airflow is connected from IP {external_ip}
', + "subject": f"Server connected from {external_ip}", + "body": f"Seems like today your server executing Airflow is connected from IP {external_ip}
", } email_info = prepare_email(get_ip.output) EmailOperator( - task_id='send_email', to=email, subject=email_info['subject'], html_content=email_info['body'] + task_id="send_email", to=email, subject=email_info["subject"], html_content=email_info["body"] ) -dag = example_dag_decorator() +example_dag = example_dag_decorator() # [END dag_decorator_usage] diff --git a/airflow/example_dags/example_datasets.py b/airflow/example_dags/example_datasets.py new file mode 100644 index 0000000000000..c2db9b1968614 --- /dev/null +++ b/airflow/example_dags/example_datasets.py @@ -0,0 +1,129 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example DAG for demonstrating behavior of Datasets feature. + +Notes on usage: + +Turn on all the dags. + +DAG dataset_produces_1 should run because it's on a schedule. + +After dataset_produces_1 runs, dataset_consumes_1 should be triggered immediately +because its only dataset dependency is managed by dataset_produces_1. + +No other dags should be triggered. Note that even though dataset_consumes_1_and_2 depends on +the dataset in dataset_produces_1, it will not be triggered until dataset_produces_2 runs +(and dataset_produces_2 is left with no schedule so that we can trigger it manually). + +Next, trigger dataset_produces_2. After dataset_produces_2 finishes, +dataset_consumes_1_and_2 should run. + +Dags dataset_consumes_1_never_scheduled and dataset_consumes_unknown_never_scheduled should not run because +they depend on datasets that never get updated. +""" +from __future__ import annotations + +import pendulum + +from airflow import DAG, Dataset +from airflow.operators.bash import BashOperator + +# [START dataset_def] +dag1_dataset = Dataset("s3://dag1/output_1.txt", extra={"hi": "bye"}) +# [END dataset_def] +dag2_dataset = Dataset("s3://dag2/output_1.txt", extra={"hi": "bye"}) + +with DAG( + dag_id="dataset_produces_1", + catchup=False, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule="@daily", + tags=["produces", "dataset-scheduled"], +) as dag1: + # [START task_outlet] + BashOperator(outlets=[dag1_dataset], task_id="producing_task_1", bash_command="sleep 5") + # [END task_outlet] + +with DAG( + dag_id="dataset_produces_2", + catchup=False, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=None, + tags=["produces", "dataset-scheduled"], +) as dag2: + BashOperator(outlets=[dag2_dataset], task_id="producing_task_2", bash_command="sleep 5") + +# [START dag_dep] +with DAG( + dag_id="dataset_consumes_1", + catchup=False, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=[dag1_dataset], + tags=["consumes", "dataset-scheduled"], +) as dag3: + # [END dag_dep] + BashOperator( + outlets=[Dataset("s3://consuming_1_task/dataset_other.txt")], + task_id="consuming_1", + bash_command="sleep 5", + ) + +with DAG( + dag_id="dataset_consumes_1_and_2", + catchup=False, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=[dag1_dataset, dag2_dataset], + tags=["consumes", "dataset-scheduled"], +) as dag4: + BashOperator( + outlets=[Dataset("s3://consuming_2_task/dataset_other_unknown.txt")], + task_id="consuming_2", + bash_command="sleep 5", + ) + +with DAG( + dag_id="dataset_consumes_1_never_scheduled", + catchup=False, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=[ + dag1_dataset, + Dataset("s3://this-dataset-doesnt-get-triggered"), + ], + tags=["consumes", "dataset-scheduled"], +) as dag5: + BashOperator( + outlets=[Dataset("s3://consuming_2_task/dataset_other_unknown.txt")], + task_id="consuming_3", + bash_command="sleep 5", + ) + +with DAG( + dag_id="dataset_consumes_unknown_never_scheduled", + catchup=False, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=[ + Dataset("s3://unrelated/dataset3.txt"), + Dataset("s3://unrelated/dataset_other_unknown.txt"), + ], + tags=["dataset-scheduled"], +) as dag6: + BashOperator( + task_id="unrelated_task", + outlets=[Dataset("s3://unrelated_task/dataset_other_unknown.txt")], + bash_command="sleep 5", + ) diff --git a/airflow/example_dags/example_dynamic_task_mapping.py b/airflow/example_dags/example_dynamic_task_mapping.py new file mode 100644 index 0000000000000..dce6cda20972c --- /dev/null +++ b/airflow/example_dags/example_dynamic_task_mapping.py @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example DAG demonstrating the usage of dynamic task mapping.""" +from __future__ import annotations + +from datetime import datetime + +from airflow import DAG +from airflow.decorators import task + +with DAG(dag_id="example_dynamic_task_mapping", start_date=datetime(2022, 3, 4)) as dag: + + @task + def add_one(x: int): + return x + 1 + + @task + def sum_it(values): + total = sum(values) + print(f"Total was {total}") + + added_values = add_one.expand(x=[1, 2, 3]) + sum_it(added_values) diff --git a/airflow/example_dags/example_external_task_marker_dag.py b/airflow/example_dags/example_external_task_marker_dag.py index 0c4479a0d66f0..1b4f5d3ffd8f2 100644 --- a/airflow/example_dags/example_external_task_marker_dag.py +++ b/airflow/example_dags/example_external_task_marker_dag.py @@ -15,27 +15,29 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example DAG demonstrating setting up inter-DAG dependencies using ExternalTaskSensor and -ExternalTaskMarker +ExternalTaskMarker. In this example, child_task1 in example_external_task_marker_child depends on parent_task in -example_external_task_marker_parent. When parent_task is cleared with "Recursive" selected, -the presence of ExternalTaskMarker tells Airflow to clear child_task1 and its -downstream tasks. +example_external_task_marker_parent. When parent_task is cleared with 'Recursive' selected, +the presence of ExternalTaskMarker tells Airflow to clear child_task1 and its downstream tasks. ExternalTaskSensor will keep poking for the status of remote ExternalTaskMarker task at a regular interval till one of the following will happen: -1. ExternalTaskMarker reaches the states mentioned in the allowed_states list - In this case, ExternalTaskSensor will exit with a success status code -2. ExternalTaskMarker reaches the states mentioned in the failed_states list - In this case, ExternalTaskSensor will raise an AirflowException and user need to handle this - with multiple downstream tasks -3. ExternalTaskSensor times out - In this case, ExternalTaskSensor will raise AirflowSkipException or AirflowSensorTimeout - exception + +ExternalTaskMarker reaches the states mentioned in the allowed_states list. +In this case, ExternalTaskSensor will exit with a success status code + +ExternalTaskMarker reaches the states mentioned in the failed_states list +In this case, ExternalTaskSensor will raise an AirflowException and user need to handle this +with multiple downstream tasks + +ExternalTaskSensor times out. In this case, ExternalTaskSensor will raise AirflowSkipException +or AirflowSensorTimeout exception + """ +from __future__ import annotations import pendulum @@ -49,8 +51,8 @@ dag_id="example_external_task_marker_parent", start_date=start_date, catchup=False, - schedule_interval=None, - tags=['example2'], + schedule=None, + tags=["example2"], ) as parent_dag: # [START howto_operator_external_task_marker] parent_task = ExternalTaskMarker( @@ -63,9 +65,9 @@ with DAG( dag_id="example_external_task_marker_child", start_date=start_date, - schedule_interval=None, + schedule=None, catchup=False, - tags=['example2'], + tags=["example2"], ) as child_dag: # [START howto_operator_external_task_sensor] child_task1 = ExternalTaskSensor( @@ -73,10 +75,23 @@ external_dag_id=parent_dag.dag_id, external_task_id=parent_task.task_id, timeout=600, - allowed_states=['success'], - failed_states=['failed', 'skipped'], + allowed_states=["success"], + failed_states=["failed", "skipped"], mode="reschedule", ) # [END howto_operator_external_task_sensor] - child_task2 = EmptyOperator(task_id="child_task2") - child_task1 >> child_task2 + + # [START howto_operator_external_task_sensor_with_task_group] + child_task2 = ExternalTaskSensor( + task_id="child_task2", + external_dag_id=parent_dag.dag_id, + external_task_group_id="parent_dag_task_group_id", + timeout=600, + allowed_states=["success"], + failed_states=["failed", "skipped"], + mode="reschedule", + ) + # [END howto_operator_external_task_sensor_with_task_group] + + child_task3 = EmptyOperator(task_id="child_task3") + child_task1 >> child_task2 >> child_task3 diff --git a/airflow/example_dags/example_kubernetes_executor.py b/airflow/example_dags/example_kubernetes_executor.py index f278eb379b1e8..b9e6bdba35445 100644 --- a/airflow/example_dags/example_kubernetes_executor.py +++ b/airflow/example_dags/example_kubernetes_executor.py @@ -18,6 +18,8 @@ """ This is an example dag for using a Kubernetes Executor Configuration. """ +from __future__ import annotations + import logging import os @@ -30,8 +32,8 @@ log = logging.getLogger(__name__) -worker_container_repository = conf.get('kubernetes', 'worker_container_repository') -worker_container_tag = conf.get('kubernetes', 'worker_container_tag') +worker_container_repository = conf.get("kubernetes_executor", "worker_container_repository") +worker_container_tag = conf.get("kubernetes_executor", "worker_container_tag") try: from kubernetes.client import models as k8s @@ -45,11 +47,11 @@ if k8s: with DAG( - dag_id='example_kubernetes_executor', - schedule_interval=None, + dag_id="example_kubernetes_executor", + schedule=None, start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, - tags=['example3'], + tags=["example3"], ) as dag: # You can use annotations on your kubernetes pods! start_task_executor_config = { @@ -88,8 +90,8 @@ def test_volume_mount(): Tests whether the volume has been mounted. """ - with open('/foo/volume_mount_test.txt', 'w') as foo: - foo.write('Hello') + with open("/foo/volume_mount_test.txt", "w") as foo: + foo.write("Hello") return_code = os.system("cat /foo/volume_mount_test.txt") if return_code != 0: @@ -110,7 +112,7 @@ def test_volume_mount(): k8s.V1Container( name="sidecar", image="ubuntu", - args=["echo \"retrieved from mount\" > /shared/test.txt"], + args=['echo "retrieved from mount" > /shared/test.txt'], command=["bash", "-cx"], volume_mounts=[k8s.V1VolumeMount(mount_path="/shared/", name="shared-empty-dir")], ), @@ -152,7 +154,7 @@ def non_root_task(): executor_config_other_ns = { "pod_override": k8s.V1Pod( - metadata=k8s.V1ObjectMeta(namespace="test-namespace", labels={'release': 'stable'}) + metadata=k8s.V1ObjectMeta(namespace="test-namespace", labels={"release": "stable"}) ) } @@ -191,21 +193,21 @@ def base_image_override_task(): k8s.V1PodAffinityTerm( label_selector=k8s.V1LabelSelector( match_expressions=[ - k8s.V1LabelSelectorRequirement(key='app', operator='In', values=['airflow']) + k8s.V1LabelSelectorRequirement(key="app", operator="In", values=["airflow"]) ] ), - topology_key='kubernetes.io/hostname', + topology_key="kubernetes.io/hostname", ) ] ) ) # Use k8s_client.V1Toleration to define node tolerations - k8s_tolerations = [k8s.V1Toleration(key='dedicated', operator='Equal', value='airflow')] + k8s_tolerations = [k8s.V1Toleration(key="dedicated", operator="Equal", value="airflow")] # Use k8s_client.V1ResourceRequirements to define resource limits k8s_resource_requirements = k8s.V1ResourceRequirements( - requests={'memory': '512Mi'}, limits={'memory': '512Mi'} + requests={"memory": "512Mi"}, limits={"memory": "512Mi"} ) kube_exec_config_resource_limits = { diff --git a/airflow/example_dags/example_latest_only.py b/airflow/example_dags/example_latest_only.py index 92ec1436a6951..3dc4e9137e350 100644 --- a/airflow/example_dags/example_latest_only.py +++ b/airflow/example_dags/example_latest_only.py @@ -15,8 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Example of the LatestOnlyOperator""" +from __future__ import annotations import datetime as dt @@ -25,13 +25,13 @@ from airflow.operators.latest_only import LatestOnlyOperator with DAG( - dag_id='latest_only', - schedule_interval=dt.timedelta(hours=4), + dag_id="latest_only", + schedule=dt.timedelta(hours=4), start_date=dt.datetime(2021, 1, 1), catchup=False, - tags=['example2', 'example3'], + tags=["example2", "example3"], ) as dag: - latest_only = LatestOnlyOperator(task_id='latest_only') - task1 = EmptyOperator(task_id='task1') + latest_only = LatestOnlyOperator(task_id="latest_only") + task1 = EmptyOperator(task_id="task1") latest_only >> task1 diff --git a/airflow/example_dags/example_latest_only_with_trigger.py b/airflow/example_dags/example_latest_only_with_trigger.py index 56d6a24462217..e71d7e05c263c 100644 --- a/airflow/example_dags/example_latest_only_with_trigger.py +++ b/airflow/example_dags/example_latest_only_with_trigger.py @@ -18,6 +18,7 @@ """ Example LatestOnlyOperator and TriggerRule interactions """ +from __future__ import annotations # [START example] import datetime @@ -30,17 +31,17 @@ from airflow.utils.trigger_rule import TriggerRule with DAG( - dag_id='latest_only_with_trigger', - schedule_interval=datetime.timedelta(hours=4), + dag_id="latest_only_with_trigger", + schedule=datetime.timedelta(hours=4), start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, - tags=['example3'], + tags=["example3"], ) as dag: - latest_only = LatestOnlyOperator(task_id='latest_only') - task1 = EmptyOperator(task_id='task1') - task2 = EmptyOperator(task_id='task2') - task3 = EmptyOperator(task_id='task3') - task4 = EmptyOperator(task_id='task4', trigger_rule=TriggerRule.ALL_DONE) + latest_only = LatestOnlyOperator(task_id="latest_only") + task1 = EmptyOperator(task_id="task1") + task2 = EmptyOperator(task_id="task2") + task3 = EmptyOperator(task_id="task3") + task4 = EmptyOperator(task_id="task4", trigger_rule=TriggerRule.ALL_DONE) latest_only >> task1 >> [task3, task4] task2 >> [task3, task4] diff --git a/airflow/example_dags/example_local_kubernetes_executor.py b/airflow/example_dags/example_local_kubernetes_executor.py index e586cafda9416..db2d7acf0dbe5 100644 --- a/airflow/example_dags/example_local_kubernetes_executor.py +++ b/airflow/example_dags/example_local_kubernetes_executor.py @@ -18,6 +18,8 @@ """ This is an example dag for using a Local Kubernetes Executor Configuration. """ +from __future__ import annotations + import logging from datetime import datetime @@ -28,8 +30,8 @@ log = logging.getLogger(__name__) -worker_container_repository = conf.get('kubernetes', 'worker_container_repository') -worker_container_tag = conf.get('kubernetes', 'worker_container_tag') +worker_container_repository = conf.get("kubernetes_executor", "worker_container_repository") +worker_container_tag = conf.get("kubernetes_executor", "worker_container_tag") try: from kubernetes.client import models as k8s @@ -40,11 +42,11 @@ if k8s: with DAG( - dag_id='example_local_kubernetes_executor', - schedule_interval=None, + dag_id="example_local_kubernetes_executor", + schedule=None, start_date=datetime(2021, 1, 1), catchup=False, - tags=['example3'], + tags=["example3"], ) as dag: # You can use annotations on your kubernetes pods! start_task_executor_config = { @@ -53,17 +55,17 @@ @task( executor_config=start_task_executor_config, - queue='kubernetes', - task_id='task_with_kubernetes_executor', + queue="kubernetes", + task_id="task_with_kubernetes_executor", ) def task_with_template(): print_stuff() - @task(task_id='task_with_local_executor') + @task(task_id="task_with_local_executor") def task_with_local(ds=None, **kwargs): """Print the Airflow context and ds variable from the context.""" print(kwargs) print(ds) - return 'Whatever you return gets printed in the logs' + return "Whatever you return gets printed in the logs" task_with_local() >> task_with_template() diff --git a/airflow/example_dags/example_nested_branch_dag.py b/airflow/example_dags/example_nested_branch_dag.py index 14ac0a43ce245..7c46592455983 100644 --- a/airflow/example_dags/example_nested_branch_dag.py +++ b/airflow/example_dags/example_nested_branch_dag.py @@ -15,31 +15,38 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example DAG demonstrating a workflow with nested branching. The join tasks are created with ``none_failed_min_one_success`` trigger rule such that they are skipped whenever their corresponding -``BranchPythonOperator`` are skipped. +branching tasks are skipped. """ +from __future__ import annotations + import pendulum +from airflow.decorators import task from airflow.models import DAG from airflow.operators.empty import EmptyOperator -from airflow.operators.python import BranchPythonOperator from airflow.utils.trigger_rule import TriggerRule with DAG( dag_id="example_nested_branch_dag", start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, - schedule_interval="@daily", + schedule="@daily", tags=["example"], ) as dag: - branch_1 = BranchPythonOperator(task_id="branch_1", python_callable=lambda: "true_1") + + @task.branch() + def branch(task_id_to_return: str) -> str: + return task_id_to_return + + branch_1 = branch.override(task_id="branch_1")(task_id_to_return="true_1") join_1 = EmptyOperator(task_id="join_1", trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS) true_1 = EmptyOperator(task_id="true_1") false_1 = EmptyOperator(task_id="false_1") - branch_2 = BranchPythonOperator(task_id="branch_2", python_callable=lambda: "true_2") + + branch_2 = branch.override(task_id="branch_2")(task_id_to_return="true_2") join_2 = EmptyOperator(task_id="join_2", trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS) true_2 = EmptyOperator(task_id="true_2") false_2 = EmptyOperator(task_id="false_2") diff --git a/airflow/example_dags/example_passing_params_via_test_command.py b/airflow/example_dags/example_passing_params_via_test_command.py index 8057d5fd54a13..055c8639f90a4 100644 --- a/airflow/example_dags/example_passing_params_via_test_command.py +++ b/airflow/example_dags/example_passing_params_via_test_command.py @@ -15,8 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Example DAG demonstrating the usage of the params arguments in templated arguments.""" +from __future__ import annotations import datetime import os @@ -59,11 +59,11 @@ def print_env_vars(test_mode=None): with DAG( "example_passing_params_via_test_command", - schedule_interval='*/1 * * * *', + schedule="*/1 * * * *", start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, dagrun_timeout=datetime.timedelta(minutes=4), - tags=['example'], + tags=["example"], ) as dag: run_this = my_py_command(params={"miff": "agg"}) @@ -75,7 +75,7 @@ def print_env_vars(test_mode=None): ) also_run_this = BashOperator( - task_id='also_run_this', + task_id="also_run_this", bash_command=my_command, params={"miff": "agg"}, env={"FOO": "{{ params.foo }}", "MIFF": "{{ params.miff }}"}, diff --git a/airflow/example_dags/example_python_operator.py b/airflow/example_dags/example_python_operator.py index 0f9a7fc476acb..4f891abe60d02 100644 --- a/airflow/example_dags/example_python_operator.py +++ b/airflow/example_dags/example_python_operator.py @@ -15,13 +15,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example DAG demonstrating the usage of the TaskFlow API to execute Python functions natively and within a virtual environment. """ +from __future__ import annotations + import logging import shutil +import sys +import tempfile import time from pprint import pprint @@ -29,39 +32,58 @@ from airflow import DAG from airflow.decorators import task +from airflow.operators.python import ExternalPythonOperator, PythonVirtualenvOperator log = logging.getLogger(__name__) +PATH_TO_PYTHON_BINARY = sys.executable + +BASE_DIR = tempfile.gettempdir() + + +def x(): + pass + + with DAG( - dag_id='example_python_operator', - schedule_interval=None, + dag_id="example_python_operator", + schedule=None, start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, - tags=['example'], + tags=["example"], ) as dag: + # [START howto_operator_python] @task(task_id="print_the_context") def print_context(ds=None, **kwargs): """Print the Airflow context and ds variable from the context.""" pprint(kwargs) print(ds) - return 'Whatever you return gets printed in the logs' + return "Whatever you return gets printed in the logs" run_this = print_context() # [END howto_operator_python] + # [START howto_operator_python_render_sql] + @task(task_id="log_sql_query", templates_dict={"query": "sql/sample.sql"}, templates_exts=[".sql"]) + def log_sql(**kwargs): + logging.info("Python task decorator query: %s", str(kwargs["templates_dict"]["query"])) + + log_the_sql = log_sql() + # [END howto_operator_python_render_sql] + # [START howto_operator_python_kwargs] # Generate 5 sleeping tasks, sleeping from 0.0 to 0.4 seconds respectively for i in range(5): - @task(task_id=f'sleep_for_{i}') + @task(task_id=f"sleep_for_{i}") def my_sleeping_function(random_base): """This is a function that will run within the DAG execution""" time.sleep(random_base) sleeping_task = my_sleeping_function(random_base=float(i) / 10) - run_this >> sleeping_task + run_this >> log_the_sql >> sleeping_task # [END howto_operator_python_kwargs] if not shutil.which("virtualenv"): @@ -82,14 +104,56 @@ def callable_virtualenv(): from colorama import Back, Fore, Style - print(Fore.RED + 'some red text') - print(Back.GREEN + 'and with a green background') - print(Style.DIM + 'and in dim text') + print(Fore.RED + "some red text") + print(Back.GREEN + "and with a green background") + print(Style.DIM + "and in dim text") print(Style.RESET_ALL) - for _ in range(10): - print(Style.DIM + 'Please wait...', flush=True) - sleep(10) - print('Finished') + for _ in range(4): + print(Style.DIM + "Please wait...", flush=True) + sleep(1) + print("Finished") virtualenv_task = callable_virtualenv() # [END howto_operator_python_venv] + + sleeping_task >> virtualenv_task + + # [START howto_operator_external_python] + @task.external_python(task_id="external_python", python=PATH_TO_PYTHON_BINARY) + def callable_external_python(): + """ + Example function that will be performed in a virtual environment. + + Importing at the module level ensures that it will not attempt to import the + library before it is installed. + """ + import sys + from time import sleep + + print(f"Running task via {sys.executable}") + print("Sleeping") + for _ in range(4): + print("Please wait...", flush=True) + sleep(1) + print("Finished") + + external_python_task = callable_external_python() + # [END howto_operator_external_python] + + # [START howto_operator_external_python_classic] + external_classic = ExternalPythonOperator( + task_id="external_python_classic", + python=PATH_TO_PYTHON_BINARY, + python_callable=x, + ) + # [END howto_operator_external_python_classic] + + # [START howto_operator_python_venv_classic] + virtual_classic = PythonVirtualenvOperator( + task_id="virtualenv_classic", + requirements="colorama==0.4.0", + python_callable=x, + ) + # [END howto_operator_python_venv_classic] + + run_this >> external_classic >> external_python_task >> virtual_classic diff --git a/airflow/example_dags/example_sensor_decorator.py b/airflow/example_dags/example_sensor_decorator.py new file mode 100644 index 0000000000000..2197a6c53af74 --- /dev/null +++ b/airflow/example_dags/example_sensor_decorator.py @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Example DAG demonstrating the usage of the sensor decorator.""" + +from __future__ import annotations + +# [START tutorial] +# [START import_module] +import pendulum + +from airflow.decorators import dag, task +from airflow.sensors.base import PokeReturnValue + +# [END import_module] + + +# [START instantiate_dag] +@dag( + schedule_interval=None, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + catchup=False, + tags=["example"], +) +def example_sensor_decorator(): + # [END instantiate_dag] + + # [START wait_function] + # Using a sensor operator to wait for the upstream data to be ready. + @task.sensor(poke_interval=60, timeout=3600, mode="reschedule") + def wait_for_upstream() -> PokeReturnValue: + return PokeReturnValue(is_done=True, xcom_value="xcom_value") + + # [END wait_function] + + # [START dummy_function] + @task + def dummy_operator() -> None: + pass + + # [END dummy_function] + + # [START main_flow] + wait_for_upstream() >> dummy_operator() + # [END main_flow] + + +# [START dag_invocation] +tutorial_etl_dag = example_sensor_decorator() +# [END dag_invocation] + +# [END tutorial] diff --git a/airflow/example_dags/example_sensors.py b/airflow/example_dags/example_sensors.py new file mode 100644 index 0000000000000..d9e3158f544bc --- /dev/null +++ b/airflow/example_dags/example_sensors.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from datetime import datetime, timedelta + +import pendulum +from pytz import UTC + +from airflow.models import DAG +from airflow.operators.bash import BashOperator +from airflow.sensors.bash import BashSensor +from airflow.sensors.filesystem import FileSensor +from airflow.sensors.python import PythonSensor +from airflow.sensors.time_delta import TimeDeltaSensor, TimeDeltaSensorAsync +from airflow.sensors.time_sensor import TimeSensor, TimeSensorAsync +from airflow.sensors.weekday import DayOfWeekSensor +from airflow.utils.trigger_rule import TriggerRule +from airflow.utils.weekday import WeekDay + + +# [START example_callables] +def success_callable(): + return True + + +def failure_callable(): + return False + + +# [END example_callables] + + +with DAG( + dag_id="example_sensors", + schedule=None, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + catchup=False, + tags=["example"], +) as dag: + # [START example_time_delta_sensor] + t0 = TimeDeltaSensor(task_id="wait_some_seconds", delta=timedelta(seconds=2)) + # [END example_time_delta_sensor] + + # [START example_time_delta_sensor_async] + t0a = TimeDeltaSensorAsync(task_id="wait_some_seconds_async", delta=timedelta(seconds=2)) + # [END example_time_delta_sensor_async] + + # [START example_time_sensors] + t1 = TimeSensor(task_id="fire_immediately", target_time=datetime.now(tz=UTC).time()) + + t2 = TimeSensor( + task_id="timeout_after_second_date_in_the_future", + timeout=1, + soft_fail=True, + target_time=(datetime.now(tz=UTC) + timedelta(hours=1)).time(), + ) + # [END example_time_sensors] + + # [START example_time_sensors_async] + t1a = TimeSensorAsync(task_id="fire_immediately_async", target_time=datetime.now(tz=UTC).time()) + + t2a = TimeSensorAsync( + task_id="timeout_after_second_date_in_the_future_async", + timeout=1, + soft_fail=True, + target_time=(datetime.now(tz=UTC) + timedelta(hours=1)).time(), + ) + # [END example_time_sensors_async] + + # [START example_bash_sensors] + t3 = BashSensor(task_id="Sensor_succeeds", bash_command="exit 0") + + t4 = BashSensor(task_id="Sensor_fails_after_3_seconds", timeout=3, soft_fail=True, bash_command="exit 1") + # [END example_bash_sensors] + + t5 = BashOperator(task_id="remove_file", bash_command="rm -rf /tmp/temporary_file_for_testing") + + # [START example_file_sensor] + t6 = FileSensor(task_id="wait_for_file", filepath="/tmp/temporary_file_for_testing") + # [END example_file_sensor] + + t7 = BashOperator( + task_id="create_file_after_3_seconds", bash_command="sleep 3; touch /tmp/temporary_file_for_testing" + ) + + # [START example_python_sensors] + t8 = PythonSensor(task_id="success_sensor_python", python_callable=success_callable) + + t9 = PythonSensor( + task_id="failure_timeout_sensor_python", timeout=3, soft_fail=True, python_callable=failure_callable + ) + # [END example_python_sensors] + + # [START example_day_of_week_sensor] + t10 = DayOfWeekSensor( + task_id="week_day_sensor_failing_on_timeout", timeout=3, soft_fail=True, week_day=WeekDay.MONDAY + ) + # [END example_day_of_week_sensor] + + tx = BashOperator(task_id="print_date_in_bash", bash_command="date") + + tx.trigger_rule = TriggerRule.NONE_FAILED + [t0, t0a, t1, t1a, t2, t2a, t3, t4] >> tx + t5 >> t6 >> tx + t7 >> tx + [t8, t9] >> tx + t10 >> tx diff --git a/airflow/example_dags/example_short_circuit_decorator.py b/airflow/example_dags/example_short_circuit_decorator.py new file mode 100644 index 0000000000000..30f6cd0e012bb --- /dev/null +++ b/airflow/example_dags/example_short_circuit_decorator.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example DAG demonstrating the usage of the `@task.short_circuit()` TaskFlow decorator.""" +from __future__ import annotations + +import pendulum + +from airflow.decorators import dag, task +from airflow.models.baseoperator import chain +from airflow.operators.empty import EmptyOperator +from airflow.utils.trigger_rule import TriggerRule + + +@dag(start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, tags=["example"]) +def example_short_circuit_decorator(): + # [START howto_operator_short_circuit] + @task.short_circuit() + def check_condition(condition): + return condition + + ds_true = [EmptyOperator(task_id="true_" + str(i)) for i in [1, 2]] + ds_false = [EmptyOperator(task_id="false_" + str(i)) for i in [1, 2]] + + condition_is_true = check_condition.override(task_id="condition_is_true")(condition=True) + condition_is_false = check_condition.override(task_id="condition_is_false")(condition=False) + + chain(condition_is_true, *ds_true) + chain(condition_is_false, *ds_false) + # [END howto_operator_short_circuit] + + # [START howto_operator_short_circuit_trigger_rules] + [task_1, task_2, task_3, task_4, task_5, task_6] = [ + EmptyOperator(task_id=f"task_{i}") for i in range(1, 7) + ] + + task_7 = EmptyOperator(task_id="task_7", trigger_rule=TriggerRule.ALL_DONE) + + short_circuit = check_condition.override(task_id="short_circuit", ignore_downstream_trigger_rules=False)( + condition=False + ) + + chain(task_1, [task_2, short_circuit], [task_3, task_4], [task_5, task_6], task_7) + # [END howto_operator_short_circuit_trigger_rules] + + +example_dag = example_short_circuit_decorator() diff --git a/airflow/example_dags/example_short_circuit_operator.py b/airflow/example_dags/example_short_circuit_operator.py index 2278de30e6294..77f976c502a16 100644 --- a/airflow/example_dags/example_short_circuit_operator.py +++ b/airflow/example_dags/example_short_circuit_operator.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Example DAG demonstrating the usage of the ShortCircuitOperator.""" +from __future__ import annotations + import pendulum from airflow import DAG @@ -26,30 +27,27 @@ from airflow.utils.trigger_rule import TriggerRule with DAG( - dag_id='example_short_circuit_operator', + dag_id="example_short_circuit_operator", start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, - tags=['example'], + tags=["example"], ) as dag: - # [START howto_operator_short_circuit] cond_true = ShortCircuitOperator( - task_id='condition_is_True', + task_id="condition_is_True", python_callable=lambda: True, ) cond_false = ShortCircuitOperator( - task_id='condition_is_False', + task_id="condition_is_False", python_callable=lambda: False, ) - ds_true = [EmptyOperator(task_id='true_' + str(i)) for i in [1, 2]] - ds_false = [EmptyOperator(task_id='false_' + str(i)) for i in [1, 2]] + ds_true = [EmptyOperator(task_id="true_" + str(i)) for i in [1, 2]] + ds_false = [EmptyOperator(task_id="false_" + str(i)) for i in [1, 2]] chain(cond_true, *ds_true) chain(cond_false, *ds_false) - # [END howto_operator_short_circuit] - # [START howto_operator_short_circuit_trigger_rules] [task_1, task_2, task_3, task_4, task_5, task_6] = [ EmptyOperator(task_id=f"task_{i}") for i in range(1, 7) ] @@ -61,4 +59,3 @@ ) chain(task_1, [task_2, short_circuit], [task_3, task_4], [task_5, task_6], task_7) - # [END howto_operator_short_circuit_trigger_rules] diff --git a/airflow/example_dags/example_skip_dag.py b/airflow/example_dags/example_skip_dag.py index 00d3a3d91e11b..c8b8958446686 100644 --- a/airflow/example_dags/example_skip_dag.py +++ b/airflow/example_dags/example_skip_dag.py @@ -15,8 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Example DAG demonstrating the EmptyOperator and a custom EmptySkipOperator which skips by default.""" +from __future__ import annotations import pendulum @@ -31,7 +31,7 @@ class EmptySkipOperator(EmptyOperator): """Empty operator which always skips the task.""" - ui_color = '#e8b7e4' + ui_color = "#e8b7e4" def execute(self, context: Context): raise AirflowSkipException @@ -45,10 +45,10 @@ def create_test_pipeline(suffix, trigger_rule): :param str trigger_rule: TriggerRule for the join task :param DAG dag_: The DAG to run the operators on """ - skip_operator = EmptySkipOperator(task_id=f'skip_operator_{suffix}') - always_true = EmptyOperator(task_id=f'always_true_{suffix}') + skip_operator = EmptySkipOperator(task_id=f"skip_operator_{suffix}") + always_true = EmptyOperator(task_id=f"always_true_{suffix}") join = EmptyOperator(task_id=trigger_rule, trigger_rule=trigger_rule) - final = EmptyOperator(task_id=f'final_{suffix}') + final = EmptyOperator(task_id=f"final_{suffix}") skip_operator >> join always_true >> join @@ -56,10 +56,10 @@ def create_test_pipeline(suffix, trigger_rule): with DAG( - dag_id='example_skip_dag', + dag_id="example_skip_dag", start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, - tags=['example'], + tags=["example"], ) as dag: - create_test_pipeline('1', TriggerRule.ALL_SUCCESS) - create_test_pipeline('2', TriggerRule.ONE_SUCCESS) + create_test_pipeline("1", TriggerRule.ALL_SUCCESS) + create_test_pipeline("2", TriggerRule.ONE_SUCCESS) diff --git a/airflow/example_dags/example_sla_dag.py b/airflow/example_dags/example_sla_dag.py index 0db6bc1ba7fcc..e76f40f10ddad 100644 --- a/airflow/example_dags/example_sla_dag.py +++ b/airflow/example_dags/example_sla_dag.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Example DAG demonstrating SLA use in Tasks""" +from __future__ import annotations import datetime import time @@ -22,8 +24,6 @@ from airflow.decorators import dag, task -"""Example DAG demonstrating SLA use in Tasks""" - # [START howto_task_sla] def sla_callback(dag, task_list, blocking_task_list, slas, blocking_tis): @@ -40,11 +40,11 @@ def sla_callback(dag, task_list, blocking_task_list, slas, blocking_tis): @dag( - schedule_interval="*/2 * * * *", + schedule="*/2 * * * *", start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, sla_miss_callback=sla_callback, - default_args={'email': "email@example.com"}, + default_args={"email": "email@example.com"}, ) def example_sla_dag(): @task(sla=datetime.timedelta(seconds=10)) @@ -60,6 +60,6 @@ def sleep_30(): sleep_20() >> sleep_30() -dag = example_sla_dag() +example_dag = example_sla_dag() # [END howto_task_sla] diff --git a/airflow/example_dags/example_subdag_operator.py b/airflow/example_dags/example_subdag_operator.py index 79d369d638d6a..f7c3b098135bb 100644 --- a/airflow/example_dags/example_subdag_operator.py +++ b/airflow/example_dags/example_subdag_operator.py @@ -15,8 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Example DAG demonstrating the usage of the SubDagOperator.""" +from __future__ import annotations # [START example_subdag_operator] import datetime @@ -26,36 +26,36 @@ from airflow.operators.empty import EmptyOperator from airflow.operators.subdag import SubDagOperator -DAG_NAME = 'example_subdag_operator' +DAG_NAME = "example_subdag_operator" with DAG( dag_id=DAG_NAME, default_args={"retries": 2}, start_date=datetime.datetime(2022, 1, 1), - schedule_interval="@once", - tags=['example'], + schedule="@once", + tags=["example"], ) as dag: start = EmptyOperator( - task_id='start', + task_id="start", ) section_1 = SubDagOperator( - task_id='section-1', - subdag=subdag(DAG_NAME, 'section-1', dag.default_args), + task_id="section-1", + subdag=subdag(DAG_NAME, "section-1", dag.default_args), ) some_other_task = EmptyOperator( - task_id='some-other-task', + task_id="some-other-task", ) section_2 = SubDagOperator( - task_id='section-2', - subdag=subdag(DAG_NAME, 'section-2', dag.default_args), + task_id="section-2", + subdag=subdag(DAG_NAME, "section-2", dag.default_args), ) end = EmptyOperator( - task_id='end', + task_id="end", ) start >> section_1 >> some_other_task >> section_2 >> end diff --git a/airflow/example_dags/example_task_group.py b/airflow/example_dags/example_task_group.py index 7bc319f98b569..9d7a9f2e74d59 100644 --- a/airflow/example_dags/example_task_group.py +++ b/airflow/example_dags/example_task_group.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Example DAG demonstrating the usage of the TaskGroup.""" +from __future__ import annotations + import pendulum from airflow.models.dag import DAG @@ -36,7 +37,7 @@ # [START howto_task_group_section_1] with TaskGroup("section_1", tooltip="Tasks for section_1") as section_1: task_1 = EmptyOperator(task_id="task_1") - task_2 = BashOperator(task_id="task_2", bash_command='echo 1') + task_2 = BashOperator(task_id="task_2", bash_command="echo 1") task_3 = EmptyOperator(task_id="task_3") task_1 >> [task_2, task_3] @@ -48,7 +49,7 @@ # [START howto_task_group_inner_section_2] with TaskGroup("inner_section_2", tooltip="Tasks for inner_section2") as inner_section_2: - task_2 = BashOperator(task_id="task_2", bash_command='echo 1') + task_2 = BashOperator(task_id="task_2", bash_command="echo 1") task_3 = EmptyOperator(task_id="task_3") task_4 = EmptyOperator(task_id="task_4") @@ -57,7 +58,7 @@ # [END howto_task_group_section_2] - end = EmptyOperator(task_id='end') + end = EmptyOperator(task_id="end") start >> section_1 >> section_2 >> end # [END howto_task_group] diff --git a/airflow/example_dags/example_task_group_decorator.py b/airflow/example_dags/example_task_group_decorator.py index 637c721886bea..db8d7b302529c 100644 --- a/airflow/example_dags/example_task_group_decorator.py +++ b/airflow/example_dags/example_task_group_decorator.py @@ -15,8 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Example DAG demonstrating the usage of the @taskgroup decorator.""" +from __future__ import annotations import pendulum @@ -29,31 +29,31 @@ @task def task_start(): """Empty Task which is First Task of Dag""" - return '[Task_start]' + return "[Task_start]" @task def task_1(value: int) -> str: """Empty Task1""" - return f'[ Task1 {value} ]' + return f"[ Task1 {value} ]" @task def task_2(value: str) -> str: """Empty Task2""" - return f'[ Task2 {value} ]' + return f"[ Task2 {value} ]" @task def task_3(value: str) -> None: """Empty Task3""" - print(f'[ Task3 {value} ]') + print(f"[ Task3 {value} ]") @task def task_end() -> None: """Empty Task which is Last Task of Dag""" - print('[ Task_End ]') + print("[ Task_End ]") # Creating TaskGroups diff --git a/airflow/example_dags/example_time_delta_sensor_async.py b/airflow/example_dags/example_time_delta_sensor_async.py index a2aa3cb66edee..d1562c5751d7d 100644 --- a/airflow/example_dags/example_time_delta_sensor_async.py +++ b/airflow/example_dags/example_time_delta_sensor_async.py @@ -15,11 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example DAG demonstrating ``TimeDeltaSensorAsync``, a drop in replacement for ``TimeDeltaSensor`` that defers and doesn't occupy a worker slot while it waits """ +from __future__ import annotations import datetime @@ -31,7 +31,7 @@ with DAG( dag_id="example_time_delta_sensor_async", - schedule_interval=None, + schedule=None, start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, tags=["example"], diff --git a/airflow/example_dags/example_trigger_controller_dag.py b/airflow/example_dags/example_trigger_controller_dag.py index a017c9a5b4176..c07fc3b190b0c 100644 --- a/airflow/example_dags/example_trigger_controller_dag.py +++ b/airflow/example_dags/example_trigger_controller_dag.py @@ -15,12 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example usage of the TriggerDagRunOperator. This example holds 2 DAGs: 1. 1st DAG (example_trigger_controller_dag) holds a TriggerDagRunOperator, which will trigger the 2nd DAG 2. 2nd DAG (example_trigger_target_dag) which will be triggered by the TriggerDagRunOperator in the 1st DAG """ +from __future__ import annotations + import pendulum from airflow import DAG @@ -30,8 +31,8 @@ dag_id="example_trigger_controller_dag", start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, - schedule_interval="@once", - tags=['example'], + schedule="@once", + tags=["example"], ) as dag: trigger = TriggerDagRunOperator( task_id="test_trigger_dagrun", diff --git a/airflow/example_dags/example_trigger_target_dag.py b/airflow/example_dags/example_trigger_target_dag.py index 20932338c8dd8..ff619376da0dc 100644 --- a/airflow/example_dags/example_trigger_target_dag.py +++ b/airflow/example_dags/example_trigger_target_dag.py @@ -15,12 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example usage of the TriggerDagRunOperator. This example holds 2 DAGs: 1. 1st DAG (example_trigger_controller_dag) holds a TriggerDagRunOperator, which will trigger the 2nd DAG 2. 2nd DAG (example_trigger_target_dag) which will be triggered by the TriggerDagRunOperator in the 1st DAG """ +from __future__ import annotations + import pendulum from airflow import DAG @@ -42,13 +43,13 @@ def run_this_func(dag_run=None): dag_id="example_trigger_target_dag", start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, - schedule_interval=None, - tags=['example'], + schedule=None, + tags=["example"], ) as dag: run_this = run_this_func() bash_task = BashOperator( task_id="bash_task", bash_command='echo "Here is the message: $message"', - env={'message': '{{ dag_run.conf.get("message") }}'}, + env={"message": '{{ dag_run.conf.get("message") }}'}, ) diff --git a/airflow/example_dags/example_xcom.py b/airflow/example_dags/example_xcom.py index b55d4e5d667cd..8455ad575cc82 100644 --- a/airflow/example_dags/example_xcom.py +++ b/airflow/example_dags/example_xcom.py @@ -15,22 +15,23 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Example DAG demonstrating the usage of XComs.""" +from __future__ import annotations + import pendulum -from airflow import DAG +from airflow import DAG, XComArg from airflow.decorators import task from airflow.operators.bash import BashOperator value_1 = [1, 2, 3] -value_2 = {'a': 'b'} +value_2 = {"a": "b"} @task def push(ti=None): """Pushes an XCom without a specific target""" - ti.xcom_push(key='value from pusher 1', value=value_1) + ti.xcom_push(key="value from pusher 1", value=value_1) @task @@ -41,7 +42,7 @@ def push_by_returning(): def _compare_values(pulled_value, check_value): if pulled_value != check_value: - raise ValueError(f'The two values differ {pulled_value} and {check_value}') + raise ValueError(f"The two values differ {pulled_value} and {check_value}") @task @@ -55,21 +56,21 @@ def puller(pulled_value_2, ti=None): @task def pull_value_from_bash_push(ti=None): - bash_pushed_via_return_value = ti.xcom_pull(key="return_value", task_ids='bash_push') - bash_manually_pushed_value = ti.xcom_pull(key="manually_pushed_value", task_ids='bash_push') + bash_pushed_via_return_value = ti.xcom_pull(key="return_value", task_ids="bash_push") + bash_manually_pushed_value = ti.xcom_pull(key="manually_pushed_value", task_ids="bash_push") print(f"The xcom value pushed by task push via return value is {bash_pushed_via_return_value}") print(f"The xcom value pushed by task push manually is {bash_manually_pushed_value}") with DAG( - 'example_xcom', - schedule_interval="@once", + "example_xcom", + schedule="@once", start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, - tags=['example'], + tags=["example"], ) as dag: bash_push = BashOperator( - task_id='bash_push', + task_id="bash_push", bash_command='echo "bash_push demo" && ' 'echo "Manually set xcom value ' '{{ ti.xcom_push(key="manually_pushed_value", value="manually_pushed_value") }}" && ' @@ -77,10 +78,10 @@ def pull_value_from_bash_push(ti=None): ) bash_pull = BashOperator( - task_id='bash_pull', + task_id="bash_pull", bash_command='echo "bash pull demo" && ' - f'echo "The xcom pushed manually is {bash_push.output["manually_pushed_value"]}" && ' - f'echo "The returned_value xcom is {bash_push.output}" && ' + f'echo "The xcom pushed manually is {XComArg(bash_push, key="manually_pushed_value")}" && ' + f'echo "The returned_value xcom is {XComArg(bash_push)}" && ' 'echo "finished"', do_xcom_push=False, ) @@ -90,6 +91,3 @@ def pull_value_from_bash_push(ti=None): [bash_pull, python_pull_from_bash] << bash_push puller(push_by_returning()) << push() - - # Task dependencies created via `XComArgs`: - # pull << push2 diff --git a/airflow/example_dags/example_xcomargs.py b/airflow/example_dags/example_xcomargs.py index 8312aca8c9d25..9d36cb535a8c1 100644 --- a/airflow/example_dags/example_xcomargs.py +++ b/airflow/example_dags/example_xcomargs.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Example DAG demonstrating the usage of the XComArgs.""" +from __future__ import annotations + import logging import pendulum @@ -41,11 +42,11 @@ def print_value(value, ts=None): with DAG( - dag_id='example_xcom_args', + dag_id="example_xcom_args", start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, - schedule_interval=None, - tags=['example'], + schedule=None, + tags=["example"], ) as dag: print_value(generate_value()) @@ -53,8 +54,8 @@ def print_value(value, ts=None): "example_xcom_args_with_operators", start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, - schedule_interval=None, - tags=['example'], + schedule=None, + tags=["example"], ) as dag2: bash_op1 = BashOperator(task_id="c", bash_command="echo c") bash_op2 = BashOperator(task_id="d", bash_command="echo c") diff --git a/airflow/example_dags/libs/helper.py b/airflow/example_dags/libs/helper.py index a3d3a720a0255..e6c2e3c4582fc 100644 --- a/airflow/example_dags/libs/helper.py +++ b/airflow/example_dags/libs/helper.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations def print_stuff(): diff --git a/airflow/example_dags/plugins/workday.py b/airflow/example_dags/plugins/workday.py index 77111a79396de..db68c29541a8f 100644 --- a/airflow/example_dags/plugins/workday.py +++ b/airflow/example_dags/plugins/workday.py @@ -15,12 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Plugin to demonstrate timetable registration and accommodate example DAGs.""" +from __future__ import annotations # [START howto_timetable] from datetime import timedelta -from typing import Optional from pendulum import UTC, Date, DateTime, Time @@ -47,9 +46,9 @@ def infer_manual_data_interval(self, run_after: DateTime) -> DataInterval: def next_dagrun_info( self, *, - last_automated_data_interval: Optional[DataInterval], + last_automated_data_interval: DataInterval | None, restriction: TimeRestriction, - ) -> Optional[DagRunInfo]: + ) -> DagRunInfo | None: if last_automated_data_interval is not None: # There was a previous run on the regular schedule. last_start = last_automated_data_interval.start last_start_weekday = last_start.weekday() diff --git a/airflow/example_dags/sql/sample.sql b/airflow/example_dags/sql/sample.sql new file mode 100644 index 0000000000000..23af6ab4b9bb3 --- /dev/null +++ b/airflow/example_dags/sql/sample.sql @@ -0,0 +1,24 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*/ + +CREATE TABLE Orders ( + order_id INT PRIMARY KEY, + name TEXT, + description TEXT +) diff --git a/airflow/example_dags/subdags/subdag.py b/airflow/example_dags/subdags/subdag.py index 2fcab731092fb..0bb9e86948eb1 100644 --- a/airflow/example_dags/subdags/subdag.py +++ b/airflow/example_dags/subdags/subdag.py @@ -15,8 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Helper function to generate a DAG and operators given some arguments.""" +from __future__ import annotations # [START subdag] import pendulum @@ -25,7 +25,7 @@ from airflow.operators.empty import EmptyOperator -def subdag(parent_dag_name, child_dag_name, args): +def subdag(parent_dag_name, child_dag_name, args) -> DAG: """ Generate a DAG to be used as a subdag. @@ -33,19 +33,18 @@ def subdag(parent_dag_name, child_dag_name, args): :param str child_dag_name: Id of the child DAG :param dict args: Default arguments to provide to the subdag :return: DAG to use as a subdag - :rtype: airflow.models.DAG """ dag_subdag = DAG( - dag_id=f'{parent_dag_name}.{child_dag_name}', + dag_id=f"{parent_dag_name}.{child_dag_name}", default_args=args, start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, - schedule_interval="@daily", + schedule="@daily", ) for i in range(5): EmptyOperator( - task_id=f'{child_dag_name}-task-{i + 1}', + task_id=f"{child_dag_name}-task-{i + 1}", default_args=args, dag=dag_subdag, ) diff --git a/airflow/example_dags/tutorial.py b/airflow/example_dags/tutorial.py index ff2bd2fe95cf7..dc6399625c98a 100644 --- a/airflow/example_dags/tutorial.py +++ b/airflow/example_dags/tutorial.py @@ -15,12 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ ### Tutorial Documentation Documentation that goes along with the Airflow tutorial located [here](https://airflow.apache.org/tutorial.html) """ +from __future__ import annotations + # [START tutorial] # [START import_module] from datetime import datetime, timedelta @@ -37,17 +38,17 @@ # [START instantiate_dag] with DAG( - 'tutorial', + "tutorial", # [START default_args] # These args will get passed on to each operator # You can override them on a per-task basis during operator initialization default_args={ - 'depends_on_past': False, - 'email': ['airflow@example.com'], - 'email_on_failure': False, - 'email_on_retry': False, - 'retries': 1, - 'retry_delay': timedelta(minutes=5), + "depends_on_past": False, + "email": ["airflow@example.com"], + "email_on_failure": False, + "email_on_retry": False, + "retries": 1, + "retry_delay": timedelta(minutes=5), # 'queue': 'bash_queue', # 'pool': 'backfill', # 'priority_weight': 10, @@ -62,25 +63,25 @@ # 'trigger_rule': 'all_success' }, # [END default_args] - description='A simple tutorial DAG', - schedule_interval=timedelta(days=1), + description="A simple tutorial DAG", + schedule=timedelta(days=1), start_date=datetime(2021, 1, 1), catchup=False, - tags=['example'], + tags=["example"], ) as dag: # [END instantiate_dag] # t1, t2 and t3 are examples of tasks created by instantiating operators # [START basic_task] t1 = BashOperator( - task_id='print_date', - bash_command='date', + task_id="print_date", + bash_command="date", ) t2 = BashOperator( - task_id='sleep', + task_id="sleep", depends_on_past=False, - bash_command='sleep 5', + bash_command="sleep 5", retries=3, ) # [END basic_task] @@ -93,11 +94,11 @@ `doc` (plain text), `doc_rst`, `doc_json`, `doc_yaml` which gets rendered in the UI's Task Instance Details page. ![img](http://montcs.bloomu.edu/~bobmon/Semesters/2012-01/491/import%20soul.png) - + **Image Credit:** Randall Munroe, [XKCD](https://xkcd.com/license.html) """ ) - dag.doc_md = __doc__ # providing that you have a docstring at the beginning of the DAG + dag.doc_md = __doc__ # providing that you have a docstring at the beginning of the DAG; OR dag.doc_md = """ This is a documentation placed anywhere """ # otherwise, type it like this @@ -114,7 +115,7 @@ ) t3 = BashOperator( - task_id='templated', + task_id="templated", depends_on_past=False, bash_command=templated_command, ) diff --git a/airflow/example_dags/tutorial_dag.py b/airflow/example_dags/tutorial_dag.py new file mode 100644 index 0000000000000..07b193865de4b --- /dev/null +++ b/airflow/example_dags/tutorial_dag.py @@ -0,0 +1,135 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +### DAG Tutorial Documentation +This DAG is demonstrating an Extract -> Transform -> Load pipeline +""" +from __future__ import annotations + +# [START tutorial] +# [START import_module] +import json +from textwrap import dedent + +import pendulum + +# The DAG object; we'll need this to instantiate a DAG +from airflow import DAG + +# Operators; we need this to operate! +from airflow.operators.python import PythonOperator + +# [END import_module] + +# [START instantiate_dag] +with DAG( + "tutorial_dag", + # [START default_args] + # These args will get passed on to each operator + # You can override them on a per-task basis during operator initialization + default_args={"retries": 2}, + # [END default_args] + description="DAG tutorial", + schedule=None, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + catchup=False, + tags=["example"], +) as dag: + # [END instantiate_dag] + # [START documentation] + dag.doc_md = __doc__ + # [END documentation] + + # [START extract_function] + def extract(**kwargs): + ti = kwargs["ti"] + data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}' + ti.xcom_push("order_data", data_string) + + # [END extract_function] + + # [START transform_function] + def transform(**kwargs): + ti = kwargs["ti"] + extract_data_string = ti.xcom_pull(task_ids="extract", key="order_data") + order_data = json.loads(extract_data_string) + + total_order_value = 0 + for value in order_data.values(): + total_order_value += value + + total_value = {"total_order_value": total_order_value} + total_value_json_string = json.dumps(total_value) + ti.xcom_push("total_order_value", total_value_json_string) + + # [END transform_function] + + # [START load_function] + def load(**kwargs): + ti = kwargs["ti"] + total_value_string = ti.xcom_pull(task_ids="transform", key="total_order_value") + total_order_value = json.loads(total_value_string) + + print(total_order_value) + + # [END load_function] + + # [START main_flow] + extract_task = PythonOperator( + task_id="extract", + python_callable=extract, + ) + extract_task.doc_md = dedent( + """\ + #### Extract task + A simple Extract task to get data ready for the rest of the data pipeline. + In this case, getting data is simulated by reading from a hardcoded JSON string. + This data is then put into xcom, so that it can be processed by the next task. + """ + ) + + transform_task = PythonOperator( + task_id="transform", + python_callable=transform, + ) + transform_task.doc_md = dedent( + """\ + #### Transform task + A simple Transform task which takes in the collection of order data from xcom + and computes the total order value. + This computed value is then put into xcom, so that it can be processed by the next task. + """ + ) + + load_task = PythonOperator( + task_id="load", + python_callable=load, + ) + load_task.doc_md = dedent( + """\ + #### Load task + A simple Load task which takes in the result of the Transform task, by reading it + from xcom and instead of saving it to end user review, just prints it out. + """ + ) + + extract_task >> transform_task >> load_task + +# [END main_flow] + +# [END tutorial] diff --git a/airflow/example_dags/tutorial_etl_dag.py b/airflow/example_dags/tutorial_etl_dag.py deleted file mode 100644 index d039a73488c18..0000000000000 --- a/airflow/example_dags/tutorial_etl_dag.py +++ /dev/null @@ -1,135 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -""" -### ETL DAG Tutorial Documentation -This ETL DAG is demonstrating an Extract -> Transform -> Load pipeline -""" -# [START tutorial] -# [START import_module] -import json -from textwrap import dedent - -import pendulum - -# The DAG object; we'll need this to instantiate a DAG -from airflow import DAG - -# Operators; we need this to operate! -from airflow.operators.python import PythonOperator - -# [END import_module] - -# [START instantiate_dag] -with DAG( - 'tutorial_etl_dag', - # [START default_args] - # These args will get passed on to each operator - # You can override them on a per-task basis during operator initialization - default_args={'retries': 2}, - # [END default_args] - description='ETL DAG tutorial', - schedule_interval=None, - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - catchup=False, - tags=['example'], -) as dag: - # [END instantiate_dag] - # [START documentation] - dag.doc_md = __doc__ - # [END documentation] - - # [START extract_function] - def extract(**kwargs): - ti = kwargs['ti'] - data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}' - ti.xcom_push('order_data', data_string) - - # [END extract_function] - - # [START transform_function] - def transform(**kwargs): - ti = kwargs['ti'] - extract_data_string = ti.xcom_pull(task_ids='extract', key='order_data') - order_data = json.loads(extract_data_string) - - total_order_value = 0 - for value in order_data.values(): - total_order_value += value - - total_value = {"total_order_value": total_order_value} - total_value_json_string = json.dumps(total_value) - ti.xcom_push('total_order_value', total_value_json_string) - - # [END transform_function] - - # [START load_function] - def load(**kwargs): - ti = kwargs['ti'] - total_value_string = ti.xcom_pull(task_ids='transform', key='total_order_value') - total_order_value = json.loads(total_value_string) - - print(total_order_value) - - # [END load_function] - - # [START main_flow] - extract_task = PythonOperator( - task_id='extract', - python_callable=extract, - ) - extract_task.doc_md = dedent( - """\ - #### Extract task - A simple Extract task to get data ready for the rest of the data pipeline. - In this case, getting data is simulated by reading from a hardcoded JSON string. - This data is then put into xcom, so that it can be processed by the next task. - """ - ) - - transform_task = PythonOperator( - task_id='transform', - python_callable=transform, - ) - transform_task.doc_md = dedent( - """\ - #### Transform task - A simple Transform task which takes in the collection of order data from xcom - and computes the total order value. - This computed value is then put into xcom, so that it can be processed by the next task. - """ - ) - - load_task = PythonOperator( - task_id='load', - python_callable=load, - ) - load_task.doc_md = dedent( - """\ - #### Load task - A simple Load task which takes in the result of the Transform task, by reading it - from xcom and instead of saving it to end user review, just prints it out. - """ - ) - - extract_task >> transform_task >> load_task - -# [END main_flow] - -# [END tutorial] diff --git a/airflow/example_dags/tutorial_taskflow_api.py b/airflow/example_dags/tutorial_taskflow_api.py new file mode 100644 index 0000000000000..f41f729af8870 --- /dev/null +++ b/airflow/example_dags/tutorial_taskflow_api.py @@ -0,0 +1,106 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +# [START tutorial] +# [START import_module] +import json + +import pendulum + +from airflow.decorators import dag, task + +# [END import_module] + + +# [START instantiate_dag] +@dag( + schedule=None, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + catchup=False, + tags=["example"], +) +def tutorial_taskflow_api(): + """ + ### TaskFlow API Tutorial Documentation + This is a simple data pipeline example which demonstrates the use of + the TaskFlow API using three simple tasks for Extract, Transform, and Load. + Documentation that goes along with the Airflow TaskFlow API tutorial is + located + [here](https://airflow.apache.org/docs/apache-airflow/stable/tutorial_taskflow_api.html) + """ + # [END instantiate_dag] + + # [START extract] + @task() + def extract(): + """ + #### Extract task + A simple Extract task to get data ready for the rest of the data + pipeline. In this case, getting data is simulated by reading from a + hardcoded JSON string. + """ + data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}' + + order_data_dict = json.loads(data_string) + return order_data_dict + + # [END extract] + + # [START transform] + @task(multiple_outputs=True) + def transform(order_data_dict: dict): + """ + #### Transform task + A simple Transform task which takes in the collection of order data and + computes the total order value. + """ + total_order_value = 0 + + for value in order_data_dict.values(): + total_order_value += value + + return {"total_order_value": total_order_value} + + # [END transform] + + # [START load] + @task() + def load(total_order_value: float): + """ + #### Load task + A simple Load task which takes in the result of the Transform task and + instead of saving it to end user review, just prints it out. + """ + + print(f"Total order value is: {total_order_value:.2f}") + + # [END load] + + # [START main_flow] + order_data = extract() + order_summary = transform(order_data) + load(order_summary["total_order_value"]) + # [END main_flow] + + +# [START dag_invocation] +tutorial_taskflow_api() +# [END dag_invocation] + +# [END tutorial] diff --git a/airflow/example_dags/tutorial_taskflow_api_etl.py b/airflow/example_dags/tutorial_taskflow_api_etl.py deleted file mode 100644 index f6af78f0a5a2c..0000000000000 --- a/airflow/example_dags/tutorial_taskflow_api_etl.py +++ /dev/null @@ -1,106 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -# [START tutorial] -# [START import_module] -import json - -import pendulum - -from airflow.decorators import dag, task - -# [END import_module] - - -# [START instantiate_dag] -@dag( - schedule_interval=None, - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - catchup=False, - tags=['example'], -) -def tutorial_taskflow_api_etl(): - """ - ### TaskFlow API Tutorial Documentation - This is a simple ETL data pipeline example which demonstrates the use of - the TaskFlow API using three simple tasks for Extract, Transform, and Load. - Documentation that goes along with the Airflow TaskFlow API tutorial is - located - [here](https://airflow.apache.org/docs/apache-airflow/stable/tutorial_taskflow_api.html) - """ - # [END instantiate_dag] - - # [START extract] - @task() - def extract(): - """ - #### Extract task - A simple Extract task to get data ready for the rest of the data - pipeline. In this case, getting data is simulated by reading from a - hardcoded JSON string. - """ - data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}' - - order_data_dict = json.loads(data_string) - return order_data_dict - - # [END extract] - - # [START transform] - @task(multiple_outputs=True) - def transform(order_data_dict: dict): - """ - #### Transform task - A simple Transform task which takes in the collection of order data and - computes the total order value. - """ - total_order_value = 0 - - for value in order_data_dict.values(): - total_order_value += value - - return {"total_order_value": total_order_value} - - # [END transform] - - # [START load] - @task() - def load(total_order_value: float): - """ - #### Load task - A simple Load task which takes in the result of the Transform task and - instead of saving it to end user review, just prints it out. - """ - - print(f"Total order value is: {total_order_value:.2f}") - - # [END load] - - # [START main_flow] - order_data = extract() - order_summary = transform(order_data) - load(order_summary["total_order_value"]) - # [END main_flow] - - -# [START dag_invocation] -tutorial_etl_dag = tutorial_taskflow_api_etl() -# [END dag_invocation] - -# [END tutorial] diff --git a/airflow/example_dags/tutorial_taskflow_api_etl_virtualenv.py b/airflow/example_dags/tutorial_taskflow_api_etl_virtualenv.py deleted file mode 100644 index ac280956b7f45..0000000000000 --- a/airflow/example_dags/tutorial_taskflow_api_etl_virtualenv.py +++ /dev/null @@ -1,89 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -import logging -import shutil -from datetime import datetime - -from airflow.decorators import dag, task - -log = logging.getLogger(__name__) - -if not shutil.which("virtualenv"): - log.warning( - "The tutorial_taskflow_api_etl_virtualenv example DAG requires virtualenv, please install it." - ) -else: - - @dag(schedule_interval=None, start_date=datetime(2021, 1, 1), catchup=False, tags=['example']) - def tutorial_taskflow_api_etl_virtualenv(): - """ - ### TaskFlow API example using virtualenv - This is a simple ETL data pipeline example which demonstrates the use of - the TaskFlow API using three simple tasks for Extract, Transform, and Load. - """ - - @task.virtualenv( - use_dill=True, - system_site_packages=False, - requirements=['funcsigs'], - ) - def extract(): - """ - #### Extract task - A simple Extract task to get data ready for the rest of the data - pipeline. In this case, getting data is simulated by reading from a - hardcoded JSON string. - """ - import json - - data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}' - - order_data_dict = json.loads(data_string) - return order_data_dict - - @task(multiple_outputs=True) - def transform(order_data_dict: dict): - """ - #### Transform task - A simple Transform task which takes in the collection of order data and - computes the total order value. - """ - total_order_value = 0 - - for value in order_data_dict.values(): - total_order_value += value - - return {"total_order_value": total_order_value} - - @task() - def load(total_order_value: float): - """ - #### Load task - A simple Load task which takes in the result of the Transform task and - instead of saving it to end user review, just prints it out. - """ - - print(f"Total order value is: {total_order_value:.2f}") - - order_data = extract() - order_summary = transform(order_data) - load(order_summary["total_order_value"]) - - tutorial_etl_dag = tutorial_taskflow_api_etl_virtualenv() diff --git a/airflow/example_dags/tutorial_taskflow_api_virtualenv.py b/airflow/example_dags/tutorial_taskflow_api_virtualenv.py new file mode 100644 index 0000000000000..a78116d2f7493 --- /dev/null +++ b/airflow/example_dags/tutorial_taskflow_api_virtualenv.py @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +import shutil +from datetime import datetime + +from airflow.decorators import dag, task + +log = logging.getLogger(__name__) + +if not shutil.which("virtualenv"): + log.warning("The tutorial_taskflow_api_virtualenv example DAG requires virtualenv, please install it.") +else: + + @dag(schedule=None, start_date=datetime(2021, 1, 1), catchup=False, tags=["example"]) + def tutorial_taskflow_api_virtualenv(): + """ + ### TaskFlow API example using virtualenv + This is a simple data pipeline example which demonstrates the use of + the TaskFlow API using three simple tasks for Extract, Transform, and Load. + """ + + @task.virtualenv( + use_dill=True, + system_site_packages=False, + requirements=["funcsigs"], + ) + def extract(): + """ + #### Extract task + A simple Extract task to get data ready for the rest of the data + pipeline. In this case, getting data is simulated by reading from a + hardcoded JSON string. + """ + import json + + data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}' + + order_data_dict = json.loads(data_string) + return order_data_dict + + @task(multiple_outputs=True) + def transform(order_data_dict: dict): + """ + #### Transform task + A simple Transform task which takes in the collection of order data and + computes the total order value. + """ + total_order_value = 0 + + for value in order_data_dict.values(): + total_order_value += value + + return {"total_order_value": total_order_value} + + @task() + def load(total_order_value: float): + """ + #### Load task + A simple Load task which takes in the result of the Transform task and + instead of saving it to end user review, just prints it out. + """ + + print(f"Total order value is: {total_order_value:.2f}") + + order_data = extract() + order_summary = transform(order_data) + load(order_summary["total_order_value"]) + + tutorial_dag = tutorial_taskflow_api_virtualenv() diff --git a/airflow/exceptions.py b/airflow/exceptions.py index 2f1e53e182a10..a60e3f90ffc73 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -15,14 +15,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# # Note: Any AirflowException raised is expected to cause the TaskInstance # to be marked in an ERROR state -"""Exceptions used by Airflow""" +"""Exceptions used by Airflow.""" +from __future__ import annotations + import datetime import warnings from http import HTTPStatus -from typing import Any, Dict, List, NamedTuple, Optional, Sized +from typing import TYPE_CHECKING, Any, NamedTuple, Sized + +if TYPE_CHECKING: + from airflow.models import DagRun class AirflowException(Exception): @@ -67,14 +71,6 @@ def __init__(self, reschedule_date): self.reschedule_date = reschedule_date -class AirflowSmartSensorException(AirflowException): - """ - Raise after the task register itself in the smart sensor service. - - It should exit without failing a task. - """ - - class InvalidStatsNameException(AirflowException): """Raise when name of the stats is invalid.""" @@ -88,7 +84,7 @@ class AirflowWebServerTimeout(AirflowException): class AirflowSkipException(AirflowException): - """Raise when the task should be skipped""" + """Raise when the task should be skipped.""" class AirflowFailException(AirflowException): @@ -99,6 +95,19 @@ class AirflowOptionalProviderFeatureException(AirflowException): """Raise by providers when imports are missing for optional provider features.""" +class XComNotFound(AirflowException): + """Raise when an XCom reference is being resolved against a non-existent XCom.""" + + def __init__(self, dag_id: str, task_id: str, key: str) -> None: + super().__init__() + self.dag_id = dag_id + self.task_id = task_id + self.key = key + + def __str__(self) -> str: + return f'XComArg result from {self.task_id} at {self.dag_id} with key="{self.key}" is not found!' + + class UnmappableOperator(AirflowException): """Raise when an operator is not implemented to be mappable.""" @@ -113,12 +122,14 @@ def __str__(self) -> str: class UnmappableXComTypePushed(AirflowException): """Raise when an unmappable type is pushed as a mapped downstream's dependency.""" - def __init__(self, value: Any) -> None: - super().__init__(value) - self.value = value + def __init__(self, value: Any, *values: Any) -> None: + super().__init__(value, *values) def __str__(self) -> str: - return f"unmappable return type {type(self.value).__qualname__!r}" + typename = type(self.args[0]).__qualname__ + for arg in self.args[1:]: + typename = f"{typename}[{type(arg).__qualname__}]" + return f"unmappable return type {typename!r}" class UnmappableXComLengthPushed(AirflowException): @@ -150,6 +161,10 @@ def __str__(self) -> str: return f"Ignoring DAG {self.dag_id} from {self.incoming} - also found in {self.existing}" +class AirflowDagInconsistent(AirflowException): + """Raise when a DAG has inconsistent attributes.""" + + class AirflowClusterPolicyViolation(AirflowException): """Raise when there is a violation of a Cluster Policy in DAG definition.""" @@ -173,6 +188,12 @@ class DagRunNotFound(AirflowNotFoundException): class DagRunAlreadyExists(AirflowBadRequest): """Raise when creating a DAG run for DAG which already has DAG run entry.""" + def __init__(self, dag_run: DagRun, execution_date: datetime.datetime, run_id: str) -> None: + super().__init__( + f"A DAG Run already exists for DAG {dag_run.dag_id} at {execution_date} with run id {run_id}" + ) + self.dag_run = dag_run + class DagFileExists(AirflowBadRequest): """Raise when a DAG ID is still in DagBag i.e., DAG file is in DAG folder.""" @@ -186,12 +207,29 @@ class DuplicateTaskIdFound(AirflowException): """Raise when a Task with duplicate task_id is defined in the same DAG.""" +class TaskAlreadyInTaskGroup(AirflowException): + """Raise when a Task cannot be added to a TaskGroup since it already belongs to another TaskGroup.""" + + def __init__(self, task_id: str, existing_group_id: str | None, new_group_id: str) -> None: + super().__init__(task_id, new_group_id) + self.task_id = task_id + self.existing_group_id = existing_group_id + self.new_group_id = new_group_id + + def __str__(self) -> str: + if self.existing_group_id is None: + existing_group = "the DAG's root group" + else: + existing_group = f"group {self.existing_group_id!r}" + return f"cannot add {self.task_id!r} to {self.new_group_id!r} (already in {existing_group})" + + class SerializationError(AirflowException): - """A problem occurred when trying to serialize a DAG.""" + """A problem occurred when trying to serialize something.""" class ParamValidationError(AirflowException): - """Raise when DAG params is invalid""" + """Raise when DAG params is invalid.""" class TaskNotFound(AirflowNotFoundException): @@ -234,7 +272,7 @@ def __init__(self, message, ti_status): class FileSyntaxError(NamedTuple): """Information about a single error in a file.""" - line_no: Optional[int] + line_no: int | None message: str def __str__(self): @@ -250,7 +288,7 @@ class AirflowFileParseException(AirflowException): :param parse_errors: File syntax errors """ - def __init__(self, msg: str, file_path: str, parse_errors: List[FileSyntaxError]) -> None: + def __init__(self, msg: str, file_path: str, parse_errors: list[FileSyntaxError]) -> None: super().__init__(msg) self.msg = msg self.file_path = file_path @@ -279,6 +317,8 @@ class ConnectionNotUnique(AirflowException): class TaskDeferred(BaseException): """ + Signal an operator moving to deferred state. + Special exception raised to signal that the operator it was raised from wishes to defer until a trigger fires. """ @@ -288,8 +328,8 @@ def __init__( *, trigger, method_name: str, - kwargs: Optional[Dict[str, Any]] = None, - timeout: Optional[datetime.timedelta] = None, + kwargs: dict[str, Any] | None = None, + timeout: datetime.timedelta | None = None, ): super().__init__() self.trigger = trigger @@ -306,3 +346,25 @@ def __repr__(self) -> str: class TaskDeferralError(AirflowException): """Raised when a task failed during deferral for some reason.""" + + +class PodMutationHookException(AirflowException): + """Raised when exception happens during Pod Mutation Hook execution.""" + + +class PodReconciliationError(AirflowException): + """Raised when an error is encountered while trying to merge pod configs.""" + + +class RemovedInAirflow3Warning(DeprecationWarning): + """Issued for usage of deprecated features that will be removed in Airflow3.""" + + deprecated_since: str | None = None + "Indicates the airflow version that started raising this deprecation warning" + + +class AirflowProviderDeprecationWarning(DeprecationWarning): + """Issued for usage of deprecated features of Airflow provider.""" + + deprecated_provider_since: str | None = None + "Indicates the provider version that started raising this deprecation warning" diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 7fcbd0642e14d..0c9af11864f99 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -15,21 +15,23 @@ # specific language governing permissions and limitations # under the License. """Base executor - this is the base class for all the implemented executors.""" +from __future__ import annotations + import sys +import warnings from collections import OrderedDict -from typing import Any, Counter, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Counter, List, Optional, Sequence, Tuple from airflow.callbacks.base_callback_sink import BaseCallbackSink from airflow.callbacks.callback_requests import CallbackRequest from airflow.configuration import conf +from airflow.exceptions import RemovedInAirflow3Warning from airflow.models.taskinstance import TaskInstance, TaskInstanceKey from airflow.stats import Stats from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State -PARALLELISM: int = conf.getint('core', 'PARALLELISM') - -NOT_STARTED_MESSAGE = "The executor should be started first!" +PARALLELISM: int = conf.getint("core", "PARALLELISM") QUEUEING_ATTEMPTS = 5 @@ -62,15 +64,15 @@ class BaseExecutor(LoggingMixin): ``0`` for infinity """ - job_id: Union[None, int, str] = None - callback_sink: Optional[BaseCallbackSink] = None + job_id: None | int | str = None + callback_sink: BaseCallbackSink | None = None def __init__(self, parallelism: int = PARALLELISM): super().__init__() self.parallelism: int = parallelism self.queued_tasks: OrderedDict[TaskInstanceKey, QueuedTaskInstanceType] = OrderedDict() - self.running: Set[TaskInstanceKey] = set() - self.event_buffer: Dict[TaskInstanceKey, EventBufferValueType] = {} + self.running: set[TaskInstanceKey] = set() + self.event_buffer: dict[TaskInstanceKey, EventBufferValueType] = {} self.attempts: Counter[TaskInstanceKey] = Counter() def __repr__(self): @@ -84,7 +86,7 @@ def queue_command( task_instance: TaskInstance, command: CommandType, priority: int = 1, - queue: Optional[str] = None, + queue: str | None = None, ): """Queues command to task""" if task_instance.key not in self.queued_tasks: @@ -97,13 +99,13 @@ def queue_task_instance( self, task_instance: TaskInstance, mark_success: bool = False, - pickle_id: Optional[str] = None, + pickle_id: str | None = None, ignore_all_deps: bool = False, ignore_depends_on_past: bool = False, ignore_task_deps: bool = False, ignore_ti_state: bool = False, - pool: Optional[str] = None, - cfg_path: Optional[str] = None, + pool: str | None = None, + cfg_path: str | None = None, ) -> None: """Queues task instance.""" pool = pool or task_instance.pool @@ -160,9 +162,9 @@ def heartbeat(self) -> None: self.log.debug("%s in queue", num_queued_tasks) self.log.debug("%s open slots", open_slots) - Stats.gauge('executor.open_slots', open_slots) - Stats.gauge('executor.queued_tasks', num_queued_tasks) - Stats.gauge('executor.running_tasks', num_running_tasks) + Stats.gauge("executor.open_slots", open_slots) + Stats.gauge("executor.queued_tasks", num_queued_tasks) + Stats.gauge("executor.running_tasks", num_running_tasks) self.trigger_tasks(open_slots) @@ -170,7 +172,7 @@ def heartbeat(self) -> None: self.log.debug("Calling the %s sync method", self.__class__) self.sync() - def order_queued_tasks_by_priority(self) -> List[Tuple[TaskInstanceKey, QueuedTaskInstanceType]]: + def order_queued_tasks_by_priority(self) -> list[tuple[TaskInstanceKey, QueuedTaskInstanceType]]: """ Orders the queued tasks by priority. @@ -221,7 +223,7 @@ def trigger_tasks(self, open_slots: int) -> None: if task_tuples: self._process_tasks(task_tuples) - def _process_tasks(self, task_tuples: List[TaskTuple]) -> None: + def _process_tasks(self, task_tuples: list[TaskTuple]) -> None: for key, command, queue, executor_config in task_tuples: del self.queued_tasks[key] self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) @@ -239,7 +241,7 @@ def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None: try: self.running.remove(key) except KeyError: - self.log.debug('Could not find key: %s', str(key)) + self.log.debug("Could not find key: %s", str(key)) self.event_buffer[key] = state, info def fail(self, key: TaskInstanceKey, info=None) -> None: @@ -260,7 +262,7 @@ def success(self, key: TaskInstanceKey, info=None) -> None: """ self.change_state(key, State.SUCCESS, info) - def get_event_buffer(self, dag_ids=None) -> Dict[TaskInstanceKey, EventBufferValueType]: + def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, EventBufferValueType]: """ Returns and flush the event buffer. In case dag_ids is specified it will only return and flush events for the given dag_ids. Otherwise @@ -269,7 +271,7 @@ def get_event_buffer(self, dag_ids=None) -> Dict[TaskInstanceKey, EventBufferVal :param dag_ids: the dag_ids to return events for; returns all if given ``None``. :return: a dict of events """ - cleared_events: Dict[TaskInstanceKey, EventBufferValueType] = {} + cleared_events: dict[TaskInstanceKey, EventBufferValueType] = {} if dag_ids is None: cleared_events = self.event_buffer self.event_buffer = {} @@ -284,8 +286,8 @@ def execute_async( self, key: TaskInstanceKey, command: CommandType, - queue: Optional[str] = None, - executor_config: Optional[Any] = None, + queue: str | None = None, + executor_config: Any | None = None, ) -> None: # pragma: no cover """ This method will execute the command asynchronously. @@ -309,7 +311,7 @@ def terminate(self): """This method is called when the daemon receives a SIGTERM""" raise NotImplementedError() - def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]: + def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]: """ Try to adopt running task instances that have been abandoned by a SchedulerJob dying. @@ -317,7 +319,6 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance re-scheduling) :return: any TaskInstances that were unable to be adopted - :rtype: list[airflow.models.TaskInstance] """ # By default, assume Executors cannot adopt tasks, so just say we failed to adopt anything. # Subclasses can do better! @@ -332,10 +333,43 @@ def slots_available(self): return sys.maxsize @staticmethod - def validate_command(command: List[str]) -> None: - """Check if the command to execute is airflow command""" + def validate_command(command: list[str]) -> None: + """ + Back-compat method to Check if the command to execute is airflow command + + :param command: command to check + :return: None + """ + warnings.warn( + """ + The `validate_command` method is deprecated. Please use ``validate_airflow_tasks_run_command`` + """, + RemovedInAirflow3Warning, + stacklevel=2, + ) + BaseExecutor.validate_airflow_tasks_run_command(command) + + @staticmethod + def validate_airflow_tasks_run_command(command: list[str]) -> tuple[str | None, str | None]: + """ + Check if the command to execute is airflow command + + Returns tuple (dag_id,task_id) retrieved from the command (replaced with None values if missing) + """ if command[0:3] != ["airflow", "tasks", "run"]: raise ValueError('The command must start with ["airflow", "tasks", "run"].') + if len(command) > 3 and "--help" not in command: + dag_id: str | None = None + task_id: str | None = None + for arg in command[3:]: + if not arg.startswith("--"): + if dag_id is None: + dag_id = arg + else: + task_id = arg + break + return dag_id, task_id + return None, None def debug_dump(self): """Called in response to SIGUSR2 by the scheduler""" diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index 7b4c04e225a75..896a46694cfd0 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -21,6 +21,8 @@ For more information on how the CeleryExecutor works, take a look at the guide: :ref:`executor:CeleryExecutor` """ +from __future__ import annotations + import datetime import logging import math @@ -33,7 +35,7 @@ from concurrent.futures import ProcessPoolExecutor from enum import Enum from multiprocessing import cpu_count -from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Set, Tuple, Union +from typing import Any, Mapping, MutableMapping, Optional, Sequence, Tuple from celery import Celery, Task, states as celery_states from celery.backends.base import BaseKeyValueStoreBackend @@ -50,6 +52,7 @@ from airflow.executors.base_executor import BaseExecutor, CommandType, EventBufferValueType, TaskTuple from airflow.models.taskinstance import TaskInstance, TaskInstanceKey from airflow.stats import Stats +from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname from airflow.utils.session import NEW_SESSION, provide_session @@ -60,43 +63,43 @@ log = logging.getLogger(__name__) # Make it constant for unit test. -CELERY_FETCH_ERR_MSG_HEADER = 'Error fetching Celery task state' +CELERY_FETCH_ERR_MSG_HEADER = "Error fetching Celery task state" -CELERY_SEND_ERR_MSG_HEADER = 'Error sending Celery task' +CELERY_SEND_ERR_MSG_HEADER = "Error sending Celery task" -OPERATION_TIMEOUT = conf.getfloat('celery', 'operation_timeout', fallback=1.0) +OPERATION_TIMEOUT = conf.getfloat("celery", "operation_timeout", fallback=1.0) -''' +""" To start the celery worker, run the command: airflow celery worker -''' +""" -if conf.has_option('celery', 'celery_config_options'): - celery_configuration = conf.getimport('celery', 'celery_config_options') +if conf.has_option("celery", "celery_config_options"): + celery_configuration = conf.getimport("celery", "celery_config_options") else: celery_configuration = DEFAULT_CELERY_CONFIG -app = Celery(conf.get('celery', 'CELERY_APP_NAME'), config_source=celery_configuration) +app = Celery(conf.get("celery", "CELERY_APP_NAME"), config_source=celery_configuration) @app.task def execute_command(command_to_exec: CommandType) -> None: """Executes command.""" - BaseExecutor.validate_command(command_to_exec) + dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(command_to_exec) celery_task_id = app.current_task.request.id log.info("[%s] Executing command in Celery: %s", celery_task_id, command_to_exec) - - try: - if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER: - _execute_in_subprocess(command_to_exec, celery_task_id) - else: - _execute_in_fork(command_to_exec, celery_task_id) - except Exception: - Stats.incr("celery.execute_command.failure") - raise + with _airflow_parsing_context_manager(dag_id=dag_id, task_id=task_id): + try: + if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER: + _execute_in_subprocess(command_to_exec, celery_task_id) + else: + _execute_in_fork(command_to_exec, celery_task_id) + except Exception: + Stats.incr("celery.execute_command.failure") + raise -def _execute_in_fork(command_to_exec: CommandType, celery_task_id: Optional[str] = None) -> None: +def _execute_in_fork(command_to_exec: CommandType, celery_task_id: str | None = None) -> None: pid = os.fork() if pid: # In parent, wait for the child @@ -104,7 +107,7 @@ def _execute_in_fork(command_to_exec: CommandType, celery_task_id: Optional[str] if ret == 0: return - msg = f'Celery command failed on host: {get_hostname()} with celery_task_id {celery_task_id}' + msg = f"Celery command failed on host: {get_hostname()} with celery_task_id {celery_task_id}" raise AirflowException(msg) from airflow.sentry import Sentry @@ -124,7 +127,6 @@ def _execute_in_fork(command_to_exec: CommandType, celery_task_id: Optional[str] args.external_executor_id = celery_task_id setproctitle(f"airflow task supervisor: {command_to_exec}") - args.func(args) ret = 0 except Exception as e: @@ -136,16 +138,16 @@ def _execute_in_fork(command_to_exec: CommandType, celery_task_id: Optional[str] os._exit(ret) -def _execute_in_subprocess(command_to_exec: CommandType, celery_task_id: Optional[str] = None) -> None: +def _execute_in_subprocess(command_to_exec: CommandType, celery_task_id: str | None = None) -> None: env = os.environ.copy() if celery_task_id: env["external_executor_id"] = celery_task_id try: subprocess.check_output(command_to_exec, stderr=subprocess.STDOUT, close_fds=True, env=env) except subprocess.CalledProcessError as e: - log.exception('[%s] execute_command encountered a CalledProcessError', celery_task_id) + log.exception("[%s] execute_command encountered a CalledProcessError", celery_task_id) log.error(e.output) - msg = f'Celery command failed on host: {get_hostname()} with celery_task_id {celery_task_id}' + msg = f"Celery command failed on host: {get_hostname()} with celery_task_id {celery_task_id}" raise AirflowException(msg) @@ -169,7 +171,7 @@ def __init__(self, exception: Exception, exception_traceback: str): def send_task_to_executor( task_tuple: TaskInstanceInCelery, -) -> Tuple[TaskInstanceKey, CommandType, Union[AsyncResult, ExceptionWithTraceback]]: +) -> tuple[TaskInstanceKey, CommandType, AsyncResult | ExceptionWithTraceback]: """Sends task to executor.""" key, command, queue, task_to_run = task_tuple try: @@ -233,36 +235,35 @@ def __init__(self): # Celery doesn't support bulk sending the tasks (which can become a bottleneck on bigger clusters) # so we use a multiprocessing pool to speed this up. # How many worker processes are created for checking celery task state. - self._sync_parallelism = conf.getint('celery', 'SYNC_PARALLELISM') + self._sync_parallelism = conf.getint("celery", "SYNC_PARALLELISM") if self._sync_parallelism == 0: self._sync_parallelism = max(1, cpu_count() - 1) self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism) self.tasks = {} - self.stalled_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = {} + self.stalled_task_timeouts: dict[TaskInstanceKey, datetime.datetime] = {} self.stalled_task_timeout = datetime.timedelta( - seconds=conf.getint('celery', 'stalled_task_timeout', fallback=0) + seconds=conf.getint("celery", "stalled_task_timeout", fallback=0) ) - self.adopted_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = {} + self.adopted_task_timeouts: dict[TaskInstanceKey, datetime.datetime] = {} self.task_adoption_timeout = ( - datetime.timedelta(seconds=conf.getint('celery', 'task_adoption_timeout', fallback=600)) + datetime.timedelta(seconds=conf.getint("celery", "task_adoption_timeout", fallback=600)) or self.stalled_task_timeout ) self.task_publish_retries: Counter[TaskInstanceKey] = Counter() - self.task_publish_max_retries = conf.getint('celery', 'task_publish_max_retries', fallback=3) + self.task_publish_max_retries = conf.getint("celery", "task_publish_max_retries", fallback=3) def start(self) -> None: - self.log.debug('Starting Celery Executor using %s processes for syncing', self._sync_parallelism) + self.log.debug("Starting Celery Executor using %s processes for syncing", self._sync_parallelism) def _num_tasks_per_send_process(self, to_send_count: int) -> int: """ How many Celery tasks should each worker process send. :return: Number of tasks that should be sent per process - :rtype: int """ return max(1, int(math.ceil(1.0 * to_send_count / self._sync_parallelism))) - def _process_tasks(self, task_tuples: List[TaskTuple]) -> None: + def _process_tasks(self, task_tuples: list[TaskTuple]) -> None: task_tuples_to_send = [task_tuple[:3] + (execute_command,) for task_tuple in task_tuples] first_task = next(t[3] for t in task_tuples_to_send) @@ -271,7 +272,7 @@ def _process_tasks(self, task_tuples: List[TaskTuple]) -> None: cached_celery_backend = first_task.backend key_and_async_results = self._send_tasks_to_celery(task_tuples_to_send) - self.log.debug('Sent all tasks.') + self.log.debug("Sent all tasks.") for key, _, result in key_and_async_results: if isinstance(result, ExceptionWithTraceback) and isinstance( @@ -305,9 +306,9 @@ def _process_tasks(self, task_tuples: List[TaskTuple]) -> None: self.event_buffer[key] = (State.QUEUED, result.task_id) # If the task runs _really quickly_ we may already have a result! - self.update_task_state(key, result.state, getattr(result, 'info', None)) + self.update_task_state(key, result.state, getattr(result, "info", None)) - def _send_tasks_to_celery(self, task_tuples_to_send: List[TaskInstanceInCelery]): + def _send_tasks_to_celery(self, task_tuples_to_send: list[TaskInstanceInCelery]): if len(task_tuples_to_send) == 1 or self._sync_parallelism == 1: # One tuple, or max one process -> send it in the main thread. return list(map(send_task_to_executor, task_tuples_to_send)) @@ -354,8 +355,8 @@ def _check_for_stalled_tasks(self) -> None: self._send_stalled_tis_back_to_scheduler(timedout_keys) def _get_timedout_ti_keys( - self, task_timeouts: Dict[TaskInstanceKey, datetime.datetime] - ) -> List[TaskInstanceKey]: + self, task_timeouts: dict[TaskInstanceKey, datetime.datetime] + ) -> list[TaskInstanceKey]: """ These timeouts exist to check to see if any of our tasks have not progressed in the expected time. This can happen for few different reasons, usually related @@ -387,7 +388,7 @@ def _get_timedout_ti_keys( @provide_session def _send_stalled_tis_back_to_scheduler( - self, keys: List[TaskInstanceKey], session: Session = NEW_SESSION + self, keys: list[TaskInstanceKey], session: Session = NEW_SESSION ) -> None: try: session.query(TaskInstance).filter( @@ -478,7 +479,7 @@ def end(self, synchronous: bool = False) -> None: def terminate(self): pass - def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]: + def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]: # See which of the TIs are still alive (or have finished even!) # # Since Celery doesn't store "SENT" state for queued commands (if we create an AsyncResult with a made @@ -528,7 +529,7 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance adopted.append(f"{ti} in state {state}") if adopted: - task_instance_str = '\n\t'.join(adopted) + task_instance_str = "\n\t".join(adopted) self.log.info( "Adopted the following %d tasks from a dead executor\n\t%s", len(adopted), task_instance_str ) @@ -536,7 +537,7 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance return not_adopted_tis def _set_celery_pending_task_timeout( - self, key: TaskInstanceKey, timeout_type: Optional[_CeleryPendingTaskTimeoutType] + self, key: TaskInstanceKey, timeout_type: _CeleryPendingTaskTimeoutType | None ) -> None: """ We use the fact that dicts maintain insertion order, and the the timeout for a @@ -551,7 +552,7 @@ def _set_celery_pending_task_timeout( self.stalled_task_timeouts[key] = utcnow() + self.stalled_task_timeout -def fetch_celery_task_state(async_result: AsyncResult) -> Tuple[str, Union[str, ExceptionWithTraceback], Any]: +def fetch_celery_task_state(async_result: AsyncResult) -> tuple[str, str | ExceptionWithTraceback, Any]: """ Fetch and return the state of the given celery task. The scope of this function is global so that it can be called by subprocesses in the pool. @@ -560,13 +561,12 @@ def fetch_celery_task_state(async_result: AsyncResult) -> Tuple[str, Union[str, to fetch the task's state :return: a tuple of the Celery task key and the Celery state and the celery info of the task - :rtype: tuple[str, str, str] """ try: with timeout(seconds=OPERATION_TIMEOUT): # Accessing state property of celery task will make actual network request # to get the current state of the task - info = async_result.info if hasattr(async_result, 'info') else None + info = async_result.info if hasattr(async_result, "info") else None return async_result.task_id, async_result.state, info except Exception as e: exception_traceback = f"Celery Task ID: {async_result}\n{traceback.format_exc()}" @@ -586,7 +586,7 @@ def __init__(self, sync_parallelism=None): super().__init__() self._sync_parallelism = sync_parallelism - def _tasks_list_to_task_ids(self, async_tasks) -> Set[str]: + def _tasks_list_to_task_ids(self, async_tasks) -> set[str]: return {a.task_id for a in async_tasks} def get_many(self, async_results) -> Mapping[str, EventBufferValueType]: diff --git a/airflow/executors/celery_kubernetes_executor.py b/airflow/executors/celery_kubernetes_executor.py index b1edc32235727..cf18158b09531 100644 --- a/airflow/executors/celery_kubernetes_executor.py +++ b/airflow/executors/celery_kubernetes_executor.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict, List, Optional, Set, Union +from __future__ import annotations + +from typing import Sequence from airflow.callbacks.base_callback_sink import BaseCallbackSink from airflow.callbacks.callback_requests import CallbackRequest @@ -37,19 +39,19 @@ class CeleryKubernetesExecutor(LoggingMixin): """ supports_ad_hoc_ti_run: bool = True - callback_sink: Optional[BaseCallbackSink] = None + callback_sink: BaseCallbackSink | None = None - KUBERNETES_QUEUE = conf.get('celery_kubernetes_executor', 'kubernetes_queue') + KUBERNETES_QUEUE = conf.get("celery_kubernetes_executor", "kubernetes_queue") def __init__(self, celery_executor: CeleryExecutor, kubernetes_executor: KubernetesExecutor): super().__init__() - self._job_id: Optional[int] = None + self._job_id: int | None = None self.celery_executor = celery_executor self.kubernetes_executor = kubernetes_executor self.kubernetes_executor.kubernetes_queue = self.KUBERNETES_QUEUE @property - def queued_tasks(self) -> Dict[TaskInstanceKey, QueuedTaskInstanceType]: + def queued_tasks(self) -> dict[TaskInstanceKey, QueuedTaskInstanceType]: """Return queued tasks from celery and kubernetes executor""" queued_tasks = self.celery_executor.queued_tasks.copy() queued_tasks.update(self.kubernetes_executor.queued_tasks) @@ -57,12 +59,12 @@ def queued_tasks(self) -> Dict[TaskInstanceKey, QueuedTaskInstanceType]: return queued_tasks @property - def running(self) -> Set[TaskInstanceKey]: + def running(self) -> set[TaskInstanceKey]: """Return running tasks from celery and kubernetes executor""" return self.celery_executor.running.union(self.kubernetes_executor.running) @property - def job_id(self) -> Optional[int]: + def job_id(self) -> int | None: """ This is a class attribute in BaseExecutor but since this is not really an executor, but a wrapper of executors we implement as property so we can have custom setter. @@ -70,7 +72,7 @@ def job_id(self) -> Optional[int]: return self._job_id @job_id.setter - def job_id(self, value: Optional[int]) -> None: + def job_id(self, value: int | None) -> None: """job_id is manipulated by SchedulerJob. We must propagate the job_id to wrapped executors.""" self._job_id = value self.kubernetes_executor.job_id = value @@ -91,7 +93,7 @@ def queue_command( task_instance: TaskInstance, command: CommandType, priority: int = 1, - queue: Optional[str] = None, + queue: str | None = None, ) -> None: """Queues command via celery or kubernetes executor""" executor = self._router(task_instance) @@ -102,13 +104,13 @@ def queue_task_instance( self, task_instance: TaskInstance, mark_success: bool = False, - pickle_id: Optional[str] = None, + pickle_id: str | None = None, ignore_all_deps: bool = False, ignore_depends_on_past: bool = False, ignore_task_deps: bool = False, ignore_ti_state: bool = False, - pool: Optional[str] = None, - cfg_path: Optional[str] = None, + pool: str | None = None, + cfg_path: str | None = None, ) -> None: """Queues task instance via celery or kubernetes executor""" executor = self._router(SimpleTaskInstance.from_ti(task_instance)) @@ -144,8 +146,8 @@ def heartbeat(self) -> None: self.kubernetes_executor.heartbeat() def get_event_buffer( - self, dag_ids: Optional[List[str]] = None - ) -> Dict[TaskInstanceKey, EventBufferValueType]: + self, dag_ids: list[str] | None = None + ) -> dict[TaskInstanceKey, EventBufferValueType]: """ Returns and flush the event buffer from celery and kubernetes executor @@ -157,7 +159,7 @@ def get_event_buffer( return {**cleared_events_from_celery, **cleared_events_from_kubernetes} - def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]: + def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]: """ Try to adopt running task instances that have been abandoned by a SchedulerJob dying. @@ -165,19 +167,13 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance re-scheduling) :return: any TaskInstances that were unable to be adopted - :rtype: list[airflow.models.TaskInstance] """ - celery_tis = [] - kubernetes_tis = [] - abandoned_tis = [] - for ti in tis: - if ti.queue == self.KUBERNETES_QUEUE: - kubernetes_tis.append(ti) - else: - celery_tis.append(ti) - abandoned_tis.extend(self.celery_executor.try_adopt_task_instances(celery_tis)) - abandoned_tis.extend(self.kubernetes_executor.try_adopt_task_instances(kubernetes_tis)) - return abandoned_tis + celery_tis = [ti for ti in tis if ti.queue != self.KUBERNETES_QUEUE] + kubernetes_tis = [ti for ti in tis if ti.queue == self.KUBERNETES_QUEUE] + return [ + *self.celery_executor.try_adopt_task_instances(celery_tis), + *self.kubernetes_executor.try_adopt_task_instances(kubernetes_tis), + ] def end(self) -> None: """End celery and kubernetes executor""" @@ -189,13 +185,12 @@ def terminate(self) -> None: self.celery_executor.terminate() self.kubernetes_executor.terminate() - def _router(self, simple_task_instance: SimpleTaskInstance) -> Union[CeleryExecutor, KubernetesExecutor]: + def _router(self, simple_task_instance: SimpleTaskInstance) -> CeleryExecutor | KubernetesExecutor: """ Return either celery_executor or kubernetes_executor :param simple_task_instance: SimpleTaskInstance :return: celery_executor or kubernetes_executor - :rtype: Union[CeleryExecutor, KubernetesExecutor] """ if simple_task_instance.queue == self.KUBERNETES_QUEUE: return self.kubernetes_executor diff --git a/airflow/executors/dask_executor.py b/airflow/executors/dask_executor.py index 5d0896455e9fd..a2c2c571630a6 100644 --- a/airflow/executors/dask_executor.py +++ b/airflow/executors/dask_executor.py @@ -22,20 +22,22 @@ For more information on how the DaskExecutor works, take a look at the guide: :ref:`executor:DaskExecutor` """ +from __future__ import annotations + import subprocess -from typing import Any, Dict, Optional +from typing import TYPE_CHECKING, Any from distributed import Client, Future, as_completed from distributed.security import Security from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.executors.base_executor import NOT_STARTED_MESSAGE, BaseExecutor, CommandType +from airflow.executors.base_executor import BaseExecutor, CommandType from airflow.models.taskinstance import TaskInstanceKey # queue="default" is a special case since this is the base config default queue name, # with respect to DaskExecutor, treat it as if no queue is provided -_UNDEFINED_QUEUES = {None, 'default'} +_UNDEFINED_QUEUES = {None, "default"} class DaskExecutor(BaseExecutor): @@ -44,16 +46,16 @@ class DaskExecutor(BaseExecutor): def __init__(self, cluster_address=None): super().__init__(parallelism=0) if cluster_address is None: - cluster_address = conf.get('dask', 'cluster_address') + cluster_address = conf.get("dask", "cluster_address") if not cluster_address: - raise ValueError('Please provide a Dask cluster address in airflow.cfg') + raise ValueError("Please provide a Dask cluster address in airflow.cfg") self.cluster_address = cluster_address # ssl / tls parameters - self.tls_ca = conf.get('dask', 'tls_ca') - self.tls_key = conf.get('dask', 'tls_key') - self.tls_cert = conf.get('dask', 'tls_cert') - self.client: Optional[Client] = None - self.futures: Optional[Dict[Future, TaskInstanceKey]] = None + self.tls_ca = conf.get("dask", "tls_ca") + self.tls_key = conf.get("dask", "tls_key") + self.tls_cert = conf.get("dask", "tls_cert") + self.client: Client | None = None + self.futures: dict[Future, TaskInstanceKey] | None = None def start(self) -> None: if self.tls_ca or self.tls_key or self.tls_cert: @@ -73,23 +75,22 @@ def execute_async( self, key: TaskInstanceKey, command: CommandType, - queue: Optional[str] = None, - executor_config: Optional[Any] = None, + queue: str | None = None, + executor_config: Any | None = None, ) -> None: + if TYPE_CHECKING: + assert self.client - self.validate_command(command) + self.validate_airflow_tasks_run_command(command) def airflow_run(): return subprocess.check_call(command, close_fds=True) - if not self.client: - raise AirflowException(NOT_STARTED_MESSAGE) - resources = None if queue not in _UNDEFINED_QUEUES: scheduler_info = self.client.scheduler_info() avail_queues = { - resource for d in scheduler_info['workers'].values() for resource in d['resources'] + resource for d in scheduler_info["workers"].values() for resource in d["resources"] } if queue not in avail_queues: @@ -100,8 +101,9 @@ def airflow_run(): self.futures[future] = key # type: ignore def _process_future(self, future: Future) -> None: - if not self.futures: - raise AirflowException(NOT_STARTED_MESSAGE) + if TYPE_CHECKING: + assert self.futures + if future.done(): key = self.futures[future] if future.exception(): @@ -115,23 +117,25 @@ def _process_future(self, future: Future) -> None: self.futures.pop(future) def sync(self) -> None: - if self.futures is None: - raise AirflowException(NOT_STARTED_MESSAGE) + if TYPE_CHECKING: + assert self.futures + # make a copy so futures can be popped during iteration for future in self.futures.copy(): self._process_future(future) def end(self) -> None: - if not self.client: - raise AirflowException(NOT_STARTED_MESSAGE) - if self.futures is None: - raise AirflowException(NOT_STARTED_MESSAGE) + if TYPE_CHECKING: + assert self.client + assert self.futures + self.client.cancel(list(self.futures.keys())) for future in as_completed(self.futures.copy()): self._process_future(future) def terminate(self): - if self.futures is None: - raise AirflowException(NOT_STARTED_MESSAGE) + if TYPE_CHECKING: + assert self.futures + self.client.cancel(self.futures.keys()) self.end() diff --git a/airflow/executors/debug_executor.py b/airflow/executors/debug_executor.py index 7a2dddec79358..d727ee8795112 100644 --- a/airflow/executors/debug_executor.py +++ b/airflow/executors/debug_executor.py @@ -22,9 +22,10 @@ For more information on how the DebugExecutor works, take a look at the guide: :ref:`executor:DebugExecutor` """ +from __future__ import annotations import threading -from typing import Any, Dict, List, Optional +from typing import Any from airflow.configuration import conf from airflow.executors.base_executor import BaseExecutor @@ -44,9 +45,9 @@ class DebugExecutor(BaseExecutor): def __init__(self): super().__init__() - self.tasks_to_run: List[TaskInstance] = [] + self.tasks_to_run: list[TaskInstance] = [] # Place where we keep information for task instance raw run - self.tasks_params: Dict[TaskInstanceKey, Dict[str, Any]] = {} + self.tasks_params: dict[TaskInstanceKey, dict[str, Any]] = {} self.fail_fast = conf.getboolean("debug", "fail_fast") def execute_async(self, *args, **kwargs) -> None: @@ -75,7 +76,7 @@ def _run_task(self, ti: TaskInstance) -> bool: key = ti.key try: params = self.tasks_params.pop(ti.key, {}) - ti._run_raw_task(job_id=ti.job_id, **params) + ti.run(job_id=ti.job_id, **params) self.change_state(key, State.SUCCESS) return True except Exception as e: @@ -88,13 +89,13 @@ def queue_task_instance( self, task_instance: TaskInstance, mark_success: bool = False, - pickle_id: Optional[str] = None, + pickle_id: str | None = None, ignore_all_deps: bool = False, ignore_depends_on_past: bool = False, ignore_task_deps: bool = False, ignore_ti_state: bool = False, - pool: Optional[str] = None, - cfg_path: Optional[str] = None, + pool: str | None = None, + cfg_path: str | None = None, ) -> None: """Queues task instance with empty command because we do not need it.""" self.queue_command( diff --git a/airflow/executors/executor_constants.py b/airflow/executors/executor_constants.py index 55a3a7f766301..c6219f3b4a965 100644 --- a/airflow/executors/executor_constants.py +++ b/airflow/executors/executor_constants.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations LOCAL_EXECUTOR = "LocalExecutor" LOCAL_KUBERNETES_EXECUTOR = "LocalKubernetesExecutor" diff --git a/airflow/executors/executor_loader.py b/airflow/executors/executor_loader.py index 723060db0f179..56802017e4fe3 100644 --- a/airflow/executors/executor_loader.py +++ b/airflow/executors/executor_loader.py @@ -15,10 +15,12 @@ # specific language governing permissions and limitations # under the License. """All executors.""" +from __future__ import annotations + import logging from contextlib import suppress from enum import Enum, unique -from typing import TYPE_CHECKING, Optional, Tuple, Type +from typing import TYPE_CHECKING from airflow.exceptions import AirflowConfigException from airflow.executors.executor_constants import ( @@ -51,33 +53,33 @@ class ConnectorSource(Enum): class ExecutorLoader: """Keeps constants for all the currently available executors.""" - _default_executor: Optional["BaseExecutor"] = None + _default_executor: BaseExecutor | None = None executors = { - LOCAL_EXECUTOR: 'airflow.executors.local_executor.LocalExecutor', - LOCAL_KUBERNETES_EXECUTOR: 'airflow.executors.local_kubernetes_executor.LocalKubernetesExecutor', - SEQUENTIAL_EXECUTOR: 'airflow.executors.sequential_executor.SequentialExecutor', - CELERY_EXECUTOR: 'airflow.executors.celery_executor.CeleryExecutor', - CELERY_KUBERNETES_EXECUTOR: 'airflow.executors.celery_kubernetes_executor.CeleryKubernetesExecutor', - DASK_EXECUTOR: 'airflow.executors.dask_executor.DaskExecutor', - KUBERNETES_EXECUTOR: 'airflow.executors.kubernetes_executor.KubernetesExecutor', - DEBUG_EXECUTOR: 'airflow.executors.debug_executor.DebugExecutor', + LOCAL_EXECUTOR: "airflow.executors.local_executor.LocalExecutor", + LOCAL_KUBERNETES_EXECUTOR: "airflow.executors.local_kubernetes_executor.LocalKubernetesExecutor", + SEQUENTIAL_EXECUTOR: "airflow.executors.sequential_executor.SequentialExecutor", + CELERY_EXECUTOR: "airflow.executors.celery_executor.CeleryExecutor", + CELERY_KUBERNETES_EXECUTOR: "airflow.executors.celery_kubernetes_executor.CeleryKubernetesExecutor", + DASK_EXECUTOR: "airflow.executors.dask_executor.DaskExecutor", + KUBERNETES_EXECUTOR: "airflow.executors.kubernetes_executor.KubernetesExecutor", + DEBUG_EXECUTOR: "airflow.executors.debug_executor.DebugExecutor", } @classmethod - def get_default_executor(cls) -> "BaseExecutor": - """Creates a new instance of the configured executor if none exists and returns it""" + def get_default_executor(cls) -> BaseExecutor: + """Creates a new instance of the configured executor if none exists and returns it.""" if cls._default_executor is not None: return cls._default_executor from airflow.configuration import conf - executor_name = conf.get_mandatory_value('core', 'EXECUTOR') + executor_name = conf.get_mandatory_value("core", "EXECUTOR") cls._default_executor = cls.load_executor(executor_name) return cls._default_executor @classmethod - def load_executor(cls, executor_name: str) -> "BaseExecutor": + def load_executor(cls, executor_name: str) -> BaseExecutor: """ Loads the executor. @@ -107,7 +109,7 @@ def load_executor(cls, executor_name: str) -> "BaseExecutor": return executor_cls() @classmethod - def import_executor_cls(cls, executor_name: str) -> Tuple[Type["BaseExecutor"], ConnectorSource]: + def import_executor_cls(cls, executor_name: str) -> tuple[type[BaseExecutor], ConnectorSource]: """ Imports the executor class. @@ -133,8 +135,7 @@ def import_executor_cls(cls, executor_name: str) -> Tuple[Type["BaseExecutor"], return import_string(executor_name), ConnectorSource.CUSTOM_PATH @classmethod - def __load_celery_kubernetes_executor(cls) -> "BaseExecutor": - """:return: an instance of CeleryKubernetesExecutor""" + def __load_celery_kubernetes_executor(cls) -> BaseExecutor: celery_executor = import_string(cls.executors[CELERY_EXECUTOR])() kubernetes_executor = import_string(cls.executors[KUBERNETES_EXECUTOR])() @@ -142,8 +143,7 @@ def __load_celery_kubernetes_executor(cls) -> "BaseExecutor": return celery_kubernetes_executor_cls(celery_executor, kubernetes_executor) @classmethod - def __load_local_kubernetes_executor(cls) -> "BaseExecutor": - """:return: an instance of LocalKubernetesExecutor""" + def __load_local_kubernetes_executor(cls) -> BaseExecutor: local_executor = import_string(cls.executors[LOCAL_EXECUTOR])() kubernetes_executor = import_string(cls.executors[KUBERNETES_EXECUTOR])() diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index c76cf58f418d4..28f720f35e1c2 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -21,22 +21,24 @@ For more information on how the KubernetesExecutor works, take a look at the guide: :ref:`executor:KubernetesExecutor` """ +from __future__ import annotations import functools import json +import logging import multiprocessing import time from datetime import timedelta from queue import Empty, Queue -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple from kubernetes import client, watch from kubernetes.client import Configuration, models as k8s from kubernetes.client.rest import ApiException from urllib3.exceptions import ReadTimeoutError -from airflow.exceptions import AirflowException -from airflow.executors.base_executor import NOT_STARTED_MESSAGE, BaseExecutor, CommandType +from airflow.exceptions import AirflowException, PodMutationHookException, PodReconciliationError +from airflow.executors.base_executor import BaseExecutor, CommandType from airflow.kubernetes import pod_generator from airflow.kubernetes.kube_client import get_kube_client from airflow.kubernetes.kube_config import KubeConfig @@ -77,10 +79,10 @@ class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin): def __init__( self, - namespace: Optional[str], + namespace: str | None, multi_namespace_mode: bool, - watcher_queue: 'Queue[KubernetesWatchType]', - resource_version: Optional[str], + watcher_queue: Queue[KubernetesWatchType], + resource_version: str | None, scheduler_job_id: str, kube_config: Configuration, ): @@ -94,9 +96,10 @@ def __init__( def run(self) -> None: """Performs watching""" + if TYPE_CHECKING: + assert self.scheduler_job_id + kube_client: client.CoreV1Api = get_kube_client() - if not self.scheduler_job_id: - raise AirflowException(NOT_STARTED_MESSAGE) while True: try: self.resource_version = self._run( @@ -108,34 +111,34 @@ def run(self) -> None: ) time.sleep(1) except Exception: - self.log.exception('Unknown error in KubernetesJobWatcher. Failing') + self.log.exception("Unknown error in KubernetesJobWatcher. Failing") self.resource_version = "0" ResourceVersion().resource_version = "0" raise else: self.log.warning( - 'Watch died gracefully, starting back up with: last resource_version: %s', + "Watch died gracefully, starting back up with: last resource_version: %s", self.resource_version, ) def _run( self, kube_client: client.CoreV1Api, - resource_version: Optional[str], + resource_version: str | None, scheduler_job_id: str, kube_config: Any, - ) -> Optional[str]: - self.log.info('Event: and now my watch begins starting at resource_version: %s', resource_version) + ) -> str | None: + self.log.info("Event: and now my watch begins starting at resource_version: %s", resource_version) watcher = watch.Watch() - kwargs = {'label_selector': f'airflow-worker={scheduler_job_id}'} + kwargs = {"label_selector": f"airflow-worker={scheduler_job_id}"} if resource_version: - kwargs['resource_version'] = resource_version + kwargs["resource_version"] = resource_version if kube_config.kube_client_request_args: for key, value in kube_config.kube_client_request_args.items(): kwargs[key] = value - last_resource_version: Optional[str] = None + last_resource_version: str | None = None if self.multi_namespace_mode: list_worker_pods = functools.partial( watcher.stream, kube_client.list_pod_for_all_namespaces, **kwargs @@ -145,21 +148,21 @@ def _run( watcher.stream, kube_client.list_namespaced_pod, self.namespace, **kwargs ) for event in list_worker_pods(): - task = event['object'] - self.log.info('Event: %s had an event of type %s', task.metadata.name, event['type']) - if event['type'] == 'ERROR': + task = event["object"] + self.log.debug("Event: %s had an event of type %s", task.metadata.name, event["type"]) + if event["type"] == "ERROR": return self.process_error(event) annotations = task.metadata.annotations task_instance_related_annotations = { - 'dag_id': annotations['dag_id'], - 'task_id': annotations['task_id'], - 'execution_date': annotations.get('execution_date'), - 'run_id': annotations.get('run_id'), - 'try_number': annotations['try_number'], + "dag_id": annotations["dag_id"], + "task_id": annotations["task_id"], + "execution_date": annotations.get("execution_date"), + "run_id": annotations.get("run_id"), + "try_number": annotations["try_number"], } - map_index = annotations.get('map_index') + map_index = annotations.get("map_index") if map_index is not None: - task_instance_related_annotations['map_index'] = map_index + task_instance_related_annotations["map_index"] = map_index self.process_status( pod_id=task.metadata.name, @@ -175,14 +178,14 @@ def _run( def process_error(self, event: Any) -> str: """Process error response""" - self.log.error('Encountered Error response from k8s list namespaced pod stream => %s', event) - raw_object = event['raw_object'] - if raw_object['code'] == 410: + self.log.error("Encountered Error response from k8s list namespaced pod stream => %s", event) + raw_object = event["raw_object"] + if raw_object["code"] == 410: self.log.info( - 'Kubernetes resource version is too old, must reset to 0 => %s', (raw_object['message'],) + "Kubernetes resource version is too old, must reset to 0 => %s", (raw_object["message"],) ) # Return resource version 0 - return '0' + return "0" raise AirflowException( f"Kubernetes failure for {raw_object['reason']} with code {raw_object['code']} and message: " f"{raw_object['message']}" @@ -193,33 +196,33 @@ def process_status( pod_id: str, namespace: str, status: str, - annotations: Dict[str, str], + annotations: dict[str, str], resource_version: str, event: Any, ) -> None: """Process status response""" - if status == 'Pending': - if event['type'] == 'DELETED': - self.log.info('Event: Failed to start pod %s', pod_id) + if status == "Pending": + if event["type"] == "DELETED": + self.log.info("Event: Failed to start pod %s", pod_id) self.watcher_queue.put((pod_id, namespace, State.FAILED, annotations, resource_version)) else: - self.log.info('Event: %s Pending', pod_id) - elif status == 'Failed': - self.log.error('Event: %s Failed', pod_id) + self.log.debug("Event: %s Pending", pod_id) + elif status == "Failed": + self.log.error("Event: %s Failed", pod_id) self.watcher_queue.put((pod_id, namespace, State.FAILED, annotations, resource_version)) - elif status == 'Succeeded': - self.log.info('Event: %s Succeeded', pod_id) + elif status == "Succeeded": + self.log.info("Event: %s Succeeded", pod_id) self.watcher_queue.put((pod_id, namespace, None, annotations, resource_version)) - elif status == 'Running': - if event['type'] == 'DELETED': - self.log.info('Event: Pod %s deleted before it could complete', pod_id) + elif status == "Running": + if event["type"] == "DELETED": + self.log.info("Event: Pod %s deleted before it could complete", pod_id) self.watcher_queue.put((pod_id, namespace, State.FAILED, annotations, resource_version)) else: - self.log.info('Event: %s is Running', pod_id) + self.log.info("Event: %s is Running", pod_id) else: self.log.warning( - 'Event: Invalid state: %s on pod: %s in namespace %s with annotations: %s with ' - 'resource_version: %s', + "Event: Invalid state: %s on pod: %s in namespace %s with annotations: %s with " + "resource_version: %s", status, pod_id, namespace, @@ -234,8 +237,8 @@ class AirflowKubernetesScheduler(LoggingMixin): def __init__( self, kube_config: Any, - task_queue: 'Queue[KubernetesJobType]', - result_queue: 'Queue[KubernetesResultsType]', + task_queue: Queue[KubernetesJobType], + result_queue: Queue[KubernetesResultsType], kube_client: client.CoreV1Api, scheduler_job_id: str, ): @@ -254,19 +257,22 @@ def __init__( def run_pod_async(self, pod: k8s.V1Pod, **kwargs): """Runs POD asynchronously""" - pod_mutation_hook(pod) + try: + pod_mutation_hook(pod) + except Exception as e: + raise PodMutationHookException(e) sanitized_pod = self.kube_client.api_client.sanitize_for_serialization(pod) json_pod = json.dumps(sanitized_pod, indent=2) - self.log.debug('Pod Creation Request: \n%s', json_pod) + self.log.debug("Pod Creation Request: \n%s", json_pod) try: resp = self.kube_client.create_namespaced_pod( body=sanitized_pod, namespace=pod.metadata.namespace, **kwargs ) - self.log.debug('Pod Creation Response: %s', resp) + self.log.debug("Pod Creation Response: %s", resp) except Exception as e: - self.log.exception('Exception when attempting to create Namespaced Pod: %s', json_pod) + self.log.exception("Exception when attempting to create Namespaced Pod: %s", json_pod) raise e return resp @@ -288,7 +294,7 @@ def _health_check_kube_watcher(self): self.log.debug("KubeJobWatcher alive, continuing") else: self.log.error( - 'Error while health checking kube watcher process. Process died for unknown reasons' + "Error while health checking kube watcher process. Process died for unknown reasons" ) ResourceVersion().resource_version = "0" self.kube_watcher = self._make_kube_watcher() @@ -300,8 +306,8 @@ def run_next(self, next_job: KubernetesJobType) -> None: and store relevant info in the current_jobs map so we can track the job's status """ - self.log.info('Kubernetes job is %s', str(next_job).replace("\n", " ")) key, command, kube_executor_config, pod_template_file = next_job + dag_id, task_id, run_id, try_number, map_index = key if command[0:3] != ["airflow", "tasks", "run"]: @@ -331,6 +337,7 @@ def run_next(self, next_job: KubernetesJobType) -> None: ) # Reconcile the pod generated by the Operator and the Pod # generated by the .cfg file + self.log.info("Creating kubernetes pod for job is %s, with pod name %s", key, pod.metadata.name) self.log.debug("Kubernetes running for command %s", command) self.log.debug("Kubernetes launching image %s", pod.spec.containers[0].image) @@ -378,21 +385,21 @@ def sync(self) -> None: def process_watcher_task(self, task: KubernetesWatchType) -> None: """Process the task by watcher.""" pod_id, namespace, state, annotations, resource_version = task - self.log.info( - 'Attempting to finish pod; pod_id: %s; state: %s; annotations: %s', pod_id, state, annotations + self.log.debug( + "Attempting to finish pod; pod_id: %s; state: %s; annotations: %s", pod_id, state, annotations ) key = annotations_to_key(annotations=annotations) if key: - self.log.debug('finishing job %s - %s (%s)', key, state, pod_id) + self.log.debug("finishing job %s - %s (%s)", key, state, pod_id) self.result_queue.put((key, state, pod_id, namespace, resource_version)) def _flush_watcher_queue(self) -> None: - self.log.debug('Executor shutting down, watcher_queue approx. size=%d', self.watcher_queue.qsize()) + self.log.debug("Executor shutting down, watcher_queue approx. size=%d", self.watcher_queue.qsize()) while True: try: task = self.watcher_queue.get_nowait() # Ignoring it since it can only have either FAILED or SUCCEEDED pods - self.log.warning('Executor shutting down, IGNORING watcher task=%s', task) + self.log.warning("Executor shutting down, IGNORING watcher task=%s", task) self.watcher_queue.task_done() except Empty: break @@ -411,7 +418,7 @@ def terminate(self) -> None: self._manager.shutdown() -def get_base_pod_from_template(pod_template_file: Optional[str], kube_config: Any) -> k8s.V1Pod: +def get_base_pod_from_template(pod_template_file: str | None, kube_config: Any) -> k8s.V1Pod: """ Reads either the pod_template_file set in the executor_config or the base pod_template_file set in the airflow.cfg to craft a "base pod" that will be used by the KubernetesExecutor @@ -434,14 +441,14 @@ class KubernetesExecutor(BaseExecutor): def __init__(self): self.kube_config = KubeConfig() self._manager = multiprocessing.Manager() - self.task_queue: 'Queue[KubernetesJobType]' = self._manager.Queue() - self.result_queue: 'Queue[KubernetesResultsType]' = self._manager.Queue() - self.kube_scheduler: Optional[AirflowKubernetesScheduler] = None - self.kube_client: Optional[client.CoreV1Api] = None - self.scheduler_job_id: Optional[str] = None - self.event_scheduler: Optional[EventScheduler] = None - self.last_handled: Dict[TaskInstanceKey, float] = {} - self.kubernetes_queue: Optional[str] = None + self.task_queue: Queue[KubernetesJobType] = self._manager.Queue() + self.result_queue: Queue[KubernetesResultsType] = self._manager.Queue() + self.kube_scheduler: AirflowKubernetesScheduler | None = None + self.kube_client: client.CoreV1Api | None = None + self.scheduler_job_id: str | None = None + self.event_scheduler: EventScheduler | None = None + self.last_handled: dict[TaskInstanceKey, float] = {} + self.kubernetes_queue: str | None = None super().__init__(parallelism=self.kube_config.parallelism) @provide_session @@ -457,15 +464,17 @@ def clear_not_launched_queued_tasks(self, session=None) -> None: is around, and if not, and there's no matching entry in our own task_queue, marks it for re-execution. """ - self.log.debug("Clearing tasks that have not been launched") - if not self.kube_client: - raise AirflowException(NOT_STARTED_MESSAGE) + if TYPE_CHECKING: + assert self.kube_client - query = session.query(TaskInstance).filter(TaskInstance.state == State.QUEUED) + self.log.debug("Clearing tasks that have not been launched") + query = session.query(TaskInstance).filter( + TaskInstance.state == State.QUEUED, TaskInstance.queued_by_job_id == self.job_id + ) if self.kubernetes_queue: query = query.filter(TaskInstance.queue == self.kubernetes_queue) - queued_tis: List[TaskInstance] = query.all() - self.log.info('Found %s queued task instances', len(queued_tis)) + queued_tis: list[TaskInstance] = query.all() + self.log.info("Found %s queued task instances", len(queued_tis)) # Go through the "last seen" dictionary and clean out old entries allowed_age = self.kube_config.worker_pods_queued_check_interval * 3 @@ -488,25 +497,25 @@ def clear_not_launched_queued_tasks(self, session=None) -> None: ) if ti.map_index >= 0: # Old tasks _couldn't_ be mapped, so we don't have to worry about compat - base_label_selector += f',map_index={ti.map_index}' + base_label_selector += f",map_index={ti.map_index}" kwargs = dict(label_selector=base_label_selector) if self.kube_config.kube_client_request_args: kwargs.update(**self.kube_config.kube_client_request_args) # Try run_id first - kwargs['label_selector'] += ',run_id=' + pod_generator.make_safe_label_value(ti.run_id) + kwargs["label_selector"] += ",run_id=" + pod_generator.make_safe_label_value(ti.run_id) pod_list = self.kube_client.list_namespaced_pod(self.kube_config.kube_namespace, **kwargs) if pod_list.items: continue # Fallback to old style of using execution_date - kwargs['label_selector'] = ( - f'{base_label_selector},' - f'execution_date={pod_generator.datetime_to_label_safe_datestring(ti.execution_date)}' + kwargs["label_selector"] = ( + f"{base_label_selector}," + f"execution_date={pod_generator.datetime_to_label_safe_datestring(ti.execution_date)}" ) pod_list = self.kube_client.list_namespaced_pod(self.kube_config.kube_namespace, **kwargs) if pod_list.items: continue - self.log.info('TaskInstance: %s found in queued state but was not launched, rescheduling', ti) + self.log.info("TaskInstance: %s found in queued state but was not launched, rescheduling", ti) session.query(TaskInstance).filter( TaskInstance.dag_id == ti.dag_id, TaskInstance.task_id == ti.task_id, @@ -516,11 +525,11 @@ def clear_not_launched_queued_tasks(self, session=None) -> None: def start(self) -> None: """Starts the executor""" - self.log.info('Start Kubernetes executor') + self.log.info("Start Kubernetes executor") if not self.job_id: raise AirflowException("Could not get scheduler_job_id") self.scheduler_job_id = str(self.job_id) - self.log.debug('Start with scheduler_job_id: %s', self.scheduler_job_id) + self.log.debug("Start with scheduler_job_id: %s", self.scheduler_job_id) self.kube_client = get_kube_client() self.kube_scheduler = AirflowKubernetesScheduler( self.kube_config, self.task_queue, self.result_queue, self.kube_client, self.scheduler_job_id @@ -530,6 +539,7 @@ def start(self) -> None: self.kube_config.worker_pods_pending_timeout_check_interval, self._check_worker_pods_pending_timeout, ) + self.event_scheduler.call_regular_interval( self.kube_config.worker_pods_queued_check_interval, self.clear_not_launched_queued_tasks, @@ -542,15 +552,22 @@ def execute_async( self, key: TaskInstanceKey, command: CommandType, - queue: Optional[str] = None, - executor_config: Optional[Any] = None, + queue: str | None = None, + executor_config: Any | None = None, ) -> None: """Executes task asynchronously""" - self.log.info('Add task %s with command %s with executor_config %s', key, command, executor_config) + if TYPE_CHECKING: + assert self.task_queue + + if self.log.isEnabledFor(logging.DEBUG): + self.log.debug("Add task %s with command %s, executor_config %s", key, command, executor_config) + else: + self.log.info("Add task %s with command %s", key, command) + try: kube_executor_config = PodGenerator.from_obj(executor_config) except Exception: - self.log.error("Invalid executor_config for %s", key) + self.log.error("Invalid executor_config for %s. Executor_config: %s", key, executor_config) self.fail(key=key, info="Invalid executor_config passed") return @@ -558,8 +575,6 @@ def execute_async( pod_template_file = executor_config.get("pod_template_file", None) else: pod_template_file = None - if not self.task_queue: - raise AirflowException(NOT_STARTED_MESSAGE) self.event_buffer[key] = (State.QUEUED, self.scheduler_job_id) self.task_queue.put((key, command, kube_executor_config, pod_template_file)) # We keep a temporary local record that we've handled this so we don't @@ -568,22 +583,18 @@ def execute_async( def sync(self) -> None: """Synchronize task state.""" + if TYPE_CHECKING: + assert self.scheduler_job_id + assert self.kube_scheduler + assert self.kube_config + assert self.result_queue + assert self.task_queue + assert self.event_scheduler + if self.running: - self.log.debug('self.running: %s', self.running) + self.log.debug("self.running: %s", self.running) if self.queued_tasks: - self.log.debug('self.queued: %s', self.queued_tasks) - if not self.scheduler_job_id: - raise AirflowException(NOT_STARTED_MESSAGE) - if not self.kube_scheduler: - raise AirflowException(NOT_STARTED_MESSAGE) - if not self.kube_config: - raise AirflowException(NOT_STARTED_MESSAGE) - if not self.result_queue: - raise AirflowException(NOT_STARTED_MESSAGE) - if not self.task_queue: - raise AirflowException(NOT_STARTED_MESSAGE) - if not self.event_scheduler: - raise AirflowException(NOT_STARTED_MESSAGE) + self.log.debug("self.queued: %s", self.queued_tasks) self.kube_scheduler.sync() last_resource_version = None @@ -593,7 +604,7 @@ def sync(self) -> None: try: key, state, pod_id, namespace, resource_version = results last_resource_version = resource_version - self.log.info('Changing state of %s to %s', results, state) + self.log.info("Changing state of %s to %s", results, state) try: self._change_state(key, state, pod_id, namespace) except Exception as e: @@ -617,8 +628,14 @@ def sync(self) -> None: task = self.task_queue.get_nowait() try: self.kube_scheduler.run_next(task) + except PodReconciliationError as e: + self.log.error( + "Pod reconciliation failed, likely due to kubernetes library upgrade. " + "Try clearing the task to re-run.", + exc_info=True, + ) + self.fail(task[0], e) except ApiException as e: - # These codes indicate something is wrong with pod definition; otherwise we assume pod # definition is ok, and that retrying may work if e.status in (400, 422): @@ -627,11 +644,19 @@ def sync(self) -> None: self.change_state(key, State.FAILED, e) else: self.log.warning( - 'ApiException when attempting to run task, re-queueing. Reason: %r. Message: %s', + "ApiException when attempting to run task, re-queueing. Reason: %r. Message: %s", e.reason, - json.loads(e.body)['message'], + json.loads(e.body)["message"], ) self.task_queue.put(task) + except PodMutationHookException as e: + key, _, _, _ = task + self.log.error( + "Pod Mutation Hook failed for the task %s. Failing task. Details: %s", + key, + e, + ) + self.fail(key, e) finally: self.task_queue.task_done() except Empty: @@ -643,15 +668,16 @@ def sync(self) -> None: def _check_worker_pods_pending_timeout(self): """Check if any pending worker pods have timed out""" - if not self.scheduler_job_id: - raise AirflowException(NOT_STARTED_MESSAGE) + if TYPE_CHECKING: + assert self.scheduler_job_id + timeout = self.kube_config.worker_pods_pending_timeout - self.log.debug('Looking for pending worker pods older than %d seconds', timeout) + self.log.debug("Looking for pending worker pods older than %d seconds", timeout) kwargs = { - 'limit': self.kube_config.worker_pods_pending_timeout_batch_size, - 'field_selector': 'status.phase=Pending', - 'label_selector': f'airflow-worker={self.scheduler_job_id}', + "limit": self.kube_config.worker_pods_pending_timeout_batch_size, + "field_selector": "status.phase=Pending", + "label_selector": f"airflow-worker={self.scheduler_job_id}", **self.kube_config.kube_client_request_args, } if self.kube_config.multi_namespace_mode: @@ -670,35 +696,36 @@ def _check_worker_pods_pending_timeout(self): self.log.error( ( 'Pod "%s" has been pending for longer than %d seconds.' - 'It will be deleted and set to failed.' + "It will be deleted and set to failed." ), pod.metadata.name, timeout, ) self.kube_scheduler.delete_pod(pod.metadata.name, pod.metadata.namespace) - def _change_state(self, key: TaskInstanceKey, state: Optional[str], pod_id: str, namespace: str) -> None: + def _change_state(self, key: TaskInstanceKey, state: str | None, pod_id: str, namespace: str) -> None: + if TYPE_CHECKING: + assert self.kube_scheduler + if state != State.RUNNING: if self.kube_config.delete_worker_pods: - if not self.kube_scheduler: - raise AirflowException(NOT_STARTED_MESSAGE) if state != State.FAILED or self.kube_config.delete_worker_pods_on_failure: self.kube_scheduler.delete_pod(pod_id, namespace) - self.log.info('Deleted pod: %s in namespace %s', str(key), str(namespace)) + self.log.info("Deleted pod: %s in namespace %s", str(key), str(namespace)) try: self.running.remove(key) except KeyError: - self.log.debug('Could not find key: %s', str(key)) + self.log.debug("Could not find key: %s", str(key)) self.event_buffer[key] = state, None - def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]: + def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]: tis_to_flush = [ti for ti in tis if not ti.queued_by_job_id] scheduler_job_ids = {ti.queued_by_job_id for ti in tis} pod_ids = {ti.key: ti for ti in tis if ti.queued_by_job_id} kube_client: client.CoreV1Api = self.kube_client for scheduler_job_id in scheduler_job_ids: scheduler_job_id = pod_generator.make_safe_label_value(str(scheduler_job_id)) - kwargs = {'label_selector': f'airflow-worker={scheduler_job_id}'} + kwargs = {"label_selector": f"airflow-worker={scheduler_job_id}"} pod_list = kube_client.list_namespaced_pod(namespace=self.kube_config.kube_namespace, **kwargs) for pod in pod_list.items: self.adopt_launched_task(kube_client, pod, pod_ids) @@ -707,7 +734,7 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance return tis_to_flush def adopt_launched_task( - self, kube_client: client.CoreV1Api, pod: k8s.V1Pod, pod_ids: Dict[TaskInstanceKey, k8s.V1Pod] + self, kube_client: client.CoreV1Api, pod: k8s.V1Pod, pod_ids: dict[TaskInstanceKey, k8s.V1Pod] ) -> None: """ Patch existing pod so that the current KubernetesJobWatcher can monitor it via label selectors @@ -716,83 +743,88 @@ def adopt_launched_task( :param pod: V1Pod spec that we will patch with new label :param pod_ids: pod_ids we expect to patch. """ - if not self.scheduler_job_id: - raise AirflowException(NOT_STARTED_MESSAGE) + if TYPE_CHECKING: + assert self.scheduler_job_id + self.log.info("attempting to adopt pod %s", pod.metadata.name) - pod.metadata.labels['airflow-worker'] = pod_generator.make_safe_label_value(self.scheduler_job_id) pod_id = annotations_to_key(pod.metadata.annotations) if pod_id not in pod_ids: self.log.error("attempting to adopt taskinstance which was not specified by database: %s", pod_id) return + new_worker_id_label = pod_generator.make_safe_label_value(self.scheduler_job_id) try: kube_client.patch_namespaced_pod( name=pod.metadata.name, namespace=pod.metadata.namespace, - body=PodGenerator.serialize_pod(pod), + body={"metadata": {"labels": {"airflow-worker": new_worker_id_label}}}, ) - pod_ids.pop(pod_id) - self.running.add(pod_id) except ApiException as e: self.log.info("Failed to adopt pod %s. Reason: %s", pod.metadata.name, e) + return + + del pod_ids[pod_id] + self.running.add(pod_id) def _adopt_completed_pods(self, kube_client: client.CoreV1Api) -> None: """ - - Patch completed pod so that the KubernetesJobWatcher can delete it. + Patch completed pods so that the KubernetesJobWatcher can delete them. :param kube_client: kubernetes client for speaking to kube API """ - if not self.scheduler_job_id: - raise AirflowException(NOT_STARTED_MESSAGE) + if TYPE_CHECKING: + assert self.scheduler_job_id + + new_worker_id_label = pod_generator.make_safe_label_value(self.scheduler_job_id) kwargs = { - 'field_selector': "status.phase=Succeeded", - 'label_selector': 'kubernetes_executor=True', + "field_selector": "status.phase=Succeeded", + "label_selector": f"kubernetes_executor=True,airflow-worker!={new_worker_id_label}", } pod_list = kube_client.list_namespaced_pod(namespace=self.kube_config.kube_namespace, **kwargs) for pod in pod_list.items: self.log.info("Attempting to adopt pod %s", pod.metadata.name) - pod.metadata.labels['airflow-worker'] = pod_generator.make_safe_label_value(self.scheduler_job_id) try: kube_client.patch_namespaced_pod( name=pod.metadata.name, namespace=pod.metadata.namespace, - body=PodGenerator.serialize_pod(pod), + body={"metadata": {"labels": {"airflow-worker": new_worker_id_label}}}, ) except ApiException as e: self.log.info("Failed to adopt pod %s. Reason: %s", pod.metadata.name, e) def _flush_task_queue(self) -> None: - if not self.task_queue: - raise AirflowException(NOT_STARTED_MESSAGE) - self.log.debug('Executor shutting down, task_queue approximate size=%d', self.task_queue.qsize()) + if TYPE_CHECKING: + assert self.task_queue + + self.log.debug("Executor shutting down, task_queue approximate size=%d", self.task_queue.qsize()) while True: try: task = self.task_queue.get_nowait() # This is a new task to run thus ok to ignore. - self.log.warning('Executor shutting down, will NOT run task=%s', task) + self.log.warning("Executor shutting down, will NOT run task=%s", task) self.task_queue.task_done() except Empty: break def _flush_result_queue(self) -> None: - if not self.result_queue: - raise AirflowException(NOT_STARTED_MESSAGE) - self.log.debug('Executor shutting down, result_queue approximate size=%d', self.result_queue.qsize()) + if TYPE_CHECKING: + assert self.result_queue + + self.log.debug("Executor shutting down, result_queue approximate size=%d", self.result_queue.qsize()) while True: try: results = self.result_queue.get_nowait() - self.log.warning('Executor shutting down, flushing results=%s', results) + self.log.warning("Executor shutting down, flushing results=%s", results) try: key, state, pod_id, namespace, resource_version = results self.log.info( - 'Changing state of %s to %s : resource_version=%d', results, state, resource_version + "Changing state of %s to %s : resource_version=%d", results, state, resource_version ) try: self._change_state(key, state, pod_id, namespace) except Exception as e: self.log.exception( - 'Ignoring exception: %s when attempting to change state of %s to %s.', + "Ignoring exception: %s when attempting to change state of %s to %s.", e, results, state, @@ -804,20 +836,22 @@ def _flush_result_queue(self) -> None: def end(self) -> None: """Called when the executor shuts down""" - if not self.task_queue: - raise AirflowException(NOT_STARTED_MESSAGE) - if not self.result_queue: - raise AirflowException(NOT_STARTED_MESSAGE) - if not self.kube_scheduler: - raise AirflowException(NOT_STARTED_MESSAGE) - self.log.info('Shutting down Kubernetes executor') - self.log.debug('Flushing task_queue...') - self._flush_task_queue() - self.log.debug('Flushing result_queue...') - self._flush_result_queue() - # Both queues should be empty... - self.task_queue.join() - self.result_queue.join() + if TYPE_CHECKING: + assert self.task_queue + assert self.result_queue + assert self.kube_scheduler + + self.log.info("Shutting down Kubernetes executor") + try: + self.log.debug("Flushing task_queue...") + self._flush_task_queue() + self.log.debug("Flushing result_queue...") + self._flush_result_queue() + # Both queues should be empty... + self.task_queue.join() + self.result_queue.join() + except ConnectionResetError: + self.log.exception("Connection Reset error while flushing task_queue and result_queue.") if self.kube_scheduler: self.kube_scheduler.terminate() self._manager.shutdown() diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py index 431add69bcfbb..c2c82d863944a 100644 --- a/airflow/executors/local_executor.py +++ b/airflow/executors/local_executor.py @@ -22,6 +22,8 @@ For more information on how the LocalExecutor works, take a look at the guide: :ref:`executor:LocalExecutor` """ +from __future__ import annotations + import logging import os import subprocess @@ -29,13 +31,13 @@ from multiprocessing import Manager, Process from multiprocessing.managers import SyncManager from queue import Empty, Queue -from typing import Any, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Optional, Tuple from setproctitle import getproctitle, setproctitle from airflow import settings from airflow.exceptions import AirflowException -from airflow.executors.base_executor import NOT_STARTED_MESSAGE, PARALLELISM, BaseExecutor, CommandType +from airflow.executors.base_executor import PARALLELISM, BaseExecutor, CommandType from airflow.models.taskinstance import TaskInstanceKey, TaskInstanceStateType from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State @@ -54,10 +56,10 @@ class LocalWorkerBase(Process, LoggingMixin): :param result_queue: the queue to store result state """ - def __init__(self, result_queue: 'Queue[TaskInstanceStateType]'): + def __init__(self, result_queue: Queue[TaskInstanceStateType]): super().__init__(target=self.do_work) self.daemon: bool = True - self.result_queue: 'Queue[TaskInstanceStateType]' = result_queue + self.result_queue: Queue[TaskInstanceStateType] = result_queue def run(self): # We know we've just started a new process, so lets disconnect from the metadata db now @@ -148,7 +150,7 @@ class LocalWorker(LocalWorkerBase): """ def __init__( - self, result_queue: 'Queue[TaskInstanceStateType]', key: TaskInstanceKey, command: CommandType + self, result_queue: Queue[TaskInstanceStateType], key: TaskInstanceKey, command: CommandType ): super().__init__(result_queue) self.key: TaskInstanceKey = key @@ -168,7 +170,7 @@ class QueuedLocalWorker(LocalWorkerBase): :param result_queue: queue where worker puts results after finishing tasks """ - def __init__(self, task_queue: 'Queue[ExecutorWorkType]', result_queue: 'Queue[TaskInstanceStateType]'): + def __init__(self, task_queue: Queue[ExecutorWorkType], result_queue: Queue[TaskInstanceStateType]): super().__init__(result_queue=result_queue) self.task_queue = task_queue @@ -205,14 +207,12 @@ def __init__(self, parallelism: int = PARALLELISM): super().__init__(parallelism=parallelism) if self.parallelism < 0: raise AirflowException("parallelism must be bigger than or equal to 0") - self.manager: Optional[SyncManager] = None - self.result_queue: Optional['Queue[TaskInstanceStateType]'] = None - self.workers: List[QueuedLocalWorker] = [] + self.manager: SyncManager | None = None + self.result_queue: Queue[TaskInstanceStateType] | None = None + self.workers: list[QueuedLocalWorker] = [] self.workers_used: int = 0 self.workers_active: int = 0 - self.impl: Optional[ - Union['LocalExecutor.UnlimitedParallelism', 'LocalExecutor.LimitedParallelism'] - ] = None + self.impl: None | (LocalExecutor.UnlimitedParallelism | LocalExecutor.LimitedParallelism) = None class UnlimitedParallelism: """ @@ -222,8 +222,8 @@ class UnlimitedParallelism: :param executor: the executor instance to implement. """ - def __init__(self, executor: 'LocalExecutor'): - self.executor: 'LocalExecutor' = executor + def __init__(self, executor: LocalExecutor): + self.executor: LocalExecutor = executor def start(self) -> None: """Starts the executor.""" @@ -234,8 +234,8 @@ def execute_async( self, key: TaskInstanceKey, command: CommandType, - queue: Optional[str] = None, - executor_config: Optional[Any] = None, + queue: str | None = None, + executor_config: Any | None = None, ) -> None: """ Executes task asynchronously. @@ -245,8 +245,9 @@ def execute_async( :param queue: Name of the queue :param executor_config: configuration for the executor """ - if not self.executor.result_queue: - raise AirflowException(NOT_STARTED_MESSAGE) + if TYPE_CHECKING: + assert self.executor.result_queue + local_worker = LocalWorker(self.executor.result_queue, key=key, command=command) self.executor.workers_used += 1 self.executor.workers_active += 1 @@ -278,17 +279,17 @@ class LimitedParallelism: :param executor: the executor instance to implement. """ - def __init__(self, executor: 'LocalExecutor'): - self.executor: 'LocalExecutor' = executor - self.queue: Optional['Queue[ExecutorWorkType]'] = None + def __init__(self, executor: LocalExecutor): + self.executor: LocalExecutor = executor + self.queue: Queue[ExecutorWorkType] | None = None def start(self) -> None: """Starts limited parallelism implementation.""" - if not self.executor.manager: - raise AirflowException(NOT_STARTED_MESSAGE) + if TYPE_CHECKING: + assert self.executor.manager + assert self.executor.result_queue + self.queue = self.executor.manager.Queue() - if not self.executor.result_queue: - raise AirflowException(NOT_STARTED_MESSAGE) self.executor.workers = [ QueuedLocalWorker(self.queue, self.executor.result_queue) for _ in range(self.executor.parallelism) @@ -303,8 +304,8 @@ def execute_async( self, key: TaskInstanceKey, command: CommandType, - queue: Optional[str] = None, - executor_config: Optional[Any] = None, + queue: str | None = None, + executor_config: Any | None = None, ) -> None: """ Executes task asynchronously. @@ -314,8 +315,9 @@ def execute_async( :param queue: name of the queue :param executor_config: configuration for the executor """ - if not self.queue: - raise AirflowException(NOT_STARTED_MESSAGE) + if TYPE_CHECKING: + assert self.queue + self.queue.put((key, command)) def sync(self): @@ -361,21 +363,22 @@ def execute_async( self, key: TaskInstanceKey, command: CommandType, - queue: Optional[str] = None, - executor_config: Optional[Any] = None, + queue: str | None = None, + executor_config: Any | None = None, ) -> None: """Execute asynchronously.""" - if not self.impl: - raise AirflowException(NOT_STARTED_MESSAGE) + if TYPE_CHECKING: + assert self.impl - self.validate_command(command) + self.validate_airflow_tasks_run_command(command) self.impl.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) def sync(self) -> None: """Sync will get called periodically by the heartbeat method.""" - if not self.impl: - raise AirflowException(NOT_STARTED_MESSAGE) + if TYPE_CHECKING: + assert self.impl + self.impl.sync() def end(self) -> None: @@ -383,10 +386,10 @@ def end(self) -> None: Ends the executor. :return: """ - if not self.impl: - raise AirflowException(NOT_STARTED_MESSAGE) - if not self.manager: - raise AirflowException(NOT_STARTED_MESSAGE) + if TYPE_CHECKING: + assert self.impl + assert self.manager + self.log.info( "Shutting down LocalExecutor" "; waiting for running tasks to finish. Signal again if you don't want to wait." diff --git a/airflow/executors/local_kubernetes_executor.py b/airflow/executors/local_kubernetes_executor.py index 9944cfe1ef1fa..b6151d5d1c1d5 100644 --- a/airflow/executors/local_kubernetes_executor.py +++ b/airflow/executors/local_kubernetes_executor.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict, List, Optional, Set, Union +from __future__ import annotations + +from typing import Sequence from airflow.callbacks.base_callback_sink import BaseCallbackSink from airflow.callbacks.callback_requests import CallbackRequest @@ -37,19 +39,19 @@ class LocalKubernetesExecutor(LoggingMixin): """ supports_ad_hoc_ti_run: bool = True - callback_sink: Optional[BaseCallbackSink] = None + callback_sink: BaseCallbackSink | None = None - KUBERNETES_QUEUE = conf.get('local_kubernetes_executor', 'kubernetes_queue') + KUBERNETES_QUEUE = conf.get("local_kubernetes_executor", "kubernetes_queue") def __init__(self, local_executor: LocalExecutor, kubernetes_executor: KubernetesExecutor): super().__init__() - self._job_id: Optional[str] = None + self._job_id: str | None = None self.local_executor = local_executor self.kubernetes_executor = kubernetes_executor self.kubernetes_executor.kubernetes_queue = self.KUBERNETES_QUEUE @property - def queued_tasks(self) -> Dict[TaskInstanceKey, QueuedTaskInstanceType]: + def queued_tasks(self) -> dict[TaskInstanceKey, QueuedTaskInstanceType]: """Return queued tasks from local and kubernetes executor""" queued_tasks = self.local_executor.queued_tasks.copy() queued_tasks.update(self.kubernetes_executor.queued_tasks) @@ -57,12 +59,12 @@ def queued_tasks(self) -> Dict[TaskInstanceKey, QueuedTaskInstanceType]: return queued_tasks @property - def running(self) -> Set[TaskInstanceKey]: + def running(self) -> set[TaskInstanceKey]: """Return running tasks from local and kubernetes executor""" return self.local_executor.running.union(self.kubernetes_executor.running) @property - def job_id(self) -> Optional[str]: + def job_id(self) -> str | None: """ This is a class attribute in BaseExecutor but since this is not really an executor, but a wrapper of executors we implement as property so we can have custom setter. @@ -70,7 +72,7 @@ def job_id(self) -> Optional[str]: return self._job_id @job_id.setter - def job_id(self, value: Optional[str]) -> None: + def job_id(self, value: str | None) -> None: """job_id is manipulated by SchedulerJob. We must propagate the job_id to wrapped executors.""" self._job_id = value self.kubernetes_executor.job_id = value @@ -92,7 +94,7 @@ def queue_command( task_instance: TaskInstance, command: CommandType, priority: int = 1, - queue: Optional[str] = None, + queue: str | None = None, ) -> None: """Queues command via local or kubernetes executor""" executor = self._router(task_instance) @@ -103,13 +105,13 @@ def queue_task_instance( self, task_instance: TaskInstance, mark_success: bool = False, - pickle_id: Optional[str] = None, + pickle_id: str | None = None, ignore_all_deps: bool = False, ignore_depends_on_past: bool = False, ignore_task_deps: bool = False, ignore_ti_state: bool = False, - pool: Optional[str] = None, - cfg_path: Optional[str] = None, + pool: str | None = None, + cfg_path: str | None = None, ) -> None: """Queues task instance via local or kubernetes executor""" executor = self._router(SimpleTaskInstance.from_ti(task_instance)) @@ -143,8 +145,8 @@ def heartbeat(self) -> None: self.kubernetes_executor.heartbeat() def get_event_buffer( - self, dag_ids: Optional[List[str]] = None - ) -> Dict[TaskInstanceKey, EventBufferValueType]: + self, dag_ids: list[str] | None = None + ) -> dict[TaskInstanceKey, EventBufferValueType]: """ Returns and flush the event buffer from local and kubernetes executor @@ -156,7 +158,7 @@ def get_event_buffer( return {**cleared_events_from_local, **cleared_events_from_kubernetes} - def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]: + def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]: """ Try to adopt running task instances that have been abandoned by a SchedulerJob dying. @@ -164,19 +166,13 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance re-scheduling) :return: any TaskInstances that were unable to be adopted - :rtype: list[airflow.models.TaskInstance] """ - local_tis = [] - kubernetes_tis = [] - abandoned_tis = [] - for ti in tis: - if ti.queue == self.KUBERNETES_QUEUE: - kubernetes_tis.append(ti) - else: - local_tis.append(ti) - abandoned_tis.extend(self.local_executor.try_adopt_task_instances(local_tis)) - abandoned_tis.extend(self.kubernetes_executor.try_adopt_task_instances(kubernetes_tis)) - return abandoned_tis + local_tis = [ti for ti in tis if ti.queue != self.KUBERNETES_QUEUE] + kubernetes_tis = [ti for ti in tis if ti.queue == self.KUBERNETES_QUEUE] + return [ + *self.local_executor.try_adopt_task_instances(local_tis), + *self.kubernetes_executor.try_adopt_task_instances(kubernetes_tis), + ] def end(self) -> None: """End local and kubernetes executor""" @@ -188,13 +184,12 @@ def terminate(self) -> None: self.local_executor.terminate() self.kubernetes_executor.terminate() - def _router(self, simple_task_instance: SimpleTaskInstance) -> Union[LocalExecutor, KubernetesExecutor]: + def _router(self, simple_task_instance: SimpleTaskInstance) -> LocalExecutor | KubernetesExecutor: """ Return either local_executor or kubernetes_executor :param simple_task_instance: SimpleTaskInstance :return: local_executor or kubernetes_executor - :rtype: Union[LocalExecutor, KubernetesExecutor] """ if simple_task_instance.queue == self.KUBERNETES_QUEUE: return self.kubernetes_executor diff --git a/airflow/executors/sequential_executor.py b/airflow/executors/sequential_executor.py index 456e3e9893e8b..c7c2f00417a21 100644 --- a/airflow/executors/sequential_executor.py +++ b/airflow/executors/sequential_executor.py @@ -22,8 +22,10 @@ For more information on how the SequentialExecutor works, take a look at the guide: :ref:`executor:SequentialExecutor` """ +from __future__ import annotations + import subprocess -from typing import Any, Optional +from typing import Any from airflow.executors.base_executor import BaseExecutor, CommandType from airflow.models.taskinstance import TaskInstanceKey @@ -48,10 +50,10 @@ def execute_async( self, key: TaskInstanceKey, command: CommandType, - queue: Optional[str] = None, - executor_config: Optional[Any] = None, + queue: str | None = None, + executor_config: Any | None = None, ) -> None: - self.validate_command(command) + self.validate_airflow_tasks_run_command(command) self.commands_to_run.append((key, command)) def sync(self) -> None: diff --git a/airflow/hooks/S3_hook.py b/airflow/hooks/S3_hook.py deleted file mode 100644 index b59311a1ba9e4..0000000000000 --- a/airflow/hooks/S3_hook.py +++ /dev/null @@ -1,30 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.s3`.""" - -import warnings - -from airflow.providers.amazon.aws.hooks.s3 import S3Hook, provide_bucket_name # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.s3`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/hooks/__init__.py b/airflow/hooks/__init__.py index 5c933400c2f8a..db2b7f0a298ee 100644 --- a/airflow/hooks/__init__.py +++ b/airflow/hooks/__init__.py @@ -15,4 +15,80 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# fmt:, off """Hooks.""" +from __future__ import annotations + +from airflow.utils.deprecation_tools import add_deprecated_classes + +__deprecated_classes = { + "S3_hook": { + "S3Hook": "airflow.providers.amazon.aws.hooks.s3.S3Hook", + "provide_bucket_name": "airflow.providers.amazon.aws.hooks.s3.provide_bucket_name", + }, + "base_hook": { + "BaseHook": "airflow.hooks.base.BaseHook", + }, + "dbapi_hook": { + "DbApiHook": "airflow.providers.common.sql.hooks.sql.DbApiHook", + }, + "docker_hook": { + "DockerHook": "airflow.providers.docker.hooks.docker.DockerHook", + }, + "druid_hook": { + "DruidDbApiHook": "airflow.providers.apache.druid.hooks.druid.DruidDbApiHook", + "DruidHook": "airflow.providers.apache.druid.hooks.druid.DruidHook", + }, + "hdfs_hook": { + "HDFSHook": "airflow.providers.apache.hdfs.hooks.hdfs.HDFSHook", + "HDFSHookException": "airflow.providers.apache.hdfs.hooks.hdfs.HDFSHookException", + }, + "hive_hooks": { + "HIVE_QUEUE_PRIORITIES": "airflow.providers.apache.hive.hooks.hive.HIVE_QUEUE_PRIORITIES", + "HiveCliHook": "airflow.providers.apache.hive.hooks.hive.HiveCliHook", + "HiveMetastoreHook": "airflow.providers.apache.hive.hooks.hive.HiveMetastoreHook", + "HiveServer2Hook": "airflow.providers.apache.hive.hooks.hive.HiveServer2Hook", + }, + "http_hook": { + "HttpHook": "airflow.providers.http.hooks.http.HttpHook", + }, + "jdbc_hook": { + "JdbcHook": "airflow.providers.jdbc.hooks.jdbc.JdbcHook", + "jaydebeapi": "airflow.providers.jdbc.hooks.jdbc.jaydebeapi", + }, + "mssql_hook": { + "MsSqlHook": "airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook", + }, + "mysql_hook": { + "MySqlHook": "airflow.providers.mysql.hooks.mysql.MySqlHook", + }, + "oracle_hook": { + "OracleHook": "airflow.providers.oracle.hooks.oracle.OracleHook", + }, + "pig_hook": { + "PigCliHook": "airflow.providers.apache.pig.hooks.pig.PigCliHook", + }, + "postgres_hook": { + "PostgresHook": "airflow.providers.postgres.hooks.postgres.PostgresHook", + }, + "presto_hook": { + "PrestoHook": "airflow.providers.presto.hooks.presto.PrestoHook", + }, + "samba_hook": { + "SambaHook": "airflow.providers.samba.hooks.samba.SambaHook", + }, + "slack_hook": { + "SlackHook": "airflow.providers.slack.hooks.slack.SlackHook", + }, + "sqlite_hook": { + "SqliteHook": "airflow.providers.sqlite.hooks.sqlite.SqliteHook", + }, + "webhdfs_hook": { + "WebHDFSHook": "airflow.providers.apache.hdfs.hooks.webhdfs.WebHDFSHook", + }, + "zendesk_hook": { + "ZendeskHook": "airflow.providers.zendesk.hooks.zendesk.ZendeskHook", + }, +} + +add_deprecated_classes(__deprecated_classes, __name__) diff --git a/airflow/hooks/base.py b/airflow/hooks/base.py index aa506cf62de28..9298a686889f7 100644 --- a/airflow/hooks/base.py +++ b/airflow/hooks/base.py @@ -15,11 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Base class for all hooks""" +"""Base class for all hooks.""" +from __future__ import annotations + import logging import warnings -from typing import TYPE_CHECKING, Any, Dict, List +from typing import TYPE_CHECKING, Any +from airflow.exceptions import RemovedInAirflow3Warning from airflow.typing_compat import Protocol from airflow.utils.log.logging_mixin import LoggingMixin @@ -31,7 +34,9 @@ class BaseHook(LoggingMixin): """ - Abstract base class for hooks, hooks are meant as an interface to + Abstract base class for hooks. + + Hooks are meant as an interface to interact with external systems. MySqlHook, HiveHook, PigHook return object that can handle the connection and interaction to specific instances of these systems, and expose consistent methods to interact @@ -39,7 +44,7 @@ class BaseHook(LoggingMixin): """ @classmethod - def get_connections(cls, conn_id: str) -> List["Connection"]: + def get_connections(cls, conn_id: str) -> list[Connection]: """ Get all connections as an iterable, given the connection id. @@ -49,13 +54,13 @@ def get_connections(cls, conn_id: str) -> List["Connection"]: warnings.warn( "`BaseHook.get_connections` method will be deprecated in the future." "Please use `BaseHook.get_connection` instead.", - PendingDeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) return [cls.get_connection(conn_id)] @classmethod - def get_connection(cls, conn_id: str) -> "Connection": + def get_connection(cls, conn_id: str) -> Connection: """ Get connection, given connection id. @@ -69,15 +74,13 @@ def get_connection(cls, conn_id: str) -> "Connection": return conn @classmethod - def get_hook(cls, conn_id: str) -> "BaseHook": + def get_hook(cls, conn_id: str) -> BaseHook: """ Returns default hook for this connection id. :param conn_id: connection id :return: default hook for this connection """ - # TODO: set method return type to BaseHook class when on 3.7+. - # See https://stackoverflow.com/a/33533514/3066428 connection = cls.get_connection(conn_id) return connection.get_hook() @@ -86,11 +89,11 @@ def get_conn(self) -> Any: raise NotImplementedError() @classmethod - def get_connection_form_widgets(cls) -> Dict[str, Any]: + def get_connection_form_widgets(cls) -> dict[str, Any]: ... @classmethod - def get_ui_field_behaviour(cls) -> Dict[str, Any]: + def get_ui_field_behaviour(cls) -> dict[str, Any]: ... @@ -139,7 +142,7 @@ def get_ui_field_behaviour(cls): hook_name: str @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: + def get_connection_form_widgets() -> dict[str, Any]: """ Returns dictionary of widgets to be added for the hook to handle extra values. @@ -155,8 +158,10 @@ def get_connection_form_widgets() -> Dict[str, Any]: ... @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + def get_ui_field_behaviour() -> dict[str, Any]: """ + Attributes of the UI field. + Returns dictionary describing customizations to implement in javascript handling the connection form. Should be compliant with airflow/customized_form_field_behaviours.schema.json' diff --git a/airflow/hooks/base_hook.py b/airflow/hooks/base_hook.py deleted file mode 100644 index cf1594d18d284..0000000000000 --- a/airflow/hooks/base_hook.py +++ /dev/null @@ -1,24 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.hooks.base`.""" - -import warnings - -from airflow.hooks.base import BaseHook # noqa - -warnings.warn("This module is deprecated. Please use `airflow.hooks.base`.", DeprecationWarning, stacklevel=2) diff --git a/airflow/hooks/dbapi.py b/airflow/hooks/dbapi.py index 0b9ce4377be23..cd4a39af8d2a5 100644 --- a/airflow/hooks/dbapi.py +++ b/airflow/hooks/dbapi.py @@ -15,363 +15,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from contextlib import closing -from datetime import datetime -from typing import Any, Optional +from __future__ import annotations -from sqlalchemy import create_engine +import warnings -from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook -from airflow.typing_compat import Protocol +from airflow.exceptions import RemovedInAirflow3Warning +from airflow.providers.common.sql.hooks.sql import ConnectorProtocol # noqa +from airflow.providers.common.sql.hooks.sql import DbApiHook # noqa - -class ConnectorProtocol(Protocol): - """A protocol where you can connect to a database.""" - - def connect(self, host: str, port: int, username: str, schema: str) -> Any: - """ - Connect to a database. - - :param host: The database host to connect to. - :param port: The database port to connect to. - :param username: The database username used for the authentication. - :param schema: The database schema to connect to. - :return: the authorized connection object. - """ - - -######################################################################################### -# # -# Note! Be extra careful when changing this file. This hook is used as a base for # -# a number of DBApi-related hooks and providers depend on the methods implemented # -# here. Whatever you add here, has to backwards compatible unless # -# `>=` is added to providers' requirements using the new feature # -# # -######################################################################################### -class DbApiHook(BaseHook): - """ - Abstract base class for sql hooks. - - :param schema: Optional DB schema that overrides the schema specified in the connection. Make sure that - if you change the schema parameter value in the constructor of the derived Hook, such change - should be done before calling the ``DBApiHook.__init__()``. - """ - - # Override to provide the connection name. - conn_name_attr = None # type: str - # Override to have a default connection id for a particular dbHook - default_conn_name = 'default_conn_id' - # Override if this db supports autocommit. - supports_autocommit = False - # Override with the object that exposes the connect method - connector = None # type: Optional[ConnectorProtocol] - # Override with db-specific query to check connection - _test_connection_sql = "select 1" - - def __init__(self, *args, schema: Optional[str] = None, **kwargs): - super().__init__() - if not self.conn_name_attr: - raise AirflowException("conn_name_attr is not defined") - elif len(args) == 1: - setattr(self, self.conn_name_attr, args[0]) - elif self.conn_name_attr not in kwargs: - setattr(self, self.conn_name_attr, self.default_conn_name) - else: - setattr(self, self.conn_name_attr, kwargs[self.conn_name_attr]) - # We should not make schema available in deriving hooks for backwards compatibility - # If a hook deriving from DBApiHook has a need to access schema, then it should retrieve it - # from kwargs and store it on its own. We do not run "pop" here as we want to give the - # Hook deriving from the DBApiHook to still have access to the field in it's constructor - self.__schema = schema - - def get_conn(self): - """Returns a connection object""" - db = self.get_connection(getattr(self, self.conn_name_attr)) - return self.connector.connect(host=db.host, port=db.port, username=db.login, schema=db.schema) - - def get_uri(self) -> str: - """ - Extract the URI from the connection. - - :return: the extracted uri. - """ - conn = self.get_connection(getattr(self, self.conn_name_attr)) - conn.schema = self.__schema or conn.schema - return conn.get_uri() - - def get_sqlalchemy_engine(self, engine_kwargs=None): - """ - Get an sqlalchemy_engine object. - - :param engine_kwargs: Kwargs used in :func:`~sqlalchemy.create_engine`. - :return: the created engine. - """ - if engine_kwargs is None: - engine_kwargs = {} - return create_engine(self.get_uri(), **engine_kwargs) - - def get_pandas_df(self, sql, parameters=None, **kwargs): - """ - Executes the sql and returns a pandas dataframe - - :param sql: the sql statement to be executed (str) or a list of - sql statements to execute - :param parameters: The parameters to render the SQL query with. - :param kwargs: (optional) passed into pandas.io.sql.read_sql method - """ - try: - from pandas.io import sql as psql - except ImportError: - raise Exception("pandas library not installed, run: pip install 'apache-airflow[pandas]'.") - - with closing(self.get_conn()) as conn: - return psql.read_sql(sql, con=conn, params=parameters, **kwargs) - - def get_pandas_df_by_chunks(self, sql, parameters=None, *, chunksize, **kwargs): - """ - Executes the sql and returns a generator - - :param sql: the sql statement to be executed (str) or a list of - sql statements to execute - :param parameters: The parameters to render the SQL query with - :param chunksize: number of rows to include in each chunk - :param kwargs: (optional) passed into pandas.io.sql.read_sql method - """ - try: - from pandas.io import sql as psql - except ImportError: - raise Exception("pandas library not installed, run: pip install 'apache-airflow[pandas]'.") - - with closing(self.get_conn()) as conn: - yield from psql.read_sql(sql, con=conn, params=parameters, chunksize=chunksize, **kwargs) - - def get_records(self, sql, parameters=None): - """ - Executes the sql and returns a set of records. - - :param sql: the sql statement to be executed (str) or a list of - sql statements to execute - :param parameters: The parameters to render the SQL query with. - """ - with closing(self.get_conn()) as conn: - with closing(conn.cursor()) as cur: - if parameters is not None: - cur.execute(sql, parameters) - else: - cur.execute(sql) - return cur.fetchall() - - def get_first(self, sql, parameters=None): - """ - Executes the sql and returns the first resulting row. - - :param sql: the sql statement to be executed (str) or a list of - sql statements to execute - :param parameters: The parameters to render the SQL query with. - """ - with closing(self.get_conn()) as conn: - with closing(conn.cursor()) as cur: - if parameters is not None: - cur.execute(sql, parameters) - else: - cur.execute(sql) - return cur.fetchone() - - def run(self, sql, autocommit=False, parameters=None, handler=None): - """ - Runs a command or a list of commands. Pass a list of sql - statements to the sql parameter to get them to execute - sequentially - - :param sql: the sql statement to be executed (str) or a list of - sql statements to execute - :param autocommit: What to set the connection's autocommit setting to - before executing the query. - :param parameters: The parameters to render the SQL query with. - :param handler: The result handler which is called with the result of each statement. - :return: query results if handler was provided. - """ - scalar = isinstance(sql, str) - if scalar: - sql = [sql] - - if sql: - self.log.debug("Executing %d statements", len(sql)) - else: - raise ValueError("List of SQL statements is empty") - - with closing(self.get_conn()) as conn: - if self.supports_autocommit: - self.set_autocommit(conn, autocommit) - - with closing(conn.cursor()) as cur: - results = [] - for sql_statement in sql: - self._run_command(cur, sql_statement, parameters) - if handler is not None: - result = handler(cur) - results.append(result) - - # If autocommit was set to False for db that supports autocommit, - # or if db does not supports autocommit, we do a manual commit. - if not self.get_autocommit(conn): - conn.commit() - - if handler is None: - return None - - if scalar: - return results[0] - - return results - - def _run_command(self, cur, sql_statement, parameters): - """Runs a statement using an already open cursor.""" - self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters) - if parameters: - cur.execute(sql_statement, parameters) - else: - cur.execute(sql_statement) - - # According to PEP 249, this is -1 when query result is not applicable. - if cur.rowcount >= 0: - self.log.info("Rows affected: %s", cur.rowcount) - - def set_autocommit(self, conn, autocommit): - """Sets the autocommit flag on the connection""" - if not self.supports_autocommit and autocommit: - self.log.warning( - "%s connection doesn't support autocommit but autocommit activated.", - getattr(self, self.conn_name_attr), - ) - conn.autocommit = autocommit - - def get_autocommit(self, conn): - """ - Get autocommit setting for the provided connection. - Return True if conn.autocommit is set to True. - Return False if conn.autocommit is not set or set to False or conn - does not support autocommit. - - :param conn: Connection to get autocommit setting from. - :return: connection autocommit setting. - :rtype: bool - """ - return getattr(conn, 'autocommit', False) and self.supports_autocommit - - def get_cursor(self): - """Returns a cursor""" - return self.get_conn().cursor() - - @staticmethod - def _generate_insert_sql(table, values, target_fields, replace, **kwargs): - """ - Static helper method that generates the INSERT SQL statement. - The REPLACE variant is specific to MySQL syntax. - - :param table: Name of the target table - :param values: The row to insert into the table - :param target_fields: The names of the columns to fill in the table - :param replace: Whether to replace instead of insert - :return: The generated INSERT or REPLACE SQL statement - :rtype: str - """ - placeholders = [ - "%s", - ] * len(values) - - if target_fields: - target_fields = ", ".join(target_fields) - target_fields = f"({target_fields})" - else: - target_fields = '' - - if not replace: - sql = "INSERT INTO " - else: - sql = "REPLACE INTO " - sql += f"{table} {target_fields} VALUES ({','.join(placeholders)})" - return sql - - def insert_rows(self, table, rows, target_fields=None, commit_every=1000, replace=False, **kwargs): - """ - A generic way to insert a set of tuples into a table, - a new transaction is created every commit_every rows - - :param table: Name of the target table - :param rows: The rows to insert into the table - :param target_fields: The names of the columns to fill in the table - :param commit_every: The maximum number of rows to insert in one - transaction. Set to 0 to insert all rows in one transaction. - :param replace: Whether to replace instead of insert - """ - i = 0 - with closing(self.get_conn()) as conn: - if self.supports_autocommit: - self.set_autocommit(conn, False) - - conn.commit() - - with closing(conn.cursor()) as cur: - for i, row in enumerate(rows, 1): - lst = [] - for cell in row: - lst.append(self._serialize_cell(cell, conn)) - values = tuple(lst) - sql = self._generate_insert_sql(table, values, target_fields, replace, **kwargs) - self.log.debug("Generated sql: %s", sql) - cur.execute(sql, values) - if commit_every and i % commit_every == 0: - conn.commit() - self.log.info("Loaded %s rows into %s so far", i, table) - - conn.commit() - self.log.info("Done loading. Loaded a total of %s rows", i) - - @staticmethod - def _serialize_cell(cell, conn=None): - """ - Returns the SQL literal of the cell as a string. - - :param cell: The cell to insert into the table - :param conn: The database connection - :return: The serialized cell - :rtype: str - """ - if cell is None: - return None - if isinstance(cell, datetime): - return cell.isoformat() - return str(cell) - - def bulk_dump(self, table, tmp_file): - """ - Dumps a database table into a tab-delimited file - - :param table: The name of the source table - :param tmp_file: The path of the target file - """ - raise NotImplementedError() - - def bulk_load(self, table, tmp_file): - """ - Loads a tab-delimited file into a database table - - :param table: The name of the target table - :param tmp_file: The path of the file to load into the table - """ - raise NotImplementedError() - - def test_connection(self): - """Tests the connection using db-specific query""" - status, message = False, '' - try: - if self.get_first(self._test_connection_sql): - status = True - message = 'Connection successfully tested' - except Exception as e: - status = False - message = str(e) - - return status, message +warnings.warn( + "This module is deprecated. Please use `airflow.providers.common.sql.hooks.sql`.", + RemovedInAirflow3Warning, + stacklevel=2, +) diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py deleted file mode 100644 index 4a441b0f50d59..0000000000000 --- a/airflow/hooks/dbapi_hook.py +++ /dev/null @@ -1,26 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.hooks.dbapi`.""" - -import warnings - -from airflow.hooks.dbapi import DbApiHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.hooks.dbapi`.", DeprecationWarning, stacklevel=2 -) diff --git a/airflow/hooks/docker_hook.py b/airflow/hooks/docker_hook.py deleted file mode 100644 index aaedd7e637d93..0000000000000 --- a/airflow/hooks/docker_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.docker.hooks.docker`.""" - -import warnings - -from airflow.providers.docker.hooks.docker import DockerHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.docker.hooks.docker`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/hooks/druid_hook.py b/airflow/hooks/druid_hook.py deleted file mode 100644 index 0a43debbabddc..0000000000000 --- a/airflow/hooks/druid_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.druid.hooks.druid`.""" - -import warnings - -from airflow.providers.apache.druid.hooks.druid import DruidDbApiHook, DruidHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.druid.hooks.druid`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/hooks/filesystem.py b/airflow/hooks/filesystem.py index d9ad6ded1cc12..39517e8cdc1ad 100644 --- a/airflow/hooks/filesystem.py +++ b/airflow/hooks/filesystem.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations from airflow.hooks.base import BaseHook @@ -33,13 +33,13 @@ class FSHook(BaseHook): Extra: {"path": "/tmp"} """ - def __init__(self, conn_id='fs_default'): + def __init__(self, conn_id: str = "fs_default"): super().__init__() conn = self.get_connection(conn_id) - self.basepath = conn.extra_dejson.get('path', '') + self.basepath = conn.extra_dejson.get("path", "") self.conn = conn - def get_conn(self): + def get_conn(self) -> None: pass def get_path(self) -> str: diff --git a/airflow/hooks/hdfs_hook.py b/airflow/hooks/hdfs_hook.py deleted file mode 100644 index fd13e7337e262..0000000000000 --- a/airflow/hooks/hdfs_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.hdfs.hooks.hdfs`.""" - -import warnings - -from airflow.providers.apache.hdfs.hooks.hdfs import HDFSHook, HDFSHookException # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.hdfs.hooks.hdfs`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/hooks/hive_hooks.py b/airflow/hooks/hive_hooks.py deleted file mode 100644 index 74d7863c8d947..0000000000000 --- a/airflow/hooks/hive_hooks.py +++ /dev/null @@ -1,33 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.hive.hooks.hive`.""" - -import warnings - -from airflow.providers.apache.hive.hooks.hive import ( # noqa - HIVE_QUEUE_PRIORITIES, - HiveCliHook, - HiveMetastoreHook, - HiveServer2Hook, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.hive.hooks.hive`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/hooks/http_hook.py b/airflow/hooks/http_hook.py deleted file mode 100644 index 5b8c1fdf9b776..0000000000000 --- a/airflow/hooks/http_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.http.hooks.http`.""" - -import warnings - -from airflow.providers.http.hooks.http import HttpHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.http.hooks.http`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/hooks/jdbc_hook.py b/airflow/hooks/jdbc_hook.py deleted file mode 100644 index a032ab0e2598b..0000000000000 --- a/airflow/hooks/jdbc_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.jdbc.hooks.jdbc`.""" - -import warnings - -from airflow.providers.jdbc.hooks.jdbc import JdbcHook, jaydebeapi # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.jdbc.hooks.jdbc`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/hooks/mssql_hook.py b/airflow/hooks/mssql_hook.py deleted file mode 100644 index 64943eeea7905..0000000000000 --- a/airflow/hooks/mssql_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.mssql.hooks.mssql`.""" - -import warnings - -from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.mssql.hooks.mssql`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/hooks/mysql_hook.py b/airflow/hooks/mysql_hook.py deleted file mode 100644 index 437313680b09c..0000000000000 --- a/airflow/hooks/mysql_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.mysql.hooks.mysql`.""" - -import warnings - -from airflow.providers.mysql.hooks.mysql import MySqlHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.mysql.hooks.mysql`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/hooks/oracle_hook.py b/airflow/hooks/oracle_hook.py deleted file mode 100644 index 0dfe33a78ae2a..0000000000000 --- a/airflow/hooks/oracle_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.oracle.hooks.oracle`.""" - -import warnings - -from airflow.providers.oracle.hooks.oracle import OracleHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.oracle.hooks.oracle`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/hooks/pig_hook.py b/airflow/hooks/pig_hook.py deleted file mode 100644 index 3ead3df6c826e..0000000000000 --- a/airflow/hooks/pig_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.pig.hooks.pig`.""" - -import warnings - -from airflow.providers.apache.pig.hooks.pig import PigCliHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.pig.hooks.pig`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/hooks/postgres_hook.py b/airflow/hooks/postgres_hook.py deleted file mode 100644 index 16f79dc329593..0000000000000 --- a/airflow/hooks/postgres_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.postgres.hooks.postgres`.""" - -import warnings - -from airflow.providers.postgres.hooks.postgres import PostgresHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.postgres.hooks.postgres`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/hooks/presto_hook.py b/airflow/hooks/presto_hook.py deleted file mode 100644 index 0c33e1423d35d..0000000000000 --- a/airflow/hooks/presto_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.presto.hooks.presto`.""" - -import warnings - -from airflow.providers.presto.hooks.presto import PrestoHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.presto.hooks.presto`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/hooks/samba_hook.py b/airflow/hooks/samba_hook.py deleted file mode 100644 index b4c7cf83b05a6..0000000000000 --- a/airflow/hooks/samba_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.samba.hooks.samba`.""" - -import warnings - -from airflow.providers.samba.hooks.samba import SambaHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.samba.hooks.samba`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/hooks/slack_hook.py b/airflow/hooks/slack_hook.py deleted file mode 100644 index 43636b2c6eeef..0000000000000 --- a/airflow/hooks/slack_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.slack.hooks.slack`.""" - -import warnings - -from airflow.providers.slack.hooks.slack import SlackHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.slack.hooks.slack`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/hooks/sqlite_hook.py b/airflow/hooks/sqlite_hook.py deleted file mode 100644 index 773900400ccbc..0000000000000 --- a/airflow/hooks/sqlite_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.sqlite.hooks.sqlite`.""" - -import warnings - -from airflow.providers.sqlite.hooks.sqlite import SqliteHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.sqlite.hooks.sqlite`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/hooks/subprocess.py b/airflow/hooks/subprocess.py index fa8c706c6963d..84479a3a40809 100644 --- a/airflow/hooks/subprocess.py +++ b/airflow/hooks/subprocess.py @@ -14,32 +14,33 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import contextlib import os import signal from collections import namedtuple from subprocess import PIPE, STDOUT, Popen from tempfile import TemporaryDirectory, gettempdir -from typing import Dict, List, Optional from airflow.hooks.base import BaseHook -SubprocessResult = namedtuple('SubprocessResult', ['exit_code', 'output']) +SubprocessResult = namedtuple("SubprocessResult", ["exit_code", "output"]) class SubprocessHook(BaseHook): - """Hook for running processes with the ``subprocess`` module""" + """Hook for running processes with the ``subprocess`` module.""" def __init__(self) -> None: - self.sub_process: Optional[Popen[bytes]] = None + self.sub_process: Popen[bytes] | None = None super().__init__() def run_command( self, - command: List[str], - env: Optional[Dict[str, str]] = None, - output_encoding: str = 'utf-8', - cwd: Optional[str] = None, + command: list[str], + env: dict[str, str] | None = None, + output_encoding: str = "utf-8", + cwd: str | None = None, ) -> SubprocessResult: """ Execute the command. @@ -52,26 +53,26 @@ def run_command( environment in which ``command`` will be executed. If omitted, ``os.environ`` will be used. Note, that in case you have Sentry configured, original variables from the environment will also be passed to the subprocess with ``SUBPROCESS_`` prefix. See - :doc:`/logging-monitoring/errors` for details. + :doc:`/administration-and-deployment/logging-monitoring/errors` for details. :param output_encoding: encoding to use for decoding stdout :param cwd: Working directory to run the command in. If None (default), the command is run in a temporary directory. :return: :class:`namedtuple` containing ``exit_code`` and ``output``, the last line from stderr or stdout """ - self.log.info('Tmp dir root location: \n %s', gettempdir()) + self.log.info("Tmp dir root location: \n %s", gettempdir()) with contextlib.ExitStack() as stack: if cwd is None: - cwd = stack.enter_context(TemporaryDirectory(prefix='airflowtmp')) + cwd = stack.enter_context(TemporaryDirectory(prefix="airflowtmp")) def pre_exec(): # Restore default signal disposition and invoke setsid - for sig in ('SIGPIPE', 'SIGXFZ', 'SIGXFSZ'): + for sig in ("SIGPIPE", "SIGXFZ", "SIGXFSZ"): if hasattr(signal, sig): signal.signal(getattr(signal, sig), signal.SIG_DFL) os.setsid() - self.log.info('Running command: %s', command) + self.log.info("Running command: %s", command) self.sub_process = Popen( command, @@ -82,24 +83,24 @@ def pre_exec(): preexec_fn=pre_exec, ) - self.log.info('Output:') - line = '' + self.log.info("Output:") + line = "" if self.sub_process is None: raise RuntimeError("The subprocess should be created here and is None!") if self.sub_process.stdout is not None: - for raw_line in iter(self.sub_process.stdout.readline, b''): - line = raw_line.decode(output_encoding, errors='backslashreplace').rstrip() + for raw_line in iter(self.sub_process.stdout.readline, b""): + line = raw_line.decode(output_encoding, errors="backslashreplace").rstrip() self.log.info("%s", line) self.sub_process.wait() - self.log.info('Command exited with return code %s', self.sub_process.returncode) + self.log.info("Command exited with return code %s", self.sub_process.returncode) return_code: int = self.sub_process.returncode return SubprocessResult(exit_code=return_code, output=line) def send_sigterm(self): """Sends SIGTERM signal to ``self.sub_process`` if one exists.""" - self.log.info('Sending SIGTERM signal to process group') - if self.sub_process and hasattr(self.sub_process, 'pid'): + self.log.info("Sending SIGTERM signal to process group") + if self.sub_process and hasattr(self.sub_process, "pid"): os.killpg(os.getpgid(self.sub_process.pid), signal.SIGTERM) diff --git a/airflow/hooks/webhdfs_hook.py b/airflow/hooks/webhdfs_hook.py deleted file mode 100644 index 1c4353835cf00..0000000000000 --- a/airflow/hooks/webhdfs_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.hdfs.hooks.webhdfs`.""" - -import warnings - -from airflow.providers.apache.hdfs.hooks.webhdfs import WebHDFSHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.hdfs.hooks.webhdfs`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/hooks/zendesk_hook.py b/airflow/hooks/zendesk_hook.py deleted file mode 100644 index 38323c5880d74..0000000000000 --- a/airflow/hooks/zendesk_hook.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.zendesk.hooks.zendesk`.""" - -import warnings - -from airflow.providers.zendesk.hooks.zendesk import ZendeskHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.zendesk.hooks.zendesk`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/jobs/__init__.py b/airflow/jobs/__init__.py index 2f061162719d5..217e5db960782 100644 --- a/airflow/jobs/__init__.py +++ b/airflow/jobs/__init__.py @@ -15,9 +15,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -import airflow.jobs.backfill_job -import airflow.jobs.base_job -import airflow.jobs.local_task_job -import airflow.jobs.scheduler_job -import airflow.jobs.triggerer_job # noqa diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py index c5b98c2a8df2a..1115edc78c5ea 100644 --- a/airflow/jobs/backfill_job.py +++ b/airflow/jobs/backfill_job.py @@ -15,10 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations import time -from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Tuple +from typing import TYPE_CHECKING, Any, Iterable, Iterator, Sequence import attr import pendulum @@ -50,7 +50,7 @@ from airflow.utils.types import DagRunType if TYPE_CHECKING: - from airflow.models.mappedoperator import MappedOperator + from airflow.models.abstractoperator import AbstractOperator class BackfillJob(BaseJob): @@ -62,7 +62,7 @@ class BackfillJob(BaseJob): STATES_COUNT_AS_RUNNING = (State.RUNNING, State.QUEUED) - __mapper_args__ = {'polymorphic_identity': 'BackfillJob'} + __mapper_args__ = {"polymorphic_identity": "BackfillJob"} @attr.define class _DagRunTaskStatus: @@ -88,15 +88,15 @@ class _DagRunTaskStatus: :param total_runs: Number of total dag runs able to run """ - to_run: Dict[TaskInstanceKey, TaskInstance] = attr.ib(factory=dict) - running: Dict[TaskInstanceKey, TaskInstance] = attr.ib(factory=dict) - skipped: Set[TaskInstanceKey] = attr.ib(factory=set) - succeeded: Set[TaskInstanceKey] = attr.ib(factory=set) - failed: Set[TaskInstanceKey] = attr.ib(factory=set) - not_ready: Set[TaskInstanceKey] = attr.ib(factory=set) - deadlocked: Set[TaskInstance] = attr.ib(factory=set) - active_runs: List[DagRun] = attr.ib(factory=list) - executed_dag_run_dates: Set[pendulum.DateTime] = attr.ib(factory=set) + to_run: dict[TaskInstanceKey, TaskInstance] = attr.ib(factory=dict) + running: dict[TaskInstanceKey, TaskInstance] = attr.ib(factory=dict) + skipped: set[TaskInstanceKey] = attr.ib(factory=set) + succeeded: set[TaskInstanceKey] = attr.ib(factory=set) + failed: set[TaskInstanceKey] = attr.ib(factory=set) + not_ready: set[TaskInstanceKey] = attr.ib(factory=set) + deadlocked: set[TaskInstance] = attr.ib(factory=set) + active_runs: list[DagRun] = attr.ib(factory=list) + executed_dag_run_dates: set[pendulum.DateTime] = attr.ib(factory=set) finished_runs: int = 0 total_runs: int = 0 @@ -117,6 +117,7 @@ def __init__( run_backwards=False, run_at_least_once=False, continue_on_failures=False, + disable_retry=False, *args, **kwargs, ): @@ -156,6 +157,7 @@ def __init__( self.run_backwards = run_backwards self.run_at_least_once = run_at_least_once self.continue_on_failures = continue_on_failures + self.disable_retry = disable_retry super().__init__(*args, **kwargs) def _update_counters(self, ti_status, session=None): @@ -217,6 +219,12 @@ def _update_counters(self, ti_status, session=None): tis_to_be_scheduled.append(ti) ti_status.running.pop(reduced_key) ti_status.to_run[ti.key] = ti + # special case: Deferrable task can go from DEFERRED to SCHEDULED; + # when that happens, we need to put it back as in UP_FOR_RESCHEDULE + elif ti.state == TaskInstanceState.SCHEDULED: + self.log.debug("Task instance %s is resumed from deferred state", ti) + ti_status.running.pop(ti.key) + ti_status.to_run[ti.key] = ti # Batch schedule of task instances if tis_to_be_scheduled: @@ -228,7 +236,7 @@ def _update_counters(self, ti_status, session=None): def _manage_executor_state( self, running, session - ) -> Iterator[Tuple["MappedOperator", str, Sequence[TaskInstance], int]]: + ) -> Iterator[tuple[AbstractOperator, str, Sequence[TaskInstance], int]]: """ Checks if the executor agrees with the state of task instances that are running. @@ -263,10 +271,20 @@ def _manage_executor_state( self.log.error(msg) ti.handle_failure(error=msg) continue + + def _iter_task_needing_expansion() -> Iterator[AbstractOperator]: + from airflow.models.mappedoperator import AbstractOperator + + for node in self.dag.get_task(ti.task_id, include_subdags=True).iter_mapped_dependants(): + if isinstance(node, AbstractOperator): + yield node + else: # A (mapped) task group. All its children need expansion. + yield from node.iter_tasks() + if ti.state not in self.STATES_COUNT_AS_RUNNING: # Don't use ti.task; if this task is mapped, that attribute # would hold the unmapped task. We need to original task here. - for node in self.dag.get_task(ti.task_id, include_subdags=True).iter_mapped_dependants(): + for node in _iter_task_needing_expansion(): new_tis, num_mapped_tis = node.expand_mapped_task(ti.run_id, session=session) yield node, ti.run_id, new_tis, num_mapped_tis @@ -292,13 +310,15 @@ def _get_dag_run(self, dagrun_info: DagRunInfo, dag: DAG, session: Session = Non # check if we are scheduling on top of a already existing dag_run # we could find a "scheduled" run instead of a "backfill" runs = DagRun.find(dag_id=dag.dag_id, execution_date=run_date, session=session) - run: Optional[DagRun] + run: DagRun | None if runs: run = runs[0] if run.state == DagRunState.RUNNING: respect_dag_max_active_limit = False # Fixes --conf overwrite for backfills with already existing DagRuns run.conf = self.conf or {} + # start_date is cleared for existing DagRuns + run.start_date = timezone.utcnow() else: run = None @@ -326,10 +346,12 @@ def _get_dag_run(self, dagrun_info: DagRunInfo, dag: DAG, session: Session = Non run.state = DagRunState.RUNNING run.run_type = DagRunType.BACKFILL_JOB run.verify_integrity(session=session) + + run.notify_dagrun_state_changed(msg="started") return run @provide_session - def _task_instances_for_dag_run(self, dag_run, session=None): + def _task_instances_for_dag_run(self, dag, dag_run, session=None): """ Returns a map of task instance key to task instance object for the tasks to run in the given dag run. @@ -349,24 +371,25 @@ def _task_instances_for_dag_run(self, dag_run, session=None): dag_run.refresh_from_db() make_transient(dag_run) + dag_run.dag = dag + info = dag_run.task_instance_scheduling_decisions(session=session) + schedulable_tis = info.schedulable_tis try: - for ti in dag_run.get_task_instances(): - # all tasks part of the backfill are scheduled to run - if ti.state == State.NONE: - ti.set_state(TaskInstanceState.SCHEDULED, session=session) + for ti in dag_run.get_task_instances(session=session): + if ti in schedulable_tis: + ti.set_state(TaskInstanceState.SCHEDULED) if ti.state != TaskInstanceState.REMOVED: tasks_to_run[ti.key] = ti session.commit() except Exception: session.rollback() raise - return tasks_to_run def _log_progress(self, ti_status): self.log.info( - '[backfill progress] | finished run %s of %s | tasks waiting: %s | succeeded: %s | ' - 'running: %s | failed: %s | skipped: %s | deadlocked: %s | not ready: %s', + "[backfill progress] | finished run %s of %s | tasks waiting: %s | succeeded: %s | " + "running: %s | failed: %s | skipped: %s | deadlocked: %s | not ready: %s", ti_status.finished_runs, ti_status.total_runs, len(ti_status.to_run), @@ -388,7 +411,7 @@ def _process_backfill_task_instances( pickle_id, start_date=None, session=None, - ): + ) -> list: """ Process a set of task instances from a set of dag runs. Special handling is done to account for different task instance states that could be present when running @@ -400,11 +423,10 @@ def _process_backfill_task_instances( :param start_date: the start date of the backfill job :param session: the current session object :return: the list of execution_dates for the finished dag runs - :rtype: list """ executed_run_dates = [] - is_unit_test = airflow_conf.getboolean('core', 'unit_test_mode') + is_unit_test = airflow_conf.getboolean("core", "unit_test_mode") while (len(ti_status.to_run) > 0 or len(ti_status.running) > 0) and len(ti_status.deadlocked) == 0: self.log.debug("*** Clearing out not_ready list ***") @@ -439,13 +461,6 @@ def _per_task_process(key, ti: TaskInstance, session=None): ti_status.running.pop(key) return - # guard against externally modified tasks instances or - # in case max concurrency has been reached at task runtime - elif ti.state == State.NONE: - self.log.warning( - "FIXME: Task instance %s state was set to None externally. This should not happen", ti - ) - ti.set_state(TaskInstanceState.SCHEDULED, session=session) if self.rerun_failed_tasks: # Rerun failed tasks or upstreamed failed tasks if ti.state in (TaskInstanceState.FAILED, TaskInstanceState.UPSTREAM_FAILED): @@ -485,7 +500,7 @@ def _per_task_process(key, ti: TaskInstance, session=None): if executor.has_task(ti): self.log.debug("Task Instance %s already in executor waiting for queue to clear", ti) else: - self.log.debug('Sending %s to executor', ti) + self.log.debug("Sending %s to executor", ti) # Skip scheduled state, we are executing immediately ti.state = TaskInstanceState.QUEUED ti.queued_by_job_id = self.id @@ -544,7 +559,7 @@ def _per_task_process(key, ti: TaskInstance, session=None): return # all remaining tasks - self.log.debug('Adding %s to not_ready', ti) + self.log.debug("Adding %s to not_ready", ti) ti_status.not_ready.add(key) try: @@ -555,7 +570,7 @@ def _per_task_process(key, ti: TaskInstance, session=None): pool = session.query(models.Pool).filter(models.Pool.pool == task.pool).first() if not pool: - raise PoolNotFound(f'Unknown pool: {task.pool}') + raise PoolNotFound(f"Unknown pool: {task.pool}") open_slots = pool.open_slots(session=session) if open_slots <= 0: @@ -627,6 +642,11 @@ def to_keep(key: TaskInstanceKey) -> bool: for new_ti in new_mapped_tis: new_ti.set_state(TaskInstanceState.SCHEDULED, session=session) + # Set state to failed for running TIs that are set up for retry if disable-retry flag is set + for ti in ti_status.running.values(): + if self.disable_retry and ti.state == TaskInstanceState.UP_FOR_RETRY: + ti.set_state(TaskInstanceState.FAILED, session=session) + # update the task counters self._update_counters(ti_status=ti_status, session=session) session.commit() @@ -669,12 +689,12 @@ def tabulate_ti_keys_set(ti_keys: Iterable[TaskInstanceKey]) -> str: return tabulate(sorted_ti_keys, headers=headers) - err = '' + err = "" if ti_status.failed: err += "Some task instances failed:\n" err += tabulate_ti_keys_set(ti_status.failed) if ti_status.deadlocked: - err += 'BackfillJob is deadlocked.' + err += "BackfillJob is deadlocked." deadlocked_depends_on_past = any( t.are_dependencies_met( dep_context=DepContext(ignore_depends_on_past=False), @@ -688,26 +708,26 @@ def tabulate_ti_keys_set(ti_keys: Iterable[TaskInstanceKey]) -> str: ) if deadlocked_depends_on_past: err += ( - 'Some of the deadlocked tasks were unable to run because ' + "Some of the deadlocked tasks were unable to run because " 'of "depends_on_past" relationships. Try running the ' - 'backfill with the option ' + "backfill with the option " '"ignore_first_depends_on_past=True" or passing "-I" at ' - 'the command line.' + "the command line." ) - err += '\nThese tasks have succeeded:\n' + err += "\nThese tasks have succeeded:\n" err += tabulate_ti_keys_set(ti_status.succeeded) - err += '\n\nThese tasks are running:\n' + err += "\n\nThese tasks are running:\n" err += tabulate_ti_keys_set(ti_status.running) - err += '\n\nThese tasks have failed:\n' + err += "\n\nThese tasks have failed:\n" err += tabulate_ti_keys_set(ti_status.failed) - err += '\n\nThese tasks are skipped:\n' + err += "\n\nThese tasks are skipped:\n" err += tabulate_ti_keys_set(ti_status.skipped) - err += '\n\nThese tasks are deadlocked:\n' + err += "\n\nThese tasks are deadlocked:\n" err += tabulate_ti_keys_set([ti.key for ti in ti_status.deadlocked]) return err - def _get_dag_with_subdags(self) -> List[DAG]: + def _get_dag_with_subdags(self) -> list[DAG]: return [self.dag] + self.dag.subdags @provide_session @@ -727,7 +747,7 @@ def _execute_dagruns(self, dagrun_infos, ti_status, executor, pickle_id, start_d for dagrun_info in dagrun_infos: for dag in self._get_dag_with_subdags(): dag_run = self._get_dag_run(dagrun_info, dag, session=session) - tis_map = self._task_instances_for_dag_run(dag_run, session=session) + tis_map = self._task_instances_for_dag_run(dag, dag_run, session=session) if dag_run is None: continue @@ -781,7 +801,7 @@ def _execute(self, session=None): tasks_that_depend_on_past = [t.task_id for t in self.dag.task_dict.values() if t.depends_on_past] if tasks_that_depend_on_past: raise AirflowException( - f'You cannot backfill backwards because one or more ' + f"You cannot backfill backwards because one or more " f'tasks depend_on_past: {",".join(tasks_that_depend_on_past)}' ) dagrun_infos = dagrun_infos[::-1] @@ -831,7 +851,7 @@ def _execute(self, session=None): pickle_id = pickle.id executor = self.executor - executor.job_id = "backfill" + executor.job_id = self.id executor.start() ti_status.total_runs = len(dagrun_infos) # total dag runs in backfill @@ -876,10 +896,10 @@ def _execute(self, session=None): session.commit() executor.end() - self.log.info("Backfill done. Exiting.") + self.log.info("Backfill done for DAG %s. Exiting.", self.dag) @provide_session - def reset_state_for_orphaned_tasks(self, filter_by_dag_run=None, session=None): + def reset_state_for_orphaned_tasks(self, filter_by_dag_run=None, session=None) -> int | None: """ This function checks if there are any tasks in the dagrun (or all) that have a schedule or queued states but are not known by the executor. If @@ -889,7 +909,6 @@ def reset_state_for_orphaned_tasks(self, filter_by_dag_run=None, session=None): :param filter_by_dag_run: the dag_run we want to process, None if all :return: the number of TIs reset - :rtype: int """ queued_tis = self.executor.queued_tasks # also consider running as the state might not have changed in the db yet @@ -934,7 +953,7 @@ def query(result, items): reset_tis = helpers.reduce_in_chunks(query, tis_to_reset, [], self.max_tis_per_query) - task_instance_str = '\n\t'.join(repr(x) for x in reset_tis) + task_instance_str = "\n\t".join(repr(x) for x in reset_tis) session.flush() self.log.info("Reset the following %s TaskInstances:\n\t%s", len(reset_tis), task_instance_str) diff --git a/airflow/jobs/base_job.py b/airflow/jobs/base_job.py index 7befccf956c41..d695615a5d8ba 100644 --- a/airflow/jobs/base_job.py +++ b/airflow/jobs/base_job.py @@ -15,12 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations from time import sleep -from typing import Optional -from sqlalchemy import Column, Index, Integer, String +from sqlalchemy import Column, Index, Integer, String, case from sqlalchemy.exc import OperationalError from sqlalchemy.orm import backref, foreign, relationship from sqlalchemy.orm.session import make_transient @@ -29,9 +28,8 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.executors.executor_loader import ExecutorLoader +from airflow.listeners.listener import get_listener_manager from airflow.models.base import ID_LEN, Base -from airflow.models.dagrun import DagRun -from airflow.models.taskinstance import TaskInstance from airflow.stats import Stats from airflow.utils import timezone from airflow.utils.helpers import convert_camel_to_snake @@ -43,6 +41,12 @@ from airflow.utils.state import State +def _resolve_dagrun_model(): + from airflow.models.dagrun import DagRun + + return DagRun + + class BaseJob(Base, LoggingMixin): """ Abstract class to be derived for jobs. Jobs are processing items with state @@ -66,24 +70,24 @@ class BaseJob(Base, LoggingMixin): hostname = Column(String(500)) unixname = Column(String(1000)) - __mapper_args__ = {'polymorphic_on': job_type, 'polymorphic_identity': 'BaseJob'} + __mapper_args__ = {"polymorphic_on": job_type, "polymorphic_identity": "BaseJob"} __table_args__ = ( - Index('job_type_heart', job_type, latest_heartbeat), - Index('idx_job_state_heartbeat', state, latest_heartbeat), - Index('idx_job_dag_id', dag_id), + Index("job_type_heart", job_type, latest_heartbeat), + Index("idx_job_state_heartbeat", state, latest_heartbeat), + Index("idx_job_dag_id", dag_id), ) task_instances_enqueued = relationship( - TaskInstance, - primaryjoin=id == foreign(TaskInstance.queued_by_job_id), # type: ignore[has-type] - backref=backref('queued_by_job', uselist=False), + "TaskInstance", + primaryjoin="BaseJob.id == foreign(TaskInstance.queued_by_job_id)", + backref=backref("queued_by_job", uselist=False), ) dag_runs = relationship( - DagRun, - primaryjoin=id == foreign(DagRun.creating_job_id), # type: ignore[has-type] - backref=backref('creating_job'), + "DagRun", + primaryjoin=lambda: BaseJob.id == foreign(_resolve_dagrun_model().creating_job_id), + backref="creating_job", ) """ @@ -92,7 +96,7 @@ class BaseJob(Base, LoggingMixin): Only makes sense for SchedulerJob and BackfillJob instances. """ - heartrate = conf.getfloat('scheduler', 'JOB_HEARTBEAT_SEC') + heartrate = conf.getfloat("scheduler", "JOB_HEARTBEAT_SEC") def __init__(self, executor=None, heartrate=None, *args, **kwargs): self.hostname = get_hostname() @@ -100,13 +104,14 @@ def __init__(self, executor=None, heartrate=None, *args, **kwargs): self.executor = executor self.executor_class = executor.__class__.__name__ else: - self.executor_class = conf.get('core', 'EXECUTOR') + self.executor_class = conf.get("core", "EXECUTOR") self.start_date = timezone.utcnow() self.latest_heartbeat = timezone.utcnow() if heartrate is not None: self.heartrate = heartrate self.unixname = getuser() - self.max_tis_per_query: int = conf.getint('scheduler', 'max_tis_per_query') + self.max_tis_per_query: int = conf.getint("scheduler", "max_tis_per_query") + get_listener_manager().hook.on_starting(component=self) super().__init__(*args, **kwargs) @cached_property @@ -115,18 +120,28 @@ def executor(self): @classmethod @provide_session - def most_recent_job(cls, session=None) -> Optional['BaseJob']: + def most_recent_job(cls, session=None) -> BaseJob | None: """ Return the most recent job of this type, if any, based on last heartbeat received. + Jobs in "running" state take precedence over others to make sure alive + job is returned if it is available. This method should be called on a subclass (i.e. on SchedulerJob) to return jobs of that type. :param session: Database session - :rtype: BaseJob or None """ - return session.query(cls).order_by(cls.latest_heartbeat.desc()).limit(1).first() + return ( + session.query(cls) + .order_by( + # Put "running" jobs at the front. + case({State.RUNNING: 0}, value=cls.state, else_=1), + cls.latest_heartbeat.desc(), + ) + .limit(1) + .first() + ) def is_alive(self, grace_multiplier=2.1): """ @@ -137,7 +152,6 @@ def is_alive(self, grace_multiplier=2.1): :param grace_multiplier: multiplier of heartrate to require heart beat within - :rtype: boolean """ return ( self.state == State.RUNNING @@ -153,7 +167,7 @@ def kill(self, session=None): try: self.on_kill() except Exception as e: - self.log.error('on_kill() method failed: %s', str(e)) + self.log.error("on_kill() method failed: %s", str(e)) session.merge(job) session.commit() raise AirflowException("Job shut down externally.") @@ -223,16 +237,16 @@ def heartbeat(self, only_if_necessary: bool = False): previous_heartbeat = self.latest_heartbeat self.heartbeat_callback(session=session) - self.log.debug('[heartbeat]') + self.log.debug("[heartbeat]") except OperationalError: - Stats.incr(convert_camel_to_snake(self.__class__.__name__) + '_heartbeat_failure', 1, 1) + Stats.incr(convert_camel_to_snake(self.__class__.__name__) + "_heartbeat_failure", 1, 1) self.log.exception("%s heartbeat got an exception", self.__class__.__name__) # We didn't manage to heartbeat, so make sure that the timestamp isn't updated self.latest_heartbeat = previous_heartbeat def run(self): """Starts the job.""" - Stats.incr(self.__class__.__name__.lower() + '_start', 1, 1) + Stats.incr(self.__class__.__name__.lower() + "_start", 1, 1) # Adding an entry in the DB with create_session() as session: self.state = State.RUNNING @@ -251,11 +265,12 @@ def run(self): self.state = State.FAILED raise finally: + get_listener_manager().hook.before_stopping(component=self) self.end_date = timezone.utcnow() session.merge(self) session.commit() - Stats.incr(self.__class__.__name__.lower() + '_end', 1, 1) + Stats.incr(self.__class__.__name__.lower() + "_end", 1, 1) def _execute(self): raise NotImplementedError("This method needs to be overridden") diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index 5711342e04d97..cc449e1fc56fb 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -15,34 +15,60 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations + import signal -from typing import Optional import psutil -from sqlalchemy.exc import OperationalError from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.jobs.base_job import BaseJob from airflow.listeners.events import register_task_instance_state_events from airflow.listeners.listener import get_listener_manager -from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance -from airflow.sentry import Sentry from airflow.stats import Stats from airflow.task.task_runner import get_task_runner from airflow.utils import timezone from airflow.utils.net import get_hostname from airflow.utils.session import provide_session -from airflow.utils.sqlalchemy import with_row_locks from airflow.utils.state import State +SIGSEGV_MESSAGE = """ +******************************************* Received SIGSEGV ******************************************* +SIGSEGV (Segmentation Violation) signal indicates Segmentation Fault error which refers to +an attempt by a program/library to write or read outside its allocated memory. + +In Python environment usually this signal refers to libraries which use low level C API. +Make sure that you use use right libraries/Docker Images +for your architecture (Intel/ARM) and/or Operational System (Linux/macOS). + +Suggested way to debug +====================== + - Set environment variable 'PYTHONFAULTHANDLER' to 'true'. + - Start airflow services. + - Restart failed airflow task. + - Check 'scheduler' and 'worker' services logs for additional traceback + which might contain information about module/library where actual error happen. + +Known Issues +============ + +Note: Only Linux-based distros supported as "Production" execution environment for Airflow. + +macOS +----- + 1. Due to limitations in Apple's libraries not every process might 'fork' safe. + One of the general error is unable to query the macOS system configuration for network proxies. + If your are not using a proxy you could disable it by set environment variable 'no_proxy' to '*'. + See: https://github.com/python/cpython/issues/58037 and https://bugs.python.org/issue30385#msg293958 +********************************************************************************************************""" + class LocalTaskJob(BaseJob): """LocalTaskJob runs a single task instance.""" - __mapper_args__ = {'polymorphic_identity': 'LocalTaskJob'} + __mapper_args__ = {"polymorphic_identity": "LocalTaskJob"} def __init__( self, @@ -52,9 +78,9 @@ def __init__( ignore_task_deps: bool = False, ignore_ti_state: bool = False, mark_success: bool = False, - pickle_id: Optional[str] = None, - pool: Optional[str] = None, - external_executor_id: Optional[str] = None, + pickle_id: str | None = None, + pool: str | None = None, + external_executor_id: str | None = None, *args, **kwargs, ): @@ -87,7 +113,26 @@ def signal_handler(signum, frame): self.task_runner.terminate() self.handle_task_exit(128 + signum) + def segfault_signal_handler(signum, frame): + """Setting sigmentation violation signal handler""" + self.log.critical(SIGSEGV_MESSAGE) + self.task_runner.terminate() + self.handle_task_exit(128 + signum) + raise AirflowException("Segmentation Fault detected.") + + def sigusr2_debug_handler(signum, frame): + import sys + import threading + import traceback + + id2name = {th.ident: th.name for th in threading.enumerate()} + for threadId, stack in sys._current_frames().items(): + print(id2name[threadId]) + traceback.print_stack(f=stack) + + signal.signal(signal.SIGSEGV, segfault_signal_handler) signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGUSR2, sigusr2_debug_handler) if not self.task_instance.check_and_change_state_before_execution( mark_success=self.mark_success, @@ -105,7 +150,7 @@ def signal_handler(signum, frame): try: self.task_runner.start() - heartbeat_time_limit = conf.getint('scheduler', 'scheduler_zombie_task_threshold') + heartbeat_time_limit = conf.getint("scheduler", "scheduler_zombie_task_threshold") # LocalTaskJob should not run callbacks, which are handled by TaskInstance._run_raw_task # 1, LocalTaskJob does not parse DAG, thus cannot run callbacks @@ -143,7 +188,7 @@ def signal_handler(signum, frame): # This can only really happen if the worker can't read the DB for a long time time_since_last_heartbeat = (timezone.utcnow() - self.latest_heartbeat).total_seconds() if time_since_last_heartbeat > heartbeat_time_limit: - Stats.incr('local_task_job_prolonged_heartbeat_failure', 1, 1) + Stats.incr("local_task_job_prolonged_heartbeat_failure", 1, 1) self.log.error("Heartbeat time limit exceeded!") raise AirflowException( f"Time since last heartbeat({time_since_last_heartbeat:.2f}s) exceeded limit " @@ -161,10 +206,11 @@ def handle_task_exit(self, return_code: int) -> None: # Without setting this, heartbeat may get us self.terminating = True self.log.info("Task exited with return code %s", return_code) + self._log_return_code_metric(return_code) if not self.task_instance.test_mode: - if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True): - self._run_mini_scheduler_on_child_tasks() + if conf.getboolean("scheduler", "schedule_after_task_execution", fallback=True): + self.task_instance.schedule_downstream_tasks() def on_kill(self): self.task_runner.terminate() @@ -195,7 +241,13 @@ def heartbeat_callback(self, session=None): recorded_pid = ti.pid same_process = recorded_pid == current_pid - if ti.run_as_user or self.task_runner.run_as_user: + if recorded_pid is not None and (ti.run_as_user or self.task_runner.run_as_user): + # when running as another user, compare the task runner pid to the parent of + # the recorded pid because user delegation becomes an extra process level. + # However, if recorded_pid is None, pass that through as it signals the task + # runner process has already completed and been cleared out. `psutil.Process` + # uses the current process if the parameter is None, which is not what is intended + # for comparison. recorded_pid = psutil.Process(ti.pid).ppid() same_process = recorded_pid == current_pid @@ -204,7 +256,7 @@ def heartbeat_callback(self, session=None): "Recorded pid %s does not match the current pid %s", recorded_pid, current_pid ) raise AirflowException("PID of job runner does not match") - elif self.task_runner.return_code() is None and hasattr(self.task_runner, 'process'): + elif self.task_runner.return_code() is None and hasattr(self.task_runner, "process"): if ti.state == State.SKIPPED: # A DagRun timeout will cause tasks to be externally marked as skipped. dagrun = ti.get_dagrun(session=session) @@ -223,56 +275,10 @@ def heartbeat_callback(self, session=None): self.terminating = True self._state_change_checks += 1 - @provide_session - @Sentry.enrich_errors - def _run_mini_scheduler_on_child_tasks(self, session=None) -> None: - try: - # Re-select the row with a lock - dag_run = with_row_locks( - session.query(DagRun).filter_by( - dag_id=self.dag_id, - run_id=self.task_instance.run_id, - ), - session=session, - ).one() - - task = self.task_instance.task - assert task.dag # For Mypy. - - # Get a partial DAG with just the specific tasks we want to examine. - # In order for dep checks to work correctly, we include ourself (so - # TriggerRuleDep can check the state of the task we just executed). - partial_dag = task.dag.partial_subset( - task.downstream_task_ids, - include_downstream=True, - include_upstream=False, - include_direct_upstream=True, - ) - - dag_run.dag = partial_dag - info = dag_run.task_instance_scheduling_decisions(session) - - skippable_task_ids = { - task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids - } - - schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids] - for schedulable_ti in schedulable_tis: - if not hasattr(schedulable_ti, "task"): - schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id) - - num = dag_run.schedule_tis(schedulable_tis) - self.log.info("%d downstream tasks scheduled from follow-on schedule check", num) - - session.commit() - except OperationalError as e: - # Any kind of DB error here is _non fatal_ as this block is just an optimisation. - self.log.info( - "Skipping mini scheduling run due to exception: %s", - e.statement, - exc_info=True, - ) - session.rollback() + def _log_return_code_metric(self, return_code: int): + Stats.incr( + f"local_task_job.task_exit.{self.id}.{self.dag_id}.{self.task_instance.task_id}.{return_code}" + ) @staticmethod def _enable_task_listeners(): diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 22ba5decb4111..b8b608efcd847 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -15,7 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations + import itertools import logging import multiprocessing @@ -25,37 +26,44 @@ import time import warnings from collections import defaultdict -from datetime import timedelta -from typing import Collection, DefaultDict, Dict, Iterator, List, Optional, Set, Tuple +from datetime import datetime, timedelta +from pathlib import Path +from typing import TYPE_CHECKING, Collection, DefaultDict, Iterator -from sqlalchemy import func, not_, or_, text +from sqlalchemy import and_, func, not_, or_, text from sqlalchemy.exc import OperationalError from sqlalchemy.orm import load_only, selectinload from sqlalchemy.orm.session import Session, make_transient +from sqlalchemy.sql import expression -from airflow import models, settings +from airflow import settings from airflow.callbacks.callback_requests import DagCallbackRequest, SlaCallbackRequest, TaskCallbackRequest -from airflow.callbacks.database_callback_sink import DatabaseCallbackSink from airflow.callbacks.pipe_callback_sink import PipeCallbackSink from airflow.configuration import conf -from airflow.dag_processing.manager import DagFileProcessorAgent +from airflow.exceptions import RemovedInAirflow3Warning from airflow.executors.executor_loader import UNPICKLEABLE_EXECUTORS from airflow.jobs.base_job import BaseJob -from airflow.jobs.local_task_job import LocalTaskJob -from airflow.models import DAG -from airflow.models.dag import DagModel +from airflow.models.dag import DAG, DagModel from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun +from airflow.models.dataset import ( + DagScheduleDatasetReference, + DatasetDagRunQueue, + DatasetEvent, + DatasetModel, + TaskOutletDatasetReference, +) from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, TaskInstanceKey from airflow.stats import Stats from airflow.ti_deps.dependencies_states import EXECUTION_STATES +from airflow.timetables.simple import DatasetTriggeredTimetable from airflow.utils import timezone -from airflow.utils.docs import get_docs_url from airflow.utils.event_scheduler import EventScheduler from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, run_with_db_retries -from airflow.utils.session import create_session, provide_session +from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.sqlalchemy import ( + CommitProhibitorGuard, is_lock_not_available_error, prohibit_commit, skip_locked, @@ -65,17 +73,22 @@ from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import DagRunType -TI = models.TaskInstance -DR = models.DagRun -DM = models.DagModel +if TYPE_CHECKING: + from types import FrameType + + from airflow.dag_processing.manager import DagFileProcessorAgent + +TI = TaskInstance +DR = DagRun +DM = DagModel -def _is_parent_process(): +def _is_parent_process() -> bool: """ Returns True if the current process is the parent process. False if the current process is a child process started by multiprocessing. """ - return multiprocessing.current_process().name == 'MainProcess' + return multiprocessing.current_process().name == "MainProcess" class SchedulerJob(BaseJob): @@ -100,18 +113,18 @@ class SchedulerJob(BaseJob): :param log: override the default Logger """ - __mapper_args__ = {'polymorphic_identity': 'SchedulerJob'} - heartrate: int = conf.getint('scheduler', 'SCHEDULER_HEARTBEAT_SEC') + __mapper_args__ = {"polymorphic_identity": "SchedulerJob"} + heartrate: int = conf.getint("scheduler", "SCHEDULER_HEARTBEAT_SEC") def __init__( self, subdir: str = settings.DAGS_FOLDER, - num_runs: int = conf.getint('scheduler', 'num_runs'), + num_runs: int = conf.getint("scheduler", "num_runs"), num_times_parse_dags: int = -1, - scheduler_idle_sleep_time: float = conf.getfloat('scheduler', 'scheduler_idle_sleep_time'), + scheduler_idle_sleep_time: float = conf.getfloat("scheduler", "scheduler_idle_sleep_time"), do_pickle: bool = False, - log: Optional[logging.Logger] = None, - processor_poll_interval: Optional[float] = None, + log: logging.Logger | None = None, + processor_poll_interval: float | None = None, *args, **kwargs, ): @@ -127,14 +140,15 @@ def __init__( warnings.warn( "The 'processor_poll_interval' parameter is deprecated. " "Please use 'scheduler_idle_sleep_time'.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) scheduler_idle_sleep_time = processor_poll_interval self._scheduler_idle_sleep_time = scheduler_idle_sleep_time # How many seconds do we wait for tasks to heartbeat before mark them as zombies. - self._zombie_threshold_secs = conf.getint('scheduler', 'scheduler_zombie_task_threshold') + self._zombie_threshold_secs = conf.getint("scheduler", "scheduler_zombie_task_threshold") self._standalone_dag_processor = conf.getboolean("scheduler", "standalone_dag_processor") + self._dag_stale_not_seen_duration = conf.getint("scheduler", "dag_stale_not_seen_duration") self.do_pickle = do_pickle super().__init__(*args, **kwargs) @@ -142,28 +156,13 @@ def __init__( self._log = log # Check what SQL backend we use - sql_conn: str = conf.get_mandatory_value('database', 'sql_alchemy_conn').lower() - self.using_sqlite = sql_conn.startswith('sqlite') - self.using_mysql = sql_conn.startswith('mysql') + sql_conn: str = conf.get_mandatory_value("database", "sql_alchemy_conn").lower() + self.using_sqlite = sql_conn.startswith("sqlite") # Dag Processor agent - not used in Dag Processor standalone mode. - self.processor_agent: Optional[DagFileProcessorAgent] = None + self.processor_agent: DagFileProcessorAgent | None = None self.dagbag = DagBag(dag_folder=self.subdir, read_dags_from_db=True, load_op_links=False) - self._paused_dag_without_running_dagruns: Set = set() - - if conf.getboolean('smart_sensor', 'use_smart_sensor'): - compatible_sensors = set( - map( - lambda l: l.strip(), - conf.get_mandatory_value('smart_sensor', 'sensors_enabled').split(','), - ) - ) - docs_url = get_docs_url('concepts/smart-sensors.html#migrating-to-deferrable-operators') - warnings.warn( - f'Smart sensors are deprecated, yet can be used for {compatible_sensors} sensors.' - f' Please use Deferrable Operators instead. See {docs_url} for more info.', - DeprecationWarning, - ) + self._paused_dag_without_running_dagruns: set = set() def register_signals(self) -> None: """Register signals that stop child processes""" @@ -171,7 +170,7 @@ def register_signals(self) -> None: signal.signal(signal.SIGTERM, self._exit_gracefully) signal.signal(signal.SIGUSR2, self._debug_dump) - def _exit_gracefully(self, signum, frame) -> None: + def _exit_gracefully(self, signum: int, frame: FrameType | None) -> None: """Helper method to clean up processor_agent to avoid leaving orphan processes.""" if not _is_parent_process(): # Only the parent process should perform the cleanup. @@ -182,7 +181,7 @@ def _exit_gracefully(self, signum, frame) -> None: self.processor_agent.end() sys.exit(os.EX_OK) - def _debug_dump(self, signum, frame): + def _debug_dump(self, signum: int, frame: FrameType | None) -> None: if not _is_parent_process(): # Only the parent process should perform the debug dump. return @@ -197,7 +196,7 @@ def _debug_dump(self, signum, frame): self.executor.debug_dump() self.log.info("-" * 80) - def is_alive(self, grace_multiplier: Optional[float] = None) -> bool: + def is_alive(self, grace_multiplier: float | None = None) -> bool: """ Is this SchedulerJob alive? @@ -207,44 +206,40 @@ def is_alive(self, grace_multiplier: Optional[float] = None) -> bool: ``grace_multiplier`` is accepted for compatibility with the parent class. - :rtype: boolean """ if grace_multiplier is not None: # Accept the same behaviour as superclass return super().is_alive(grace_multiplier=grace_multiplier) - scheduler_health_check_threshold: int = conf.getint('scheduler', 'scheduler_health_check_threshold') + scheduler_health_check_threshold: int = conf.getint("scheduler", "scheduler_health_check_threshold") return ( self.state == State.RUNNING and (timezone.utcnow() - self.latest_heartbeat).total_seconds() < scheduler_health_check_threshold ) - @provide_session def __get_concurrency_maps( - self, states: List[TaskInstanceState], session: Session = None - ) -> Tuple[DefaultDict[str, int], DefaultDict[Tuple[str, str], int]]: + self, states: list[TaskInstanceState], session: Session + ) -> tuple[DefaultDict[str, int], DefaultDict[tuple[str, str], int]]: """ Get the concurrency maps. :param states: List of states to query for :return: A map from (dag_id, task_id) to # of task instances and a map from (dag_id, task_id) to # of task instances in the given state list - :rtype: tuple[dict[str, int], dict[tuple[str, str], int]] """ - ti_concurrency_query: List[Tuple[str, str, int]] = ( - session.query(TI.task_id, TI.dag_id, func.count('*')) + ti_concurrency_query: list[tuple[str, str, int]] = ( + session.query(TI.task_id, TI.dag_id, func.count("*")) .filter(TI.state.in_(states)) .group_by(TI.task_id, TI.dag_id) ).all() dag_map: DefaultDict[str, int] = defaultdict(int) - task_map: DefaultDict[Tuple[str, str], int] = defaultdict(int) + task_map: DefaultDict[tuple[str, str], int] = defaultdict(int) for result in ti_concurrency_query: task_id, dag_id, count = result dag_map[dag_id] += count task_map[(dag_id, task_id)] = count return dag_map, task_map - @provide_session - def _executable_task_instances_to_queued(self, max_tis: int, session: Session = None) -> List[TI]: + def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -> list[TI]: """ Finds TIs that are ready for execution with respect to pool limits, dag max_active_tasks, executor state, and priority. @@ -252,9 +247,10 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = :param max_tis: Maximum number of TIs to queue in this loop. :return: list[airflow.models.TaskInstance] """ + from airflow.models.pool import Pool from airflow.utils.db import DBLocks - executable_tis: List[TI] = [] + executable_tis: list[TI] = [] if session.get_bind().dialect.name == "postgresql": # Optimization: to avoid littering the DB errors of "ERROR: canceling statement due to lock @@ -268,16 +264,16 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = if not lock_acquired: # Throw an error like the one that would happen with NOWAIT raise OperationalError( - "Failed to acquire advisory lock", params=None, orig=RuntimeError('55P03') + "Failed to acquire advisory lock", params=None, orig=RuntimeError("55P03") ) # Get the pool settings. We get a lock on the pool rows, treating this as a "critical section" # Throws an exception if lock cannot be obtained, rather than blocking - pools = models.Pool.slots_stats(lock_rows=True, session=session) + pools = Pool.slots_stats(lock_rows=True, session=session) # If the pools are full, there is no point doing anything! # If _somehow_ the pool is overfull, don't let the limit go negative - it breaks SQL - pool_slots_free = sum(max(0, pool['open']) for pool in pools.values()) + pool_slots_free = sum(max(0, pool["open"]) for pool in pools.values()) if pool_slots_free == 0: self.log.debug("All pools are full!") @@ -285,11 +281,11 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = max_tis = min(max_tis, pool_slots_free) - starved_pools = {pool_name for pool_name, stats in pools.items() if stats['open'] <= 0} + starved_pools = {pool_name for pool_name, stats in pools.items() if stats["open"] <= 0} # dag_id to # of running tasks and (dag_id, task_id) to # of running tasks. dag_active_tasks_map: DefaultDict[str, int] - task_concurrency_map: DefaultDict[Tuple[str, str], int] + task_concurrency_map: DefaultDict[tuple[str, str], int] dag_active_tasks_map, task_concurrency_map = self.__get_concurrency_maps( states=list(EXECUTION_STATES), session=session ) @@ -299,8 +295,8 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = num_starving_tasks_total = 0 # dag and task ids that can't be queued because of concurrency limits - starved_dags: Set[str] = set() - starved_tasks: Set[Tuple[str, str]] = set() + starved_dags: set[str] = set() + starved_tasks: set[tuple[str, str]] = set() pool_num_starving_tasks: DefaultDict[str, int] = defaultdict(int) @@ -315,13 +311,14 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = # and the dag is not paused query = ( session.query(TI) + .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql") .join(TI.dag_run) .filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state == DagRunState.RUNNING) .join(TI.dag_model) .filter(not_(DM.is_paused)) .filter(TI.state == TaskInstanceState.SCHEDULED) - .options(selectinload('dag_model')) - .order_by(-TI.priority_weight, DR.execution_date) + .options(selectinload("dag_model")) + .order_by(-TI.priority_weight, DR.execution_date, TI.map_index) ) if starved_pools: @@ -336,12 +333,21 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = query = query.limit(max_tis) - task_instances_to_examine: List[TI] = with_row_locks( - query, - of=TI, - session=session, - **skip_locked(session=session), - ).all() + timer = Stats.timer("scheduler.critical_section_query_duration") + timer.start() + + try: + task_instances_to_examine: list[TI] = with_row_locks( + query, + of=TI, + session=session, + **skip_locked(session=session), + ).all() + timer.stop(send=True) + except OperationalError as e: + timer.stop(send=False) + raise e + # TODO[HA]: This was wrong before anyway, as it only looked at a sub-set of dags, not everything. # Stats.gauge('scheduler.tasks.pending', len(task_instances_to_examine)) @@ -444,12 +450,12 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = dag_id, task_instance, ) - session.query(TI).filter(TI.dag_id == dag_id, TI.state == State.SCHEDULED).update( - {TI.state: State.FAILED}, synchronize_session='fetch' - ) + session.query(TI).filter( + TI.dag_id == dag_id, TI.state == TaskInstanceState.SCHEDULED + ).update({TI.state: TaskInstanceState.FAILED}, synchronize_session="fetch") continue - task_concurrency_limit: Optional[int] = None + task_concurrency_limit: int | None = None if serialized_dag.has_task(task_instance.task_id): task_concurrency_limit = serialized_dag.get_task( task_instance.task_id @@ -494,11 +500,11 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = ) for pool_name, num_starving_tasks in pool_num_starving_tasks.items(): - Stats.gauge(f'pool.starving_tasks.{pool_name}', num_starving_tasks) + Stats.gauge(f"pool.starving_tasks.{pool_name}", num_starving_tasks) - Stats.gauge('scheduler.tasks.starving', num_starving_tasks_total) - Stats.gauge('scheduler.tasks.running', num_tasks_in_executor) - Stats.gauge('scheduler.tasks.executable', len(executable_tis)) + Stats.gauge("scheduler.tasks.starving", num_starving_tasks_total) + Stats.gauge("scheduler.tasks.running", num_tasks_in_executor) + Stats.gauge("scheduler.tasks.executable", len(executable_tis)) if len(executable_tis) > 0: task_instance_str = "\n\t".join(repr(x) for x in executable_tis) @@ -521,10 +527,7 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = make_transient(ti) return executable_tis - @provide_session - def _enqueue_task_instances_with_queued_state( - self, task_instances: List[TI], session: Session = None - ) -> None: + def _enqueue_task_instances_with_queued_state(self, task_instances: list[TI], session: Session) -> None: """ Takes task_instances, which should have been set to queued, and enqueues them with the executor. @@ -581,14 +584,13 @@ def _critical_section_enqueue_task_instances(self, session: Session) -> int: self._enqueue_task_instances_with_queued_state(queued_tis, session=session) return len(queued_tis) - @provide_session - def _process_executor_events(self, session: Session = None) -> int: + def _process_executor_events(self, session: Session) -> int: """Respond to executor events.""" if not self._standalone_dag_processor and not self.processor_agent: raise ValueError("Processor agent is not started.") - ti_primary_key_to_try_number_map: Dict[Tuple[str, str, str, int], int] = {} + ti_primary_key_to_try_number_map: dict[tuple[str, str, str, int], int] = {} event_buffer = self.executor.get_event_buffer() - tis_with_right_state: List[TaskInstanceKey] = [] + tis_with_right_state: list[TaskInstanceKey] = [] # Report execution for ti_key, value in event_buffer.items(): @@ -614,7 +616,7 @@ def _process_executor_events(self, session: Session = None) -> int: # Check state of finished tasks filter_for_tis = TI.filter_for_tis(tis_with_right_state) - query = session.query(TI).filter(filter_for_tis).options(selectinload('dag_model')) + query = session.query(TI).filter(filter_for_tis).options(selectinload("dag_model")) # row lock this entire set of taskinstances to make sure the scheduler doesn't fail when we have # multi-schedulers tis: Iterator[TI] = with_row_locks( @@ -628,7 +630,6 @@ def _process_executor_events(self, session: Session = None) -> int: buffer_key = ti.key.with_try_number(try_number) state, info = event_buffer.pop(buffer_key) - # TODO: should we fail RUNNING as well, as we do in Backfills? if state == TaskInstanceState.QUEUED: ti.external_executor_id = info self.log.info("Setting external_id for %s to %s", ti, info) @@ -664,8 +665,21 @@ def _process_executor_events(self, session: Session = None) -> int: ti.pid, ) - if ti.try_number == buffer_key.try_number and ti.state == State.QUEUED: - Stats.incr('scheduler.tasks.killed_externally') + # There are two scenarios why the same TI with the same try_number is queued + # after executor is finished with it: + # 1) the TI was killed externally and it had no time to mark itself failed + # - in this case we should mark it as failed here. + # 2) the TI has been requeued after getting deferred - in this case either our executor has it + # or the TI is queued by another job. Either ways we should not fail it. + + # All of this could also happen if the state is "running", + # but that is handled by the zombie detection. + + ti_queued = ti.try_number == buffer_key.try_number and ti.state == TaskInstanceState.QUEUED + ti_requeued = ti.queued_by_job_id != self.id or self.executor.has_task(ti) + + if ti_queued and not ti_requeued: + Stats.incr("scheduler.tasks.killed_externally") msg = ( "Executor reports task instance %s finished (%s) although the " "task says its %s. (Info: %s) Was the task killed externally?" @@ -686,6 +700,7 @@ def _process_executor_events(self, session: Session = None) -> int: full_filepath=ti.dag_model.fileloc, simple_task_instance=SimpleTaskInstance.from_ti(ti), msg=msg % (ti, state, ti.state, info), + processor_subdir=ti.dag_model.processor_subdir, ) self.executor.send_callback(request) else: @@ -694,6 +709,8 @@ def _process_executor_events(self, session: Session = None) -> int: return len(event_buffer) def _execute(self) -> None: + from airflow.dag_processing.manager import DagFileProcessorAgent + self.log.info("Starting the scheduler") # DAGs can be pickled for easier remote execution by some executors @@ -705,11 +722,11 @@ def _execute(self) -> None: # so the scheduler job and DAG parser don't access the DB at the same time. async_mode = not self.using_sqlite - processor_timeout_seconds: int = conf.getint('core', 'dag_file_processor_timeout') + processor_timeout_seconds: int = conf.getint("core", "dag_file_processor_timeout") processor_timeout = timedelta(seconds=processor_timeout_seconds) if not self._standalone_dag_processor: self.processor_agent = DagFileProcessorAgent( - dag_directory=self.subdir, + dag_directory=Path(self.subdir), max_runs=self.num_times_parse_dags, processor_timeout=processor_timeout, dag_ids=[], @@ -725,6 +742,8 @@ def _execute(self) -> None: get_sink_pipe=self.processor_agent.get_callbacks_pipe ) else: + from airflow.callbacks.database_callback_sink import DatabaseCallbackSink + self.log.debug("Using DatabaseCallbackSink as callback sink.") self.executor.callback_sink = DatabaseCallbackSink() @@ -750,7 +769,7 @@ def _execute(self) -> None: self.log.info( "Deactivating DAGs that haven't been touched since %s", execute_start_time.isoformat() ) - models.DAG.deactivate_stale_dags(execute_start_time) + DAG.deactivate_stale_dags(execute_start_time) settings.Session.remove() # type: ignore except Exception: @@ -768,25 +787,32 @@ def _execute(self) -> None: self.log.exception("Exception when executing DagFileProcessorAgent.end") self.log.info("Exited execute loop") - def _update_dag_run_state_for_paused_dags(self): + @provide_session + def _update_dag_run_state_for_paused_dags(self, session: Session = NEW_SESSION) -> None: try: - paused_dag_ids = DagModel.get_all_paused_dag_ids() - for dag_id in paused_dag_ids: - if dag_id in self._paused_dag_without_running_dagruns: - continue - - dag = SerializedDagModel.get_dag(dag_id) + paused_runs = ( + session.query(DagRun) + .join(DagRun.dag_model) + .join(TaskInstance) + .filter( + DagModel.is_paused == expression.true(), + DagRun.state == DagRunState.RUNNING, + DagRun.run_type != DagRunType.BACKFILL_JOB, + ) + .having(DagRun.last_scheduling_decision <= func.max(TaskInstance.updated_at)) + .group_by(DagRun) + ) + for dag_run in paused_runs: + dag = self.dagbag.get_dag(dag_run.dag_id, session=session) if dag is None: continue - dag_runs = DagRun.find(dag_id=dag_id, state=State.RUNNING) - for dag_run in dag_runs: - dag_run.dag = dag - _, callback_to_run = dag_run.update_state(execute_callbacks=False) - if callback_to_run: - self._send_dag_callbacks_to_processor(dag, callback_to_run) - self._paused_dag_without_running_dagruns.add(dag_id) + + dag_run.dag = dag + _, callback_to_run = dag_run.update_state(execute_callbacks=False, session=session) + if callback_to_run: + self._send_dag_callbacks_to_processor(dag, callback_to_run) except Exception as e: # should not fail the scheduler - self.log.exception('Failed to update dag run state for paused dags due to %s', str(e)) + self.log.exception("Failed to update dag run state for paused dags due to %s", str(e)) def _run_scheduler_loop(self) -> None: """ @@ -803,11 +829,10 @@ def _run_scheduler_loop(self) -> None: .. image:: ../docs/apache-airflow/img/scheduler_loop.jpg - :rtype: None """ if not self.processor_agent and not self._standalone_dag_processor: raise ValueError("Processor agent is not started.") - is_unit_test: bool = conf.getboolean('core', 'unit_test_mode') + is_unit_test: bool = conf.getboolean("core", "unit_test_mode") timers = EventScheduler() @@ -815,28 +840,39 @@ def _run_scheduler_loop(self) -> None: self.adopt_or_reset_orphaned_tasks() timers.call_regular_interval( - conf.getfloat('scheduler', 'orphaned_tasks_check_interval', fallback=300.0), + conf.getfloat("scheduler", "orphaned_tasks_check_interval", fallback=300.0), self.adopt_or_reset_orphaned_tasks, ) timers.call_regular_interval( - conf.getfloat('scheduler', 'trigger_timeout_check_interval', fallback=15.0), + conf.getfloat("scheduler", "trigger_timeout_check_interval", fallback=15.0), self.check_trigger_timeouts, ) timers.call_regular_interval( - conf.getfloat('scheduler', 'pool_metrics_interval', fallback=5.0), + conf.getfloat("scheduler", "pool_metrics_interval", fallback=5.0), self._emit_pool_metrics, ) timers.call_regular_interval( - conf.getfloat('scheduler', 'zombie_detection_interval', fallback=10.0), + conf.getfloat("scheduler", "zombie_detection_interval", fallback=10.0), self._find_zombies, ) timers.call_regular_interval(60.0, self._update_dag_run_state_for_paused_dags) + timers.call_regular_interval( + conf.getfloat("scheduler", "parsing_cleanup_interval"), + self._orphan_unreferenced_datasets, + ) + + if self._standalone_dag_processor: + timers.call_regular_interval( + conf.getfloat("scheduler", "parsing_cleanup_interval"), + self._cleanup_stale_dags, + ) + for loop_count in itertools.count(start=1): - with Stats.timer() as timer: + with Stats.timer("scheduler.scheduler_loop_duration") as timer: if self.using_sqlite and self.processor_agent: self.processor_agent.run_single_parsing_loop() @@ -885,7 +921,7 @@ def _run_scheduler_loop(self) -> None: ) break - def _do_scheduling(self, session) -> int: + def _do_scheduling(self, session: Session) -> int: """ This function is where the main scheduling decisions take places. It: @@ -913,7 +949,6 @@ def _do_scheduling(self, session) -> int: See docs of _critical_section_enqueue_task_instances for more. :return: Number of TIs enqueued in this iteration - :rtype: int """ # Put a check in place to make sure we don't commit unexpectedly with prohibit_commit(session) as guard: @@ -926,12 +961,7 @@ def _do_scheduling(self, session) -> int: # Bulk fetch the currently active dag runs for the dags we are # examining, rather than making one query per DagRun - callback_tuples = [] - for dag_run in dag_runs: - callback_to_run = self._schedule_dag_run(dag_run, session) - callback_tuples.append((dag_run, callback_to_run)) - - guard.commit() + callback_tuples = self._schedule_all_dag_runs(guard, dag_runs, session) # Send the callbacks after we commit to ensure the context is up to date when it gets run for dag_run, callback_to_run in callback_tuples: @@ -954,7 +984,7 @@ def _do_scheduling(self, session) -> int: num_queued_tis = 0 else: try: - timer = Stats.timer('scheduler.critical_section_duration') + timer = Stats.timer("scheduler.critical_section_duration") timer.start() # Find anything TIs in state SCHEDULED, try to QUEUE it (send it to the executor) @@ -968,7 +998,7 @@ def _do_scheduling(self, session) -> int: if is_lock_not_available_error(error=e): self.log.debug("Critical section lock held by another Scheduler") - Stats.incr('scheduler.critical_section_busy') + Stats.incr("scheduler.critical_section_busy") session.rollback() return 0 raise @@ -983,10 +1013,19 @@ def _get_next_dagruns_to_examine(self, state: DagRunState, session: Session): return DagRun.next_dagruns_to_examine(state, session) @retry_db_transaction - def _create_dagruns_for_dags(self, guard, session): + def _create_dagruns_for_dags(self, guard: CommitProhibitorGuard, session: Session) -> None: """Find Dag Models needing DagRuns and Create Dag Runs with retries in case of OperationalError""" - query = DagModel.dags_needing_dagruns(session) - self._create_dag_runs(query.all(), session) + query, dataset_triggered_dag_info = DagModel.dags_needing_dagruns(session) + all_dags_needing_dag_runs = set(query.all()) + dataset_triggered_dags = [ + dag for dag in all_dags_needing_dag_runs if dag.dag_id in dataset_triggered_dag_info + ] + non_dataset_dags = all_dags_needing_dag_runs.difference(dataset_triggered_dags) + self._create_dag_runs(non_dataset_dags, session) + if dataset_triggered_dags: + self._create_dag_runs_dataset_triggered( + dataset_triggered_dags, dataset_triggered_dag_info, session + ) # commit the session - Release the write lock on DagModel table. guard.commit() @@ -1052,7 +1091,107 @@ def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) - # TODO[HA]: Should we do a session.flush() so we don't have to keep lots of state/object in # memory for larger dags? or expunge_all() - def _should_update_dag_next_dagruns(self, dag, dag_model: DagModel, total_active_runs) -> bool: + def _create_dag_runs_dataset_triggered( + self, + dag_models: Collection[DagModel], + dataset_triggered_dag_info: dict[str, tuple[datetime, datetime]], + session: Session, + ) -> None: + """For DAGs that are triggered by datasets, create dag runs.""" + # Bulk Fetch DagRuns with dag_id and execution_date same + # as DagModel.dag_id and DagModel.next_dagrun + # This list is used to verify if the DagRun already exist so that we don't attempt to create + # duplicate dag runs + exec_dates = { + dag_id: timezone.coerce_datetime(last_time) + for dag_id, (_, last_time) in dataset_triggered_dag_info.items() + } + existing_dagruns: set[tuple[str, timezone.DateTime]] = set( + session.query(DagRun.dag_id, DagRun.execution_date).filter( + tuple_in_condition((DagRun.dag_id, DagRun.execution_date), exec_dates.items()) + ) + ) + + for dag_model in dag_models: + dag = self.dagbag.get_dag(dag_model.dag_id, session=session) + if not dag: + self.log.error("DAG '%s' not found in serialized_dag table", dag_model.dag_id) + continue + + if not isinstance(dag.timetable, DatasetTriggeredTimetable): + self.log.error( + "DAG '%s' was dataset-scheduled, but didn't have a DatasetTriggeredTimetable!", + dag_model.dag_id, + ) + continue + + dag_hash = self.dagbag.dags_hash.get(dag.dag_id) + + # Explicitly check if the DagRun already exists. This is an edge case + # where a Dag Run is created but `DagModel.next_dagrun` and `DagModel.next_dagrun_create_after` + # are not updated. + # We opted to check DagRun existence instead + # of catching an Integrity error and rolling back the session i.e + # we need to set dag.next_dagrun_info if the Dag Run already exists or if we + # create a new one. This is so that in the next Scheduling loop we try to create new runs + # instead of falling in a loop of Integrity Error. + exec_date = exec_dates[dag.dag_id] + if (dag.dag_id, exec_date) not in existing_dagruns: + + previous_dag_run = ( + session.query(DagRun) + .filter( + DagRun.dag_id == dag.dag_id, + DagRun.execution_date < exec_date, + DagRun.run_type == DagRunType.DATASET_TRIGGERED, + ) + .order_by(DagRun.execution_date.desc()) + .first() + ) + dataset_event_filters = [ + DagScheduleDatasetReference.dag_id == dag.dag_id, + DatasetEvent.timestamp <= exec_date, + ] + if previous_dag_run: + dataset_event_filters.append(DatasetEvent.timestamp > previous_dag_run.execution_date) + + dataset_events = ( + session.query(DatasetEvent) + .join( + DagScheduleDatasetReference, + DatasetEvent.dataset_id == DagScheduleDatasetReference.dataset_id, + ) + .join(DatasetEvent.source_dag_run) + .filter(*dataset_event_filters) + .all() + ) + + data_interval = dag.timetable.data_interval_for_events(exec_date, dataset_events) + run_id = dag.timetable.generate_run_id( + run_type=DagRunType.DATASET_TRIGGERED, + logical_date=exec_date, + data_interval=data_interval, + session=session, + events=dataset_events, + ) + + dag_run = dag.create_dagrun( + run_id=run_id, + run_type=DagRunType.DATASET_TRIGGERED, + execution_date=exec_date, + data_interval=data_interval, + state=DagRunState.QUEUED, + external_trigger=False, + session=session, + dag_hash=dag_hash, + creating_job_id=self.id, + ) + dag_run.consumed_dataset_events.extend(dataset_events) + session.query(DatasetDagRunQueue).filter( + DatasetDagRunQueue.target_dag_id == dag_run.dag_id + ).delete() + + def _should_update_dag_next_dagruns(self, dag, dag_model: DagModel, total_active_runs: int) -> bool: """Check if the dag's next_dagruns_create_after should be updated.""" if total_active_runs >= dag.max_active_runs: self.log.info( @@ -1065,10 +1204,7 @@ def _should_update_dag_next_dagruns(self, dag, dag_model: DagModel, total_active return False return True - def _start_queued_dagruns( - self, - session: Session, - ) -> None: + def _start_queued_dagruns(self, session: Session) -> None: """Find DagRuns in queued state and decide moving them to running state""" dag_runs = self._get_next_dagruns_to_examine(DagRunState.QUEUED, session) @@ -1088,10 +1224,9 @@ def _update_state(dag: DAG, dag_run: DagRun): # always happening immediately after the data interval. expected_start_date = dag.get_run_data_interval(dag_run).end schedule_delay = dag_run.start_date - expected_start_date - Stats.timing(f'dagrun.schedule_delay.{dag.dag_id}', schedule_delay) + Stats.timing(f"dagrun.schedule_delay.{dag.dag_id}", schedule_delay) for dag_run in dag_runs: - dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session) if not dag: self.log.error("DAG '%s' not found in serialized_dag table", dag_run.dag_id) @@ -1108,19 +1243,32 @@ def _update_state(dag: DAG, dag_run: DagRun): else: active_runs_of_dags[dag_run.dag_id] += 1 _update_state(dag, dag_run) + dag_run.notify_dagrun_state_changed() + + @retry_db_transaction + def _schedule_all_dag_runs(self, guard, dag_runs, session): + """Makes scheduling decisions for all `dag_runs`""" + callback_tuples = [] + for dag_run in dag_runs: + callback_to_run = self._schedule_dag_run(dag_run, session) + callback_tuples.append((dag_run, callback_to_run)) + + guard.commit() + + return callback_tuples def _schedule_dag_run( self, dag_run: DagRun, session: Session, - ) -> Optional[DagCallbackRequest]: + ) -> DagCallbackRequest | None: """ Make scheduling decisions about an individual dag run :param dag_run: The DagRun to schedule :return: Callback that needs to be executed """ - callback: Optional[DagCallbackRequest] = None + callback: DagCallbackRequest | None = None dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session) @@ -1156,19 +1304,20 @@ def _schedule_dag_run( dag_id=dag.dag_id, run_id=dag_run.run_id, is_failure_callback=True, - msg='timed_out', + processor_subdir=dag_model.processor_subdir, + msg="timed_out", ) - # Send SLA & DAG Success/Failure Callbacks to be executed - self._send_dag_callbacks_to_processor(dag, callback_to_execute) - # Because we send the callback here, we need to return None - return callback + dag_run.notify_dagrun_state_changed() + return callback_to_execute if dag_run.execution_date > timezone.utcnow() and not dag.allow_future_exec_dates: self.log.error("Execution date is in future: %s", dag_run.execution_date) return callback - self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session) + if not self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session): + self.log.warning("The DAG disappeared before verifying integrity: %s. Skipping.", dag_run.dag_id) + return callback # TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else? schedulable_tis, callback_to_run = dag_run.update_state(session=session, execute_callbacks=False) if dag_run.state in State.finished: @@ -1184,30 +1333,36 @@ def _schedule_dag_run( return callback_to_run - @provide_session - def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session=None): - """Only run DagRun.verify integrity if Serialized DAG has changed since it is slow""" + def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session: Session) -> bool: + """ + Only run DagRun.verify integrity if Serialized DAG has changed since it is slow. + + Return True if we determine that DAG still exists. + """ latest_version = SerializedDagModel.get_latest_version_hash(dag_run.dag_id, session=session) if dag_run.dag_hash == latest_version: self.log.debug("DAG %s not changed structure, skipping dagrun.verify_integrity", dag_run.dag_id) - return + return True dag_run.dag_hash = latest_version # Refresh the DAG dag_run.dag = self.dagbag.get_dag(dag_id=dag_run.dag_id, session=session) + if not dag_run.dag: + return False # Verify integrity also takes care of session.flush dag_run.verify_integrity(session=session) + return True - def _send_dag_callbacks_to_processor(self, dag: DAG, callback: Optional[DagCallbackRequest] = None): + def _send_dag_callbacks_to_processor(self, dag: DAG, callback: DagCallbackRequest | None = None) -> None: self._send_sla_callbacks_to_processor(dag) if callback: self.executor.send_callback(callback) else: self.log.debug("callback is empty") - def _send_sla_callbacks_to_processor(self, dag: DAG): + def _send_sla_callbacks_to_processor(self, dag: DAG) -> None: """Sends SLA Callbacks to DagFileProcessor if tasks have SLAs set and check_slas=True""" if not settings.CHECK_SLAS: return @@ -1216,32 +1371,42 @@ def _send_sla_callbacks_to_processor(self, dag: DAG): self.log.debug("Skipping SLA check for %s because no tasks in DAG have SLAs", dag) return - request = SlaCallbackRequest(full_filepath=dag.fileloc, dag_id=dag.dag_id) + if not dag.timetable.periodic: + self.log.debug("Skipping SLA check for %s because DAG is not scheduled", dag) + return + + dag_model = DagModel.get_dagmodel(dag.dag_id) + request = SlaCallbackRequest( + full_filepath=dag.fileloc, + dag_id=dag.dag_id, + processor_subdir=dag_model.processor_subdir, + ) self.executor.send_callback(request) @provide_session - def _emit_pool_metrics(self, session: Session = None) -> None: - pools = models.Pool.slots_stats(session=session) + def _emit_pool_metrics(self, session: Session = NEW_SESSION) -> None: + from airflow.models.pool import Pool + + pools = Pool.slots_stats(session=session) for pool_name, slot_stats in pools.items(): - Stats.gauge(f'pool.open_slots.{pool_name}', slot_stats["open"]) - Stats.gauge(f'pool.queued_slots.{pool_name}', slot_stats["queued"]) - Stats.gauge(f'pool.running_slots.{pool_name}', slot_stats["running"]) + Stats.gauge(f"pool.open_slots.{pool_name}", slot_stats["open"]) + Stats.gauge(f"pool.queued_slots.{pool_name}", slot_stats["queued"]) + Stats.gauge(f"pool.running_slots.{pool_name}", slot_stats["running"]) @provide_session - def heartbeat_callback(self, session: Session = None) -> None: - Stats.incr('scheduler_heartbeat', 1, 1) + def heartbeat_callback(self, session: Session = NEW_SESSION) -> None: + Stats.incr("scheduler_heartbeat", 1, 1) @provide_session - def adopt_or_reset_orphaned_tasks(self, session: Session = None): + def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) -> int: """ Reset any TaskInstance still in QUEUED or SCHEDULED states that were enqueued by a SchedulerJob that is no longer running. :return: the number of TIs reset - :rtype: int """ self.log.info("Resetting orphaned tasks for active dag runs") - timeout = conf.getint('scheduler', 'scheduler_health_check_threshold') + timeout = conf.getint("scheduler", "scheduler_health_check_threshold") for attempt in run_with_db_retries(logger=self.log): with attempt: @@ -1264,7 +1429,7 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = None): if num_failed: self.log.info("Marked %d SchedulerJob instances as failed", num_failed) - Stats.incr(self.__class__.__name__.lower() + '_end', num_failed) + Stats.incr(self.__class__.__name__.lower() + "_end", num_failed) resettable_states = [TaskInstanceState.QUEUED, TaskInstanceState.RUNNING] query = ( @@ -1299,11 +1464,11 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = None): for ti in set(tis_to_reset_or_adopt) - set(to_reset): ti.queued_by_job_id = self.id - Stats.incr('scheduler.orphaned_tasks.cleared', len(to_reset)) - Stats.incr('scheduler.orphaned_tasks.adopted', len(tis_to_reset_or_adopt) - len(to_reset)) + Stats.incr("scheduler.orphaned_tasks.cleared", len(to_reset)) + Stats.incr("scheduler.orphaned_tasks.adopted", len(tis_to_reset_or_adopt) - len(to_reset)) if to_reset: - task_instance_str = '\n\t'.join(reset_tis_message) + task_instance_str = "\n\t".join(reset_tis_message) self.log.info( "Reset the following %s orphaned TaskInstances:\n\t%s", len(to_reset), @@ -1320,7 +1485,7 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = None): return len(to_reset) @provide_session - def check_trigger_timeouts(self, session: Session = None): + def check_trigger_timeouts(self, session: Session = NEW_SESSION) -> None: """ Looks at all tasks that are in the "deferred" state and whose trigger or execution timeout has passed, so they can be marked as failed. @@ -1345,38 +1510,117 @@ def check_trigger_timeouts(self, session: Session = None): if num_timed_out_tasks: self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks) - @provide_session - def _find_zombies(self, session): + def _find_zombies(self) -> None: """ Find zombie task instances, which are tasks haven't heartbeated for too long - and update the current zombie list. + or have a no-longer-running LocalTaskJob, and create a TaskCallbackRequest + to be handled by the DAG processor. """ + from airflow.jobs.local_task_job import LocalTaskJob + self.log.debug("Finding 'running' jobs without a recent heartbeat") limit_dttm = timezone.utcnow() - timedelta(seconds=self._zombie_threshold_secs) - zombies = ( - session.query(TaskInstance, DagModel.fileloc) - .join(LocalTaskJob, TaskInstance.job_id == LocalTaskJob.id) - .join(DagModel, TaskInstance.dag_id == DagModel.dag_id) - .filter(TaskInstance.state == State.RUNNING) - .filter( - or_( - LocalTaskJob.state != State.RUNNING, - LocalTaskJob.latest_heartbeat < limit_dttm, + with create_session() as session: + zombies: list[tuple[TI, str, str]] = ( + session.query(TI, DM.fileloc, DM.processor_subdir) + .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql") + .join(LocalTaskJob, TI.job_id == LocalTaskJob.id) + .join(DM, TI.dag_id == DM.dag_id) + .filter(TI.state == TaskInstanceState.RUNNING) + .filter( + or_( + LocalTaskJob.state != State.RUNNING, + LocalTaskJob.latest_heartbeat < limit_dttm, + ) ) + .filter(TI.queued_by_job_id == self.id) + .all() ) - .all() - ) if zombies: self.log.warning("Failing (%s) jobs without heartbeat after %s", len(zombies), limit_dttm) - for ti, file_loc in zombies: + for ti, file_loc, processor_subdir in zombies: + zombie_message_details = self._generate_zombie_message_details(ti) request = TaskCallbackRequest( full_filepath=file_loc, + processor_subdir=processor_subdir, simple_task_instance=SimpleTaskInstance.from_ti(ti), - msg=f"Detected {ti} as zombie", + msg=str(zombie_message_details), ) self.log.error("Detected zombie job: %s", request) self.executor.send_callback(request) - Stats.incr('zombies_killed') + Stats.incr("zombies_killed") + + @staticmethod + def _generate_zombie_message_details(ti: TaskInstance): + zombie_message_details = { + "DAG Id": ti.dag_id, + "Task Id": ti.task_id, + "Run Id": ti.run_id, + } + + if ti.map_index != -1: + zombie_message_details["Map Index"] = ti.map_index + if ti.hostname: + zombie_message_details["Hostname"] = ti.hostname + if ti.external_executor_id: + zombie_message_details["External Executor Id"] = ti.external_executor_id + + return zombie_message_details + + @provide_session + def _cleanup_stale_dags(self, session: Session = NEW_SESSION) -> None: + """ + Find all dags that were not updated by Dag Processor recently and mark them as inactive. + + In case one of DagProcessors is stopped (in case there are multiple of them + for different dag folders), it's dags are never marked as inactive. + Also remove dags from SerializedDag table. + Executed on schedule only if [scheduler]standalone_dag_processor is True. + """ + self.log.debug("Checking dags not parsed within last %s seconds.", self._dag_stale_not_seen_duration) + limit_lpt = timezone.utcnow() - timedelta(seconds=self._dag_stale_not_seen_duration) + stale_dags = ( + session.query(DagModel).filter(DagModel.is_active, DagModel.last_parsed_time < limit_lpt).all() + ) + if not stale_dags: + self.log.debug("Not stale dags found.") + return + + self.log.info("Found (%d) stales dags not parsed after %s.", len(stale_dags), limit_lpt) + for dag in stale_dags: + dag.is_active = False + SerializedDagModel.remove_dag(dag_id=dag.dag_id, session=session) + session.flush() + + @provide_session + def _orphan_unreferenced_datasets(self, session: Session = NEW_SESSION) -> None: + """ + Detects datasets that are no longer referenced in any DAG schedule parameters or task outlets and + sets the dataset is_orphaned flag to True + """ + orphaned_dataset_query = ( + session.query(DatasetModel) + .join( + DagScheduleDatasetReference, + isouter=True, + ) + .join( + TaskOutletDatasetReference, + isouter=True, + ) + # MSSQL doesn't like it when we select a column that we haven't grouped by. All other DBs let us + # group by id and select all columns. + .group_by(DatasetModel if session.get_bind().dialect.name == "mssql" else DatasetModel.id) + .having( + and_( + func.count(DagScheduleDatasetReference.dag_id) == 0, + func.count(TaskOutletDatasetReference.dag_id) == 0, + ) + ) + ) + for dataset in orphaned_dataset_query: + self.log.info("Orphaning unreferenced dataset '%s'", dataset.uri) + dataset.is_orphaned = expression.true() diff --git a/airflow/jobs/triggerer_job.py b/airflow/jobs/triggerer_job.py index ac7d22a6b1da9..20fcf20b169b0 100644 --- a/airflow/jobs/triggerer_job.py +++ b/airflow/jobs/triggerer_job.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import asyncio import os @@ -22,7 +23,7 @@ import threading import time from collections import deque -from typing import Deque, Dict, Set, Tuple, Type +from typing import Deque from sqlalchemy import func @@ -48,14 +49,14 @@ class TriggererJob(BaseJob): - A subthread runs all the async code """ - __mapper_args__ = {'polymorphic_identity': 'TriggererJob'} + __mapper_args__ = {"polymorphic_identity": "TriggererJob"} def __init__(self, capacity=None, *args, **kwargs): # Call superclass super().__init__(*args, **kwargs) if capacity is None: - self.capacity = conf.getint('triggerer', 'default_capacity', fallback=1000) + self.capacity = conf.getint("triggerer", "default_capacity", fallback=1000) elif isinstance(capacity, int) and capacity > 0: self.capacity = capacity else: @@ -158,7 +159,7 @@ def handle_events(self): # Tell the model to wake up its tasks Trigger.submit_event(trigger_id=trigger_id, event=event) # Emit stat event - Stats.incr('triggers.succeeded') + Stats.incr("triggers.succeeded") def handle_failed_triggers(self): """ @@ -170,10 +171,10 @@ def handle_failed_triggers(self): trigger_id, saved_exc = self.runner.failed_triggers.popleft() Trigger.submit_failure(trigger_id=trigger_id, exc=saved_exc) # Emit stat event - Stats.incr('triggers.failed') + Stats.incr("triggers.failed") def emit_metrics(self): - Stats.gauge('triggers.running', len(self.runner.triggers)) + Stats.gauge("triggers.running", len(self.runner.triggers)) class TriggerDetails(TypedDict): @@ -195,22 +196,22 @@ class TriggerRunner(threading.Thread, LoggingMixin): """ # Maps trigger IDs to their running tasks and other info - triggers: Dict[int, TriggerDetails] + triggers: dict[int, TriggerDetails] # Cache for looking up triggers by classpath - trigger_cache: Dict[str, Type[BaseTrigger]] + trigger_cache: dict[str, type[BaseTrigger]] # Inbound queue of new triggers - to_create: Deque[Tuple[int, BaseTrigger]] + to_create: Deque[tuple[int, BaseTrigger]] # Inbound queue of deleted triggers to_cancel: Deque[int] # Outbound queue of events - events: Deque[Tuple[int, TriggerEvent]] + events: Deque[tuple[int, TriggerEvent]] # Outbound queue of failed triggers - failed_triggers: Deque[Tuple[int, BaseException]] + failed_triggers: Deque[tuple[int, BaseException]] # Should-we-stop flag stop: bool = False @@ -346,7 +347,7 @@ async def block_watchdog(self): "to get more information on overrunning coroutines.", time_elapsed, ) - Stats.incr('triggers.blocked_main_thread') + Stats.incr("triggers.blocked_main_thread") # Async trigger logic @@ -355,10 +356,10 @@ async def run_trigger(self, trigger_id, trigger): Wrapper which runs an actual trigger (they are async generators) and pushes their events into our outbound event deque. """ - self.log.info("Trigger %s starting", self.triggers[trigger_id]['name']) + self.log.info("Trigger %s starting", self.triggers[trigger_id]["name"]) try: async for event in trigger.run(): - self.log.info("Trigger %s fired: %s", self.triggers[trigger_id]['name'], event) + self.log.info("Trigger %s fired: %s", self.triggers[trigger_id]["name"], event) self.triggers[trigger_id]["events"] += 1 self.events.append((trigger_id, event)) finally: @@ -370,7 +371,7 @@ async def run_trigger(self, trigger_id, trigger): # Main-thread sync API - def update_triggers(self, requested_trigger_ids: Set[int]): + def update_triggers(self, requested_trigger_ids: set[int]): """ Called from the main thread to request that we update what triggers we're running. @@ -413,7 +414,7 @@ def update_triggers(self, requested_trigger_ids: Set[int]): for old_id in cancel_trigger_ids: self.to_cancel.append(old_id) - def get_trigger_by_classpath(self, classpath: str) -> Type[BaseTrigger]: + def get_trigger_by_classpath(self, classpath: str) -> type[BaseTrigger]: """ Gets a trigger class by its classpath ("path.to.module.classname") diff --git a/airflow/kubernetes/k8s_model.py b/airflow/kubernetes/k8s_model.py index 01e294672aa52..123294e0bdc20 100644 --- a/airflow/kubernetes/k8s_model.py +++ b/airflow/kubernetes/k8s_model.py @@ -15,9 +15,10 @@ # specific language governing permissions and limitations # under the License. """Classes for interacting with Kubernetes API.""" +from __future__ import annotations + from abc import ABC, abstractmethod from functools import reduce -from typing import List, Optional from kubernetes.client import models as k8s @@ -42,7 +43,7 @@ def attach_to_pod(self, pod: k8s.V1Pod) -> k8s.V1Pod: """ -def append_to_pod(pod: k8s.V1Pod, k8s_objects: Optional[List[K8SModel]]): +def append_to_pod(pod: k8s.V1Pod, k8s_objects: list[K8SModel] | None): """ :param pod: A pod to attach a list of Kubernetes objects to :param k8s_objects: a potential None list of K8SModels diff --git a/airflow/kubernetes/kube_client.py b/airflow/kubernetes/kube_client.py index 7e6ba05119787..46aff02b9d947 100644 --- a/airflow/kubernetes/kube_client.py +++ b/airflow/kubernetes/kube_client.py @@ -15,8 +15,9 @@ # specific language governing permissions and limitations # under the License. """Client for kubernetes communication""" +from __future__ import annotations + import logging -from typing import Optional from airflow.configuration import conf @@ -30,7 +31,10 @@ has_kubernetes = True def _disable_verify_ssl() -> None: - configuration = Configuration() + if hasattr(Configuration, "get_default_copy"): + configuration = Configuration.get_default_copy() + else: + configuration = Configuration() configuration.verify_ssl = False Configuration.set_default(configuration) @@ -55,35 +59,35 @@ def _enable_tcp_keepalive() -> None: from urllib3.connection import HTTPConnection, HTTPSConnection - tcp_keep_idle = conf.getint('kubernetes', 'tcp_keep_idle') - tcp_keep_intvl = conf.getint('kubernetes', 'tcp_keep_intvl') - tcp_keep_cnt = conf.getint('kubernetes', 'tcp_keep_cnt') + tcp_keep_idle = conf.getint("kubernetes_executor", "tcp_keep_idle") + tcp_keep_intvl = conf.getint("kubernetes_executor", "tcp_keep_intvl") + tcp_keep_cnt = conf.getint("kubernetes_executor", "tcp_keep_cnt") socket_options = [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)] if hasattr(socket, "TCP_KEEPIDLE"): socket_options.append((socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, tcp_keep_idle)) else: - log.warning("Unable to set TCP_KEEPIDLE on this platform") + log.debug("Unable to set TCP_KEEPIDLE on this platform") if hasattr(socket, "TCP_KEEPINTVL"): socket_options.append((socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, tcp_keep_intvl)) else: - log.warning("Unable to set TCP_KEEPINTVL on this platform") + log.debug("Unable to set TCP_KEEPINTVL on this platform") if hasattr(socket, "TCP_KEEPCNT"): socket_options.append((socket.IPPROTO_TCP, socket.TCP_KEEPCNT, tcp_keep_cnt)) else: - log.warning("Unable to set TCP_KEEPCNT on this platform") + log.debug("Unable to set TCP_KEEPCNT on this platform") HTTPSConnection.default_socket_options = HTTPSConnection.default_socket_options + socket_options HTTPConnection.default_socket_options = HTTPConnection.default_socket_options + socket_options def get_kube_client( - in_cluster: bool = conf.getboolean('kubernetes', 'in_cluster'), - cluster_context: Optional[str] = None, - config_file: Optional[str] = None, + in_cluster: bool = conf.getboolean("kubernetes_executor", "in_cluster"), + cluster_context: str | None = None, + config_file: str | None = None, ) -> client.CoreV1Api: """ Retrieves Kubernetes client @@ -97,19 +101,19 @@ def get_kube_client( if not has_kubernetes: raise _import_err - if conf.getboolean('kubernetes', 'enable_tcp_keepalive'): + if conf.getboolean("kubernetes_executor", "enable_tcp_keepalive"): _enable_tcp_keepalive() - if not conf.getboolean('kubernetes', 'verify_ssl'): - _disable_verify_ssl() - if in_cluster: config.load_incluster_config() else: if cluster_context is None: - cluster_context = conf.get('kubernetes', 'cluster_context', fallback=None) + cluster_context = conf.get("kubernetes_executor", "cluster_context", fallback=None) if config_file is None: - config_file = conf.get('kubernetes', 'config_file', fallback=None) + config_file = conf.get("kubernetes_executor", "config_file", fallback=None) config.load_kube_config(config_file=config_file, context=cluster_context) + if not conf.getboolean("kubernetes_executor", "verify_ssl"): + _disable_verify_ssl() + return client.CoreV1Api() diff --git a/airflow/kubernetes/kube_config.py b/airflow/kubernetes/kube_config.py index 2b550214543dc..0285f65208d01 100644 --- a/airflow/kubernetes/kube_config.py +++ b/airflow/kubernetes/kube_config.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from airflow.configuration import conf from airflow.exceptions import AirflowConfigException @@ -23,30 +24,30 @@ class KubeConfig: """Configuration for Kubernetes""" - core_section = 'core' - kubernetes_section = 'kubernetes' - logging_section = 'logging' + core_section = "core" + kubernetes_section = "kubernetes_executor" + logging_section = "logging" def __init__(self): configuration_dict = conf.as_dict(display_sensitive=True) self.core_configuration = configuration_dict[self.core_section] self.airflow_home = AIRFLOW_HOME - self.dags_folder = conf.get(self.core_section, 'dags_folder') - self.parallelism = conf.getint(self.core_section, 'parallelism') - self.pod_template_file = conf.get(self.kubernetes_section, 'pod_template_file', fallback=None) + self.dags_folder = conf.get(self.core_section, "dags_folder") + self.parallelism = conf.getint(self.core_section, "parallelism") + self.pod_template_file = conf.get(self.kubernetes_section, "pod_template_file", fallback=None) - self.delete_worker_pods = conf.getboolean(self.kubernetes_section, 'delete_worker_pods') + self.delete_worker_pods = conf.getboolean(self.kubernetes_section, "delete_worker_pods") self.delete_worker_pods_on_failure = conf.getboolean( - self.kubernetes_section, 'delete_worker_pods_on_failure' + self.kubernetes_section, "delete_worker_pods_on_failure" ) self.worker_pods_creation_batch_size = conf.getint( - self.kubernetes_section, 'worker_pods_creation_batch_size' + self.kubernetes_section, "worker_pods_creation_batch_size" ) - self.worker_container_repository = conf.get(self.kubernetes_section, 'worker_container_repository') - self.worker_container_tag = conf.get(self.kubernetes_section, 'worker_container_tag') + self.worker_container_repository = conf.get(self.kubernetes_section, "worker_container_repository") + self.worker_container_tag = conf.get(self.kubernetes_section, "worker_container_tag") if self.worker_container_repository and self.worker_container_tag: - self.kube_image = f'{self.worker_container_repository}:{self.worker_container_tag}' + self.kube_image = f"{self.worker_container_repository}:{self.worker_container_tag}" else: self.kube_image = None @@ -54,27 +55,27 @@ def __init__(self): # that if your # cluster has RBAC enabled, your scheduler may need service account permissions to # create, watch, get, and delete pods in this namespace. - self.kube_namespace = conf.get(self.kubernetes_section, 'namespace') - self.multi_namespace_mode = conf.getboolean(self.kubernetes_section, 'multi_namespace_mode') + self.kube_namespace = conf.get(self.kubernetes_section, "namespace") + self.multi_namespace_mode = conf.getboolean(self.kubernetes_section, "multi_namespace_mode") # The Kubernetes Namespace in which pods will be created by the executor. Note # that if your # cluster has RBAC enabled, your workers may need service account permissions to # interact with cluster components. - self.executor_namespace = conf.get(self.kubernetes_section, 'namespace') + self.executor_namespace = conf.get(self.kubernetes_section, "namespace") - self.worker_pods_pending_timeout = conf.getint(self.kubernetes_section, 'worker_pods_pending_timeout') + self.worker_pods_pending_timeout = conf.getint(self.kubernetes_section, "worker_pods_pending_timeout") self.worker_pods_pending_timeout_check_interval = conf.getint( - self.kubernetes_section, 'worker_pods_pending_timeout_check_interval' + self.kubernetes_section, "worker_pods_pending_timeout_check_interval" ) self.worker_pods_pending_timeout_batch_size = conf.getint( - self.kubernetes_section, 'worker_pods_pending_timeout_batch_size' + self.kubernetes_section, "worker_pods_pending_timeout_batch_size" ) self.worker_pods_queued_check_interval = conf.getint( - self.kubernetes_section, 'worker_pods_queued_check_interval' + self.kubernetes_section, "worker_pods_queued_check_interval" ) self.kube_client_request_args = conf.getjson( - self.kubernetes_section, 'kube_client_request_args', fallback={} + self.kubernetes_section, "kube_client_request_args", fallback={} ) if not isinstance(self.kube_client_request_args, dict): raise AirflowConfigException( @@ -82,13 +83,13 @@ def __init__(self): + type(self.kube_client_request_args).__name__ ) if self.kube_client_request_args: - if '_request_timeout' in self.kube_client_request_args and isinstance( - self.kube_client_request_args['_request_timeout'], list + if "_request_timeout" in self.kube_client_request_args and isinstance( + self.kube_client_request_args["_request_timeout"], list ): - self.kube_client_request_args['_request_timeout'] = tuple( - self.kube_client_request_args['_request_timeout'] + self.kube_client_request_args["_request_timeout"] = tuple( + self.kube_client_request_args["_request_timeout"] ) - self.delete_option_kwargs = conf.getjson(self.kubernetes_section, 'delete_option_kwargs', fallback={}) + self.delete_option_kwargs = conf.getjson(self.kubernetes_section, "delete_option_kwargs", fallback={}) if not isinstance(self.delete_option_kwargs, dict): raise AirflowConfigException( f"[{self.kubernetes_section}] 'delete_option_kwargs' expected a JSON dict, got " diff --git a/airflow/kubernetes/kubernetes_helper_functions.py b/airflow/kubernetes/kubernetes_helper_functions.py index 1068a0521b894..1f0c809cf94bd 100644 --- a/airflow/kubernetes/kubernetes_helper_functions.py +++ b/airflow/kubernetes/kubernetes_helper_functions.py @@ -14,9 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import logging -from typing import Dict, Optional import pendulum from slugify import slugify @@ -26,22 +26,7 @@ log = logging.getLogger(__name__) -def _strip_unsafe_kubernetes_special_chars(string: str) -> str: - """ - Kubernetes only supports lowercase alphanumeric characters, "-" and "." in - the pod name. - However, there are special rules about how "-" and "." can be used so let's - only keep - alphanumeric chars see here for detail: - https://kubernetes.io/docs/concepts/overview/working-with-objects/names/ - - :param string: The requested Pod name - :return: Pod name stripped of any unsafe characters - """ - return slugify(string, separator='', lowercase=True) - - -def create_pod_id(dag_id: str, task_id: str) -> str: +def create_pod_id(dag_id: str | None = None, task_id: str | None = None) -> str: """ Generates the kubernetes safe pod_id. Note that this is NOT the full ID that will be launched to k8s. We will add a uuid @@ -51,27 +36,32 @@ def create_pod_id(dag_id: str, task_id: str) -> str: :param task_id: Task ID :return: The non-unique pod_id for this task/DAG pairing """ - safe_dag_id = _strip_unsafe_kubernetes_special_chars(dag_id) - safe_task_id = _strip_unsafe_kubernetes_special_chars(task_id) - return safe_dag_id + safe_task_id + name = "" + if dag_id: + name += dag_id + if task_id: + if name: + name += "-" + name += task_id + return slugify(name, lowercase=True)[:253].strip("-.") -def annotations_to_key(annotations: Dict[str, str]) -> Optional[TaskInstanceKey]: +def annotations_to_key(annotations: dict[str, str]) -> TaskInstanceKey: """Build a TaskInstanceKey based on pod annotations""" log.debug("Creating task key for annotations %s", annotations) - dag_id = annotations['dag_id'] - task_id = annotations['task_id'] - try_number = int(annotations['try_number']) - annotation_run_id = annotations.get('run_id') - map_index = int(annotations.get('map_index', -1)) + dag_id = annotations["dag_id"] + task_id = annotations["task_id"] + try_number = int(annotations["try_number"]) + annotation_run_id = annotations.get("run_id") + map_index = int(annotations.get("map_index", -1)) - if not annotation_run_id and 'execution_date' in annotations: + if not annotation_run_id and "execution_date" in annotations: # Compat: Look up the run_id from the TI table! from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance from airflow.settings import Session - execution_date = pendulum.parse(annotations['execution_date']) + execution_date = pendulum.parse(annotations["execution_date"]) # Do _not_ use create-session, we don't want to expunge session = Session() diff --git a/airflow/kubernetes/pod.py b/airflow/kubernetes/pod.py index a5b6cde0e335b..5b946b2e3a885 100644 --- a/airflow/kubernetes/pod.py +++ b/airflow/kubernetes/pod.py @@ -19,16 +19,20 @@ This module is deprecated. Please use :mod:`kubernetes.client.models` for `V1ResourceRequirements` and `Port`. """ -# flake8: noqa +from __future__ import annotations import warnings +from airflow.exceptions import RemovedInAirflow3Warning + +# flake8: noqa + with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + warnings.simplefilter("ignore", RemovedInAirflow3Warning) from airflow.providers.cncf.kubernetes.backcompat.pod import Port, Resources # noqa: autoflake warnings.warn( "This module is deprecated. Please use `kubernetes.client.models` for `V1ResourceRequirements` and `Port`.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py index 52b45801ccabc..4382b390aa485 100644 --- a/airflow/kubernetes/pod_generator.py +++ b/airflow/kubernetes/pod_generator.py @@ -20,29 +20,33 @@ The advantage being that the full Kubernetes API is supported and no serialization need be written. """ +from __future__ import annotations + import copy import datetime import hashlib +import logging import os import re import uuid import warnings from functools import reduce -from typing import List, Optional, Union from dateutil import parser from kubernetes.client import models as k8s from kubernetes.client.api_client import ApiClient -from airflow.exceptions import AirflowConfigException +from airflow.exceptions import AirflowConfigException, PodReconciliationError, RemovedInAirflow3Warning from airflow.kubernetes.pod_generator_deprecated import PodDefaults, PodGenerator as PodGeneratorDeprecated from airflow.utils import yaml from airflow.version import version as airflow_version +log = logging.getLogger(__name__) + MAX_LABEL_LEN = 63 -def make_safe_label_value(string): +def make_safe_label_value(string: str) -> str: """ Valid label values must be 63 characters or less and must be empty or begin and end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), underscores (_), @@ -70,7 +74,7 @@ def datetime_to_label_safe_datestring(datetime_obj: datetime.datetime) -> str: :param datetime_obj: datetime.datetime object :return: ISO-like string representing the datetime """ - return datetime_obj.isoformat().replace(":", "_").replace('+', '_plus_') + return datetime_obj.isoformat().replace(":", "_").replace("+", "_plus_") def label_safe_datestring_to_datetime(string: str) -> datetime.datetime: @@ -82,7 +86,7 @@ def label_safe_datestring_to_datetime(string: str) -> datetime.datetime: :param string: str :return: datetime.datetime object """ - return parser.parse(string.replace('_plus_', '+').replace("_", ":")) + return parser.parse(string.replace("_plus_", "+").replace("_", ":")) class PodGenerator: @@ -100,8 +104,8 @@ class PodGenerator: def __init__( self, - pod: Optional[k8s.V1Pod] = None, - pod_template_file: Optional[str] = None, + pod: k8s.V1Pod | None = None, + pod_template_file: str | None = None, extract_xcom: bool = True, ): if not pod_template_file and not pod: @@ -147,7 +151,7 @@ def add_xcom_sidecar(pod: k8s.V1Pod) -> k8s.V1Pod: return pod_cp @staticmethod - def from_obj(obj) -> Optional[Union[dict, k8s.V1Pod]]: + def from_obj(obj) -> dict | k8s.V1Pod | None: """Converts to pod from obj""" if obj is None: return None @@ -169,19 +173,19 @@ def from_obj(obj) -> Optional[Union[dict, k8s.V1Pod]]: return k8s_object elif isinstance(k8s_legacy_object, dict): warnings.warn( - 'Using a dictionary for the executor_config is deprecated and will soon be removed.' + "Using a dictionary for the executor_config is deprecated and will soon be removed." 'please use a `kubernetes.client.models.V1Pod` class with a "pod_override" key' - ' instead. ', - category=DeprecationWarning, + " instead. ", + category=RemovedInAirflow3Warning, ) return PodGenerator.from_legacy_obj(obj) else: raise TypeError( - 'Cannot convert a non-kubernetes.client.models.V1Pod object into a KubernetesExecutorConfig' + "Cannot convert a non-kubernetes.client.models.V1Pod object into a KubernetesExecutorConfig" ) @staticmethod - def from_legacy_obj(obj) -> Optional[k8s.V1Pod]: + def from_legacy_obj(obj) -> k8s.V1Pod | None: """Converts to pod from obj""" if obj is None: return None @@ -193,18 +197,18 @@ def from_legacy_obj(obj) -> Optional[k8s.V1Pod]: if not namespaced: return None - resources = namespaced.get('resources') + resources = namespaced.get("resources") if resources is None: requests = { - 'cpu': namespaced.pop('request_cpu', None), - 'memory': namespaced.pop('request_memory', None), - 'ephemeral-storage': namespaced.get('ephemeral-storage'), # We pop this one in limits + "cpu": namespaced.pop("request_cpu", None), + "memory": namespaced.pop("request_memory", None), + "ephemeral-storage": namespaced.get("ephemeral-storage"), # We pop this one in limits } limits = { - 'cpu': namespaced.pop('limit_cpu', None), - 'memory': namespaced.pop('limit_memory', None), - 'ephemeral-storage': namespaced.pop('ephemeral-storage', None), + "cpu": namespaced.pop("limit_cpu", None), + "memory": namespaced.pop("limit_memory", None), + "ephemeral-storage": namespaced.pop("ephemeral-storage", None), } all_resources = list(requests.values()) + list(limits.values()) if all(r is None for r in all_resources): @@ -214,11 +218,11 @@ def from_legacy_obj(obj) -> Optional[k8s.V1Pod]: requests = {k: v for k, v in requests.items() if v is not None} limits = {k: v for k, v in limits.items() if v is not None} resources = k8s.V1ResourceRequirements(requests=requests, limits=limits) - namespaced['resources'] = resources + namespaced["resources"] = resources return PodGeneratorDeprecated(**namespaced).gen_pod() @staticmethod - def reconcile_pods(base_pod: k8s.V1Pod, client_pod: Optional[k8s.V1Pod]) -> k8s.V1Pod: + def reconcile_pods(base_pod: k8s.V1Pod, client_pod: k8s.V1Pod | None) -> k8s.V1Pod: """ :param base_pod: has the base attributes which are overwritten if they exist in the client pod and remain if they do not exist in the client_pod @@ -253,17 +257,17 @@ def reconcile_metadata(base_meta, client_meta): elif client_meta and base_meta: client_meta.labels = merge_objects(base_meta.labels, client_meta.labels) client_meta.annotations = merge_objects(base_meta.annotations, client_meta.annotations) - extend_object_field(base_meta, client_meta, 'managed_fields') - extend_object_field(base_meta, client_meta, 'finalizers') - extend_object_field(base_meta, client_meta, 'owner_references') + extend_object_field(base_meta, client_meta, "managed_fields") + extend_object_field(base_meta, client_meta, "finalizers") + extend_object_field(base_meta, client_meta, "owner_references") return merge_objects(base_meta, client_meta) return None @staticmethod def reconcile_specs( - base_spec: Optional[k8s.V1PodSpec], client_spec: Optional[k8s.V1PodSpec] - ) -> Optional[k8s.V1PodSpec]: + base_spec: k8s.V1PodSpec | None, client_spec: k8s.V1PodSpec | None + ) -> k8s.V1PodSpec | None: """ :param base_spec: has the base attributes which are overwritten if they exist in the client_spec and remain if they do not exist in the client_spec @@ -278,16 +282,16 @@ def reconcile_specs( client_spec.containers = PodGenerator.reconcile_containers( base_spec.containers, client_spec.containers ) - merged_spec = extend_object_field(base_spec, client_spec, 'init_containers') - merged_spec = extend_object_field(base_spec, merged_spec, 'volumes') + merged_spec = extend_object_field(base_spec, client_spec, "init_containers") + merged_spec = extend_object_field(base_spec, merged_spec, "volumes") return merge_objects(base_spec, merged_spec) return None @staticmethod def reconcile_containers( - base_containers: List[k8s.V1Container], client_containers: List[k8s.V1Container] - ) -> List[k8s.V1Container]: + base_containers: list[k8s.V1Container], client_containers: list[k8s.V1Container] + ) -> list[k8s.V1Container]: """ :param base_containers: has the base attributes which are overwritten if they exist in the client_containers and remain if they do not exist in the client_containers @@ -303,11 +307,11 @@ def reconcile_containers( client_container = client_containers[0] base_container = base_containers[0] - client_container = extend_object_field(base_container, client_container, 'volume_mounts') - client_container = extend_object_field(base_container, client_container, 'env') - client_container = extend_object_field(base_container, client_container, 'env_from') - client_container = extend_object_field(base_container, client_container, 'ports') - client_container = extend_object_field(base_container, client_container, 'volume_devices') + client_container = extend_object_field(base_container, client_container, "volume_mounts") + client_container = extend_object_field(base_container, client_container, "env") + client_container = extend_object_field(base_container, client_container, "env_from") + client_container = extend_object_field(base_container, client_container, "ports") + client_container = extend_object_field(base_container, client_container, "volume_devices") client_container = merge_objects(base_container, client_container) return [client_container] + PodGenerator.reconcile_containers( @@ -321,13 +325,13 @@ def construct_pod( pod_id: str, try_number: int, kube_image: str, - date: Optional[datetime.datetime], - args: List[str], - pod_override_object: Optional[k8s.V1Pod], + date: datetime.datetime | None, + args: list[str], + pod_override_object: k8s.V1Pod | None, base_worker_pod: k8s.V1Pod, namespace: str, scheduler_job_id: str, - run_id: Optional[str] = None, + run_id: str | None = None, map_index: int = -1, ) -> k8s.V1Pod: """ @@ -344,27 +348,27 @@ def construct_pod( image = kube_image annotations = { - 'dag_id': dag_id, - 'task_id': task_id, - 'try_number': str(try_number), + "dag_id": dag_id, + "task_id": task_id, + "try_number": str(try_number), } labels = { - 'airflow-worker': make_safe_label_value(scheduler_job_id), - 'dag_id': make_safe_label_value(dag_id), - 'task_id': make_safe_label_value(task_id), - 'try_number': str(try_number), - 'airflow_version': airflow_version.replace('+', '-'), - 'kubernetes_executor': 'True', + "airflow-worker": make_safe_label_value(scheduler_job_id), + "dag_id": make_safe_label_value(dag_id), + "task_id": make_safe_label_value(task_id), + "try_number": str(try_number), + "airflow_version": airflow_version.replace("+", "-"), + "kubernetes_executor": "True", } if map_index >= 0: - annotations['map_index'] = str(map_index) - labels['map_index'] = str(map_index) + annotations["map_index"] = str(map_index) + labels["map_index"] = str(map_index) if date: - annotations['execution_date'] = date.isoformat() - labels['execution_date'] = datetime_to_label_safe_datestring(date) + annotations["execution_date"] = date.isoformat() + labels["execution_date"] = datetime_to_label_safe_datestring(date) if run_id: - annotations['run_id'] = run_id - labels['run_id'] = make_safe_label_value(run_id) + annotations["run_id"] = run_id + labels["run_id"] = make_safe_label_value(run_id) dynamic_pod = k8s.V1Pod( metadata=k8s.V1ObjectMeta( @@ -389,7 +393,10 @@ def construct_pod( # Pod from the pod_template_File -> Pod from executor_config arg -> Pod from the K8s executor pod_list = [base_worker_pod, pod_override_object, dynamic_pod] - return reduce(PodGenerator.reconcile_pods, pod_list) + try: + return reduce(PodGenerator.reconcile_pods, pod_list) + except Exception as e: + raise PodReconciliationError from e @staticmethod def serialize_pod(pod: k8s.V1Pod) -> dict: @@ -408,24 +415,25 @@ def deserialize_model_file(path: str) -> k8s.V1Pod: """ :param path: Path to the file :return: a kubernetes.client.models.V1Pod - - Unfortunately we need access to the private method - ``_ApiClient__deserialize_model`` from the kubernetes client. - This issue is tracked here; https://github.com/kubernetes-client/python/issues/977. """ if os.path.exists(path): with open(path) as stream: pod = yaml.safe_load(stream) else: - pod = yaml.safe_load(path) + pod = None + log.warning("Model file %s does not exist", path) return PodGenerator.deserialize_model_dict(pod) @staticmethod - def deserialize_model_dict(pod_dict: dict) -> k8s.V1Pod: + def deserialize_model_dict(pod_dict: dict | None) -> k8s.V1Pod: """ Deserializes python dictionary to k8s.V1Pod + Unfortunately we need access to the private method + ``_ApiClient__deserialize_model`` from the kubernetes client. + This issue is tracked here; https://github.com/kubernetes-client/python/issues/977. + :param pod_dict: Serialized dict of k8s.V1Pod object :return: De-serialized k8s.V1Pod """ @@ -433,7 +441,7 @@ def deserialize_model_dict(pod_dict: dict) -> k8s.V1Pod: return api_client._ApiClient__deserialize_model(pod_dict, k8s.V1Pod) @staticmethod - def make_unique_pod_id(pod_id: str) -> Optional[str]: + def make_unique_pod_id(pod_id: str) -> str | None: r""" Kubernetes pod names must consist of one or more lowercase rfc1035/rfc1123 labels separated by '.' with a maximum length of 253 @@ -456,7 +464,7 @@ def make_unique_pod_id(pod_id: str) -> Optional[str]: # Get prefix length after subtracting the uuid length. Clean up '.' and '-' from # end of podID ('.' can't be followed by '-'). label_prefix_length = MAX_LABEL_LEN - len(safe_uuid) - 1 # -1 for separator - trimmed_pod_id = pod_id[:label_prefix_length].rstrip('-.') + trimmed_pod_id = pod_id[:label_prefix_length].rstrip("-.") # previously used a '.' as the separator, but this could create errors in some situations return f"{trimmed_pod_id}-{safe_uuid}" diff --git a/airflow/kubernetes/pod_generator_deprecated.py b/airflow/kubernetes/pod_generator_deprecated.py index fcbb2bb402d6d..f08a6c45d2231 100644 --- a/airflow/kubernetes/pod_generator_deprecated.py +++ b/airflow/kubernetes/pod_generator_deprecated.py @@ -20,11 +20,12 @@ The advantage being that the full Kubernetes API is supported and no serialization need be written. """ +from __future__ import annotations + import copy import hashlib import re import uuid -from typing import Dict, List, Optional, Union from kubernetes.client import models as k8s @@ -36,15 +37,15 @@ class PodDefaults: """Static defaults for Pods""" - XCOM_MOUNT_PATH = '/airflow/xcom' - SIDECAR_CONTAINER_NAME = 'airflow-xcom-sidecar' + XCOM_MOUNT_PATH = "/airflow/xcom" + SIDECAR_CONTAINER_NAME = "airflow-xcom-sidecar" XCOM_CMD = 'trap "exit 0" INT; while true; do sleep 30; done;' - VOLUME_MOUNT = k8s.V1VolumeMount(name='xcom', mount_path=XCOM_MOUNT_PATH) - VOLUME = k8s.V1Volume(name='xcom', empty_dir=k8s.V1EmptyDirVolumeSource()) + VOLUME_MOUNT = k8s.V1VolumeMount(name="xcom", mount_path=XCOM_MOUNT_PATH) + VOLUME = k8s.V1Volume(name="xcom", empty_dir=k8s.V1EmptyDirVolumeSource()) SIDECAR_CONTAINER = k8s.V1Container( name=SIDECAR_CONTAINER_NAME, - command=['sh', '-c', XCOM_CMD], - image='alpine', + command=["sh", "-c", XCOM_CMD], + image="alpine", volume_mounts=[VOLUME_MOUNT], resources=k8s.V1ResourceRequirements( requests={ @@ -117,38 +118,38 @@ class PodGenerator: def __init__( self, - image: Optional[str] = None, - name: Optional[str] = None, - namespace: Optional[str] = None, - volume_mounts: Optional[List[Union[k8s.V1VolumeMount, dict]]] = None, - envs: Optional[Dict[str, str]] = None, - cmds: Optional[List[str]] = None, - args: Optional[List[str]] = None, - labels: Optional[Dict[str, str]] = None, - node_selectors: Optional[Dict[str, str]] = None, - ports: Optional[List[Union[k8s.V1ContainerPort, dict]]] = None, - volumes: Optional[List[Union[k8s.V1Volume, dict]]] = None, - image_pull_policy: Optional[str] = None, - restart_policy: Optional[str] = None, - image_pull_secrets: Optional[str] = None, - init_containers: Optional[List[k8s.V1Container]] = None, - service_account_name: Optional[str] = None, - resources: Optional[Union[k8s.V1ResourceRequirements, dict]] = None, - annotations: Optional[Dict[str, str]] = None, - affinity: Optional[dict] = None, + image: str | None = None, + name: str | None = None, + namespace: str | None = None, + volume_mounts: list[k8s.V1VolumeMount | dict] | None = None, + envs: dict[str, str] | None = None, + cmds: list[str] | None = None, + args: list[str] | None = None, + labels: dict[str, str] | None = None, + node_selectors: dict[str, str] | None = None, + ports: list[k8s.V1ContainerPort | dict] | None = None, + volumes: list[k8s.V1Volume | dict] | None = None, + image_pull_policy: str | None = None, + restart_policy: str | None = None, + image_pull_secrets: str | None = None, + init_containers: list[k8s.V1Container] | None = None, + service_account_name: str | None = None, + resources: k8s.V1ResourceRequirements | dict | None = None, + annotations: dict[str, str] | None = None, + affinity: dict | None = None, hostnetwork: bool = False, - tolerations: Optional[list] = None, - security_context: Optional[Union[k8s.V1PodSecurityContext, dict]] = None, - configmaps: Optional[List[str]] = None, - dnspolicy: Optional[str] = None, - schedulername: Optional[str] = None, + tolerations: list | None = None, + security_context: k8s.V1PodSecurityContext | dict | None = None, + configmaps: list[str] | None = None, + dnspolicy: str | None = None, + schedulername: str | None = None, extract_xcom: bool = False, - priority_class_name: Optional[str] = None, + priority_class_name: str | None = None, ): self.pod = k8s.V1Pod() - self.pod.api_version = 'v1' - self.pod.kind = 'Pod' + self.pod.api_version = "v1" + self.pod.kind = "Pod" # Pod Metadata self.metadata = k8s.V1ObjectMeta() @@ -158,7 +159,7 @@ def __init__( self.metadata.annotations = annotations # Pod Container - self.container = k8s.V1Container(name='base') + self.container = k8s.V1Container(name="base") self.container.image = image self.container.env = [] @@ -204,7 +205,7 @@ def __init__( self.spec.image_pull_secrets = [] if image_pull_secrets: - for image_pull_secret in image_pull_secrets.split(','): + for image_pull_secret in image_pull_secrets.split(","): self.spec.image_pull_secrets.append(k8s.V1LocalObjectReference(name=image_pull_secret)) # Attach sidecar @@ -240,7 +241,7 @@ def add_sidecar(pod: k8s.V1Pod) -> k8s.V1Pod: return pod_cp @staticmethod - def from_obj(obj) -> Optional[k8s.V1Pod]: + def from_obj(obj) -> k8s.V1Pod | None: """Converts to pod from obj""" if obj is None: return None @@ -250,8 +251,8 @@ def from_obj(obj) -> Optional[k8s.V1Pod]: if not isinstance(obj, dict): raise TypeError( - 'Cannot convert a non-dictionary or non-PodGenerator ' - 'object into a KubernetesExecutorConfig' + "Cannot convert a non-dictionary or non-PodGenerator " + "object into a KubernetesExecutorConfig" ) # We do not want to extract constant here from ExecutorLoader because it is just @@ -261,25 +262,25 @@ def from_obj(obj) -> Optional[k8s.V1Pod]: if not namespaced: return None - resources = namespaced.get('resources') + resources = namespaced.get("resources") if resources is None: requests = { - 'cpu': namespaced.get('request_cpu'), - 'memory': namespaced.get('request_memory'), - 'ephemeral-storage': namespaced.get('ephemeral-storage'), + "cpu": namespaced.get("request_cpu"), + "memory": namespaced.get("request_memory"), + "ephemeral-storage": namespaced.get("ephemeral-storage"), } limits = { - 'cpu': namespaced.get('limit_cpu'), - 'memory': namespaced.get('limit_memory'), - 'ephemeral-storage': namespaced.get('ephemeral-storage'), + "cpu": namespaced.get("limit_cpu"), + "memory": namespaced.get("limit_memory"), + "ephemeral-storage": namespaced.get("ephemeral-storage"), } all_resources = list(requests.values()) + list(limits.values()) if all(r is None for r in all_resources): resources = None else: resources = k8s.V1ResourceRequirements(requests=requests, limits=limits) - namespaced['resources'] = resources + namespaced["resources"] = resources return PodGenerator(**namespaced).gen_pod() @staticmethod diff --git a/airflow/kubernetes/pod_launcher.py b/airflow/kubernetes/pod_launcher.py index 0b9cbbe45a481..bd52f49653313 100644 --- a/airflow/kubernetes/pod_launcher.py +++ b/airflow/kubernetes/pod_launcher.py @@ -19,4 +19,6 @@ This module is deprecated. Please use :mod:`kubernetes.client.models` for V1ResourceRequirements and Port. """ +from __future__ import annotations + from airflow.kubernetes.pod_launcher_deprecated import PodLauncher, PodStatus # noqa: autoflake diff --git a/airflow/kubernetes/pod_launcher_deprecated.py b/airflow/kubernetes/pod_launcher_deprecated.py index 97845dad51d5a..13476c43760be 100644 --- a/airflow/kubernetes/pod_launcher_deprecated.py +++ b/airflow/kubernetes/pod_launcher_deprecated.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. """Launches PODs""" +from __future__ import annotations + import json import math import time import warnings from datetime import datetime as dt -from typing import Optional, Tuple import pendulum import tenacity @@ -30,7 +31,7 @@ from kubernetes.stream import stream as kubernetes_stream from requests.exceptions import HTTPError -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, RemovedInAirflow3Warning from airflow.kubernetes.kube_client import get_kube_client from airflow.kubernetes.pod_generator import PodDefaults from airflow.settings import pod_mutation_hook @@ -46,7 +47,7 @@ https://pypi.org/project/apache-airflow-providers-cncf-kubernetes/ """, - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) @@ -54,10 +55,10 @@ class PodStatus: """Status of the PODs""" - PENDING = 'pending' - RUNNING = 'running' - FAILED = 'failed' - SUCCEEDED = 'succeeded' + PENDING = "pending" + RUNNING = "running" + FAILED = "failed" + SUCCEEDED = "succeeded" class PodLauncher(LoggingMixin): @@ -69,7 +70,7 @@ def __init__( self, kube_client: client.CoreV1Api = None, in_cluster: bool = True, - cluster_context: Optional[str] = None, + cluster_context: str | None = None, extract_xcom: bool = False, ): """ @@ -94,14 +95,14 @@ def run_pod_async(self, pod: V1Pod, **kwargs): sanitized_pod = self._client.api_client.sanitize_for_serialization(pod) json_pod = json.dumps(sanitized_pod, indent=2) - self.log.debug('Pod Creation Request: \n%s', json_pod) + self.log.debug("Pod Creation Request: \n%s", json_pod) try: resp = self._client.create_namespaced_pod( body=sanitized_pod, namespace=pod.metadata.namespace, **kwargs ) - self.log.debug('Pod Creation Response: %s', resp) + self.log.debug("Pod Creation Response: %s", resp) except Exception as e: - self.log.exception('Exception when attempting to create Namespaced Pod: %s', json_pod) + self.log.exception("Exception when attempting to create Namespaced Pod: %s", json_pod) raise e return resp @@ -134,13 +135,12 @@ def start_pod(self, pod: V1Pod, startup_timeout: int = 120): raise AirflowException("Pod took too long to start") time.sleep(1) - def monitor_pod(self, pod: V1Pod, get_logs: bool) -> Tuple[State, Optional[str]]: + def monitor_pod(self, pod: V1Pod, get_logs: bool) -> tuple[State, str | None]: """ Monitors a pod and returns the final state :param pod: pod spec that will be monitored :param get_logs: whether to read the logs locally - :return: Tuple[State, Optional[str]] """ if get_logs: read_logs_since_sec = None @@ -148,7 +148,7 @@ def monitor_pod(self, pod: V1Pod, get_logs: bool) -> Tuple[State, Optional[str]] while True: logs = self.read_pod_logs(pod, timestamps=True, since_seconds=read_logs_since_sec) for line in logs: - timestamp, message = self.parse_log_line(line.decode('utf-8')) + timestamp, message = self.parse_log_line(line.decode("utf-8")) if timestamp: last_log_time = pendulum.parse(timestamp) self.log.info(message) @@ -157,7 +157,7 @@ def monitor_pod(self, pod: V1Pod, get_logs: bool) -> Tuple[State, Optional[str]] if not self.base_container_is_running(pod): break - self.log.warning('Pod %s log read interrupted', pod.metadata.name) + self.log.warning("Pod %s log read interrupted", pod.metadata.name) if last_log_time: delta = pendulum.now() - last_log_time # Prefer logs duplication rather than loss @@ -165,25 +165,24 @@ def monitor_pod(self, pod: V1Pod, get_logs: bool) -> Tuple[State, Optional[str]] result = None if self.extract_xcom: while self.base_container_is_running(pod): - self.log.info('Container %s has state %s', pod.metadata.name, State.RUNNING) + self.log.info("Container %s has state %s", pod.metadata.name, State.RUNNING) time.sleep(2) result = self._extract_xcom(pod) self.log.info(result) result = json.loads(result) while self.pod_is_running(pod): - self.log.info('Pod %s has state %s', pod.metadata.name, State.RUNNING) + self.log.info("Pod %s has state %s", pod.metadata.name, State.RUNNING) time.sleep(2) return self._task_status(self.read_pod(pod)), result - def parse_log_line(self, line: str) -> Tuple[Optional[str], str]: + def parse_log_line(self, line: str) -> tuple[str | None, str]: """ Parse K8s log line and returns the final state :param line: k8s log line :return: timestamp and log message - :rtype: Tuple[str, str] """ - split_at = line.find(' ') + split_at = line.find(" ") if split_at == -1: self.log.error( "Error parsing timestamp (no timestamp in message: %r). " @@ -196,7 +195,7 @@ def parse_log_line(self, line: str) -> Tuple[Optional[str], str]: return timestamp, message def _task_status(self, event): - self.log.info('Event: %s had an event of type %s', event.metadata.name, event.status.phase) + self.log.info("Event: %s had an event of type %s", event.metadata.name, event.status.phase) status = self.process_status(event.metadata.name, event.status.phase) return status @@ -213,7 +212,7 @@ def pod_is_running(self, pod: V1Pod): def base_container_is_running(self, pod: V1Pod): """Tests if base container is running""" event = self.read_pod(pod) - status = next(iter(filter(lambda s: s.name == 'base', event.status.container_statuses)), None) + status = next((s for s in event.status.container_statuses if s.name == "base"), None) if not status: return False return status.state.running is not None @@ -222,30 +221,30 @@ def base_container_is_running(self, pod: V1Pod): def read_pod_logs( self, pod: V1Pod, - tail_lines: Optional[int] = None, + tail_lines: int | None = None, timestamps: bool = False, - since_seconds: Optional[int] = None, + since_seconds: int | None = None, ): """Reads log from the POD""" additional_kwargs = {} if since_seconds: - additional_kwargs['since_seconds'] = since_seconds + additional_kwargs["since_seconds"] = since_seconds if tail_lines: - additional_kwargs['tail_lines'] = tail_lines + additional_kwargs["tail_lines"] = tail_lines try: return self._client.read_namespaced_pod_log( name=pod.metadata.name, namespace=pod.metadata.namespace, - container='base', + container="base", follow=True, timestamps=timestamps, _preload_content=False, **additional_kwargs, ) except HTTPError as e: - raise AirflowException(f'There was an error reading the kubernetes API: {e}') + raise AirflowException(f"There was an error reading the kubernetes API: {e}") @tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True) def read_pod_events(self, pod): @@ -255,7 +254,7 @@ def read_pod_events(self, pod): namespace=pod.metadata.namespace, field_selector=f"involvedObject.name={pod.metadata.name}" ) except HTTPError as e: - raise AirflowException(f'There was an error reading the kubernetes API: {e}') + raise AirflowException(f"There was an error reading the kubernetes API: {e}") @tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True) def read_pod(self, pod: V1Pod): @@ -263,7 +262,7 @@ def read_pod(self, pod: V1Pod): try: return self._client.read_namespaced_pod(pod.metadata.name, pod.metadata.namespace) except HTTPError as e: - raise AirflowException(f'There was an error reading the kubernetes API: {e}') + raise AirflowException(f"There was an error reading the kubernetes API: {e}") def _extract_xcom(self, pod: V1Pod): resp = kubernetes_stream( @@ -271,7 +270,7 @@ def _extract_xcom(self, pod: V1Pod): pod.metadata.name, pod.metadata.namespace, container=PodDefaults.SIDECAR_CONTAINER_NAME, - command=['/bin/sh'], + command=["/bin/sh"], stdin=True, stdout=True, stderr=True, @@ -279,18 +278,18 @@ def _extract_xcom(self, pod: V1Pod): _preload_content=False, ) try: - result = self._exec_pod_command(resp, f'cat {PodDefaults.XCOM_MOUNT_PATH}/return.json') - self._exec_pod_command(resp, 'kill -s SIGINT 1') + result = self._exec_pod_command(resp, f"cat {PodDefaults.XCOM_MOUNT_PATH}/return.json") + self._exec_pod_command(resp, "kill -s SIGINT 1") finally: resp.close() if result is None: - raise AirflowException(f'Failed to extract xcom from pod: {pod.metadata.name}') + raise AirflowException(f"Failed to extract xcom from pod: {pod.metadata.name}") return result def _exec_pod_command(self, resp, command): if resp.is_open(): - self.log.info('Running command... %s\n', command) - resp.write_stdin(command + '\n') + self.log.info("Running command... %s\n", command) + resp.write_stdin(command + "\n") while resp.is_open(): resp.update(timeout=1) if resp.peek_stdout(): @@ -306,13 +305,13 @@ def process_status(self, job_id, status): if status == PodStatus.PENDING: return State.QUEUED elif status == PodStatus.FAILED: - self.log.error('Event with job id %s Failed', job_id) + self.log.error("Event with job id %s Failed", job_id) return State.FAILED elif status == PodStatus.SUCCEEDED: - self.log.info('Event with job id %s Succeeded', job_id) + self.log.info("Event with job id %s Succeeded", job_id) return State.SUCCESS elif status == PodStatus.RUNNING: return State.RUNNING else: - self.log.error('Event: Invalid state %s on job %s', status, job_id) + self.log.error("Event: Invalid state %s on job %s", status, job_id) return State.FAILED diff --git a/airflow/kubernetes/pod_runtime_info_env.py b/airflow/kubernetes/pod_runtime_info_env.py index a51f3b96fc39a..32e178263b126 100644 --- a/airflow/kubernetes/pod_runtime_info_env.py +++ b/airflow/kubernetes/pod_runtime_info_env.py @@ -16,14 +16,18 @@ # specific language governing permissions and limitations # under the License. """This module is deprecated. Please use :mod:`kubernetes.client.models.V1EnvVar`.""" +from __future__ import annotations + import warnings +from airflow.exceptions import RemovedInAirflow3Warning + with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + warnings.simplefilter("ignore", RemovedInAirflow3Warning) from airflow.providers.cncf.kubernetes.backcompat.pod_runtime_info_env import PodRuntimeInfoEnv # noqa warnings.warn( "This module is deprecated. Please use `kubernetes.client.models.V1EnvVar`.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) diff --git a/airflow/kubernetes/secret.py b/airflow/kubernetes/secret.py index afb30916ff357..145dcd5b2cb5f 100644 --- a/airflow/kubernetes/secret.py +++ b/airflow/kubernetes/secret.py @@ -15,9 +15,10 @@ # specific language governing permissions and limitations # under the License. """Classes for interacting with Kubernetes API""" +from __future__ import annotations + import copy import uuid -from typing import Tuple from kubernetes.client import models as k8s @@ -45,19 +46,19 @@ def __init__(self, deploy_type, deploy_target, secret, key=None, items=None): secret keys to paths https://kubernetes.io/docs/concepts/configuration/secret/#projection-of-secret-keys-to-specific-paths """ - if deploy_type not in ('env', 'volume'): + if deploy_type not in ("env", "volume"): raise AirflowConfigException("deploy_type must be env or volume") self.deploy_type = deploy_type self.deploy_target = deploy_target self.items = items or [] - if deploy_target is not None and deploy_type == 'env': + if deploy_target is not None and deploy_type == "env": # if deploying to env, capitalize the deploy target self.deploy_target = deploy_target.upper() if key is not None and deploy_target is None: - raise AirflowConfigException('If `key` is set, `deploy_target` should not be None') + raise AirflowConfigException("If `key` is set, `deploy_target` should not be None") self.secret = secret self.key = key @@ -75,9 +76,9 @@ def to_env_from_secret(self) -> k8s.V1EnvFromSource: """Reads from environment to secret""" return k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name=self.secret)) - def to_volume_secret(self) -> Tuple[k8s.V1Volume, k8s.V1VolumeMount]: + def to_volume_secret(self) -> tuple[k8s.V1Volume, k8s.V1VolumeMount]: """Converts to volume secret""" - vol_id = f'secretvol{uuid.uuid4()}' + vol_id = f"secretvol{uuid.uuid4()}" volume = k8s.V1Volume(name=vol_id, secret=k8s.V1SecretVolumeSource(secret_name=self.secret)) if self.items: volume.secret.items = self.items @@ -87,7 +88,7 @@ def attach_to_pod(self, pod: k8s.V1Pod) -> k8s.V1Pod: """Attaches to pod""" cp_pod = copy.deepcopy(pod) - if self.deploy_type == 'volume': + if self.deploy_type == "volume": volume, volume_mount = self.to_volume_secret() if cp_pod.spec.volumes is None: cp_pod.spec.volumes = [] @@ -96,13 +97,13 @@ def attach_to_pod(self, pod: k8s.V1Pod) -> k8s.V1Pod: cp_pod.spec.containers[0].volume_mounts = [] cp_pod.spec.containers[0].volume_mounts.append(volume_mount) - if self.deploy_type == 'env' and self.key is not None: + if self.deploy_type == "env" and self.key is not None: env = self.to_env_secret() if cp_pod.spec.containers[0].env is None: cp_pod.spec.containers[0].env = [] cp_pod.spec.containers[0].env.append(env) - if self.deploy_type == 'env' and self.key is None: + if self.deploy_type == "env" and self.key is None: env_from = self.to_env_from_secret() if cp_pod.spec.containers[0].env_from is None: cp_pod.spec.containers[0].env_from = [] @@ -119,4 +120,4 @@ def __eq__(self, other): ) def __repr__(self): - return f'Secret({self.deploy_type}, {self.deploy_target}, {self.secret}, {self.key})' + return f"Secret({self.deploy_type}, {self.deploy_target}, {self.secret}, {self.key})" diff --git a/airflow/kubernetes/volume.py b/airflow/kubernetes/volume.py index 81b4fda3a15da..ecb39e457fd4a 100644 --- a/airflow/kubernetes/volume.py +++ b/airflow/kubernetes/volume.py @@ -16,14 +16,18 @@ # specific language governing permissions and limitations # under the License. """This module is deprecated. Please use :mod:`kubernetes.client.models.V1Volume`.""" +from __future__ import annotations + import warnings +from airflow.exceptions import RemovedInAirflow3Warning + with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + warnings.simplefilter("ignore", RemovedInAirflow3Warning) from airflow.providers.cncf.kubernetes.backcompat.volume import Volume # noqa: autoflake warnings.warn( "This module is deprecated. Please use `kubernetes.client.models.V1Volume`.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) diff --git a/airflow/kubernetes/volume_mount.py b/airflow/kubernetes/volume_mount.py index f558425881752..e65351d85f5af 100644 --- a/airflow/kubernetes/volume_mount.py +++ b/airflow/kubernetes/volume_mount.py @@ -16,14 +16,18 @@ # specific language governing permissions and limitations # under the License. """This module is deprecated. Please use :mod:`kubernetes.client.models.V1VolumeMount`.""" +from __future__ import annotations + import warnings +from airflow.exceptions import RemovedInAirflow3Warning + with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + warnings.simplefilter("ignore", RemovedInAirflow3Warning) from airflow.providers.cncf.kubernetes.backcompat.volume_mount import VolumeMount # noqa: autoflake warnings.warn( "This module is deprecated. Please use `kubernetes.client.models.V1VolumeMount`.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) diff --git a/airflow/lineage/__init__.py b/airflow/lineage/__init__.py index 3d8c696487915..173956b74c289 100644 --- a/airflow/lineage/__init__.py +++ b/airflow/lineage/__init__.py @@ -15,21 +15,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Provides lineage support functions""" -import json +"""Provides lineage support functions.""" +from __future__ import annotations + +import itertools import logging from functools import wraps -from typing import Any, Callable, Dict, Optional, TypeVar, cast - -import attr -import jinja2 -from cattr import structure, unstructure +from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast from airflow.configuration import conf from airflow.lineage.backend import LineageBackend -from airflow.utils.module_loading import import_string -ENV = jinja2.Environment() +if TYPE_CHECKING: + from airflow.utils.context import Context + PIPELINE_OUTLETS = "pipeline_outlets" PIPELINE_INLETS = "pipeline_inlets" @@ -38,17 +37,8 @@ log = logging.getLogger(__name__) -@attr.s(auto_attribs=True) -class Metadata: - """Class for serialized entities.""" - - type_name: str = attr.ib() - source: str = attr.ib() - data: Dict = attr.ib() - - -def get_backend() -> Optional[LineageBackend]: - """Gets the lineage backend if defined in the configs""" +def get_backend() -> LineageBackend | None: + """Gets the lineage backend if defined in the configs.""" clazz = conf.getimport("lineage", "backend", fallback=None) if clazz: @@ -63,33 +53,8 @@ def get_backend() -> Optional[LineageBackend]: return None -def _get_instance(meta: Metadata): - """Instantiate an object from Metadata""" - cls = import_string(meta.type_name) - return structure(meta.data, cls) - - -def _render_object(obj: Any, context) -> Any: - """Renders a attr annotated object. Will set non serializable attributes to none""" - return structure( - json.loads( - ENV.from_string(json.dumps(unstructure(obj), default=lambda o: None)) - .render(**context) - .encode('utf-8') - ), - type(obj), - ) - - -def _to_dataset(obj: Any, source: str) -> Optional[Metadata]: - """Create Metadata from attr annotated object""" - if not attr.has(obj): - return None - - type_name = obj.__module__ + '.' + obj.__class__.__name__ - data = unstructure(obj) - - return Metadata(type_name, source, data) +def _render_object(obj: Any, context: Context) -> dict: + return context["ti"].task.render_template(obj, context) T = TypeVar("T", bound=Callable) @@ -97,6 +62,8 @@ def _to_dataset(obj: Any, source: str) -> Optional[Metadata]: def apply_lineage(func: T) -> T: """ + Conditionally send lineage to the backend. + Saves the lineage to XCom and if configured to do so sends it to the backend. """ @@ -104,20 +71,22 @@ def apply_lineage(func: T) -> T: @wraps(func) def wrapper(self, context, *args, **kwargs): + self.log.debug("Lineage called with inlets: %s, outlets: %s", self.inlets, self.outlets) + ret_val = func(self, context, *args, **kwargs) - outlets = [unstructure(_to_dataset(x, f"{self.dag_id}.{self.task_id}")) for x in self.outlets] - inlets = [unstructure(_to_dataset(x, None)) for x in self.inlets] + outlets = list(self.outlets) + inlets = list(self.inlets) - if self.outlets: + if outlets: self.xcom_push( - context, key=PIPELINE_OUTLETS, value=outlets, execution_date=context['ti'].execution_date + context, key=PIPELINE_OUTLETS, value=outlets, execution_date=context["ti"].execution_date ) - if self.inlets: + if inlets: self.xcom_push( - context, key=PIPELINE_INLETS, value=inlets, execution_date=context['ti'].execution_date + context, key=PIPELINE_INLETS, value=inlets, execution_date=context["ti"].execution_date ) if _backend: @@ -130,7 +99,9 @@ def wrapper(self, context, *args, **kwargs): def prepare_lineage(func: T) -> T: """ - Prepares the lineage inlets and outlets. Inlets can be: + Prepares the lineage inlets and outlets. + + Inlets can be: * "auto" -> picks up any outlets from direct upstream tasks that have outlets defined, as such that if A -> B -> C and B does not have outlets but A does, these are provided as inlets. @@ -145,49 +116,44 @@ def wrapper(self, context, *args, **kwargs): self.log.debug("Preparing lineage inlets and outlets") - if isinstance(self._inlets, (str, AbstractOperator)) or attr.has(self._inlets): - self._inlets = [ - self._inlets, - ] + if isinstance(self.inlets, (str, AbstractOperator)): + self.inlets = [self.inlets] - if self._inlets and isinstance(self._inlets, list): + if self.inlets and isinstance(self.inlets, list): # get task_ids that are specified as parameter and make sure they are upstream task_ids = ( - {o for o in self._inlets if isinstance(o, str)} - .union(op.task_id for op in self._inlets if isinstance(op, AbstractOperator)) + {o for o in self.inlets if isinstance(o, str)} + .union(op.task_id for op in self.inlets if isinstance(op, AbstractOperator)) .intersection(self.get_flat_relative_ids(upstream=True)) ) # pick up unique direct upstream task_ids if AUTO is specified - if AUTO.upper() in self._inlets or AUTO.lower() in self._inlets: + if AUTO.upper() in self.inlets or AUTO.lower() in self.inlets: task_ids = task_ids.union(task_ids.symmetric_difference(self.upstream_task_ids)) + # Remove auto and task_ids + self.inlets = [i for i in self.inlets if not isinstance(i, str)] _inlets = self.xcom_pull(context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS) # re-instantiate the obtained inlets - _inlets = [ - _get_instance(structure(item, Metadata)) for sublist in _inlets if sublist for item in sublist - ] + # xcom_pull returns a list of items for each given task_id + _inlets = [item for item in itertools.chain.from_iterable(_inlets)] self.inlets.extend(_inlets) - self.inlets.extend(self._inlets) - elif self._inlets: + elif self.inlets: raise AttributeError("inlets is not a list, operator, string or attr annotated object") - if not isinstance(self._outlets, list): - self._outlets = [ - self._outlets, - ] - - self.outlets.extend(self._outlets) + if not isinstance(self.outlets, list): + self.outlets = [self.outlets] # render inlets and outlets - self.inlets = [_render_object(i, context) for i in self.inlets if attr.has(i)] + self.inlets = [_render_object(i, context) for i in self.inlets] - self.outlets = [_render_object(i, context) for i in self.outlets if attr.has(i)] + self.outlets = [_render_object(i, context) for i in self.outlets] self.log.debug("inlets: %s, outlets: %s", self.inlets, self.outlets) + return func(self, context, *args, **kwargs) return cast(T, wrapper) diff --git a/airflow/lineage/backend.py b/airflow/lineage/backend.py index ca072f434105a..29a755109c64f 100644 --- a/airflow/lineage/backend.py +++ b/airflow/lineage/backend.py @@ -15,25 +15,27 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Sends lineage metadata to a backend""" -from typing import TYPE_CHECKING, Optional +"""Sends lineage metadata to a backend.""" +from __future__ import annotations + +from typing import TYPE_CHECKING if TYPE_CHECKING: from airflow.models.baseoperator import BaseOperator class LineageBackend: - """Sends lineage metadata to a backend""" + """Sends lineage metadata to a backend.""" def send_lineage( self, - operator: 'BaseOperator', - inlets: Optional[list] = None, - outlets: Optional[list] = None, - context: Optional[dict] = None, + operator: BaseOperator, + inlets: list | None = None, + outlets: list | None = None, + context: dict | None = None, ): """ - Sends lineage metadata to a backend + Sends lineage metadata to a backend. :param operator: the operator executing a transformation on the inlets and outlets :param inlets: the inlets to this operator diff --git a/airflow/lineage/entities.py b/airflow/lineage/entities.py index 87703edfbaf10..c52dc4cfd0133 100644 --- a/airflow/lineage/entities.py +++ b/airflow/lineage/entities.py @@ -15,31 +15,33 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Defines base entities used for providing lineage information.""" +from __future__ import annotations -""" -Defines the base entities that can be used for providing lineage -information. -""" -from typing import Any, Dict, List, Optional +from typing import Any, ClassVar import attr @attr.s(auto_attribs=True) class File: - """File entity. Refers to a file""" + """File entity. Refers to a file.""" + + template_fields: ClassVar = ("url",) url: str = attr.ib() - type_hint: Optional[str] = None + type_hint: str | None = None @attr.s(auto_attribs=True, kw_only=True) class User: - """User entity. Identifies a user""" + """User entity. Identifies a user.""" email: str = attr.ib() - first_name: Optional[str] = None - last_name: Optional[str] = None + first_name: str | None = None + last_name: str | None = None + + template_fields: ClassVar = ("email", "first_name", "last_name") @attr.s(auto_attribs=True, kw_only=True) @@ -48,15 +50,19 @@ class Tag: tag_name: str = attr.ib() + template_fields: ClassVar = ("tag_name",) + @attr.s(auto_attribs=True, kw_only=True) class Column: - """Column of a Table""" + """Column of a Table.""" name: str = attr.ib() - description: Optional[str] = None + description: str | None = None data_type: str = attr.ib() - tags: List[Tag] = [] + tags: list[Tag] = [] + + template_fields: ClassVar = ("name", "description", "data_type", "tags") # this is a temporary hack to satisfy mypy. Once @@ -64,20 +70,32 @@ class Column: # `attr.converters.default_if_none(default=False)` -def default_if_none(arg: Optional[bool]) -> bool: +def default_if_none(arg: bool | None) -> bool: + """Get default value when None.""" return arg or False @attr.s(auto_attribs=True, kw_only=True) class Table: - """Table entity""" + """Table entity.""" database: str = attr.ib() cluster: str = attr.ib() name: str = attr.ib() - tags: List[Tag] = [] - description: Optional[str] = None - columns: List[Column] = [] - owners: List[User] = [] - extra: Dict[str, Any] = {} - type_hint: Optional[str] = None + tags: list[Tag] = [] + description: str | None = None + columns: list[Column] = [] + owners: list[User] = [] + extra: dict[str, Any] = {} + type_hint: str | None = None + + template_fields: ClassVar = ( + "database", + "cluster", + "name", + "tags", + "description", + "columns", + "owners", + "extra", + ) diff --git a/airflow/listeners/__init__.py b/airflow/listeners/__init__.py index d1df70bda1158..87840b50e2fa5 100644 --- a/airflow/listeners/__init__.py +++ b/airflow/listeners/__init__.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from pluggy import HookimplMarker hookimpl = HookimplMarker("airflow") diff --git a/airflow/listeners/events.py b/airflow/listeners/events.py index d5af64710a17a..53c113af8ee66 100644 --- a/airflow/listeners/events.py +++ b/airflow/listeners/events.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import logging from sqlalchemy import event @@ -28,6 +30,8 @@ def on_task_instance_state_session_flush(session, flush_context): """ + Flush task instance's state. + Listens for session.flush() events that modify TaskInstance's state, and notify listeners that listen for that event. Doing it this way enable us to be stateless in the SQLAlchemy event listener. """ @@ -38,7 +42,7 @@ def on_task_instance_state_session_flush(session, flush_context): if isinstance(state.object, TaskInstance) and session.is_modified( state.object, include_collections=False ): - added, unchanged, deleted = flush_context.get_attribute_history(state, 'state') + added, unchanged, deleted = flush_context.get_attribute_history(state, "state") logger.debug( "session flush listener: added %s unchanged %s deleted %s - %s", @@ -67,13 +71,15 @@ def on_task_instance_state_session_flush(session, flush_context): def register_task_instance_state_events(): + """Register a task instance state event.""" global _is_listening if not _is_listening: - event.listen(Session, 'after_flush', on_task_instance_state_session_flush) + event.listen(Session, "after_flush", on_task_instance_state_session_flush) _is_listening = True def unregister_task_instance_state_events(): + """Unregister a task instance state event.""" global _is_listening - event.remove(Session, 'after_flush', on_task_instance_state_session_flush) + event.remove(Session, "after_flush", on_task_instance_state_session_flush) _is_listening = False diff --git a/airflow/listeners/listener.py b/airflow/listeners/listener.py index 3c4d052399dab..546d732513f76 100644 --- a/airflow/listeners/listener.py +++ b/airflow/listeners/listener.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import logging -from types import ModuleType from typing import TYPE_CHECKING import pluggy @@ -33,37 +34,38 @@ class ListenerManager: - """Class that manages registration of listeners and provides hook property for calling them""" + """Manage listener registration and provides hook property for calling them.""" def __init__(self): - from airflow.listeners import spec + from airflow.listeners.spec import dagrun, lifecycle, taskinstance self.pm = pluggy.PluginManager("airflow") - self.pm.add_hookspecs(spec) + self.pm.add_hookspecs(lifecycle) + self.pm.add_hookspecs(dagrun) + self.pm.add_hookspecs(taskinstance) @property def has_listeners(self) -> bool: return len(self.pm.get_plugins()) > 0 @property - def hook(self) -> "_HookRelay": - """Returns hook, on which plugin methods specified in spec can be called.""" + def hook(self) -> _HookRelay: + """Return hook, on which plugin methods specified in spec can be called.""" return self.pm.hook def add_listener(self, listener): - if not isinstance(listener, ModuleType): - raise TypeError("Listener %s is not module", str(listener)) if self.pm.is_registered(listener): return self.pm.register(listener) def clear(self): - """Remove registered plugins""" + """Remove registered plugins.""" for plugin in self.pm.get_plugins(): self.pm.unregister(plugin) def get_listener_manager() -> ListenerManager: + """Get singleton listener manager.""" global _listener_manager if not _listener_manager: _listener_manager = ListenerManager() diff --git a/airflow/listeners/spec.py b/airflow/listeners/spec.py deleted file mode 100644 index fbaf63e89ac6e..0000000000000 --- a/airflow/listeners/spec.py +++ /dev/null @@ -1,49 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from typing import TYPE_CHECKING, Optional - -from pluggy import HookspecMarker - -if TYPE_CHECKING: - from sqlalchemy.orm.session import Session - - from airflow.models.taskinstance import TaskInstance - from airflow.utils.state import TaskInstanceState - -hookspec = HookspecMarker("airflow") - - -@hookspec -def on_task_instance_running( - previous_state: "TaskInstanceState", task_instance: "TaskInstance", session: Optional["Session"] -): - """Called when task state changes to RUNNING. Previous_state can be State.NONE.""" - - -@hookspec -def on_task_instance_success( - previous_state: "TaskInstanceState", task_instance: "TaskInstance", session: Optional["Session"] -): - """Called when task state changes to SUCCESS. Previous_state can be State.NONE.""" - - -@hookspec -def on_task_instance_failed( - previous_state: "TaskInstanceState", task_instance: "TaskInstance", session: Optional["Session"] -): - """Called when task state changes to FAIL. Previous_state can be State.NONE.""" diff --git a/airflow/providers/airbyte/example_dags/__init__.py b/airflow/listeners/spec/__init__.py similarity index 100% rename from airflow/providers/airbyte/example_dags/__init__.py rename to airflow/listeners/spec/__init__.py diff --git a/airflow/listeners/spec/dagrun.py b/airflow/listeners/spec/dagrun.py new file mode 100644 index 0000000000000..d2ae1a6b78cb5 --- /dev/null +++ b/airflow/listeners/spec/dagrun.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pluggy import HookspecMarker + +if TYPE_CHECKING: + from airflow.models.dagrun import DagRun + +hookspec = HookspecMarker("airflow") + + +@hookspec +def on_dag_run_running(dag_run: DagRun, msg: str): + """Called when dag run state changes to RUNNING.""" + + +@hookspec +def on_dag_run_success(dag_run: DagRun, msg: str): + """Called when dag run state changes to SUCCESS.""" + + +@hookspec +def on_dag_run_failed(dag_run: DagRun, msg: str): + """Called when dag run state changes to FAIL.""" diff --git a/airflow/listeners/spec/lifecycle.py b/airflow/listeners/spec/lifecycle.py new file mode 100644 index 0000000000000..6ab0aa3b5cde1 --- /dev/null +++ b/airflow/listeners/spec/lifecycle.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from pluggy import HookspecMarker + +hookspec = HookspecMarker("airflow") + + +@hookspec +def on_starting(component): + """ + Called before Airflow component - jobs like scheduler, worker, or task runner starts. + + It's guaranteed this will be called before any other plugin method. + + :param component: Component that calls this method + """ + + +@hookspec +def before_stopping(component): + """ + Called before Airflow component - jobs like scheduler, worker, or task runner stops. + + It's guaranteed this will be called after any other plugin method. + + :param component: Component that calls this method + """ diff --git a/airflow/listeners/spec/taskinstance.py b/airflow/listeners/spec/taskinstance.py new file mode 100644 index 0000000000000..78de8a5f62b14 --- /dev/null +++ b/airflow/listeners/spec/taskinstance.py @@ -0,0 +1,51 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pluggy import HookspecMarker + +if TYPE_CHECKING: + from sqlalchemy.orm.session import Session + + from airflow.models.taskinstance import TaskInstance + from airflow.utils.state import TaskInstanceState + +hookspec = HookspecMarker("airflow") + + +@hookspec +def on_task_instance_running( + previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session | None +): + """Called when task state changes to RUNNING. Previous_state can be State.NONE.""" + + +@hookspec +def on_task_instance_success( + previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session | None +): + """Called when task state changes to SUCCESS. Previous_state can be State.NONE.""" + + +@hookspec +def on_task_instance_failed( + previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session | None +): + """Called when task state changes to FAIL. Previous_state can be State.NONE.""" diff --git a/airflow/logging_config.py b/airflow/logging_config.py index 645e53eb3efaa..d78f84cb6a443 100644 --- a/airflow/logging_config.py +++ b/airflow/logging_config.py @@ -15,7 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations + import logging import warnings from logging.config import dictConfig @@ -29,11 +30,11 @@ def configure_logging(): """Configure & Validate Airflow Logging""" - logging_class_path = '' + logging_class_path = "" try: - logging_class_path = conf.get('logging', 'logging_config_class') + logging_class_path = conf.get("logging", "logging_config_class") except AirflowConfigException: - log.debug('Could not find key logging_config_class in config') + log.debug("Could not find key logging_config_class in config") if logging_class_path: try: @@ -43,31 +44,31 @@ def configure_logging(): if not isinstance(logging_config, dict): raise ValueError("Logging Config should be of dict type") - log.info('Successfully imported user-defined logging config from %s', logging_class_path) + log.info("Successfully imported user-defined logging config from %s", logging_class_path) except Exception as err: # Import default logging configurations. - raise ImportError(f'Unable to load custom logging from {logging_class_path} due to {err}') + raise ImportError(f"Unable to load custom logging from {logging_class_path} due to {err}") else: - logging_class_path = 'airflow.config_templates.airflow_local_settings.DEFAULT_LOGGING_CONFIG' + logging_class_path = "airflow.config_templates.airflow_local_settings.DEFAULT_LOGGING_CONFIG" logging_config = import_string(logging_class_path) - log.debug('Unable to load custom logging, using default config instead') + log.debug("Unable to load custom logging, using default config instead") try: # Ensure that the password masking filter is applied to the 'task' handler # no matter what the user did. - if 'filters' in logging_config and 'mask_secrets' in logging_config['filters']: + if "filters" in logging_config and "mask_secrets" in logging_config["filters"]: # But if they replace the logging config _entirely_, don't try to set this, it won't work - task_handler_config = logging_config['handlers']['task'] + task_handler_config = logging_config["handlers"]["task"] - task_handler_config.setdefault('filters', []) + task_handler_config.setdefault("filters", []) - if 'mask_secrets' not in task_handler_config['filters']: - task_handler_config['filters'].append('mask_secrets') + if "mask_secrets" not in task_handler_config["filters"]: + task_handler_config["filters"].append("mask_secrets") # Try to init logging dictConfig(logging_config) except (ValueError, KeyError) as e: - log.error('Unable to load the config, contains a configuration error.') + log.error("Unable to load the config, contains a configuration error.") # When there is an error in the config, escalate the exception # otherwise Airflow would silently fall back on the default config raise e @@ -80,9 +81,9 @@ def configure_logging(): def validate_logging_config(logging_config): """Validate the provided Logging Config""" # Now lets validate the other logging-related settings - task_log_reader = conf.get('logging', 'task_log_reader') + task_log_reader = conf.get("logging", "task_log_reader") - logger = logging.getLogger('airflow.task') + logger = logging.getLogger("airflow.task") def _get_handler(name): return next((h for h in logger.handlers if h.name == name), None) @@ -96,7 +97,7 @@ def _get_handler(name): "Running config has been adjusted to match", DeprecationWarning, ) - conf.set('logging', 'task_log_reader', 'task') + conf.set("logging", "task_log_reader", "task") else: raise AirflowConfigException( f"Configured task_log_reader {task_log_reader!r} was not a handler of " diff --git a/airflow/macros/__init__.py b/airflow/macros/__init__.py index e1f27411e5069..4364d3278a1e1 100644 --- a/airflow/macros/__init__.py +++ b/airflow/macros/__init__.py @@ -15,11 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import time # noqa import uuid # noqa from datetime import datetime, timedelta from random import random # noqa -from typing import Any, Optional +from typing import Any import dateutil # noqa from pendulum import DateTime @@ -29,7 +31,7 @@ def ds_add(ds: str, days: int) -> str: """ - Add or subtract days from a YYYY-MM-DD + Add or subtract days from a YYYY-MM-DD. :param ds: anchor date in ``YYYY-MM-DD`` format to add to :param days: number of days to add to the ds, you can use negative values @@ -47,8 +49,7 @@ def ds_add(ds: str, days: int) -> str: def ds_format(ds: str, input_format: str, output_format: str) -> str: """ - Takes an input string and outputs another string - as specified in the output format + Output datetime string in a given format. :param ds: input string which contains a date :param input_format: input string format. E.g. %Y-%m-%d @@ -62,15 +63,15 @@ def ds_format(ds: str, input_format: str, output_format: str) -> str: return datetime.strptime(str(ds), input_format).strftime(output_format) -def datetime_diff_for_humans(dt: Any, since: Optional[DateTime] = None) -> str: +def datetime_diff_for_humans(dt: Any, since: DateTime | None = None) -> str: """ - Return a human-readable/approximate difference between two datetimes, or - one and now. + Return a human-readable/approximate difference between datetimes. + + When only one datetime is provided, the comparison will be based on now. :param dt: The datetime to display the diff for :param since: When to display the date from. If ``None`` then the diff is between ``dt`` and now. - :rtype: str """ import pendulum diff --git a/airflow/macros/hive.py b/airflow/macros/hive.py index fe5685fcf2536..7da0202eb2596 100644 --- a/airflow/macros/hive.py +++ b/airflow/macros/hive.py @@ -15,12 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import datetime def max_partition( - table, schema="default", field=None, filter_map=None, metastore_conn_id='metastore_default' + table, schema="default", field=None, filter_map=None, metastore_conn_id="metastore_default" ): """ Gets the max partition for a table. @@ -43,13 +44,13 @@ def max_partition( """ from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook - if '.' in table: - schema, table = table.split('.') + if "." in table: + schema, table = table.split(".") hive_hook = HiveMetastoreHook(metastore_conn_id=metastore_conn_id) return hive_hook.max_partition(schema=schema, table_name=table, field=field, filter_map=filter_map) -def _closest_date(target_dt, date_list, before_target=None): +def _closest_date(target_dt, date_list, before_target=None) -> datetime.date | None: """ This function finds the date in a list closest to the target date. An optional parameter can be given to get the closest before or after. @@ -58,7 +59,6 @@ def _closest_date(target_dt, date_list, before_target=None): :param date_list: The list of dates to search :param before_target: closest before or after the target :returns: The closest date - :rtype: datetime.date or None """ time_before = lambda d: target_dt - d if d <= target_dt else datetime.timedelta.max time_after = lambda d: d - target_dt if d >= target_dt else datetime.timedelta.max @@ -71,7 +71,9 @@ def _closest_date(target_dt, date_list, before_target=None): return min(date_list, key=time_after).date() -def closest_ds_partition(table, ds, before=True, schema="default", metastore_conn_id='metastore_default'): +def closest_ds_partition( + table, ds, before=True, schema="default", metastore_conn_id="metastore_default" +) -> str | None: """ This function finds the date in a list closest to the target date. An optional parameter can be given to get the closest before or after. @@ -82,7 +84,6 @@ def closest_ds_partition(table, ds, before=True, schema="default", metastore_con :param schema: table schema :param metastore_conn_id: which metastore connection to use :returns: The closest date - :rtype: str or None >>> tbl = 'airflow.static_babynames_partitioned' >>> closest_ds_partition(tbl, '2015-01-02') @@ -90,8 +91,8 @@ def closest_ds_partition(table, ds, before=True, schema="default", metastore_con """ from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook - if '.' in table: - schema, table = table.split('.') + if "." in table: + schema, table = table.split(".") hive_hook = HiveMetastoreHook(metastore_conn_id=metastore_conn_id) partitions = hive_hook.get_partitions(schema=schema, table_name=table) if not partitions: @@ -100,7 +101,9 @@ def closest_ds_partition(table, ds, before=True, schema="default", metastore_con if ds in part_vals: return ds else: - parts = [datetime.datetime.strptime(pv, '%Y-%m-%d') for pv in part_vals] - target_dt = datetime.datetime.strptime(ds, '%Y-%m-%d') + parts = [datetime.datetime.strptime(pv, "%Y-%m-%d") for pv in part_vals] + target_dt = datetime.datetime.strptime(ds, "%Y-%m-%d") closest_ds = _closest_date(target_dt, parts, before_target=before) - return closest_ds.isoformat() + if closest_ds is not None: + return closest_ds.isoformat() + return None diff --git a/airflow/migrations/db_types.py b/airflow/migrations/db_types.py index 9b8f3e974ffd9..70a1a1e6b47f4 100644 --- a/airflow/migrations/db_types.py +++ b/airflow/migrations/db_types.py @@ -15,7 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations + import sqlalchemy as sa from alembic import context from lazy_object_proxy import Proxy @@ -71,7 +72,7 @@ def lazy_load(): module = globals() # Lookup the type based on the dialect specific type, or fallback to the generic type - type_ = module.get(f'_{dialect}_{name}', None) or module.get(f'_sa_{name}') + type_ = module.get(f"_{dialect}_{name}", None) or module.get(f"_sa_{name}") val = module[name] = type_() return val diff --git a/airflow/migrations/env.py b/airflow/migrations/env.py index 58a8f7f4a0465..9f97195a6d9c4 100644 --- a/airflow/migrations/env.py +++ b/airflow/migrations/env.py @@ -15,12 +15,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations +import contextlib +import sys from logging.config import fileConfig from alembic import context from airflow import models, settings +from airflow.utils.db import compare_server_default, compare_type def include_object(_, name, type_, *args): @@ -32,6 +36,9 @@ def include_object(_, name, type_, *args): return True +# Make sure everything is imported so that alembic can find it all +models.import_all_models() + # this is the Alembic Config object, which provides # access to the values within the .ini file in use. config = context.config @@ -51,8 +58,6 @@ def include_object(_, name, type_, *args): # my_important_option = config.get_main_option("my_important_option") # ... etc. -COMPARE_TYPE = False - def run_migrations_offline(): """Run migrations in 'offline' mode. @@ -70,7 +75,8 @@ def run_migrations_offline(): url=settings.SQL_ALCHEMY_CONN, target_metadata=target_metadata, literal_binds=True, - compare_type=COMPARE_TYPE, + compare_type=compare_type, + compare_server_default=compare_server_default, render_as_batch=True, ) @@ -85,14 +91,18 @@ def run_migrations_online(): and associate a connection with the context. """ - connectable = settings.engine + with contextlib.ExitStack() as stack: + connection = config.attributes.get("connection", None) + + if not connection: + connection = stack.push(settings.engine.connect()) - with connectable.connect() as connection: context.configure( connection=connection, transaction_per_migration=True, target_metadata=target_metadata, - compare_type=COMPARE_TYPE, + compare_type=compare_type, + compare_server_default=compare_server_default, include_object=include_object, render_as_batch=True, ) @@ -105,3 +115,9 @@ def run_migrations_online(): run_migrations_offline() else: run_migrations_online() + +if "airflow.www.app" in sys.modules: + # Already imported, make sure we clear out any cached app + from airflow.www.app import purge_cached_app + + purge_cached_app() diff --git a/airflow/migrations/utils.py b/airflow/migrations/utils.py index 5737fa950774d..78f925ac84b53 100644 --- a/airflow/migrations/utils.py +++ b/airflow/migrations/utils.py @@ -14,12 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from collections import defaultdict from contextlib import contextmanager -def get_mssql_table_constraints(conn, table_name): +def get_mssql_table_constraints(conn, table_name) -> dict[str, dict[str, list[str]]]: """ This function return primary and unique constraint along with column name. Some tables like `task_instance` @@ -29,7 +30,6 @@ def get_mssql_table_constraints(conn, table_name): :param conn: sql connection object :param table_name: table name :return: a dictionary of ((constraint name, constraint type), column name) of table - :rtype: defaultdict(list) """ query = f"""SELECT tc.CONSTRAINT_NAME , tc.CONSTRAINT_TYPE, ccu.COLUMN_NAME FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc @@ -47,7 +47,7 @@ def get_mssql_table_constraints(conn, table_name): @contextmanager def disable_sqlite_fkeys(op): - if op.get_bind().dialect.name == 'sqlite': + if op.get_bind().dialect.name == "sqlite": op.execute("PRAGMA foreign_keys=off") yield op op.execute("PRAGMA foreign_keys=on") diff --git a/airflow/migrations/versions/0001_1_5_0_current_schema.py b/airflow/migrations/versions/0001_1_5_0_current_schema.py index 9824db7dad36f..0bfc7ca518308 100644 --- a/airflow/migrations/versions/0001_1_5_0_current_schema.py +++ b/airflow/migrations/versions/0001_1_5_0_current_schema.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """current schema Revision ID: e3a246e0dc1 @@ -23,20 +22,20 @@ Create Date: 2015-08-18 16:35:00.883495 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op -from sqlalchemy import func +from sqlalchemy import func, inspect -from airflow.compat.sqlalchemy import inspect from airflow.migrations.db_types import StringID # revision identifiers, used by Alembic. -revision = 'e3a246e0dc1' +revision = "e3a246e0dc1" down_revision = None branch_labels = None depends_on = None -airflow_version = '1.5.0' +airflow_version = "1.5.0" def upgrade(): @@ -44,199 +43,199 @@ def upgrade(): inspector = inspect(conn) tables = inspector.get_table_names() - if 'connection' not in tables: + if "connection" not in tables: op.create_table( - 'connection', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('conn_id', StringID(), nullable=True), - sa.Column('conn_type', sa.String(length=500), nullable=True), - sa.Column('host', sa.String(length=500), nullable=True), - sa.Column('schema', sa.String(length=500), nullable=True), - sa.Column('login', sa.String(length=500), nullable=True), - sa.Column('password', sa.String(length=500), nullable=True), - sa.Column('port', sa.Integer(), nullable=True), - sa.Column('extra', sa.String(length=5000), nullable=True), - sa.PrimaryKeyConstraint('id'), + "connection", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("conn_id", StringID(), nullable=True), + sa.Column("conn_type", sa.String(length=500), nullable=True), + sa.Column("host", sa.String(length=500), nullable=True), + sa.Column("schema", sa.String(length=500), nullable=True), + sa.Column("login", sa.String(length=500), nullable=True), + sa.Column("password", sa.String(length=500), nullable=True), + sa.Column("port", sa.Integer(), nullable=True), + sa.Column("extra", sa.String(length=5000), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - if 'dag' not in tables: + if "dag" not in tables: op.create_table( - 'dag', - sa.Column('dag_id', StringID(), nullable=False), - sa.Column('is_paused', sa.Boolean(), nullable=True), - sa.Column('is_subdag', sa.Boolean(), nullable=True), - sa.Column('is_active', sa.Boolean(), nullable=True), - sa.Column('last_scheduler_run', sa.DateTime(), nullable=True), - sa.Column('last_pickled', sa.DateTime(), nullable=True), - sa.Column('last_expired', sa.DateTime(), nullable=True), - sa.Column('scheduler_lock', sa.Boolean(), nullable=True), - sa.Column('pickle_id', sa.Integer(), nullable=True), - sa.Column('fileloc', sa.String(length=2000), nullable=True), - sa.Column('owners', sa.String(length=2000), nullable=True), - sa.PrimaryKeyConstraint('dag_id'), + "dag", + sa.Column("dag_id", StringID(), nullable=False), + sa.Column("is_paused", sa.Boolean(), nullable=True), + sa.Column("is_subdag", sa.Boolean(), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=True), + sa.Column("last_scheduler_run", sa.DateTime(), nullable=True), + sa.Column("last_pickled", sa.DateTime(), nullable=True), + sa.Column("last_expired", sa.DateTime(), nullable=True), + sa.Column("scheduler_lock", sa.Boolean(), nullable=True), + sa.Column("pickle_id", sa.Integer(), nullable=True), + sa.Column("fileloc", sa.String(length=2000), nullable=True), + sa.Column("owners", sa.String(length=2000), nullable=True), + sa.PrimaryKeyConstraint("dag_id"), ) - if 'dag_pickle' not in tables: + if "dag_pickle" not in tables: op.create_table( - 'dag_pickle', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('pickle', sa.PickleType(), nullable=True), - sa.Column('created_dttm', sa.DateTime(), nullable=True), - sa.Column('pickle_hash', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id'), + "dag_pickle", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("pickle", sa.PickleType(), nullable=True), + sa.Column("created_dttm", sa.DateTime(), nullable=True), + sa.Column("pickle_hash", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - if 'import_error' not in tables: + if "import_error" not in tables: op.create_table( - 'import_error', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('timestamp', sa.DateTime(), nullable=True), - sa.Column('filename', sa.String(length=1024), nullable=True), - sa.Column('stacktrace', sa.Text(), nullable=True), - sa.PrimaryKeyConstraint('id'), + "import_error", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("timestamp", sa.DateTime(), nullable=True), + sa.Column("filename", sa.String(length=1024), nullable=True), + sa.Column("stacktrace", sa.Text(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - if 'job' not in tables: + if "job" not in tables: op.create_table( - 'job', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('dag_id', sa.String(length=250), nullable=True), - sa.Column('state', sa.String(length=20), nullable=True), - sa.Column('job_type', sa.String(length=30), nullable=True), - sa.Column('start_date', sa.DateTime(), nullable=True), - sa.Column('end_date', sa.DateTime(), nullable=True), - sa.Column('latest_heartbeat', sa.DateTime(), nullable=True), - sa.Column('executor_class', sa.String(length=500), nullable=True), - sa.Column('hostname', sa.String(length=500), nullable=True), - sa.Column('unixname', sa.String(length=1000), nullable=True), - sa.PrimaryKeyConstraint('id'), + "job", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("dag_id", sa.String(length=250), nullable=True), + sa.Column("state", sa.String(length=20), nullable=True), + sa.Column("job_type", sa.String(length=30), nullable=True), + sa.Column("start_date", sa.DateTime(), nullable=True), + sa.Column("end_date", sa.DateTime(), nullable=True), + sa.Column("latest_heartbeat", sa.DateTime(), nullable=True), + sa.Column("executor_class", sa.String(length=500), nullable=True), + sa.Column("hostname", sa.String(length=500), nullable=True), + sa.Column("unixname", sa.String(length=1000), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - op.create_index('job_type_heart', 'job', ['job_type', 'latest_heartbeat'], unique=False) - if 'log' not in tables: + op.create_index("job_type_heart", "job", ["job_type", "latest_heartbeat"], unique=False) + if "log" not in tables: op.create_table( - 'log', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('dttm', sa.DateTime(), nullable=True), - sa.Column('dag_id', StringID(), nullable=True), - sa.Column('task_id', StringID(), nullable=True), - sa.Column('event', sa.String(length=30), nullable=True), - sa.Column('execution_date', sa.DateTime(), nullable=True), - sa.Column('owner', sa.String(length=500), nullable=True), - sa.PrimaryKeyConstraint('id'), + "log", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("dttm", sa.DateTime(), nullable=True), + sa.Column("dag_id", StringID(), nullable=True), + sa.Column("task_id", StringID(), nullable=True), + sa.Column("event", sa.String(length=30), nullable=True), + sa.Column("execution_date", sa.DateTime(), nullable=True), + sa.Column("owner", sa.String(length=500), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - if 'sla_miss' not in tables: + if "sla_miss" not in tables: op.create_table( - 'sla_miss', - sa.Column('task_id', StringID(), nullable=False), - sa.Column('dag_id', StringID(), nullable=False), - sa.Column('execution_date', sa.DateTime(), nullable=False), - sa.Column('email_sent', sa.Boolean(), nullable=True), - sa.Column('timestamp', sa.DateTime(), nullable=True), - sa.Column('description', sa.Text(), nullable=True), - sa.PrimaryKeyConstraint('task_id', 'dag_id', 'execution_date'), + "sla_miss", + sa.Column("task_id", StringID(), nullable=False), + sa.Column("dag_id", StringID(), nullable=False), + sa.Column("execution_date", sa.DateTime(), nullable=False), + sa.Column("email_sent", sa.Boolean(), nullable=True), + sa.Column("timestamp", sa.DateTime(), nullable=True), + sa.Column("description", sa.Text(), nullable=True), + sa.PrimaryKeyConstraint("task_id", "dag_id", "execution_date"), ) - if 'slot_pool' not in tables: + if "slot_pool" not in tables: op.create_table( - 'slot_pool', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('pool', StringID(length=50), nullable=True), - sa.Column('slots', sa.Integer(), nullable=True), - sa.Column('description', sa.Text(), nullable=True), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('pool'), + "slot_pool", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("pool", StringID(length=50), nullable=True), + sa.Column("slots", sa.Integer(), nullable=True), + sa.Column("description", sa.Text(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("pool"), ) - if 'task_instance' not in tables: + if "task_instance" not in tables: op.create_table( - 'task_instance', - sa.Column('task_id', StringID(), nullable=False), - sa.Column('dag_id', StringID(), nullable=False), - sa.Column('execution_date', sa.DateTime(), nullable=False), - sa.Column('start_date', sa.DateTime(), nullable=True), - sa.Column('end_date', sa.DateTime(), nullable=True), - sa.Column('duration', sa.Integer(), nullable=True), - sa.Column('state', sa.String(length=20), nullable=True), - sa.Column('try_number', sa.Integer(), nullable=True), - sa.Column('hostname', sa.String(length=1000), nullable=True), - sa.Column('unixname', sa.String(length=1000), nullable=True), - sa.Column('job_id', sa.Integer(), nullable=True), - sa.Column('pool', sa.String(length=50), nullable=True), - sa.Column('queue', sa.String(length=50), nullable=True), - sa.Column('priority_weight', sa.Integer(), nullable=True), - sa.PrimaryKeyConstraint('task_id', 'dag_id', 'execution_date'), + "task_instance", + sa.Column("task_id", StringID(), nullable=False), + sa.Column("dag_id", StringID(), nullable=False), + sa.Column("execution_date", sa.DateTime(), nullable=False), + sa.Column("start_date", sa.DateTime(), nullable=True), + sa.Column("end_date", sa.DateTime(), nullable=True), + sa.Column("duration", sa.Integer(), nullable=True), + sa.Column("state", sa.String(length=20), nullable=True), + sa.Column("try_number", sa.Integer(), nullable=True), + sa.Column("hostname", sa.String(length=1000), nullable=True), + sa.Column("unixname", sa.String(length=1000), nullable=True), + sa.Column("job_id", sa.Integer(), nullable=True), + sa.Column("pool", sa.String(length=50), nullable=True), + sa.Column("queue", sa.String(length=50), nullable=True), + sa.Column("priority_weight", sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint("task_id", "dag_id", "execution_date"), ) - op.create_index('ti_dag_state', 'task_instance', ['dag_id', 'state'], unique=False) - op.create_index('ti_pool', 'task_instance', ['pool', 'state', 'priority_weight'], unique=False) + op.create_index("ti_dag_state", "task_instance", ["dag_id", "state"], unique=False) + op.create_index("ti_pool", "task_instance", ["pool", "state", "priority_weight"], unique=False) op.create_index( - 'ti_state_lkp', 'task_instance', ['dag_id', 'task_id', 'execution_date', 'state'], unique=False + "ti_state_lkp", "task_instance", ["dag_id", "task_id", "execution_date", "state"], unique=False ) - if 'user' not in tables: + if "user" not in tables: op.create_table( - 'user', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('username', StringID(), nullable=True), - sa.Column('email', sa.String(length=500), nullable=True), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('username'), + "user", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("username", StringID(), nullable=True), + sa.Column("email", sa.String(length=500), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("username"), ) - if 'variable' not in tables: + if "variable" not in tables: op.create_table( - 'variable', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('key', StringID(), nullable=True), - sa.Column('val', sa.Text(), nullable=True), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('key'), + "variable", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("key", StringID(), nullable=True), + sa.Column("val", sa.Text(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("key"), ) - if 'chart' not in tables: + if "chart" not in tables: op.create_table( - 'chart', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('label', sa.String(length=200), nullable=True), - sa.Column('conn_id', sa.String(length=250), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=True), - sa.Column('chart_type', sa.String(length=100), nullable=True), - sa.Column('sql_layout', sa.String(length=50), nullable=True), - sa.Column('sql', sa.Text(), nullable=True), - sa.Column('y_log_scale', sa.Boolean(), nullable=True), - sa.Column('show_datatable', sa.Boolean(), nullable=True), - sa.Column('show_sql', sa.Boolean(), nullable=True), - sa.Column('height', sa.Integer(), nullable=True), - sa.Column('default_params', sa.String(length=5000), nullable=True), - sa.Column('x_is_date', sa.Boolean(), nullable=True), - sa.Column('iteration_no', sa.Integer(), nullable=True), - sa.Column('last_modified', sa.DateTime(), nullable=True), + "chart", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("label", sa.String(length=200), nullable=True), + sa.Column("conn_id", sa.String(length=250), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=True), + sa.Column("chart_type", sa.String(length=100), nullable=True), + sa.Column("sql_layout", sa.String(length=50), nullable=True), + sa.Column("sql", sa.Text(), nullable=True), + sa.Column("y_log_scale", sa.Boolean(), nullable=True), + sa.Column("show_datatable", sa.Boolean(), nullable=True), + sa.Column("show_sql", sa.Boolean(), nullable=True), + sa.Column("height", sa.Integer(), nullable=True), + sa.Column("default_params", sa.String(length=5000), nullable=True), + sa.Column("x_is_date", sa.Boolean(), nullable=True), + sa.Column("iteration_no", sa.Integer(), nullable=True), + sa.Column("last_modified", sa.DateTime(), nullable=True), sa.ForeignKeyConstraint( - ['user_id'], - ['user.id'], + ["user_id"], + ["user.id"], ), - sa.PrimaryKeyConstraint('id'), + sa.PrimaryKeyConstraint("id"), ) - if 'xcom' not in tables: + if "xcom" not in tables: op.create_table( - 'xcom', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('key', StringID(length=512), nullable=True), - sa.Column('value', sa.PickleType(), nullable=True), - sa.Column('timestamp', sa.DateTime(), default=func.now(), nullable=False), - sa.Column('execution_date', sa.DateTime(), nullable=False), - sa.Column('task_id', StringID(), nullable=False), - sa.Column('dag_id', StringID(), nullable=False), - sa.PrimaryKeyConstraint('id'), + "xcom", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("key", StringID(length=512), nullable=True), + sa.Column("value", sa.PickleType(), nullable=True), + sa.Column("timestamp", sa.DateTime(), default=func.now, nullable=False), + sa.Column("execution_date", sa.DateTime(), nullable=False), + sa.Column("task_id", StringID(), nullable=False), + sa.Column("dag_id", StringID(), nullable=False), + sa.PrimaryKeyConstraint("id"), ) def downgrade(): - op.drop_table('chart') - op.drop_table('variable') - op.drop_table('user') - op.drop_index('ti_state_lkp', table_name='task_instance') - op.drop_index('ti_pool', table_name='task_instance') - op.drop_index('ti_dag_state', table_name='task_instance') - op.drop_table('task_instance') - op.drop_table('slot_pool') - op.drop_table('sla_miss') - op.drop_table('log') - op.drop_index('job_type_heart', table_name='job') - op.drop_table('job') - op.drop_table('import_error') - op.drop_table('dag_pickle') - op.drop_table('dag') - op.drop_table('connection') - op.drop_table('xcom') + op.drop_table("chart") + op.drop_table("variable") + op.drop_table("user") + op.drop_index("ti_state_lkp", table_name="task_instance") + op.drop_index("ti_pool", table_name="task_instance") + op.drop_index("ti_dag_state", table_name="task_instance") + op.drop_table("task_instance") + op.drop_table("slot_pool") + op.drop_table("sla_miss") + op.drop_table("log") + op.drop_index("job_type_heart", table_name="job") + op.drop_table("job") + op.drop_table("import_error") + op.drop_table("dag_pickle") + op.drop_table("dag") + op.drop_table("connection") + op.drop_table("xcom") diff --git a/airflow/migrations/versions/0002_1_5_0_create_is_encrypted.py b/airflow/migrations/versions/0002_1_5_0_create_is_encrypted.py index 2d13be3d2e234..9ab70672c7457 100644 --- a/airflow/migrations/versions/0002_1_5_0_create_is_encrypted.py +++ b/airflow/migrations/versions/0002_1_5_0_create_is_encrypted.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``is_encrypted`` column in ``connection`` table Revision ID: 1507a7289a2f @@ -23,20 +22,21 @@ Create Date: 2015-08-18 18:57:51.927315 """ +from __future__ import annotations + import sqlalchemy as sa from alembic import op - -from airflow.compat.sqlalchemy import inspect +from sqlalchemy import inspect # revision identifiers, used by Alembic. -revision = '1507a7289a2f' -down_revision = 'e3a246e0dc1' +revision = "1507a7289a2f" +down_revision = "e3a246e0dc1" branch_labels = None depends_on = None -airflow_version = '1.5.0' +airflow_version = "1.5.0" connectionhelper = sa.Table( - 'connection', sa.MetaData(), sa.Column('id', sa.Integer, primary_key=True), sa.Column('is_encrypted') + "connection", sa.MetaData(), sa.Column("id", sa.Integer, primary_key=True), sa.Column("is_encrypted") ) @@ -49,16 +49,16 @@ def upgrade(): # this will only be true if 'connection' already exists in the db, # but not if alembic created it in a previous migration - if 'connection' in inspector.get_table_names(): - col_names = [c['name'] for c in inspector.get_columns('connection')] - if 'is_encrypted' in col_names: + if "connection" in inspector.get_table_names(): + col_names = [c["name"] for c in inspector.get_columns("connection")] + if "is_encrypted" in col_names: return - op.add_column('connection', sa.Column('is_encrypted', sa.Boolean, unique=False, default=False)) + op.add_column("connection", sa.Column("is_encrypted", sa.Boolean, unique=False, default=False)) conn = op.get_bind() conn.execute(connectionhelper.update().values(is_encrypted=False)) def downgrade(): - op.drop_column('connection', 'is_encrypted') + op.drop_column("connection", "is_encrypted") diff --git a/airflow/migrations/versions/0003_1_5_0_for_compatibility.py b/airflow/migrations/versions/0003_1_5_0_for_compatibility.py index cd45e1fa745da..99c0a5df0db17 100644 --- a/airflow/migrations/versions/0003_1_5_0_for_compatibility.py +++ b/airflow/migrations/versions/0003_1_5_0_for_compatibility.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Maintain history for compatibility with earlier migrations Revision ID: 13eb55f81627 @@ -23,13 +22,14 @@ Create Date: 2015-08-23 05:12:49.732174 """ +from __future__ import annotations # revision identifiers, used by Alembic. -revision = '13eb55f81627' -down_revision = '1507a7289a2f' +revision = "13eb55f81627" +down_revision = "1507a7289a2f" branch_labels = None depends_on = None -airflow_version = '1.5.0' +airflow_version = "1.5.0" def upgrade(): diff --git a/airflow/migrations/versions/0004_1_5_0_more_logging_into_task_isntance.py b/airflow/migrations/versions/0004_1_5_0_more_logging_into_task_isntance.py index 2eb793addaed5..e54d86c3771f7 100644 --- a/airflow/migrations/versions/0004_1_5_0_more_logging_into_task_isntance.py +++ b/airflow/migrations/versions/0004_1_5_0_more_logging_into_task_isntance.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``operator`` and ``queued_dttm`` to ``task_instance`` table Revision ID: 338e90f54d61 @@ -23,22 +22,24 @@ Create Date: 2015-08-25 06:09:20.460147 """ +from __future__ import annotations + import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = '338e90f54d61' -down_revision = '13eb55f81627' +revision = "338e90f54d61" +down_revision = "13eb55f81627" branch_labels = None depends_on = None -airflow_version = '1.5.0' +airflow_version = "1.5.0" def upgrade(): - op.add_column('task_instance', sa.Column('operator', sa.String(length=1000), nullable=True)) - op.add_column('task_instance', sa.Column('queued_dttm', sa.DateTime(), nullable=True)) + op.add_column("task_instance", sa.Column("operator", sa.String(length=1000), nullable=True)) + op.add_column("task_instance", sa.Column("queued_dttm", sa.DateTime(), nullable=True)) def downgrade(): - op.drop_column('task_instance', 'queued_dttm') - op.drop_column('task_instance', 'operator') + op.drop_column("task_instance", "queued_dttm") + op.drop_column("task_instance", "operator") diff --git a/airflow/migrations/versions/0005_1_5_2_job_id_indices.py b/airflow/migrations/versions/0005_1_5_2_job_id_indices.py index e6ba4fd226670..e2443e676a6ba 100644 --- a/airflow/migrations/versions/0005_1_5_2_job_id_indices.py +++ b/airflow/migrations/versions/0005_1_5_2_job_id_indices.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add indices in ``job`` table Revision ID: 52d714495f0 @@ -23,19 +22,21 @@ Create Date: 2015-10-20 03:17:01.962542 """ +from __future__ import annotations + from alembic import op # revision identifiers, used by Alembic. -revision = '52d714495f0' -down_revision = '338e90f54d61' +revision = "52d714495f0" +down_revision = "338e90f54d61" branch_labels = None depends_on = None -airflow_version = '1.5.2' +airflow_version = "1.5.2" def upgrade(): - op.create_index('idx_job_state_heartbeat', 'job', ['state', 'latest_heartbeat'], unique=False) + op.create_index("idx_job_state_heartbeat", "job", ["state", "latest_heartbeat"], unique=False) def downgrade(): - op.drop_index('idx_job_state_heartbeat', table_name='job') + op.drop_index("idx_job_state_heartbeat", table_name="job") diff --git a/airflow/migrations/versions/0006_1_6_0_adding_extra_to_log.py b/airflow/migrations/versions/0006_1_6_0_adding_extra_to_log.py index 5bc28ad372633..1fec347e8fd17 100644 --- a/airflow/migrations/versions/0006_1_6_0_adding_extra_to_log.py +++ b/airflow/migrations/versions/0006_1_6_0_adding_extra_to_log.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Adding ``extra`` column to ``Log`` table Revision ID: 502898887f84 @@ -23,20 +22,22 @@ Create Date: 2015-11-03 22:50:49.794097 """ +from __future__ import annotations + import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = '502898887f84' -down_revision = '52d714495f0' +revision = "502898887f84" +down_revision = "52d714495f0" branch_labels = None depends_on = None -airflow_version = '1.6.0' +airflow_version = "1.6.0" def upgrade(): - op.add_column('log', sa.Column('extra', sa.Text(), nullable=True)) + op.add_column("log", sa.Column("extra", sa.Text(), nullable=True)) def downgrade(): - op.drop_column('log', 'extra') + op.drop_column("log", "extra") diff --git a/airflow/migrations/versions/0007_1_6_0_add_dagrun.py b/airflow/migrations/versions/0007_1_6_0_add_dagrun.py index 66a65fd3c3d3c..93440146b4929 100644 --- a/airflow/migrations/versions/0007_1_6_0_add_dagrun.py +++ b/airflow/migrations/versions/0007_1_6_0_add_dagrun.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``dag_run`` table Revision ID: 1b38cef5b76e @@ -23,6 +22,7 @@ Create Date: 2015-10-27 08:31:48.475140 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op @@ -30,27 +30,27 @@ from airflow.migrations.db_types import StringID # revision identifiers, used by Alembic. -revision = '1b38cef5b76e' -down_revision = '502898887f84' +revision = "1b38cef5b76e" +down_revision = "502898887f84" branch_labels = None depends_on = None -airflow_version = '1.6.0' +airflow_version = "1.6.0" def upgrade(): op.create_table( - 'dag_run', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('dag_id', StringID(), nullable=True), - sa.Column('execution_date', sa.DateTime(), nullable=True), - sa.Column('state', sa.String(length=50), nullable=True), - sa.Column('run_id', StringID(), nullable=True), - sa.Column('external_trigger', sa.Boolean(), nullable=True), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('dag_id', 'execution_date'), - sa.UniqueConstraint('dag_id', 'run_id'), + "dag_run", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("dag_id", StringID(), nullable=True), + sa.Column("execution_date", sa.DateTime(), nullable=True), + sa.Column("state", sa.String(length=50), nullable=True), + sa.Column("run_id", StringID(), nullable=True), + sa.Column("external_trigger", sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("dag_id", "execution_date"), + sa.UniqueConstraint("dag_id", "run_id"), ) def downgrade(): - op.drop_table('dag_run') + op.drop_table("dag_run") diff --git a/airflow/migrations/versions/0008_1_6_0_task_duration.py b/airflow/migrations/versions/0008_1_6_0_task_duration.py index 9fa217d1aa181..0a17100e25844 100644 --- a/airflow/migrations/versions/0008_1_6_0_task_duration.py +++ b/airflow/migrations/versions/0008_1_6_0_task_duration.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Change ``task_instance.task_duration`` type to ``FLOAT`` Revision ID: 2e541a1dcfed @@ -23,24 +22,25 @@ Create Date: 2015-10-28 20:38:41.266143 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import mysql # revision identifiers, used by Alembic. -revision = '2e541a1dcfed' -down_revision = '1b38cef5b76e' +revision = "2e541a1dcfed" +down_revision = "1b38cef5b76e" branch_labels = None depends_on = None -airflow_version = '1.6.0' +airflow_version = "1.6.0" def upgrade(): # use batch_alter_table to support SQLite workaround with op.batch_alter_table("task_instance") as batch_op: batch_op.alter_column( - 'duration', + "duration", existing_type=mysql.INTEGER(display_width=11), type_=sa.Float(), existing_nullable=True, diff --git a/airflow/migrations/versions/0009_1_6_0_dagrun_config.py b/airflow/migrations/versions/0009_1_6_0_dagrun_config.py index ae7790713797c..502598b09c6b7 100644 --- a/airflow/migrations/versions/0009_1_6_0_dagrun_config.py +++ b/airflow/migrations/versions/0009_1_6_0_dagrun_config.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``conf`` column in ``dag_run`` table Revision ID: 40e67319e3a9 @@ -23,20 +22,22 @@ Create Date: 2015-10-29 08:36:31.726728 """ +from __future__ import annotations + import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = '40e67319e3a9' -down_revision = '2e541a1dcfed' +revision = "40e67319e3a9" +down_revision = "2e541a1dcfed" branch_labels = None depends_on = None -airflow_version = '1.6.0' +airflow_version = "1.6.0" def upgrade(): - op.add_column('dag_run', sa.Column('conf', sa.PickleType(), nullable=True)) + op.add_column("dag_run", sa.Column("conf", sa.PickleType(), nullable=True)) def downgrade(): - op.drop_column('dag_run', 'conf') + op.drop_column("dag_run", "conf") diff --git a/airflow/migrations/versions/0010_1_6_2_add_password_column_to_user.py b/airflow/migrations/versions/0010_1_6_2_add_password_column_to_user.py index d1004ba52a8d3..5938c3e1cf0d6 100644 --- a/airflow/migrations/versions/0010_1_6_2_add_password_column_to_user.py +++ b/airflow/migrations/versions/0010_1_6_2_add_password_column_to_user.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``password`` column to ``user`` table Revision ID: 561833c1c74b @@ -23,20 +22,22 @@ Create Date: 2015-11-30 06:51:25.872557 """ +from __future__ import annotations + import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = '561833c1c74b' -down_revision = '40e67319e3a9' +revision = "561833c1c74b" +down_revision = "40e67319e3a9" branch_labels = None depends_on = None -airflow_version = '1.6.2' +airflow_version = "1.6.2" def upgrade(): - op.add_column('user', sa.Column('password', sa.String(255))) + op.add_column("user", sa.Column("password", sa.String(255))) def downgrade(): - op.drop_column('user', 'password') + op.drop_column("user", "password") diff --git a/airflow/migrations/versions/0011_1_6_2_dagrun_start_end.py b/airflow/migrations/versions/0011_1_6_2_dagrun_start_end.py index ae471443b04e8..e03f49a8a7024 100644 --- a/airflow/migrations/versions/0011_1_6_2_dagrun_start_end.py +++ b/airflow/migrations/versions/0011_1_6_2_dagrun_start_end.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``start_date`` and ``end_date`` in ``dag_run`` table Revision ID: 4446e08588 @@ -23,23 +22,24 @@ Create Date: 2015-12-10 11:26:18.439223 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = '4446e08588' -down_revision = '561833c1c74b' +revision = "4446e08588" +down_revision = "561833c1c74b" branch_labels = None depends_on = None -airflow_version = '1.6.2' +airflow_version = "1.6.2" def upgrade(): - op.add_column('dag_run', sa.Column('end_date', sa.DateTime(), nullable=True)) - op.add_column('dag_run', sa.Column('start_date', sa.DateTime(), nullable=True)) + op.add_column("dag_run", sa.Column("end_date", sa.DateTime(), nullable=True)) + op.add_column("dag_run", sa.Column("start_date", sa.DateTime(), nullable=True)) def downgrade(): - op.drop_column('dag_run', 'start_date') - op.drop_column('dag_run', 'end_date') + op.drop_column("dag_run", "start_date") + op.drop_column("dag_run", "end_date") diff --git a/airflow/migrations/versions/0012_1_7_0_add_notification_sent_column_to_sla_miss.py b/airflow/migrations/versions/0012_1_7_0_add_notification_sent_column_to_sla_miss.py index 0e3548257c7a8..58ff703aeed16 100644 --- a/airflow/migrations/versions/0012_1_7_0_add_notification_sent_column_to_sla_miss.py +++ b/airflow/migrations/versions/0012_1_7_0_add_notification_sent_column_to_sla_miss.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``notification_sent`` column to ``sla_miss`` table Revision ID: bbc73705a13e @@ -23,20 +22,22 @@ Create Date: 2016-01-14 18:05:54.871682 """ +from __future__ import annotations + import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = 'bbc73705a13e' -down_revision = '4446e08588' +revision = "bbc73705a13e" +down_revision = "4446e08588" branch_labels = None depends_on = None -airflow_version = '1.7.0' +airflow_version = "1.7.0" def upgrade(): - op.add_column('sla_miss', sa.Column('notification_sent', sa.Boolean, default=False)) + op.add_column("sla_miss", sa.Column("notification_sent", sa.Boolean, default=False)) def downgrade(): - op.drop_column('sla_miss', 'notification_sent') + op.drop_column("sla_miss", "notification_sent") diff --git a/airflow/migrations/versions/0013_1_7_0_add_a_column_to_track_the_encryption_.py b/airflow/migrations/versions/0013_1_7_0_add_a_column_to_track_the_encryption_.py index 4e0d53e062654..f8bd23bf5f441 100644 --- a/airflow/migrations/versions/0013_1_7_0_add_a_column_to_track_the_encryption_.py +++ b/airflow/migrations/versions/0013_1_7_0_add_a_column_to_track_the_encryption_.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add a column to track the encryption state of the 'Extra' field in connection Revision ID: bba5a7cfc896 @@ -23,21 +22,22 @@ Create Date: 2016-01-29 15:10:32.656425 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = 'bba5a7cfc896' -down_revision = 'bbc73705a13e' +revision = "bba5a7cfc896" +down_revision = "bbc73705a13e" branch_labels = None depends_on = None -airflow_version = '1.7.0' +airflow_version = "1.7.0" def upgrade(): - op.add_column('connection', sa.Column('is_extra_encrypted', sa.Boolean, default=False)) + op.add_column("connection", sa.Column("is_extra_encrypted", sa.Boolean, default=False)) def downgrade(): - op.drop_column('connection', 'is_extra_encrypted') + op.drop_column("connection", "is_extra_encrypted") diff --git a/airflow/migrations/versions/0014_1_7_0_add_is_encrypted_column_to_variable_.py b/airflow/migrations/versions/0014_1_7_0_add_is_encrypted_column_to_variable_.py index e72260ec2089d..eba4999da7f8c 100644 --- a/airflow/migrations/versions/0014_1_7_0_add_is_encrypted_column_to_variable_.py +++ b/airflow/migrations/versions/0014_1_7_0_add_is_encrypted_column_to_variable_.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``is_encrypted`` column to variable table Revision ID: 1968acfc09e3 @@ -23,20 +22,22 @@ Create Date: 2016-02-02 17:20:55.692295 """ +from __future__ import annotations + import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = '1968acfc09e3' -down_revision = 'bba5a7cfc896' +revision = "1968acfc09e3" +down_revision = "bba5a7cfc896" branch_labels = None depends_on = None -airflow_version = '1.7.0' +airflow_version = "1.7.0" def upgrade(): - op.add_column('variable', sa.Column('is_encrypted', sa.Boolean, default=False)) + op.add_column("variable", sa.Column("is_encrypted", sa.Boolean, default=False)) def downgrade(): - op.drop_column('variable', 'is_encrypted') + op.drop_column("variable", "is_encrypted") diff --git a/airflow/migrations/versions/0015_1_7_1_rename_user_table.py b/airflow/migrations/versions/0015_1_7_1_rename_user_table.py index 7259eb5c1821c..3b644e1453118 100644 --- a/airflow/migrations/versions/0015_1_7_1_rename_user_table.py +++ b/airflow/migrations/versions/0015_1_7_1_rename_user_table.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Rename user table Revision ID: 2e82aab8ef20 @@ -23,19 +22,21 @@ Create Date: 2016-04-02 19:28:15.211915 """ +from __future__ import annotations + from alembic import op # revision identifiers, used by Alembic. -revision = '2e82aab8ef20' -down_revision = '1968acfc09e3' +revision = "2e82aab8ef20" +down_revision = "1968acfc09e3" branch_labels = None depends_on = None -airflow_version = '1.7.1' +airflow_version = "1.7.1" def upgrade(): - op.rename_table('user', 'users') + op.rename_table("user", "users") def downgrade(): - op.rename_table('users', 'user') + op.rename_table("users", "user") diff --git a/airflow/migrations/versions/0016_1_7_1_add_ti_state_index.py b/airflow/migrations/versions/0016_1_7_1_add_ti_state_index.py index eae9bfd06450d..1f10af8d156fe 100644 --- a/airflow/migrations/versions/0016_1_7_1_add_ti_state_index.py +++ b/airflow/migrations/versions/0016_1_7_1_add_ti_state_index.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add TI state index Revision ID: 211e584da130 @@ -23,19 +22,21 @@ Create Date: 2016-06-30 10:54:24.323588 """ +from __future__ import annotations + from alembic import op # revision identifiers, used by Alembic. -revision = '211e584da130' -down_revision = '2e82aab8ef20' +revision = "211e584da130" +down_revision = "2e82aab8ef20" branch_labels = None depends_on = None -airflow_version = '1.7.1.3' +airflow_version = "1.7.1.3" def upgrade(): - op.create_index('ti_state', 'task_instance', ['state'], unique=False) + op.create_index("ti_state", "task_instance", ["state"], unique=False) def downgrade(): - op.drop_index('ti_state', table_name='task_instance') + op.drop_index("ti_state", table_name="task_instance") diff --git a/airflow/migrations/versions/0017_1_7_1_add_task_fails_journal_table.py b/airflow/migrations/versions/0017_1_7_1_add_task_fails_journal_table.py index 679efb70a1818..45aad1d2c403f 100644 --- a/airflow/migrations/versions/0017_1_7_1_add_task_fails_journal_table.py +++ b/airflow/migrations/versions/0017_1_7_1_add_task_fails_journal_table.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``task_fail`` table Revision ID: 64de9cddf6c9 @@ -23,32 +22,34 @@ Create Date: 2016-08-03 14:02:59.203021 """ +from __future__ import annotations + import sqlalchemy as sa from alembic import op from airflow.migrations.db_types import StringID # revision identifiers, used by Alembic. -revision = '64de9cddf6c9' -down_revision = '211e584da130' +revision = "64de9cddf6c9" +down_revision = "211e584da130" branch_labels = None depends_on = None -airflow_version = '1.7.1.3' +airflow_version = "1.7.1.3" def upgrade(): op.create_table( - 'task_fail', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('task_id', StringID(), nullable=False), - sa.Column('dag_id', StringID(), nullable=False), - sa.Column('execution_date', sa.DateTime(), nullable=False), - sa.Column('start_date', sa.DateTime(), nullable=True), - sa.Column('end_date', sa.DateTime(), nullable=True), - sa.Column('duration', sa.Integer(), nullable=True), - sa.PrimaryKeyConstraint('id'), + "task_fail", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("task_id", StringID(), nullable=False), + sa.Column("dag_id", StringID(), nullable=False), + sa.Column("execution_date", sa.DateTime(), nullable=False), + sa.Column("start_date", sa.DateTime(), nullable=True), + sa.Column("end_date", sa.DateTime(), nullable=True), + sa.Column("duration", sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) def downgrade(): - op.drop_table('task_fail') + op.drop_table("task_fail") diff --git a/airflow/migrations/versions/0018_1_7_1_add_dag_stats_table.py b/airflow/migrations/versions/0018_1_7_1_add_dag_stats_table.py index 9cb37d2a41cbc..19726d872e395 100644 --- a/airflow/migrations/versions/0018_1_7_1_add_dag_stats_table.py +++ b/airflow/migrations/versions/0018_1_7_1_add_dag_stats_table.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``dag_stats`` table Revision ID: f2ca10b85618 @@ -23,29 +22,31 @@ Create Date: 2016-07-20 15:08:28.247537 """ +from __future__ import annotations + import sqlalchemy as sa from alembic import op from airflow.migrations.db_types import StringID # revision identifiers, used by Alembic. -revision = 'f2ca10b85618' -down_revision = '64de9cddf6c9' +revision = "f2ca10b85618" +down_revision = "64de9cddf6c9" branch_labels = None depends_on = None -airflow_version = '1.7.1.3' +airflow_version = "1.7.1.3" def upgrade(): op.create_table( - 'dag_stats', - sa.Column('dag_id', StringID(), nullable=False), - sa.Column('state', sa.String(length=50), nullable=False), - sa.Column('count', sa.Integer(), nullable=False, default=0), - sa.Column('dirty', sa.Boolean(), nullable=False, default=False), - sa.PrimaryKeyConstraint('dag_id', 'state'), + "dag_stats", + sa.Column("dag_id", StringID(), nullable=False), + sa.Column("state", sa.String(length=50), nullable=False), + sa.Column("count", sa.Integer(), nullable=False, default=0), + sa.Column("dirty", sa.Boolean(), nullable=False, default=False), + sa.PrimaryKeyConstraint("dag_id", "state"), ) def downgrade(): - op.drop_table('dag_stats') + op.drop_table("dag_stats") diff --git a/airflow/migrations/versions/0019_1_7_1_add_fractional_seconds_to_mysql_tables.py b/airflow/migrations/versions/0019_1_7_1_add_fractional_seconds_to_mysql_tables.py index 9cee531e749d5..23ec9bf1776d1 100644 --- a/airflow/migrations/versions/0019_1_7_1_add_fractional_seconds_to_mysql_tables.py +++ b/airflow/migrations/versions/0019_1_7_1_add_fractional_seconds_to_mysql_tables.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add fractional seconds to MySQL tables Revision ID: 4addfa1236f1 @@ -23,100 +22,101 @@ Create Date: 2016-09-11 13:39:18.592072 """ +from __future__ import annotations from alembic import op from sqlalchemy.dialects import mysql # revision identifiers, used by Alembic. -revision = '4addfa1236f1' -down_revision = 'f2ca10b85618' +revision = "4addfa1236f1" +down_revision = "f2ca10b85618" branch_labels = None depends_on = None -airflow_version = '1.7.1.3' +airflow_version = "1.7.1.3" def upgrade(): conn = op.get_bind() if conn.dialect.name == "mysql": - op.alter_column(table_name='dag', column_name='last_scheduler_run', type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='dag', column_name='last_pickled', type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='dag', column_name='last_expired', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="dag", column_name="last_scheduler_run", type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="dag", column_name="last_pickled", type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="dag", column_name="last_expired", type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='dag_pickle', column_name='created_dttm', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="dag_pickle", column_name="created_dttm", type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='dag_run', column_name='execution_date', type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='dag_run', column_name='start_date', type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='dag_run', column_name='end_date', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="dag_run", column_name="execution_date", type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="dag_run", column_name="start_date", type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="dag_run", column_name="end_date", type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='import_error', column_name='timestamp', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="import_error", column_name="timestamp", type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='job', column_name='start_date', type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='job', column_name='end_date', type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='job', column_name='latest_heartbeat', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="job", column_name="start_date", type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="job", column_name="end_date", type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="job", column_name="latest_heartbeat", type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='log', column_name='dttm', type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='log', column_name='execution_date', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="log", column_name="dttm", type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="log", column_name="execution_date", type_=mysql.DATETIME(fsp=6)) op.alter_column( - table_name='sla_miss', column_name='execution_date', type_=mysql.DATETIME(fsp=6), nullable=False + table_name="sla_miss", column_name="execution_date", type_=mysql.DATETIME(fsp=6), nullable=False ) - op.alter_column(table_name='sla_miss', column_name='timestamp', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="sla_miss", column_name="timestamp", type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='task_fail', column_name='execution_date', type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='task_fail', column_name='start_date', type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='task_fail', column_name='end_date', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="task_fail", column_name="execution_date", type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="task_fail", column_name="start_date", type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="task_fail", column_name="end_date", type_=mysql.DATETIME(fsp=6)) op.alter_column( - table_name='task_instance', - column_name='execution_date', + table_name="task_instance", + column_name="execution_date", type_=mysql.DATETIME(fsp=6), nullable=False, ) - op.alter_column(table_name='task_instance', column_name='start_date', type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='task_instance', column_name='end_date', type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='task_instance', column_name='queued_dttm', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="task_instance", column_name="start_date", type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="task_instance", column_name="end_date", type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="task_instance", column_name="queued_dttm", type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='xcom', column_name='timestamp', type_=mysql.DATETIME(fsp=6)) - op.alter_column(table_name='xcom', column_name='execution_date', type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="xcom", column_name="timestamp", type_=mysql.DATETIME(fsp=6)) + op.alter_column(table_name="xcom", column_name="execution_date", type_=mysql.DATETIME(fsp=6)) def downgrade(): conn = op.get_bind() if conn.dialect.name == "mysql": - op.alter_column(table_name='dag', column_name='last_scheduler_run', type_=mysql.DATETIME()) - op.alter_column(table_name='dag', column_name='last_pickled', type_=mysql.DATETIME()) - op.alter_column(table_name='dag', column_name='last_expired', type_=mysql.DATETIME()) + op.alter_column(table_name="dag", column_name="last_scheduler_run", type_=mysql.DATETIME()) + op.alter_column(table_name="dag", column_name="last_pickled", type_=mysql.DATETIME()) + op.alter_column(table_name="dag", column_name="last_expired", type_=mysql.DATETIME()) - op.alter_column(table_name='dag_pickle', column_name='created_dttm', type_=mysql.DATETIME()) + op.alter_column(table_name="dag_pickle", column_name="created_dttm", type_=mysql.DATETIME()) - op.alter_column(table_name='dag_run', column_name='execution_date', type_=mysql.DATETIME()) - op.alter_column(table_name='dag_run', column_name='start_date', type_=mysql.DATETIME()) - op.alter_column(table_name='dag_run', column_name='end_date', type_=mysql.DATETIME()) + op.alter_column(table_name="dag_run", column_name="execution_date", type_=mysql.DATETIME()) + op.alter_column(table_name="dag_run", column_name="start_date", type_=mysql.DATETIME()) + op.alter_column(table_name="dag_run", column_name="end_date", type_=mysql.DATETIME()) - op.alter_column(table_name='import_error', column_name='timestamp', type_=mysql.DATETIME()) + op.alter_column(table_name="import_error", column_name="timestamp", type_=mysql.DATETIME()) - op.alter_column(table_name='job', column_name='start_date', type_=mysql.DATETIME()) - op.alter_column(table_name='job', column_name='end_date', type_=mysql.DATETIME()) - op.alter_column(table_name='job', column_name='latest_heartbeat', type_=mysql.DATETIME()) + op.alter_column(table_name="job", column_name="start_date", type_=mysql.DATETIME()) + op.alter_column(table_name="job", column_name="end_date", type_=mysql.DATETIME()) + op.alter_column(table_name="job", column_name="latest_heartbeat", type_=mysql.DATETIME()) - op.alter_column(table_name='log', column_name='dttm', type_=mysql.DATETIME()) - op.alter_column(table_name='log', column_name='execution_date', type_=mysql.DATETIME()) + op.alter_column(table_name="log", column_name="dttm", type_=mysql.DATETIME()) + op.alter_column(table_name="log", column_name="execution_date", type_=mysql.DATETIME()) op.alter_column( - table_name='sla_miss', column_name='execution_date', type_=mysql.DATETIME(), nullable=False + table_name="sla_miss", column_name="execution_date", type_=mysql.DATETIME(), nullable=False ) - op.alter_column(table_name='sla_miss', column_name='timestamp', type_=mysql.DATETIME()) + op.alter_column(table_name="sla_miss", column_name="timestamp", type_=mysql.DATETIME()) - op.alter_column(table_name='task_fail', column_name='execution_date', type_=mysql.DATETIME()) - op.alter_column(table_name='task_fail', column_name='start_date', type_=mysql.DATETIME()) - op.alter_column(table_name='task_fail', column_name='end_date', type_=mysql.DATETIME()) + op.alter_column(table_name="task_fail", column_name="execution_date", type_=mysql.DATETIME()) + op.alter_column(table_name="task_fail", column_name="start_date", type_=mysql.DATETIME()) + op.alter_column(table_name="task_fail", column_name="end_date", type_=mysql.DATETIME()) op.alter_column( - table_name='task_instance', column_name='execution_date', type_=mysql.DATETIME(), nullable=False + table_name="task_instance", column_name="execution_date", type_=mysql.DATETIME(), nullable=False ) - op.alter_column(table_name='task_instance', column_name='start_date', type_=mysql.DATETIME()) - op.alter_column(table_name='task_instance', column_name='end_date', type_=mysql.DATETIME()) - op.alter_column(table_name='task_instance', column_name='queued_dttm', type_=mysql.DATETIME()) + op.alter_column(table_name="task_instance", column_name="start_date", type_=mysql.DATETIME()) + op.alter_column(table_name="task_instance", column_name="end_date", type_=mysql.DATETIME()) + op.alter_column(table_name="task_instance", column_name="queued_dttm", type_=mysql.DATETIME()) - op.alter_column(table_name='xcom', column_name='timestamp', type_=mysql.DATETIME()) - op.alter_column(table_name='xcom', column_name='execution_date', type_=mysql.DATETIME()) + op.alter_column(table_name="xcom", column_name="timestamp", type_=mysql.DATETIME()) + op.alter_column(table_name="xcom", column_name="execution_date", type_=mysql.DATETIME()) diff --git a/airflow/migrations/versions/0020_1_7_1_xcom_dag_task_indices.py b/airflow/migrations/versions/0020_1_7_1_xcom_dag_task_indices.py index ea2aba8e32620..d162a5872ccc4 100644 --- a/airflow/migrations/versions/0020_1_7_1_xcom_dag_task_indices.py +++ b/airflow/migrations/versions/0020_1_7_1_xcom_dag_task_indices.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add indices on ``xcom`` table Revision ID: 8504051e801b @@ -23,22 +22,23 @@ Create Date: 2016-11-29 08:13:03.253312 """ +from __future__ import annotations from alembic import op # revision identifiers, used by Alembic. -revision = '8504051e801b' -down_revision = '4addfa1236f1' +revision = "8504051e801b" +down_revision = "4addfa1236f1" branch_labels = None depends_on = None -airflow_version = '1.7.1.3' +airflow_version = "1.7.1.3" def upgrade(): """Create Index.""" - op.create_index('idx_xcom_dag_task_date', 'xcom', ['dag_id', 'task_id', 'execution_date'], unique=False) + op.create_index("idx_xcom_dag_task_date", "xcom", ["dag_id", "task_id", "execution_date"], unique=False) def downgrade(): """Drop Index.""" - op.drop_index('idx_xcom_dag_task_date', table_name='xcom') + op.drop_index("idx_xcom_dag_task_date", table_name="xcom") diff --git a/airflow/migrations/versions/0021_1_7_1_add_pid_field_to_taskinstance.py b/airflow/migrations/versions/0021_1_7_1_add_pid_field_to_taskinstance.py index cea1d1010a888..0eb7e9bf730d4 100644 --- a/airflow/migrations/versions/0021_1_7_1_add_pid_field_to_taskinstance.py +++ b/airflow/migrations/versions/0021_1_7_1_add_pid_field_to_taskinstance.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``pid`` field to ``TaskInstance`` Revision ID: 5e7d17757c7a @@ -23,23 +22,24 @@ Create Date: 2016-12-07 15:51:37.119478 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = '5e7d17757c7a' -down_revision = '8504051e801b' +revision = "5e7d17757c7a" +down_revision = "8504051e801b" branch_labels = None depends_on = None -airflow_version = '1.7.1.3' +airflow_version = "1.7.1.3" def upgrade(): """Add pid column to task_instance table.""" - op.add_column('task_instance', sa.Column('pid', sa.Integer)) + op.add_column("task_instance", sa.Column("pid", sa.Integer)) def downgrade(): """Drop pid column from task_instance table.""" - op.drop_column('task_instance', 'pid') + op.drop_column("task_instance", "pid") diff --git a/airflow/migrations/versions/0022_1_7_1_add_dag_id_state_index_on_dag_run_table.py b/airflow/migrations/versions/0022_1_7_1_add_dag_id_state_index_on_dag_run_table.py index a7acdb2e7b187..bdf070b739f15 100644 --- a/airflow/migrations/versions/0022_1_7_1_add_dag_id_state_index_on_dag_run_table.py +++ b/airflow/migrations/versions/0022_1_7_1_add_dag_id_state_index_on_dag_run_table.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``dag_id``/``state`` index on ``dag_run`` table Revision ID: 127d2bf2dfa7 @@ -23,19 +22,21 @@ Create Date: 2017-01-25 11:43:51.635667 """ +from __future__ import annotations + from alembic import op # revision identifiers, used by Alembic. -revision = '127d2bf2dfa7' -down_revision = '5e7d17757c7a' +revision = "127d2bf2dfa7" +down_revision = "5e7d17757c7a" branch_labels = None depends_on = None -airflow_version = '1.7.1.3' +airflow_version = "1.7.1.3" def upgrade(): - op.create_index('dag_id_state', 'dag_run', ['dag_id', 'state'], unique=False) + op.create_index("dag_id_state", "dag_run", ["dag_id", "state"], unique=False) def downgrade(): - op.drop_index('dag_id_state', table_name='dag_run') + op.drop_index("dag_id_state", table_name="dag_run") diff --git a/airflow/migrations/versions/0023_1_8_2_add_max_tries_column_to_task_instance.py b/airflow/migrations/versions/0023_1_8_2_add_max_tries_column_to_task_instance.py index 7685b77afd04d..c6eb6222dc190 100644 --- a/airflow/migrations/versions/0023_1_8_2_add_max_tries_column_to_task_instance.py +++ b/airflow/migrations/versions/0023_1_8_2_add_max_tries_column_to_task_instance.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``max_tries`` column to ``task_instance`` Revision ID: cc1e65623dc7 @@ -22,22 +21,22 @@ Create Date: 2017-06-19 16:53:12.851141 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op -from sqlalchemy import Column, Integer, String +from sqlalchemy import Column, Integer, String, inspect from sqlalchemy.ext.declarative import declarative_base from airflow import settings -from airflow.compat.sqlalchemy import inspect from airflow.models import DagBag # revision identifiers, used by Alembic. -revision = 'cc1e65623dc7' -down_revision = '127d2bf2dfa7' +revision = "cc1e65623dc7" +down_revision = "127d2bf2dfa7" branch_labels = None depends_on = None -airflow_version = '1.8.2' +airflow_version = "1.8.2" Base = declarative_base() BATCH_SIZE = 5000 @@ -56,7 +55,7 @@ class TaskInstance(Base): # type: ignore def upgrade(): - op.add_column('task_instance', sa.Column('max_tries', sa.Integer, server_default="-1")) + op.add_column("task_instance", sa.Column("max_tries", sa.Integer, server_default="-1")) # Check if table task_instance exist before data migration. This check is # needed for database that does not create table until migration finishes. # Checking task_instance table exists prevent the error of querying @@ -65,7 +64,7 @@ def upgrade(): inspector = inspect(connection) tables = inspector.get_table_names() - if 'task_instance' in tables: + if "task_instance" in tables: # Get current session sessionmaker = sa.orm.sessionmaker() session = sessionmaker(bind=connection) @@ -102,7 +101,7 @@ def upgrade(): def downgrade(): engine = settings.engine connection = op.get_bind() - if engine.dialect.has_table(connection, 'task_instance'): + if engine.dialect.has_table(connection, "task_instance"): sessionmaker = sa.orm.sessionmaker() session = sessionmaker(bind=connection) dagbag = DagBag(settings.DAGS_FOLDER) @@ -124,4 +123,4 @@ def downgrade(): session.merge(ti) session.commit() session.commit() - op.drop_column('task_instance', 'max_tries') + op.drop_column("task_instance", "max_tries") diff --git a/airflow/migrations/versions/0024_1_8_2_make_xcom_value_column_a_large_binary.py b/airflow/migrations/versions/0024_1_8_2_make_xcom_value_column_a_large_binary.py index 00a21bf543125..4512918b5ed7d 100644 --- a/airflow/migrations/versions/0024_1_8_2_make_xcom_value_column_a_large_binary.py +++ b/airflow/migrations/versions/0024_1_8_2_make_xcom_value_column_a_large_binary.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Make xcom value column a large binary Revision ID: bdaa763e6c56 @@ -23,16 +22,18 @@ Create Date: 2017-08-14 16:06:31.568971 """ +from __future__ import annotations + import dill import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = 'bdaa763e6c56' -down_revision = 'cc1e65623dc7' +revision = "bdaa763e6c56" +down_revision = "cc1e65623dc7" branch_labels = None depends_on = None -airflow_version = '1.8.2' +airflow_version = "1.8.2" def upgrade(): @@ -40,10 +41,10 @@ def upgrade(): # type. # use batch_alter_table to support SQLite workaround with op.batch_alter_table("xcom") as batch_op: - batch_op.alter_column('value', type_=sa.LargeBinary()) + batch_op.alter_column("value", type_=sa.LargeBinary()) def downgrade(): # use batch_alter_table to support SQLite workaround with op.batch_alter_table("xcom") as batch_op: - batch_op.alter_column('value', type_=sa.PickleType(pickler=dill)) + batch_op.alter_column("value", type_=sa.PickleType(pickler=dill)) diff --git a/airflow/migrations/versions/0025_1_8_2_add_ti_job_id_index.py b/airflow/migrations/versions/0025_1_8_2_add_ti_job_id_index.py index 364e57d50360d..79f148ff17acf 100644 --- a/airflow/migrations/versions/0025_1_8_2_add_ti_job_id_index.py +++ b/airflow/migrations/versions/0025_1_8_2_add_ti_job_id_index.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Create index on ``job_id`` column in ``task_instance`` table Revision ID: 947454bf1dff @@ -23,19 +22,21 @@ Create Date: 2017-08-15 15:12:13.845074 """ +from __future__ import annotations + from alembic import op # revision identifiers, used by Alembic. -revision = '947454bf1dff' -down_revision = 'bdaa763e6c56' +revision = "947454bf1dff" +down_revision = "bdaa763e6c56" branch_labels = None depends_on = None -airflow_version = '1.8.2' +airflow_version = "1.8.2" def upgrade(): - op.create_index('ti_job_id', 'task_instance', ['job_id'], unique=False) + op.create_index("ti_job_id", "task_instance", ["job_id"], unique=False) def downgrade(): - op.drop_index('ti_job_id', table_name='task_instance') + op.drop_index("ti_job_id", table_name="task_instance") diff --git a/airflow/migrations/versions/0026_1_8_2_increase_text_size_for_mysql.py b/airflow/migrations/versions/0026_1_8_2_increase_text_size_for_mysql.py index 41450aec0f9c2..711ba3662a4a9 100644 --- a/airflow/migrations/versions/0026_1_8_2_increase_text_size_for_mysql.py +++ b/airflow/migrations/versions/0026_1_8_2_increase_text_size_for_mysql.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Increase text size for MySQL (not relevant for other DBs' text types) Revision ID: d2ae31099d61 @@ -23,24 +22,26 @@ Create Date: 2017-08-18 17:07:16.686130 """ +from __future__ import annotations + from alembic import op from sqlalchemy.dialects import mysql # revision identifiers, used by Alembic. -revision = 'd2ae31099d61' -down_revision = '947454bf1dff' +revision = "d2ae31099d61" +down_revision = "947454bf1dff" branch_labels = None depends_on = None -airflow_version = '1.8.2' +airflow_version = "1.8.2" def upgrade(): conn = op.get_bind() if conn.dialect.name == "mysql": - op.alter_column(table_name='variable', column_name='val', type_=mysql.MEDIUMTEXT) + op.alter_column(table_name="variable", column_name="val", type_=mysql.MEDIUMTEXT) def downgrade(): conn = op.get_bind() if conn.dialect.name == "mysql": - op.alter_column(table_name='variable', column_name='val', type_=mysql.TEXT) + op.alter_column(table_name="variable", column_name="val", type_=mysql.TEXT) diff --git a/airflow/migrations/versions/0027_1_10_0_add_time_zone_awareness.py b/airflow/migrations/versions/0027_1_10_0_add_time_zone_awareness.py index 8945ff06c99ca..2b36361c5b87e 100644 --- a/airflow/migrations/versions/0027_1_10_0_add_time_zone_awareness.py +++ b/airflow/migrations/versions/0027_1_10_0_add_time_zone_awareness.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add time zone awareness Revision ID: 0e2a74e0fc9f @@ -22,6 +21,7 @@ Create Date: 2017-11-10 22:22:31.326152 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op @@ -32,7 +32,7 @@ down_revision = "d2ae31099d61" branch_labels = None depends_on = None -airflow_version = '1.10.0' +airflow_version = "1.10.0" def upgrade(): diff --git a/airflow/migrations/versions/0028_1_10_0_add_kubernetes_resource_checkpointing.py b/airflow/migrations/versions/0028_1_10_0_add_kubernetes_resource_checkpointing.py index c37831ac05c0d..8ed7187a508fa 100644 --- a/airflow/migrations/versions/0028_1_10_0_add_kubernetes_resource_checkpointing.py +++ b/airflow/migrations/versions/0028_1_10_0_add_kubernetes_resource_checkpointing.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - - """Add Kubernetes resource check-pointing Revision ID: 33ae817a1ff4 @@ -23,17 +21,18 @@ Create Date: 2017-09-11 15:26:47.598494 """ +from __future__ import annotations + import sqlalchemy as sa from alembic import op - -from airflow.compat.sqlalchemy import inspect +from sqlalchemy import inspect # revision identifiers, used by Alembic. -revision = '33ae817a1ff4' -down_revision = 'd2ae31099d61' +revision = "33ae817a1ff4" +down_revision = "d2ae31099d61" branch_labels = None depends_on = None -airflow_version = '1.10.0' +airflow_version = "1.10.0" RESOURCE_TABLE = "kube_resource_version" diff --git a/airflow/migrations/versions/0029_1_10_0_add_executor_config_to_task_instance.py b/airflow/migrations/versions/0029_1_10_0_add_executor_config_to_task_instance.py index 67bda056ecb51..9808c2c882363 100644 --- a/airflow/migrations/versions/0029_1_10_0_add_executor_config_to_task_instance.py +++ b/airflow/migrations/versions/0029_1_10_0_add_executor_config_to_task_instance.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - - """Add ``executor_config`` column to ``task_instance`` table Revision ID: 33ae817a1ff4 @@ -23,17 +21,18 @@ Create Date: 2017-09-11 15:26:47.598494 """ +from __future__ import annotations import dill import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = '27c6a30d7c24' -down_revision = '33ae817a1ff4' +revision = "27c6a30d7c24" +down_revision = "33ae817a1ff4" branch_labels = None depends_on = None -airflow_version = '1.10.0' +airflow_version = "1.10.0" TASK_INSTANCE_TABLE = "task_instance" NEW_COLUMN = "executor_config" diff --git a/airflow/migrations/versions/0030_1_10_0_add_kubernetes_scheduler_uniqueness.py b/airflow/migrations/versions/0030_1_10_0_add_kubernetes_scheduler_uniqueness.py index 4dd1864580a9d..c90934576e8c7 100644 --- a/airflow/migrations/versions/0030_1_10_0_add_kubernetes_scheduler_uniqueness.py +++ b/airflow/migrations/versions/0030_1_10_0_add_kubernetes_scheduler_uniqueness.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - - """Add kubernetes scheduler uniqueness Revision ID: 86770d1215c0 @@ -23,15 +21,17 @@ Create Date: 2018-04-03 15:31:20.814328 """ +from __future__ import annotations + import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = '86770d1215c0' -down_revision = '27c6a30d7c24' +revision = "86770d1215c0" +down_revision = "27c6a30d7c24" branch_labels = None depends_on = None -airflow_version = '1.10.0' +airflow_version = "1.10.0" RESOURCE_TABLE = "kube_worker_uuid" diff --git a/airflow/migrations/versions/0031_1_10_0_merge_heads.py b/airflow/migrations/versions/0031_1_10_0_merge_heads.py index 33691b7b04ec3..2f9f404351ae0 100644 --- a/airflow/migrations/versions/0031_1_10_0_merge_heads.py +++ b/airflow/migrations/versions/0031_1_10_0_merge_heads.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Merge migrations Heads Revision ID: 05f30312d566 @@ -22,13 +21,14 @@ Create Date: 2018-06-17 10:47:23.339972 """ +from __future__ import annotations # revision identifiers, used by Alembic. -revision = '05f30312d566' -down_revision = ('86770d1215c0', '0e2a74e0fc9f') +revision = "05f30312d566" +down_revision = ("86770d1215c0", "0e2a74e0fc9f") branch_labels = None depends_on = None -airflow_version = '1.10.0' +airflow_version = "1.10.0" def upgrade(): diff --git a/airflow/migrations/versions/0032_1_10_0_fix_mysql_not_null_constraint.py b/airflow/migrations/versions/0032_1_10_0_fix_mysql_not_null_constraint.py index 765d43762e1aa..c5d0eecb9eb9a 100644 --- a/airflow/migrations/versions/0032_1_10_0_fix_mysql_not_null_constraint.py +++ b/airflow/migrations/versions/0032_1_10_0_fix_mysql_not_null_constraint.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Fix MySQL not null constraint Revision ID: f23433877c24 @@ -22,30 +21,32 @@ Create Date: 2018-06-17 10:16:31.412131 """ +from __future__ import annotations + from alembic import op from sqlalchemy.dialects import mysql # revision identifiers, used by Alembic. -revision = 'f23433877c24' -down_revision = '05f30312d566' +revision = "f23433877c24" +down_revision = "05f30312d566" branch_labels = None depends_on = None -airflow_version = '1.10.0' +airflow_version = "1.10.0" def upgrade(): conn = op.get_bind() - if conn.dialect.name == 'mysql': + if conn.dialect.name == "mysql": conn.execute("SET time_zone = '+00:00'") - op.alter_column('task_fail', 'execution_date', existing_type=mysql.TIMESTAMP(fsp=6), nullable=False) - op.alter_column('xcom', 'execution_date', existing_type=mysql.TIMESTAMP(fsp=6), nullable=False) - op.alter_column('xcom', 'timestamp', existing_type=mysql.TIMESTAMP(fsp=6), nullable=False) + op.alter_column("task_fail", "execution_date", existing_type=mysql.TIMESTAMP(fsp=6), nullable=False) + op.alter_column("xcom", "execution_date", existing_type=mysql.TIMESTAMP(fsp=6), nullable=False) + op.alter_column("xcom", "timestamp", existing_type=mysql.TIMESTAMP(fsp=6), nullable=False) def downgrade(): conn = op.get_bind() - if conn.dialect.name == 'mysql': + if conn.dialect.name == "mysql": conn.execute("SET time_zone = '+00:00'") - op.alter_column('xcom', 'timestamp', existing_type=mysql.TIMESTAMP(fsp=6), nullable=True) - op.alter_column('xcom', 'execution_date', existing_type=mysql.TIMESTAMP(fsp=6), nullable=True) - op.alter_column('task_fail', 'execution_date', existing_type=mysql.TIMESTAMP(fsp=6), nullable=True) + op.alter_column("xcom", "timestamp", existing_type=mysql.TIMESTAMP(fsp=6), nullable=True) + op.alter_column("xcom", "execution_date", existing_type=mysql.TIMESTAMP(fsp=6), nullable=True) + op.alter_column("task_fail", "execution_date", existing_type=mysql.TIMESTAMP(fsp=6), nullable=True) diff --git a/airflow/migrations/versions/0033_1_10_0_fix_sqlite_foreign_key.py b/airflow/migrations/versions/0033_1_10_0_fix_sqlite_foreign_key.py index abf289572fbda..f13891b4d4079 100644 --- a/airflow/migrations/versions/0033_1_10_0_fix_sqlite_foreign_key.py +++ b/airflow/migrations/versions/0033_1_10_0_fix_sqlite_foreign_key.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Fix Sqlite foreign key Revision ID: 856955da8476 @@ -22,22 +21,23 @@ Create Date: 2018-06-17 15:54:53.844230 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = '856955da8476' -down_revision = 'f23433877c24' +revision = "856955da8476" +down_revision = "f23433877c24" branch_labels = None depends_on = None -airflow_version = '1.10.0' +airflow_version = "1.10.0" def upgrade(): """Fix broken foreign-key constraint for existing SQLite DBs.""" conn = op.get_bind() - if conn.dialect.name == 'sqlite': + if conn.dialect.name == "sqlite": # Fix broken foreign-key constraint for existing SQLite DBs. # # Re-define tables and use copy_from to avoid reflection @@ -45,27 +45,27 @@ def upgrade(): # # Use batch_alter_table to support SQLite workaround. chart_table = sa.Table( - 'chart', + "chart", sa.MetaData(), - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('label', sa.String(length=200), nullable=True), - sa.Column('conn_id', sa.String(length=250), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=True), - sa.Column('chart_type', sa.String(length=100), nullable=True), - sa.Column('sql_layout', sa.String(length=50), nullable=True), - sa.Column('sql', sa.Text(), nullable=True), - sa.Column('y_log_scale', sa.Boolean(), nullable=True), - sa.Column('show_datatable', sa.Boolean(), nullable=True), - sa.Column('show_sql', sa.Boolean(), nullable=True), - sa.Column('height', sa.Integer(), nullable=True), - sa.Column('default_params', sa.String(length=5000), nullable=True), - sa.Column('x_is_date', sa.Boolean(), nullable=True), - sa.Column('iteration_no', sa.Integer(), nullable=True), - sa.Column('last_modified', sa.DateTime(), nullable=True), - sa.PrimaryKeyConstraint('id'), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("label", sa.String(length=200), nullable=True), + sa.Column("conn_id", sa.String(length=250), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=True), + sa.Column("chart_type", sa.String(length=100), nullable=True), + sa.Column("sql_layout", sa.String(length=50), nullable=True), + sa.Column("sql", sa.Text(), nullable=True), + sa.Column("y_log_scale", sa.Boolean(), nullable=True), + sa.Column("show_datatable", sa.Boolean(), nullable=True), + sa.Column("show_sql", sa.Boolean(), nullable=True), + sa.Column("height", sa.Integer(), nullable=True), + sa.Column("default_params", sa.String(length=5000), nullable=True), + sa.Column("x_is_date", sa.Boolean(), nullable=True), + sa.Column("iteration_no", sa.Integer(), nullable=True), + sa.Column("last_modified", sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - with op.batch_alter_table('chart', copy_from=chart_table) as batch_op: - batch_op.create_foreign_key('chart_user_id_fkey', 'users', ['user_id'], ['id']) + with op.batch_alter_table("chart", copy_from=chart_table) as batch_op: + batch_op.create_foreign_key("chart_user_id_fkey", "users", ["user_id"], ["id"]) def downgrade(): diff --git a/airflow/migrations/versions/0034_1_10_0_index_taskfail.py b/airflow/migrations/versions/0034_1_10_0_index_taskfail.py index 1323e77f7f040..084cb0d925ac9 100644 --- a/airflow/migrations/versions/0034_1_10_0_index_taskfail.py +++ b/airflow/migrations/versions/0034_1_10_0_index_taskfail.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Create index on ``task_fail`` table Revision ID: 9635ae0956e7 @@ -22,21 +21,23 @@ Create Date: 2018-06-17 21:40:01.963540 """ +from __future__ import annotations + from alembic import op # revision identifiers, used by Alembic. -revision = '9635ae0956e7' -down_revision = '856955da8476' +revision = "9635ae0956e7" +down_revision = "856955da8476" branch_labels = None depends_on = None -airflow_version = '1.10.0' +airflow_version = "1.10.0" def upgrade(): op.create_index( - 'idx_task_fail_dag_task_date', 'task_fail', ['dag_id', 'task_id', 'execution_date'], unique=False + "idx_task_fail_dag_task_date", "task_fail", ["dag_id", "task_id", "execution_date"], unique=False ) def downgrade(): - op.drop_index('idx_task_fail_dag_task_date', table_name='task_fail') + op.drop_index("idx_task_fail_dag_task_date", table_name="task_fail") diff --git a/airflow/migrations/versions/0035_1_10_2_add_idx_log_dag.py b/airflow/migrations/versions/0035_1_10_2_add_idx_log_dag.py index 76fcfd375b9e2..149c6ca8d52e3 100644 --- a/airflow/migrations/versions/0035_1_10_2_add_idx_log_dag.py +++ b/airflow/migrations/versions/0035_1_10_2_add_idx_log_dag.py @@ -22,19 +22,21 @@ Create Date: 2018-08-07 06:41:41.028249 """ +from __future__ import annotations + from alembic import op # revision identifiers, used by Alembic. -revision = 'dd25f486b8ea' -down_revision = '9635ae0956e7' +revision = "dd25f486b8ea" +down_revision = "9635ae0956e7" branch_labels = None depends_on = None -airflow_version = '1.10.2' +airflow_version = "1.10.2" def upgrade(): - op.create_index('idx_log_dag', 'log', ['dag_id'], unique=False) + op.create_index("idx_log_dag", "log", ["dag_id"], unique=False) def downgrade(): - op.drop_index('idx_log_dag', table_name='log') + op.drop_index("idx_log_dag", table_name="log") diff --git a/airflow/migrations/versions/0036_1_10_2_add_index_to_taskinstance.py b/airflow/migrations/versions/0036_1_10_2_add_index_to_taskinstance.py index a52663dcc9dea..f0453d0e8fb36 100644 --- a/airflow/migrations/versions/0036_1_10_2_add_index_to_taskinstance.py +++ b/airflow/migrations/versions/0036_1_10_2_add_index_to_taskinstance.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add index to ``task_instance`` table Revision ID: bf00311e1990 @@ -23,20 +22,21 @@ Create Date: 2018-09-12 09:53:52.007433 """ +from __future__ import annotations from alembic import op # revision identifiers, used by Alembic. -revision = 'bf00311e1990' -down_revision = 'dd25f486b8ea' +revision = "bf00311e1990" +down_revision = "dd25f486b8ea" branch_labels = None depends_on = None -airflow_version = '1.10.2' +airflow_version = "1.10.2" def upgrade(): - op.create_index('ti_dag_date', 'task_instance', ['dag_id', 'execution_date'], unique=False) + op.create_index("ti_dag_date", "task_instance", ["dag_id", "execution_date"], unique=False) def downgrade(): - op.drop_index('ti_dag_date', table_name='task_instance') + op.drop_index("ti_dag_date", table_name="task_instance") diff --git a/airflow/migrations/versions/0037_1_10_2_add_task_reschedule_table.py b/airflow/migrations/versions/0037_1_10_2_add_task_reschedule_table.py index a26204fff122e..c30916725f154 100644 --- a/airflow/migrations/versions/0037_1_10_2_add_task_reschedule_table.py +++ b/airflow/migrations/versions/0037_1_10_2_add_task_reschedule_table.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``task_reschedule`` table Revision ID: 0a2a5b66e19d @@ -22,49 +21,51 @@ Create Date: 2018-06-17 22:50:00.053620 """ +from __future__ import annotations + import sqlalchemy as sa from alembic import op from airflow.migrations.db_types import TIMESTAMP, StringID # revision identifiers, used by Alembic. -revision = '0a2a5b66e19d' -down_revision = '9635ae0956e7' +revision = "0a2a5b66e19d" +down_revision = "9635ae0956e7" branch_labels = None depends_on = None -airflow_version = '1.10.2' +airflow_version = "1.10.2" -TABLE_NAME = 'task_reschedule' -INDEX_NAME = 'idx_' + TABLE_NAME + '_dag_task_date' +TABLE_NAME = "task_reschedule" +INDEX_NAME = "idx_" + TABLE_NAME + "_dag_task_date" def upgrade(): # See 0e2a74e0fc9f_add_time_zone_awareness timestamp = TIMESTAMP - if op.get_bind().dialect.name == 'mssql': + if op.get_bind().dialect.name == "mssql": # We need to keep this as it was for this old migration on mssql timestamp = sa.DateTime() op.create_table( TABLE_NAME, - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('task_id', StringID(), nullable=False), - sa.Column('dag_id', StringID(), nullable=False), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("task_id", StringID(), nullable=False), + sa.Column("dag_id", StringID(), nullable=False), # use explicit server_default=None otherwise mysql implies defaults for first timestamp column - sa.Column('execution_date', timestamp, nullable=False, server_default=None), - sa.Column('try_number', sa.Integer(), nullable=False), - sa.Column('start_date', timestamp, nullable=False), - sa.Column('end_date', timestamp, nullable=False), - sa.Column('duration', sa.Integer(), nullable=False), - sa.Column('reschedule_date', timestamp, nullable=False), - sa.PrimaryKeyConstraint('id'), + sa.Column("execution_date", timestamp, nullable=False, server_default=None), + sa.Column("try_number", sa.Integer(), nullable=False), + sa.Column("start_date", timestamp, nullable=False), + sa.Column("end_date", timestamp, nullable=False), + sa.Column("duration", sa.Integer(), nullable=False), + sa.Column("reschedule_date", timestamp, nullable=False), + sa.PrimaryKeyConstraint("id"), sa.ForeignKeyConstraint( - ['task_id', 'dag_id', 'execution_date'], - ['task_instance.task_id', 'task_instance.dag_id', 'task_instance.execution_date'], - name='task_reschedule_dag_task_date_fkey', + ["task_id", "dag_id", "execution_date"], + ["task_instance.task_id", "task_instance.dag_id", "task_instance.execution_date"], + name="task_reschedule_dag_task_date_fkey", ), ) - op.create_index(INDEX_NAME, TABLE_NAME, ['dag_id', 'task_id', 'execution_date'], unique=False) + op.create_index(INDEX_NAME, TABLE_NAME, ["dag_id", "task_id", "execution_date"], unique=False) def downgrade(): diff --git a/airflow/migrations/versions/0038_1_10_2_add_sm_dag_index.py b/airflow/migrations/versions/0038_1_10_2_add_sm_dag_index.py index a4bf9a03e1985..f3188901e46d6 100644 --- a/airflow/migrations/versions/0038_1_10_2_add_sm_dag_index.py +++ b/airflow/migrations/versions/0038_1_10_2_add_sm_dag_index.py @@ -14,28 +14,28 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -"""Merge migrations Heads +"""Merge migrations Heads. Revision ID: 03bc53e68815 Revises: 0a2a5b66e19d, bf00311e1990 Create Date: 2018-11-24 20:21:46.605414 """ +from __future__ import annotations from alembic import op # revision identifiers, used by Alembic. -revision = '03bc53e68815' -down_revision = ('0a2a5b66e19d', 'bf00311e1990') +revision = "03bc53e68815" +down_revision = ("0a2a5b66e19d", "bf00311e1990") branch_labels = None depends_on = None -airflow_version = '1.10.2' +airflow_version = "1.10.2" def upgrade(): - op.create_index('sm_dag', 'sla_miss', ['dag_id'], unique=False) + op.create_index("sm_dag", "sla_miss", ["dag_id"], unique=False) def downgrade(): - op.drop_index('sm_dag', table_name='sla_miss') + op.drop_index("sm_dag", table_name="sla_miss") diff --git a/airflow/migrations/versions/0039_1_10_2_add_superuser_field.py b/airflow/migrations/versions/0039_1_10_2_add_superuser_field.py index 11ea2938deb7d..314be0bba8c65 100644 --- a/airflow/migrations/versions/0039_1_10_2_add_superuser_field.py +++ b/airflow/migrations/versions/0039_1_10_2_add_superuser_field.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add superuser field Revision ID: 41f5f12752f8 @@ -22,21 +21,22 @@ Create Date: 2018-12-04 15:50:04.456875 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = '41f5f12752f8' -down_revision = '03bc53e68815' +revision = "41f5f12752f8" +down_revision = "03bc53e68815" branch_labels = None depends_on = None -airflow_version = '1.10.2' +airflow_version = "1.10.2" def upgrade(): - op.add_column('users', sa.Column('superuser', sa.Boolean(), default=False)) + op.add_column("users", sa.Column("superuser", sa.Boolean(), default=False)) def downgrade(): - op.drop_column('users', 'superuser') + op.drop_column("users", "superuser") diff --git a/airflow/migrations/versions/0040_1_10_3_add_fields_to_dag.py b/airflow/migrations/versions/0040_1_10_3_add_fields_to_dag.py index a1eb00f14fb52..2c9c2d04a1bb1 100644 --- a/airflow/migrations/versions/0040_1_10_3_add_fields_to_dag.py +++ b/airflow/migrations/versions/0040_1_10_3_add_fields_to_dag.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``description`` and ``default_view`` column to ``dag`` table Revision ID: c8ffec048a3b @@ -23,23 +22,24 @@ Create Date: 2018-12-23 21:55:46.463634 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = 'c8ffec048a3b' -down_revision = '41f5f12752f8' +revision = "c8ffec048a3b" +down_revision = "41f5f12752f8" branch_labels = None depends_on = None -airflow_version = '1.10.3' +airflow_version = "1.10.3" def upgrade(): - op.add_column('dag', sa.Column('description', sa.Text(), nullable=True)) - op.add_column('dag', sa.Column('default_view', sa.String(25), nullable=True)) + op.add_column("dag", sa.Column("description", sa.Text(), nullable=True)) + op.add_column("dag", sa.Column("default_view", sa.String(25), nullable=True)) def downgrade(): - op.drop_column('dag', 'description') - op.drop_column('dag', 'default_view') + op.drop_column("dag", "description") + op.drop_column("dag", "default_view") diff --git a/airflow/migrations/versions/0041_1_10_3_add_schedule_interval_to_dag.py b/airflow/migrations/versions/0041_1_10_3_add_schedule_interval_to_dag.py index 42e0a5db92882..2f741f0e65b65 100644 --- a/airflow/migrations/versions/0041_1_10_3_add_schedule_interval_to_dag.py +++ b/airflow/migrations/versions/0041_1_10_3_add_schedule_interval_to_dag.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add schedule interval to dag Revision ID: dd4ecb8fbee3 @@ -23,21 +22,22 @@ Create Date: 2018-12-27 18:39:25.748032 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = 'dd4ecb8fbee3' -down_revision = 'c8ffec048a3b' +revision = "dd4ecb8fbee3" +down_revision = "c8ffec048a3b" branch_labels = None depends_on = None -airflow_version = '1.10.3' +airflow_version = "1.10.3" def upgrade(): - op.add_column('dag', sa.Column('schedule_interval', sa.Text(), nullable=True)) + op.add_column("dag", sa.Column("schedule_interval", sa.Text(), nullable=True)) def downgrade(): - op.drop_column('dag', 'schedule_interval') + op.drop_column("dag", "schedule_interval") diff --git a/airflow/migrations/versions/0042_1_10_3_task_reschedule_fk_on_cascade_delete.py b/airflow/migrations/versions/0042_1_10_3_task_reschedule_fk_on_cascade_delete.py index 9a535cbc99c48..9c42ef2b188ff 100644 --- a/airflow/migrations/versions/0042_1_10_3_task_reschedule_fk_on_cascade_delete.py +++ b/airflow/migrations/versions/0042_1_10_3_task_reschedule_fk_on_cascade_delete.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """task reschedule foreign key on cascade delete Revision ID: 939bb1e647c8 @@ -22,35 +21,36 @@ Create Date: 2019-02-04 20:21:50.669751 """ +from __future__ import annotations from alembic import op # revision identifiers, used by Alembic. -revision = '939bb1e647c8' -down_revision = 'dd4ecb8fbee3' +revision = "939bb1e647c8" +down_revision = "dd4ecb8fbee3" branch_labels = None depends_on = None -airflow_version = '1.10.3' +airflow_version = "1.10.3" def upgrade(): - with op.batch_alter_table('task_reschedule') as batch_op: - batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', type_='foreignkey') + with op.batch_alter_table("task_reschedule") as batch_op: + batch_op.drop_constraint("task_reschedule_dag_task_date_fkey", type_="foreignkey") batch_op.create_foreign_key( - 'task_reschedule_dag_task_date_fkey', - 'task_instance', - ['task_id', 'dag_id', 'execution_date'], - ['task_id', 'dag_id', 'execution_date'], - ondelete='CASCADE', + "task_reschedule_dag_task_date_fkey", + "task_instance", + ["task_id", "dag_id", "execution_date"], + ["task_id", "dag_id", "execution_date"], + ondelete="CASCADE", ) def downgrade(): - with op.batch_alter_table('task_reschedule') as batch_op: - batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', type_='foreignkey') + with op.batch_alter_table("task_reschedule") as batch_op: + batch_op.drop_constraint("task_reschedule_dag_task_date_fkey", type_="foreignkey") batch_op.create_foreign_key( - 'task_reschedule_dag_task_date_fkey', - 'task_instance', - ['task_id', 'dag_id', 'execution_date'], - ['task_id', 'dag_id', 'execution_date'], + "task_reschedule_dag_task_date_fkey", + "task_instance", + ["task_id", "dag_id", "execution_date"], + ["task_id", "dag_id", "execution_date"], ) diff --git a/airflow/migrations/versions/0043_1_10_4_make_taskinstance_pool_not_nullable.py b/airflow/migrations/versions/0043_1_10_4_make_taskinstance_pool_not_nullable.py index 53261af4a87e8..52069c56daba3 100644 --- a/airflow/migrations/versions/0043_1_10_4_make_taskinstance_pool_not_nullable.py +++ b/airflow/migrations/versions/0043_1_10_4_make_taskinstance_pool_not_nullable.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Make ``TaskInstance.pool`` not nullable Revision ID: 6e96a59344a4 @@ -23,6 +22,7 @@ Create Date: 2019-06-13 21:51:32.878437 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op @@ -33,11 +33,11 @@ from airflow.utils.sqlalchemy import UtcDateTime # revision identifiers, used by Alembic. -revision = '6e96a59344a4' -down_revision = '939bb1e647c8' +revision = "6e96a59344a4" +down_revision = "939bb1e647c8" branch_labels = None depends_on = None -airflow_version = '1.10.4' +airflow_version = "1.10.4" Base = declarative_base() ID_LEN = 250 @@ -58,45 +58,45 @@ def upgrade(): """Make TaskInstance.pool field not nullable.""" with create_session() as session: session.query(TaskInstance).filter(TaskInstance.pool.is_(None)).update( - {TaskInstance.pool: 'default_pool'}, synchronize_session=False + {TaskInstance.pool: "default_pool"}, synchronize_session=False ) # Avoid select updated rows session.commit() conn = op.get_bind() if conn.dialect.name == "mssql": - op.drop_index('ti_pool', table_name='task_instance') + op.drop_index("ti_pool", table_name="task_instance") # use batch_alter_table to support SQLite workaround - with op.batch_alter_table('task_instance') as batch_op: + with op.batch_alter_table("task_instance") as batch_op: batch_op.alter_column( - column_name='pool', + column_name="pool", type_=sa.String(50), nullable=False, ) if conn.dialect.name == "mssql": - op.create_index('ti_pool', 'task_instance', ['pool', 'state', 'priority_weight']) + op.create_index("ti_pool", "task_instance", ["pool", "state", "priority_weight"]) def downgrade(): """Make TaskInstance.pool field nullable.""" conn = op.get_bind() if conn.dialect.name == "mssql": - op.drop_index('ti_pool', table_name='task_instance') + op.drop_index("ti_pool", table_name="task_instance") # use batch_alter_table to support SQLite workaround - with op.batch_alter_table('task_instance') as batch_op: + with op.batch_alter_table("task_instance") as batch_op: batch_op.alter_column( - column_name='pool', + column_name="pool", type_=sa.String(50), nullable=True, ) if conn.dialect.name == "mssql": - op.create_index('ti_pool', 'task_instance', ['pool', 'state', 'priority_weight']) + op.create_index("ti_pool", "task_instance", ["pool", "state", "priority_weight"]) with create_session() as session: - session.query(TaskInstance).filter(TaskInstance.pool == 'default_pool').update( + session.query(TaskInstance).filter(TaskInstance.pool == "default_pool").update( {TaskInstance.pool: None}, synchronize_session=False ) # Avoid select updated rows session.commit() diff --git a/airflow/migrations/versions/0044_1_10_7_add_serialized_dag_table.py b/airflow/migrations/versions/0044_1_10_7_add_serialized_dag_table.py index 38179f2a7033a..2feb636de5e07 100644 --- a/airflow/migrations/versions/0044_1_10_7_add_serialized_dag_table.py +++ b/airflow/migrations/versions/0044_1_10_7_add_serialized_dag_table.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``serialized_dag`` table Revision ID: d38e04c12aa2 @@ -23,6 +22,8 @@ Create Date: 2019-08-01 14:39:35.616417 """ +from __future__ import annotations + import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import mysql @@ -30,11 +31,11 @@ from airflow.migrations.db_types import StringID # revision identifiers, used by Alembic. -revision = 'd38e04c12aa2' -down_revision = '6e96a59344a4' +revision = "d38e04c12aa2" +down_revision = "6e96a59344a4" branch_labels = None depends_on = None -airflow_version = '1.10.7' +airflow_version = "1.10.7" def upgrade(): @@ -51,15 +52,15 @@ def upgrade(): json_type = sa.Text op.create_table( - 'serialized_dag', - sa.Column('dag_id', StringID(), nullable=False), - sa.Column('fileloc', sa.String(length=2000), nullable=False), - sa.Column('fileloc_hash', sa.Integer(), nullable=False), - sa.Column('data', json_type(), nullable=False), - sa.Column('last_updated', sa.DateTime(), nullable=False), - sa.PrimaryKeyConstraint('dag_id'), + "serialized_dag", + sa.Column("dag_id", StringID(), nullable=False), + sa.Column("fileloc", sa.String(length=2000), nullable=False), + sa.Column("fileloc_hash", sa.Integer(), nullable=False), + sa.Column("data", json_type(), nullable=False), + sa.Column("last_updated", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("dag_id"), ) - op.create_index('idx_fileloc_hash', 'serialized_dag', ['fileloc_hash']) + op.create_index("idx_fileloc_hash", "serialized_dag", ["fileloc_hash"]) if conn.dialect.name == "mysql": conn.execute("SET time_zone = '+00:00'") @@ -93,4 +94,4 @@ def upgrade(): def downgrade(): """Downgrade version.""" - op.drop_table('serialized_dag') + op.drop_table("serialized_dag") diff --git a/airflow/migrations/versions/0045_1_10_7_add_root_dag_id_to_dag.py b/airflow/migrations/versions/0045_1_10_7_add_root_dag_id_to_dag.py index f879450369ab8..8c230df3830da 100644 --- a/airflow/migrations/versions/0045_1_10_7_add_root_dag_id_to_dag.py +++ b/airflow/migrations/versions/0045_1_10_7_add_root_dag_id_to_dag.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``root_dag_id`` to ``DAG`` Revision ID: b3b105409875 @@ -23,6 +22,7 @@ Create Date: 2019-09-28 23:20:01.744775 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op @@ -30,20 +30,20 @@ from airflow.migrations.db_types import StringID # revision identifiers, used by Alembic. -revision = 'b3b105409875' -down_revision = 'd38e04c12aa2' +revision = "b3b105409875" +down_revision = "d38e04c12aa2" branch_labels = None depends_on = None -airflow_version = '1.10.7' +airflow_version = "1.10.7" def upgrade(): """Apply Add ``root_dag_id`` to ``DAG``""" - op.add_column('dag', sa.Column('root_dag_id', StringID(), nullable=True)) - op.create_index('idx_root_dag_id', 'dag', ['root_dag_id'], unique=False) + op.add_column("dag", sa.Column("root_dag_id", StringID(), nullable=True)) + op.create_index("idx_root_dag_id", "dag", ["root_dag_id"], unique=False) def downgrade(): """Unapply Add ``root_dag_id`` to ``DAG``""" - op.drop_index('idx_root_dag_id', table_name='dag') - op.drop_column('dag', 'root_dag_id') + op.drop_index("idx_root_dag_id", table_name="dag") + op.drop_column("dag", "root_dag_id") diff --git a/airflow/migrations/versions/0046_1_10_5_change_datetime_to_datetime2_6_on_mssql_.py b/airflow/migrations/versions/0046_1_10_5_change_datetime_to_datetime2_6_on_mssql_.py index 49e3e9b84a7e0..3b2970aae7139 100644 --- a/airflow/migrations/versions/0046_1_10_5_change_datetime_to_datetime2_6_on_mssql_.py +++ b/airflow/migrations/versions/0046_1_10_5_change_datetime_to_datetime2_6_on_mssql_.py @@ -15,14 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -"""change datetime to datetime2(6) on MSSQL tables +"""change datetime to datetime2(6) on MSSQL tables. Revision ID: 74effc47d867 Revises: 6e96a59344a4 Create Date: 2019-08-01 15:19:57.585620 """ +from __future__ import annotations from collections import defaultdict @@ -30,15 +30,15 @@ from sqlalchemy.dialects import mssql # revision identifiers, used by Alembic. -revision = '74effc47d867' -down_revision = '6e96a59344a4' +revision = "74effc47d867" +down_revision = "6e96a59344a4" branch_labels = None depends_on = None -airflow_version = '1.10.5' +airflow_version = "1.10.5" def upgrade(): - """Change datetime to datetime2(6) when using MSSQL as backend""" + """Change datetime to datetime2(6) when using MSSQL as backend.""" conn = op.get_bind() if conn.dialect.name == "mssql": result = conn.execute( @@ -50,106 +50,106 @@ def upgrade(): if mssql_version in ("2000", "2005"): return - with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op: - task_reschedule_batch_op.drop_index('idx_task_reschedule_dag_task_date') - task_reschedule_batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', type_='foreignkey') + with op.batch_alter_table("task_reschedule") as task_reschedule_batch_op: + task_reschedule_batch_op.drop_index("idx_task_reschedule_dag_task_date") + task_reschedule_batch_op.drop_constraint("task_reschedule_dag_task_date_fkey", type_="foreignkey") task_reschedule_batch_op.alter_column( column_name="execution_date", type_=mssql.DATETIME2(precision=6), nullable=False, ) task_reschedule_batch_op.alter_column( - column_name='start_date', type_=mssql.DATETIME2(precision=6) + column_name="start_date", type_=mssql.DATETIME2(precision=6) ) - task_reschedule_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME2(precision=6)) + task_reschedule_batch_op.alter_column(column_name="end_date", type_=mssql.DATETIME2(precision=6)) task_reschedule_batch_op.alter_column( - column_name='reschedule_date', type_=mssql.DATETIME2(precision=6) + column_name="reschedule_date", type_=mssql.DATETIME2(precision=6) ) - with op.batch_alter_table('task_instance') as task_instance_batch_op: - task_instance_batch_op.drop_index('ti_state_lkp') - task_instance_batch_op.drop_index('ti_dag_date') + with op.batch_alter_table("task_instance") as task_instance_batch_op: + task_instance_batch_op.drop_index("ti_state_lkp") + task_instance_batch_op.drop_index("ti_dag_date") modify_execution_date_with_constraint( - conn, task_instance_batch_op, 'task_instance', mssql.DATETIME2(precision=6), False + conn, task_instance_batch_op, "task_instance", mssql.DATETIME2(precision=6), False ) - task_instance_batch_op.alter_column(column_name='start_date', type_=mssql.DATETIME2(precision=6)) - task_instance_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME2(precision=6)) - task_instance_batch_op.alter_column(column_name='queued_dttm', type_=mssql.DATETIME2(precision=6)) + task_instance_batch_op.alter_column(column_name="start_date", type_=mssql.DATETIME2(precision=6)) + task_instance_batch_op.alter_column(column_name="end_date", type_=mssql.DATETIME2(precision=6)) + task_instance_batch_op.alter_column(column_name="queued_dttm", type_=mssql.DATETIME2(precision=6)) task_instance_batch_op.create_index( - 'ti_state_lkp', ['dag_id', 'task_id', 'execution_date'], unique=False + "ti_state_lkp", ["dag_id", "task_id", "execution_date"], unique=False ) - task_instance_batch_op.create_index('ti_dag_date', ['dag_id', 'execution_date'], unique=False) + task_instance_batch_op.create_index("ti_dag_date", ["dag_id", "execution_date"], unique=False) - with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op: + with op.batch_alter_table("task_reschedule") as task_reschedule_batch_op: task_reschedule_batch_op.create_foreign_key( - 'task_reschedule_dag_task_date_fkey', - 'task_instance', - ['task_id', 'dag_id', 'execution_date'], - ['task_id', 'dag_id', 'execution_date'], - ondelete='CASCADE', + "task_reschedule_dag_task_date_fkey", + "task_instance", + ["task_id", "dag_id", "execution_date"], + ["task_id", "dag_id", "execution_date"], + ondelete="CASCADE", ) task_reschedule_batch_op.create_index( - 'idx_task_reschedule_dag_task_date', ['dag_id', 'task_id', 'execution_date'], unique=False + "idx_task_reschedule_dag_task_date", ["dag_id", "task_id", "execution_date"], unique=False ) - with op.batch_alter_table('dag_run') as dag_run_batch_op: + with op.batch_alter_table("dag_run") as dag_run_batch_op: modify_execution_date_with_constraint( - conn, dag_run_batch_op, 'dag_run', mssql.DATETIME2(precision=6), None + conn, dag_run_batch_op, "dag_run", mssql.DATETIME2(precision=6), None ) - dag_run_batch_op.alter_column(column_name='start_date', type_=mssql.DATETIME2(precision=6)) - dag_run_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME2(precision=6)) + dag_run_batch_op.alter_column(column_name="start_date", type_=mssql.DATETIME2(precision=6)) + dag_run_batch_op.alter_column(column_name="end_date", type_=mssql.DATETIME2(precision=6)) op.alter_column(table_name="log", column_name="execution_date", type_=mssql.DATETIME2(precision=6)) - op.alter_column(table_name='log', column_name='dttm', type_=mssql.DATETIME2(precision=6)) + op.alter_column(table_name="log", column_name="dttm", type_=mssql.DATETIME2(precision=6)) - with op.batch_alter_table('sla_miss') as sla_miss_batch_op: + with op.batch_alter_table("sla_miss") as sla_miss_batch_op: modify_execution_date_with_constraint( - conn, sla_miss_batch_op, 'sla_miss', mssql.DATETIME2(precision=6), False + conn, sla_miss_batch_op, "sla_miss", mssql.DATETIME2(precision=6), False ) - sla_miss_batch_op.alter_column(column_name='timestamp', type_=mssql.DATETIME2(precision=6)) + sla_miss_batch_op.alter_column(column_name="timestamp", type_=mssql.DATETIME2(precision=6)) - op.drop_index('idx_task_fail_dag_task_date', table_name='task_fail') + op.drop_index("idx_task_fail_dag_task_date", table_name="task_fail") op.alter_column( table_name="task_fail", column_name="execution_date", type_=mssql.DATETIME2(precision=6) ) - op.alter_column(table_name='task_fail', column_name='start_date', type_=mssql.DATETIME2(precision=6)) - op.alter_column(table_name='task_fail', column_name='end_date', type_=mssql.DATETIME2(precision=6)) + op.alter_column(table_name="task_fail", column_name="start_date", type_=mssql.DATETIME2(precision=6)) + op.alter_column(table_name="task_fail", column_name="end_date", type_=mssql.DATETIME2(precision=6)) op.create_index( - 'idx_task_fail_dag_task_date', 'task_fail', ['dag_id', 'task_id', 'execution_date'], unique=False + "idx_task_fail_dag_task_date", "task_fail", ["dag_id", "task_id", "execution_date"], unique=False ) - op.drop_index('idx_xcom_dag_task_date', table_name='xcom') + op.drop_index("idx_xcom_dag_task_date", table_name="xcom") op.alter_column(table_name="xcom", column_name="execution_date", type_=mssql.DATETIME2(precision=6)) - op.alter_column(table_name='xcom', column_name='timestamp', type_=mssql.DATETIME2(precision=6)) + op.alter_column(table_name="xcom", column_name="timestamp", type_=mssql.DATETIME2(precision=6)) op.create_index( - 'idx_xcom_dag_task_date', 'xcom', ['dag_id', 'task_id', 'execution_date'], unique=False + "idx_xcom_dag_task_date", "xcom", ["dag_id", "task_id", "execution_date"], unique=False ) op.alter_column( - table_name='dag', column_name='last_scheduler_run', type_=mssql.DATETIME2(precision=6) + table_name="dag", column_name="last_scheduler_run", type_=mssql.DATETIME2(precision=6) ) - op.alter_column(table_name='dag', column_name='last_pickled', type_=mssql.DATETIME2(precision=6)) - op.alter_column(table_name='dag', column_name='last_expired', type_=mssql.DATETIME2(precision=6)) + op.alter_column(table_name="dag", column_name="last_pickled", type_=mssql.DATETIME2(precision=6)) + op.alter_column(table_name="dag", column_name="last_expired", type_=mssql.DATETIME2(precision=6)) op.alter_column( - table_name='dag_pickle', column_name='created_dttm', type_=mssql.DATETIME2(precision=6) + table_name="dag_pickle", column_name="created_dttm", type_=mssql.DATETIME2(precision=6) ) op.alter_column( - table_name='import_error', column_name='timestamp', type_=mssql.DATETIME2(precision=6) + table_name="import_error", column_name="timestamp", type_=mssql.DATETIME2(precision=6) ) - op.drop_index('job_type_heart', table_name='job') - op.drop_index('idx_job_state_heartbeat', table_name='job') - op.alter_column(table_name='job', column_name='start_date', type_=mssql.DATETIME2(precision=6)) - op.alter_column(table_name='job', column_name='end_date', type_=mssql.DATETIME2(precision=6)) - op.alter_column(table_name='job', column_name='latest_heartbeat', type_=mssql.DATETIME2(precision=6)) - op.create_index('idx_job_state_heartbeat', 'job', ['state', 'latest_heartbeat'], unique=False) - op.create_index('job_type_heart', 'job', ['job_type', 'latest_heartbeat'], unique=False) + op.drop_index("job_type_heart", table_name="job") + op.drop_index("idx_job_state_heartbeat", table_name="job") + op.alter_column(table_name="job", column_name="start_date", type_=mssql.DATETIME2(precision=6)) + op.alter_column(table_name="job", column_name="end_date", type_=mssql.DATETIME2(precision=6)) + op.alter_column(table_name="job", column_name="latest_heartbeat", type_=mssql.DATETIME2(precision=6)) + op.create_index("idx_job_state_heartbeat", "job", ["state", "latest_heartbeat"], unique=False) + op.create_index("job_type_heart", "job", ["job_type", "latest_heartbeat"], unique=False) def downgrade(): - """Change datetime2(6) back to datetime""" + """Change datetime2(6) back to datetime.""" conn = op.get_bind() if conn.dialect.name == "mssql": result = conn.execute( @@ -161,88 +161,89 @@ def downgrade(): if mssql_version in ("2000", "2005"): return - with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op: - task_reschedule_batch_op.drop_index('idx_task_reschedule_dag_task_date') - task_reschedule_batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', type_='foreignkey') + with op.batch_alter_table("task_reschedule") as task_reschedule_batch_op: + task_reschedule_batch_op.drop_index("idx_task_reschedule_dag_task_date") + task_reschedule_batch_op.drop_constraint("task_reschedule_dag_task_date_fkey", type_="foreignkey") task_reschedule_batch_op.alter_column( column_name="execution_date", type_=mssql.DATETIME, nullable=False ) - task_reschedule_batch_op.alter_column(column_name='start_date', type_=mssql.DATETIME) - task_reschedule_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME) - task_reschedule_batch_op.alter_column(column_name='reschedule_date', type_=mssql.DATETIME) + task_reschedule_batch_op.alter_column(column_name="start_date", type_=mssql.DATETIME) + task_reschedule_batch_op.alter_column(column_name="end_date", type_=mssql.DATETIME) + task_reschedule_batch_op.alter_column(column_name="reschedule_date", type_=mssql.DATETIME) - with op.batch_alter_table('task_instance') as task_instance_batch_op: - task_instance_batch_op.drop_index('ti_state_lkp') - task_instance_batch_op.drop_index('ti_dag_date') + with op.batch_alter_table("task_instance") as task_instance_batch_op: + task_instance_batch_op.drop_index("ti_state_lkp") + task_instance_batch_op.drop_index("ti_dag_date") modify_execution_date_with_constraint( - conn, task_instance_batch_op, 'task_instance', mssql.DATETIME, False + conn, task_instance_batch_op, "task_instance", mssql.DATETIME, False ) - task_instance_batch_op.alter_column(column_name='start_date', type_=mssql.DATETIME) - task_instance_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME) - task_instance_batch_op.alter_column(column_name='queued_dttm', type_=mssql.DATETIME) + task_instance_batch_op.alter_column(column_name="start_date", type_=mssql.DATETIME) + task_instance_batch_op.alter_column(column_name="end_date", type_=mssql.DATETIME) + task_instance_batch_op.alter_column(column_name="queued_dttm", type_=mssql.DATETIME) task_instance_batch_op.create_index( - 'ti_state_lkp', ['dag_id', 'task_id', 'execution_date'], unique=False + "ti_state_lkp", ["dag_id", "task_id", "execution_date"], unique=False ) - task_instance_batch_op.create_index('ti_dag_date', ['dag_id', 'execution_date'], unique=False) + task_instance_batch_op.create_index("ti_dag_date", ["dag_id", "execution_date"], unique=False) - with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op: + with op.batch_alter_table("task_reschedule") as task_reschedule_batch_op: task_reschedule_batch_op.create_foreign_key( - 'task_reschedule_dag_task_date_fkey', - 'task_instance', - ['task_id', 'dag_id', 'execution_date'], - ['task_id', 'dag_id', 'execution_date'], - ondelete='CASCADE', + "task_reschedule_dag_task_date_fkey", + "task_instance", + ["task_id", "dag_id", "execution_date"], + ["task_id", "dag_id", "execution_date"], + ondelete="CASCADE", ) task_reschedule_batch_op.create_index( - 'idx_task_reschedule_dag_task_date', ['dag_id', 'task_id', 'execution_date'], unique=False + "idx_task_reschedule_dag_task_date", ["dag_id", "task_id", "execution_date"], unique=False ) - with op.batch_alter_table('dag_run') as dag_run_batch_op: - modify_execution_date_with_constraint(conn, dag_run_batch_op, 'dag_run', mssql.DATETIME, None) - dag_run_batch_op.alter_column(column_name='start_date', type_=mssql.DATETIME) - dag_run_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME) + with op.batch_alter_table("dag_run") as dag_run_batch_op: + modify_execution_date_with_constraint(conn, dag_run_batch_op, "dag_run", mssql.DATETIME, None) + dag_run_batch_op.alter_column(column_name="start_date", type_=mssql.DATETIME) + dag_run_batch_op.alter_column(column_name="end_date", type_=mssql.DATETIME) op.alter_column(table_name="log", column_name="execution_date", type_=mssql.DATETIME) - op.alter_column(table_name='log', column_name='dttm', type_=mssql.DATETIME) + op.alter_column(table_name="log", column_name="dttm", type_=mssql.DATETIME) - with op.batch_alter_table('sla_miss') as sla_miss_batch_op: - modify_execution_date_with_constraint(conn, sla_miss_batch_op, 'sla_miss', mssql.DATETIME, False) - sla_miss_batch_op.alter_column(column_name='timestamp', type_=mssql.DATETIME) + with op.batch_alter_table("sla_miss") as sla_miss_batch_op: + modify_execution_date_with_constraint(conn, sla_miss_batch_op, "sla_miss", mssql.DATETIME, False) + sla_miss_batch_op.alter_column(column_name="timestamp", type_=mssql.DATETIME) - op.drop_index('idx_task_fail_dag_task_date', table_name='task_fail') + op.drop_index("idx_task_fail_dag_task_date", table_name="task_fail") op.alter_column(table_name="task_fail", column_name="execution_date", type_=mssql.DATETIME) - op.alter_column(table_name='task_fail', column_name='start_date', type_=mssql.DATETIME) - op.alter_column(table_name='task_fail', column_name='end_date', type_=mssql.DATETIME) + op.alter_column(table_name="task_fail", column_name="start_date", type_=mssql.DATETIME) + op.alter_column(table_name="task_fail", column_name="end_date", type_=mssql.DATETIME) op.create_index( - 'idx_task_fail_dag_task_date', 'task_fail', ['dag_id', 'task_id', 'execution_date'], unique=False + "idx_task_fail_dag_task_date", "task_fail", ["dag_id", "task_id", "execution_date"], unique=False ) - op.drop_index('idx_xcom_dag_task_date', table_name='xcom') + op.drop_index("idx_xcom_dag_task_date", table_name="xcom") op.alter_column(table_name="xcom", column_name="execution_date", type_=mssql.DATETIME) - op.alter_column(table_name='xcom', column_name='timestamp', type_=mssql.DATETIME) + op.alter_column(table_name="xcom", column_name="timestamp", type_=mssql.DATETIME) op.create_index( - 'idx_xcom_dag_task_date', 'xcom', ['dag_id', 'task_ild', 'execution_date'], unique=False + "idx_xcom_dag_task_date", "xcom", ["dag_id", "task_ild", "execution_date"], unique=False ) - op.alter_column(table_name='dag', column_name='last_scheduler_run', type_=mssql.DATETIME) - op.alter_column(table_name='dag', column_name='last_pickled', type_=mssql.DATETIME) - op.alter_column(table_name='dag', column_name='last_expired', type_=mssql.DATETIME) + op.alter_column(table_name="dag", column_name="last_scheduler_run", type_=mssql.DATETIME) + op.alter_column(table_name="dag", column_name="last_pickled", type_=mssql.DATETIME) + op.alter_column(table_name="dag", column_name="last_expired", type_=mssql.DATETIME) - op.alter_column(table_name='dag_pickle', column_name='created_dttm', type_=mssql.DATETIME) + op.alter_column(table_name="dag_pickle", column_name="created_dttm", type_=mssql.DATETIME) - op.alter_column(table_name='import_error', column_name='timestamp', type_=mssql.DATETIME) + op.alter_column(table_name="import_error", column_name="timestamp", type_=mssql.DATETIME) - op.drop_index('job_type_heart', table_name='job') - op.drop_index('idx_job_state_heartbeat', table_name='job') - op.alter_column(table_name='job', column_name='start_date', type_=mssql.DATETIME) - op.alter_column(table_name='job', column_name='end_date', type_=mssql.DATETIME) - op.alter_column(table_name='job', column_name='latest_heartbeat', type_=mssql.DATETIME) - op.create_index('idx_job_state_heartbeat', 'job', ['state', 'latest_heartbeat'], unique=False) - op.create_index('job_type_heart', 'job', ['job_type', 'latest_heartbeat'], unique=False) + op.drop_index("job_type_heart", table_name="job") + op.drop_index("idx_job_state_heartbeat", table_name="job") + op.alter_column(table_name="job", column_name="start_date", type_=mssql.DATETIME) + op.alter_column(table_name="job", column_name="end_date", type_=mssql.DATETIME) + op.alter_column(table_name="job", column_name="latest_heartbeat", type_=mssql.DATETIME) + op.create_index("idx_job_state_heartbeat", "job", ["state", "latest_heartbeat"], unique=False) + op.create_index("job_type_heart", "job", ["job_type", "latest_heartbeat"], unique=False) -def get_table_constraints(conn, table_name): - """ +def get_table_constraints(conn, table_name) -> dict[tuple[str, str], list[str]]: + """Return primary and unique constraint along with column name. + This function return primary and unique constraint along with column name. some tables like task_instance is missing primary key constraint name and the name is @@ -252,7 +253,6 @@ def get_table_constraints(conn, table_name): :param conn: sql connection object :param table_name: table name :return: a dictionary of ((constraint name, constraint type), column name) of table - :rtype: defaultdict(list) """ query = f"""SELECT tc.CONSTRAINT_NAME , tc.CONSTRAINT_TYPE, ccu.COLUMN_NAME FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc @@ -268,49 +268,47 @@ def get_table_constraints(conn, table_name): def reorder_columns(columns): - """ - Reorder the columns for creating constraint, preserve primary key ordering + """Reorder the columns for creating constraint. + Preserve primary key ordering ``['task_id', 'dag_id', 'execution_date']`` :param columns: columns retrieved from DB related to constraint :return: ordered column """ ordered_columns = [] - for column in ['task_id', 'dag_id', 'execution_date']: + for column in ["task_id", "dag_id", "execution_date"]: if column in columns: ordered_columns.append(column) for column in columns: - if column not in ['task_id', 'dag_id', 'execution_date']: + if column not in ["task_id", "dag_id", "execution_date"]: ordered_columns.append(column) return ordered_columns def drop_constraint(operator, constraint_dict): - """ - Drop a primary key or unique constraint + """Drop a primary key or unique constraint. :param operator: batch_alter_table for the table :param constraint_dict: a dictionary of ((constraint name, constraint type), column name) of table """ for constraint, columns in constraint_dict.items(): - if 'execution_date' in columns: + if "execution_date" in columns: if constraint[1].lower().startswith("primary"): - operator.drop_constraint(constraint[0], type_='primary') + operator.drop_constraint(constraint[0], type_="primary") elif constraint[1].lower().startswith("unique"): - operator.drop_constraint(constraint[0], type_='unique') + operator.drop_constraint(constraint[0], type_="unique") def create_constraint(operator, constraint_dict): - """ - Create a primary key or unique constraint + """Create a primary key or unique constraint. :param operator: batch_alter_table for the table :param constraint_dict: a dictionary of ((constraint name, constraint type), column name) of table """ for constraint, columns in constraint_dict.items(): - if 'execution_date' in columns: + if "execution_date" in columns: if constraint[1].lower().startswith("primary"): operator.create_primary_key(constraint_name=constraint[0], columns=reorder_columns(columns)) elif constraint[1].lower().startswith("unique"): @@ -319,8 +317,8 @@ def create_constraint(operator, constraint_dict): ) -def modify_execution_date_with_constraint(conn, batch_operator, table_name, type_, nullable): - """ +def modify_execution_date_with_constraint(conn, batch_operator, table_name, type_, nullable) -> None: + """Change type of column execution_date. Helper function changes type of column execution_date by dropping and recreating any primary/unique constraint associated with the column @@ -331,7 +329,6 @@ def modify_execution_date_with_constraint(conn, batch_operator, table_name, type :param type_: DB column type :param nullable: nullable (boolean) :return: a dictionary of ((constraint name, constraint type), column name) of table - :rtype: defaultdict(list) """ constraint_dict = get_table_constraints(conn, table_name) drop_constraint(batch_operator, constraint_dict) diff --git a/airflow/migrations/versions/0047_1_10_4_increase_queue_name_size_limit.py b/airflow/migrations/versions/0047_1_10_4_increase_queue_name_size_limit.py index 50dbc43caea89..806495981a175 100644 --- a/airflow/migrations/versions/0047_1_10_4_increase_queue_name_size_limit.py +++ b/airflow/migrations/versions/0047_1_10_4_increase_queue_name_size_limit.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Increase queue name size limit Revision ID: 004c1210f153 @@ -23,16 +22,17 @@ Create Date: 2019-06-07 07:46:04.262275 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = '004c1210f153' -down_revision = '939bb1e647c8' +revision = "004c1210f153" +down_revision = "939bb1e647c8" branch_labels = None depends_on = None -airflow_version = '1.10.4' +airflow_version = "1.10.4" def upgrade(): @@ -41,12 +41,12 @@ def upgrade(): by broker backends that might use unusually large queue names. """ # use batch_alter_table to support SQLite workaround - with op.batch_alter_table('task_instance') as batch_op: - batch_op.alter_column('queue', type_=sa.String(256)) + with op.batch_alter_table("task_instance") as batch_op: + batch_op.alter_column("queue", type_=sa.String(256)) def downgrade(): """Revert column size from 256 to 50 characters, might result in data loss.""" # use batch_alter_table to support SQLite workaround - with op.batch_alter_table('task_instance') as batch_op: - batch_op.alter_column('queue', type_=sa.String(50)) + with op.batch_alter_table("task_instance") as batch_op: + batch_op.alter_column("queue", type_=sa.String(50)) diff --git a/airflow/migrations/versions/0048_1_10_3_remove_dag_stat_table.py b/airflow/migrations/versions/0048_1_10_3_remove_dag_stat_table.py index 454482e71877a..b837c10487f1a 100644 --- a/airflow/migrations/versions/0048_1_10_3_remove_dag_stat_table.py +++ b/airflow/migrations/versions/0048_1_10_3_remove_dag_stat_table.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Remove ``dag_stat`` table Revision ID: a56c9515abdc @@ -23,16 +22,17 @@ Create Date: 2018-12-27 10:27:59.715872 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = 'a56c9515abdc' -down_revision = 'c8ffec048a3b' +revision = "a56c9515abdc" +down_revision = "c8ffec048a3b" branch_labels = None depends_on = None -airflow_version = '1.10.3' +airflow_version = "1.10.3" def upgrade(): @@ -43,10 +43,10 @@ def upgrade(): def downgrade(): """Create dag_stats table""" op.create_table( - 'dag_stats', - sa.Column('dag_id', sa.String(length=250), nullable=False), - sa.Column('state', sa.String(length=50), nullable=False), - sa.Column('count', sa.Integer(), nullable=False, default=0), - sa.Column('dirty', sa.Boolean(), nullable=False, default=False), - sa.PrimaryKeyConstraint('dag_id', 'state'), + "dag_stats", + sa.Column("dag_id", sa.String(length=250), nullable=False), + sa.Column("state", sa.String(length=50), nullable=False), + sa.Column("count", sa.Integer(), nullable=False, default=0), + sa.Column("dirty", sa.Boolean(), nullable=False, default=False), + sa.PrimaryKeyConstraint("dag_id", "state"), ) diff --git a/airflow/migrations/versions/0049_1_10_7_merge_heads.py b/airflow/migrations/versions/0049_1_10_7_merge_heads.py index 1f589e993fe7a..c600a17ab2d05 100644 --- a/airflow/migrations/versions/0049_1_10_7_merge_heads.py +++ b/airflow/migrations/versions/0049_1_10_7_merge_heads.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Straighten out the migrations Revision ID: 08364691d074 @@ -23,13 +22,14 @@ Create Date: 2019-11-19 22:05:11.752222 """ +from __future__ import annotations # revision identifiers, used by Alembic. -revision = '08364691d074' -down_revision = ('a56c9515abdc', '004c1210f153', '74effc47d867', 'b3b105409875') +revision = "08364691d074" +down_revision = ("a56c9515abdc", "004c1210f153", "74effc47d867", "b3b105409875") branch_labels = None depends_on = None -airflow_version = '1.10.7' +airflow_version = "1.10.7" def upgrade(): diff --git a/airflow/migrations/versions/0050_1_10_7_increase_length_for_connection_password.py b/airflow/migrations/versions/0050_1_10_7_increase_length_for_connection_password.py index 955e1a668623c..c855ce7750411 100644 --- a/airflow/migrations/versions/0050_1_10_7_increase_length_for_connection_password.py +++ b/airflow/migrations/versions/0050_1_10_7_increase_length_for_connection_password.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Increase length for connection password Revision ID: fe461863935f @@ -23,23 +22,24 @@ Create Date: 2019-12-08 09:47:09.033009 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = 'fe461863935f' -down_revision = '08364691d074' +revision = "fe461863935f" +down_revision = "08364691d074" branch_labels = None depends_on = None -airflow_version = '1.10.7' +airflow_version = "1.10.7" def upgrade(): """Apply Increase length for connection password""" - with op.batch_alter_table('connection', schema=None) as batch_op: + with op.batch_alter_table("connection", schema=None) as batch_op: batch_op.alter_column( - 'password', + "password", existing_type=sa.VARCHAR(length=500), type_=sa.String(length=5000), existing_nullable=True, @@ -48,9 +48,9 @@ def upgrade(): def downgrade(): """Unapply Increase length for connection password""" - with op.batch_alter_table('connection', schema=None) as batch_op: + with op.batch_alter_table("connection", schema=None) as batch_op: batch_op.alter_column( - 'password', + "password", existing_type=sa.String(length=5000), type_=sa.VARCHAR(length=500), existing_nullable=True, diff --git a/airflow/migrations/versions/0051_1_10_8_add_dagtags_table.py b/airflow/migrations/versions/0051_1_10_8_add_dagtags_table.py index a7ae10860e8aa..7d1c03f7ba1d0 100644 --- a/airflow/migrations/versions/0051_1_10_8_add_dagtags_table.py +++ b/airflow/migrations/versions/0051_1_10_8_add_dagtags_table.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``DagTags`` table Revision ID: 7939bcff74ba @@ -23,6 +22,7 @@ Create Date: 2020-01-07 19:39:01.247442 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op @@ -30,27 +30,27 @@ from airflow.migrations.db_types import StringID # revision identifiers, used by Alembic. -revision = '7939bcff74ba' -down_revision = 'fe461863935f' +revision = "7939bcff74ba" +down_revision = "fe461863935f" branch_labels = None depends_on = None -airflow_version = '1.10.8' +airflow_version = "1.10.8" def upgrade(): """Apply Add ``DagTags`` table""" op.create_table( - 'dag_tag', - sa.Column('name', sa.String(length=100), nullable=False), - sa.Column('dag_id', StringID(), nullable=False), + "dag_tag", + sa.Column("name", sa.String(length=100), nullable=False), + sa.Column("dag_id", StringID(), nullable=False), sa.ForeignKeyConstraint( - ['dag_id'], - ['dag.dag_id'], + ["dag_id"], + ["dag.dag_id"], ), - sa.PrimaryKeyConstraint('name', 'dag_id'), + sa.PrimaryKeyConstraint("name", "dag_id"), ) def downgrade(): """Unapply Add ``DagTags`` table""" - op.drop_table('dag_tag') + op.drop_table("dag_tag") diff --git a/airflow/migrations/versions/0052_1_10_10_add_pool_slots_field_to_task_instance.py b/airflow/migrations/versions/0052_1_10_10_add_pool_slots_field_to_task_instance.py index 26bff68ba2141..ac209a87c46a6 100644 --- a/airflow/migrations/versions/0052_1_10_10_add_pool_slots_field_to_task_instance.py +++ b/airflow/migrations/versions/0052_1_10_10_add_pool_slots_field_to_task_instance.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``pool_slots`` field to ``task_instance`` Revision ID: a4c2fd67d16b @@ -23,21 +22,22 @@ Create Date: 2020-01-14 03:35:01.161519 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = 'a4c2fd67d16b' -down_revision = '7939bcff74ba' +revision = "a4c2fd67d16b" +down_revision = "7939bcff74ba" branch_labels = None depends_on = None -airflow_version = '1.10.10' +airflow_version = "1.10.10" def upgrade(): - op.add_column('task_instance', sa.Column('pool_slots', sa.Integer, default=1)) + op.add_column("task_instance", sa.Column("pool_slots", sa.Integer, default=1)) def downgrade(): - op.drop_column('task_instance', 'pool_slots') + op.drop_column("task_instance", "pool_slots") diff --git a/airflow/migrations/versions/0053_1_10_10_add_rendered_task_instance_fields_table.py b/airflow/migrations/versions/0053_1_10_10_add_rendered_task_instance_fields_table.py index 2027dd4745fb7..9bffb96b6db48 100644 --- a/airflow/migrations/versions/0053_1_10_10_add_rendered_task_instance_fields_table.py +++ b/airflow/migrations/versions/0053_1_10_10_add_rendered_task_instance_fields_table.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``RenderedTaskInstanceFields`` table Revision ID: 852ae6c715af @@ -23,6 +22,7 @@ Create Date: 2020-03-10 22:19:18.034961 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op @@ -30,13 +30,13 @@ from airflow.migrations.db_types import StringID # revision identifiers, used by Alembic. -revision = '852ae6c715af' -down_revision = 'a4c2fd67d16b' +revision = "852ae6c715af" +down_revision = "a4c2fd67d16b" branch_labels = None depends_on = None -airflow_version = '1.10.10' +airflow_version = "1.10.10" -TABLE_NAME = 'rendered_task_instance_fields' +TABLE_NAME = "rendered_task_instance_fields" def upgrade(): @@ -54,11 +54,11 @@ def upgrade(): op.create_table( TABLE_NAME, - sa.Column('dag_id', StringID(), nullable=False), - sa.Column('task_id', StringID(), nullable=False), - sa.Column('execution_date', sa.TIMESTAMP(timezone=True), nullable=False), - sa.Column('rendered_fields', json_type(), nullable=False), - sa.PrimaryKeyConstraint('dag_id', 'task_id', 'execution_date'), + sa.Column("dag_id", StringID(), nullable=False), + sa.Column("task_id", StringID(), nullable=False), + sa.Column("execution_date", sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column("rendered_fields", json_type(), nullable=False), + sa.PrimaryKeyConstraint("dag_id", "task_id", "execution_date"), ) diff --git a/airflow/migrations/versions/0054_1_10_10_add_dag_code_table.py b/airflow/migrations/versions/0054_1_10_10_add_dag_code_table.py index 523a1a0f8c0e5..e0628d142d9bb 100644 --- a/airflow/migrations/versions/0054_1_10_10_add_dag_code_table.py +++ b/airflow/migrations/versions/0054_1_10_10_add_dag_code_table.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``dag_code`` table Revision ID: 952da73b5eff @@ -23,6 +22,7 @@ Create Date: 2020-03-12 12:39:01.797462 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op @@ -30,11 +30,11 @@ from airflow.models.dagcode import DagCode # revision identifiers, used by Alembic. -revision = '952da73b5eff' -down_revision = '852ae6c715af' +revision = "952da73b5eff" +down_revision = "852ae6c715af" branch_labels = None depends_on = None -airflow_version = '1.10.10' +airflow_version = "1.10.10" def upgrade(): @@ -44,7 +44,7 @@ def upgrade(): Base = declarative_base() class SerializedDagModel(Base): - __tablename__ = 'serialized_dag' + __tablename__ = "serialized_dag" # There are other columns here, but these are the only ones we need for the SELECT/UPDATE we are doing dag_id = sa.Column(sa.String(250), primary_key=True) @@ -53,23 +53,23 @@ class SerializedDagModel(Base): """Apply add source code table""" op.create_table( - 'dag_code', - sa.Column('fileloc_hash', sa.BigInteger(), nullable=False, primary_key=True, autoincrement=False), - sa.Column('fileloc', sa.String(length=2000), nullable=False), - sa.Column('source_code', sa.UnicodeText(), nullable=False), - sa.Column('last_updated', sa.TIMESTAMP(timezone=True), nullable=False), + "dag_code", + sa.Column("fileloc_hash", sa.BigInteger(), nullable=False, primary_key=True, autoincrement=False), + sa.Column("fileloc", sa.String(length=2000), nullable=False), + sa.Column("source_code", sa.UnicodeText(), nullable=False), + sa.Column("last_updated", sa.TIMESTAMP(timezone=True), nullable=False), ) conn = op.get_bind() - if conn.dialect.name != 'sqlite': + if conn.dialect.name != "sqlite": if conn.dialect.name == "mssql": - op.drop_index('idx_fileloc_hash', 'serialized_dag') + op.drop_index("idx_fileloc_hash", "serialized_dag") op.alter_column( - table_name='serialized_dag', column_name='fileloc_hash', type_=sa.BigInteger(), nullable=False + table_name="serialized_dag", column_name="fileloc_hash", type_=sa.BigInteger(), nullable=False ) if conn.dialect.name == "mssql": - op.create_index('idx_fileloc_hash', 'serialized_dag', ['fileloc_hash']) + op.create_index("idx_fileloc_hash", "serialized_dag", ["fileloc_hash"]) sessionmaker = sa.orm.sessionmaker() session = sessionmaker(bind=conn) @@ -82,4 +82,4 @@ class SerializedDagModel(Base): def downgrade(): """Unapply add source code table""" - op.drop_table('dag_code') + op.drop_table("dag_code") diff --git a/airflow/migrations/versions/0055_1_10_11_add_precision_to_execution_date_in_mysql.py b/airflow/migrations/versions/0055_1_10_11_add_precision_to_execution_date_in_mysql.py index 873ba148d03d6..d9e13b8ec6d6a 100644 --- a/airflow/migrations/versions/0055_1_10_11_add_precision_to_execution_date_in_mysql.py +++ b/airflow/migrations/versions/0055_1_10_11_add_precision_to_execution_date_in_mysql.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add Precision to ``execution_date`` in ``RenderedTaskInstanceFields`` table Revision ID: a66efa278eea @@ -23,19 +22,20 @@ Create Date: 2020-06-16 21:44:02.883132 """ +from __future__ import annotations from alembic import op from sqlalchemy.dialects import mysql # revision identifiers, used by Alembic. -revision = 'a66efa278eea' -down_revision = '952da73b5eff' +revision = "a66efa278eea" +down_revision = "952da73b5eff" branch_labels = None depends_on = None -airflow_version = '1.10.11' +airflow_version = "1.10.11" -TABLE_NAME = 'rendered_task_instance_fields' -COLUMN_NAME = 'execution_date' +TABLE_NAME = "rendered_task_instance_fields" +COLUMN_NAME = "execution_date" def upgrade(): diff --git a/airflow/migrations/versions/0056_1_10_12_add_dag_hash_column_to_serialized_dag_.py b/airflow/migrations/versions/0056_1_10_12_add_dag_hash_column_to_serialized_dag_.py index 5113e2f28cec0..bbed486d096d8 100644 --- a/airflow/migrations/versions/0056_1_10_12_add_dag_hash_column_to_serialized_dag_.py +++ b/airflow/migrations/versions/0056_1_10_12_add_dag_hash_column_to_serialized_dag_.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``dag_hash`` Column to ``serialized_dag`` table Revision ID: da3f683c3a5a @@ -23,26 +22,27 @@ Create Date: 2020-08-07 20:52:09.178296 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = 'da3f683c3a5a' -down_revision = 'a66efa278eea' +revision = "da3f683c3a5a" +down_revision = "a66efa278eea" branch_labels = None depends_on = None -airflow_version = '1.10.12' +airflow_version = "1.10.12" def upgrade(): """Apply Add ``dag_hash`` Column to ``serialized_dag`` table""" op.add_column( - 'serialized_dag', - sa.Column('dag_hash', sa.String(32), nullable=False, server_default='Hash not calculated yet'), + "serialized_dag", + sa.Column("dag_hash", sa.String(32), nullable=False, server_default="Hash not calculated yet"), ) def downgrade(): """Unapply Add ``dag_hash`` Column to ``serialized_dag`` table""" - op.drop_column('serialized_dag', 'dag_hash') + op.drop_column("serialized_dag", "dag_hash") diff --git a/airflow/migrations/versions/0057_1_10_13_add_fab_tables.py b/airflow/migrations/versions/0057_1_10_13_add_fab_tables.py index bd3fe44e9ee87..56d0a8981528a 100644 --- a/airflow/migrations/versions/0057_1_10_13_add_fab_tables.py +++ b/airflow/migrations/versions/0057_1_10_13_add_fab_tables.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Create FAB Tables Revision ID: 92c57b58940d @@ -23,18 +22,18 @@ Create Date: 2020-11-13 19:27:10.161814 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op - -from airflow.compat.sqlalchemy import inspect +from sqlalchemy import inspect # revision identifiers, used by Alembic. -revision = '92c57b58940d' -down_revision = 'da3f683c3a5a' +revision = "92c57b58940d" +down_revision = "da3f683c3a5a" branch_labels = None depends_on = None -airflow_version = '1.10.13' +airflow_version = "1.10.13" def upgrade(): @@ -44,110 +43,110 @@ def upgrade(): tables = inspector.get_table_names() if "ab_permission" not in tables: op.create_table( - 'ab_permission', - sa.Column('id', sa.Integer(), nullable=False, primary_key=True), - sa.Column('name', sa.String(length=100), nullable=False), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('name'), + "ab_permission", + sa.Column("id", sa.Integer(), nullable=False, primary_key=True), + sa.Column("name", sa.String(length=100), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name"), ) if "ab_view_menu" not in tables: op.create_table( - 'ab_view_menu', - sa.Column('id', sa.Integer(), nullable=False, primary_key=True), - sa.Column('name', sa.String(length=100), nullable=False), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('name'), + "ab_view_menu", + sa.Column("id", sa.Integer(), nullable=False, primary_key=True), + sa.Column("name", sa.String(length=100), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name"), ) if "ab_role" not in tables: op.create_table( - 'ab_role', - sa.Column('id', sa.Integer(), nullable=False, primary_key=True), - sa.Column('name', sa.String(length=64), nullable=False), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('name'), + "ab_role", + sa.Column("id", sa.Integer(), nullable=False, primary_key=True), + sa.Column("name", sa.String(length=64), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name"), ) if "ab_permission_view" not in tables: op.create_table( - 'ab_permission_view', - sa.Column('id', sa.Integer(), nullable=False, primary_key=True), - sa.Column('permission_id', sa.Integer(), nullable=True), - sa.Column('view_menu_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['permission_id'], ['ab_permission.id']), - sa.ForeignKeyConstraint(['view_menu_id'], ['ab_view_menu.id']), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('permission_id', 'view_menu_id'), + "ab_permission_view", + sa.Column("id", sa.Integer(), nullable=False, primary_key=True), + sa.Column("permission_id", sa.Integer(), nullable=True), + sa.Column("view_menu_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["permission_id"], ["ab_permission.id"]), + sa.ForeignKeyConstraint(["view_menu_id"], ["ab_view_menu.id"]), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("permission_id", "view_menu_id"), ) if "ab_permission_view_role" not in tables: op.create_table( - 'ab_permission_view_role', - sa.Column('id', sa.Integer(), nullable=False, primary_key=True), - sa.Column('permission_view_id', sa.Integer(), nullable=True), - sa.Column('role_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['permission_view_id'], ['ab_permission_view.id']), - sa.ForeignKeyConstraint(['role_id'], ['ab_role.id']), - sa.PrimaryKeyConstraint('id'), + "ab_permission_view_role", + sa.Column("id", sa.Integer(), nullable=False, primary_key=True), + sa.Column("permission_view_id", sa.Integer(), nullable=True), + sa.Column("role_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["permission_view_id"], ["ab_permission_view.id"]), + sa.ForeignKeyConstraint(["role_id"], ["ab_role.id"]), + sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("permission_view_id", "role_id"), ) if "ab_user" not in tables: op.create_table( - 'ab_user', - sa.Column('id', sa.Integer(), nullable=False, primary_key=True), - sa.Column('first_name', sa.String(length=64), nullable=False), - sa.Column('last_name', sa.String(length=64), nullable=False), - sa.Column('username', sa.String(length=64), nullable=False), - sa.Column('password', sa.String(length=256), nullable=True), - sa.Column('active', sa.Boolean(), nullable=True), - sa.Column('email', sa.String(length=64), nullable=False), - sa.Column('last_login', sa.DateTime(), nullable=True), - sa.Column('login_count', sa.Integer(), nullable=True), - sa.Column('fail_login_count', sa.Integer(), nullable=True), - sa.Column('created_on', sa.DateTime(), nullable=True), - sa.Column('changed_on', sa.DateTime(), nullable=True), - sa.Column('created_by_fk', sa.Integer(), nullable=True), - sa.Column('changed_by_fk', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['changed_by_fk'], ['ab_user.id']), - sa.ForeignKeyConstraint(['created_by_fk'], ['ab_user.id']), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('email'), - sa.UniqueConstraint('username'), + "ab_user", + sa.Column("id", sa.Integer(), nullable=False, primary_key=True), + sa.Column("first_name", sa.String(length=64), nullable=False), + sa.Column("last_name", sa.String(length=64), nullable=False), + sa.Column("username", sa.String(length=64), nullable=False), + sa.Column("password", sa.String(length=256), nullable=True), + sa.Column("active", sa.Boolean(), nullable=True), + sa.Column("email", sa.String(length=64), nullable=False), + sa.Column("last_login", sa.DateTime(), nullable=True), + sa.Column("login_count", sa.Integer(), nullable=True), + sa.Column("fail_login_count", sa.Integer(), nullable=True), + sa.Column("created_on", sa.DateTime(), nullable=True), + sa.Column("changed_on", sa.DateTime(), nullable=True), + sa.Column("created_by_fk", sa.Integer(), nullable=True), + sa.Column("changed_by_fk", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["changed_by_fk"], ["ab_user.id"]), + sa.ForeignKeyConstraint(["created_by_fk"], ["ab_user.id"]), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("email"), + sa.UniqueConstraint("username"), ) if "ab_user_role" not in tables: op.create_table( - 'ab_user_role', - sa.Column('id', sa.Integer(), nullable=False, primary_key=True), - sa.Column('user_id', sa.Integer(), nullable=True), - sa.Column('role_id', sa.Integer(), nullable=True), + "ab_user_role", + sa.Column("id", sa.Integer(), nullable=False, primary_key=True), + sa.Column("user_id", sa.Integer(), nullable=True), + sa.Column("role_id", sa.Integer(), nullable=True), sa.ForeignKeyConstraint( - ['role_id'], - ['ab_role.id'], + ["role_id"], + ["ab_role.id"], ), sa.ForeignKeyConstraint( - ['user_id'], - ['ab_user.id'], + ["user_id"], + ["ab_user.id"], ), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('user_id', 'role_id'), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("user_id", "role_id"), ) if "ab_register_user" not in tables: op.create_table( - 'ab_register_user', - sa.Column('id', sa.Integer(), nullable=False, primary_key=True), - sa.Column('first_name', sa.String(length=64), nullable=False), - sa.Column('last_name', sa.String(length=64), nullable=False), - sa.Column('username', sa.String(length=64), nullable=False), - sa.Column('password', sa.String(length=256), nullable=True), - sa.Column('email', sa.String(length=64), nullable=False), - sa.Column('registration_date', sa.DateTime(), nullable=True), - sa.Column('registration_hash', sa.String(length=256), nullable=True), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('username'), + "ab_register_user", + sa.Column("id", sa.Integer(), nullable=False, primary_key=True), + sa.Column("first_name", sa.String(length=64), nullable=False), + sa.Column("last_name", sa.String(length=64), nullable=False), + sa.Column("username", sa.String(length=64), nullable=False), + sa.Column("password", sa.String(length=256), nullable=True), + sa.Column("email", sa.String(length=64), nullable=False), + sa.Column("registration_date", sa.DateTime(), nullable=True), + sa.Column("registration_hash", sa.String(length=256), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("username"), ) @@ -172,7 +171,7 @@ def downgrade(): indexes = inspector.get_foreign_keys(table) for index in indexes: if conn.dialect.name != "sqlite": - op.drop_constraint(index.get('name'), table, type_='foreignkey') + op.drop_constraint(index.get("name"), table, type_="foreignkey") for table in fab_tables: if table in tables: diff --git a/airflow/migrations/versions/0058_1_10_13_increase_length_of_fab_ab_view_menu_.py b/airflow/migrations/versions/0058_1_10_13_increase_length_of_fab_ab_view_menu_.py index 4378c8bd0c084..811562c286877 100644 --- a/airflow/migrations/versions/0058_1_10_13_increase_length_of_fab_ab_view_menu_.py +++ b/airflow/migrations/versions/0058_1_10_13_increase_length_of_fab_ab_view_menu_.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Increase length of ``Flask-AppBuilder`` ``ab_view_menu.name`` column Revision ID: 03afc6b6f902 @@ -23,19 +22,20 @@ Create Date: 2020-11-13 22:21:41.619565 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op +from sqlalchemy import inspect -from airflow.compat.sqlalchemy import inspect from airflow.migrations.db_types import StringID # revision identifiers, used by Alembic. -revision = '03afc6b6f902' -down_revision = '92c57b58940d' +revision = "03afc6b6f902" +down_revision = "92c57b58940d" branch_labels = None depends_on = None -airflow_version = '1.10.13' +airflow_version = "1.10.13" def upgrade(): @@ -62,8 +62,8 @@ def upgrade(): op.execute("PRAGMA foreign_keys=on") else: op.alter_column( - table_name='ab_view_menu', - column_name='name', + table_name="ab_view_menu", + column_name="name", type_=StringID(length=250), nullable=False, ) @@ -92,5 +92,5 @@ def downgrade(): op.execute("PRAGMA foreign_keys=on") else: op.alter_column( - table_name='ab_view_menu', column_name='name', type_=sa.String(length=100), nullable=False + table_name="ab_view_menu", column_name="name", type_=sa.String(length=100), nullable=False ) diff --git a/airflow/migrations/versions/0059_2_0_0_drop_user_and_chart.py b/airflow/migrations/versions/0059_2_0_0_drop_user_and_chart.py index deb2778661f68..a427133ae394c 100644 --- a/airflow/migrations/versions/0059_2_0_0_drop_user_and_chart.py +++ b/airflow/migrations/versions/0059_2_0_0_drop_user_and_chart.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Drop ``user`` and ``chart`` table Revision ID: cf5dc11e79ad @@ -22,18 +21,19 @@ Create Date: 2019-01-24 15:30:35.834740 """ +from __future__ import annotations + import sqlalchemy as sa from alembic import op +from sqlalchemy import inspect from sqlalchemy.dialects import mysql -from airflow.compat.sqlalchemy import inspect - # revision identifiers, used by Alembic. -revision = 'cf5dc11e79ad' -down_revision = '03afc6b6f902' +revision = "cf5dc11e79ad" +down_revision = "03afc6b6f902" branch_labels = None depends_on = None -airflow_version = '2.0.0' +airflow_version = "2.0.0" def upgrade(): @@ -47,11 +47,11 @@ def upgrade(): inspector = inspect(conn) tables = inspector.get_table_names() - if 'known_event' in tables: + if "known_event" in tables: for fkey in inspector.get_foreign_keys(table_name="known_event", referred_table="users"): - if fkey['name']: - with op.batch_alter_table(table_name='known_event') as bop: - bop.drop_constraint(fkey['name'], type_="foreignkey") + if fkey["name"]: + with op.batch_alter_table(table_name="known_event") as bop: + bop.drop_constraint(fkey["name"], type_="foreignkey") if "chart" in tables: op.drop_table( @@ -66,48 +66,48 @@ def downgrade(): conn = op.get_bind() op.create_table( - 'users', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('username', sa.String(length=250), nullable=True), - sa.Column('email', sa.String(length=500), nullable=True), - sa.Column('password', sa.String(255)), - sa.Column('superuser', sa.Boolean(), default=False), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('username'), + "users", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("username", sa.String(length=250), nullable=True), + sa.Column("email", sa.String(length=500), nullable=True), + sa.Column("password", sa.String(255)), + sa.Column("superuser", sa.Boolean(), default=False), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("username"), ) op.create_table( - 'chart', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('label', sa.String(length=200), nullable=True), - sa.Column('conn_id', sa.String(length=250), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=True), - sa.Column('chart_type', sa.String(length=100), nullable=True), - sa.Column('sql_layout', sa.String(length=50), nullable=True), - sa.Column('sql', sa.Text(), nullable=True), - sa.Column('y_log_scale', sa.Boolean(), nullable=True), - sa.Column('show_datatable', sa.Boolean(), nullable=True), - sa.Column('show_sql', sa.Boolean(), nullable=True), - sa.Column('height', sa.Integer(), nullable=True), - sa.Column('default_params', sa.String(length=5000), nullable=True), - sa.Column('x_is_date', sa.Boolean(), nullable=True), - sa.Column('iteration_no', sa.Integer(), nullable=True), - sa.Column('last_modified', sa.DateTime(), nullable=True), + "chart", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("label", sa.String(length=200), nullable=True), + sa.Column("conn_id", sa.String(length=250), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=True), + sa.Column("chart_type", sa.String(length=100), nullable=True), + sa.Column("sql_layout", sa.String(length=50), nullable=True), + sa.Column("sql", sa.Text(), nullable=True), + sa.Column("y_log_scale", sa.Boolean(), nullable=True), + sa.Column("show_datatable", sa.Boolean(), nullable=True), + sa.Column("show_sql", sa.Boolean(), nullable=True), + sa.Column("height", sa.Integer(), nullable=True), + sa.Column("default_params", sa.String(length=5000), nullable=True), + sa.Column("x_is_date", sa.Boolean(), nullable=True), + sa.Column("iteration_no", sa.Integer(), nullable=True), + sa.Column("last_modified", sa.DateTime(), nullable=True), sa.ForeignKeyConstraint( - ['user_id'], - ['users.id'], + ["user_id"], + ["users.id"], ), - sa.PrimaryKeyConstraint('id'), + sa.PrimaryKeyConstraint("id"), ) - if conn.dialect.name == 'mysql': + if conn.dialect.name == "mysql": conn.execute("SET time_zone = '+00:00'") - op.alter_column(table_name='chart', column_name='last_modified', type_=mysql.TIMESTAMP(fsp=6)) + op.alter_column(table_name="chart", column_name="last_modified", type_=mysql.TIMESTAMP(fsp=6)) else: - if conn.dialect.name in ('sqlite', 'mssql'): + if conn.dialect.name in ("sqlite", "mssql"): return - if conn.dialect.name == 'postgresql': + if conn.dialect.name == "postgresql": conn.execute("set timezone=UTC") - op.alter_column(table_name='chart', column_name='last_modified', type_=sa.TIMESTAMP(timezone=True)) + op.alter_column(table_name="chart", column_name="last_modified", type_=sa.TIMESTAMP(timezone=True)) diff --git a/airflow/migrations/versions/0060_2_0_0_remove_id_column_from_xcom.py b/airflow/migrations/versions/0060_2_0_0_remove_id_column_from_xcom.py index a588af5c53917..64ccc753a5bf0 100644 --- a/airflow/migrations/versions/0060_2_0_0_remove_id_column_from_xcom.py +++ b/airflow/migrations/versions/0060_2_0_0_remove_id_column_from_xcom.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Remove id column from xcom Revision ID: bbf4a7ad0465 @@ -23,23 +22,22 @@ Create Date: 2019-10-29 13:53:09.445943 """ +from __future__ import annotations from collections import defaultdict from alembic import op -from sqlalchemy import Column, Integer - -from airflow.compat.sqlalchemy import inspect +from sqlalchemy import Column, Integer, inspect # revision identifiers, used by Alembic. -revision = 'bbf4a7ad0465' -down_revision = 'cf5dc11e79ad' +revision = "bbf4a7ad0465" +down_revision = "cf5dc11e79ad" branch_labels = None depends_on = None -airflow_version = '2.0.0' +airflow_version = "2.0.0" -def get_table_constraints(conn, table_name): +def get_table_constraints(conn, table_name) -> dict[tuple[str, str], list[str]]: """ This function return primary and unique constraint along with column name. Some tables like `task_instance` @@ -50,7 +48,6 @@ def get_table_constraints(conn, table_name): :param conn: sql connection object :param table_name: table name :return: a dictionary of ((constraint name, constraint type), column name) of table - :rtype: defaultdict(list) """ query = f"""SELECT tc.CONSTRAINT_NAME , tc.CONSTRAINT_TYPE, ccu.COLUMN_NAME FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc @@ -75,9 +72,9 @@ def drop_column_constraints(operator, column_name, constraint_dict): for constraint, columns in constraint_dict.items(): if column_name in columns: if constraint[1].lower().startswith("primary"): - operator.drop_constraint(constraint[0], type_='primary') + operator.drop_constraint(constraint[0], type_="primary") elif constraint[1].lower().startswith("unique"): - operator.drop_constraint(constraint[0], type_='unique') + operator.drop_constraint(constraint[0], type_="unique") def create_constraints(operator, column_name, constraint_dict): @@ -100,25 +97,25 @@ def upgrade(): conn = op.get_bind() inspector = inspect(conn) - with op.batch_alter_table('xcom') as bop: - xcom_columns = [col.get('name') for col in inspector.get_columns("xcom")] + with op.batch_alter_table("xcom") as bop: + xcom_columns = [col.get("name") for col in inspector.get_columns("xcom")] if "id" in xcom_columns: - if conn.dialect.name == 'mssql': + if conn.dialect.name == "mssql": constraint_dict = get_table_constraints(conn, "xcom") - drop_column_constraints(bop, 'id', constraint_dict) - bop.drop_column('id') - bop.drop_index('idx_xcom_dag_task_date') + drop_column_constraints(bop, "id", constraint_dict) + bop.drop_column("id") + bop.drop_index("idx_xcom_dag_task_date") # mssql doesn't allow primary keys with nullable columns - if conn.dialect.name != 'mssql': - bop.create_primary_key('pk_xcom', ['dag_id', 'task_id', 'key', 'execution_date']) + if conn.dialect.name != "mssql": + bop.create_primary_key("pk_xcom", ["dag_id", "task_id", "key", "execution_date"]) def downgrade(): """Unapply Remove id column from xcom""" conn = op.get_bind() - with op.batch_alter_table('xcom') as bop: - if conn.dialect.name != 'mssql': - bop.drop_constraint('pk_xcom', type_='primary') - bop.add_column(Column('id', Integer, nullable=False)) - bop.create_primary_key('id', ['id']) - bop.create_index('idx_xcom_dag_task_date', ['dag_id', 'task_id', 'key', 'execution_date']) + with op.batch_alter_table("xcom") as bop: + if conn.dialect.name != "mssql": + bop.drop_constraint("pk_xcom", type_="primary") + bop.add_column(Column("id", Integer, nullable=False)) + bop.create_primary_key("id", ["id"]) + bop.create_index("idx_xcom_dag_task_date", ["dag_id", "task_id", "key", "execution_date"]) diff --git a/airflow/migrations/versions/0061_2_0_0_increase_length_of_pool_name.py b/airflow/migrations/versions/0061_2_0_0_increase_length_of_pool_name.py index 1174c9be6bf90..62256e2d6ab30 100644 --- a/airflow/migrations/versions/0061_2_0_0_increase_length_of_pool_name.py +++ b/airflow/migrations/versions/0061_2_0_0_increase_length_of_pool_name.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Increase length of pool name Revision ID: b25a55525161 @@ -23,6 +22,7 @@ Create Date: 2020-03-09 08:48:14.534700 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op @@ -30,21 +30,21 @@ from airflow.models.base import COLLATION_ARGS # revision identifiers, used by Alembic. -revision = 'b25a55525161' -down_revision = 'bbf4a7ad0465' +revision = "b25a55525161" +down_revision = "bbf4a7ad0465" branch_labels = None depends_on = None -airflow_version = '2.0.0' +airflow_version = "2.0.0" def upgrade(): """Increase column length of pool name from 50 to 256 characters""" # use batch_alter_table to support SQLite workaround - with op.batch_alter_table('slot_pool', table_args=sa.UniqueConstraint('pool')) as batch_op: - batch_op.alter_column('pool', type_=sa.String(256, **COLLATION_ARGS)) + with op.batch_alter_table("slot_pool", table_args=sa.UniqueConstraint("pool")) as batch_op: + batch_op.alter_column("pool", type_=sa.String(256, **COLLATION_ARGS)) def downgrade(): """Revert Increased length of pool name from 256 to 50 characters""" - with op.batch_alter_table('slot_pool', table_args=sa.UniqueConstraint('pool')) as batch_op: - batch_op.alter_column('pool', type_=sa.String(50)) + with op.batch_alter_table("slot_pool", table_args=sa.UniqueConstraint("pool")) as batch_op: + batch_op.alter_column("pool", type_=sa.String(50)) diff --git a/airflow/migrations/versions/0062_2_0_0_add_dagrun_run_type.py b/airflow/migrations/versions/0062_2_0_0_add_dagrun_run_type.py index ce3d090742dd9..69da52c3850ff 100644 --- a/airflow/migrations/versions/0062_2_0_0_add_dagrun_run_type.py +++ b/airflow/migrations/versions/0062_2_0_0_add_dagrun_run_type.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Add ``run_type`` column in ``dag_run`` table @@ -24,13 +23,13 @@ Create Date: 2020-04-08 13:35:25.671327 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op -from sqlalchemy import Column, Integer, String +from sqlalchemy import Column, Integer, String, inspect from sqlalchemy.ext.declarative import declarative_base -from airflow.compat.sqlalchemy import inspect from airflow.utils.types import DagRunType # revision identifiers, used by Alembic. @@ -38,7 +37,7 @@ down_revision = "b25a55525161" branch_labels = None depends_on = None -airflow_version = '2.0.0' +airflow_version = "2.0.0" Base = declarative_base() @@ -59,7 +58,7 @@ def upgrade(): conn = op.get_bind() inspector = inspect(conn) - dag_run_columns = [col.get('name') for col in inspector.get_columns("dag_run")] + dag_run_columns = [col.get("name") for col in inspector.get_columns("dag_run")] if "run_type" not in dag_run_columns: diff --git a/airflow/migrations/versions/0063_2_0_0_set_conn_type_as_non_nullable.py b/airflow/migrations/versions/0063_2_0_0_set_conn_type_as_non_nullable.py index 1df5bfbfe324c..411a5785980e9 100644 --- a/airflow/migrations/versions/0063_2_0_0_set_conn_type_as_non_nullable.py +++ b/airflow/migrations/versions/0063_2_0_0_set_conn_type_as_non_nullable.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Set ``conn_type`` as non-nullable Revision ID: 8f966b9c467a @@ -23,6 +22,7 @@ Create Date: 2020-06-08 22:36:34.534121 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op @@ -33,7 +33,7 @@ down_revision = "3c20cacc0044" branch_labels = None depends_on = None -airflow_version = '2.0.0' +airflow_version = "2.0.0" def upgrade(): diff --git a/airflow/migrations/versions/0064_2_0_0_add_unique_constraint_to_conn_id.py b/airflow/migrations/versions/0064_2_0_0_add_unique_constraint_to_conn_id.py index e3172ddf34172..094935b141190 100644 --- a/airflow/migrations/versions/0064_2_0_0_add_unique_constraint_to_conn_id.py +++ b/airflow/migrations/versions/0064_2_0_0_add_unique_constraint_to_conn_id.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add unique constraint to ``conn_id`` Revision ID: 8d48763f6d53 @@ -23,6 +22,7 @@ Create Date: 2020-05-03 16:55:01.834231 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op @@ -30,17 +30,17 @@ from airflow.models.base import COLLATION_ARGS # revision identifiers, used by Alembic. -revision = '8d48763f6d53' -down_revision = '8f966b9c467a' +revision = "8d48763f6d53" +down_revision = "8f966b9c467a" branch_labels = None depends_on = None -airflow_version = '2.0.0' +airflow_version = "2.0.0" def upgrade(): """Apply Add unique constraint to ``conn_id`` and set it as non-nullable""" try: - with op.batch_alter_table('connection') as batch_op: + with op.batch_alter_table("connection") as batch_op: batch_op.alter_column("conn_id", nullable=False, existing_type=sa.String(250, **COLLATION_ARGS)) batch_op.create_unique_constraint(constraint_name="unique_conn_id", columns=["conn_id"]) @@ -50,7 +50,7 @@ def upgrade(): def downgrade(): """Unapply Add unique constraint to ``conn_id`` and set it as non-nullable""" - with op.batch_alter_table('connection') as batch_op: + with op.batch_alter_table("connection") as batch_op: batch_op.drop_constraint(constraint_name="unique_conn_id", type_="unique") batch_op.alter_column("conn_id", nullable=True, existing_type=sa.String(250)) diff --git a/airflow/migrations/versions/0065_2_0_0_update_schema_for_smart_sensor.py b/airflow/migrations/versions/0065_2_0_0_update_schema_for_smart_sensor.py index d98be012b7b52..93279c3dea092 100644 --- a/airflow/migrations/versions/0065_2_0_0_update_schema_for_smart_sensor.py +++ b/airflow/migrations/versions/0065_2_0_0_update_schema_for_smart_sensor.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``sensor_instance`` table Revision ID: e38be357a868 @@ -23,19 +22,20 @@ Create Date: 2019-06-07 04:03:17.003939 """ +from __future__ import annotations + import sqlalchemy as sa from alembic import op -from sqlalchemy import func +from sqlalchemy import func, inspect -from airflow.compat.sqlalchemy import inspect from airflow.migrations.db_types import TIMESTAMP, StringID # revision identifiers, used by Alembic. -revision = 'e38be357a868' -down_revision = '8d48763f6d53' +revision = "e38be357a868" +down_revision = "8d48763f6d53" branch_labels = None depends_on = None -airflow_version = '2.0.0' +airflow_version = "2.0.0" def upgrade(): @@ -43,38 +43,38 @@ def upgrade(): conn = op.get_bind() inspector = inspect(conn) tables = inspector.get_table_names() - if 'sensor_instance' in tables: + if "sensor_instance" in tables: return op.create_table( - 'sensor_instance', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('task_id', StringID(), nullable=False), - sa.Column('dag_id', StringID(), nullable=False), - sa.Column('execution_date', TIMESTAMP, nullable=False), - sa.Column('state', sa.String(length=20), nullable=True), - sa.Column('try_number', sa.Integer(), nullable=True), - sa.Column('start_date', TIMESTAMP, nullable=True), - sa.Column('operator', sa.String(length=1000), nullable=False), - sa.Column('op_classpath', sa.String(length=1000), nullable=False), - sa.Column('hashcode', sa.BigInteger(), nullable=False), - sa.Column('shardcode', sa.Integer(), nullable=False), - sa.Column('poke_context', sa.Text(), nullable=False), - sa.Column('execution_context', sa.Text(), nullable=True), - sa.Column('created_at', TIMESTAMP, default=func.now(), nullable=False), - sa.Column('updated_at', TIMESTAMP, default=func.now(), nullable=False), - sa.PrimaryKeyConstraint('id'), + "sensor_instance", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("task_id", StringID(), nullable=False), + sa.Column("dag_id", StringID(), nullable=False), + sa.Column("execution_date", TIMESTAMP, nullable=False), + sa.Column("state", sa.String(length=20), nullable=True), + sa.Column("try_number", sa.Integer(), nullable=True), + sa.Column("start_date", TIMESTAMP, nullable=True), + sa.Column("operator", sa.String(length=1000), nullable=False), + sa.Column("op_classpath", sa.String(length=1000), nullable=False), + sa.Column("hashcode", sa.BigInteger(), nullable=False), + sa.Column("shardcode", sa.Integer(), nullable=False), + sa.Column("poke_context", sa.Text(), nullable=False), + sa.Column("execution_context", sa.Text(), nullable=True), + sa.Column("created_at", TIMESTAMP, default=func.now, nullable=False), + sa.Column("updated_at", TIMESTAMP, default=func.now, nullable=False), + sa.PrimaryKeyConstraint("id"), ) - op.create_index('ti_primary_key', 'sensor_instance', ['dag_id', 'task_id', 'execution_date'], unique=True) - op.create_index('si_hashcode', 'sensor_instance', ['hashcode'], unique=False) - op.create_index('si_shardcode', 'sensor_instance', ['shardcode'], unique=False) - op.create_index('si_state_shard', 'sensor_instance', ['state', 'shardcode'], unique=False) - op.create_index('si_updated_at', 'sensor_instance', ['updated_at'], unique=False) + op.create_index("ti_primary_key", "sensor_instance", ["dag_id", "task_id", "execution_date"], unique=True) + op.create_index("si_hashcode", "sensor_instance", ["hashcode"], unique=False) + op.create_index("si_shardcode", "sensor_instance", ["shardcode"], unique=False) + op.create_index("si_state_shard", "sensor_instance", ["state", "shardcode"], unique=False) + op.create_index("si_updated_at", "sensor_instance", ["updated_at"], unique=False) def downgrade(): conn = op.get_bind() inspector = inspect(conn) tables = inspector.get_table_names() - if 'sensor_instance' in tables: - op.drop_table('sensor_instance') + if "sensor_instance" in tables: + op.drop_table("sensor_instance") diff --git a/airflow/migrations/versions/0066_2_0_0_add_queued_by_job_id_to_ti.py b/airflow/migrations/versions/0066_2_0_0_add_queued_by_job_id_to_ti.py index 90247d02ef1b5..011fef7a3aa9e 100644 --- a/airflow/migrations/versions/0066_2_0_0_add_queued_by_job_id_to_ti.py +++ b/airflow/migrations/versions/0066_2_0_0_add_queued_by_job_id_to_ti.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add queued by Job ID to TI Revision ID: b247b1e3d1ed @@ -23,25 +22,26 @@ Create Date: 2020-09-04 11:53:00.978882 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = 'b247b1e3d1ed' -down_revision = 'e38be357a868' +revision = "b247b1e3d1ed" +down_revision = "e38be357a868" branch_labels = None depends_on = None -airflow_version = '2.0.0' +airflow_version = "2.0.0" def upgrade(): """Apply Add queued by Job ID to TI""" - with op.batch_alter_table('task_instance') as batch_op: - batch_op.add_column(sa.Column('queued_by_job_id', sa.Integer(), nullable=True)) + with op.batch_alter_table("task_instance") as batch_op: + batch_op.add_column(sa.Column("queued_by_job_id", sa.Integer(), nullable=True)) def downgrade(): """Unapply Add queued by Job ID to TI""" - with op.batch_alter_table('task_instance') as batch_op: - batch_op.drop_column('queued_by_job_id') + with op.batch_alter_table("task_instance") as batch_op: + batch_op.drop_column("queued_by_job_id") diff --git a/airflow/migrations/versions/0067_2_0_0_add_external_executor_id_to_ti.py b/airflow/migrations/versions/0067_2_0_0_add_external_executor_id_to_ti.py index a9b6030caff8f..aec294aef7f11 100644 --- a/airflow/migrations/versions/0067_2_0_0_add_external_executor_id_to_ti.py +++ b/airflow/migrations/versions/0067_2_0_0_add_external_executor_id_to_ti.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add external executor ID to TI Revision ID: e1a11ece99cc @@ -23,25 +22,26 @@ Create Date: 2020-09-12 08:23:45.698865 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = 'e1a11ece99cc' -down_revision = 'b247b1e3d1ed' +revision = "e1a11ece99cc" +down_revision = "b247b1e3d1ed" branch_labels = None depends_on = None -airflow_version = '2.0.0' +airflow_version = "2.0.0" def upgrade(): """Apply Add external executor ID to TI""" - with op.batch_alter_table('task_instance', schema=None) as batch_op: - batch_op.add_column(sa.Column('external_executor_id', sa.String(length=250), nullable=True)) + with op.batch_alter_table("task_instance", schema=None) as batch_op: + batch_op.add_column(sa.Column("external_executor_id", sa.String(length=250), nullable=True)) def downgrade(): """Unapply Add external executor ID to TI""" - with op.batch_alter_table('task_instance', schema=None) as batch_op: - batch_op.drop_column('external_executor_id') + with op.batch_alter_table("task_instance", schema=None) as batch_op: + batch_op.drop_column("external_executor_id") diff --git a/airflow/migrations/versions/0068_2_0_0_drop_kuberesourceversion_and_.py b/airflow/migrations/versions/0068_2_0_0_drop_kuberesourceversion_and_.py index d2449aa48dc1f..2d90f587c259d 100644 --- a/airflow/migrations/versions/0068_2_0_0_drop_kuberesourceversion_and_.py +++ b/airflow/migrations/versions/0068_2_0_0_drop_kuberesourceversion_and_.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Drop ``KubeResourceVersion`` and ``KubeWorkerId`` Revision ID: bef4f3d11e8b @@ -23,18 +22,18 @@ Create Date: 2020-09-22 18:45:28.011654 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op - -from airflow.compat.sqlalchemy import inspect +from sqlalchemy import inspect # revision identifiers, used by Alembic. -revision = 'bef4f3d11e8b' -down_revision = 'e1a11ece99cc' +revision = "bef4f3d11e8b" +down_revision = "e1a11ece99cc" branch_labels = None depends_on = None -airflow_version = '2.0.0' +airflow_version = "2.0.0" WORKER_UUID_TABLE = "kube_worker_uuid" diff --git a/airflow/migrations/versions/0069_2_0_0_add_scheduling_decision_to_dagrun_and_.py b/airflow/migrations/versions/0069_2_0_0_add_scheduling_decision_to_dagrun_and_.py index 66f25cf29fac8..43c79291729b3 100644 --- a/airflow/migrations/versions/0069_2_0_0_add_scheduling_decision_to_dagrun_and_.py +++ b/airflow/migrations/versions/0069_2_0_0_add_scheduling_decision_to_dagrun_and_.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``scheduling_decision`` to ``DagRun`` and ``DAG`` Revision ID: 98271e7606e2 @@ -23,6 +22,7 @@ Create Date: 2020-10-01 12:13:32.968148 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op @@ -30,11 +30,11 @@ from airflow.migrations.db_types import TIMESTAMP # revision identifiers, used by Alembic. -revision = '98271e7606e2' -down_revision = 'bef4f3d11e8b' +revision = "98271e7606e2" +down_revision = "bef4f3d11e8b" branch_labels = None depends_on = None -airflow_version = '2.0.0' +airflow_version = "2.0.0" def upgrade(): @@ -46,24 +46,24 @@ def upgrade(): if is_sqlite: op.execute("PRAGMA foreign_keys=off") - with op.batch_alter_table('dag_run', schema=None) as batch_op: - batch_op.add_column(sa.Column('last_scheduling_decision', TIMESTAMP, nullable=True)) - batch_op.create_index('idx_last_scheduling_decision', ['last_scheduling_decision'], unique=False) - batch_op.add_column(sa.Column('dag_hash', sa.String(32), nullable=True)) + with op.batch_alter_table("dag_run", schema=None) as batch_op: + batch_op.add_column(sa.Column("last_scheduling_decision", TIMESTAMP, nullable=True)) + batch_op.create_index("idx_last_scheduling_decision", ["last_scheduling_decision"], unique=False) + batch_op.add_column(sa.Column("dag_hash", sa.String(32), nullable=True)) - with op.batch_alter_table('dag', schema=None) as batch_op: - batch_op.add_column(sa.Column('next_dagrun', TIMESTAMP, nullable=True)) - batch_op.add_column(sa.Column('next_dagrun_create_after', TIMESTAMP, nullable=True)) + with op.batch_alter_table("dag", schema=None) as batch_op: + batch_op.add_column(sa.Column("next_dagrun", TIMESTAMP, nullable=True)) + batch_op.add_column(sa.Column("next_dagrun_create_after", TIMESTAMP, nullable=True)) # Create with nullable and no default, then ALTER to set values, to avoid table level lock - batch_op.add_column(sa.Column('concurrency', sa.Integer(), nullable=True)) - batch_op.add_column(sa.Column('has_task_concurrency_limits', sa.Boolean(), nullable=True)) + batch_op.add_column(sa.Column("concurrency", sa.Integer(), nullable=True)) + batch_op.add_column(sa.Column("has_task_concurrency_limits", sa.Boolean(), nullable=True)) - batch_op.create_index('idx_next_dagrun_create_after', ['next_dagrun_create_after'], unique=False) + batch_op.create_index("idx_next_dagrun_create_after", ["next_dagrun_create_after"], unique=False) try: from airflow.configuration import conf - concurrency = conf.getint('core', 'dag_concurrency', fallback=16) + concurrency = conf.getint("core", "dag_concurrency", fallback=16) except: # noqa concurrency = 16 @@ -79,9 +79,9 @@ def upgrade(): """ ) - with op.batch_alter_table('dag', schema=None) as batch_op: - batch_op.alter_column('concurrency', type_=sa.Integer(), nullable=False) - batch_op.alter_column('has_task_concurrency_limits', type_=sa.Boolean(), nullable=False) + with op.batch_alter_table("dag", schema=None) as batch_op: + batch_op.alter_column("concurrency", type_=sa.Integer(), nullable=False) + batch_op.alter_column("has_task_concurrency_limits", type_=sa.Boolean(), nullable=False) if is_sqlite: op.execute("PRAGMA foreign_keys=on") @@ -95,17 +95,17 @@ def downgrade(): if is_sqlite: op.execute("PRAGMA foreign_keys=off") - with op.batch_alter_table('dag_run', schema=None) as batch_op: - batch_op.drop_index('idx_last_scheduling_decision') - batch_op.drop_column('last_scheduling_decision') - batch_op.drop_column('dag_hash') - - with op.batch_alter_table('dag', schema=None) as batch_op: - batch_op.drop_index('idx_next_dagrun_create_after') - batch_op.drop_column('next_dagrun_create_after') - batch_op.drop_column('next_dagrun') - batch_op.drop_column('concurrency') - batch_op.drop_column('has_task_concurrency_limits') + with op.batch_alter_table("dag_run", schema=None) as batch_op: + batch_op.drop_index("idx_last_scheduling_decision") + batch_op.drop_column("last_scheduling_decision") + batch_op.drop_column("dag_hash") + + with op.batch_alter_table("dag", schema=None) as batch_op: + batch_op.drop_index("idx_next_dagrun_create_after") + batch_op.drop_column("next_dagrun_create_after") + batch_op.drop_column("next_dagrun") + batch_op.drop_column("concurrency") + batch_op.drop_column("has_task_concurrency_limits") if is_sqlite: op.execute("PRAGMA foreign_keys=on") diff --git a/airflow/migrations/versions/0070_2_0_0_fix_mssql_exec_date_rendered_task_instance.py b/airflow/migrations/versions/0070_2_0_0_fix_mssql_exec_date_rendered_task_instance.py index dd14290729997..97bb6a4e83df2 100644 --- a/airflow/migrations/versions/0070_2_0_0_fix_mssql_exec_date_rendered_task_instance.py +++ b/airflow/migrations/versions/0070_2_0_0_fix_mssql_exec_date_rendered_task_instance.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """fix_mssql_exec_date_rendered_task_instance_fields_for_MSSQL Revision ID: 52d53670a240 @@ -23,18 +22,20 @@ Create Date: 2020-10-13 15:13:24.911486 """ +from __future__ import annotations + import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import mssql # revision identifiers, used by Alembic. -revision = '52d53670a240' -down_revision = '98271e7606e2' +revision = "52d53670a240" +down_revision = "98271e7606e2" branch_labels = None depends_on = None -airflow_version = '2.0.0' +airflow_version = "2.0.0" -TABLE_NAME = 'rendered_task_instance_fields' +TABLE_NAME = "rendered_task_instance_fields" def upgrade(): @@ -49,11 +50,11 @@ def upgrade(): op.create_table( TABLE_NAME, - sa.Column('dag_id', sa.String(length=250), nullable=False), - sa.Column('task_id', sa.String(length=250), nullable=False), - sa.Column('execution_date', mssql.DATETIME2, nullable=False), - sa.Column('rendered_fields', json_type(), nullable=False), - sa.PrimaryKeyConstraint('dag_id', 'task_id', 'execution_date'), + sa.Column("dag_id", sa.String(length=250), nullable=False), + sa.Column("task_id", sa.String(length=250), nullable=False), + sa.Column("execution_date", mssql.DATETIME2, nullable=False), + sa.Column("rendered_fields", json_type(), nullable=False), + sa.PrimaryKeyConstraint("dag_id", "task_id", "execution_date"), ) @@ -69,9 +70,9 @@ def downgrade(): op.create_table( TABLE_NAME, - sa.Column('dag_id', sa.String(length=250), nullable=False), - sa.Column('task_id', sa.String(length=250), nullable=False), - sa.Column('execution_date', sa.TIMESTAMP, nullable=False), - sa.Column('rendered_fields', json_type(), nullable=False), - sa.PrimaryKeyConstraint('dag_id', 'task_id', 'execution_date'), + sa.Column("dag_id", sa.String(length=250), nullable=False), + sa.Column("task_id", sa.String(length=250), nullable=False), + sa.Column("execution_date", sa.TIMESTAMP, nullable=False), + sa.Column("rendered_fields", json_type(), nullable=False), + sa.PrimaryKeyConstraint("dag_id", "task_id", "execution_date"), ) diff --git a/airflow/migrations/versions/0071_2_0_0_add_job_id_to_dagrun_table.py b/airflow/migrations/versions/0071_2_0_0_add_job_id_to_dagrun_table.py index c42308393fef2..f0dd1c4d969a3 100644 --- a/airflow/migrations/versions/0071_2_0_0_add_job_id_to_dagrun_table.py +++ b/airflow/migrations/versions/0071_2_0_0_add_job_id_to_dagrun_table.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``creating_job_id`` to ``DagRun`` table Revision ID: 364159666cbd @@ -23,23 +22,24 @@ Create Date: 2020-10-10 09:08:07.332456 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = '364159666cbd' -down_revision = '52d53670a240' +revision = "364159666cbd" +down_revision = "52d53670a240" branch_labels = None depends_on = None -airflow_version = '2.0.0' +airflow_version = "2.0.0" def upgrade(): """Apply Add ``creating_job_id`` to ``DagRun`` table""" - op.add_column('dag_run', sa.Column('creating_job_id', sa.Integer)) + op.add_column("dag_run", sa.Column("creating_job_id", sa.Integer)) def downgrade(): """Unapply Add job_id to DagRun table""" - op.drop_column('dag_run', 'creating_job_id') + op.drop_column("dag_run", "creating_job_id") diff --git a/airflow/migrations/versions/0072_2_0_0_add_k8s_yaml_to_rendered_templates.py b/airflow/migrations/versions/0072_2_0_0_add_k8s_yaml_to_rendered_templates.py index 675df14259230..242878d5a8b52 100644 --- a/airflow/migrations/versions/0072_2_0_0_add_k8s_yaml_to_rendered_templates.py +++ b/airflow/migrations/versions/0072_2_0_0_add_k8s_yaml_to_rendered_templates.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """add-k8s-yaml-to-rendered-templates Revision ID: 45ba3f1493b9 @@ -23,6 +22,7 @@ Create Date: 2020-10-23 23:01:52.471442 """ +from __future__ import annotations import sqlalchemy_jsonfield from alembic import op @@ -31,14 +31,14 @@ from airflow.settings import json # revision identifiers, used by Alembic. -revision = '45ba3f1493b9' -down_revision = '364159666cbd' +revision = "45ba3f1493b9" +down_revision = "364159666cbd" branch_labels = None depends_on = None -airflow_version = '2.0.0' +airflow_version = "2.0.0" __tablename__ = "rendered_task_instance_fields" -k8s_pod_yaml = Column('k8s_pod_yaml', sqlalchemy_jsonfield.JSONField(json=json), nullable=True) +k8s_pod_yaml = Column("k8s_pod_yaml", sqlalchemy_jsonfield.JSONField(json=json), nullable=True) def upgrade(): @@ -50,4 +50,4 @@ def upgrade(): def downgrade(): """Unapply add-k8s-yaml-to-rendered-templates""" with op.batch_alter_table(__tablename__, schema=None) as batch_op: - batch_op.drop_column('k8s_pod_yaml') + batch_op.drop_column("k8s_pod_yaml") diff --git a/airflow/migrations/versions/0073_2_0_0_prefix_dag_permissions.py b/airflow/migrations/versions/0073_2_0_0_prefix_dag_permissions.py index 2538a12a89494..660da4ac2c4f9 100644 --- a/airflow/migrations/versions/0073_2_0_0_prefix_dag_permissions.py +++ b/airflow/migrations/versions/0073_2_0_0_prefix_dag_permissions.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Prefix DAG permissions. Revision ID: 849da589634d @@ -23,6 +22,7 @@ Create Date: 2020-10-01 17:25:10.006322 """ +from __future__ import annotations from flask_appbuilder import SQLA @@ -31,23 +31,23 @@ from airflow.www.fab_security.sqla.models import Action, Permission, Resource # revision identifiers, used by Alembic. -revision = '849da589634d' -down_revision = '45ba3f1493b9' +revision = "849da589634d" +down_revision = "45ba3f1493b9" branch_labels = None depends_on = None -airflow_version = '2.0.0' +airflow_version = "2.0.0" def prefix_individual_dag_permissions(session): - dag_perms = ['can_dag_read', 'can_dag_edit'] + dag_perms = ["can_dag_read", "can_dag_edit"] prefix = "DAG:" perms = ( session.query(Permission) .join(Action) .filter(Action.name.in_(dag_perms)) .join(Resource) - .filter(Resource.name != 'all_dags') - .filter(Resource.name.notlike(prefix + '%')) + .filter(Resource.name != "all_dags") + .filter(Resource.name.notlike(prefix + "%")) .all() ) resource_ids = {permission.resource.id for permission in perms} @@ -57,14 +57,14 @@ def prefix_individual_dag_permissions(session): def remove_prefix_in_individual_dag_permissions(session): - dag_perms = ['can_read', 'can_edit'] + dag_perms = ["can_read", "can_edit"] prefix = "DAG:" perms = ( session.query(Permission) .join(Action) .filter(Action.name.in_(dag_perms)) .join(Resource) - .filter(Resource.name.like(prefix + '%')) + .filter(Resource.name.like(prefix + "%")) .all() ) for permission in perms: @@ -86,12 +86,12 @@ def get_or_create_dag_resource(session): def get_or_create_all_dag_resource(session): - all_dag_resource = get_resource_query(session, 'all_dags').first() + all_dag_resource = get_resource_query(session, "all_dags").first() if all_dag_resource: return all_dag_resource all_dag_resource = Resource() - all_dag_resource.name = 'all_dags' + all_dag_resource.name = "all_dags" session.add(all_dag_resource) session.commit() @@ -156,19 +156,19 @@ def migrate_to_new_dag_permissions(db): prefix_individual_dag_permissions(db.session) # Update existing permissions to use `can_read` instead of `can_dag_read` - can_dag_read_action = get_action_query(db.session, 'can_dag_read').first() + can_dag_read_action = get_action_query(db.session, "can_dag_read").first() old_can_dag_read_permissions = get_permission_with_action_query(db.session, can_dag_read_action) - can_read_action = get_or_create_action(db.session, 'can_read') + can_read_action = get_or_create_action(db.session, "can_read") update_permission_action(db.session, old_can_dag_read_permissions, can_read_action) # Update existing permissions to use `can_edit` instead of `can_dag_edit` - can_dag_edit_action = get_action_query(db.session, 'can_dag_edit').first() + can_dag_edit_action = get_action_query(db.session, "can_dag_edit").first() old_can_dag_edit_permissions = get_permission_with_action_query(db.session, can_dag_edit_action) - can_edit_action = get_or_create_action(db.session, 'can_edit') + can_edit_action = get_or_create_action(db.session, "can_edit") update_permission_action(db.session, old_can_dag_edit_permissions, can_edit_action) # Update existing permissions for `all_dags` resource to use `DAGs` resource. - all_dags_resource = get_resource_query(db.session, 'all_dags').first() + all_dags_resource = get_resource_query(db.session, "all_dags").first() if all_dags_resource: old_all_dags_permission = get_permission_with_resource_query(db.session, all_dags_resource) dag_resource = get_or_create_dag_resource(db.session) @@ -193,15 +193,15 @@ def undo_migrate_to_new_dag_permissions(session): remove_prefix_in_individual_dag_permissions(session) # Update existing permissions to use `can_dag_read` instead of `can_read` - can_read_action = get_action_query(session, 'can_read').first() + can_read_action = get_action_query(session, "can_read").first() new_can_read_permissions = get_permission_with_action_query(session, can_read_action) - can_dag_read_action = get_or_create_action(session, 'can_dag_read') + can_dag_read_action = get_or_create_action(session, "can_dag_read") update_permission_action(session, new_can_read_permissions, can_dag_read_action) # Update existing permissions to use `can_dag_edit` instead of `can_edit` - can_edit_action = get_action_query(session, 'can_edit').first() + can_edit_action = get_action_query(session, "can_edit").first() new_can_edit_permissions = get_permission_with_action_query(session, can_edit_action) - can_dag_edit_action = get_or_create_action(session, 'can_dag_edit') + can_dag_edit_action = get_or_create_action(session, "can_dag_edit") update_permission_action(session, new_can_edit_permissions, can_dag_edit_action) # Update existing permissions for `DAGs` resource to use `all_dags` resource. diff --git a/airflow/migrations/versions/0074_2_0_0_resource_based_permissions.py b/airflow/migrations/versions/0074_2_0_0_resource_based_permissions.py index 6942be25cfd18..1748ca3d5f3aa 100644 --- a/airflow/migrations/versions/0074_2_0_0_resource_based_permissions.py +++ b/airflow/migrations/versions/0074_2_0_0_resource_based_permissions.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Resource based permissions. Revision ID: 2c6edca13270 @@ -23,17 +22,19 @@ Create Date: 2020-10-21 00:18:52.529438 """ +from __future__ import annotations + import logging from airflow.security import permissions -from airflow.www.app import create_app +from airflow.www.app import cached_app # revision identifiers, used by Alembic. -revision = '2c6edca13270' -down_revision = '849da589634d' +revision = "2c6edca13270" +down_revision = "849da589634d" branch_labels = None depends_on = None -airflow_version = '2.0.0' +airflow_version = "2.0.0" mapping = { @@ -287,7 +288,7 @@ def remap_permissions(): """Apply Map Airflow permissions.""" - appbuilder = create_app(config={'FAB_UPDATE_PERMS': False}).appbuilder + appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder for old, new in mapping.items(): (old_resource_name, old_action_name) = old old_permission = appbuilder.sm.get_permission(old_action_name, old_resource_name) @@ -312,7 +313,7 @@ def remap_permissions(): def undo_remap_permissions(): """Unapply Map Airflow permissions""" - appbuilder = create_app(config={'FAB_UPDATE_PERMS': False}).appbuilder + appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder for old, new in mapping.items(): (new_resource_name, new_action_name) = new[0] new_permission = appbuilder.sm.get_permission(new_action_name, new_resource_name) diff --git a/airflow/migrations/versions/0075_2_0_0_add_description_field_to_connection.py b/airflow/migrations/versions/0075_2_0_0_add_description_field_to_connection.py index 4c3f5835dcbfd..db03fae459818 100644 --- a/airflow/migrations/versions/0075_2_0_0_add_description_field_to_connection.py +++ b/airflow/migrations/versions/0075_2_0_0_add_description_field_to_connection.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add description field to ``connection`` table Revision ID: 61ec73d9401f @@ -23,33 +22,34 @@ Create Date: 2020-09-10 14:56:30.279248 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = '61ec73d9401f' -down_revision = '2c6edca13270' +revision = "61ec73d9401f" +down_revision = "2c6edca13270" branch_labels = None depends_on = None -airflow_version = '2.0.0' +airflow_version = "2.0.0" def upgrade(): """Apply Add description field to ``connection`` table""" conn = op.get_bind() - with op.batch_alter_table('connection') as batch_op: + with op.batch_alter_table("connection") as batch_op: if conn.dialect.name == "mysql": # Handles case where on mysql with utf8mb4 this would exceed the size of row # We have to set text type in this migration even if originally it was string # This is permanently fixed in the follow-up migration 64a7d6477aae - batch_op.add_column(sa.Column('description', sa.Text(length=5000), nullable=True)) + batch_op.add_column(sa.Column("description", sa.Text(length=5000), nullable=True)) else: - batch_op.add_column(sa.Column('description', sa.String(length=5000), nullable=True)) + batch_op.add_column(sa.Column("description", sa.String(length=5000), nullable=True)) def downgrade(): """Unapply Add description field to ``connection`` table""" - with op.batch_alter_table('connection', schema=None) as batch_op: - batch_op.drop_column('description') + with op.batch_alter_table("connection", schema=None) as batch_op: + batch_op.drop_column("description") diff --git a/airflow/migrations/versions/0076_2_0_0_fix_description_field_in_connection_to_.py b/airflow/migrations/versions/0076_2_0_0_fix_description_field_in_connection_to_.py index dba78c04c461a..a397378e4c6b8 100644 --- a/airflow/migrations/versions/0076_2_0_0_fix_description_field_in_connection_to_.py +++ b/airflow/migrations/versions/0076_2_0_0_fix_description_field_in_connection_to_.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Fix description field in ``connection`` to be ``text`` Revision ID: 64a7d6477aae @@ -23,16 +22,17 @@ Create Date: 2020-11-25 08:56:11.866607 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = '64a7d6477aae' -down_revision = '61ec73d9401f' +revision = "64a7d6477aae" +down_revision = "61ec73d9401f" branch_labels = None depends_on = None -airflow_version = '2.0.0' +airflow_version = "2.0.0" def upgrade(): @@ -43,15 +43,15 @@ def upgrade(): return if conn.dialect.name == "mysql": op.alter_column( - 'connection', - 'description', + "connection", + "description", existing_type=sa.String(length=5000), type_=sa.Text(length=5000), existing_nullable=True, ) else: # postgres does not allow size modifier for text type - op.alter_column('connection', 'description', existing_type=sa.String(length=5000), type_=sa.Text()) + op.alter_column("connection", "description", existing_type=sa.String(length=5000), type_=sa.Text()) def downgrade(): @@ -62,8 +62,8 @@ def downgrade(): return if conn.dialect.name == "mysql": op.alter_column( - 'connection', - 'description', + "connection", + "description", existing_type=sa.Text(5000), type_=sa.String(length=5000), existing_nullable=True, @@ -71,8 +71,8 @@ def downgrade(): else: # postgres does not allow size modifier for text type op.alter_column( - 'connection', - 'description', + "connection", + "description", existing_type=sa.Text(), type_=sa.String(length=5000), existing_nullable=True, diff --git a/airflow/migrations/versions/0077_2_0_0_change_field_in_dagcode_to_mediumtext_.py b/airflow/migrations/versions/0077_2_0_0_change_field_in_dagcode_to_mediumtext_.py index 2b905c8c2916d..c6343b7fc14f1 100644 --- a/airflow/migrations/versions/0077_2_0_0_change_field_in_dagcode_to_mediumtext_.py +++ b/airflow/migrations/versions/0077_2_0_0_change_field_in_dagcode_to_mediumtext_.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Change field in ``DagCode`` to ``MEDIUMTEXT`` for MySql Revision ID: e959f08ac86c @@ -23,22 +22,24 @@ Create Date: 2020-12-07 16:31:43.982353 """ +from __future__ import annotations + from alembic import op from sqlalchemy.dialects import mysql # revision identifiers, used by Alembic. -revision = 'e959f08ac86c' -down_revision = '64a7d6477aae' +revision = "e959f08ac86c" +down_revision = "64a7d6477aae" branch_labels = None depends_on = None -airflow_version = '2.0.0' +airflow_version = "2.0.0" def upgrade(): conn = op.get_bind() if conn.dialect.name == "mysql": op.alter_column( - table_name='dag_code', column_name='source_code', type_=mysql.MEDIUMTEXT, nullable=False + table_name="dag_code", column_name="source_code", type_=mysql.MEDIUMTEXT, nullable=False ) diff --git a/airflow/migrations/versions/0078_2_0_1_remove_can_read_permission_on_config_.py b/airflow/migrations/versions/0078_2_0_1_remove_can_read_permission_on_config_.py index 51e8f20dbaf6c..b9bc66d01e094 100644 --- a/airflow/migrations/versions/0078_2_0_1_remove_can_read_permission_on_config_.py +++ b/airflow/migrations/versions/0078_2_0_1_remove_can_read_permission_on_config_.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Remove ``can_read`` permission on config resource for ``User`` and ``Viewer`` role Revision ID: 82b7c48c147f @@ -23,17 +22,19 @@ Create Date: 2021-02-04 12:45:58.138224 """ +from __future__ import annotations + import logging from airflow.security import permissions -from airflow.www.app import create_app +from airflow.www.app import cached_app # revision identifiers, used by Alembic. -revision = '82b7c48c147f' -down_revision = 'e959f08ac86c' +revision = "82b7c48c147f" +down_revision = "e959f08ac86c" branch_labels = None depends_on = None -airflow_version = '2.0.1' +airflow_version = "2.0.1" def upgrade(): @@ -41,7 +42,7 @@ def upgrade(): log = logging.getLogger() handlers = log.handlers[:] - appbuilder = create_app(config={'FAB_UPDATE_PERMS': False}).appbuilder + appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder roles_to_modify = [role for role in appbuilder.sm.get_all_roles() if role.name in ["User", "Viewer"]] can_read_on_config_perm = appbuilder.sm.get_permission( permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG @@ -58,7 +59,7 @@ def upgrade(): def downgrade(): """Add can_read action on config resource for User and Viewer role""" - appbuilder = create_app(config={'FAB_UPDATE_PERMS': False}).appbuilder + appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder roles_to_modify = [role for role in appbuilder.sm.get_all_roles() if role.name in ["User", "Viewer"]] can_read_on_config_perm = appbuilder.sm.get_permission( permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG diff --git a/airflow/migrations/versions/0079_2_0_2_increase_size_of_connection_extra_field_.py b/airflow/migrations/versions/0079_2_0_2_increase_size_of_connection_extra_field_.py index 77873664f6c20..39a4be4cdd9b0 100644 --- a/airflow/migrations/versions/0079_2_0_2_increase_size_of_connection_extra_field_.py +++ b/airflow/migrations/versions/0079_2_0_2_increase_size_of_connection_extra_field_.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Increase size of ``connection.extra`` field to handle multiple RSA keys Revision ID: 449b4072c2da @@ -23,23 +22,24 @@ Create Date: 2020-03-16 19:02:55.337710 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = '449b4072c2da' -down_revision = '82b7c48c147f' +revision = "449b4072c2da" +down_revision = "82b7c48c147f" branch_labels = None depends_on = None -airflow_version = '2.0.2' +airflow_version = "2.0.2" def upgrade(): """Apply increase_length_for_connection_password""" - with op.batch_alter_table('connection', schema=None) as batch_op: + with op.batch_alter_table("connection", schema=None) as batch_op: batch_op.alter_column( - 'extra', + "extra", existing_type=sa.VARCHAR(length=5000), type_=sa.TEXT(), existing_nullable=True, @@ -48,9 +48,9 @@ def upgrade(): def downgrade(): """Unapply increase_length_for_connection_password""" - with op.batch_alter_table('connection', schema=None) as batch_op: + with op.batch_alter_table("connection", schema=None) as batch_op: batch_op.alter_column( - 'extra', + "extra", existing_type=sa.TEXT(), type_=sa.VARCHAR(length=5000), existing_nullable=True, diff --git a/airflow/migrations/versions/0080_2_0_2_change_default_pool_slots_to_1.py b/airflow/migrations/versions/0080_2_0_2_change_default_pool_slots_to_1.py index f5ae34c2977c9..16a2da2c71d84 100644 --- a/airflow/migrations/versions/0080_2_0_2_change_default_pool_slots_to_1.py +++ b/airflow/migrations/versions/0080_2_0_2_change_default_pool_slots_to_1.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Change default ``pool_slots`` to ``1`` Revision ID: 8646922c8a04 @@ -23,25 +22,38 @@ Create Date: 2021-02-23 23:19:22.409973 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = '8646922c8a04' -down_revision = '449b4072c2da' +revision = "8646922c8a04" +down_revision = "449b4072c2da" branch_labels = None depends_on = None -airflow_version = '2.0.2' +airflow_version = "2.0.2" def upgrade(): """Change default ``pool_slots`` to ``1`` and make pool_slots not nullable""" + op.execute("UPDATE task_instance SET pool_slots = 1 WHERE pool_slots IS NULL") with op.batch_alter_table("task_instance", schema=None) as batch_op: - batch_op.alter_column("pool_slots", existing_type=sa.Integer, nullable=False, server_default='1') + batch_op.alter_column("pool_slots", existing_type=sa.Integer, nullable=False, server_default="1") def downgrade(): """Unapply Change default ``pool_slots`` to ``1``""" - with op.batch_alter_table("task_instance", schema=None) as batch_op: - batch_op.alter_column("pool_slots", existing_type=sa.Integer, nullable=True, server_default=None) + conn = op.get_bind() + if conn.dialect.name == "mssql": + inspector = sa.inspect(conn.engine) + columns = inspector.get_columns("task_instance") + for col in columns: + if col["name"] == "pool_slots" and col["default"] == "('1')": + with op.batch_alter_table("task_instance", schema=None) as batch_op: + batch_op.alter_column( + "pool_slots", existing_type=sa.Integer, nullable=True, server_default=None + ) + else: + with op.batch_alter_table("task_instance", schema=None) as batch_op: + batch_op.alter_column("pool_slots", existing_type=sa.Integer, nullable=True, server_default=None) diff --git a/airflow/migrations/versions/0081_2_0_2_rename_last_scheduler_run_column.py b/airflow/migrations/versions/0081_2_0_2_rename_last_scheduler_run_column.py index 487abdb11994b..78a498bf69c09 100644 --- a/airflow/migrations/versions/0081_2_0_2_rename_last_scheduler_run_column.py +++ b/airflow/migrations/versions/0081_2_0_2_rename_last_scheduler_run_column.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Rename ``last_scheduler_run`` column in ``DAG`` table to ``last_parsed_time`` Revision ID: 2e42bb497a22 @@ -23,31 +22,32 @@ Create Date: 2021-03-04 19:50:38.880942 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import mssql # revision identifiers, used by Alembic. -revision = '2e42bb497a22' -down_revision = '8646922c8a04' +revision = "2e42bb497a22" +down_revision = "8646922c8a04" branch_labels = None depends_on = None -airflow_version = '2.0.2' +airflow_version = "2.0.2" def upgrade(): """Apply Rename ``last_scheduler_run`` column in ``DAG`` table to ``last_parsed_time``""" conn = op.get_bind() if conn.dialect.name == "mssql": - with op.batch_alter_table('dag') as batch_op: + with op.batch_alter_table("dag") as batch_op: batch_op.alter_column( - 'last_scheduler_run', new_column_name='last_parsed_time', type_=mssql.DATETIME2(precision=6) + "last_scheduler_run", new_column_name="last_parsed_time", type_=mssql.DATETIME2(precision=6) ) else: - with op.batch_alter_table('dag') as batch_op: + with op.batch_alter_table("dag") as batch_op: batch_op.alter_column( - 'last_scheduler_run', new_column_name='last_parsed_time', type_=sa.TIMESTAMP(timezone=True) + "last_scheduler_run", new_column_name="last_parsed_time", type_=sa.TIMESTAMP(timezone=True) ) @@ -55,12 +55,12 @@ def downgrade(): """Unapply Rename ``last_scheduler_run`` column in ``DAG`` table to ``last_parsed_time``""" conn = op.get_bind() if conn.dialect.name == "mssql": - with op.batch_alter_table('dag') as batch_op: + with op.batch_alter_table("dag") as batch_op: batch_op.alter_column( - 'last_parsed_time', new_column_name='last_scheduler_run', type_=mssql.DATETIME2(precision=6) + "last_parsed_time", new_column_name="last_scheduler_run", type_=mssql.DATETIME2(precision=6) ) else: - with op.batch_alter_table('dag') as batch_op: + with op.batch_alter_table("dag") as batch_op: batch_op.alter_column( - 'last_parsed_time', new_column_name='last_scheduler_run', type_=sa.TIMESTAMP(timezone=True) + "last_parsed_time", new_column_name="last_scheduler_run", type_=sa.TIMESTAMP(timezone=True) ) diff --git a/airflow/migrations/versions/0082_2_1_0_increase_pool_name_size_in_taskinstance.py b/airflow/migrations/versions/0082_2_1_0_increase_pool_name_size_in_taskinstance.py index 30e8ca553f8c3..5b3c4b89b4bf1 100644 --- a/airflow/migrations/versions/0082_2_1_0_increase_pool_name_size_in_taskinstance.py +++ b/airflow/migrations/versions/0082_2_1_0_increase_pool_name_size_in_taskinstance.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Increase maximum length of pool name in ``task_instance`` table to ``256`` characters Revision ID: 90d1635d7b86 @@ -23,32 +22,33 @@ Create Date: 2021-04-05 09:37:54.848731 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = '90d1635d7b86' -down_revision = '2e42bb497a22' +revision = "90d1635d7b86" +down_revision = "2e42bb497a22" branch_labels = None depends_on = None -airflow_version = '2.1.0' +airflow_version = "2.1.0" def upgrade(): """Apply Increase maximum length of pool name in ``task_instance`` table to ``256`` characters""" - with op.batch_alter_table('task_instance') as batch_op: - batch_op.alter_column('pool', type_=sa.String(256), nullable=False) + with op.batch_alter_table("task_instance") as batch_op: + batch_op.alter_column("pool", type_=sa.String(256), nullable=False) def downgrade(): """Unapply Increase maximum length of pool name in ``task_instance`` table to ``256`` characters""" conn = op.get_bind() - if conn.dialect.name == 'mssql': - with op.batch_alter_table('task_instance') as batch_op: - batch_op.drop_index('ti_pool') - batch_op.alter_column('pool', type_=sa.String(50), nullable=False) - batch_op.create_index('ti_pool', ['pool']) + if conn.dialect.name == "mssql": + with op.batch_alter_table("task_instance") as batch_op: + batch_op.drop_index("ti_pool") + batch_op.alter_column("pool", type_=sa.String(50), nullable=False) + batch_op.create_index("ti_pool", ["pool"]) else: - with op.batch_alter_table('task_instance') as batch_op: - batch_op.alter_column('pool', type_=sa.String(50), nullable=False) + with op.batch_alter_table("task_instance") as batch_op: + batch_op.alter_column("pool", type_=sa.String(50), nullable=False) diff --git a/airflow/migrations/versions/0083_2_1_0_add_description_field_to_variable.py b/airflow/migrations/versions/0083_2_1_0_add_description_field_to_variable.py index 761a718f32237..14fb651893be1 100644 --- a/airflow/migrations/versions/0083_2_1_0_add_description_field_to_variable.py +++ b/airflow/migrations/versions/0083_2_1_0_add_description_field_to_variable.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add description field to ``Variable`` model Revision ID: e165e7455d70 @@ -23,25 +22,26 @@ Create Date: 2021-04-11 22:28:02.107290 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = 'e165e7455d70' -down_revision = '90d1635d7b86' +revision = "e165e7455d70" +down_revision = "90d1635d7b86" branch_labels = None depends_on = None -airflow_version = '2.1.0' +airflow_version = "2.1.0" def upgrade(): """Apply Add description field to ``Variable`` model""" - with op.batch_alter_table('variable', schema=None) as batch_op: - batch_op.add_column(sa.Column('description', sa.Text(), nullable=True)) + with op.batch_alter_table("variable", schema=None) as batch_op: + batch_op.add_column(sa.Column("description", sa.Text(), nullable=True)) def downgrade(): """Unapply Add description field to ``Variable`` model""" - with op.batch_alter_table('variable', schema=None) as batch_op: - batch_op.drop_column('description') + with op.batch_alter_table("variable", schema=None) as batch_op: + batch_op.drop_column("description") diff --git a/airflow/migrations/versions/0084_2_1_0_resource_based_permissions_for_default_.py b/airflow/migrations/versions/0084_2_1_0_resource_based_permissions_for_default_.py index cd162de566e88..f5e8706c09d54 100644 --- a/airflow/migrations/versions/0084_2_1_0_resource_based_permissions_for_default_.py +++ b/airflow/migrations/versions/0084_2_1_0_resource_based_permissions_for_default_.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Resource based permissions for default ``Flask-AppBuilder`` views Revision ID: a13f7613ad25 @@ -23,17 +22,19 @@ Create Date: 2021-03-20 21:23:05.793378 """ +from __future__ import annotations + import logging from airflow.security import permissions -from airflow.www.app import create_app +from airflow.www.app import cached_app # revision identifiers, used by Alembic. -revision = 'a13f7613ad25' -down_revision = 'e165e7455d70' +revision = "a13f7613ad25" +down_revision = "e165e7455d70" branch_labels = None depends_on = None -airflow_version = '2.1.0' +airflow_version = "2.1.0" mapping = { @@ -139,7 +140,7 @@ def remap_permissions(): """Apply Map Airflow permissions.""" - appbuilder = create_app(config={'FAB_UPDATE_PERMS': False}).appbuilder + appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder for old, new in mapping.items(): (old_resource_name, old_action_name) = old old_permission = appbuilder.sm.get_permission(old_action_name, old_resource_name) @@ -164,7 +165,7 @@ def remap_permissions(): def undo_remap_permissions(): """Unapply Map Airflow permissions""" - appbuilder = create_app(config={'FAB_UPDATE_PERMS': False}).appbuilder + appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder for old, new in mapping.items(): (new_resource_name, new_action_name) = new[0] new_permission = appbuilder.sm.get_permission(new_action_name, new_resource_name) diff --git a/airflow/migrations/versions/0085_2_1_3_add_queued_at_column_to_dagrun_table.py b/airflow/migrations/versions/0085_2_1_3_add_queued_at_column_to_dagrun_table.py index f40daa1035b23..3d27a5d1e30f4 100644 --- a/airflow/migrations/versions/0085_2_1_3_add_queued_at_column_to_dagrun_table.py +++ b/airflow/migrations/versions/0085_2_1_3_add_queued_at_column_to_dagrun_table.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``queued_at`` column in ``dag_run`` table Revision ID: 97cdd93827b8 @@ -23,6 +22,7 @@ Create Date: 2021-06-29 21:53:48.059438 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op @@ -30,19 +30,19 @@ from airflow.migrations.db_types import TIMESTAMP # revision identifiers, used by Alembic. -revision = '97cdd93827b8' -down_revision = 'a13f7613ad25' +revision = "97cdd93827b8" +down_revision = "a13f7613ad25" branch_labels = None depends_on = None -airflow_version = '2.1.3' +airflow_version = "2.1.3" def upgrade(): """Apply Add ``queued_at`` column in ``dag_run`` table""" - op.add_column('dag_run', sa.Column('queued_at', TIMESTAMP, nullable=True)) + op.add_column("dag_run", sa.Column("queued_at", TIMESTAMP, nullable=True)) def downgrade(): """Unapply Add ``queued_at`` column in ``dag_run`` table""" - with op.batch_alter_table('dag_run') as batch_op: - batch_op.drop_column('queued_at') + with op.batch_alter_table("dag_run") as batch_op: + batch_op.drop_column("queued_at") diff --git a/airflow/migrations/versions/0086_2_1_4_add_max_active_runs_column_to_dagmodel_.py b/airflow/migrations/versions/0086_2_1_4_add_max_active_runs_column_to_dagmodel_.py index 49ba327dc00e8..c68e8e8e07c97 100644 --- a/airflow/migrations/versions/0086_2_1_4_add_max_active_runs_column_to_dagmodel_.py +++ b/airflow/migrations/versions/0086_2_1_4_add_max_active_runs_column_to_dagmodel_.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``max_active_runs`` column to ``dag_model`` table Revision ID: 092435bf5d12 @@ -23,27 +22,28 @@ Create Date: 2021-09-06 21:29:24.728923 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op from sqlalchemy import text # revision identifiers, used by Alembic. -revision = '092435bf5d12' -down_revision = '97cdd93827b8' +revision = "092435bf5d12" +down_revision = "97cdd93827b8" branch_labels = None depends_on = None -airflow_version = '2.1.4' +airflow_version = "2.1.4" def upgrade(): """Apply Add ``max_active_runs`` column to ``dag_model`` table""" - op.add_column('dag', sa.Column('max_active_runs', sa.Integer(), nullable=True)) - with op.batch_alter_table('dag_run', schema=None) as batch_op: + op.add_column("dag", sa.Column("max_active_runs", sa.Integer(), nullable=True)) + with op.batch_alter_table("dag_run", schema=None) as batch_op: # Add index to dag_run.dag_id and also add index to dag_run.state where state==running - batch_op.create_index('idx_dag_run_dag_id', ['dag_id']) + batch_op.create_index("idx_dag_run_dag_id", ["dag_id"]) batch_op.create_index( - 'idx_dag_run_running_dags', + "idx_dag_run_running_dags", ["state", "dag_id"], postgresql_where=text("state='running'"), mssql_where=text("state='running'"), @@ -53,9 +53,9 @@ def upgrade(): def downgrade(): """Unapply Add ``max_active_runs`` column to ``dag_model`` table""" - with op.batch_alter_table('dag') as batch_op: - batch_op.drop_column('max_active_runs') - with op.batch_alter_table('dag_run', schema=None) as batch_op: + with op.batch_alter_table("dag") as batch_op: + batch_op.drop_column("max_active_runs") + with op.batch_alter_table("dag_run", schema=None) as batch_op: # Drop index to dag_run.dag_id and also drop index to dag_run.state where state==running - batch_op.drop_index('idx_dag_run_dag_id') - batch_op.drop_index('idx_dag_run_running_dags') + batch_op.drop_index("idx_dag_run_dag_id") + batch_op.drop_index("idx_dag_run_running_dags") diff --git a/airflow/migrations/versions/0087_2_1_4_add_index_on_state_dag_id_for_queued_.py b/airflow/migrations/versions/0087_2_1_4_add_index_on_state_dag_id_for_queued_.py index f9e55fce83550..d7c34cc5b0586 100644 --- a/airflow/migrations/versions/0087_2_1_4_add_index_on_state_dag_id_for_queued_.py +++ b/airflow/migrations/versions/0087_2_1_4_add_index_on_state_dag_id_for_queued_.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add index on state, dag_id for queued ``dagrun`` Revision ID: ccde3e26fe78 @@ -23,23 +22,24 @@ Create Date: 2021-09-08 16:35:34.867711 """ +from __future__ import annotations from alembic import op from sqlalchemy import text # revision identifiers, used by Alembic. -revision = 'ccde3e26fe78' -down_revision = '092435bf5d12' +revision = "ccde3e26fe78" +down_revision = "092435bf5d12" branch_labels = None depends_on = None -airflow_version = '2.1.4' +airflow_version = "2.1.4" def upgrade(): """Apply Add index on state, dag_id for queued ``dagrun``""" - with op.batch_alter_table('dag_run') as batch_op: + with op.batch_alter_table("dag_run") as batch_op: batch_op.create_index( - 'idx_dag_run_queued_dags', + "idx_dag_run_queued_dags", ["state", "dag_id"], postgresql_where=text("state='queued'"), mssql_where=text("state='queued'"), @@ -49,5 +49,5 @@ def upgrade(): def downgrade(): """Unapply Add index on state, dag_id for queued ``dagrun``""" - with op.batch_alter_table('dag_run') as batch_op: - batch_op.drop_index('idx_dag_run_queued_dags') + with op.batch_alter_table("dag_run") as batch_op: + batch_op.drop_index("idx_dag_run_queued_dags") diff --git a/airflow/migrations/versions/0088_2_2_0_improve_mssql_compatibility.py b/airflow/migrations/versions/0088_2_2_0_improve_mssql_compatibility.py index 9fdb56ac9bf16..a04a9ffe70b32 100644 --- a/airflow/migrations/versions/0088_2_2_0_improve_mssql_compatibility.py +++ b/airflow/migrations/versions/0088_2_2_0_improve_mssql_compatibility.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Improve MSSQL compatibility Revision ID: 83f031fd9f1c @@ -23,6 +22,7 @@ Create Date: 2021-04-06 12:22:02.197726 """ +from __future__ import annotations from collections import defaultdict @@ -33,11 +33,11 @@ from airflow.migrations.db_types import TIMESTAMP # revision identifiers, used by Alembic. -revision = '83f031fd9f1c' -down_revision = 'ccde3e26fe78' +revision = "83f031fd9f1c" +down_revision = "ccde3e26fe78" branch_labels = None depends_on = None -airflow_version = '2.2.0' +airflow_version = "2.2.0" def is_table_empty(conn, table_name): @@ -48,10 +48,10 @@ def is_table_empty(conn, table_name): :param table_name: table name :return: Booelan indicating if the table is present """ - return conn.execute(f'select TOP 1 * from {table_name}').first() is None + return conn.execute(f"select TOP 1 * from {table_name}").first() is None -def get_table_constraints(conn, table_name): +def get_table_constraints(conn, table_name) -> dict[tuple[str, str], list[str]]: """ This function return primary and unique constraint along with column name. some tables like task_instance @@ -62,7 +62,6 @@ def get_table_constraints(conn, table_name): :param conn: sql connection object :param table_name: table name :return: a dictionary of ((constraint name, constraint type), column name) of table - :rtype: defaultdict(list) """ query = f"""SELECT tc.CONSTRAINT_NAME , tc.CONSTRAINT_TYPE, ccu.COLUMN_NAME FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc @@ -87,9 +86,9 @@ def drop_column_constraints(operator, column_name, constraint_dict): for constraint, columns in constraint_dict.items(): if column_name in columns: if constraint[1].lower().startswith("primary"): - operator.drop_constraint(constraint[0], type_='primary') + operator.drop_constraint(constraint[0], type_="primary") elif constraint[1].lower().startswith("unique"): - operator.drop_constraint(constraint[0], type_='unique') + operator.drop_constraint(constraint[0], type_="unique") def create_constraints(operator, column_name, constraint_dict): @@ -146,32 +145,32 @@ def alter_mssql_datetime_column(conn, op, table_name, column_name, nullable): def upgrade(): """Improve compatibility with MSSQL backend""" conn = op.get_bind() - if conn.dialect.name != 'mssql': + if conn.dialect.name != "mssql": return - recreate_mssql_ts_column(conn, op, 'dag_code', 'last_updated') - recreate_mssql_ts_column(conn, op, 'rendered_task_instance_fields', 'execution_date') - alter_mssql_datetime_column(conn, op, 'serialized_dag', 'last_updated', False) + recreate_mssql_ts_column(conn, op, "dag_code", "last_updated") + recreate_mssql_ts_column(conn, op, "rendered_task_instance_fields", "execution_date") + alter_mssql_datetime_column(conn, op, "serialized_dag", "last_updated", False) op.alter_column(table_name="xcom", column_name="timestamp", type_=TIMESTAMP, nullable=False) - with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op: - task_reschedule_batch_op.alter_column(column_name='end_date', type_=TIMESTAMP, nullable=False) - task_reschedule_batch_op.alter_column(column_name='reschedule_date', type_=TIMESTAMP, nullable=False) - task_reschedule_batch_op.alter_column(column_name='start_date', type_=TIMESTAMP, nullable=False) - with op.batch_alter_table('task_fail') as task_fail_batch_op: - task_fail_batch_op.drop_index('idx_task_fail_dag_task_date') + with op.batch_alter_table("task_reschedule") as task_reschedule_batch_op: + task_reschedule_batch_op.alter_column(column_name="end_date", type_=TIMESTAMP, nullable=False) + task_reschedule_batch_op.alter_column(column_name="reschedule_date", type_=TIMESTAMP, nullable=False) + task_reschedule_batch_op.alter_column(column_name="start_date", type_=TIMESTAMP, nullable=False) + with op.batch_alter_table("task_fail") as task_fail_batch_op: + task_fail_batch_op.drop_index("idx_task_fail_dag_task_date") task_fail_batch_op.alter_column(column_name="execution_date", type_=TIMESTAMP, nullable=False) task_fail_batch_op.create_index( - 'idx_task_fail_dag_task_date', ['dag_id', 'task_id', 'execution_date'], unique=False + "idx_task_fail_dag_task_date", ["dag_id", "task_id", "execution_date"], unique=False ) - with op.batch_alter_table('task_instance') as task_instance_batch_op: - task_instance_batch_op.drop_index('ti_state_lkp') + with op.batch_alter_table("task_instance") as task_instance_batch_op: + task_instance_batch_op.drop_index("ti_state_lkp") task_instance_batch_op.create_index( - 'ti_state_lkp', ['dag_id', 'task_id', 'execution_date', 'state'], unique=False + "ti_state_lkp", ["dag_id", "task_id", "execution_date", "state"], unique=False ) - constraint_dict = get_table_constraints(conn, 'dag_run') + constraint_dict = get_table_constraints(conn, "dag_run") for constraint, columns in constraint_dict.items(): - if 'dag_id' in columns: + if "dag_id" in columns: if constraint[1].lower().startswith("unique"): - op.drop_constraint(constraint[0], 'dag_run', type_='unique') + op.drop_constraint(constraint[0], "dag_run", type_="unique") # create filtered indexes conn.execute( """CREATE UNIQUE NONCLUSTERED INDEX idx_not_null_dag_id_execution_date @@ -188,25 +187,25 @@ def upgrade(): def downgrade(): """Reverse MSSQL backend compatibility improvements""" conn = op.get_bind() - if conn.dialect.name != 'mssql': + if conn.dialect.name != "mssql": return op.alter_column(table_name="xcom", column_name="timestamp", type_=TIMESTAMP, nullable=True) - with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op: - task_reschedule_batch_op.alter_column(column_name='end_date', type_=TIMESTAMP, nullable=True) - task_reschedule_batch_op.alter_column(column_name='reschedule_date', type_=TIMESTAMP, nullable=True) - task_reschedule_batch_op.alter_column(column_name='start_date', type_=TIMESTAMP, nullable=True) - with op.batch_alter_table('task_fail') as task_fail_batch_op: - task_fail_batch_op.drop_index('idx_task_fail_dag_task_date') + with op.batch_alter_table("task_reschedule") as task_reschedule_batch_op: + task_reschedule_batch_op.alter_column(column_name="end_date", type_=TIMESTAMP, nullable=True) + task_reschedule_batch_op.alter_column(column_name="reschedule_date", type_=TIMESTAMP, nullable=True) + task_reschedule_batch_op.alter_column(column_name="start_date", type_=TIMESTAMP, nullable=True) + with op.batch_alter_table("task_fail") as task_fail_batch_op: + task_fail_batch_op.drop_index("idx_task_fail_dag_task_date") task_fail_batch_op.alter_column(column_name="execution_date", type_=TIMESTAMP, nullable=False) task_fail_batch_op.create_index( - 'idx_task_fail_dag_task_date', ['dag_id', 'task_id', 'execution_date'], unique=False + "idx_task_fail_dag_task_date", ["dag_id", "task_id", "execution_date"], unique=False ) - with op.batch_alter_table('task_instance') as task_instance_batch_op: - task_instance_batch_op.drop_index('ti_state_lkp') + with op.batch_alter_table("task_instance") as task_instance_batch_op: + task_instance_batch_op.drop_index("ti_state_lkp") task_instance_batch_op.create_index( - 'ti_state_lkp', ['dag_id', 'task_id', 'execution_date'], unique=False + "ti_state_lkp", ["dag_id", "task_id", "execution_date"], unique=False ) - op.create_unique_constraint('UQ__dag_run__dag_id_run_id', 'dag_run', ['dag_id', 'run_id']) - op.create_unique_constraint('UQ__dag_run__dag_id_execution_date', 'dag_run', ['dag_id', 'execution_date']) - op.drop_index('idx_not_null_dag_id_execution_date', table_name='dag_run') - op.drop_index('idx_not_null_dag_id_run_id', table_name='dag_run') + op.create_unique_constraint("UQ__dag_run__dag_id_run_id", "dag_run", ["dag_id", "run_id"]) + op.create_unique_constraint("UQ__dag_run__dag_id_execution_date", "dag_run", ["dag_id", "execution_date"]) + op.drop_index("idx_not_null_dag_id_execution_date", table_name="dag_run") + op.drop_index("idx_not_null_dag_id_run_id", table_name="dag_run") diff --git a/airflow/migrations/versions/0089_2_2_0_make_xcom_pkey_columns_non_nullable.py b/airflow/migrations/versions/0089_2_2_0_make_xcom_pkey_columns_non_nullable.py index 45a4559056418..68ef900a2db98 100644 --- a/airflow/migrations/versions/0089_2_2_0_make_xcom_pkey_columns_non_nullable.py +++ b/airflow/migrations/versions/0089_2_2_0_make_xcom_pkey_columns_non_nullable.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Make XCom primary key columns non-nullable Revision ID: e9304a3141f0 @@ -23,37 +22,39 @@ Create Date: 2021-04-06 13:22:02.197726 """ +from __future__ import annotations + from alembic import op from airflow.migrations.db_types import TIMESTAMP, StringID # revision identifiers, used by Alembic. -revision = 'e9304a3141f0' -down_revision = '83f031fd9f1c' +revision = "e9304a3141f0" +down_revision = "83f031fd9f1c" branch_labels = None depends_on = None -airflow_version = '2.2.0' +airflow_version = "2.2.0" def upgrade(): """Apply Make XCom primary key columns non-nullable""" conn = op.get_bind() - with op.batch_alter_table('xcom') as bop: + with op.batch_alter_table("xcom") as bop: bop.alter_column("key", type_=StringID(length=512), nullable=False) bop.alter_column("execution_date", type_=TIMESTAMP, nullable=False) - if conn.dialect.name == 'mssql': - bop.create_primary_key('pk_xcom', ['dag_id', 'task_id', 'key', 'execution_date']) + if conn.dialect.name == "mssql": + bop.create_primary_key("pk_xcom", ["dag_id", "task_id", "key", "execution_date"]) def downgrade(): """Unapply Make XCom primary key columns non-nullable""" conn = op.get_bind() - with op.batch_alter_table('xcom') as bop: + with op.batch_alter_table("xcom") as bop: # regardless of what the model defined, the `key` and `execution_date` # columns were always non-nullable for mysql, sqlite and postgres, so leave them alone - if conn.dialect.name == 'mssql': - bop.drop_constraint('pk_xcom', 'primary') + if conn.dialect.name == "mssql": + bop.drop_constraint("pk_xcom", "primary") # execution_date and key wasn't nullable in the other databases bop.alter_column("key", type_=StringID(length=512), nullable=True) bop.alter_column("execution_date", type_=TIMESTAMP, nullable=True) diff --git a/airflow/migrations/versions/0090_2_2_0_rename_concurrency_column_in_dag_table_.py b/airflow/migrations/versions/0090_2_2_0_rename_concurrency_column_in_dag_table_.py index f71fdf38821c8..d1816e32ed42b 100644 --- a/airflow/migrations/versions/0090_2_2_0_rename_concurrency_column_in_dag_table_.py +++ b/airflow/migrations/versions/0090_2_2_0_rename_concurrency_column_in_dag_table_.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Rename ``concurrency`` column in ``dag`` table to`` max_active_tasks`` Revision ID: 30867afad44a @@ -23,16 +22,17 @@ Create Date: 2021-06-04 22:11:19.849981 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = '30867afad44a' -down_revision = 'e9304a3141f0' +revision = "30867afad44a" +down_revision = "e9304a3141f0" branch_labels = None depends_on = None -airflow_version = '2.2.0' +airflow_version = "2.2.0" def upgrade(): @@ -42,10 +42,10 @@ def upgrade(): if is_sqlite: op.execute("PRAGMA foreign_keys=off") - with op.batch_alter_table('dag') as batch_op: + with op.batch_alter_table("dag") as batch_op: batch_op.alter_column( - 'concurrency', - new_column_name='max_active_tasks', + "concurrency", + new_column_name="max_active_tasks", type_=sa.Integer(), nullable=False, ) @@ -55,10 +55,10 @@ def upgrade(): def downgrade(): """Unapply Rename ``concurrency`` column in ``dag`` table to`` max_active_tasks``""" - with op.batch_alter_table('dag') as batch_op: + with op.batch_alter_table("dag") as batch_op: batch_op.alter_column( - 'max_active_tasks', - new_column_name='concurrency', + "max_active_tasks", + new_column_name="concurrency", type_=sa.Integer(), nullable=False, ) diff --git a/airflow/migrations/versions/0091_2_2_0_add_trigger_table_and_task_info.py b/airflow/migrations/versions/0091_2_2_0_add_trigger_table_and_task_info.py index 0f83d4ff8dd59..34d32d9d8c914 100644 --- a/airflow/migrations/versions/0091_2_2_0_add_trigger_table_and_task_info.py +++ b/airflow/migrations/versions/0091_2_2_0_add_trigger_table_and_task_info.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Adds ``trigger`` table and deferrable operator columns to task instance Revision ID: 54bebd308c5f @@ -23,6 +22,7 @@ Create Date: 2021-04-14 12:56:40.688260 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op @@ -30,41 +30,41 @@ from airflow.utils.sqlalchemy import ExtendedJSON # revision identifiers, used by Alembic. -revision = '54bebd308c5f' -down_revision = '30867afad44a' +revision = "54bebd308c5f" +down_revision = "30867afad44a" branch_labels = None depends_on = None -airflow_version = '2.2.0' +airflow_version = "2.2.0" def upgrade(): """Apply Adds ``trigger`` table and deferrable operator columns to task instance""" op.create_table( - 'trigger', - sa.Column('id', sa.Integer(), primary_key=True, nullable=False), - sa.Column('classpath', sa.String(length=1000), nullable=False), - sa.Column('kwargs', ExtendedJSON(), nullable=False), - sa.Column('created_date', sa.DateTime(), nullable=False), - sa.Column('triggerer_id', sa.Integer(), nullable=True), + "trigger", + sa.Column("id", sa.Integer(), primary_key=True, nullable=False), + sa.Column("classpath", sa.String(length=1000), nullable=False), + sa.Column("kwargs", ExtendedJSON(), nullable=False), + sa.Column("created_date", sa.DateTime(), nullable=False), + sa.Column("triggerer_id", sa.Integer(), nullable=True), ) - with op.batch_alter_table('task_instance', schema=None) as batch_op: - batch_op.add_column(sa.Column('trigger_id', sa.Integer())) - batch_op.add_column(sa.Column('trigger_timeout', sa.DateTime())) - batch_op.add_column(sa.Column('next_method', sa.String(length=1000))) - batch_op.add_column(sa.Column('next_kwargs', ExtendedJSON())) + with op.batch_alter_table("task_instance", schema=None) as batch_op: + batch_op.add_column(sa.Column("trigger_id", sa.Integer())) + batch_op.add_column(sa.Column("trigger_timeout", sa.DateTime())) + batch_op.add_column(sa.Column("next_method", sa.String(length=1000))) + batch_op.add_column(sa.Column("next_kwargs", ExtendedJSON())) batch_op.create_foreign_key( - 'task_instance_trigger_id_fkey', 'trigger', ['trigger_id'], ['id'], ondelete="CASCADE" + "task_instance_trigger_id_fkey", "trigger", ["trigger_id"], ["id"], ondelete="CASCADE" ) - batch_op.create_index('ti_trigger_id', ['trigger_id']) + batch_op.create_index("ti_trigger_id", ["trigger_id"]) def downgrade(): """Unapply Adds ``trigger`` table and deferrable operator columns to task instance""" - with op.batch_alter_table('task_instance', schema=None) as batch_op: - batch_op.drop_constraint('task_instance_trigger_id_fkey', type_='foreignkey') - batch_op.drop_index('ti_trigger_id') - batch_op.drop_column('trigger_id') - batch_op.drop_column('trigger_timeout') - batch_op.drop_column('next_method') - batch_op.drop_column('next_kwargs') - op.drop_table('trigger') + with op.batch_alter_table("task_instance", schema=None) as batch_op: + batch_op.drop_constraint("task_instance_trigger_id_fkey", type_="foreignkey") + batch_op.drop_index("ti_trigger_id") + batch_op.drop_column("trigger_id") + batch_op.drop_column("trigger_timeout") + batch_op.drop_column("next_method") + batch_op.drop_column("next_kwargs") + op.drop_table("trigger") diff --git a/airflow/migrations/versions/0092_2_2_0_add_data_interval_start_end_to_dagmodel_and_dagrun.py b/airflow/migrations/versions/0092_2_2_0_add_data_interval_start_end_to_dagmodel_and_dagrun.py index 1aaa2230ee571..8a4fac006b40c 100644 --- a/airflow/migrations/versions/0092_2_2_0_add_data_interval_start_end_to_dagmodel_and_dagrun.py +++ b/airflow/migrations/versions/0092_2_2_0_add_data_interval_start_end_to_dagmodel_and_dagrun.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add data_interval_[start|end] to DagModel and DagRun. Revision ID: 142555e44c17 @@ -23,6 +22,7 @@ Create Date: 2021-06-09 08:28:02.089817 """ +from __future__ import annotations from alembic import op from sqlalchemy import Column @@ -34,7 +34,7 @@ down_revision = "54bebd308c5f" branch_labels = None depends_on = None -airflow_version = '2.2.0' +airflow_version = "2.2.0" def upgrade(): diff --git a/airflow/migrations/versions/0093_2_2_0_taskinstance_keyed_to_dagrun.py b/airflow/migrations/versions/0093_2_2_0_taskinstance_keyed_to_dagrun.py index b7bff79edfa6b..57cc7ac70c549 100644 --- a/airflow/migrations/versions/0093_2_2_0_taskinstance_keyed_to_dagrun.py +++ b/airflow/migrations/versions/0093_2_2_0_taskinstance_keyed_to_dagrun.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Change ``TaskInstance`` and ``TaskReschedule`` tables from execution_date to run_id. Revision ID: 7b2661a43ba3 @@ -23,6 +22,7 @@ Create Date: 2021-07-15 15:26:12.710749 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op @@ -34,33 +34,33 @@ ID_LEN = 250 # revision identifiers, used by Alembic. -revision = '7b2661a43ba3' -down_revision = '142555e44c17' +revision = "7b2661a43ba3" +down_revision = "142555e44c17" branch_labels = None depends_on = None -airflow_version = '2.2.0' +airflow_version = "2.2.0" # Just Enough Table to run the conditions for update. task_instance = table( - 'task_instance', - column('task_id', sa.String), - column('dag_id', sa.String), - column('run_id', sa.String), - column('execution_date', sa.TIMESTAMP), + "task_instance", + column("task_id", sa.String), + column("dag_id", sa.String), + column("run_id", sa.String), + column("execution_date", sa.TIMESTAMP), ) task_reschedule = table( - 'task_reschedule', - column('task_id', sa.String), - column('dag_id', sa.String), - column('run_id', sa.String), - column('execution_date', sa.TIMESTAMP), + "task_reschedule", + column("task_id", sa.String), + column("dag_id", sa.String), + column("run_id", sa.String), + column("execution_date", sa.TIMESTAMP), ) dag_run = table( - 'dag_run', - column('dag_id', sa.String), - column('run_id', sa.String), - column('execution_date', sa.TIMESTAMP), + "dag_run", + column("dag_id", sa.String), + column("run_id", sa.String), + column("execution_date", sa.TIMESTAMP), ) @@ -72,75 +72,77 @@ def upgrade(): dt_type = TIMESTAMP string_id_col_type = StringID() - if dialect_name == 'sqlite': + if dialect_name == "sqlite": naming_convention = { "uq": "%(table_name)s_%(column_0_N_name)s_key", } # The naming_convention force the previously un-named UNIQUE constraints to have the right name with op.batch_alter_table( - 'dag_run', naming_convention=naming_convention, recreate="always" + "dag_run", naming_convention=naming_convention, recreate="always" ) as batch_op: - batch_op.alter_column('dag_id', existing_type=string_id_col_type, nullable=False) - batch_op.alter_column('run_id', existing_type=string_id_col_type, nullable=False) - batch_op.alter_column('execution_date', existing_type=dt_type, nullable=False) - elif dialect_name == 'mysql': - with op.batch_alter_table('dag_run') as batch_op: + batch_op.alter_column("dag_id", existing_type=string_id_col_type, nullable=False) + batch_op.alter_column("run_id", existing_type=string_id_col_type, nullable=False) + batch_op.alter_column("execution_date", existing_type=dt_type, nullable=False) + elif dialect_name == "mysql": + with op.batch_alter_table("dag_run") as batch_op: batch_op.alter_column( - 'dag_id', existing_type=sa.String(length=ID_LEN), type_=string_id_col_type, nullable=False + "dag_id", existing_type=sa.String(length=ID_LEN), type_=string_id_col_type, nullable=False ) batch_op.alter_column( - 'run_id', existing_type=sa.String(length=ID_LEN), type_=string_id_col_type, nullable=False + "run_id", existing_type=sa.String(length=ID_LEN), type_=string_id_col_type, nullable=False ) - batch_op.alter_column('execution_date', existing_type=dt_type, nullable=False) - batch_op.drop_constraint('dag_id', 'unique') - batch_op.drop_constraint('dag_id_2', 'unique') + batch_op.alter_column("execution_date", existing_type=dt_type, nullable=False) + inspector = sa.inspect(conn.engine) + unique_keys = inspector.get_unique_constraints("dag_run") + for unique_key in unique_keys: + batch_op.drop_constraint(unique_key["name"], type_="unique") batch_op.create_unique_constraint( - 'dag_run_dag_id_execution_date_key', ['dag_id', 'execution_date'] + "dag_run_dag_id_execution_date_key", ["dag_id", "execution_date"] ) - batch_op.create_unique_constraint('dag_run_dag_id_run_id_key', ['dag_id', 'run_id']) - elif dialect_name == 'mssql': + batch_op.create_unique_constraint("dag_run_dag_id_run_id_key", ["dag_id", "run_id"]) + elif dialect_name == "mssql": - with op.batch_alter_table('dag_run') as batch_op: - batch_op.drop_index('idx_not_null_dag_id_execution_date') - batch_op.drop_index('idx_not_null_dag_id_run_id') + with op.batch_alter_table("dag_run") as batch_op: + batch_op.drop_index("idx_not_null_dag_id_execution_date") + batch_op.drop_index("idx_not_null_dag_id_run_id") - batch_op.drop_index('dag_id_state') - batch_op.drop_index('idx_dag_run_dag_id') - batch_op.drop_index('idx_dag_run_running_dags') - batch_op.drop_index('idx_dag_run_queued_dags') + batch_op.drop_index("dag_id_state") + batch_op.drop_index("idx_dag_run_dag_id") + batch_op.drop_index("idx_dag_run_running_dags") + batch_op.drop_index("idx_dag_run_queued_dags") - batch_op.alter_column('dag_id', existing_type=string_id_col_type, nullable=False) - batch_op.alter_column('execution_date', existing_type=dt_type, nullable=False) - batch_op.alter_column('run_id', existing_type=string_id_col_type, nullable=False) + batch_op.alter_column("dag_id", existing_type=string_id_col_type, nullable=False) + batch_op.alter_column("execution_date", existing_type=dt_type, nullable=False) + batch_op.alter_column("run_id", existing_type=string_id_col_type, nullable=False) # _Somehow_ mssql was missing these constraints entirely batch_op.create_unique_constraint( - 'dag_run_dag_id_execution_date_key', ['dag_id', 'execution_date'] + "dag_run_dag_id_execution_date_key", ["dag_id", "execution_date"] ) - batch_op.create_unique_constraint('dag_run_dag_id_run_id_key', ['dag_id', 'run_id']) + batch_op.create_unique_constraint("dag_run_dag_id_run_id_key", ["dag_id", "run_id"]) - batch_op.create_index('dag_id_state', ['dag_id', 'state'], unique=False) - batch_op.create_index('idx_dag_run_dag_id', ['dag_id']) + batch_op.create_index("dag_id_state", ["dag_id", "state"], unique=False) + batch_op.create_index("idx_dag_run_dag_id", ["dag_id"]) batch_op.create_index( - 'idx_dag_run_running_dags', + "idx_dag_run_running_dags", ["state", "dag_id"], mssql_where=sa.text("state='running'"), ) batch_op.create_index( - 'idx_dag_run_queued_dags', + "idx_dag_run_queued_dags", ["state", "dag_id"], mssql_where=sa.text("state='queued'"), ) else: # Make sure DagRun PK columns are non-nullable - with op.batch_alter_table('dag_run', schema=None) as batch_op: - batch_op.alter_column('dag_id', existing_type=string_id_col_type, nullable=False) - batch_op.alter_column('execution_date', existing_type=dt_type, nullable=False) - batch_op.alter_column('run_id', existing_type=string_id_col_type, nullable=False) + with op.batch_alter_table("dag_run", schema=None) as batch_op: + batch_op.alter_column("dag_id", existing_type=string_id_col_type, nullable=False) + batch_op.alter_column("execution_date", existing_type=dt_type, nullable=False) + batch_op.alter_column("run_id", existing_type=string_id_col_type, nullable=False) # First create column nullable - op.add_column('task_instance', sa.Column('run_id', type_=string_id_col_type, nullable=True)) - op.add_column('task_reschedule', sa.Column('run_id', type_=string_id_col_type, nullable=True)) + op.add_column("task_instance", sa.Column("run_id", type_=string_id_col_type, nullable=True)) + op.add_column("task_reschedule", sa.Column("run_id", type_=string_id_col_type, nullable=True)) # # TaskReschedule has a FK to TaskInstance, so we have to update that before @@ -149,23 +151,23 @@ def upgrade(): update_query = _multi_table_update(dialect_name, task_reschedule, task_reschedule.c.run_id) op.execute(update_query) - with op.batch_alter_table('task_reschedule', schema=None) as batch_op: + with op.batch_alter_table("task_reschedule", schema=None) as batch_op: batch_op.alter_column( - 'run_id', existing_type=string_id_col_type, existing_nullable=True, nullable=False + "run_id", existing_type=string_id_col_type, existing_nullable=True, nullable=False ) - batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', 'foreignkey') + batch_op.drop_constraint("task_reschedule_dag_task_date_fkey", "foreignkey") if dialect_name == "mysql": # Mysql creates an index and a constraint -- we have to drop both - batch_op.drop_index('task_reschedule_dag_task_date_fkey') + batch_op.drop_index("task_reschedule_dag_task_date_fkey") batch_op.alter_column( - 'dag_id', existing_type=sa.String(length=ID_LEN), type_=string_id_col_type, nullable=False + "dag_id", existing_type=sa.String(length=ID_LEN), type_=string_id_col_type, nullable=False ) - batch_op.drop_index('idx_task_reschedule_dag_task_date') + batch_op.drop_index("idx_task_reschedule_dag_task_date") # Then update the new column by selecting the right value from DagRun # But first we will drop and recreate indexes to make it faster - if dialect_name == 'postgresql': + if dialect_name == "postgresql": # Recreate task_instance, without execution_date and with dagrun.run_id op.execute( """ @@ -200,87 +202,87 @@ def upgrade(): INNER JOIN dag_run ON dag_run.dag_id = ti.dag_id AND dag_run.execution_date = ti.execution_date; """ ) - op.drop_table('task_instance') - op.rename_table('new_task_instance', 'task_instance') + op.drop_table("task_instance") + op.rename_table("new_task_instance", "task_instance") # Fix up columns after the 'create table as select' - with op.batch_alter_table('task_instance', schema=None) as batch_op: + with op.batch_alter_table("task_instance", schema=None) as batch_op: batch_op.alter_column( - 'pool', existing_type=string_id_col_type, existing_nullable=True, nullable=False + "pool", existing_type=string_id_col_type, existing_nullable=True, nullable=False ) - batch_op.alter_column('max_tries', existing_type=sa.Integer(), server_default="-1") + batch_op.alter_column("max_tries", existing_type=sa.Integer(), server_default="-1") batch_op.alter_column( - 'pool_slots', existing_type=sa.Integer(), existing_nullable=True, nullable=False + "pool_slots", existing_type=sa.Integer(), existing_nullable=True, nullable=False ) else: update_query = _multi_table_update(dialect_name, task_instance, task_instance.c.run_id) op.execute(update_query) - with op.batch_alter_table('task_instance', schema=None) as batch_op: - if dialect_name != 'postgresql': + with op.batch_alter_table("task_instance", schema=None) as batch_op: + if dialect_name != "postgresql": # TODO: Is this right for non-postgres? - if dialect_name == 'mssql': + if dialect_name == "mssql": constraints = get_mssql_table_constraints(conn, "task_instance") - pk, _ = constraints['PRIMARY KEY'].popitem() - batch_op.drop_constraint(pk, type_='primary') - elif dialect_name not in ('sqlite'): - batch_op.drop_constraint('task_instance_pkey', type_='primary') - batch_op.drop_index('ti_dag_date') - batch_op.drop_index('ti_state_lkp') - batch_op.drop_column('execution_date') + pk, _ = constraints["PRIMARY KEY"].popitem() + batch_op.drop_constraint(pk, type_="primary") + elif dialect_name not in ("sqlite"): + batch_op.drop_constraint("task_instance_pkey", type_="primary") + batch_op.drop_index("ti_dag_date") + batch_op.drop_index("ti_state_lkp") + batch_op.drop_column("execution_date") # Then make it non-nullable batch_op.alter_column( - 'run_id', existing_type=string_id_col_type, existing_nullable=True, nullable=False + "run_id", existing_type=string_id_col_type, existing_nullable=True, nullable=False ) batch_op.alter_column( - 'dag_id', existing_type=string_id_col_type, existing_nullable=True, nullable=False + "dag_id", existing_type=string_id_col_type, existing_nullable=True, nullable=False ) - batch_op.create_primary_key('task_instance_pkey', ['dag_id', 'task_id', 'run_id']) + batch_op.create_primary_key("task_instance_pkey", ["dag_id", "task_id", "run_id"]) batch_op.create_foreign_key( - 'task_instance_dag_run_fkey', - 'dag_run', - ['dag_id', 'run_id'], - ['dag_id', 'run_id'], - ondelete='CASCADE', + "task_instance_dag_run_fkey", + "dag_run", + ["dag_id", "run_id"], + ["dag_id", "run_id"], + ondelete="CASCADE", ) - batch_op.create_index('ti_dag_run', ['dag_id', 'run_id']) - batch_op.create_index('ti_state_lkp', ['dag_id', 'task_id', 'run_id', 'state']) - if dialect_name == 'postgresql': - batch_op.create_index('ti_dag_state', ['dag_id', 'state']) - batch_op.create_index('ti_job_id', ['job_id']) - batch_op.create_index('ti_pool', ['pool', 'state', 'priority_weight']) - batch_op.create_index('ti_state', ['state']) + batch_op.create_index("ti_dag_run", ["dag_id", "run_id"]) + batch_op.create_index("ti_state_lkp", ["dag_id", "task_id", "run_id", "state"]) + if dialect_name == "postgresql": + batch_op.create_index("ti_dag_state", ["dag_id", "state"]) + batch_op.create_index("ti_job_id", ["job_id"]) + batch_op.create_index("ti_pool", ["pool", "state", "priority_weight"]) + batch_op.create_index("ti_state", ["state"]) batch_op.create_foreign_key( - 'task_instance_trigger_id_fkey', 'trigger', ['trigger_id'], ['id'], ondelete="CASCADE" + "task_instance_trigger_id_fkey", "trigger", ["trigger_id"], ["id"], ondelete="CASCADE" ) - batch_op.create_index('ti_trigger_id', ['trigger_id']) + batch_op.create_index("ti_trigger_id", ["trigger_id"]) - with op.batch_alter_table('task_reschedule', schema=None) as batch_op: - batch_op.drop_column('execution_date') + with op.batch_alter_table("task_reschedule", schema=None) as batch_op: + batch_op.drop_column("execution_date") batch_op.create_index( - 'idx_task_reschedule_dag_task_run', - ['dag_id', 'task_id', 'run_id'], + "idx_task_reschedule_dag_task_run", + ["dag_id", "task_id", "run_id"], unique=False, ) # _Now_ there is a unique constraint on the columns in TI we can re-create the FK from TaskReschedule batch_op.create_foreign_key( - 'task_reschedule_ti_fkey', - 'task_instance', - ['dag_id', 'task_id', 'run_id'], - ['dag_id', 'task_id', 'run_id'], - ondelete='CASCADE', + "task_reschedule_ti_fkey", + "task_instance", + ["dag_id", "task_id", "run_id"], + ["dag_id", "task_id", "run_id"], + ondelete="CASCADE", ) # https://docs.microsoft.com/en-us/sql/relational-databases/errors-events/mssqlserver-1785-database-engine-error?view=sql-server-ver15 - ondelete = 'CASCADE' if dialect_name != 'mssql' else 'NO ACTION' + ondelete = "CASCADE" if dialect_name != "mssql" else "NO ACTION" batch_op.create_foreign_key( - 'task_reschedule_dr_fkey', - 'dag_run', - ['dag_id', 'run_id'], - ['dag_id', 'run_id'], + "task_reschedule_dr_fkey", + "dag_run", + ["dag_id", "run_id"], + ["dag_id", "run_id"], ondelete=ondelete, ) @@ -291,8 +293,8 @@ def downgrade(): dt_type = TIMESTAMP string_id_col_type = StringID() - op.add_column('task_instance', sa.Column('execution_date', dt_type, nullable=True)) - op.add_column('task_reschedule', sa.Column('execution_date', dt_type, nullable=True)) + op.add_column("task_instance", sa.Column("execution_date", dt_type, nullable=True)) + op.add_column("task_reschedule", sa.Column("execution_date", dt_type, nullable=True)) update_query = _multi_table_update(dialect_name, task_instance, task_instance.c.execution_date) op.execute(update_query) @@ -300,71 +302,75 @@ def downgrade(): update_query = _multi_table_update(dialect_name, task_reschedule, task_reschedule.c.execution_date) op.execute(update_query) - with op.batch_alter_table('task_reschedule', schema=None) as batch_op: - batch_op.alter_column('execution_date', existing_type=dt_type, existing_nullable=True, nullable=False) + with op.batch_alter_table("task_reschedule", schema=None) as batch_op: + batch_op.alter_column("execution_date", existing_type=dt_type, existing_nullable=True, nullable=False) # Can't drop PK index while there is a FK referencing it - batch_op.drop_constraint('task_reschedule_ti_fkey', type_='foreignkey') - batch_op.drop_constraint('task_reschedule_dr_fkey', type_='foreignkey') - batch_op.drop_index('idx_task_reschedule_dag_task_run') - - with op.batch_alter_table('task_instance', schema=None) as batch_op: - batch_op.drop_constraint('task_instance_pkey', type_='primary') - batch_op.alter_column('execution_date', existing_type=dt_type, existing_nullable=True, nullable=False) - if dialect_name != 'mssql': + batch_op.drop_constraint("task_reschedule_ti_fkey", type_="foreignkey") + batch_op.drop_constraint("task_reschedule_dr_fkey", type_="foreignkey") + batch_op.drop_index("idx_task_reschedule_dag_task_run") + + with op.batch_alter_table("task_instance", schema=None) as batch_op: + batch_op.drop_constraint("task_instance_pkey", type_="primary") + batch_op.alter_column("execution_date", existing_type=dt_type, existing_nullable=True, nullable=False) + if dialect_name != "mssql": batch_op.alter_column( - 'dag_id', existing_type=string_id_col_type, existing_nullable=False, nullable=True + "dag_id", existing_type=string_id_col_type, existing_nullable=False, nullable=True ) - batch_op.create_primary_key('task_instance_pkey', ['dag_id', 'task_id', 'execution_date']) + batch_op.create_primary_key("task_instance_pkey", ["dag_id", "task_id", "execution_date"]) - batch_op.drop_constraint('task_instance_dag_run_fkey', type_='foreignkey') - batch_op.drop_index('ti_dag_run') - batch_op.drop_index('ti_state_lkp') - batch_op.create_index('ti_state_lkp', ['dag_id', 'task_id', 'execution_date', 'state']) - batch_op.create_index('ti_dag_date', ['dag_id', 'execution_date'], unique=False) + batch_op.drop_constraint("task_instance_dag_run_fkey", type_="foreignkey") + batch_op.drop_index("ti_dag_run") + batch_op.drop_index("ti_state_lkp") + batch_op.create_index("ti_state_lkp", ["dag_id", "task_id", "execution_date", "state"]) + batch_op.create_index("ti_dag_date", ["dag_id", "execution_date"], unique=False) - batch_op.drop_column('run_id') + batch_op.drop_column("run_id") - with op.batch_alter_table('task_reschedule', schema=None) as batch_op: - batch_op.drop_column('run_id') + with op.batch_alter_table("task_reschedule", schema=None) as batch_op: + batch_op.drop_column("run_id") batch_op.create_index( - 'idx_task_reschedule_dag_task_date', - ['dag_id', 'task_id', 'execution_date'], + "idx_task_reschedule_dag_task_date", + ["dag_id", "task_id", "execution_date"], unique=False, ) # Can only create FK once there is an index on these columns batch_op.create_foreign_key( - 'task_reschedule_dag_task_date_fkey', - 'task_instance', - ['dag_id', 'task_id', 'execution_date'], - ['dag_id', 'task_id', 'execution_date'], - ondelete='CASCADE', + "task_reschedule_dag_task_date_fkey", + "task_instance", + ["dag_id", "task_id", "execution_date"], + ["dag_id", "task_id", "execution_date"], + ondelete="CASCADE", ) + if dialect_name == "mysql": + batch_op.create_index( + "task_reschedule_dag_task_date_fkey", ["dag_id", "execution_date"], unique=False + ) if dialect_name == "mssql": - with op.batch_alter_table('dag_run', schema=None) as batch_op: - batch_op.drop_constraint('dag_run_dag_id_execution_date_key', 'unique') - batch_op.drop_constraint('dag_run_dag_id_run_id_key', 'unique') - batch_op.drop_index('dag_id_state') - batch_op.drop_index('idx_dag_run_running_dags') - batch_op.drop_index('idx_dag_run_queued_dags') - batch_op.drop_index('idx_dag_run_dag_id') + with op.batch_alter_table("dag_run", schema=None) as batch_op: + batch_op.drop_constraint("dag_run_dag_id_execution_date_key", "unique") + batch_op.drop_constraint("dag_run_dag_id_run_id_key", "unique") + batch_op.drop_index("dag_id_state") + batch_op.drop_index("idx_dag_run_running_dags") + batch_op.drop_index("idx_dag_run_queued_dags") + batch_op.drop_index("idx_dag_run_dag_id") - batch_op.alter_column('dag_id', existing_type=string_id_col_type, nullable=True) - batch_op.alter_column('execution_date', existing_type=dt_type, nullable=True) - batch_op.alter_column('run_id', existing_type=string_id_col_type, nullable=True) + batch_op.alter_column("dag_id", existing_type=string_id_col_type, nullable=True) + batch_op.alter_column("execution_date", existing_type=dt_type, nullable=True) + batch_op.alter_column("run_id", existing_type=string_id_col_type, nullable=True) - batch_op.create_index('dag_id_state', ['dag_id', 'state'], unique=False) - batch_op.create_index('idx_dag_run_dag_id', ['dag_id']) + batch_op.create_index("dag_id_state", ["dag_id", "state"], unique=False) + batch_op.create_index("idx_dag_run_dag_id", ["dag_id"]) batch_op.create_index( - 'idx_dag_run_running_dags', + "idx_dag_run_running_dags", ["state", "dag_id"], mssql_where=sa.text("state='running'"), ) batch_op.create_index( - 'idx_dag_run_queued_dags', + "idx_dag_run_queued_dags", ["state", "dag_id"], mssql_where=sa.text("state='queued'"), ) @@ -379,12 +385,12 @@ def downgrade(): WHERE dag_id IS NOT NULL and run_id is not null""" ) else: - with op.batch_alter_table('dag_run', schema=None) as batch_op: - batch_op.drop_index('dag_id_state') - batch_op.alter_column('run_id', existing_type=sa.VARCHAR(length=250), nullable=True) - batch_op.alter_column('execution_date', existing_type=dt_type, nullable=True) - batch_op.alter_column('dag_id', existing_type=sa.VARCHAR(length=250), nullable=True) - batch_op.create_index('dag_id_state', ['dag_id', 'state'], unique=False) + with op.batch_alter_table("dag_run", schema=None) as batch_op: + batch_op.drop_index("dag_id_state") + batch_op.alter_column("run_id", existing_type=sa.VARCHAR(length=250), nullable=True) + batch_op.alter_column("execution_date", existing_type=dt_type, nullable=True) + batch_op.alter_column("dag_id", existing_type=sa.VARCHAR(length=250), nullable=True) + batch_op.create_index("dag_id_state", ["dag_id", "state"], unique=False) def _multi_table_update(dialect_name, target, column): diff --git a/airflow/migrations/versions/0094_2_2_3_add_has_import_errors_column_to_dagmodel.py b/airflow/migrations/versions/0094_2_2_3_add_has_import_errors_column_to_dagmodel.py index d401e241b862d..fe931e62ddf77 100644 --- a/airflow/migrations/versions/0094_2_2_3_add_has_import_errors_column_to_dagmodel.py +++ b/airflow/migrations/versions/0094_2_2_3_add_has_import_errors_column_to_dagmodel.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add has_import_errors column to DagModel Revision ID: be2bfac3da23 @@ -23,24 +22,25 @@ Create Date: 2021-11-04 20:33:11.009547 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = 'be2bfac3da23' -down_revision = '7b2661a43ba3' +revision = "be2bfac3da23" +down_revision = "7b2661a43ba3" branch_labels = None depends_on = None -airflow_version = '2.2.3' +airflow_version = "2.2.3" def upgrade(): """Apply Add has_import_errors column to DagModel""" - op.add_column("dag", sa.Column("has_import_errors", sa.Boolean(), server_default='0')) + op.add_column("dag", sa.Column("has_import_errors", sa.Boolean(), server_default="0")) def downgrade(): """Unapply Add has_import_errors column to DagModel""" - with op.batch_alter_table('dag') as batch_op: - batch_op.drop_column('has_import_errors', mssql_drop_default=True) + with op.batch_alter_table("dag") as batch_op: + batch_op.drop_column("has_import_errors", mssql_drop_default=True) diff --git a/airflow/migrations/versions/0095_2_2_4_add_session_table_to_db.py b/airflow/migrations/versions/0095_2_2_4_add_session_table_to_db.py index 3b70acbdab529..ba5f32253164a 100644 --- a/airflow/migrations/versions/0095_2_2_4_add_session_table_to_db.py +++ b/airflow/migrations/versions/0095_2_2_4_add_session_table_to_db.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Create a ``session`` table to store web session data Revision ID: c381b21cb7e4 @@ -23,30 +22,31 @@ Create Date: 2022-01-25 13:56:35.069429 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = 'c381b21cb7e4' -down_revision = 'be2bfac3da23' +revision = "c381b21cb7e4" +down_revision = "be2bfac3da23" branch_labels = None depends_on = None -airflow_version = '2.2.4' +airflow_version = "2.2.4" -TABLE_NAME = 'session' +TABLE_NAME = "session" def upgrade(): """Apply Create a ``session`` table to store web session data""" op.create_table( TABLE_NAME, - sa.Column('id', sa.Integer()), - sa.Column('session_id', sa.String(255)), - sa.Column('data', sa.LargeBinary()), - sa.Column('expiry', sa.DateTime()), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('session_id'), + sa.Column("id", sa.Integer()), + sa.Column("session_id", sa.String(255)), + sa.Column("data", sa.LargeBinary()), + sa.Column("expiry", sa.DateTime()), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("session_id"), ) diff --git a/airflow/migrations/versions/0096_2_2_4_adding_index_for_dag_id_in_job.py b/airflow/migrations/versions/0096_2_2_4_adding_index_for_dag_id_in_job.py index 2d08d7f922a58..447081a40a763 100644 --- a/airflow/migrations/versions/0096_2_2_4_adding_index_for_dag_id_in_job.py +++ b/airflow/migrations/versions/0096_2_2_4_adding_index_for_dag_id_in_job.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add index for ``dag_id`` column in ``job`` table. Revision ID: 587bdf053233 @@ -23,22 +22,23 @@ Create Date: 2021-12-14 10:20:12.482940 """ +from __future__ import annotations from alembic import op # revision identifiers, used by Alembic. -revision = '587bdf053233' -down_revision = 'c381b21cb7e4' +revision = "587bdf053233" +down_revision = "c381b21cb7e4" branch_labels = None depends_on = None -airflow_version = '2.2.4' +airflow_version = "2.2.4" def upgrade(): """Apply Add index for ``dag_id`` column in ``job`` table.""" - op.create_index('idx_job_dag_id', 'job', ['dag_id'], unique=False) + op.create_index("idx_job_dag_id", "job", ["dag_id"], unique=False) def downgrade(): """Unapply Add index for ``dag_id`` column in ``job`` table.""" - op.drop_index('idx_job_dag_id', table_name='job') + op.drop_index("idx_job_dag_id", table_name="job") diff --git a/airflow/migrations/versions/0097_2_3_0_increase_length_of_email_and_username.py b/airflow/migrations/versions/0097_2_3_0_increase_length_of_email_and_username.py index be028e83378fe..b62b3136ae9ce 100644 --- a/airflow/migrations/versions/0097_2_3_0_increase_length_of_email_and_username.py +++ b/airflow/migrations/versions/0097_2_3_0_increase_length_of_email_and_username.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Increase length of email and username in ``ab_user`` and ``ab_register_user`` table to ``256`` characters Revision ID: 5e3ec427fdd3 @@ -23,6 +22,7 @@ Create Date: 2021-12-01 11:49:26.390210 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op @@ -30,53 +30,53 @@ from airflow.migrations.utils import get_mssql_table_constraints # revision identifiers, used by Alembic. -revision = '5e3ec427fdd3' -down_revision = '587bdf053233' +revision = "5e3ec427fdd3" +down_revision = "587bdf053233" branch_labels = None depends_on = None -airflow_version = '2.3.0' +airflow_version = "2.3.0" def upgrade(): """Increase length of email from 64 to 256 characters""" - with op.batch_alter_table('ab_user') as batch_op: - batch_op.alter_column('username', type_=sa.String(256)) - batch_op.alter_column('email', type_=sa.String(256)) - with op.batch_alter_table('ab_register_user') as batch_op: - batch_op.alter_column('username', type_=sa.String(256)) - batch_op.alter_column('email', type_=sa.String(256)) + with op.batch_alter_table("ab_user") as batch_op: + batch_op.alter_column("username", type_=sa.String(256)) + batch_op.alter_column("email", type_=sa.String(256)) + with op.batch_alter_table("ab_register_user") as batch_op: + batch_op.alter_column("username", type_=sa.String(256)) + batch_op.alter_column("email", type_=sa.String(256)) def downgrade(): """Revert length of email from 256 to 64 characters""" conn = op.get_bind() - if conn.dialect.name != 'mssql': - with op.batch_alter_table('ab_user') as batch_op: - batch_op.alter_column('username', type_=sa.String(64), nullable=False) - batch_op.alter_column('email', type_=sa.String(64)) - with op.batch_alter_table('ab_register_user') as batch_op: - batch_op.alter_column('username', type_=sa.String(64)) - batch_op.alter_column('email', type_=sa.String(64)) + if conn.dialect.name != "mssql": + with op.batch_alter_table("ab_user") as batch_op: + batch_op.alter_column("username", type_=sa.String(64), nullable=False) + batch_op.alter_column("email", type_=sa.String(64)) + with op.batch_alter_table("ab_register_user") as batch_op: + batch_op.alter_column("username", type_=sa.String(64)) + batch_op.alter_column("email", type_=sa.String(64)) else: # MSSQL doesn't drop implicit unique constraints it created # We need to drop the two unique constraints explicitly - with op.batch_alter_table('ab_user') as batch_op: + with op.batch_alter_table("ab_user") as batch_op: # Drop the unique constraint on username and email - constraints = get_mssql_table_constraints(conn, 'ab_user') - unique_key, _ = constraints['UNIQUE'].popitem() - batch_op.drop_constraint(unique_key, type_='unique') - unique_key, _ = constraints['UNIQUE'].popitem() - batch_op.drop_constraint(unique_key, type_='unique') - batch_op.alter_column('username', type_=sa.String(64), nullable=False) - batch_op.create_unique_constraint(None, ['username']) - batch_op.alter_column('email', type_=sa.String(64)) - batch_op.create_unique_constraint(None, ['email']) + constraints = get_mssql_table_constraints(conn, "ab_user") + unique_key, _ = constraints["UNIQUE"].popitem() + batch_op.drop_constraint(unique_key, type_="unique") + unique_key, _ = constraints["UNIQUE"].popitem() + batch_op.drop_constraint(unique_key, type_="unique") + batch_op.alter_column("username", type_=sa.String(64), nullable=False) + batch_op.create_unique_constraint(None, ["username"]) + batch_op.alter_column("email", type_=sa.String(64)) + batch_op.create_unique_constraint(None, ["email"]) - with op.batch_alter_table('ab_register_user') as batch_op: + with op.batch_alter_table("ab_register_user") as batch_op: # Drop the unique constraint on username and email - constraints = get_mssql_table_constraints(conn, 'ab_register_user') - for k, _ in constraints.get('UNIQUE').items(): - batch_op.drop_constraint(k, type_='unique') - batch_op.alter_column('username', type_=sa.String(64)) - batch_op.create_unique_constraint(None, ['username']) - batch_op.alter_column('email', type_=sa.String(64)) + constraints = get_mssql_table_constraints(conn, "ab_register_user") + for k, _ in constraints.get("UNIQUE").items(): + batch_op.drop_constraint(k, type_="unique") + batch_op.alter_column("username", type_=sa.String(64)) + batch_op.create_unique_constraint(None, ["username"]) + batch_op.alter_column("email", type_=sa.String(64)) diff --git a/airflow/migrations/versions/0098_2_3_0_added_timetable_description_column.py b/airflow/migrations/versions/0098_2_3_0_added_timetable_description_column.py index 811ece55c54bc..9d9d8c72eea10 100644 --- a/airflow/migrations/versions/0098_2_3_0_added_timetable_description_column.py +++ b/airflow/migrations/versions/0098_2_3_0_added_timetable_description_column.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``timetable_description`` column to DagModel for UI. Revision ID: 786e3737b18f @@ -23,30 +22,31 @@ Create Date: 2021-10-15 13:33:04.754052 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = '786e3737b18f' -down_revision = '5e3ec427fdd3' +revision = "786e3737b18f" +down_revision = "5e3ec427fdd3" branch_labels = None depends_on = None -airflow_version = '2.3.0' +airflow_version = "2.3.0" def upgrade(): """Apply Add ``timetable_description`` column to DagModel for UI.""" - with op.batch_alter_table('dag', schema=None) as batch_op: - batch_op.add_column(sa.Column('timetable_description', sa.String(length=1000), nullable=True)) + with op.batch_alter_table("dag", schema=None) as batch_op: + batch_op.add_column(sa.Column("timetable_description", sa.String(length=1000), nullable=True)) def downgrade(): """Unapply Add ``timetable_description`` column to DagModel for UI.""" - is_sqlite = bool(op.get_bind().dialect.name == 'sqlite') + is_sqlite = bool(op.get_bind().dialect.name == "sqlite") if is_sqlite: - op.execute('PRAGMA foreign_keys=off') - with op.batch_alter_table('dag') as batch_op: - batch_op.drop_column('timetable_description') + op.execute("PRAGMA foreign_keys=off") + with op.batch_alter_table("dag") as batch_op: + batch_op.drop_column("timetable_description") if is_sqlite: - op.execute('PRAGMA foreign_keys=on') + op.execute("PRAGMA foreign_keys=on") diff --git a/airflow/migrations/versions/0099_2_3_0_add_task_log_filename_template_model.py b/airflow/migrations/versions/0099_2_3_0_add_task_log_filename_template_model.py index 05fc2767fa9fe..2fa870943c186 100644 --- a/airflow/migrations/versions/0099_2_3_0_add_task_log_filename_template_model.py +++ b/airflow/migrations/versions/0099_2_3_0_add_task_log_filename_template_model.py @@ -15,13 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``LogTemplate`` table to track changes to config values ``log_filename_template`` Revision ID: f9da662e7089 Revises: 786e3737b18f Create Date: 2021-12-09 06:11:21.044940 """ +from __future__ import annotations from alembic import op from sqlalchemy import Column, ForeignKey, Integer, Text @@ -34,7 +34,7 @@ down_revision = "786e3737b18f" branch_labels = None depends_on = None -airflow_version = '2.3.0' +airflow_version = "2.3.0" def upgrade(): diff --git a/airflow/migrations/versions/0100_2_3_0_add_taskmap_and_map_id_on_taskinstance.py b/airflow/migrations/versions/0100_2_3_0_add_taskmap_and_map_id_on_taskinstance.py index 741c02d0ff6fb..d11953009aeed 100644 --- a/airflow/migrations/versions/0100_2_3_0_add_taskmap_and_map_id_on_taskinstance.py +++ b/airflow/migrations/versions/0100_2_3_0_add_taskmap_and_map_id_on_taskinstance.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add ``map_index`` column to TaskInstance to identify task-mapping, and a ``task_map`` table to track mapping values from XCom. @@ -23,6 +22,7 @@ Revises: f9da662e7089 Create Date: 2021-12-13 22:59:41.052584 """ +from __future__ import annotations from alembic import op from sqlalchemy import CheckConstraint, Column, ForeignKeyConstraint, Integer, text @@ -35,7 +35,7 @@ down_revision = "f9da662e7089" branch_labels = None depends_on = None -airflow_version = '2.3.0' +airflow_version = "2.3.0" def upgrade(): diff --git a/airflow/migrations/versions/0101_2_3_0_add_data_compressed_to_serialized_dag.py b/airflow/migrations/versions/0101_2_3_0_add_data_compressed_to_serialized_dag.py index f7276d8b2cac2..34e832c4a0e96 100644 --- a/airflow/migrations/versions/0101_2_3_0_add_data_compressed_to_serialized_dag.py +++ b/airflow/migrations/versions/0101_2_3_0_add_data_compressed_to_serialized_dag.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """add data_compressed to serialized_dag Revision ID: a3bcd0914482 @@ -23,25 +22,26 @@ Create Date: 2022-02-03 22:40:59.841119 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = 'a3bcd0914482' -down_revision = 'e655c0453f75' +revision = "a3bcd0914482" +down_revision = "e655c0453f75" branch_labels = None depends_on = None -airflow_version = '2.3.0' +airflow_version = "2.3.0" def upgrade(): - with op.batch_alter_table('serialized_dag') as batch_op: - batch_op.alter_column('data', existing_type=sa.JSON, nullable=True) - batch_op.add_column(sa.Column('data_compressed', sa.LargeBinary, nullable=True)) + with op.batch_alter_table("serialized_dag") as batch_op: + batch_op.alter_column("data", existing_type=sa.JSON, nullable=True) + batch_op.add_column(sa.Column("data_compressed", sa.LargeBinary, nullable=True)) def downgrade(): - with op.batch_alter_table('serialized_dag') as batch_op: - batch_op.alter_column('data', existing_type=sa.JSON, nullable=False) - batch_op.drop_column('data_compressed') + with op.batch_alter_table("serialized_dag") as batch_op: + batch_op.alter_column("data", existing_type=sa.JSON, nullable=False) + batch_op.drop_column("data_compressed") diff --git a/airflow/migrations/versions/0102_2_3_0_switch_xcom_table_to_use_run_id.py b/airflow/migrations/versions/0102_2_3_0_switch_xcom_table_to_use_run_id.py index f987a1fe12606..b8b06e802c931 100644 --- a/airflow/migrations/versions/0102_2_3_0_switch_xcom_table_to_use_run_id.py +++ b/airflow/migrations/versions/0102_2_3_0_switch_xcom_table_to_use_run_id.py @@ -15,13 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Switch XCom table to use ``run_id`` and add ``map_index``. Revision ID: c306b5b5ae4a Revises: a3bcd0914482 Create Date: 2022-01-19 03:20:35.329037 """ +from __future__ import annotations + from typing import Sequence from alembic import op @@ -35,7 +36,7 @@ down_revision = "a3bcd0914482" branch_labels = None depends_on = None -airflow_version = '2.3.0' +airflow_version = "2.3.0" metadata = MetaData() @@ -167,8 +168,8 @@ def downgrade(): op.drop_table("xcom") op.rename_table("__airflow_tmp_xcom", "xcom") - if conn.dialect.name == 'mssql': - constraints = get_mssql_table_constraints(conn, 'xcom') - pk, _ = constraints['PRIMARY KEY'].popitem() - op.drop_constraint(pk, 'xcom', type_='primary') + if conn.dialect.name == "mssql": + constraints = get_mssql_table_constraints(conn, "xcom") + pk, _ = constraints["PRIMARY KEY"].popitem() + op.drop_constraint(pk, "xcom", type_="primary") op.create_primary_key("pk_xcom", "xcom", ["dag_id", "task_id", "execution_date", "key"]) diff --git a/airflow/migrations/versions/0103_2_3_0_add_callback_request_table.py b/airflow/migrations/versions/0103_2_3_0_add_callback_request_table.py index 637abe8aee1b9..09ad52727d0ec 100644 --- a/airflow/migrations/versions/0103_2_3_0_add_callback_request_table.py +++ b/airflow/migrations/versions/0103_2_3_0_add_callback_request_table.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """add callback request table Revision ID: c97c2ab6aa23 @@ -23,6 +22,7 @@ Create Date: 2022-01-28 21:11:11.857010 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op @@ -32,25 +32,25 @@ from airflow.utils.sqlalchemy import ExtendedJSON # revision identifiers, used by Alembic. -revision = 'c97c2ab6aa23' -down_revision = 'c306b5b5ae4a' +revision = "c97c2ab6aa23" +down_revision = "c306b5b5ae4a" branch_labels = None depends_on = None -airflow_version = '2.3.0' +airflow_version = "2.3.0" -TABLE_NAME = 'callback_request' +TABLE_NAME = "callback_request" def upgrade(): op.create_table( TABLE_NAME, - sa.Column('id', sa.Integer(), nullable=False, primary_key=True), - sa.Column('created_at', TIMESTAMP, default=func.now(), nullable=False), - sa.Column('priority_weight', sa.Integer(), default=1, nullable=False), - sa.Column('callback_data', ExtendedJSON, nullable=False), - sa.Column('callback_type', sa.String(20), nullable=False), - sa.Column('dag_directory', sa.String(length=1000), nullable=True), - sa.PrimaryKeyConstraint('id'), + sa.Column("id", sa.Integer(), nullable=False, primary_key=True), + sa.Column("created_at", TIMESTAMP, default=func.now, nullable=False), + sa.Column("priority_weight", sa.Integer(), default=1, nullable=False), + sa.Column("callback_data", ExtendedJSON, nullable=False), + sa.Column("callback_type", sa.String(20), nullable=False), + sa.Column("dag_directory", sa.String(length=1000), nullable=True), + sa.PrimaryKeyConstraint("id"), ) diff --git a/airflow/migrations/versions/0104_2_3_0_migrate_rtif_to_use_run_id_and_map_index.py b/airflow/migrations/versions/0104_2_3_0_migrate_rtif_to_use_run_id_and_map_index.py index f64f95b2b9bde..ff52470f42164 100644 --- a/airflow/migrations/versions/0104_2_3_0_migrate_rtif_to_use_run_id_and_map_index.py +++ b/airflow/migrations/versions/0104_2_3_0_migrate_rtif_to_use_run_id_and_map_index.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Migrate RTIF to use run_id and map_index Revision ID: 4eaab2fe6582 @@ -23,6 +22,7 @@ Create Date: 2022-03-03 17:48:29.955821 """ +from __future__ import annotations import sqlalchemy as sa from alembic import op @@ -35,11 +35,11 @@ ID_LEN = 250 # revision identifiers, used by Alembic. -revision = '4eaab2fe6582' -down_revision = 'c97c2ab6aa23' +revision = "4eaab2fe6582" +down_revision = "c97c2ab6aa23" branch_labels = None depends_on = None -airflow_version = '2.3.0' +airflow_version = "2.3.0" # Just Enough Table to run the conditions for update. @@ -49,42 +49,42 @@ def tables(for_downgrade=False): global task_instance, rendered_task_instance_fields, dag_run metadata = sa.MetaData() task_instance = sa.Table( - 'task_instance', + "task_instance", metadata, - sa.Column('task_id', StringID()), - sa.Column('dag_id', StringID()), - sa.Column('run_id', StringID()), - sa.Column('execution_date', TIMESTAMP), + sa.Column("task_id", StringID()), + sa.Column("dag_id", StringID()), + sa.Column("run_id", StringID()), + sa.Column("execution_date", TIMESTAMP), ) rendered_task_instance_fields = sa.Table( - 'rendered_task_instance_fields', + "rendered_task_instance_fields", metadata, - sa.Column('dag_id', StringID()), - sa.Column('task_id', StringID()), - sa.Column('run_id', StringID()), - sa.Column('execution_date', TIMESTAMP), - sa.Column('rendered_fields', sqlalchemy_jsonfield.JSONField(), nullable=False), - sa.Column('k8s_pod_yaml', sqlalchemy_jsonfield.JSONField(), nullable=True), + sa.Column("dag_id", StringID()), + sa.Column("task_id", StringID()), + sa.Column("run_id", StringID()), + sa.Column("execution_date", TIMESTAMP), + sa.Column("rendered_fields", sqlalchemy_jsonfield.JSONField(), nullable=False), + sa.Column("k8s_pod_yaml", sqlalchemy_jsonfield.JSONField(), nullable=True), ) if for_downgrade: rendered_task_instance_fields.append_column( - sa.Column('map_index', sa.Integer(), server_default='-1'), + sa.Column("map_index", sa.Integer(), server_default="-1"), ) rendered_task_instance_fields.append_constraint( ForeignKeyConstraint( - ['dag_id', 'run_id'], + ["dag_id", "run_id"], ["dag_run.dag_id", "dag_run.run_id"], - name='rtif_dag_run_fkey', + name="rtif_dag_run_fkey", ondelete="CASCADE", ), ) dag_run = sa.Table( - 'dag_run', + "dag_run", metadata, - sa.Column('dag_id', StringID()), - sa.Column('run_id', StringID()), - sa.Column('execution_date', TIMESTAMP), + sa.Column("dag_id", StringID()), + sa.Column("run_id", StringID()), + sa.Column("execution_date", TIMESTAMP), ) @@ -110,44 +110,44 @@ def upgrade(): tables() dialect_name = op.get_bind().dialect.name - with op.batch_alter_table('rendered_task_instance_fields') as batch_op: - batch_op.add_column(sa.Column('map_index', sa.Integer(), server_default='-1', nullable=False)) + with op.batch_alter_table("rendered_task_instance_fields") as batch_op: + batch_op.add_column(sa.Column("map_index", sa.Integer(), server_default="-1", nullable=False)) rendered_task_instance_fields.append_column( - sa.Column('map_index', sa.Integer(), server_default='-1', nullable=False) + sa.Column("map_index", sa.Integer(), server_default="-1", nullable=False) ) - batch_op.add_column(sa.Column('run_id', type_=StringID(), nullable=True)) + batch_op.add_column(sa.Column("run_id", type_=StringID(), nullable=True)) update_query = _multi_table_update( dialect_name, rendered_task_instance_fields, rendered_task_instance_fields.c.run_id ) op.execute(update_query) with op.batch_alter_table( - 'rendered_task_instance_fields', copy_from=rendered_task_instance_fields + "rendered_task_instance_fields", copy_from=rendered_task_instance_fields ) as batch_op: - if dialect_name == 'mssql': - constraints = get_mssql_table_constraints(op.get_bind(), 'rendered_task_instance_fields') - pk, _ = constraints['PRIMARY KEY'].popitem() - batch_op.drop_constraint(pk, type_='primary') - elif dialect_name != 'sqlite': - batch_op.drop_constraint('rendered_task_instance_fields_pkey', type_='primary') - batch_op.alter_column('run_id', existing_type=StringID(), existing_nullable=True, nullable=False) - batch_op.drop_column('execution_date') + if dialect_name == "mssql": + constraints = get_mssql_table_constraints(op.get_bind(), "rendered_task_instance_fields") + pk, _ = constraints["PRIMARY KEY"].popitem() + batch_op.drop_constraint(pk, type_="primary") + elif dialect_name != "sqlite": + batch_op.drop_constraint("rendered_task_instance_fields_pkey", type_="primary") + batch_op.alter_column("run_id", existing_type=StringID(), existing_nullable=True, nullable=False) + batch_op.drop_column("execution_date") batch_op.create_primary_key( - 'rendered_task_instance_fields_pkey', ['dag_id', 'task_id', 'run_id', 'map_index'] + "rendered_task_instance_fields_pkey", ["dag_id", "task_id", "run_id", "map_index"] ) batch_op.create_foreign_key( - 'rtif_ti_fkey', - 'task_instance', - ['dag_id', 'task_id', 'run_id', 'map_index'], - ['dag_id', 'task_id', 'run_id', 'map_index'], - ondelete='CASCADE', + "rtif_ti_fkey", + "task_instance", + ["dag_id", "task_id", "run_id", "map_index"], + ["dag_id", "task_id", "run_id", "map_index"], + ondelete="CASCADE", ) def downgrade(): tables(for_downgrade=True) dialect_name = op.get_bind().dialect.name - op.add_column('rendered_task_instance_fields', sa.Column('execution_date', TIMESTAMP, nullable=True)) + op.add_column("rendered_task_instance_fields", sa.Column("execution_date", TIMESTAMP, nullable=True)) update_query = _multi_table_update( dialect_name, rendered_task_instance_fields, rendered_task_instance_fields.c.execution_date @@ -155,14 +155,14 @@ def downgrade(): op.execute(update_query) with op.batch_alter_table( - 'rendered_task_instance_fields', copy_from=rendered_task_instance_fields + "rendered_task_instance_fields", copy_from=rendered_task_instance_fields ) as batch_op: - batch_op.alter_column('execution_date', existing_type=TIMESTAMP, nullable=False) - if dialect_name != 'sqlite': - batch_op.drop_constraint('rtif_ti_fkey', type_='foreignkey') - batch_op.drop_constraint('rendered_task_instance_fields_pkey', type_='primary') + batch_op.alter_column("execution_date", existing_type=TIMESTAMP, nullable=False) + if dialect_name != "sqlite": + batch_op.drop_constraint("rtif_ti_fkey", type_="foreignkey") + batch_op.drop_constraint("rendered_task_instance_fields_pkey", type_="primary") batch_op.create_primary_key( - 'rendered_task_instance_fields_pkey', ['dag_id', 'task_id', 'execution_date'] + "rendered_task_instance_fields_pkey", ["dag_id", "task_id", "execution_date"] ) - batch_op.drop_column('map_index', mssql_drop_default=True) - batch_op.drop_column('run_id') + batch_op.drop_column("map_index", mssql_drop_default=True) + batch_op.drop_column("run_id") diff --git a/airflow/migrations/versions/0105_2_3_0_add_map_index_to_taskfail.py b/airflow/migrations/versions/0105_2_3_0_add_map_index_to_taskfail.py index 303e1eda58241..d65cf6b548947 100644 --- a/airflow/migrations/versions/0105_2_3_0_add_map_index_to_taskfail.py +++ b/airflow/migrations/versions/0105_2_3_0_add_map_index_to_taskfail.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add map_index to TaskFail Drop index idx_task_fail_dag_task_date @@ -28,8 +27,7 @@ Revises: 4eaab2fe6582 Create Date: 2022-03-14 10:31:11.220720 """ - -from typing import List +from __future__ import annotations import sqlalchemy as sa from alembic import op @@ -38,11 +36,11 @@ from airflow.migrations.db_types import TIMESTAMP, StringID # revision identifiers, used by Alembic. -revision = '48925b2719cb' -down_revision = '4eaab2fe6582' +revision = "48925b2719cb" +down_revision = "4eaab2fe6582" branch_labels = None depends_on = None -airflow_version = '2.3.0' +airflow_version = "2.3.0" ID_LEN = 250 @@ -51,29 +49,29 @@ def tables(): global task_instance, task_fail, dag_run metadata = sa.MetaData() task_instance = sa.Table( - 'task_instance', + "task_instance", metadata, - sa.Column('task_id', StringID()), - sa.Column('dag_id', StringID()), - sa.Column('run_id', StringID()), - sa.Column('map_index', sa.Integer(), server_default='-1'), - sa.Column('execution_date', TIMESTAMP), + sa.Column("task_id", StringID()), + sa.Column("dag_id", StringID()), + sa.Column("run_id", StringID()), + sa.Column("map_index", sa.Integer(), server_default="-1"), + sa.Column("execution_date", TIMESTAMP), ) task_fail = sa.Table( - 'task_fail', + "task_fail", metadata, - sa.Column('dag_id', StringID()), - sa.Column('task_id', StringID()), - sa.Column('run_id', StringID()), - sa.Column('map_index', StringID()), - sa.Column('execution_date', TIMESTAMP), + sa.Column("dag_id", StringID()), + sa.Column("task_id", StringID()), + sa.Column("run_id", StringID()), + sa.Column("map_index", StringID()), + sa.Column("execution_date", TIMESTAMP), ) dag_run = sa.Table( - 'dag_run', + "dag_run", metadata, - sa.Column('dag_id', StringID()), - sa.Column('run_id', StringID()), - sa.Column('execution_date', TIMESTAMP), + sa.Column("dag_id", StringID()), + sa.Column("run_id", StringID()), + sa.Column("execution_date", TIMESTAMP), ) @@ -81,7 +79,7 @@ def _update_value_from_dag_run( dialect_name: str, target_table: sa.Table, target_column: ColumnElement, - join_columns: List[str], + join_columns: list[str], ) -> Update: """ Grabs a value from the source table ``dag_run`` and updates target with this value. @@ -108,51 +106,51 @@ def upgrade(): tables() dialect_name = op.get_bind().dialect.name - op.drop_index('idx_task_fail_dag_task_date', table_name='task_fail') + op.drop_index("idx_task_fail_dag_task_date", table_name="task_fail") - with op.batch_alter_table('task_fail') as batch_op: - batch_op.add_column(sa.Column('map_index', sa.Integer(), server_default='-1', nullable=False)) - batch_op.add_column(sa.Column('run_id', type_=StringID(), nullable=True)) + with op.batch_alter_table("task_fail") as batch_op: + batch_op.add_column(sa.Column("map_index", sa.Integer(), server_default="-1", nullable=False)) + batch_op.add_column(sa.Column("run_id", type_=StringID(), nullable=True)) update_query = _update_value_from_dag_run( dialect_name=dialect_name, target_table=task_fail, target_column=task_fail.c.run_id, - join_columns=['dag_id', 'execution_date'], + join_columns=["dag_id", "execution_date"], ) op.execute(update_query) - with op.batch_alter_table('task_fail') as batch_op: - batch_op.alter_column('run_id', existing_type=StringID(), existing_nullable=True, nullable=False) - batch_op.drop_column('execution_date') + with op.batch_alter_table("task_fail") as batch_op: + batch_op.alter_column("run_id", existing_type=StringID(), existing_nullable=True, nullable=False) + batch_op.drop_column("execution_date") batch_op.create_foreign_key( - 'task_fail_ti_fkey', - 'task_instance', - ['dag_id', 'task_id', 'run_id', 'map_index'], - ['dag_id', 'task_id', 'run_id', 'map_index'], - ondelete='CASCADE', + "task_fail_ti_fkey", + "task_instance", + ["dag_id", "task_id", "run_id", "map_index"], + ["dag_id", "task_id", "run_id", "map_index"], + ondelete="CASCADE", ) def downgrade(): tables() dialect_name = op.get_bind().dialect.name - op.add_column('task_fail', sa.Column('execution_date', TIMESTAMP, nullable=True)) + op.add_column("task_fail", sa.Column("execution_date", TIMESTAMP, nullable=True)) update_query = _update_value_from_dag_run( dialect_name=dialect_name, target_table=task_fail, target_column=task_fail.c.execution_date, - join_columns=['dag_id', 'run_id'], + join_columns=["dag_id", "run_id"], ) op.execute(update_query) - with op.batch_alter_table('task_fail', copy_from=task_fail) as batch_op: - batch_op.alter_column('execution_date', existing_type=TIMESTAMP, nullable=False) - if dialect_name != 'sqlite': - batch_op.drop_constraint('task_fail_ti_fkey', type_='foreignkey') - batch_op.drop_column('map_index', mssql_drop_default=True) - batch_op.drop_column('run_id') + with op.batch_alter_table("task_fail") as batch_op: + batch_op.alter_column("execution_date", existing_type=TIMESTAMP, nullable=False) + if dialect_name != "sqlite": + batch_op.drop_constraint("task_fail_ti_fkey", type_="foreignkey") + batch_op.drop_column("map_index", mssql_drop_default=True) + batch_op.drop_column("run_id") op.create_index( - index_name='idx_task_fail_dag_task_date', - table_name='task_fail', - columns=['dag_id', 'task_id', 'execution_date'], + index_name="idx_task_fail_dag_task_date", + table_name="task_fail", + columns=["dag_id", "task_id", "execution_date"], unique=False, ) diff --git a/airflow/migrations/versions/0106_2_3_0_update_migration_for_fab_tables_to_add_missing_constraints.py b/airflow/migrations/versions/0106_2_3_0_update_migration_for_fab_tables_to_add_missing_constraints.py index 81f6f1f34492b..13823ac354011 100644 --- a/airflow/migrations/versions/0106_2_3_0_update_migration_for_fab_tables_to_add_missing_constraints.py +++ b/airflow/migrations/versions/0106_2_3_0_update_migration_for_fab_tables_to_add_missing_constraints.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Update migration for FAB tables to add missing constraints Revision ID: 909884dea523 @@ -23,7 +22,7 @@ Create Date: 2022-03-21 08:33:01.635688 """ - +from __future__ import annotations import sqlalchemy as sa from alembic import op @@ -31,78 +30,78 @@ from airflow.migrations.utils import get_mssql_table_constraints # revision identifiers, used by Alembic. -revision = '909884dea523' -down_revision = '48925b2719cb' +revision = "909884dea523" +down_revision = "48925b2719cb" branch_labels = None depends_on = None -airflow_version = '2.3.0' +airflow_version = "2.3.0" def upgrade(): """Apply Update migration for FAB tables to add missing constraints""" conn = op.get_bind() - if conn.dialect.name == 'sqlite': - op.execute('PRAGMA foreign_keys=OFF') - with op.batch_alter_table('ab_view_menu', schema=None) as batch_op: - batch_op.create_unique_constraint(batch_op.f('ab_view_menu_name_uq'), ['name']) - op.execute('PRAGMA foreign_keys=ON') - elif conn.dialect.name == 'mysql': - with op.batch_alter_table('ab_register_user', schema=None) as batch_op: - batch_op.alter_column('username', existing_type=sa.String(256), nullable=False) - batch_op.alter_column('email', existing_type=sa.String(256), nullable=False) - with op.batch_alter_table('ab_user', schema=None) as batch_op: - batch_op.alter_column('username', existing_type=sa.String(256), nullable=False) - batch_op.alter_column('email', existing_type=sa.String(256), nullable=False) - elif conn.dialect.name == 'mssql': - with op.batch_alter_table('ab_register_user') as batch_op: + if conn.dialect.name == "sqlite": + op.execute("PRAGMA foreign_keys=OFF") + with op.batch_alter_table("ab_view_menu", schema=None) as batch_op: + batch_op.create_unique_constraint(batch_op.f("ab_view_menu_name_uq"), ["name"]) + op.execute("PRAGMA foreign_keys=ON") + elif conn.dialect.name == "mysql": + with op.batch_alter_table("ab_register_user", schema=None) as batch_op: + batch_op.alter_column("username", existing_type=sa.String(256), nullable=False) + batch_op.alter_column("email", existing_type=sa.String(256), nullable=False) + with op.batch_alter_table("ab_user", schema=None) as batch_op: + batch_op.alter_column("username", existing_type=sa.String(256), nullable=False) + batch_op.alter_column("email", existing_type=sa.String(256), nullable=False) + elif conn.dialect.name == "mssql": + with op.batch_alter_table("ab_register_user") as batch_op: # Drop the unique constraint on username and email - constraints = get_mssql_table_constraints(conn, 'ab_register_user') - for k, _ in constraints.get('UNIQUE').items(): - batch_op.drop_constraint(k, type_='unique') - batch_op.alter_column('username', existing_type=sa.String(256), nullable=False) - batch_op.create_unique_constraint(None, ['username']) - batch_op.alter_column('email', existing_type=sa.String(256), nullable=False) - with op.batch_alter_table('ab_user') as batch_op: + constraints = get_mssql_table_constraints(conn, "ab_register_user") + for k, _ in constraints.get("UNIQUE").items(): + batch_op.drop_constraint(k, type_="unique") + batch_op.alter_column("username", existing_type=sa.String(256), nullable=False) + batch_op.create_unique_constraint(None, ["username"]) + batch_op.alter_column("email", existing_type=sa.String(256), nullable=False) + with op.batch_alter_table("ab_user") as batch_op: # Drop the unique constraint on username and email - constraints = get_mssql_table_constraints(conn, 'ab_user') - for k, _ in constraints.get('UNIQUE').items(): - batch_op.drop_constraint(k, type_='unique') - batch_op.alter_column('username', existing_type=sa.String(256), nullable=False) - batch_op.create_unique_constraint(None, ['username']) - batch_op.alter_column('email', existing_type=sa.String(256), nullable=False) - batch_op.create_unique_constraint(None, ['email']) + constraints = get_mssql_table_constraints(conn, "ab_user") + for k, _ in constraints.get("UNIQUE").items(): + batch_op.drop_constraint(k, type_="unique") + batch_op.alter_column("username", existing_type=sa.String(256), nullable=False) + batch_op.create_unique_constraint(None, ["username"]) + batch_op.alter_column("email", existing_type=sa.String(256), nullable=False) + batch_op.create_unique_constraint(None, ["email"]) def downgrade(): """Unapply Update migration for FAB tables to add missing constraints""" conn = op.get_bind() - if conn.dialect.name == 'sqlite': - op.execute('PRAGMA foreign_keys=OFF') - with op.batch_alter_table('ab_view_menu', schema=None) as batch_op: - batch_op.drop_constraint('ab_view_menu_name_uq', type_='unique') - op.execute('PRAGMA foreign_keys=ON') - elif conn.dialect.name == 'mysql': - with op.batch_alter_table('ab_user', schema=None) as batch_op: - batch_op.alter_column('email', existing_type=sa.String(256), nullable=True) - batch_op.alter_column('username', existing_type=sa.String(256), nullable=True, unique=True) - with op.batch_alter_table('ab_register_user', schema=None) as batch_op: - batch_op.alter_column('email', existing_type=sa.String(256), nullable=True) - batch_op.alter_column('username', existing_type=sa.String(256), nullable=True, unique=True) - elif conn.dialect.name == 'mssql': - with op.batch_alter_table('ab_register_user') as batch_op: + if conn.dialect.name == "sqlite": + op.execute("PRAGMA foreign_keys=OFF") + with op.batch_alter_table("ab_view_menu", schema=None) as batch_op: + batch_op.drop_constraint("ab_view_menu_name_uq", type_="unique") + op.execute("PRAGMA foreign_keys=ON") + elif conn.dialect.name == "mysql": + with op.batch_alter_table("ab_user", schema=None) as batch_op: + batch_op.alter_column("email", existing_type=sa.String(256), nullable=True) + batch_op.alter_column("username", existing_type=sa.String(256), nullable=True, unique=True) + with op.batch_alter_table("ab_register_user", schema=None) as batch_op: + batch_op.alter_column("email", existing_type=sa.String(256), nullable=True) + batch_op.alter_column("username", existing_type=sa.String(256), nullable=True, unique=True) + elif conn.dialect.name == "mssql": + with op.batch_alter_table("ab_register_user") as batch_op: # Drop the unique constraint on username and email - constraints = get_mssql_table_constraints(conn, 'ab_register_user') - for k, _ in constraints.get('UNIQUE').items(): - batch_op.drop_constraint(k, type_='unique') - batch_op.alter_column('username', existing_type=sa.String(256), nullable=False, unique=True) - batch_op.create_unique_constraint(None, ['username']) - batch_op.alter_column('email', existing_type=sa.String(256), nullable=False, unique=True) - with op.batch_alter_table('ab_user') as batch_op: + constraints = get_mssql_table_constraints(conn, "ab_register_user") + for k, _ in constraints.get("UNIQUE").items(): + batch_op.drop_constraint(k, type_="unique") + batch_op.alter_column("username", existing_type=sa.String(256), nullable=False, unique=True) + batch_op.create_unique_constraint(None, ["username"]) + batch_op.alter_column("email", existing_type=sa.String(256), nullable=False, unique=True) + with op.batch_alter_table("ab_user") as batch_op: # Drop the unique constraint on username and email - constraints = get_mssql_table_constraints(conn, 'ab_user') - for k, _ in constraints.get('UNIQUE').items(): - batch_op.drop_constraint(k, type_='unique') - batch_op.alter_column('username', existing_type=sa.String(256), nullable=True) - batch_op.create_unique_constraint(None, ['username']) - batch_op.alter_column('email', existing_type=sa.String(256), nullable=True, unique=True) - batch_op.create_unique_constraint(None, ['email']) + constraints = get_mssql_table_constraints(conn, "ab_user") + for k, _ in constraints.get("UNIQUE").items(): + batch_op.drop_constraint(k, type_="unique") + batch_op.alter_column("username", existing_type=sa.String(256), nullable=True) + batch_op.create_unique_constraint(None, ["username"]) + batch_op.alter_column("email", existing_type=sa.String(256), nullable=True, unique=True) + batch_op.create_unique_constraint(None, ["email"]) diff --git a/airflow/migrations/versions/0107_2_3_0_add_map_index_to_log.py b/airflow/migrations/versions/0107_2_3_0_add_map_index_to_log.py index f48cd90cfa1a3..96c52b03967b3 100644 --- a/airflow/migrations/versions/0107_2_3_0_add_map_index_to_log.py +++ b/airflow/migrations/versions/0107_2_3_0_add_map_index_to_log.py @@ -15,13 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add map_index to Log. Revision ID: 75d5ed6c2b43 Revises: 909884dea523 Create Date: 2022-03-15 16:35:54.816863 """ +from __future__ import annotations + from alembic import op from sqlalchemy import Column, Integer @@ -30,7 +31,7 @@ down_revision = "909884dea523" branch_labels = None depends_on = None -airflow_version = '2.3.0' +airflow_version = "2.3.0" def upgrade(): diff --git a/airflow/migrations/versions/0108_2_3_0_default_dag_view_grid.py b/airflow/migrations/versions/0108_2_3_0_default_dag_view_grid.py index 03a9f43c3ad62..7875d1eceb30d 100644 --- a/airflow/migrations/versions/0108_2_3_0_default_dag_view_grid.py +++ b/airflow/migrations/versions/0108_2_3_0_default_dag_view_grid.py @@ -15,41 +15,41 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -"""Update dag.default_view to grid +"""Update dag.default_view to grid. Revision ID: b1b348e02d07 Revises: 75d5ed6c2b43 Create Date: 2022-04-19 17:25:00.872220 """ +from __future__ import annotations from alembic import op from sqlalchemy import String from sqlalchemy.sql import column, table # revision identifiers, used by Alembic. -revision = 'b1b348e02d07' -down_revision = '75d5ed6c2b43' +revision = "b1b348e02d07" +down_revision = "75d5ed6c2b43" branch_labels = None -depends_on = '75d5ed6c2b43' -airflow_version = '2.3.0' +depends_on = "75d5ed6c2b43" +airflow_version = "2.3.0" -dag = table('dag', column('default_view', String)) +dag = table("dag", column("default_view", String)) def upgrade(): op.execute( dag.update() - .where(dag.c.default_view == op.inline_literal('tree')) - .values({'default_view': op.inline_literal('grid')}) + .where(dag.c.default_view == op.inline_literal("tree")) + .values({"default_view": op.inline_literal("grid")}) ) def downgrade(): op.execute( dag.update() - .where(dag.c.default_view == op.inline_literal('grid')) - .values({'default_view': op.inline_literal('tree')}) + .where(dag.c.default_view == op.inline_literal("grid")) + .values({"default_view": op.inline_literal("tree")}) ) diff --git a/airflow/migrations/versions/0109_2_3_1_add_index_for_event_in_log.py b/airflow/migrations/versions/0109_2_3_1_add_index_for_event_in_log.py index 2023a3c294fee..bfe958520be11 100644 --- a/airflow/migrations/versions/0109_2_3_1_add_index_for_event_in_log.py +++ b/airflow/migrations/versions/0109_2_3_1_add_index_for_event_in_log.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add index for ``event`` column in ``log`` table. Revision ID: 1de7bc13c950 @@ -23,22 +22,23 @@ Create Date: 2022-05-10 18:18:43.484829 """ +from __future__ import annotations from alembic import op # revision identifiers, used by Alembic. -revision = '1de7bc13c950' -down_revision = 'b1b348e02d07' +revision = "1de7bc13c950" +down_revision = "b1b348e02d07" branch_labels = None depends_on = None -airflow_version = '2.3.1' +airflow_version = "2.3.1" def upgrade(): """Apply Add index for ``event`` column in ``log`` table.""" - op.create_index('idx_log_event', 'log', ['event'], unique=False) + op.create_index("idx_log_event", "log", ["event"], unique=False) def downgrade(): """Unapply Add index for ``event`` column in ``log`` table.""" - op.drop_index('idx_log_event', table_name='log') + op.drop_index("idx_log_event", table_name="log") diff --git a/airflow/migrations/versions/0110_2_3_2_add_cascade_to_dag_tag_foreignkey.py b/airflow/migrations/versions/0110_2_3_2_add_cascade_to_dag_tag_foreignkey.py index 55d9e9754e532..6ca9b2afb9528 100644 --- a/airflow/migrations/versions/0110_2_3_2_add_cascade_to_dag_tag_foreignkey.py +++ b/airflow/migrations/versions/0110_2_3_2_add_cascade_to_dag_tag_foreignkey.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Add cascade to dag_tag foreign key Revision ID: 3c94c427fdf6 @@ -23,62 +22,60 @@ Create Date: 2022-05-03 09:47:41.957710 """ +from __future__ import annotations from alembic import op +from sqlalchemy import inspect from airflow.migrations.utils import get_mssql_table_constraints # revision identifiers, used by Alembic. -revision = '3c94c427fdf6' -down_revision = '1de7bc13c950' +revision = "3c94c427fdf6" +down_revision = "1de7bc13c950" branch_labels = None depends_on = None -airflow_version = '2.3.2' +airflow_version = "2.3.2" def upgrade(): """Apply Add cascade to dag_tag foreignkey""" conn = op.get_bind() - if conn.dialect.name == 'sqlite': - naming_convention = { - "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", - } + if conn.dialect.name in ["sqlite", "mysql"]: + inspector = inspect(conn.engine) + foreignkey = inspector.get_foreign_keys("dag_tag") with op.batch_alter_table( - 'dag_tag', naming_convention=naming_convention, recreate='always' + "dag_tag", ) as batch_op: - batch_op.drop_constraint('fk_dag_tag_dag_id_dag', type_='foreignkey') + batch_op.drop_constraint(foreignkey[0]["name"], type_="foreignkey") batch_op.create_foreign_key( - "dag_tag_dag_id_fkey", 'dag', ['dag_id'], ['dag_id'], ondelete='CASCADE' + "dag_tag_dag_id_fkey", "dag", ["dag_id"], ["dag_id"], ondelete="CASCADE" ) else: - with op.batch_alter_table('dag_tag') as batch_op: - if conn.dialect.name == 'mssql': - constraints = get_mssql_table_constraints(conn, 'dag_tag') - Fk, _ = constraints['FOREIGN KEY'].popitem() - batch_op.drop_constraint(Fk, type_='foreignkey') - if conn.dialect.name == 'postgresql': - batch_op.drop_constraint('dag_tag_dag_id_fkey', type_='foreignkey') - if conn.dialect.name == 'mysql': - batch_op.drop_constraint('dag_tag_ibfk_1', type_='foreignkey') - + with op.batch_alter_table("dag_tag") as batch_op: + if conn.dialect.name == "mssql": + constraints = get_mssql_table_constraints(conn, "dag_tag") + Fk, _ = constraints["FOREIGN KEY"].popitem() + batch_op.drop_constraint(Fk, type_="foreignkey") + if conn.dialect.name == "postgresql": + batch_op.drop_constraint("dag_tag_dag_id_fkey", type_="foreignkey") batch_op.create_foreign_key( - "dag_tag_dag_id_fkey", 'dag', ['dag_id'], ['dag_id'], ondelete='CASCADE' + "dag_tag_dag_id_fkey", "dag", ["dag_id"], ["dag_id"], ondelete="CASCADE" ) def downgrade(): """Unapply Add cascade to dag_tag foreignkey""" conn = op.get_bind() - if conn.dialect.name == 'sqlite': - with op.batch_alter_table('dag_tag') as batch_op: - batch_op.drop_constraint('dag_tag_dag_id_fkey', type_='foreignkey') - batch_op.create_foreign_key("fk_dag_tag_dag_id_dag", 'dag', ['dag_id'], ['dag_id']) + if conn.dialect.name == "sqlite": + with op.batch_alter_table("dag_tag") as batch_op: + batch_op.drop_constraint("dag_tag_dag_id_fkey", type_="foreignkey") + batch_op.create_foreign_key("fk_dag_tag_dag_id_dag", "dag", ["dag_id"], ["dag_id"]) else: - with op.batch_alter_table('dag_tag') as batch_op: - batch_op.drop_constraint('dag_tag_dag_id_fkey', type_='foreignkey') + with op.batch_alter_table("dag_tag") as batch_op: + batch_op.drop_constraint("dag_tag_dag_id_fkey", type_="foreignkey") batch_op.create_foreign_key( None, - 'dag', - ['dag_id'], - ['dag_id'], + "dag", + ["dag_id"], + ["dag_id"], ) diff --git a/airflow/migrations/versions/0111_2_3_3_add_indexes_for_cascade_deletes.py b/airflow/migrations/versions/0111_2_3_3_add_indexes_for_cascade_deletes.py new file mode 100644 index 0000000000000..d2c39177fafcd --- /dev/null +++ b/airflow/migrations/versions/0111_2_3_3_add_indexes_for_cascade_deletes.py @@ -0,0 +1,94 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Add indexes for CASCADE deletes on task_instance + +Some databases don't add indexes on the FK columns so we have to add them for performance on CASCADE deletes. + +Revision ID: f5fcbda3e651 +Revises: 3c94c427fdf6 +Create Date: 2022-06-15 18:04:54.081789 + +""" +from __future__ import annotations + +from alembic import context, op + +# revision identifiers, used by Alembic. +revision = "f5fcbda3e651" +down_revision = "3c94c427fdf6" +branch_labels = None +depends_on = None +airflow_version = "2.3.3" + + +def _mysql_tables_where_indexes_already_present(conn): + """ + If user downgraded and is upgrading again, we have to check for existing + indexes on mysql because we can't (and don't) drop them as part of the + downgrade. + """ + to_check = [ + ("xcom", "idx_xcom_task_instance"), + ("task_reschedule", "idx_task_reschedule_dag_run"), + ("task_fail", "idx_task_fail_task_instance"), + ] + tables = set() + for tbl, idx in to_check: + if conn.execute(f"show indexes from {tbl} where Key_name = '{idx}'").first(): + tables.add(tbl) + return tables + + +def upgrade(): + """Apply Add indexes for CASCADE deletes""" + conn = op.get_bind() + tables_to_skip = set() + + # mysql requires indexes for FKs, so adding had the effect of renaming, and we cannot remove. + if conn.dialect.name == "mysql" and not context.is_offline_mode(): + tables_to_skip.update(_mysql_tables_where_indexes_already_present(conn)) + + if "task_fail" not in tables_to_skip: + with op.batch_alter_table("task_fail", schema=None) as batch_op: + batch_op.create_index("idx_task_fail_task_instance", ["dag_id", "task_id", "run_id", "map_index"]) + + if "task_reschedule" not in tables_to_skip: + with op.batch_alter_table("task_reschedule", schema=None) as batch_op: + batch_op.create_index("idx_task_reschedule_dag_run", ["dag_id", "run_id"]) + + if "xcom" not in tables_to_skip: + with op.batch_alter_table("xcom", schema=None) as batch_op: + batch_op.create_index("idx_xcom_task_instance", ["dag_id", "task_id", "run_id", "map_index"]) + + +def downgrade(): + """Unapply Add indexes for CASCADE deletes""" + conn = op.get_bind() + + # mysql requires indexes for FKs, so adding had the effect of renaming, and we cannot remove. + if conn.dialect.name == "mysql": + return + + with op.batch_alter_table("xcom", schema=None) as batch_op: + batch_op.drop_index("idx_xcom_task_instance") + + with op.batch_alter_table("task_reschedule", schema=None) as batch_op: + batch_op.drop_index("idx_task_reschedule_dag_run") + + with op.batch_alter_table("task_fail", schema=None) as batch_op: + batch_op.drop_index("idx_task_fail_task_instance") diff --git a/airflow/migrations/versions/0112_2_4_0_add_dagwarning_model.py b/airflow/migrations/versions/0112_2_4_0_add_dagwarning_model.py new file mode 100644 index 0000000000000..2bb999b21b350 --- /dev/null +++ b/airflow/migrations/versions/0112_2_4_0_add_dagwarning_model.py @@ -0,0 +1,60 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Add DagWarning model + +Revision ID: 424117c37d18 +Revises: 3c94c427fdf6 +Create Date: 2022-04-27 15:57:36.736743 +""" +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +from airflow.migrations.db_types import TIMESTAMP, StringID + +# revision identifiers, used by Alembic. + + +revision = "424117c37d18" +down_revision = "f5fcbda3e651" +branch_labels = None +depends_on = None +airflow_version = "2.4.0" + + +def upgrade(): + """Apply Add DagWarning model""" + op.create_table( + "dag_warning", + sa.Column("dag_id", StringID(), primary_key=True), + sa.Column("warning_type", sa.String(length=50), primary_key=True), + sa.Column("message", sa.Text(), nullable=False), + sa.Column("timestamp", TIMESTAMP, nullable=False), + sa.ForeignKeyConstraint( + ("dag_id",), + ["dag.dag_id"], + name="dcw_dag_id_fkey", + ondelete="CASCADE", + ), + ) + + +def downgrade(): + """Unapply Add DagWarning model""" + op.drop_table("dag_warning") diff --git a/airflow/migrations/versions/0113_2_4_0_compare_types_between_orm_and_db.py b/airflow/migrations/versions/0113_2_4_0_compare_types_between_orm_and_db.py new file mode 100644 index 0000000000000..8beebc80a9aa7 --- /dev/null +++ b/airflow/migrations/versions/0113_2_4_0_compare_types_between_orm_and_db.py @@ -0,0 +1,262 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""compare types between ORM and DB. + +Revision ID: 44b7034f6bdc +Revises: 424117c37d18 +Create Date: 2022-05-31 09:16:44.558754 + +""" +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +from airflow.migrations.db_types import TIMESTAMP + +# revision identifiers, used by Alembic. +revision = "44b7034f6bdc" +down_revision = "424117c37d18" +branch_labels = None +depends_on = None +airflow_version = "2.4.0" + + +def upgrade(): + """Apply compare types between ORM and DB.""" + conn = op.get_bind() + with op.batch_alter_table("connection", schema=None) as batch_op: + batch_op.alter_column( + "extra", + existing_type=sa.TEXT(), + type_=sa.Text(), + existing_nullable=True, + ) + with op.batch_alter_table("log_template", schema=None) as batch_op: + batch_op.alter_column( + "created_at", existing_type=sa.DateTime(), type_=TIMESTAMP(), existing_nullable=False + ) + + with op.batch_alter_table("serialized_dag", schema=None) as batch_op: + # drop server_default + batch_op.alter_column( + "dag_hash", + existing_type=sa.String(32), + server_default=None, + type_=sa.String(32), + existing_nullable=False, + ) + with op.batch_alter_table("trigger", schema=None) as batch_op: + batch_op.alter_column( + "created_date", existing_type=sa.DateTime(), type_=TIMESTAMP(), existing_nullable=False + ) + + if conn.dialect.name != "sqlite": + return + with op.batch_alter_table("serialized_dag", schema=None) as batch_op: + batch_op.alter_column("fileloc_hash", existing_type=sa.Integer, type_=sa.BigInteger()) + # Some sqlite date are not in db_types.TIMESTAMP. Convert these to TIMESTAMP. + with op.batch_alter_table("dag", schema=None) as batch_op: + batch_op.alter_column( + "last_pickled", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True + ) + batch_op.alter_column( + "last_expired", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True + ) + + with op.batch_alter_table("dag_pickle", schema=None) as batch_op: + batch_op.alter_column( + "created_dttm", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True + ) + + with op.batch_alter_table("dag_run", schema=None) as batch_op: + batch_op.alter_column( + "execution_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=False + ) + batch_op.alter_column( + "start_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True + ) + batch_op.alter_column( + "end_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True + ) + + with op.batch_alter_table("import_error", schema=None) as batch_op: + batch_op.alter_column( + "timestamp", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True + ) + + with op.batch_alter_table("job", schema=None) as batch_op: + batch_op.alter_column( + "start_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True + ) + batch_op.alter_column( + "end_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True + ) + batch_op.alter_column( + "latest_heartbeat", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True + ) + + with op.batch_alter_table("log", schema=None) as batch_op: + batch_op.alter_column("dttm", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True) + batch_op.alter_column( + "execution_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True + ) + + with op.batch_alter_table("serialized_dag", schema=None) as batch_op: + batch_op.alter_column( + "last_updated", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=False + ) + + with op.batch_alter_table("sla_miss", schema=None) as batch_op: + batch_op.alter_column( + "execution_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=False + ) + batch_op.alter_column( + "timestamp", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True + ) + + with op.batch_alter_table("task_fail", schema=None) as batch_op: + batch_op.alter_column( + "start_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True + ) + batch_op.alter_column( + "end_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True + ) + + with op.batch_alter_table("task_instance", schema=None) as batch_op: + batch_op.alter_column( + "start_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True + ) + batch_op.alter_column( + "end_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True + ) + batch_op.alter_column( + "queued_dttm", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True + ) + + +def downgrade(): + """Unapply compare types between ORM and DB.""" + with op.batch_alter_table("connection", schema=None) as batch_op: + batch_op.alter_column( + "extra", + existing_type=sa.Text(), + type_=sa.TEXT(), + existing_nullable=True, + ) + with op.batch_alter_table("log_template", schema=None) as batch_op: + batch_op.alter_column( + "created_at", existing_type=TIMESTAMP(), type_=sa.DateTime(), existing_nullable=False + ) + with op.batch_alter_table("serialized_dag", schema=None) as batch_op: + # add server_default + batch_op.alter_column( + "dag_hash", + existing_type=sa.String(32), + server_default="Hash not calculated yet", + type_=sa.String(32), + existing_nullable=False, + ) + with op.batch_alter_table("trigger", schema=None) as batch_op: + batch_op.alter_column( + "created_date", existing_type=TIMESTAMP(), type_=sa.DateTime(), existing_nullable=False + ) + conn = op.get_bind() + + if conn.dialect.name != "sqlite": + return + with op.batch_alter_table("serialized_dag", schema=None) as batch_op: + batch_op.alter_column("fileloc_hash", existing_type=sa.BigInteger, type_=sa.Integer()) + # Change these column back to sa.DATETIME() + with op.batch_alter_table("task_instance", schema=None) as batch_op: + batch_op.alter_column( + "queued_dttm", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True + ) + batch_op.alter_column( + "end_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True + ) + batch_op.alter_column( + "start_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True + ) + + with op.batch_alter_table("task_fail", schema=None) as batch_op: + batch_op.alter_column( + "end_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True + ) + batch_op.alter_column( + "start_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True + ) + + with op.batch_alter_table("sla_miss", schema=None) as batch_op: + batch_op.alter_column( + "timestamp", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True + ) + batch_op.alter_column( + "execution_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=False + ) + + with op.batch_alter_table("serialized_dag", schema=None) as batch_op: + batch_op.alter_column( + "last_updated", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=False + ) + + with op.batch_alter_table("log", schema=None) as batch_op: + batch_op.alter_column( + "execution_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True + ) + batch_op.alter_column("dttm", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True) + + with op.batch_alter_table("job", schema=None) as batch_op: + batch_op.alter_column( + "latest_heartbeat", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True + ) + batch_op.alter_column( + "end_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True + ) + batch_op.alter_column( + "start_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True + ) + + with op.batch_alter_table("import_error", schema=None) as batch_op: + batch_op.alter_column( + "timestamp", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True + ) + + with op.batch_alter_table("dag_run", schema=None) as batch_op: + batch_op.alter_column( + "end_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True + ) + batch_op.alter_column( + "start_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True + ) + batch_op.alter_column( + "execution_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=False + ) + + with op.batch_alter_table("dag_pickle", schema=None) as batch_op: + batch_op.alter_column( + "created_dttm", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True + ) + + with op.batch_alter_table("dag", schema=None) as batch_op: + batch_op.alter_column( + "last_expired", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True + ) + batch_op.alter_column( + "last_pickled", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True + ) diff --git a/airflow/migrations/versions/0114_2_4_0_add_dataset_model.py b/airflow/migrations/versions/0114_2_4_0_add_dataset_model.py new file mode 100644 index 0000000000000..b6671a8c075a2 --- /dev/null +++ b/airflow/migrations/versions/0114_2_4_0_add_dataset_model.py @@ -0,0 +1,190 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Add Dataset model + +Revision ID: 0038cd0c28b4 +Revises: 44b7034f6bdc +Create Date: 2022-06-22 14:37:20.880672 + +""" +from __future__ import annotations + +import sqlalchemy as sa +import sqlalchemy_jsonfield +from alembic import op +from sqlalchemy import Integer, String, func + +from airflow.migrations.db_types import TIMESTAMP, StringID +from airflow.settings import json + +revision = "0038cd0c28b4" +down_revision = "44b7034f6bdc" +branch_labels = None +depends_on = None +airflow_version = "2.4.0" + + +def _create_dataset_table(): + op.create_table( + "dataset", + sa.Column("id", Integer, primary_key=True, autoincrement=True), + sa.Column( + "uri", + String(length=3000).with_variant( + String( + length=3000, + # latin1 allows for more indexed length in mysql + # and this field should only be ascii chars + collation="latin1_general_cs", + ), + "mysql", + ), + nullable=False, + ), + sa.Column("extra", sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}), + sa.Column("created_at", TIMESTAMP, nullable=False), + sa.Column("updated_at", TIMESTAMP, nullable=False), + sqlite_autoincrement=True, # ensures PK values not reused + ) + op.create_index("idx_uri_unique", "dataset", ["uri"], unique=True) + + +def _create_dag_schedule_dataset_reference_table(): + op.create_table( + "dag_schedule_dataset_reference", + sa.Column("dataset_id", Integer, primary_key=True, nullable=False), + sa.Column("dag_id", StringID(), primary_key=True, nullable=False), + sa.Column("created_at", TIMESTAMP, default=func.now, nullable=False), + sa.Column("updated_at", TIMESTAMP, default=func.now, nullable=False), + sa.ForeignKeyConstraint( + ("dataset_id",), + ["dataset.id"], + name="dsdr_dataset_fkey", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + columns=("dag_id",), + refcolumns=["dag.dag_id"], + name="dsdr_dag_id_fkey", + ondelete="CASCADE", + ), + ) + + +def _create_task_outlet_dataset_reference_table(): + op.create_table( + "task_outlet_dataset_reference", + sa.Column("dataset_id", Integer, primary_key=True, nullable=False), + sa.Column("dag_id", StringID(), primary_key=True, nullable=False), + sa.Column("task_id", StringID(), primary_key=True, nullable=False), + sa.Column("created_at", TIMESTAMP, default=func.now, nullable=False), + sa.Column("updated_at", TIMESTAMP, default=func.now, nullable=False), + sa.ForeignKeyConstraint( + ("dataset_id",), + ["dataset.id"], + name="todr_dataset_fkey", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + columns=("dag_id",), + refcolumns=["dag.dag_id"], + name="todr_dag_id_fkey", + ondelete="CASCADE", + ), + ) + + +def _create_dataset_dag_run_queue_table(): + op.create_table( + "dataset_dag_run_queue", + sa.Column("dataset_id", Integer, primary_key=True, nullable=False), + sa.Column("target_dag_id", StringID(), primary_key=True, nullable=False), + sa.Column("created_at", TIMESTAMP, default=func.now, nullable=False), + sa.ForeignKeyConstraint( + ("dataset_id",), + ["dataset.id"], + name="ddrq_dataset_fkey", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ("target_dag_id",), + ["dag.dag_id"], + name="ddrq_dag_fkey", + ondelete="CASCADE", + ), + ) + + +def _create_dataset_event_table(): + op.create_table( + "dataset_event", + sa.Column("id", Integer, primary_key=True, autoincrement=True), + sa.Column("dataset_id", Integer, nullable=False), + sa.Column("extra", sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}), + sa.Column("source_task_id", String(250), nullable=True), + sa.Column("source_dag_id", String(250), nullable=True), + sa.Column("source_run_id", String(250), nullable=True), + sa.Column("source_map_index", sa.Integer(), nullable=True, server_default="-1"), + sa.Column("timestamp", TIMESTAMP, nullable=False), + sqlite_autoincrement=True, # ensures PK values not reused + ) + op.create_index("idx_dataset_id_timestamp", "dataset_event", ["dataset_id", "timestamp"]) + + +def _create_dataset_event_dag_run_table(): + op.create_table( + "dagrun_dataset_event", + sa.Column("dag_run_id", sa.Integer(), nullable=False), + sa.Column("event_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["dag_run_id"], + ["dag_run.id"], + name=op.f("dagrun_dataset_events_dag_run_id_fkey"), + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["event_id"], + ["dataset_event.id"], + name=op.f("dagrun_dataset_events_event_id_fkey"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("dag_run_id", "event_id", name=op.f("dagrun_dataset_events_pkey")), + ) + with op.batch_alter_table("dagrun_dataset_event") as batch_op: + batch_op.create_index("idx_dagrun_dataset_events_dag_run_id", ["dag_run_id"], unique=False) + batch_op.create_index("idx_dagrun_dataset_events_event_id", ["event_id"], unique=False) + + +def upgrade(): + """Apply Add Dataset model""" + _create_dataset_table() + _create_dag_schedule_dataset_reference_table() + _create_task_outlet_dataset_reference_table() + _create_dataset_dag_run_queue_table() + _create_dataset_event_table() + _create_dataset_event_dag_run_table() + + +def downgrade(): + """Unapply Add Dataset model""" + op.drop_table("dag_schedule_dataset_reference") + op.drop_table("task_outlet_dataset_reference") + op.drop_table("dataset_dag_run_queue") + op.drop_table("dagrun_dataset_event") + op.drop_table("dataset_event") + op.drop_table("dataset") diff --git a/airflow/migrations/versions/0115_2_4_0_remove_smart_sensors.py b/airflow/migrations/versions/0115_2_4_0_remove_smart_sensors.py new file mode 100644 index 0000000000000..2f1c9994345fe --- /dev/null +++ b/airflow/migrations/versions/0115_2_4_0_remove_smart_sensors.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Remove smart sensors + +Revision ID: f4ff391becb5 +Revises: 0038cd0c28b4 +Create Date: 2022-08-03 11:33:44.777945 + +""" +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op +from sqlalchemy import func +from sqlalchemy.sql import column, table + +from airflow.migrations.db_types import TIMESTAMP, StringID + +# revision identifiers, used by Alembic. +revision = "f4ff391becb5" +down_revision = "0038cd0c28b4" +branch_labels = None +depends_on = None +airflow_version = "2.4.0" + + +def upgrade(): + """Apply Remove smart sensors""" + op.drop_table("sensor_instance") + + """Minimal model definition for migrations""" + task_instance = table("task_instance", column("state", sa.String)) + op.execute(task_instance.update().where(task_instance.c.state == "sensing").values({"state": "failed"})) + + +def downgrade(): + """Unapply Remove smart sensors""" + op.create_table( + "sensor_instance", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("task_id", StringID(), nullable=False), + sa.Column("dag_id", StringID(), nullable=False), + sa.Column("execution_date", TIMESTAMP, nullable=False), + sa.Column("state", sa.String(length=20), nullable=True), + sa.Column("try_number", sa.Integer(), nullable=True), + sa.Column("start_date", TIMESTAMP, nullable=True), + sa.Column("operator", sa.String(length=1000), nullable=False), + sa.Column("op_classpath", sa.String(length=1000), nullable=False), + sa.Column("hashcode", sa.BigInteger(), nullable=False), + sa.Column("shardcode", sa.Integer(), nullable=False), + sa.Column("poke_context", sa.Text(), nullable=False), + sa.Column("execution_context", sa.Text(), nullable=True), + sa.Column("created_at", TIMESTAMP, default=func.now, nullable=False), + sa.Column("updated_at", TIMESTAMP, default=func.now, nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("ti_primary_key", "sensor_instance", ["dag_id", "task_id", "execution_date"], unique=True) + op.create_index("si_hashcode", "sensor_instance", ["hashcode"], unique=False) + op.create_index("si_shardcode", "sensor_instance", ["shardcode"], unique=False) + op.create_index("si_state_shard", "sensor_instance", ["state", "shardcode"], unique=False) + op.create_index("si_updated_at", "sensor_instance", ["updated_at"], unique=False) diff --git a/airflow/migrations/versions/0116_2_4_0_add_dag_owner_attributes_table.py b/airflow/migrations/versions/0116_2_4_0_add_dag_owner_attributes_table.py new file mode 100644 index 0000000000000..5b4d8c2e6baa2 --- /dev/null +++ b/airflow/migrations/versions/0116_2_4_0_add_dag_owner_attributes_table.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""add dag_owner_attributes table + +Revision ID: 1486deb605b4 +Revises: f4ff391becb5 +Create Date: 2022-08-04 16:59:45.406589 + +""" +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +from airflow.migrations.db_types import StringID + +# revision identifiers, used by Alembic. +revision = "1486deb605b4" +down_revision = "f4ff391becb5" +branch_labels = None +depends_on = None +airflow_version = "2.4.0" + + +def upgrade(): + """Apply Add ``DagOwnerAttributes`` table""" + op.create_table( + "dag_owner_attributes", + sa.Column("dag_id", StringID(), nullable=False), + sa.Column("owner", sa.String(length=500), nullable=False), + sa.Column("link", sa.String(length=500), nullable=False), + sa.ForeignKeyConstraint(["dag_id"], ["dag.dag_id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("dag_id", "owner"), + ) + + +def downgrade(): + """Unapply Add Dataset model""" + op.drop_table("dag_owner_attributes") diff --git a/airflow/migrations/versions/0117_2_4_0_add_processor_subdir_to_dagmodel_and_.py b/airflow/migrations/versions/0117_2_4_0_add_processor_subdir_to_dagmodel_and_.py new file mode 100644 index 0000000000000..6846d4ebb8943 --- /dev/null +++ b/airflow/migrations/versions/0117_2_4_0_add_processor_subdir_to_dagmodel_and_.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Add processor_subdir column to DagModel, SerializedDagModel and CallbackRequest tables. + +Revision ID: ecb43d2a1842 +Revises: 1486deb605b4 +Create Date: 2022-08-26 11:30:11.249580 + +""" +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "ecb43d2a1842" +down_revision = "1486deb605b4" +branch_labels = None +depends_on = None +airflow_version = "2.4.0" + + +def upgrade(): + """Apply add processor_subdir to DagModel and SerializedDagModel""" + conn = op.get_bind() + + with op.batch_alter_table("dag") as batch_op: + if conn.dialect.name == "mysql": + batch_op.add_column(sa.Column("processor_subdir", sa.Text(length=2000), nullable=True)) + else: + batch_op.add_column(sa.Column("processor_subdir", sa.String(length=2000), nullable=True)) + + with op.batch_alter_table("serialized_dag") as batch_op: + if conn.dialect.name == "mysql": + batch_op.add_column(sa.Column("processor_subdir", sa.Text(length=2000), nullable=True)) + else: + batch_op.add_column(sa.Column("processor_subdir", sa.String(length=2000), nullable=True)) + + with op.batch_alter_table("callback_request") as batch_op: + batch_op.drop_column("dag_directory") + if conn.dialect.name == "mysql": + batch_op.add_column(sa.Column("processor_subdir", sa.Text(length=2000), nullable=True)) + else: + batch_op.add_column(sa.Column("processor_subdir", sa.String(length=2000), nullable=True)) + + +def downgrade(): + """Unapply Add processor_subdir to DagModel and SerializedDagModel""" + conn = op.get_bind() + with op.batch_alter_table("dag", schema=None) as batch_op: + batch_op.drop_column("processor_subdir") + + with op.batch_alter_table("serialized_dag", schema=None) as batch_op: + batch_op.drop_column("processor_subdir") + + with op.batch_alter_table("callback_request") as batch_op: + batch_op.drop_column("processor_subdir") + if conn.dialect.name == "mysql": + batch_op.add_column(sa.Column("dag_directory", sa.Text(length=1000), nullable=True)) + else: + batch_op.add_column(sa.Column("dag_directory", sa.String(length=1000), nullable=True)) diff --git a/airflow/migrations/versions/0118_2_4_2_add_missing_autoinc_fab.py b/airflow/migrations/versions/0118_2_4_2_add_missing_autoinc_fab.py new file mode 100644 index 0000000000000..6c4f010e2f353 --- /dev/null +++ b/airflow/migrations/versions/0118_2_4_2_add_missing_autoinc_fab.py @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Add missing auto-increment to columns on FAB tables + +Revision ID: b0d31815b5a6 +Revises: ecb43d2a1842 +Create Date: 2022-10-05 13:16:45.638490 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "b0d31815b5a6" +down_revision = "ecb43d2a1842" +branch_labels = None +depends_on = None +airflow_version = "2.4.2" + + +def upgrade(): + """Apply migration. + + If these columns are already of the right type (i.e. created by our + migration in 1.10.13 rather than FAB itself in an earlier version), this + migration will issue an alter statement to change them to what they already + are -- i.e. its a no-op. + + These tables are small (100 to low 1k rows at most), so it's not too costly + to change them. + """ + conn = op.get_bind() + if conn.dialect.name in ["mssql", "sqlite"]: + # 1.10.12 didn't support SQL Server, so it couldn't have gotten this wrong --> nothing to correct + # SQLite autoinc was "implicit" for an INTEGER NOT NULL PRIMARY KEY + return + + for table in ( + "ab_permission", + "ab_view_menu", + "ab_role", + "ab_permission_view", + "ab_permission_view_role", + "ab_user", + "ab_user_role", + "ab_register_user", + ): + with op.batch_alter_table(table) as batch: + kwargs = {} + if conn.dialect.name == "postgresql": + kwargs["server_default"] = sa.Sequence(f"{table}_id_seq").next_value() + else: + kwargs["autoincrement"] = True + batch.alter_column("id", existing_type=sa.Integer(), existing_nullable=False, **kwargs) + + +def downgrade(): + """Unapply add_missing_autoinc_fab""" + # No downgrade needed, these _should_ have applied from 1.10.13 but didn't due to a previous bug! diff --git a/airflow/migrations/versions/0119_2_4_3_add_case_insensitive_unique_constraint_for_username.py b/airflow/migrations/versions/0119_2_4_3_add_case_insensitive_unique_constraint_for_username.py new file mode 100644 index 0000000000000..69b52f0864339 --- /dev/null +++ b/airflow/migrations/versions/0119_2_4_3_add_case_insensitive_unique_constraint_for_username.py @@ -0,0 +1,89 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Add case-insensitive unique constraint for username + +Revision ID: e07f49787c9d +Revises: b0d31815b5a6 +Create Date: 2022-10-25 17:29:46.432326 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "e07f49787c9d" +down_revision = "b0d31815b5a6" +branch_labels = None +depends_on = None +airflow_version = "2.4.3" + + +def upgrade(): + """Apply Add case-insensitive unique constraint""" + conn = op.get_bind() + if conn.dialect.name == "postgresql": + op.create_index("idx_ab_user_username", "ab_user", [sa.text("LOWER(username)")], unique=True) + op.create_index( + "idx_ab_register_user_username", "ab_register_user", [sa.text("LOWER(username)")], unique=True + ) + elif conn.dialect.name == "sqlite": + with op.batch_alter_table("ab_user") as batch_op: + batch_op.alter_column( + "username", + existing_type=sa.String(64), + _type=sa.String(64, collation="NOCASE"), + unique=True, + nullable=False, + ) + with op.batch_alter_table("ab_register_user") as batch_op: + batch_op.alter_column( + "username", + existing_type=sa.String(64), + _type=sa.String(64, collation="NOCASE"), + unique=True, + nullable=False, + ) + + +def downgrade(): + """Unapply Add case-insensitive unique constraint""" + conn = op.get_bind() + if conn.dialect.name == "postgresql": + op.drop_index("idx_ab_user_username", table_name="ab_user") + op.drop_index("idx_ab_register_user_username", table_name="ab_register_user") + elif conn.dialect.name == "sqlite": + with op.batch_alter_table("ab_user") as batch_op: + batch_op.alter_column( + "username", + existing_type=sa.String(64, collation="NOCASE"), + _type=sa.String(64), + unique=True, + nullable=False, + ) + with op.batch_alter_table("ab_register_user") as batch_op: + batch_op.alter_column( + "username", + existing_type=sa.String(64, collation="NOCASE"), + _type=sa.String(64), + unique=True, + nullable=False, + ) diff --git a/airflow/migrations/versions/0120_2_5_0_add_updated_at_to_dagrun_and_ti.py b/airflow/migrations/versions/0120_2_5_0_add_updated_at_to_dagrun_and_ti.py new file mode 100644 index 0000000000000..f51f89001e82b --- /dev/null +++ b/airflow/migrations/versions/0120_2_5_0_add_updated_at_to_dagrun_and_ti.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Add updated_at column to DagRun and TaskInstance + +Revision ID: ee8d93fcc81e +Revises: e07f49787c9d +Create Date: 2022-09-08 19:08:37.623121 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +from airflow.migrations.db_types import TIMESTAMP + +# revision identifiers, used by Alembic. +revision = "ee8d93fcc81e" +down_revision = "e07f49787c9d" +branch_labels = None +depends_on = None +airflow_version = "2.5.0" + + +def upgrade(): + """Apply add updated_at column to DagRun and TaskInstance""" + with op.batch_alter_table("task_instance") as batch_op: + batch_op.add_column(sa.Column("updated_at", TIMESTAMP, default=sa.func.now)) + + with op.batch_alter_table("dag_run") as batch_op: + batch_op.add_column(sa.Column("updated_at", TIMESTAMP, default=sa.func.now)) + + +def downgrade(): + """Unapply add updated_at column to DagRun and TaskInstance""" + with op.batch_alter_table("task_instance") as batch_op: + batch_op.drop_column("updated_at") + + with op.batch_alter_table("dag_run") as batch_op: + batch_op.drop_column("updated_at") diff --git a/airflow/migrations/versions/0121_2_5_0_add_dagrunnote_and_taskinstancenote.py b/airflow/migrations/versions/0121_2_5_0_add_dagrunnote_and_taskinstancenote.py new file mode 100644 index 0000000000000..d13fba38aa70c --- /dev/null +++ b/airflow/migrations/versions/0121_2_5_0_add_dagrunnote_and_taskinstancenote.py @@ -0,0 +1,94 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Add DagRunNote and TaskInstanceNote + +Revision ID: 1986afd32c1b +Revises: ee8d93fcc81e +Create Date: 2022-11-22 21:49:05.843439 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +from airflow.migrations.db_types import StringID +from airflow.utils.sqlalchemy import UtcDateTime + +# revision identifiers, used by Alembic. +revision = "1986afd32c1b" +down_revision = "ee8d93fcc81e" +branch_labels = None +depends_on = None +airflow_version = "2.5.0" + + +def upgrade(): + """Apply Add DagRunNote and TaskInstanceNote""" + op.create_table( + "dag_run_note", + sa.Column("user_id", sa.Integer(), nullable=True), + sa.Column("dag_run_id", sa.Integer(), nullable=False), + sa.Column( + "content", sa.String(length=1000).with_variant(sa.Text(length=1000), "mysql"), nullable=True + ), + sa.Column("created_at", UtcDateTime(timezone=True), nullable=False), + sa.Column("updated_at", UtcDateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint( + ("dag_run_id",), ["dag_run.id"], name="dag_run_note_dr_fkey", ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(("user_id",), ["ab_user.id"], name="dag_run_note_user_fkey"), + sa.PrimaryKeyConstraint("dag_run_id", name=op.f("dag_run_note_pkey")), + ) + + op.create_table( + "task_instance_note", + sa.Column("user_id", sa.Integer(), nullable=True), + sa.Column("task_id", StringID(), nullable=False), + sa.Column("dag_id", StringID(), nullable=False), + sa.Column("run_id", StringID(), nullable=False), + sa.Column("map_index", sa.Integer(), nullable=False), + sa.Column( + "content", sa.String(length=1000).with_variant(sa.Text(length=1000), "mysql"), nullable=True + ), + sa.Column("created_at", UtcDateTime(timezone=True), nullable=False), + sa.Column("updated_at", UtcDateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint( + "task_id", "dag_id", "run_id", "map_index", name=op.f("task_instance_note_pkey") + ), + sa.ForeignKeyConstraint( + ("dag_id", "task_id", "run_id", "map_index"), + [ + "task_instance.dag_id", + "task_instance.task_id", + "task_instance.run_id", + "task_instance.map_index", + ], + name="task_instance_note_ti_fkey", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint(("user_id",), ["ab_user.id"], name="task_instance_note_user_fkey"), + ) + + +def downgrade(): + """Unapply Add DagRunNote and TaskInstanceNote""" + op.drop_table("task_instance_note") + op.drop_table("dag_run_note") diff --git a/airflow/migrations/versions/0122_2_5_0_add_is_orphaned_to_datasetmodel.py b/airflow/migrations/versions/0122_2_5_0_add_is_orphaned_to_datasetmodel.py new file mode 100644 index 0000000000000..f4355402c84d6 --- /dev/null +++ b/airflow/migrations/versions/0122_2_5_0_add_is_orphaned_to_datasetmodel.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Add is_orphaned to DatasetModel + +Revision ID: 290244fb8b83 +Revises: 1986afd32c1b +Create Date: 2022-11-22 00:12:53.432961 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "290244fb8b83" +down_revision = "1986afd32c1b" +branch_labels = None +depends_on = None +airflow_version = "2.5.0" + + +def upgrade(): + """Add is_orphaned to DatasetModel""" + with op.batch_alter_table("dataset") as batch_op: + batch_op.add_column( + sa.Column( + "is_orphaned", + sa.Boolean, + default=False, + nullable=False, + server_default="0", + ) + ) + + +def downgrade(): + """Remove is_orphaned from DatasetModel""" + with op.batch_alter_table("dataset") as batch_op: + batch_op.drop_column("is_orphaned", mssql_drop_default=True) diff --git a/airflow/models/__init__.py b/airflow/models/__init__.py index c36dfb1bca0ba..c750757cb048d 100644 --- a/airflow/models/__init__.py +++ b/airflow/models/__init__.py @@ -16,33 +16,9 @@ # specific language governing permissions and limitations # under the License. """Airflow models""" -from typing import Union - -from airflow.models.base import ID_LEN, Base -from airflow.models.baseoperator import BaseOperator, BaseOperatorLink -from airflow.models.connection import Connection -from airflow.models.dag import DAG, DagModel, DagTag -from airflow.models.dagbag import DagBag -from airflow.models.dagpickle import DagPickle -from airflow.models.dagrun import DagRun -from airflow.models.db_callback_request import DbCallbackRequest -from airflow.models.errors import ImportError -from airflow.models.log import Log -from airflow.models.mappedoperator import MappedOperator -from airflow.models.operator import Operator -from airflow.models.param import Param -from airflow.models.pool import Pool -from airflow.models.renderedtifields import RenderedTaskInstanceFields -from airflow.models.sensorinstance import SensorInstance -from airflow.models.skipmixin import SkipMixin -from airflow.models.slamiss import SlaMiss -from airflow.models.taskfail import TaskFail -from airflow.models.taskinstance import TaskInstance, clear_task_instances -from airflow.models.taskreschedule import TaskReschedule -from airflow.models.trigger import Trigger -from airflow.models.variable import Variable -from airflow.models.xcom import XCOM_RETURN_KEY, XCom +from __future__ import annotations +# Do not add new models to this -- this is for compat only __all__ = [ "DAG", "ID_LEN", @@ -52,6 +28,7 @@ "BaseOperatorLink", "Connection", "DagBag", + "DagWarning", "DagModel", "DagPickle", "DagRun", @@ -64,7 +41,6 @@ "Param", "Pool", "RenderedTaskInstanceFields", - "SensorInstance", "SkipMixin", "SlaMiss", "TaskFail", @@ -75,3 +51,96 @@ "XCom", "clear_task_instances", ] + + +from typing import TYPE_CHECKING + + +def import_all_models(): + for name in __lazy_imports: + __getattr__(name) + + import airflow.jobs.backfill_job + import airflow.jobs.base_job + import airflow.jobs.local_task_job + import airflow.jobs.scheduler_job + import airflow.jobs.triggerer_job + import airflow.models.dagwarning + import airflow.models.dataset + import airflow.models.serialized_dag + import airflow.models.tasklog + import airflow.www.fab_security.sqla.models + + +def __getattr__(name): + # PEP-562: Lazy loaded attributes on python modules + path = __lazy_imports.get(name) + if not path: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + from airflow.utils.module_loading import import_string + + val = import_string(f"{path}.{name}") + # Store for next time + globals()[name] = val + return val + + +__lazy_imports = { + "DAG": "airflow.models.dag", + "ID_LEN": "airflow.models.base", + "XCOM_RETURN_KEY": "airflow.models.xcom", + "Base": "airflow.models.base", + "BaseOperator": "airflow.models.baseoperator", + "BaseOperatorLink": "airflow.models.baseoperator", + "Connection": "airflow.models.connection", + "DagBag": "airflow.models.dagbag", + "DagModel": "airflow.models.dag", + "DagPickle": "airflow.models.dagpickle", + "DagRun": "airflow.models.dagrun", + "DagTag": "airflow.models.dag", + "DbCallbackRequest": "airflow.models.db_callback_request", + "ImportError": "airflow.models.errors", + "Log": "airflow.models.log", + "MappedOperator": "airflow.models.mappedoperator", + "Operator": "airflow.models.operator", + "Param": "airflow.models.param", + "Pool": "airflow.models.pool", + "RenderedTaskInstanceFields": "airflow.models.renderedtifields", + "SkipMixin": "airflow.models.skipmixin", + "SlaMiss": "airflow.models.slamiss", + "TaskFail": "airflow.models.taskfail", + "TaskInstance": "airflow.models.taskinstance", + "TaskReschedule": "airflow.models.taskreschedule", + "Trigger": "airflow.models.trigger", + "Variable": "airflow.models.variable", + "XCom": "airflow.models.xcom", + "clear_task_instances": "airflow.models.taskinstance", +} + +if TYPE_CHECKING: + # I was unable to get mypy to respect a airflow/models/__init__.pyi, so + # having to resort back to this hacky method + from airflow.models.base import ID_LEN, Base + from airflow.models.baseoperator import BaseOperator, BaseOperatorLink + from airflow.models.connection import Connection + from airflow.models.dag import DAG, DagModel, DagTag + from airflow.models.dagbag import DagBag + from airflow.models.dagpickle import DagPickle + from airflow.models.dagrun import DagRun + from airflow.models.db_callback_request import DbCallbackRequest + from airflow.models.errors import ImportError + from airflow.models.log import Log + from airflow.models.mappedoperator import MappedOperator + from airflow.models.operator import Operator + from airflow.models.param import Param + from airflow.models.pool import Pool + from airflow.models.renderedtifields import RenderedTaskInstanceFields + from airflow.models.skipmixin import SkipMixin + from airflow.models.slamiss import SlaMiss + from airflow.models.taskfail import TaskFail + from airflow.models.taskinstance import TaskInstance, clear_task_instances + from airflow.models.taskreschedule import TaskReschedule + from airflow.models.trigger import Trigger + from airflow.models.variable import Variable + from airflow.models.xcom import XCOM_RETURN_KEY, XCom diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index cb566ea6f9557..d693f8bfc95cb 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -15,34 +15,25 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import datetime import inspect -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Collection, - Dict, - FrozenSet, - Iterable, - List, - Optional, - Sequence, - Set, - Type, - Union, -) +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable, Iterator, Sequence -from airflow.compat.functools import cached_property +from airflow.compat.functools import cache, cached_property from airflow.configuration import conf from airflow.exceptions import AirflowException +from airflow.models.expandinput import NotFullyPopulated from airflow.models.taskmixin import DAGNode from airflow.utils.context import Context from airflow.utils.helpers import render_template_as_native, render_template_to_string from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.mixins import ResolveMixin from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.sqlalchemy import skip_locked, with_row_locks +from airflow.utils.state import State, TaskInstanceState +from airflow.utils.task_group import MappedTaskGroup from airflow.utils.trigger_rule import TriggerRule from airflow.utils.weight_rule import WeightRule @@ -54,6 +45,7 @@ from airflow.models.baseoperator import BaseOperator, BaseOperatorLink from airflow.models.dag import DAG + from airflow.models.mappedoperator import MappedOperator from airflow.models.operator import Operator from airflow.models.taskinstance import TaskInstance @@ -72,11 +64,15 @@ conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM) ) DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS -DEFAULT_TASK_EXECUTION_TIMEOUT: Optional[datetime.timedelta] = conf.gettimedelta( +DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta( "core", "default_task_execution_timeout" ) +class NotMapped(Exception): + """Raise if a task is neither mapped nor has any parent mapped groups.""" + + class AbstractOperator(LoggingMixin, DAGNode): """Common implementation for operators, including unmapped and mapped. @@ -90,13 +86,13 @@ class AbstractOperator(LoggingMixin, DAGNode): :meta private: """ - operator_class: Union[Type["BaseOperator"], Dict[str, Any]] + operator_class: type[BaseOperator] | dict[str, Any] weight_rule: str priority_weight: int # Defines the operator level extra links. - operator_extra_links: Collection["BaseOperatorLink"] + operator_extra_links: Collection[BaseOperatorLink] # For derived classes to define which fields will get jinjaified. template_fields: Collection[str] # Defines which files extensions to look for in the templated fields. @@ -105,32 +101,39 @@ class AbstractOperator(LoggingMixin, DAGNode): owner: str task_id: str - HIDE_ATTRS_FROM_UI: ClassVar[FrozenSet[str]] = frozenset( + outlets: list + inlets: list + + HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset( ( - 'log', - 'dag', # We show dag_id, don't need to show this too - 'node_id', # Duplicates task_id - 'task_group', # Doesn't have a useful repr, no point showing in UI - 'inherits_from_empty_operator', # impl detail + "log", + "dag", # We show dag_id, don't need to show this too + "node_id", # Duplicates task_id + "task_group", # Doesn't have a useful repr, no point showing in UI + "inherits_from_empty_operator", # impl detail # For compatibility with TG, for operators these are just the current task, no point showing - 'roots', - 'leaves', + "roots", + "leaves", # These lists are already shown via *_task_ids - 'upstream_list', - 'downstream_list', + "upstream_list", + "downstream_list", # Not useful, implementation detail, already shown elsewhere - 'global_operator_extra_link_dict', - 'operator_extra_link_dict', + "global_operator_extra_link_dict", + "operator_extra_link_dict", ) ) - def get_dag(self) -> "Optional[DAG]": + def get_dag(self) -> DAG | None: raise NotImplementedError() @property def task_type(self) -> str: raise NotImplementedError() + @property + def operator_name(self) -> str: + raise NotImplementedError() + @property def inherits_from_empty_operator(self) -> bool: raise NotImplementedError() @@ -147,7 +150,7 @@ def dag_id(self) -> str: def node_id(self) -> str: return self.task_id - def get_template_env(self) -> "jinja2.Environment": + def get_template_env(self) -> jinja2.Environment: """Fetch a Jinja template environment from the DAG or instantiate empty environment if no DAG.""" # This is imported locally since Jinja2 is heavy and we don't need it # for most of the functionalities. It is imported by get_template_env() @@ -189,7 +192,7 @@ def resolve_template_files(self) -> None: self.log.exception("Failed to get source %s", item) self.prepare_template() - def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]: + def get_direct_relative_ids(self, upstream: bool = False) -> set[str]: """Get direct relative IDs to the current task, upstream or downstream.""" if upstream: return self.upstream_task_ids @@ -198,32 +201,116 @@ def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]: def get_flat_relative_ids( self, upstream: bool = False, - found_descendants: Optional[Set[str]] = None, - ) -> Set[str]: + found_descendants: set[str] | None = None, + ) -> set[str]: """Get a flat set of relative IDs, upstream or downstream.""" dag = self.get_dag() if not dag: return set() - if not found_descendants: + if found_descendants is None: found_descendants = set() - relative_ids = self.get_direct_relative_ids(upstream) - for relative_id in relative_ids: - if relative_id not in found_descendants: - found_descendants.add(relative_id) - relative_task = dag.task_dict[relative_id] - relative_task.get_flat_relative_ids(upstream, found_descendants) + task_ids_to_trace = self.get_direct_relative_ids(upstream) + while task_ids_to_trace: + task_ids_to_trace_next: set[str] = set() + for task_id in task_ids_to_trace: + if task_id in found_descendants: + continue + task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream)) + found_descendants.add(task_id) + task_ids_to_trace = task_ids_to_trace_next return found_descendants - def get_flat_relatives(self, upstream: bool = False) -> Collection["Operator"]: + def get_flat_relatives(self, upstream: bool = False) -> Collection[Operator]: """Get a flat list of relatives, either upstream or downstream.""" dag = self.get_dag() if not dag: return set() return [dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream)] + def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: + """Return mapped nodes that are direct dependencies of the current task. + + For now, this walks the entire DAG to find mapped nodes that has this + current task as an upstream. We cannot use ``downstream_list`` since it + only contains operators, not task groups. In the future, we should + provide a way to record an DAG node's all downstream nodes instead. + + Note that this does not guarantee the returned tasks actually use the + current task for task mapping, but only checks those task are mapped + operators, and are downstreams of the current task. + + To get a list of tasks that uses the current task for task mapping, use + :meth:`iter_mapped_dependants` instead. + """ + from airflow.models.mappedoperator import MappedOperator + from airflow.utils.task_group import TaskGroup + + def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]: + """Recursively walk children in a task group. + + This yields all direct children (including both tasks and task + groups), and all children of any task groups. + """ + for key, child in group.children.items(): + yield key, child + if isinstance(child, TaskGroup): + yield from _walk_group(child) + + dag = self.get_dag() + if not dag: + raise RuntimeError("Cannot check for mapped dependants when not attached to a DAG") + for key, child in _walk_group(dag.task_group): + if key == self.node_id: + continue + if not isinstance(child, (MappedOperator, MappedTaskGroup)): + continue + if self.node_id in child.upstream_task_ids: + yield child + + def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]: + """Return mapped nodes that depend on the current task the expansion. + + For now, this walks the entire DAG to find mapped nodes that has this + current task as an upstream. We cannot use ``downstream_list`` since it + only contains operators, not task groups. In the future, we should + provide a way to record an DAG node's all downstream nodes instead. + """ + return ( + downstream + for downstream in self._iter_all_mapped_downstreams() + if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies()) + ) + + def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: + """Return mapped task groups this task belongs to. + + Groups are returned from the closest to the outmost. + + :meta private: + """ + parent = self.task_group + while parent is not None: + if isinstance(parent, MappedTaskGroup): + yield parent + parent = parent.task_group + + def get_closest_mapped_task_group(self) -> MappedTaskGroup | None: + """:meta private:""" + return next(self.iter_mapped_task_groups(), None) + + def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator: + """Get the "normal" operator from current abstract operator. + + MappedOperator uses this to unmap itself based on the map index. A non- + mapped operator (i.e. BaseOperator subclass) simply returns itself. + + :meta private: + """ + raise NotImplementedError() + @property def priority_weight_total(self) -> int: """ @@ -252,9 +339,9 @@ def priority_weight_total(self) -> int: ) @cached_property - def operator_extra_link_dict(self) -> Dict[str, Any]: + def operator_extra_link_dict(self) -> dict[str, Any]: """Returns dictionary of all extra links for the operator""" - op_extra_links_from_plugin: Dict[str, Any] = {} + op_extra_links_from_plugin: dict[str, Any] = {} from airflow import plugins_manager plugins_manager.initialize_extra_operators_links_plugins() @@ -271,7 +358,7 @@ def operator_extra_link_dict(self) -> Dict[str, Any]: return operator_extra_links_all @cached_property - def global_operator_extra_link_dict(self) -> Dict[str, Any]: + def global_operator_extra_link_dict(self) -> dict[str, Any]: """Returns dictionary of all global extra links""" from airflow import plugins_manager @@ -281,10 +368,10 @@ def global_operator_extra_link_dict(self) -> Dict[str, Any]: return {link.name: link for link in plugins_manager.global_operator_extra_links} @cached_property - def extra_links(self) -> List[str]: + def extra_links(self) -> list[str]: return list(set(self.operator_extra_link_dict).union(self.global_operator_extra_link_dict)) - def get_extra_links(self, ti: "TaskInstance", link_name: str) -> Optional[str]: + def get_extra_links(self, ti: TaskInstance, link_name: str) -> str | None: """For an operator, gets the URLs that the ``extra_links`` entry points to. :meta private: @@ -295,33 +382,198 @@ def get_extra_links(self, ti: "TaskInstance", link_name: str) -> Optional[str]: :param link_name: The name of the link we're looking for the URL for. Should be one of the options specified in ``extra_links``. """ - link: Optional["BaseOperatorLink"] = self.operator_extra_link_dict.get(link_name) + link: BaseOperatorLink | None = self.operator_extra_link_dict.get(link_name) if not link: link = self.global_operator_extra_link_dict.get(link_name) if not link: return None - # Check for old function signature + parameters = inspect.signature(link.get_link).parameters - args = [name for name, p in parameters.items() if p.kind != p.VAR_KEYWORD] - if "ti_key" in args: - return link.get_link(self, ti_key=ti.key) + old_signature = all(name != "ti_key" for name, p in parameters.items() if p.kind != p.VAR_KEYWORD) + + if old_signature: + return link.get_link(self.unmap(None), ti.dag_run.logical_date) # type: ignore[misc] + return link.get_link(self.unmap(None), ti_key=ti.key) + + @cache + def get_parse_time_mapped_ti_count(self) -> int: + """Number of mapped task instances that can be created on DAG run creation. + + This only considers literal mapped arguments, and would return *None* + when any non-literal values are used for mapping. + + :raise NotFullyPopulated: If non-literal mapped arguments are encountered. + :raise NotMapped: If the operator is neither mapped, nor has any parent + mapped task groups. + :return: Total number of mapped TIs this task should have. + """ + group = self.get_closest_mapped_task_group() + if group is None: + raise NotMapped + return group.get_parse_time_mapped_ti_count() + + def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: + """Number of mapped TaskInstances that can be created at run time. + + This considers both literal and non-literal mapped arguments, and the + result is therefore available when all depended tasks have finished. The + return value should be identical to ``parse_time_mapped_ti_count`` if + all mapped arguments are literal. + + :raise NotFullyPopulated: If upstream tasks are not all complete yet. + :raise NotMapped: If the operator is neither mapped, nor has any parent + mapped task groups. + :return: Total number of mapped TIs this task should have. + """ + group = self.get_closest_mapped_task_group() + if group is None: + raise NotMapped + return group.get_mapped_ti_count(run_id, session=session) + + def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]: + """Create the mapped task instances for mapped task. + + :raise NotMapped: If this task does not need expansion. + :return: The newly created mapped task instances (if any) in ascending + order by map index, and the maximum map index value. + """ + from sqlalchemy import func, or_ + + from airflow.models.baseoperator import BaseOperator + from airflow.models.mappedoperator import MappedOperator + from airflow.models.taskinstance import TaskInstance + from airflow.settings import task_instance_mutation_hook + + if not isinstance(self, (BaseOperator, MappedOperator)): + raise RuntimeError(f"cannot expand unrecognized operator type {type(self).__name__}") + + try: + total_length: int | None = self.get_mapped_ti_count(run_id, session=session) + except NotFullyPopulated as e: + # It's possible that the upstream tasks are not yet done, but we + # don't have upstream of upstreams in partial DAGs (possible in the + # mini-scheduler), so we ignore this exception. + if not self.dag or not self.dag.partial: + self.log.error( + "Cannot expand %r for run %s; missing upstream values: %s", + self, + run_id, + sorted(e.missing), + ) + total_length = None + + state: TaskInstanceState | None = None + unmapped_ti: TaskInstance | None = ( + session.query(TaskInstance) + .filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.run_id == run_id, + TaskInstance.map_index == -1, + or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)), + ) + .one_or_none() + ) + + all_expanded_tis: list[TaskInstance] = [] + + if unmapped_ti: + # The unmapped task instance still exists and is unfinished, i.e. we + # haven't tried to run it before. + if total_length is None: + # If the DAG is partial, it's likely that the upstream tasks + # are not done yet, so the task can't fail yet. + if not self.dag or not self.dag.partial: + unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED + elif total_length < 1: + # If the upstream maps this to a zero-length value, simply mark + # the unmapped task instance as SKIPPED (if needed). + self.log.info( + "Marking %s as SKIPPED since the map has %d values to expand", + unmapped_ti, + total_length, + ) + unmapped_ti.state = TaskInstanceState.SKIPPED + else: + zero_index_ti_exists = ( + session.query(TaskInstance) + .filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.run_id == run_id, + TaskInstance.map_index == 0, + ) + .count() + > 0 + ) + if not zero_index_ti_exists: + # Otherwise convert this into the first mapped index, and create + # TaskInstance for other indexes. + unmapped_ti.map_index = 0 + self.log.debug("Updated in place to become %s", unmapped_ti) + all_expanded_tis.append(unmapped_ti) + session.flush() + else: + self.log.debug("Deleting the original task instance: %s", unmapped_ti) + session.delete(unmapped_ti) + state = unmapped_ti.state + + if total_length is None or total_length < 1: + # Nothing to fixup. + indexes_to_map: Iterable[int] = () else: - return link.get_link(self, ti.dag_run.logical_date) # type: ignore[misc] - return None + # Only create "missing" ones. + current_max_mapping = ( + session.query(func.max(TaskInstance.map_index)) + .filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.run_id == run_id, + ) + .scalar() + ) + indexes_to_map = range(current_max_mapping + 1, total_length) + + for index in indexes_to_map: + # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings. + ti = TaskInstance(self, run_id=run_id, map_index=index, state=state) + self.log.debug("Expanding TIs upserted %s", ti) + task_instance_mutation_hook(ti) + ti = session.merge(ti) + ti.refresh_from_task(self) # session.merge() loses task information. + all_expanded_tis.append(ti) + + # Coerce the None case to 0 -- these two are almost treated identically, + # except the unmapped ti (if exists) is marked to different states. + total_expanded_ti_count = total_length or 0 + + # Any (old) task instances with inapplicable indexes (>= the total + # number we need) are set to "REMOVED". + query = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.run_id == run_id, + TaskInstance.map_index >= total_expanded_ti_count, + ) + to_update = with_row_locks(query, of=TaskInstance, session=session, **skip_locked(session=session)) + for ti in to_update: + ti.state = TaskInstanceState.REMOVED + session.flush() + return all_expanded_tis, total_expanded_ti_count - 1 def render_template_fields( self, context: Context, - jinja_env: Optional["jinja2.Environment"] = None, - ) -> Optional["BaseOperator"]: - """Template all attributes listed in template_fields. + jinja_env: jinja2.Environment | None = None, + ) -> None: + """Template all attributes listed in *self.template_fields*. If the operator is mapped, this should return the unmapped, fully rendered, and map-expanded operator. The mapped operator should not be - modified. + modified. However, *context* may be modified in-place to reference the + unmapped operator for template rendering. - If the operator is not mapped, this should modify the operator in-place - and return either *None* (for backwards compatibility) or *self*. + If the operator is not mapped, this should modify the operator in-place. """ raise NotImplementedError() @@ -331,10 +583,10 @@ def _do_render_template_fields( parent: Any, template_fields: Iterable[str], context: Context, - jinja_env: "jinja2.Environment", - seen_oids: Set, + jinja_env: jinja2.Environment, + seen_oids: set[int], *, - session: "Session" = NEW_SESSION, + session: Session = NEW_SESSION, ) -> None: for attr_name in template_fields: try: @@ -346,20 +598,30 @@ def _do_render_template_fields( ) if not value: continue - rendered_content = self.render_template( - value, - context, - jinja_env, - seen_oids, - ) - setattr(parent, attr_name, rendered_content) + try: + rendered_content = self.render_template( + value, + context, + jinja_env, + seen_oids, + ) + except Exception: + self.log.exception( + "Exception rendering Jinja template for task '%s', field '%s'. Template: %r", + self.task_id, + attr_name, + value, + ) + raise + else: + setattr(parent, attr_name, rendered_content) def render_template( self, content: Any, context: Context, - jinja_env: Optional["jinja2.Environment"] = None, - seen_oids: Optional[Set] = None, + jinja_env: jinja2.Environment | None = None, + seen_oids: set[int] | None = None, ) -> Any: """Render a templated string. @@ -379,12 +641,17 @@ def render_template( value = content del content + if seen_oids is not None: + oids = seen_oids + else: + oids = set() + + if id(value) in oids: + return value + if not jinja_env: jinja_env = self.get_template_env() - from airflow.models.param import DagParam - from airflow.models.xcom_arg import XComArg - if isinstance(value, str): if any(value.endswith(ext) for ext in self.template_ext): # A filepath. template = jinja_env.get_template(value) @@ -395,26 +662,22 @@ def render_template( return render_template_as_native(template, context) return render_template_to_string(template, context) - if isinstance(value, (DagParam, XComArg)): + if isinstance(value, ResolveMixin): return value.resolve(context) # Fast path for common built-in collections. if value.__class__ is tuple: - return tuple(self.render_template(element, context, jinja_env) for element in value) + return tuple(self.render_template(element, context, jinja_env, oids) for element in value) elif isinstance(value, tuple): # Special case for named tuples. - return value.__class__(*(self.render_template(el, context, jinja_env) for el in value)) + return value.__class__(*(self.render_template(el, context, jinja_env, oids) for el in value)) elif isinstance(value, list): - return [self.render_template(element, context, jinja_env) for element in value] + return [self.render_template(element, context, jinja_env, oids) for element in value] elif isinstance(value, dict): - return {key: self.render_template(value, context, jinja_env) for key, value in value.items()} + return {k: self.render_template(v, context, jinja_env, oids) for k, v in value.items()} elif isinstance(value, set): - return {self.render_template(element, context, jinja_env) for element in value} + return {self.render_template(element, context, jinja_env, oids) for element in value} # More complex collections. - if seen_oids is None: - oids = set() - else: - oids = seen_oids self._render_nested_template_fields(value, context, jinja_env, oids) return value @@ -422,8 +685,8 @@ def _render_nested_template_fields( self, value: Any, context: Context, - jinja_env: "jinja2.Environment", - seen_oids: Set[int], + jinja_env: jinja2.Environment, + seen_oids: set[int], ) -> None: if id(value) in seen_oids: return diff --git a/airflow/models/base.py b/airflow/models/base.py index 478bd904eb83c..b2587cc1767b8 100644 --- a/airflow/models/base.py +++ b/airflow/models/base.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from typing import Any @@ -25,9 +26,26 @@ SQL_ALCHEMY_SCHEMA = conf.get("database", "SQL_ALCHEMY_SCHEMA") -metadata = ( - None if not SQL_ALCHEMY_SCHEMA or SQL_ALCHEMY_SCHEMA.isspace() else MetaData(schema=SQL_ALCHEMY_SCHEMA) -) +# For more information about what the tokens in the naming convention +# below mean, see: +# https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.MetaData.params.naming_convention +naming_convention = { + "ix": "idx_%(column_0_N_label)s", + "uq": "%(table_name)s_%(column_0_N_name)s_uq", + "ck": "ck_%(table_name)s_%(constraint_name)s", + "fk": "%(table_name)s_%(column_0_name)s_fkey", + "pk": "%(table_name)s_pkey", +} + + +def _get_schema(): + if not SQL_ALCHEMY_SCHEMA or SQL_ALCHEMY_SCHEMA.isspace(): + return None + return SQL_ALCHEMY_SCHEMA + + +metadata = MetaData(schema=_get_schema(), naming_convention=naming_convention) + Base: Any = declarative_base(metadata=metadata) ID_LEN = 250 @@ -35,9 +53,9 @@ def get_id_collation_args(): """Get SQLAlchemy args to use for COLLATION""" - collation = conf.get('database', 'sql_engine_collation_for_ids', fallback=None) + collation = conf.get("database", "sql_engine_collation_for_ids", fallback=None) if collation: - return {'collation': collation} + return {"collation": collation} else: # Automatically use utf8mb3_bin collation for mysql # This is backwards-compatible. All our IDS are ASCII anyway so even if @@ -49,9 +67,9 @@ def get_id_collation_args(): # # We cannot use session/dialect as at this point we are trying to determine the right connection # parameters, so we use the connection - conn = conf.get('database', 'sql_alchemy_conn', fallback='') - if conn.startswith('mysql') or conn.startswith("mariadb"): - return {'collation': 'utf8mb3_bin'} + conn = conf.get("database", "sql_alchemy_conn", fallback="") + if conn.startswith("mysql") or conn.startswith("mariadb"): + return {"collation": "utf8mb3_bin"} return {} diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 3208c435769b3..9448c4a415db9 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -16,6 +16,8 @@ # specific language governing permissions and limitations # under the License. """Base operator for all operators.""" +from __future__ import annotations + import abc import collections import collections.abc @@ -35,14 +37,9 @@ Callable, ClassVar, Collection, - Dict, - FrozenSet, Iterable, List, - Optional, Sequence, - Set, - Tuple, Type, TypeVar, Union, @@ -56,7 +53,7 @@ from sqlalchemy.orm.exc import NoResultFound from airflow.configuration import conf -from airflow.exceptions import AirflowException, TaskDeferred +from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskDeferred from airflow.lineage import apply_lineage, prepare_lineage from airflow.models.abstractoperator import ( DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, @@ -87,6 +84,7 @@ from airflow.triggers.base import BaseTrigger from airflow.utils import timezone from airflow.utils.context import Context +from airflow.utils.decorators import fixup_decorator_warning_stack from airflow.utils.helpers import validate_key from airflow.utils.operator_resources import Resources from airflow.utils.session import NEW_SESSION, provide_session @@ -98,6 +96,7 @@ from airflow.models.dag import DAG from airflow.models.taskinstance import TaskInstanceKey + from airflow.models.xcom_arg import XComArg from airflow.utils.task_group import TaskGroup ScheduleInterval = Union[str, timedelta, relativedelta] @@ -105,12 +104,12 @@ TaskPreExecuteHook = Callable[[Context], None] TaskPostExecuteHook = Callable[[Context, Any], None] -T = TypeVar('T', bound=FunctionType) +T = TypeVar("T", bound=FunctionType) logger = logging.getLogger("airflow.models.baseoperator.BaseOperator") -def parse_retries(retries: Any) -> Optional[int]: +def parse_retries(retries: Any) -> int | None: if retries is None or isinstance(retries, int): return retries try: @@ -121,20 +120,20 @@ def parse_retries(retries: Any) -> Optional[int]: return parsed_retries -def coerce_timedelta(value: Union[float, timedelta], *, key: str) -> timedelta: +def coerce_timedelta(value: float | timedelta, *, key: str) -> timedelta: if isinstance(value, timedelta): return value logger.debug("%s isn't a timedelta object, assuming secs", key) return timedelta(seconds=value) -def coerce_resources(resources: Optional[Dict[str, Any]]) -> Optional[Resources]: +def coerce_resources(resources: dict[str, Any] | None) -> Resources | None: if resources is None: return None return Resources(**resources) -def _get_parent_defaults(dag: Optional["DAG"], task_group: Optional["TaskGroup"]) -> Tuple[dict, ParamsDict]: +def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> tuple[dict, ParamsDict]: if not dag: return {}, ParamsDict() dag_args = copy.copy(dag.default_args) @@ -147,11 +146,11 @@ def _get_parent_defaults(dag: Optional["DAG"], task_group: Optional["TaskGroup"] def get_merged_defaults( - dag: Optional["DAG"], - task_group: Optional["TaskGroup"], - task_params: Optional[dict], - task_default_args: Optional[dict], -) -> Tuple[dict, ParamsDict]: + dag: DAG | None, + task_group: TaskGroup | None, + task_params: dict | None, + task_default_args: dict | None, +) -> tuple[dict, ParamsDict]: args, params = _get_parent_defaults(dag, task_group) if task_params: if not isinstance(task_params, collections.abc.Mapping): @@ -172,7 +171,7 @@ class _PartialDescriptor: class_method = None def __get__( - self, obj: "BaseOperator", cls: "Optional[Type[BaseOperator]]" = None + self, obj: BaseOperator, cls: type[BaseOperator] | None = None ) -> Callable[..., OperatorPartial]: # Call this "partial" so it looks nicer in stack traces. def partial(**kwargs): @@ -185,46 +184,46 @@ def partial(**kwargs): # This is what handles the actual mapping. def partial( - operator_class: Type["BaseOperator"], + operator_class: type[BaseOperator], *, task_id: str, - dag: Optional["DAG"] = None, - task_group: Optional["TaskGroup"] = None, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, + dag: DAG | None = None, + task_group: TaskGroup | None = None, + start_date: datetime | None = None, + end_date: datetime | None = None, owner: str = DEFAULT_OWNER, - email: Union[None, str, Iterable[str]] = None, - params: Optional[dict] = None, - resources: Optional[Dict[str, Any]] = None, + email: None | str | Iterable[str] = None, + params: dict | None = None, + resources: dict[str, Any] | None = None, trigger_rule: str = DEFAULT_TRIGGER_RULE, depends_on_past: bool = False, ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, wait_for_downstream: bool = False, - retries: Optional[int] = DEFAULT_RETRIES, + retries: int | None = DEFAULT_RETRIES, queue: str = DEFAULT_QUEUE, - pool: Optional[str] = None, + pool: str | None = None, pool_slots: int = DEFAULT_POOL_SLOTS, - execution_timeout: Optional[timedelta] = DEFAULT_TASK_EXECUTION_TIMEOUT, - max_retry_delay: Union[None, timedelta, float] = None, - retry_delay: Union[timedelta, float] = DEFAULT_RETRY_DELAY, + execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT, + max_retry_delay: None | timedelta | float = None, + retry_delay: timedelta | float = DEFAULT_RETRY_DELAY, retry_exponential_backoff: bool = False, priority_weight: int = DEFAULT_PRIORITY_WEIGHT, weight_rule: str = DEFAULT_WEIGHT_RULE, - sla: Optional[timedelta] = None, - max_active_tis_per_dag: Optional[int] = None, - on_execute_callback: Optional[TaskStateChangeCallback] = None, - on_failure_callback: Optional[TaskStateChangeCallback] = None, - on_success_callback: Optional[TaskStateChangeCallback] = None, - on_retry_callback: Optional[TaskStateChangeCallback] = None, - run_as_user: Optional[str] = None, - executor_config: Optional[Dict] = None, - inlets: Optional[Any] = None, - outlets: Optional[Any] = None, - doc: Optional[str] = None, - doc_md: Optional[str] = None, - doc_json: Optional[str] = None, - doc_yaml: Optional[str] = None, - doc_rst: Optional[str] = None, + sla: timedelta | None = None, + max_active_tis_per_dag: int | None = None, + on_execute_callback: TaskStateChangeCallback | None = None, + on_failure_callback: TaskStateChangeCallback | None = None, + on_success_callback: TaskStateChangeCallback | None = None, + on_retry_callback: TaskStateChangeCallback | None = None, + run_as_user: str | None = None, + executor_config: dict | None = None, + inlets: Any | None = None, + outlets: Any | None = None, + doc: str | None = None, + doc_md: str | None = None, + doc_json: str | None = None, + doc_yaml: str | None = None, + doc_rst: str | None = None, **kwargs, ) -> OperatorPartial: from airflow.models.dag import DagContext @@ -239,7 +238,7 @@ def partial( task_id = task_group.child_id(task_id) # Merge DAG and task group level defaults into user-supplied values. - partial_kwargs, default_params = get_merged_defaults( + partial_kwargs, partial_params = get_merged_defaults( dag=dag, task_group=task_group, task_params=params, @@ -255,7 +254,6 @@ def partial( partial_kwargs.setdefault("end_date", end_date) partial_kwargs.setdefault("owner", owner) partial_kwargs.setdefault("email", email) - partial_kwargs.setdefault("params", default_params) partial_kwargs.setdefault("trigger_rule", trigger_rule) partial_kwargs.setdefault("depends_on_past", depends_on_past) partial_kwargs.setdefault("ignore_first_depends_on_past", ignore_first_depends_on_past) @@ -278,8 +276,8 @@ def partial( partial_kwargs.setdefault("on_success_callback", on_success_callback) partial_kwargs.setdefault("run_as_user", run_as_user) partial_kwargs.setdefault("executor_config", executor_config) - partial_kwargs.setdefault("inlets", inlets) - partial_kwargs.setdefault("outlets", outlets) + partial_kwargs.setdefault("inlets", inlets or []) + partial_kwargs.setdefault("outlets", outlets or []) partial_kwargs.setdefault("resources", resources) partial_kwargs.setdefault("doc", doc) partial_kwargs.setdefault("doc_json", doc_json) @@ -306,7 +304,11 @@ def partial( partial_kwargs["executor_config"] = partial_kwargs["executor_config"] or {} partial_kwargs["resources"] = coerce_resources(partial_kwargs["resources"]) - return OperatorPartial(operator_class=operator_class, kwargs=partial_kwargs) + return OperatorPartial( + operator_class=operator_class, + kwargs=partial_kwargs, + params=partial_params, + ) class BaseOperatorMeta(abc.ABCMeta): @@ -331,7 +333,7 @@ def _apply_defaults(cls, func: T) -> T: non_variadic_params = { name: param for (name, param) in sig_cache.parameters.items() - if param.name != 'self' and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) + if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) } non_optional_args = { name @@ -339,28 +341,10 @@ def _apply_defaults(cls, func: T) -> T: if param.default == param.empty and name != "task_id" } - class autostacklevel_warn: - def __init__(self): - self.warnings = __import__('warnings') - - def __getattr__(self, name): - return getattr(self.warnings, name) - - def __dir__(self): - return dir(self.warnings) - - def warn(self, message, category=None, stacklevel=1, source=None): - self.warnings.warn(message, category, stacklevel + 2, source) - - if func.__globals__.get('warnings') is sys.modules['warnings']: - # Yes, this is slightly hacky, but it _automatically_ sets the right - # stacklevel parameter to `warnings.warn` to ignore the decorator. Now - # that the decorator is applied automatically, this makes the needed - # stacklevel parameter less confusing. - func.__globals__['warnings'] = autostacklevel_warn() + fixup_decorator_warning_stack(func) @functools.wraps(func) - def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: + def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any: from airflow.models.dag import DagContext from airflow.utils.task_group import TaskGroupContext @@ -372,8 +356,8 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: getattr(self, "_BaseOperator__from_mapped", False), ) - dag: Optional[DAG] = kwargs.get('dag') or DagContext.get_current_dag() - task_group: Optional[TaskGroup] = kwargs.get('task_group') + dag: DAG | None = kwargs.get("dag") or DagContext.get_current_dag() + task_group: TaskGroup | None = kwargs.get("task_group") if dag and not task_group: task_group = TaskGroupContext.get_current_task_group(dag) @@ -398,12 +382,12 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: if merged_params: kwargs["params"] = merged_params - hook = getattr(self, '_hook_apply_defaults', None) + hook = getattr(self, "_hook_apply_defaults", None) if hook: args, kwargs = hook(**kwargs, default_args=default_args) - default_args = kwargs.pop('default_args', {}) + default_args = kwargs.pop("default_args", {}) - if not hasattr(self, '_BaseOperator__init_kwargs'): + if not hasattr(self, "_BaseOperator__init_kwargs"): self._BaseOperator__init_kwargs = {} self._BaseOperator__from_mapped = instantiated_from_mapped @@ -412,8 +396,9 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: # Store the args passed to init -- we need them to support task.map serialzation! self._BaseOperator__init_kwargs.update(kwargs) # type: ignore - if not instantiated_from_mapped: - # Set upstream task defined by XComArgs passed to template fields of the operator. + # Set upstream task defined by XComArgs passed to template fields of the operator. + # BUT: only do this _ONCE_, not once for each class in the hierarchy + if not instantiated_from_mapped and func == self.__init__.__wrapped__: # type: ignore[misc] self.set_xcomargs_dependencies() # Mark instance as instantiated. self._BaseOperator__instantiated = True @@ -428,8 +413,8 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: def __new__(cls, name, bases, namespace, **kwargs): new_cls = super().__new__(cls, name, bases, namespace, **kwargs) with contextlib.suppress(KeyError): - # Update the partial descriptor with the class method so it call call the actual function (but let - # subclasses override it if they need to) + # Update the partial descriptor with the class method, so it calls the actual function + # (but let subclasses override it if they need to) partial_desc = vars(new_cls)["partial"] if isinstance(partial_desc, _PartialDescriptor): partial_desc.class_method = classmethod(partial) @@ -463,7 +448,7 @@ class derived from this one results in the creation of a task object, (e.g. user/person/team/role name) to clarify ownership is recommended. :param email: the 'to' email address(es) used in email alerts. This can be a single email or multiple ones. Multiple addresses can be specified as a - comma or semi-colon separated string or by passing a list of strings. + comma or semicolon separated string or by passing a list of strings. :param email_on_retry: Indicates whether email alerts should be sent when a task is retried :param email_on_failure: Indicates whether email alerts should be sent when @@ -576,7 +561,7 @@ class derived from this one results in the creation of a task object, |experimental| :param trigger_rule: defines the rule by which dependencies are applied for the task to get triggered. Options are: - ``{ all_success | all_failed | all_done | all_skipped | one_success | + ``{ all_success | all_failed | all_done | all_skipped | one_success | one_done | one_failed | none_failed | none_failed_min_one_success | none_skipped | always}`` default is ``all_success``. Options can be set as string or using the constants defined in the static class @@ -620,54 +605,54 @@ class derived from this one results in the creation of a task object, template_fields: Sequence[str] = () template_ext: Sequence[str] = () - template_fields_renderers: Dict[str, str] = {} + template_fields_renderers: dict[str, str] = {} # Defines the color in the UI - ui_color: str = '#fff' - ui_fgcolor: str = '#000' + ui_color: str = "#fff" + ui_fgcolor: str = "#000" pool: str = "" # base list which includes all the attrs that don't need deep copy. - _base_operator_shallow_copy_attrs: Tuple[str, ...] = ( - 'user_defined_macros', - 'user_defined_filters', - 'params', - '_log', + _base_operator_shallow_copy_attrs: tuple[str, ...] = ( + "user_defined_macros", + "user_defined_filters", + "params", + "_log", ) # each operator should override this class attr for shallow copy attrs. shallow_copy_attrs: Sequence[str] = () # Defines the operator level extra links - operator_extra_links: Collection['BaseOperatorLink'] = () + operator_extra_links: Collection[BaseOperatorLink] = () # The _serialized_fields are lazily loaded when get_serialized_fields() method is called - __serialized_fields: Optional[FrozenSet[str]] = None + __serialized_fields: frozenset[str] | None = None partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore _comps = { - 'task_id', - 'dag_id', - 'owner', - 'email', - 'email_on_retry', - 'retry_delay', - 'retry_exponential_backoff', - 'max_retry_delay', - 'start_date', - 'end_date', - 'depends_on_past', - 'wait_for_downstream', - 'priority_weight', - 'sla', - 'execution_timeout', - 'on_execute_callback', - 'on_failure_callback', - 'on_success_callback', - 'on_retry_callback', - 'do_xcom_push', + "task_id", + "dag_id", + "owner", + "email", + "email_on_retry", + "retry_delay", + "retry_exponential_backoff", + "max_retry_delay", + "start_date", + "end_date", + "depends_on_past", + "wait_for_downstream", + "priority_weight", + "sla", + "execution_timeout", + "on_execute_callback", + "on_failure_callback", + "on_success_callback", + "on_retry_callback", + "do_xcom_push", } # Defines if the operator supports lineage without manual definitions @@ -677,25 +662,20 @@ class derived from this one results in the creation of a task object, __instantiated = False # List of args as passed to `init()`, after apply_defaults() has been updated. Used to "recreate" the task # when mapping - __init_kwargs: Dict[str, Any] + __init_kwargs: dict[str, Any] # Set to True before calling execute method _lock_for_execution = False - _dag: Optional["DAG"] = None - task_group: Optional["TaskGroup"] = None + _dag: DAG | None = None + task_group: TaskGroup | None = None # subdag parameter is only set for SubDagOperator. # Setting it to None by default as other Operators do not have that field - subdag: Optional["DAG"] = None - - start_date: Optional[pendulum.DateTime] = None - end_date: Optional[pendulum.DateTime] = None + subdag: DAG | None = None - # How operator-mapping arguments should be validated. If True, a default validation implementation that - # calls the operator's constructor is used. If False, the operator should implement its own validation - # logic (default implementation is 'pass' i.e. no validation whatsoever). - mapped_arguments_validated_by_init: ClassVar[bool] = False + start_date: pendulum.DateTime | None = None + end_date: pendulum.DateTime | None = None # Set to True for an operator instantiated by a mapped operator. __from_mapped = False @@ -704,49 +684,49 @@ def __init__( self, task_id: str, owner: str = DEFAULT_OWNER, - email: Optional[Union[str, Iterable[str]]] = None, - email_on_retry: bool = conf.getboolean('email', 'default_email_on_retry', fallback=True), - email_on_failure: bool = conf.getboolean('email', 'default_email_on_failure', fallback=True), - retries: Optional[int] = DEFAULT_RETRIES, - retry_delay: Union[timedelta, float] = DEFAULT_RETRY_DELAY, + email: str | Iterable[str] | None = None, + email_on_retry: bool = conf.getboolean("email", "default_email_on_retry", fallback=True), + email_on_failure: bool = conf.getboolean("email", "default_email_on_failure", fallback=True), + retries: int | None = DEFAULT_RETRIES, + retry_delay: timedelta | float = DEFAULT_RETRY_DELAY, retry_exponential_backoff: bool = False, - max_retry_delay: Optional[Union[timedelta, float]] = None, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, + max_retry_delay: timedelta | float | None = None, + start_date: datetime | None = None, + end_date: datetime | None = None, depends_on_past: bool = False, ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, wait_for_downstream: bool = False, - dag: Optional['DAG'] = None, - params: Optional[Dict] = None, - default_args: Optional[Dict] = None, + dag: DAG | None = None, + params: dict | None = None, + default_args: dict | None = None, priority_weight: int = DEFAULT_PRIORITY_WEIGHT, weight_rule: str = DEFAULT_WEIGHT_RULE, queue: str = DEFAULT_QUEUE, - pool: Optional[str] = None, + pool: str | None = None, pool_slots: int = DEFAULT_POOL_SLOTS, - sla: Optional[timedelta] = None, - execution_timeout: Optional[timedelta] = DEFAULT_TASK_EXECUTION_TIMEOUT, - on_execute_callback: Optional[TaskStateChangeCallback] = None, - on_failure_callback: Optional[TaskStateChangeCallback] = None, - on_success_callback: Optional[TaskStateChangeCallback] = None, - on_retry_callback: Optional[TaskStateChangeCallback] = None, - pre_execute: Optional[TaskPreExecuteHook] = None, - post_execute: Optional[TaskPostExecuteHook] = None, + sla: timedelta | None = None, + execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT, + on_execute_callback: TaskStateChangeCallback | None = None, + on_failure_callback: TaskStateChangeCallback | None = None, + on_success_callback: TaskStateChangeCallback | None = None, + on_retry_callback: TaskStateChangeCallback | None = None, + pre_execute: TaskPreExecuteHook | None = None, + post_execute: TaskPostExecuteHook | None = None, trigger_rule: str = DEFAULT_TRIGGER_RULE, - resources: Optional[Dict[str, Any]] = None, - run_as_user: Optional[str] = None, - task_concurrency: Optional[int] = None, - max_active_tis_per_dag: Optional[int] = None, - executor_config: Optional[Dict] = None, + resources: dict[str, Any] | None = None, + run_as_user: str | None = None, + task_concurrency: int | None = None, + max_active_tis_per_dag: int | None = None, + executor_config: dict | None = None, do_xcom_push: bool = True, - inlets: Optional[Any] = None, - outlets: Optional[Any] = None, - task_group: Optional["TaskGroup"] = None, - doc: Optional[str] = None, - doc_md: Optional[str] = None, - doc_json: Optional[str] = None, - doc_yaml: Optional[str] = None, - doc_rst: Optional[str] = None, + inlets: Any | None = None, + outlets: Any | None = None, + task_group: TaskGroup | None = None, + doc: str | None = None, + doc_md: str | None = None, + doc_json: str | None = None, + doc_yaml: str | None = None, + doc_rst: str | None = None, **kwargs, ): from airflow.models.dag import DagContext @@ -756,17 +736,18 @@ def __init__( super().__init__() + kwargs.pop("_airflow_mapped_validation_only", None) if kwargs: - if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'): + if not conf.getboolean("operators", "ALLOW_ILLEGAL_ARGUMENTS"): raise AirflowException( f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). " f"Invalid arguments were:\n**kwargs: {kwargs}", ) warnings.warn( - f'Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). ' - 'Support for passing such arguments will be dropped in future. ' - f'Invalid arguments were:\n**kwargs: {kwargs}', - category=PendingDeprecationWarning, + f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). " + "Support for passing such arguments will be dropped in future. " + f"Invalid arguments were:\n**kwargs: {kwargs}", + category=RemovedInAirflow3Warning, stacklevel=3, ) validate_key(task_id) @@ -785,7 +766,7 @@ def __init__( if execution_timeout is not None and not isinstance(execution_timeout, timedelta): raise ValueError( - f'execution_timeout must be timedelta object but passed as type: {type(execution_timeout)}' + f"execution_timeout must be timedelta object but passed as type: {type(execution_timeout)}" ) self.execution_timeout = execution_timeout @@ -818,7 +799,7 @@ def __init__( if trigger_rule == "dummy": warnings.warn( "dummy Trigger Rule is deprecated. Please use `TriggerRule.ALWAYS`.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) trigger_rule = TriggerRule.ALWAYS @@ -827,7 +808,7 @@ def __init__( warnings.warn( "none_failed_or_skipped Trigger Rule is deprecated. " "Please use `none_failed_min_one_success`.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) trigger_rule = TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS @@ -838,7 +819,7 @@ def __init__( f"'{dag.dag_id if dag else ''}.{task_id}'; received '{trigger_rule}'." ) - self.trigger_rule = TriggerRule(trigger_rule) + self.trigger_rule: TriggerRule = TriggerRule(trigger_rule) self.depends_on_past: bool = depends_on_past self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past self.wait_for_downstream: bool = wait_for_downstream @@ -854,7 +835,7 @@ def __init__( ) # At execution_time this becomes a normal dict - self.params: Union[ParamsDict, dict] = ParamsDict(params) + self.params: ParamsDict | dict = ParamsDict(params) if priority_weight is not None and not isinstance(priority_weight, int): raise AirflowException( f"`priority_weight` for task '{self.task_id}' only accepts integers, " @@ -873,11 +854,11 @@ def __init__( # TODO: Remove in Airflow 3.0 warnings.warn( "The 'task_concurrency' parameter is deprecated. Please use 'max_active_tis_per_dag'.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) max_active_tis_per_dag = task_concurrency - self.max_active_tis_per_dag: Optional[int] = max_active_tis_per_dag + self.max_active_tis_per_dag: int | None = max_active_tis_per_dag self.do_xcom_push = do_xcom_push self.doc_md = doc_md @@ -886,8 +867,8 @@ def __init__( self.doc_rst = doc_rst self.doc = doc - self.upstream_task_ids: Set[str] = set() - self.downstream_task_ids: Set[str] = set() + self.upstream_task_ids: set[str] = set() + self.downstream_task_ids: set[str] = set() if dag: self.dag = dag @@ -895,14 +876,11 @@ def __init__( self._log = logging.getLogger("airflow.task.operators") # Lineage - self.inlets: List = [] - self.outlets: List = [] - - self._inlets: List = [] - self._outlets: List = [] + self.inlets: list = [] + self.outlets: list = [] if inlets: - self._inlets = ( + self.inlets = ( inlets if isinstance(inlets, list) else [ @@ -911,7 +889,7 @@ def __init__( ) if outlets: - self._outlets = ( + self.outlets = ( outlets if isinstance(outlets, list) else [ @@ -954,11 +932,11 @@ def __hash__(self): def __or__(self, other): """ Called for [This Operator] | [Operator], The inlets of other - will be set to pickup the outlets from this operator. Other will + will be set to pick up the outlets from this operator. Other will be set as a downstream task of this operator. """ if isinstance(other, BaseOperator): - if not self._outlets and not self.supports_lineage: + if not self.outlets and not self.supports_lineage: raise ValueError("No outlets defined for this operator") other.add_inlets([self.task_id]) self.set_downstream(other) @@ -1014,33 +992,33 @@ def __setattr__(self, key, value): def add_inlets(self, inlets: Iterable[Any]): """Sets inlets to this operator""" - self._inlets.extend(inlets) + self.inlets.extend(inlets) def add_outlets(self, outlets: Iterable[Any]): """Defines the outlets of this operator""" - self._outlets.extend(outlets) + self.outlets.extend(outlets) def get_inlet_defs(self): - """:return: list of inlets defined for this operator""" - return self._inlets + """:meta private:""" + return self.inlets def get_outlet_defs(self): - """:return: list of outlets defined for this operator""" - return self._outlets + """:meta private:""" + return self.outlets - def get_dag(self) -> "Optional[DAG]": + def get_dag(self) -> DAG | None: return self._dag @property # type: ignore[override] - def dag(self) -> 'DAG': # type: ignore[override] + def dag(self) -> DAG: # type: ignore[override] """Returns the Operator's DAG if set, otherwise raises an error""" if self._dag: return self._dag else: - raise AirflowException(f'Operator {self} has not been assigned to a DAG yet') + raise AirflowException(f"Operator {self} has not been assigned to a DAG yet") @dag.setter - def dag(self, dag: Optional['DAG']): + def dag(self, dag: DAG | None): """ Operators can be assigned to one DAG, one time. Repeat assignments to that same DAG are ok. @@ -1051,7 +1029,7 @@ def dag(self, dag: Optional['DAG']): self._dag = None return if not isinstance(dag, DAG): - raise TypeError(f'Expected DAG; received {dag.__class__.__name__}') + raise TypeError(f"Expected DAG; received {dag.__class__.__name__}") elif self.has_dag() and self.dag is not dag: raise AirflowException(f"The DAG assigned to {self} can not be changed.") @@ -1068,7 +1046,7 @@ def has_dag(self): """Returns True if the Operator has been assigned to a DAG.""" return self._dag is not None - deps: FrozenSet[BaseTIDep] = frozenset( + deps: frozenset[BaseTIDep] = frozenset( { NotInRetryPeriodDep(), PrevDagrunDep(), @@ -1082,7 +1060,7 @@ def has_dag(self): extended/overridden by subclasses. """ - def prepare_for_execution(self) -> "BaseOperator": + def prepare_for_execution(self) -> BaseOperator: """ Lock task for execution to disable custom action in __setattr__ and returns a copy of the task @@ -1146,18 +1124,17 @@ def post_execute(self, context: Any, result: Any = None): def on_kill(self) -> None: """ - Override this method to cleanup subprocesses when a task instance + Override this method to clean up subprocesses when a task instance gets killed. Any use of the threading, subprocess or multiprocessing - module within an operator needs to be cleaned up or it will leave + module within an operator needs to be cleaned up, or it will leave ghost processes behind. """ def __deepcopy__(self, memo): - """ - Hack sorting double chained task lists by task_id to avoid hitting - max_depth on deepcopy operations. - """ + # Hack sorting double chained task lists by task_id to avoid hitting + # max_depth on deepcopy operations. sys.setrecursionlimit(5000) # TODO fix this in a better way + cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result @@ -1165,15 +1142,19 @@ def __deepcopy__(self, memo): shallow_copy = cls.shallow_copy_attrs + cls._base_operator_shallow_copy_attrs for k, v in self.__dict__.items(): + if k == "_BaseOperator__instantiated": + # Don't set this until the _end_, as it changes behaviour of __setattr__ + continue if k not in shallow_copy: setattr(result, k, copy.deepcopy(v, memo)) else: setattr(result, k, copy.copy(v)) + result.__instantiated = self.__instantiated return result def __getstate__(self): state = dict(self.__dict__) - del state['_log'] + del state["_log"] return state @@ -1184,25 +1165,24 @@ def __setstate__(self, state): def render_template_fields( self, context: Context, - jinja_env: Optional["jinja2.Environment"] = None, - ) -> Optional["BaseOperator"]: - """Template all attributes listed in template_fields. + jinja_env: jinja2.Environment | None = None, + ) -> None: + """Template all attributes listed in *self.template_fields*. This mutates the attributes in-place and is irreversible. - :param context: Dict with values to apply on content - :param jinja_env: Jinja environment + :param context: Context dict with values to apply on content. + :param jinja_env: Jinja's environment to use for rendering. """ if not jinja_env: jinja_env = self.get_template_env() self._do_render_template_fields(self, self.template_fields, context, jinja_env, set()) - return self @provide_session def clear( self, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, + start_date: datetime | None = None, + end_date: datetime | None = None, upstream: bool = False, downstream: bool = False, session: Session = NEW_SESSION, @@ -1236,10 +1216,10 @@ def clear( @provide_session def get_task_instances( self, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, + start_date: datetime | None = None, + end_date: datetime | None = None, session: Session = NEW_SESSION, - ) -> List[TaskInstance]: + ) -> list[TaskInstance]: """Get task instances related to this task for a specific date range.""" from airflow.models import DagRun @@ -1258,8 +1238,8 @@ def get_task_instances( @provide_session def run( self, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, + start_date: datetime | None = None, + end_date: datetime | None = None, ignore_first_depends_on_past: bool = True, ignore_ti_state: bool = False, mark_success: bool = False, @@ -1314,7 +1294,7 @@ def run( def dry_run(self) -> None: """Performs dry run for the operator - just render template fields.""" - self.log.info('Dry run') + self.log.info("Dry run") for field in self.template_fields: try: content = getattr(self, field) @@ -1325,10 +1305,10 @@ def dry_run(self) -> None: ) if content and isinstance(content, str): - self.log.info('Rendering template for %s', field) + self.log.info("Rendering template for %s", field) self.log.info(content) - def get_direct_relatives(self, upstream: bool = False) -> Iterable["DAGNode"]: + def get_direct_relatives(self, upstream: bool = False) -> Iterable[DAGNode]: """ Get list of the direct relatives to the current task, upstream or downstream. @@ -1342,7 +1322,7 @@ def __repr__(self): return "".format(self=self) @property - def operator_class(self) -> Type["BaseOperator"]: # type: ignore[override] + def operator_class(self) -> type[BaseOperator]: # type: ignore[override] return self.__class__ @property @@ -1351,17 +1331,25 @@ def task_type(self) -> str: return self.__class__.__name__ @property - def roots(self) -> List["BaseOperator"]: + def operator_name(self) -> str: + """@property: use a more friendly display name for the operator, if set""" + try: + return self.custom_operator_name # type: ignore + except AttributeError: + return self.task_type + + @property + def roots(self) -> list[BaseOperator]: """Required by DAGNode.""" return [self] @property - def leaves(self) -> List["BaseOperator"]: + def leaves(self) -> list[BaseOperator]: """Required by DAGNode.""" return [self] @property - def output(self): + def output(self) -> XComArg: """Returns reference to XCom pushed by current operator""" from airflow.models.xcom_arg import XComArg @@ -1372,7 +1360,7 @@ def xcom_push( context: Any, key: str, value: Any, - execution_date: Optional[datetime] = None, + execution_date: datetime | None = None, ) -> None: """ Make an XCom available for tasks to pull. @@ -1385,15 +1373,15 @@ def xcom_push( this date. This can be used, for example, to send a message to a task on a future date without it being immediately visible. """ - context['ti'].xcom_push(key=key, value=value, execution_date=execution_date) + context["ti"].xcom_push(key=key, value=value, execution_date=execution_date) @staticmethod def xcom_pull( context: Any, - task_ids: Optional[Union[str, List[str]]] = None, - dag_id: Optional[str] = None, + task_ids: str | list[str] | None = None, + dag_id: str | None = None, key: str = XCOM_RETURN_KEY, - include_prior_dates: Optional[bool] = None, + include_prior_dates: bool | None = None, ) -> Any: """ Pull XComs that optionally meet certain criteria. @@ -1421,7 +1409,7 @@ def xcom_pull( execution_date are returned. If True, XComs from previous dates are returned as well. """ - return context['ti'].xcom_pull( + return context["ti"].xcom_pull( key=key, task_ids=task_ids, dag_id=dag_id, include_prior_dates=include_prior_dates ) @@ -1437,61 +1425,54 @@ def get_serialized_fields(cls): # Exception in SerializedDAG.serialize_dag() call. DagContext.push_context_managed_dag(None) cls.__serialized_fields = frozenset( - vars(BaseOperator(task_id='test')).keys() + vars(BaseOperator(task_id="test")).keys() - { - 'inlets', - 'outlets', - 'upstream_task_ids', - 'default_args', - 'dag', - '_dag', - 'label', - '_BaseOperator__instantiated', - '_BaseOperator__init_kwargs', - '_BaseOperator__from_mapped', + "upstream_task_ids", + "default_args", + "dag", + "_dag", + "label", + "_BaseOperator__instantiated", + "_BaseOperator__init_kwargs", + "_BaseOperator__from_mapped", } | { # Class level defaults need to be added to this list - 'start_date', - 'end_date', - '_task_type', - 'subdag', - 'ui_color', - 'ui_fgcolor', - 'template_ext', - 'template_fields', - 'template_fields_renderers', - 'params', + "start_date", + "end_date", + "_task_type", + "_operator_name", + "subdag", + "ui_color", + "ui_fgcolor", + "template_ext", + "template_fields", + "template_fields_renderers", + "params", } ) DagContext.pop_context_managed_dag() return cls.__serialized_fields - def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]: + def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: """Required by DAGNode.""" return DagAttributeTypes.OP, self.task_id - def is_smart_sensor_compatible(self): - """Return if this operator can use smart service. Default False.""" - return False - - is_mapped: ClassVar[bool] = False - @property def inherits_from_empty_operator(self): """Used to determine if an Operator is inherited from EmptyOperator""" # This looks like `isinstance(self, EmptyOperator) would work, but this also # needs to cope when `self` is a Serialized instance of a EmptyOperator or one - # of its sub-classes (which don't inherit from anything but BaseOperator). - return getattr(self, '_is_empty', False) + # of its subclasses (which don't inherit from anything but BaseOperator). + return getattr(self, "_is_empty", False) def defer( self, *, trigger: BaseTrigger, method_name: str, - kwargs: Optional[Dict[str, Any]] = None, - timeout: Optional[timedelta] = None, + kwargs: dict[str, Any] | None = None, + timeout: timedelta | None = None, ): """ Marks this Operator as being "deferred" - that is, suspending its @@ -1502,13 +1483,7 @@ def defer( """ raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) - @classmethod - def validate_mapped_arguments(cls, **kwargs: Any) -> None: - """Validate arguments when this operator is being mapped.""" - if cls.mapped_arguments_validated_by_init: - cls(**kwargs, _airflow_from_mapped=True) - - def unmap(self) -> "BaseOperator": + def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator: """:meta private:""" return self @@ -1517,7 +1492,7 @@ def unmap(self) -> "BaseOperator": Chainable = Union[DependencyMixin, Sequence[DependencyMixin]] -def chain(*tasks: Union[DependencyMixin, Sequence[DependencyMixin]]) -> None: +def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None: r""" Given a number of tasks, builds a dependency chain. @@ -1588,12 +1563,12 @@ def chain(*tasks: Union[DependencyMixin, Sequence[DependencyMixin]]) -> None: .. code-block:: python - chain(t1, [Label("branch one"), Label("branch two")], [x1(), x2()], task_group1, t2()) + chain(t1, [Label("branch one"), Label("branch two")], [x1(), x2()], task_group1, x3()) is equivalent to:: / "branch one" -> x1 \ - t1 -> t2 -> x3 + t1 -> task_group1 -> x3 \ "branch two" -> x2 / .. code-block:: python @@ -1634,13 +1609,13 @@ def chain(*tasks: Union[DependencyMixin, Sequence[DependencyMixin]]) -> None: down_task.set_upstream(up_task) continue if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence): - raise TypeError(f'Chain not supported between instances of {type(up_task)} and {type(down_task)}') + raise TypeError(f"Chain not supported between instances of {type(up_task)} and {type(down_task)}") up_task_list = up_task down_task_list = down_task if len(up_task_list) != len(down_task_list): raise AirflowException( - f'Chain not supported for different length Iterable. ' - f'Got {len(up_task_list)} and {len(down_task_list)}.' + f"Chain not supported for different length Iterable. " + f"Got {len(up_task_list)} and {len(down_task_list)}." ) for up_t, down_t in zip(up_task_list, down_task_list): up_t.set_downstream(down_t) @@ -1648,7 +1623,7 @@ def chain(*tasks: Union[DependencyMixin, Sequence[DependencyMixin]]) -> None: def cross_downstream( from_tasks: Sequence[DependencyMixin], - to_tasks: Union[DependencyMixin, Sequence[DependencyMixin]], + to_tasks: DependencyMixin | Sequence[DependencyMixin], ): r""" Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks. @@ -1747,11 +1722,19 @@ def cross_downstream( task.set_downstream(to_tasks) +# pyupgrade assumes all type annotations can be lazily evaluated, but this is +# not the case for attrs-decorated classes, since cattrs needs to evaluate the +# annotation expressions at runtime, and Python before 3.9.0 does not lazily +# evaluate those. Putting the expression in a top-level assignment statement +# communicates this runtime requirement to pyupgrade. +BaseOperatorClassList = List[Type[BaseOperator]] + + @attr.s(auto_attribs=True) class BaseOperatorLink(metaclass=ABCMeta): """Abstract base class that defines how we get an operator link.""" - operators: ClassVar[List[Type[BaseOperator]]] = [] + operators: ClassVar[BaseOperatorClassList] = [] """ This property will be used by Airflow Plugins to find the Operators to which you want to assign this Operator Link @@ -1762,21 +1745,16 @@ class BaseOperatorLink(metaclass=ABCMeta): @property @abstractmethod def name(self) -> str: - """ - Name of the link. This will be the button name on the task UI. - - :return: link name - """ + """Name of the link. This will be the button name on the task UI.""" @abstractmethod - def get_link(self, operator: AbstractOperator, *, ti_key: "TaskInstanceKey") -> str: - """ - Link to external system. + def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str: + """Link to external system. Note: The old signature of this function was ``(self, operator, dttm: datetime)``. That is still supported at runtime but is deprecated. - :param operator: airflow operator - :param ti_key: TaskInstance ID to return link for + :param operator: The Airflow operator object this link is associated to. + :param ti_key: TaskInstance ID to return link for. :return: link to external system """ diff --git a/airflow/models/connection.py b/airflow/models/connection.py index b21e68b73b356..5f7406d9d4835 100644 --- a/airflow/models/connection.py +++ b/airflow/models/connection.py @@ -15,23 +15,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import json import logging import warnings from json import JSONDecodeError -from typing import Dict, Optional, Union -from urllib.parse import parse_qsl, quote, unquote, urlencode, urlparse +from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit from sqlalchemy import Boolean, Column, Integer, String, Text from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import reconstructor, synonym from airflow.configuration import ensure_secrets_loaded -from airflow.exceptions import AirflowException, AirflowNotFoundException +from airflow.exceptions import AirflowException, AirflowNotFoundException, RemovedInAirflow3Warning from airflow.models.base import ID_LEN, Base from airflow.models.crypto import get_fernet -from airflow.providers_manager import ProvidersManager from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.log.secrets_masker import mask_secret from airflow.utils.module_loading import import_string @@ -41,7 +40,7 @@ def parse_netloc_to_hostname(*args, **kwargs): """This method is deprecated.""" - warnings.warn("This method is deprecated.", DeprecationWarning) + warnings.warn("This method is deprecated.", RemovedInAirflow3Warning) return _parse_netloc_to_hostname(*args, **kwargs) @@ -49,8 +48,8 @@ def parse_netloc_to_hostname(*args, **kwargs): # See: https://issues.apache.org/jira/browse/AIRFLOW-3615 def _parse_netloc_to_hostname(uri_parts): """Parse a URI string to get correct Hostname.""" - hostname = unquote(uri_parts.hostname or '') - if '/' in hostname: + hostname = unquote(uri_parts.hostname or "") + if "/" in hostname: hostname = uri_parts.netloc if "@" in hostname: hostname = hostname.rsplit("@", 1)[1] @@ -83,35 +82,35 @@ class Connection(Base, LoggingMixin): :param uri: URI address describing connection parameters. """ - EXTRA_KEY = '__extra__' + EXTRA_KEY = "__extra__" __tablename__ = "connection" id = Column(Integer(), primary_key=True) conn_id = Column(String(ID_LEN), unique=True, nullable=False) conn_type = Column(String(500), nullable=False) - description = Column(Text(5000)) + description = Column(Text().with_variant(Text(5000), "mysql").with_variant(String(5000), "sqlite")) host = Column(String(500)) schema = Column(String(500)) login = Column(String(500)) - _password = Column('password', String(5000)) + _password = Column("password", String(5000)) port = Column(Integer()) is_encrypted = Column(Boolean, unique=False, default=False) is_extra_encrypted = Column(Boolean, unique=False, default=False) - _extra = Column('extra', Text()) + _extra = Column("extra", Text()) def __init__( self, - conn_id: Optional[str] = None, - conn_type: Optional[str] = None, - description: Optional[str] = None, - host: Optional[str] = None, - login: Optional[str] = None, - password: Optional[str] = None, - schema: Optional[str] = None, - port: Optional[int] = None, - extra: Optional[Union[str, dict]] = None, - uri: Optional[str] = None, + conn_id: str | None = None, + conn_type: str | None = None, + description: str | None = None, + host: str | None = None, + login: str | None = None, + password: str | None = None, + schema: str | None = None, + port: int | None = None, + extra: str | dict | None = None, + uri: str | None = None, ): super().__init__() self.conn_id = conn_id @@ -155,14 +154,14 @@ def _validate_extra(extra, conn_id) -> None: "Encountered JSON value in `extra` which does not parse as a dictionary in " f"connection {conn_id!r}. From Airflow 3.0, the `extra` field must contain a JSON " "representation of a Python dict.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=3, ) except json.JSONDecodeError: warnings.warn( f"Encountered non-JSON in `extra` field for connection {conn_id!r}. Support for " "non-JSON `extra` will be removed in Airflow 3.0", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) return None @@ -175,20 +174,21 @@ def on_db_load(self): def parse_from_uri(self, **uri): """This method is deprecated. Please use uri parameter in constructor.""" warnings.warn( - "This method is deprecated. Please use uri parameter in constructor.", DeprecationWarning + "This method is deprecated. Please use uri parameter in constructor.", + RemovedInAirflow3Warning, ) self._parse_from_uri(**uri) @staticmethod def _normalize_conn_type(conn_type): - if conn_type == 'postgresql': - conn_type = 'postgres' - elif '-' in conn_type: - conn_type = conn_type.replace('-', '_') + if conn_type == "postgresql": + conn_type = "postgres" + elif "-" in conn_type: + conn_type = conn_type.replace("-", "_") return conn_type def _parse_from_uri(self, uri: str): - uri_parts = urlparse(uri) + uri_parts = urlsplit(uri) conn_type = uri_parts.scheme self.conn_type = self._normalize_conn_type(conn_type) self.host = _parse_netloc_to_hostname(uri_parts) @@ -206,35 +206,38 @@ def _parse_from_uri(self, uri: str): def get_uri(self) -> str: """Return connection in URI format""" - if '_' in self.conn_type: + if self.conn_type and "_" in self.conn_type: self.log.warning( "Connection schemes (type: %s) shall not contain '_' according to RFC3986.", self.conn_type, ) - uri = f"{str(self.conn_type).lower().replace('_', '-')}://" + if self.conn_type: + uri = f"{self.conn_type.lower().replace('_', '-')}://" + else: + uri = "//" - authority_block = '' + authority_block = "" if self.login is not None: - authority_block += quote(self.login, safe='') + authority_block += quote(self.login, safe="") if self.password is not None: - authority_block += ':' + quote(self.password, safe='') + authority_block += ":" + quote(self.password, safe="") - if authority_block > '': - authority_block += '@' + if authority_block > "": + authority_block += "@" uri += authority_block - host_block = '' + host_block = "" if self.host: - host_block += quote(self.host, safe='') + host_block += quote(self.host, safe="") if self.port: - if host_block > '': - host_block += f':{self.port}' + if host_block == "" and authority_block == "": + host_block += f"@:{self.port}" else: - host_block += f'@:{self.port}' + host_block += f":{self.port}" if self.schema: host_block += f"/{quote(self.schema, safe='')}" @@ -243,17 +246,17 @@ def get_uri(self) -> str: if self.extra: try: - query: Optional[str] = urlencode(self.extra_dejson) + query: str | None = urlencode(self.extra_dejson) except TypeError: query = None if query and self.extra_dejson == dict(parse_qsl(query, keep_blank_values=True)): - uri += '?' + query + uri += ("?" if self.schema else "/?") + query else: - uri += '?' + urlencode({self.EXTRA_KEY: self.extra}) + uri += ("?" if self.schema else "/?") + urlencode({self.EXTRA_KEY: self.extra}) return uri - def get_password(self) -> Optional[str]: + def get_password(self) -> str | None: """Return encrypted password.""" if self._password and self.is_encrypted: fernet = get_fernet() @@ -262,23 +265,23 @@ def get_password(self) -> Optional[str]: f"Can't decrypt encrypted password for login={self.login} " f"FERNET_KEY configuration is missing" ) - return fernet.decrypt(bytes(self._password, 'utf-8')).decode() + return fernet.decrypt(bytes(self._password, "utf-8")).decode() else: return self._password - def set_password(self, value: Optional[str]): + def set_password(self, value: str | None): """Encrypt password and set in object attribute.""" if value: fernet = get_fernet() - self._password = fernet.encrypt(bytes(value, 'utf-8')).decode() + self._password = fernet.encrypt(bytes(value, "utf-8")).decode() self.is_encrypted = fernet.is_encrypted @declared_attr def password(cls): """Password. The value is decrypted/encrypted when reading/setting the value.""" - return synonym('_password', descriptor=property(cls.get_password, cls.set_password)) + return synonym("_password", descriptor=property(cls.get_password, cls.set_password)) - def get_extra(self) -> Dict: + def get_extra(self) -> str: """Return encrypted extra-data.""" if self._extra and self.is_extra_encrypted: fernet = get_fernet() @@ -287,7 +290,7 @@ def get_extra(self) -> Dict: f"Can't decrypt `extra` params for login={self.login}, " f"FERNET_KEY configuration is missing" ) - extra_val = fernet.decrypt(bytes(self._extra, 'utf-8')).decode() + extra_val = fernet.decrypt(bytes(self._extra, "utf-8")).decode() else: extra_val = self._extra if extra_val: @@ -299,7 +302,7 @@ def set_extra(self, value: str): if value: self._validate_extra(value, self.conn_id) fernet = get_fernet() - self._extra = fernet.encrypt(bytes(value, 'utf-8')).decode() + self._extra = fernet.encrypt(bytes(value, "utf-8")).decode() self.is_extra_encrypted = fernet.is_encrypted else: self._extra = value @@ -308,18 +311,20 @@ def set_extra(self, value: str): @declared_attr def extra(cls): """Extra data. The value is decrypted/encrypted when reading/setting the value.""" - return synonym('_extra', descriptor=property(cls.get_extra, cls.set_extra)) + return synonym("_extra", descriptor=property(cls.get_extra, cls.set_extra)) def rotate_fernet_key(self): """Encrypts data with a new key. See: :ref:`security/fernet`""" fernet = get_fernet() if self._password and self.is_encrypted: - self._password = fernet.rotate(self._password.encode('utf-8')).decode() + self._password = fernet.rotate(self._password.encode("utf-8")).decode() if self._extra and self.is_extra_encrypted: - self._extra = fernet.rotate(self._extra.encode('utf-8')).decode() + self._extra = fernet.rotate(self._extra.encode("utf-8")).decode() def get_hook(self, *, hook_params=None): """Return hook based on conn_type""" + from airflow.providers_manager import ProvidersManager + hook = ProvidersManager().hooks.get(self.conn_type, None) if hook is None: @@ -339,7 +344,7 @@ def get_hook(self, *, hook_params=None): return hook_class(**{hook.connection_id_attribute_name: self.conn_id}, **hook_params) def __repr__(self): - return self.conn_id or '' + return self.conn_id or "" def log_info(self): """ @@ -349,7 +354,7 @@ def log_info(self): warnings.warn( "This method is deprecated. You can read each field individually or " "use the default representation (__repr__).", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) return ( @@ -366,7 +371,7 @@ def debug_info(self): warnings.warn( "This method is deprecated. You can read each field individually or " "use the default representation (__repr__).", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) return ( @@ -377,10 +382,10 @@ def debug_info(self): def test_connection(self): """Calls out get_hook method and executes test_connection method on that.""" - status, message = False, '' + status, message = False, "" try: hook = self.get_hook() - if getattr(hook, 'test_connection', False): + if getattr(hook, "test_connection", False): status, message = hook.test_connection() else: message = ( @@ -392,7 +397,7 @@ def test_connection(self): return status, message @property - def extra_dejson(self) -> Dict: + def extra_dejson(self) -> dict: """Returns the extra property by deserializing json.""" obj = {} if self.extra: @@ -408,7 +413,7 @@ def extra_dejson(self) -> Dict: return obj @classmethod - def get_connection_from_secrets(cls, conn_id: str) -> 'Connection': + def get_connection_from_secrets(cls, conn_id: str) -> Connection: """ Get connection by conn_id. @@ -422,26 +427,26 @@ def get_connection_from_secrets(cls, conn_id: str) -> 'Connection': return conn except Exception: log.exception( - 'Unable to retrieve connection from secrets backend (%s). ' - 'Checking subsequent secrets backend.', + "Unable to retrieve connection from secrets backend (%s). " + "Checking subsequent secrets backend.", type(secrets_backend).__name__, ) raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined") @classmethod - def from_json(cls, value, conn_id=None) -> 'Connection': + def from_json(cls, value, conn_id=None) -> Connection: kwargs = json.loads(value) - extra = kwargs.pop('extra', None) + extra = kwargs.pop("extra", None) if extra: - kwargs['extra'] = extra if isinstance(extra, str) else json.dumps(extra) - conn_type = kwargs.pop('conn_type', None) + kwargs["extra"] = extra if isinstance(extra, str) else json.dumps(extra) + conn_type = kwargs.pop("conn_type", None) if conn_type: - kwargs['conn_type'] = cls._normalize_conn_type(conn_type) - port = kwargs.pop('port', None) + kwargs["conn_type"] = cls._normalize_conn_type(conn_type) + port = kwargs.pop("port", None) if port: try: - kwargs['port'] = int(port) + kwargs["port"] = int(port) except ValueError: raise ValueError(f"Expected integer value for `port`, but got {port!r} instead.") return Connection(conn_id=conn_id, **kwargs) diff --git a/airflow/models/crypto.py b/airflow/models/crypto.py index b57c53732df32..d1a2b06a60599 100644 --- a/airflow/models/crypto.py +++ b/airflow/models/crypto.py @@ -15,10 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import logging -from typing import Optional +from __future__ import annotations -from cryptography.fernet import Fernet, MultiFernet +import logging from airflow.configuration import conf from airflow.exceptions import AirflowException @@ -58,7 +57,7 @@ def encrypt(self, b): return b -_fernet = None # type: Optional[FernetProtocol] +_fernet: FernetProtocol | None = None def get_fernet(): @@ -71,19 +70,21 @@ def get_fernet(): :return: Fernet object :raises: airflow.exceptions.AirflowException if there's a problem trying to load Fernet """ + from cryptography.fernet import Fernet, MultiFernet + global _fernet if _fernet: return _fernet try: - fernet_key = conf.get('core', 'FERNET_KEY') + fernet_key = conf.get("core", "FERNET_KEY") if not fernet_key: log.warning("empty cryptography key - values will not be stored encrypted.") _fernet = NullFernet() else: _fernet = MultiFernet( - [Fernet(fernet_part.encode('utf-8')) for fernet_part in fernet_key.split(',')] + [Fernet(fernet_part.encode("utf-8")) for fernet_part in fernet_key.split(",")] ) _fernet.is_encrypted = True except (ValueError, TypeError) as value_error: diff --git a/airflow/models/dag.py b/airflow/models/dag.py index ece03bafe8c92..4cb59754468d2 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations +import collections import copy import functools import itertools @@ -28,6 +30,7 @@ import traceback import warnings import weakref +from collections import deque from datetime import datetime, timedelta from inspect import signature from typing import ( @@ -35,25 +38,23 @@ Any, Callable, Collection, - Dict, - FrozenSet, + Deque, Iterable, + Iterator, List, - Optional, Sequence, - Set, - Tuple, - Type, Union, cast, overload, ) +from urllib.parse import urlsplit import jinja2 import pendulum from dateutil.relativedelta import relativedelta from pendulum.tz.timezone import Timezone -from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, not_, or_ +from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, and_, case, func, not_, or_ +from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm import backref, joinedload, relationship from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session @@ -62,28 +63,36 @@ import airflow.templates from airflow import settings, utils from airflow.compat.functools import cached_property -from airflow.configuration import conf -from airflow.exceptions import AirflowException, DuplicateTaskIdFound, TaskNotFound +from airflow.configuration import conf, secrets_backend_list +from airflow.exceptions import ( + AirflowDagInconsistent, + AirflowException, + AirflowSkipException, + DuplicateTaskIdFound, + RemovedInAirflow3Warning, + TaskNotFound, +) from airflow.models.abstractoperator import AbstractOperator -from airflow.models.base import ID_LEN, Base -from airflow.models.dagbag import DagBag +from airflow.models.base import Base, StringID from airflow.models.dagcode import DagCode from airflow.models.dagpickle import DagPickle from airflow.models.dagrun import DagRun from airflow.models.operator import Operator from airflow.models.param import DagParam, ParamsDict from airflow.models.taskinstance import Context, TaskInstance, TaskInstanceKey, clear_task_instances +from airflow.secrets.local_filesystem import LocalFilesystemBackend from airflow.security import permissions from airflow.stats import Stats from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable -from airflow.timetables.simple import NullTimetable, OnceTimetable +from airflow.timetables.simple import DatasetTriggeredTimetable, NullTimetable, OnceTimetable from airflow.typing_compat import Literal from airflow.utils import timezone from airflow.utils.dag_cycle_tester import check_cycle from airflow.utils.dates import cron_presets, date_range as utils_date_range +from airflow.utils.decorators import fixup_decorator_warning_stack from airflow.utils.file import correct_maybe_zipped -from airflow.utils.helpers import exactly_one, validate_key +from airflow.utils.helpers import at_most_one, exactly_one, validate_key from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, tuple_in_condition, with_row_locks @@ -91,16 +100,21 @@ from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType if TYPE_CHECKING: + from types import ModuleType + + from airflow.datasets import Dataset from airflow.decorators import TaskDecoratorCollection + from airflow.models.dagbag import DagBag from airflow.models.slamiss import SlaMiss from airflow.utils.task_group import TaskGroup log = logging.getLogger(__name__) -DEFAULT_VIEW_PRESETS = ['grid', 'graph', 'duration', 'gantt', 'landing_times'] -ORIENTATION_PRESETS = ['LR', 'TB', 'RL', 'BT'] +DEFAULT_VIEW_PRESETS = ["grid", "graph", "duration", "gantt", "landing_times"] +ORIENTATION_PRESETS = ["LR", "TB", "RL", "BT"] +TAG_MAX_LEN = 100 DagStateChangeCallback = Callable[[Context], None] ScheduleInterval = Union[None, str, timedelta, relativedelta] @@ -109,6 +123,9 @@ # but Mypy cannot handle that right now. Track progress of PEP 661 for progress. # See also: https://discuss.python.org/t/9126/7 ScheduleIntervalArg = Union[ArgNotSet, ScheduleInterval] +ScheduleArg = Union[ArgNotSet, ScheduleInterval, Timetable, Collection["Dataset"]] + +SLAMissCallback = Callable[["DAG", str, str, List["SlaMiss"], List[TaskInstance]], None] # Backward compatibility: If neither schedule_interval nor timetable is @@ -142,7 +159,7 @@ def _get_model_data_interval( instance: Any, start_field_name: str, end_field_name: str, -) -> Optional[DataInterval]: +) -> DataInterval | None: start = timezone.coerce_datetime(getattr(instance, start_field_name)) end = timezone.coerce_datetime(getattr(instance, end_field_name)) if start is None: @@ -172,7 +189,7 @@ def create_timetable(interval: ScheduleIntervalArg, timezone: Timezone) -> Timet def get_last_dagrun(dag_id, session, include_externally_triggered=False): """ Returns the last dag run for a dag, None if there was none. - Last dag run can be any type of run eg. scheduled or backfilled. + Last dag run can be any type of run e.g. scheduled or backfilled. Overridden DagRuns are ignored. """ DR = DagRun @@ -183,6 +200,49 @@ def get_last_dagrun(dag_id, session, include_externally_triggered=False): return query.first() +def get_dataset_triggered_next_run_info( + dag_ids: list[str], *, session: Session +) -> dict[str, dict[str, int | str]]: + """ + Given a list of dag_ids, get string representing how close any that are dataset triggered are + their next run, e.g. "1 of 2 datasets updated" + """ + from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue as DDRQ, DatasetModel + + return { + x.dag_id: { + "uri": x.uri, + "ready": x.ready, + "total": x.total, + } + for x in session.query( + DagScheduleDatasetReference.dag_id, + # This is a dirty hack to workaround group by requiring an aggregate, since grouping by dataset + # is not what we want to do here...but it works + case((func.count() == 1, func.max(DatasetModel.uri)), else_="").label("uri"), + func.count().label("total"), + func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)).label("ready"), + ) + .join( + DDRQ, + and_( + DDRQ.dataset_id == DagScheduleDatasetReference.dataset_id, + DDRQ.target_dag_id == DagScheduleDatasetReference.dag_id, + ), + isouter=True, + ) + .join( + DatasetModel, + DatasetModel.id == DagScheduleDatasetReference.dataset_id, + ) + .group_by( + DagScheduleDatasetReference.dag_id, + ) + .filter(DagScheduleDatasetReference.dag_id.in_(dag_ids)) + .all() + } + + @functools.total_ordering class DAG(LoggingMixin): """ @@ -199,19 +259,25 @@ class DAG(LoggingMixin): Note that if you plan to use time zones all the dates provided should be pendulum dates. See :ref:`timezone_aware_dags`. + .. versionadded:: 2.4 + The *schedule* argument to specify either time-based scheduling logic + (timetable), or dataset-driven triggers. + + .. deprecated:: 2.4 + The arguments *schedule_interval* and *timetable*. Their functionalities + are merged into the new *schedule* argument. + :param dag_id: The id of the DAG; must consist exclusively of alphanumeric characters, dashes, dots and underscores (all ASCII) :param description: The description for the DAG to e.g. be shown on the webserver - :param schedule_interval: Defines how often that DAG runs, this - timedelta object gets added to your latest task instance's - execution_date to figure out the next schedule - :param timetable: Specify which timetable to use (in which case schedule_interval - must not be set). See :doc:`/howto/timetable` for more information + :param schedule: Defines the rules according to which DAG runs are scheduled. Can + accept cron string, timedelta object, Timetable, or list of Dataset objects. + See also :doc:`/howto/timetable`. :param start_date: The timestamp from which the scheduler will attempt to backfill :param end_date: A date beyond which your DAG won't run, leave to None - for open ended scheduling - :param template_searchpath: This list of folders (non relative) + for open-ended scheduling + :param template_searchpath: This list of folders (non-relative) defines where jinja will look for your templates. Order matters. Note that jinja/airflow includes the path of your DAG file by default @@ -279,21 +345,25 @@ class DAG(LoggingMixin): to render templates as native Python types. If False, a Jinja ``Environment`` is used to render templates as string values. :param tags: List of tags to help filtering DAGs in the UI. + :param owner_links: Dict of owners and their links, that will be clickable on the DAGs view UI. + Can be used as an HTTP link (for example the link to your Slack channel), or a mailto link. + e.g: {"dag_owner": "https://airflow.apache.org/"} + :param auto_register: Automatically register this DAG when it is used in a ``with`` block """ _comps = { - 'dag_id', - 'task_ids', - 'parent_dag', - 'start_date', - 'end_date', - 'schedule_interval', - 'fileloc', - 'template_searchpath', - 'last_loaded', + "dag_id", + "task_ids", + "parent_dag", + "start_date", + "end_date", + "schedule_interval", + "fileloc", + "template_searchpath", + "last_loaded", } - __serialized_fields: Optional[FrozenSet[str]] = None + __serialized_fields: frozenset[str] | None = None fileloc: str """ @@ -303,44 +373,51 @@ class DAG(LoggingMixin): from a ZIP file or other DAG distribution format. """ - parent_dag: Optional["DAG"] = None # Gets set when DAGs are loaded + parent_dag: DAG | None = None # Gets set when DAGs are loaded + # NOTE: When updating arguments here, please also keep arguments in @dag() + # below in sync. (Search for 'def dag(' in this file.) def __init__( self, dag_id: str, - description: Optional[str] = None, + description: str | None = None, + schedule: ScheduleArg = NOTSET, schedule_interval: ScheduleIntervalArg = NOTSET, - timetable: Optional[Timetable] = None, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, - full_filepath: Optional[str] = None, - template_searchpath: Optional[Union[str, Iterable[str]]] = None, - template_undefined: Type[jinja2.StrictUndefined] = jinja2.StrictUndefined, - user_defined_macros: Optional[Dict] = None, - user_defined_filters: Optional[Dict] = None, - default_args: Optional[Dict] = None, - concurrency: Optional[int] = None, - max_active_tasks: int = conf.getint('core', 'max_active_tasks_per_dag'), - max_active_runs: int = conf.getint('core', 'max_active_runs_per_dag'), - dagrun_timeout: Optional[timedelta] = None, - sla_miss_callback: Optional[ - Callable[["DAG", str, str, List["SlaMiss"], List[TaskInstance]], None] - ] = None, - default_view: str = conf.get_mandatory_value('webserver', 'dag_default_view').lower(), - orientation: str = conf.get_mandatory_value('webserver', 'dag_orientation'), - catchup: bool = conf.getboolean('scheduler', 'catchup_by_default'), - on_success_callback: Optional[DagStateChangeCallback] = None, - on_failure_callback: Optional[DagStateChangeCallback] = None, - doc_md: Optional[str] = None, - params: Optional[Dict] = None, - access_control: Optional[Dict] = None, - is_paused_upon_creation: Optional[bool] = None, - jinja_environment_kwargs: Optional[Dict] = None, + timetable: Timetable | None = None, + start_date: datetime | None = None, + end_date: datetime | None = None, + full_filepath: str | None = None, + template_searchpath: str | Iterable[str] | None = None, + template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined, + user_defined_macros: dict | None = None, + user_defined_filters: dict | None = None, + default_args: dict | None = None, + concurrency: int | None = None, + max_active_tasks: int = conf.getint("core", "max_active_tasks_per_dag"), + max_active_runs: int = conf.getint("core", "max_active_runs_per_dag"), + dagrun_timeout: timedelta | None = None, + sla_miss_callback: SLAMissCallback | None = None, + default_view: str = conf.get_mandatory_value("webserver", "dag_default_view").lower(), + orientation: str = conf.get_mandatory_value("webserver", "dag_orientation"), + catchup: bool = conf.getboolean("scheduler", "catchup_by_default"), + on_success_callback: DagStateChangeCallback | None = None, + on_failure_callback: DagStateChangeCallback | None = None, + doc_md: str | None = None, + params: dict | None = None, + access_control: dict | None = None, + is_paused_upon_creation: bool | None = None, + jinja_environment_kwargs: dict | None = None, render_template_as_native_obj: bool = False, - tags: Optional[List[str]] = None, + tags: list[str] | None = None, + owner_links: dict[str, str] | None = None, + auto_register: bool = True, ): from airflow.utils.task_group import TaskGroup + if tags and any(len(tag) > TAG_MAX_LEN for tag in tags): + raise AirflowException(f"tag cannot be longer than {TAG_MAX_LEN} characters") + + self.owner_links = owner_links if owner_links else {} self.user_defined_macros = user_defined_macros self.user_defined_filters = user_defined_filters if default_args and not isinstance(default_args, dict): @@ -349,9 +426,9 @@ def __init__( params = params or {} # merging potentially conflicting default_args['params'] into params - if 'params' in self.default_args: - params.update(self.default_args['params']) - del self.default_args['params'] + if "params" in self.default_args: + params.update(self.default_args["params"]) + del self.default_args["params"] # check self.params and convert them into ParamsDict self.params = ParamsDict(params) @@ -359,7 +436,7 @@ def __init__( if full_filepath: warnings.warn( "Passing full_filepath to DAG() is deprecated and has no effect", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) @@ -370,29 +447,29 @@ def __init__( # TODO: Remove in Airflow 3.0 warnings.warn( "The 'concurrency' parameter is deprecated. Please use 'max_active_tasks'.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) max_active_tasks = concurrency self._max_active_tasks = max_active_tasks - self._pickle_id: Optional[int] = None + self._pickle_id: int | None = None self._description = description # set file location to caller source path back = sys._getframe().f_back self.fileloc = back.f_code.co_filename if back else "" - self.task_dict: Dict[str, Operator] = {} + self.task_dict: dict[str, Operator] = {} # set timezone from start_date tz = None if start_date and start_date.tzinfo: tzinfo = None if start_date.tzinfo else settings.TIMEZONE tz = pendulum.instance(start_date, tz=tzinfo).timezone - elif 'start_date' in self.default_args and self.default_args['start_date']: - date = self.default_args['start_date'] + elif "start_date" in self.default_args and self.default_args["start_date"]: + date = self.default_args["start_date"] if not isinstance(date, datetime): date = timezone.parse(date) - self.default_args['start_date'] = date + self.default_args["start_date"] = date start_date = date tzinfo = None if date.tzinfo else settings.TIMEZONE @@ -400,62 +477,96 @@ def __init__( self.timezone = tz or settings.TIMEZONE # Apply the timezone we settled on to end_date if it wasn't supplied - if 'end_date' in self.default_args and self.default_args['end_date']: - if isinstance(self.default_args['end_date'], str): - self.default_args['end_date'] = timezone.parse( - self.default_args['end_date'], timezone=self.timezone + if "end_date" in self.default_args and self.default_args["end_date"]: + if isinstance(self.default_args["end_date"], str): + self.default_args["end_date"] = timezone.parse( + self.default_args["end_date"], timezone=self.timezone ) self.start_date = timezone.convert_to_utc(start_date) self.end_date = timezone.convert_to_utc(end_date) # also convert tasks - if 'start_date' in self.default_args: - self.default_args['start_date'] = timezone.convert_to_utc(self.default_args['start_date']) - if 'end_date' in self.default_args: - self.default_args['end_date'] = timezone.convert_to_utc(self.default_args['end_date']) + if "start_date" in self.default_args: + self.default_args["start_date"] = timezone.convert_to_utc(self.default_args["start_date"]) + if "end_date" in self.default_args: + self.default_args["end_date"] = timezone.convert_to_utc(self.default_args["end_date"]) + + # sort out DAG's scheduling behavior + scheduling_args = [schedule_interval, timetable, schedule] + if not at_most_one(*scheduling_args): + raise ValueError("At most one allowed for args 'schedule_interval', 'timetable', and 'schedule'.") + if schedule_interval is not NOTSET: + warnings.warn( + "Param `schedule_interval` is deprecated and will be removed in a future release. " + "Please use `schedule` instead. ", + RemovedInAirflow3Warning, + stacklevel=2, + ) + if timetable is not None: + warnings.warn( + "Param `timetable` is deprecated and will be removed in a future release. " + "Please use `schedule` instead. ", + RemovedInAirflow3Warning, + stacklevel=2, + ) - # Calculate the DAG's timetable. - if timetable is None: - self.timetable = create_timetable(schedule_interval, self.timezone) - if isinstance(schedule_interval, ArgNotSet): - schedule_interval = DEFAULT_SCHEDULE_INTERVAL - self.schedule_interval: ScheduleInterval = schedule_interval - elif isinstance(schedule_interval, ArgNotSet): + self.timetable: Timetable + self.schedule_interval: ScheduleInterval + self.dataset_triggers: Collection[Dataset] = [] + + if isinstance(schedule, Collection) and not isinstance(schedule, str): + from airflow.datasets import Dataset + + if not all(isinstance(x, Dataset) for x in schedule): + raise ValueError("All elements in 'schedule' should be datasets") + self.dataset_triggers = list(schedule) + elif isinstance(schedule, Timetable): + timetable = schedule + elif schedule is not NOTSET: + schedule_interval = schedule + + if self.dataset_triggers: + self.timetable = DatasetTriggeredTimetable() + self.schedule_interval = self.timetable.summary + elif timetable: self.timetable = timetable self.schedule_interval = self.timetable.summary else: - raise TypeError("cannot specify both 'schedule_interval' and 'timetable'") + if isinstance(schedule_interval, ArgNotSet): + schedule_interval = DEFAULT_SCHEDULE_INTERVAL + self.schedule_interval = schedule_interval + self.timetable = create_timetable(schedule_interval, self.timezone) if isinstance(template_searchpath, str): template_searchpath = [template_searchpath] self.template_searchpath = template_searchpath self.template_undefined = template_undefined self.last_loaded = timezone.utcnow() - self.safe_dag_id = dag_id.replace('.', '__dot__') + self.safe_dag_id = dag_id.replace(".", "__dot__") self.max_active_runs = max_active_runs self.dagrun_timeout = dagrun_timeout self.sla_miss_callback = sla_miss_callback if default_view in DEFAULT_VIEW_PRESETS: self._default_view: str = default_view - elif default_view == 'tree': + elif default_view == "tree": warnings.warn( "`default_view` of 'tree' has been renamed to 'grid' -- please update your DAG", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) - self._default_view = 'grid' + self._default_view = "grid" else: raise AirflowException( - f'Invalid values of dag.default_view: only support ' - f'{DEFAULT_VIEW_PRESETS}, but get {default_view}' + f"Invalid values of dag.default_view: only support " + f"{DEFAULT_VIEW_PRESETS}, but get {default_view}" ) if orientation in ORIENTATION_PRESETS: self.orientation = orientation else: raise AirflowException( - f'Invalid values of dag.orientation: only support ' - f'{ORIENTATION_PRESETS}, but get {orientation}' + f"Invalid values of dag.orientation: only support " + f"{ORIENTATION_PRESETS}, but get {orientation}" ) self.catchup = catchup @@ -466,23 +577,96 @@ def __init__( # Keeps track of any extra edge metadata (sparse; will not contain all # edges, so do not iterate over it for that). Outer key is upstream # task ID, inner key is downstream task ID. - self.edge_info: Dict[str, Dict[str, EdgeInfoType]] = {} + self.edge_info: dict[str, dict[str, EdgeInfoType]] = {} # To keep it in parity with Serialized DAGs # and identify if DAG has on_*_callback without actually storing them in Serialized JSON self.has_on_success_callback = self.on_success_callback is not None self.has_on_failure_callback = self.on_failure_callback is not None - self.doc_md = doc_md - self._access_control = DAG._upgrade_outdated_dag_access_control(access_control) self.is_paused_upon_creation = is_paused_upon_creation + self.auto_register = auto_register self.jinja_environment_kwargs = jinja_environment_kwargs self.render_template_as_native_obj = render_template_as_native_obj + + self.doc_md = self.get_doc_md(doc_md) + self.tags = tags or [] self._task_group = TaskGroup.create_root(self) self.validate_schedule_and_params() + wrong_links = dict(self.iter_invalid_owner_links()) + if wrong_links: + raise AirflowException( + "Wrong link format was used for the owner. Use a valid link \n" + f"Bad formatted links are: {wrong_links}" + ) + + # this will only be set at serialization time + # it's only use is for determining the relative + # fileloc based only on the serialize dag + self._processor_dags_folder = None + + def get_doc_md(self, doc_md: str | None) -> str | None: + if doc_md is None: + return doc_md + + env = self.get_template_env(force_sandboxed=True) + + if not doc_md.endswith(".md"): + template = jinja2.Template(doc_md) + else: + try: + template = env.get_template(doc_md) + except jinja2.exceptions.TemplateNotFound: + return f""" + # Templating Error! + Not able to find the template file: `{doc_md}`. + """ + + return template.render() + + def _check_schedule_interval_matches_timetable(self) -> bool: + """Check ``schedule_interval`` and ``timetable`` match. + + This is done as a part of the DAG validation done before it's bagged, to + guard against the DAG's ``timetable`` (or ``schedule_interval``) from + being changed after it's created, e.g. + + .. code-block:: python + + dag1 = DAG("d1", timetable=MyTimetable()) + dag1.schedule_interval = "@once" + + dag2 = DAG("d2", schedule="@once") + dag2.timetable = MyTimetable() + + Validation is done by creating a timetable and check its summary matches + ``schedule_interval``. The logic is not bullet-proof, especially if a + custom timetable does not provide a useful ``summary``. But this is the + best we can do. + """ + if self.schedule_interval == self.timetable.summary: + return True + try: + timetable = create_timetable(self.schedule_interval, self.timezone) + except ValueError: + return False + return timetable.summary == self.timetable.summary + + def validate(self): + """Validate the DAG has a coherent setup. + + This is called by the DAG bag before bagging the DAG. + """ + if not self._check_schedule_interval_matches_timetable(): + raise AirflowDagInconsistent( + f"inconsistent schedule: timetable {self.timetable.summary!r} " + f"does not match schedule_interval {self.schedule_interval!r}", + ) + self.params.validate() + self.timetable.validate() def __repr__(self): return f"" @@ -504,7 +688,7 @@ def __hash__(self): hash_components = [type(self)] for c in self._comps: # task_ids returns a list and lists can't be hashed - if c == 'task_ids': + if c == "task_ids": val = tuple(self.task_dict.keys()) else: val = getattr(self, c, None) @@ -546,7 +730,7 @@ def _upgrade_outdated_dag_access_control(access_control=None): warnings.warn( "The 'can_dag_read' and 'can_dag_edit' permissions are deprecated. " "Please use 'can_read' and 'can_edit', respectively.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=3, ) @@ -555,19 +739,19 @@ def _upgrade_outdated_dag_access_control(access_control=None): def date_range( self, start_date: pendulum.DateTime, - num: Optional[int] = None, - end_date: Optional[datetime] = None, - ) -> List[datetime]: + num: int | None = None, + end_date: datetime | None = None, + ) -> list[datetime]: message = "`DAG.date_range()` is deprecated." if num is not None: - warnings.warn(message, category=DeprecationWarning, stacklevel=2) + warnings.warn(message, category=RemovedInAirflow3Warning, stacklevel=2) with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + warnings.simplefilter("ignore", RemovedInAirflow3Warning) return utils_date_range( start_date=start_date, num=num, delta=self.normalized_schedule_interval ) message += " Please use `DAG.iter_dagrun_infos_between(..., align=False)` instead." - warnings.warn(message, category=DeprecationWarning, stacklevel=2) + warnings.warn(message, category=RemovedInAirflow3Warning, stacklevel=2) if end_date is None: coerced_end_date = timezone.utcnow() else: @@ -578,7 +762,7 @@ def date_range( def is_fixed_time_schedule(self): warnings.warn( "`DAG.is_fixed_time_schedule()` is deprecated.", - category=DeprecationWarning, + category=RemovedInAirflow3Warning, stacklevel=2, ) try: @@ -595,7 +779,7 @@ def following_schedule(self, dttm): """ warnings.warn( "`DAG.following_schedule()` is deprecated. Use `DAG.next_dagrun_info(restricted=False)` instead.", - category=DeprecationWarning, + category=RemovedInAirflow3Warning, stacklevel=2, ) data_interval = self.infer_automated_data_interval(timezone.coerce_datetime(dttm)) @@ -609,21 +793,21 @@ def previous_schedule(self, dttm): warnings.warn( "`DAG.previous_schedule()` is deprecated.", - category=DeprecationWarning, + category=RemovedInAirflow3Warning, stacklevel=2, ) if not isinstance(self.timetable, _DataIntervalTimetable): return None return self.timetable._get_prev(timezone.coerce_datetime(dttm)) - def get_next_data_interval(self, dag_model: "DagModel") -> Optional[DataInterval]: + def get_next_data_interval(self, dag_model: DagModel) -> DataInterval | None: """Get the data interval of the next scheduled run. For compatibility, this method infers the data interval from the DAG's schedule if the run does not have an explicit one set, which is possible for runs created prior to AIP-39. - This function is private to Airflow core and should not be depended as a + This function is private to Airflow core and should not be depended on as a part of the Python API. :meta private: @@ -648,7 +832,7 @@ def get_run_data_interval(self, run: DagRun) -> DataInterval: schedule if the run does not have an explicit one set, which is possible for runs created prior to AIP-39. - This function is private to Airflow core and should not be depended as a + This function is private to Airflow core and should not be depended on as a part of the Python API. :meta private: @@ -686,10 +870,10 @@ def infer_automated_data_interval(self, logical_date: datetime) -> DataInterval: def next_dagrun_info( self, - last_automated_dagrun: Union[None, datetime, DataInterval], + last_automated_dagrun: None | datetime | DataInterval, *, restricted: bool = True, - ) -> Optional[DagRunInfo]: + ) -> DagRunInfo | None: """Get information about the next DagRun of this dag after ``date_last_automated_dagrun``. This calculates what time interval the next DagRun should operate on @@ -716,7 +900,7 @@ def next_dagrun_info( if isinstance(last_automated_dagrun, datetime): warnings.warn( "Passing a datetime to DAG.next_dagrun_info is deprecated. Use a DataInterval instead.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) data_interval = self.infer_automated_data_interval( @@ -742,10 +926,10 @@ def next_dagrun_info( info = None return info - def next_dagrun_after_date(self, date_last_automated_dagrun: Optional[pendulum.DateTime]): + def next_dagrun_after_date(self, date_last_automated_dagrun: pendulum.DateTime | None): warnings.warn( "`DAG.next_dagrun_after_date()` is deprecated. Please use `DAG.next_dagrun_info()` instead.", - category=DeprecationWarning, + category=RemovedInAirflow3Warning, stacklevel=2, ) if date_last_automated_dagrun is None: @@ -776,7 +960,7 @@ def _time_restriction(self) -> TimeRestriction: def iter_dagrun_infos_between( self, - earliest: Optional[pendulum.DateTime], + earliest: pendulum.DateTime | None, latest: pendulum.DateTime, *, align: bool = True, @@ -856,7 +1040,7 @@ def iter_dagrun_infos_between( ) break - def get_run_dates(self, start_date, end_date=None): + def get_run_dates(self, start_date, end_date=None) -> list: """ Returns a list of dates between the interval received as parameter using this dag's schedule interval. Returned dates can be used for execution dates. @@ -864,11 +1048,10 @@ def get_run_dates(self, start_date, end_date=None): :param start_date: The start date of the interval. :param end_date: The end date of the interval. Defaults to ``timezone.utcnow()``. :return: A list of dates within the interval following the dag's schedule. - :rtype: list """ warnings.warn( "`DAG.get_run_dates()` is deprecated. Please use `DAG.iter_dagrun_infos_between()` instead.", - category=DeprecationWarning, + category=RemovedInAirflow3Warning, stacklevel=2, ) earliest = timezone.coerce_datetime(start_date) @@ -881,16 +1064,16 @@ def get_run_dates(self, start_date, end_date=None): def normalize_schedule(self, dttm): warnings.warn( "`DAG.normalize_schedule()` is deprecated.", - category=DeprecationWarning, + category=RemovedInAirflow3Warning, stacklevel=2, ) with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + warnings.simplefilter("ignore", RemovedInAirflow3Warning) following = self.following_schedule(dttm) if not following: # in case of @once return dttm with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + warnings.simplefilter("ignore", RemovedInAirflow3Warning) previous_of_following = self.previous_schedule(following) if previous_of_following != dttm: return following @@ -928,7 +1111,7 @@ def full_filepath(self) -> str: """:meta private:""" warnings.warn( "DAG.full_filepath is deprecated in favour of fileloc", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) return self.fileloc @@ -937,7 +1120,7 @@ def full_filepath(self) -> str: def full_filepath(self, value) -> None: warnings.warn( "DAG.full_filepath is deprecated in favour of fileloc", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) self.fileloc = value @@ -947,7 +1130,7 @@ def concurrency(self) -> int: # TODO: Remove in Airflow 3.0 warnings.warn( "The 'DAG.concurrency' attribute is deprecated. Please use 'DAG.max_active_tasks'.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) return self._max_active_tasks @@ -973,7 +1156,7 @@ def access_control(self, value): self._access_control = DAG._upgrade_outdated_dag_access_control(value) @property - def description(self) -> Optional[str]: + def description(self) -> str | None: return self._description @property @@ -981,7 +1164,7 @@ def default_view(self) -> str: return self._default_view @property - def pickle_id(self) -> Optional[int]: + def pickle_id(self) -> int | None: return self._pickle_id @pickle_id.setter @@ -999,26 +1182,28 @@ def param(self, name: str, default: Any = NOTSET) -> DagParam: return DagParam(current_dag=self, name=name, default=default) @property - def tasks(self) -> List[Operator]: + def tasks(self) -> list[Operator]: return list(self.task_dict.values()) @tasks.setter def tasks(self, val): - raise AttributeError('DAG.tasks can not be modified. Use dag.add_task() instead.') + raise AttributeError("DAG.tasks can not be modified. Use dag.add_task() instead.") @property - def task_ids(self) -> List[str]: + def task_ids(self) -> list[str]: return list(self.task_dict.keys()) @property - def task_group(self) -> "TaskGroup": + def task_group(self) -> TaskGroup: return self._task_group @property def filepath(self) -> str: """:meta private:""" warnings.warn( - "filepath is deprecated, use relative_fileloc instead", DeprecationWarning, stacklevel=2 + "filepath is deprecated, use relative_fileloc instead", + RemovedInAirflow3Warning, + stacklevel=2, ) return str(self.relative_fileloc) @@ -1027,7 +1212,11 @@ def relative_fileloc(self) -> pathlib.Path: """File location of the importable dag 'file' relative to the configured DAGs folder.""" path = pathlib.Path(self.fileloc) try: - return path.relative_to(settings.DAGS_FOLDER) + rel_path = path.relative_to(self._processor_dags_folder or settings.DAGS_FOLDER) + if rel_path == pathlib.Path("."): + return path + else: + return rel_path except ValueError: # Not relative to DAGS_FOLDER. return path @@ -1043,7 +1232,6 @@ def owner(self) -> str: Return list of all owners found in DAG tasks. :return: Comma separated list of owners in DAG tasks - :rtype: str """ return ", ".join({t.owner for t in self.tasks}) @@ -1069,18 +1257,18 @@ def concurrency_reached(self): """This attribute is deprecated. Please use `airflow.models.DAG.get_concurrency_reached` method.""" warnings.warn( "This attribute is deprecated. Please use `airflow.models.DAG.get_concurrency_reached` method.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) return self.get_concurrency_reached() @provide_session - def get_is_active(self, session=NEW_SESSION) -> Optional[None]: + def get_is_active(self, session=NEW_SESSION) -> None: """Returns a boolean indicating whether this DAG is active""" return session.query(DagModel.is_active).filter(DagModel.dag_id == self.dag_id).scalar() @provide_session - def get_is_paused(self, session=NEW_SESSION) -> Optional[None]: + def get_is_paused(self, session=NEW_SESSION) -> None: """Returns a boolean indicating whether this DAG is paused""" return session.query(DagModel.is_paused).filter(DagModel.dag_id == self.dag_id).scalar() @@ -1089,7 +1277,7 @@ def is_paused(self): """This attribute is deprecated. Please use `airflow.models.DAG.get_is_paused` method.""" warnings.warn( "This attribute is deprecated. Please use `airflow.models.DAG.get_is_paused` method.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) return self.get_is_paused() @@ -1098,12 +1286,12 @@ def is_paused(self): def normalized_schedule_interval(self) -> ScheduleInterval: warnings.warn( "DAG.normalized_schedule_interval() is deprecated.", - category=DeprecationWarning, + category=RemovedInAirflow3Warning, stacklevel=2, ) if isinstance(self.schedule_interval, str) and self.schedule_interval in cron_presets: _schedule_interval: ScheduleInterval = cron_presets.get(self.schedule_interval) - elif self.schedule_interval == '@once': + elif self.schedule_interval == "@once": _schedule_interval = None else: _schedule_interval = self.schedule_interval @@ -1127,12 +1315,12 @@ def handle_callback(self, dagrun, success=True, reason=None, session=NEW_SESSION """ callback = self.on_success_callback if success else self.on_failure_callback if callback: - self.log.info('Executing dag callback function: %s', callback) + self.log.info("Executing dag callback function: %s", callback) tis = dagrun.get_task_instances(session=session) ti = tis[-1] # get first TaskInstance of DagRun ti.task = self.get_task(ti.task_id) context = ti.get_template_context(session=session) - context.update({'reason': reason}) + context.update({"reason": reason}) try: callback(context) except Exception: @@ -1179,8 +1367,8 @@ def get_num_active_runs(self, external_trigger=None, only_running=True, session= @provide_session def get_dagrun( self, - execution_date: Optional[datetime] = None, - run_id: Optional[str] = None, + execution_date: datetime | None = None, + run_id: str | None = None, session: Session = NEW_SESSION, ): """ @@ -1224,7 +1412,7 @@ def get_dagruns_between(self, start_date, end_date, session=NEW_SESSION): return dagruns @provide_session - def get_latest_execution_date(self, session: Session = NEW_SESSION) -> Optional[pendulum.DateTime]: + def get_latest_execution_date(self, session: Session = NEW_SESSION) -> pendulum.DateTime | None: """Returns the latest date for which at least one dag run exists""" return session.query(func.max(DagRun.execution_date)).filter(DagRun.dag_id == self.dag_id).scalar() @@ -1233,7 +1421,7 @@ def latest_execution_date(self): """This attribute is deprecated. Please use `airflow.models.DAG.get_latest_execution_date`.""" warnings.warn( "This attribute is deprecated. Please use `airflow.models.DAG.get_latest_execution_date`.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) return self.get_latest_execution_date() @@ -1250,8 +1438,8 @@ def subdags(self): isinstance(task, SubDagOperator) or # TODO remove in Airflow 2.0 - type(task).__name__ == 'SubDagOperator' - or task.task_type == 'SubDagOperator' + type(task).__name__ == "SubDagOperator" + or task.task_type == "SubDagOperator" ): subdag_lst.append(task.subdag) subdag_lst += task.subdag.subdags @@ -1270,10 +1458,10 @@ def get_template_env(self, *, force_sandboxed: bool = False) -> jinja2.Environme # Default values (for backward compatibility) jinja_env_options = { - 'loader': jinja2.FileSystemLoader(searchpath), - 'undefined': self.template_undefined, - 'extensions': ["jinja2.ext.do"], - 'cache_size': 0, + "loader": jinja2.FileSystemLoader(searchpath), + "undefined": self.template_undefined, + "extensions": ["jinja2.ext.do"], + "cache_size": 0, } if self.jinja_environment_kwargs: jinja_env_options.update(self.jinja_environment_kwargs) @@ -1306,7 +1494,7 @@ def get_task_instances_before( num: int, *, session: Session = NEW_SESSION, - ) -> List[TaskInstance]: + ) -> list[TaskInstance]: """Get ``num`` task instances before (including) ``base_date``. The returned list may contain exactly ``num`` task instances. It can @@ -1314,7 +1502,7 @@ def get_task_instances_before( ``base_date``, or more if there are manual task runs between the requested period, which does not count toward ``num``. """ - min_date: Optional[datetime] = ( + min_date: datetime | None = ( session.query(DagRun.execution_date) .filter( DagRun.dag_id == self.dag_id, @@ -1333,11 +1521,11 @@ def get_task_instances_before( @provide_session def get_task_instances( self, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, - state: Optional[List[TaskInstanceState]] = None, + start_date: datetime | None = None, + end_date: datetime | None = None, + state: list[TaskInstanceState] | None = None, session: Session = NEW_SESSION, - ) -> List[TaskInstance]: + ) -> list[TaskInstance]: if not start_date: start_date = (timezone.utcnow() - timedelta(30)).replace( hour=0, minute=0, second=0, microsecond=0 @@ -1360,17 +1548,17 @@ def get_task_instances( def _get_task_instances( self, *, - task_ids: Optional[Collection[Union[str, Tuple[str, int]]]], - start_date: Optional[datetime], - end_date: Optional[datetime], - run_id: Optional[str], - state: Union[TaskInstanceState, Sequence[TaskInstanceState]], + task_ids: Collection[str | tuple[str, int]] | None, + start_date: datetime | None, + end_date: datetime | None, + run_id: str | None, + state: TaskInstanceState | Sequence[TaskInstanceState], include_subdags: bool, include_parentdag: bool, include_dependent_dags: bool, - exclude_task_ids: Optional[Collection[Union[str, Tuple[str, int]]]], + exclude_task_ids: Collection[str | tuple[str, int]] | None, session: Session, - dag_bag: Optional["DagBag"] = ..., + dag_bag: DagBag | None = ..., ) -> Iterable[TaskInstance]: ... # pragma: no cover @@ -1378,43 +1566,43 @@ def _get_task_instances( def _get_task_instances( self, *, - task_ids: Optional[Collection[Union[str, Tuple[str, int]]]], + task_ids: Collection[str | tuple[str, int]] | None, as_pk_tuple: Literal[True], - start_date: Optional[datetime], - end_date: Optional[datetime], - run_id: Optional[str], - state: Union[TaskInstanceState, Sequence[TaskInstanceState]], + start_date: datetime | None, + end_date: datetime | None, + run_id: str | None, + state: TaskInstanceState | Sequence[TaskInstanceState], include_subdags: bool, include_parentdag: bool, include_dependent_dags: bool, - exclude_task_ids: Optional[Collection[Union[str, Tuple[str, int]]]], + exclude_task_ids: Collection[str | tuple[str, int]] | None, session: Session, - dag_bag: Optional["DagBag"] = ..., + dag_bag: DagBag | None = ..., recursion_depth: int = ..., max_recursion_depth: int = ..., - visited_external_tis: Set[TaskInstanceKey] = ..., - ) -> Set["TaskInstanceKey"]: + visited_external_tis: set[TaskInstanceKey] = ..., + ) -> set[TaskInstanceKey]: ... # pragma: no cover def _get_task_instances( self, *, - task_ids: Optional[Collection[Union[str, Tuple[str, int]]]], + task_ids: Collection[str | tuple[str, int]] | None, as_pk_tuple: Literal[True, None] = None, - start_date: Optional[datetime], - end_date: Optional[datetime], - run_id: Optional[str], - state: Union[TaskInstanceState, Sequence[TaskInstanceState]], + start_date: datetime | None, + end_date: datetime | None, + run_id: str | None, + state: TaskInstanceState | Sequence[TaskInstanceState], include_subdags: bool, include_parentdag: bool, include_dependent_dags: bool, - exclude_task_ids: Optional[Collection[Union[str, Tuple[str, int]]]], + exclude_task_ids: Collection[str | tuple[str, int]] | None, session: Session, - dag_bag: Optional["DagBag"] = None, + dag_bag: DagBag | None = None, recursion_depth: int = 0, - max_recursion_depth: Optional[int] = None, - visited_external_tis: Optional[Set[TaskInstanceKey]] = None, - ) -> Union[Iterable[TaskInstance], Set[TaskInstanceKey]]: + max_recursion_depth: int | None = None, + visited_external_tis: set[TaskInstanceKey] | None = None, + ) -> Iterable[TaskInstance] | set[TaskInstanceKey]: TI = TaskInstance # If we are looking at subdags/dependent dags we want to avoid UNION calls @@ -1423,7 +1611,7 @@ def _get_task_instances( # # This will be empty if we are only looking at one dag, in which case # we can return the filtered TI query object directly. - result: Set[TaskInstanceKey] = set() + result: set[TaskInstanceKey] = set() # Do we want full objects, or just the primary columns? if as_pk_tuple: @@ -1481,7 +1669,7 @@ def _get_task_instances( visited_external_tis = set() p_dag = self.parent_dag.partial_subset( - task_ids_or_regex=r"^{}$".format(self.dag_id.split('.')[1]), + task_ids_or_regex=r"^{}$".format(self.dag_id.split(".")[1]), include_upstream=False, include_downstream=True, ) @@ -1553,6 +1741,8 @@ def _get_task_instances( for tii in external_tis: if not dag_bag: + from airflow.models.dagbag import DagBag + dag_bag = DagBag(read_dags_from_db=True) external_dag = dag_bag.get_dag(tii.dag_id, session=session) if not external_dag: @@ -1585,7 +1775,7 @@ def _get_task_instances( if result or as_pk_tuple: # Only execute the `ti` query if we have also collected some other results (i.e. subdags etc.) if as_pk_tuple: - result.update(TaskInstanceKey(*cols) for cols in tis.all()) + result.update(TaskInstanceKey(**cols._mapping) for cols in tis.all()) else: result.update(ti.key for ti in tis) @@ -1618,9 +1808,9 @@ def set_task_instance_state( self, *, task_id: str, - map_indexes: Optional[Collection[int]] = None, - execution_date: Optional[datetime] = None, - run_id: Optional[str] = None, + map_indexes: Collection[int] | None = None, + execution_date: datetime | None = None, + run_id: str | None = None, state: TaskInstanceState, upstream: bool = False, downstream: bool = False, @@ -1628,7 +1818,7 @@ def set_task_instance_state( past: bool = False, commit: bool = True, session=NEW_SESSION, - ) -> List[TaskInstance]: + ) -> list[TaskInstance]: """ Set the state of a TaskInstance to the given state, and clear its downstream tasks that are in failed or upstream_failed state. @@ -1653,7 +1843,7 @@ def set_task_instance_state( task = self.get_task(task_id) task.dag = self - tasks_to_set_state: List[Union[Operator, Tuple[Operator, int]]] + tasks_to_set_state: list[Operator | tuple[Operator, int]] if map_indexes is None: tasks_to_set_state = [task] else: @@ -1709,12 +1899,12 @@ def set_task_instance_state( return altered @property - def roots(self) -> List[Operator]: + def roots(self) -> list[Operator]: """Return nodes with no parents. These are first to execute and are called roots or root nodes.""" return [task for task in self.tasks if not task.upstream_list] @property - def leaves(self) -> List[Operator]: + def leaves(self) -> list[Operator]: """Return nodes with no children. These are last to execute and are called leaves or leaf nodes.""" return [task for task in self.tasks if not task.downstream_list] @@ -1741,13 +1931,13 @@ def set_dag_runs_state( self, state: str = State.RUNNING, session: Session = NEW_SESSION, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, - dag_ids: List[str] = [], + start_date: datetime | None = None, + end_date: datetime | None = None, + dag_ids: list[str] = [], ) -> None: warnings.warn( "This method is deprecated and will be removed in a future version.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=3, ) dag_ids = dag_ids or [self.dag_id] @@ -1756,14 +1946,14 @@ def set_dag_runs_state( query = query.filter(DagRun.execution_date >= start_date) if end_date: query = query.filter(DagRun.execution_date <= end_date) - query.update({DagRun.state: state}, synchronize_session='fetch') + query.update({DagRun.state: state}, synchronize_session="fetch") @provide_session def clear( self, - task_ids: Union[Collection[str], Collection[Tuple[str, int]], None] = None, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, + task_ids: Collection[str | tuple[str, int]] | None = None, + start_date: datetime | None = None, + end_date: datetime | None = None, only_failed: bool = False, only_running: bool = False, confirm_prompt: bool = False, @@ -1774,10 +1964,10 @@ def clear( session: Session = NEW_SESSION, get_tis: bool = False, recursion_depth: int = 0, - max_recursion_depth: Optional[int] = None, - dag_bag: Optional["DagBag"] = None, - exclude_task_ids: Union[FrozenSet[str], FrozenSet[Tuple[str, int]], None] = frozenset(), - ) -> Union[int, Iterable[TaskInstance]]: + max_recursion_depth: int | None = None, + dag_bag: DagBag | None = None, + exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), + ) -> int | Iterable[TaskInstance]: """ Clears a set of task instances associated with the current dag for a specified date range. @@ -1802,7 +1992,7 @@ def clear( if get_tis: warnings.warn( "Passing `get_tis` to dag.clear() is deprecated. Use `dry_run` parameter instead.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) dry_run = True @@ -1810,13 +2000,13 @@ def clear( if recursion_depth: warnings.warn( "Passing `recursion_depth` to dag.clear() is deprecated.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) if max_recursion_depth: warnings.warn( "Passing `max_recursion_depth` to dag.clear() is deprecated.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) @@ -1937,12 +2127,12 @@ def __deepcopy__(self, memo): result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k not in ('user_defined_macros', 'user_defined_filters', '_log'): + if k not in ("user_defined_macros", "user_defined_filters", "_log"): setattr(result, k, copy.deepcopy(v, memo)) result.user_defined_macros = self.user_defined_macros result.user_defined_filters = self.user_defined_filters - if hasattr(self, '_log'): + if hasattr(self, "_log"): result._log = self._log return result @@ -1950,14 +2140,14 @@ def sub_dag(self, *args, **kwargs): """This method is deprecated in favor of partial_subset""" warnings.warn( "This method is deprecated and will be removed in a future version. Please use partial_subset", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) return self.partial_subset(*args, **kwargs) def partial_subset( self, - task_ids_or_regex: Union[str, re.Pattern, Iterable[str]], + task_ids_or_regex: str | re.Pattern | Iterable[str], include_downstream=False, include_upstream=True, include_direct_upstream=False, @@ -1976,7 +2166,6 @@ def partial_subset( :param include_direct_upstream: Include all tasks directly upstream of matched and downstream (if include_downstream = True) tasks """ - from airflow.models.baseoperator import BaseOperator from airflow.models.mappedoperator import MappedOperator @@ -1990,14 +2179,14 @@ def partial_subset( else: matched_tasks = [t for t in self.tasks if t.task_id in task_ids_or_regex] - also_include: List[Operator] = [] + also_include: list[Operator] = [] for t in matched_tasks: if include_downstream: also_include.extend(t.get_flat_relatives(upstream=False)) if include_upstream: also_include.extend(t.get_flat_relatives(upstream=True)) - direct_upstreams: List[Operator] = [] + direct_upstreams: list[Operator] = [] if include_direct_upstream: for t in itertools.chain(matched_tasks, also_include): upstream = (u for u in t.upstream_list if isinstance(u, (BaseOperator, MappedOperator))) @@ -2006,7 +2195,7 @@ def partial_subset( # Compiling the unique list of tasks that made the cut # Make sure to not recursively deepcopy the dag or task_group while copying the task. # task_group is reset later - def _deepcopy_task(t) -> "Operator": + def _deepcopy_task(t) -> Operator: memo.setdefault(id(t.task_group), None) return copy.deepcopy(t, memo) @@ -2064,6 +2253,13 @@ def filter_task_group(group, parent_group): def has_task(self, task_id: str): return task_id in self.task_dict + def has_task_group(self, task_group_id: str) -> bool: + return task_group_id in self.task_group_dict + + @cached_property + def task_group_dict(self): + return {k: v for k, v in self._task_group.get_task_group_dict().items() if k is not None} + def get_task(self, task_id: str, include_subdags: bool = False) -> Operator: if task_id in self.task_dict: return self.task_dict[task_id] @@ -2075,16 +2271,16 @@ def get_task(self, task_id: str, include_subdags: bool = False) -> Operator: def pickle_info(self): d = {} - d['is_picklable'] = True + d["is_picklable"] = True try: dttm = timezone.utcnow() pickled = pickle.dumps(self) - d['pickle_len'] = len(pickled) - d['pickling_duration'] = str(timezone.utcnow() - dttm) + d["pickle_len"] = len(pickled) + d["pickling_duration"] = str(timezone.utcnow() - dttm) except Exception as e: self.log.debug(e) - d['is_picklable'] = False - d['stacktrace'] = traceback.format_exc() + d["is_picklable"] = False + d["stacktrace"] = traceback.format_exc() return d @provide_session @@ -2115,7 +2311,7 @@ def get_downstream(task, level=0): get_downstream(t) @property - def task(self) -> "TaskDecoratorCollection": + def task(self) -> TaskDecoratorCollection: from airflow.decorators import task return cast("TaskDecoratorCollection", functools.partial(task, dag=self)) @@ -2126,6 +2322,8 @@ def add_task(self, task: Operator) -> None: :param task: the task you want to add """ + from airflow.utils.task_group import TaskGroupContext + if not self.start_date and not task.start_date: raise AirflowException("DAG is missing the start_date parameter") # if the task has no start date, assign it the same as the DAG @@ -2144,15 +2342,22 @@ def add_task(self, task: Operator) -> None: elif task.end_date and self.end_date: task.end_date = min(task.end_date, self.end_date) + task_id = task.task_id + if not task.task_group: + task_group = TaskGroupContext.get_current_task_group(self) + if task_group: + task_id = task_group.child_id(task_id) + task_group.add(task) + if ( - task.task_id in self.task_dict and self.task_dict[task.task_id] is not task - ) or task.task_id in self._task_group.used_group_ids: - raise DuplicateTaskIdFound(f"Task id '{task.task_id}' has already been added to the DAG") + task_id in self.task_dict and self.task_dict[task_id] is not task + ) or task_id in self._task_group.used_group_ids: + raise DuplicateTaskIdFound(f"Task id '{task_id}' has already been added to the DAG") else: - self.task_dict[task.task_id] = task + self.task_dict[task_id] = task task.dag = self # Add task_id to used_group_ids to prevent group_id and task_id collisions. - self._task_group.used_group_ids.add(task.task_id) + self._task_group.used_group_ids.add(task_id) self.task_count = len(self.task_dict) @@ -2169,7 +2374,7 @@ def _remove_task(self, task_id: str) -> None: # This is "private" as removing could leave a hole in dependencies if done incorrectly, and this # doesn't guard against that task = self.task_dict.pop(task_id) - tg = getattr(task, 'task_group', None) + tg = getattr(task, "task_group", None) if tg: tg._remove(task) @@ -2182,7 +2387,7 @@ def run( mark_success=False, local=False, executor=None, - donot_pickle=conf.getboolean('core', 'donot_pickle'), + donot_pickle=conf.getboolean("core", "donot_pickle"), ignore_task_deps=False, ignore_first_depends_on_past=True, pool=None, @@ -2193,6 +2398,7 @@ def run( run_backwards=False, run_at_least_once=False, continue_on_failures=False, + disable_retry=False, ): """ Runs the DAG. @@ -2243,6 +2449,7 @@ def run( run_backwards=run_backwards, run_at_least_once=run_at_least_once, continue_on_failures=continue_on_failures, + disable_retry=disable_retry, ) job.run() @@ -2256,20 +2463,96 @@ def cli(self): args = parser.parse_args() args.func(args, self) + @provide_session + def test( + self, + execution_date: datetime | None = None, + run_conf: dict[str, Any] | None = None, + conn_file_path: str | None = None, + variable_file_path: str | None = None, + session: Session = NEW_SESSION, + ) -> None: + """ + Execute one single DagRun for a given DAG and execution date. + + :param execution_date: execution date for the DAG run + :param run_conf: configuration to pass to newly created dagrun + :param conn_file_path: file path to a connection file in either yaml or json + :param variable_file_path: file path to a variable file in either yaml or json + :param session: database connection (optional) + """ + + def add_logger_if_needed(ti: TaskInstance): + """ + Add a formatted logger to the taskinstance so all logs are surfaced to the command line instead + of into a task file. Since this is a local test run, it is much better for the user to see logs + in the command line, rather than needing to search for a log file. + Args: + ti: The taskinstance that will receive a logger + + """ + format = logging.Formatter("[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s") + handler = logging.StreamHandler(sys.stdout) + handler.level = logging.INFO + handler.setFormatter(format) + # only add log handler once + if not any(isinstance(h, logging.StreamHandler) for h in ti.log.handlers): + self.log.debug("Adding Streamhandler to taskinstance %s", ti.task_id) + ti.log.addHandler(handler) + + if conn_file_path or variable_file_path: + local_secrets = LocalFilesystemBackend( + variables_file_path=variable_file_path, connections_file_path=conn_file_path + ) + secrets_backend_list.insert(0, local_secrets) + + execution_date = execution_date or timezone.utcnow() + self.log.debug("Clearing existing task instances for execution date %s", execution_date) + self.clear( + start_date=execution_date, + end_date=execution_date, + dag_run_state=False, # type: ignore + session=session, + ) + self.log.debug("Getting dagrun for dag %s", self.dag_id) + dr: DagRun = _get_or_create_dagrun( + dag=self, + start_date=execution_date, + execution_date=execution_date, + run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date), + session=session, + conf=run_conf, + ) + + tasks = self.task_dict + self.log.debug("starting dagrun") + # Instead of starting a scheduler, we run the minimal loop possible to check + # for task readiness and dependency management. This is notably faster + # than creating a BackfillJob and allows us to surface logs to the user + while dr.state == State.RUNNING: + schedulable_tis, _ = dr.update_state(session=session) + for ti in schedulable_tis: + add_logger_if_needed(ti) + ti.task = tasks[ti.task_id] + _run_task(ti, session=session) + if conn_file_path or variable_file_path: + # Remove the local variables we have added to the secrets_backend_list + secrets_backend_list.pop(0) + @provide_session def create_dagrun( self, state: DagRunState, - execution_date: Optional[datetime] = None, - run_id: Optional[str] = None, - start_date: Optional[datetime] = None, - external_trigger: Optional[bool] = False, - conf: Optional[dict] = None, - run_type: Optional[DagRunType] = None, - session=NEW_SESSION, - dag_hash: Optional[str] = None, - creating_job_id: Optional[int] = None, - data_interval: Optional[Tuple[datetime, datetime]] = None, + execution_date: datetime | None = None, + run_id: str | None = None, + start_date: datetime | None = None, + external_trigger: bool | None = False, + conf: dict | None = None, + run_type: DagRunType | None = None, + session: Session = NEW_SESSION, + dag_hash: str | None = None, + creating_job_id: int | None = None, + data_interval: tuple[datetime, datetime] | None = None, ): """ Creates a dag run from this dag including the tasks associated with this dag. @@ -2287,15 +2570,46 @@ def create_dagrun( :param dag_hash: Hash of Serialized DAG :param data_interval: Data interval of the DagRun """ + logical_date = timezone.coerce_datetime(execution_date) + + if data_interval and not isinstance(data_interval, DataInterval): + data_interval = DataInterval(*map(timezone.coerce_datetime, data_interval)) + + if data_interval is None and logical_date is not None: + warnings.warn( + "Calling `DAG.create_dagrun()` without an explicit data interval is deprecated", + RemovedInAirflow3Warning, + stacklevel=3, + ) + if run_type == DagRunType.MANUAL: + data_interval = self.timetable.infer_manual_data_interval(run_after=logical_date) + else: + data_interval = self.infer_automated_data_interval(logical_date) + + if run_type is None or isinstance(run_type, DagRunType): + pass + elif isinstance(run_type, str): # Compatibility: run_type used to be a str. + run_type = DagRunType(run_type) + else: + raise ValueError(f"`run_type` should be a DagRunType, not {type(run_type)}") + if run_id: # Infer run_type from run_id if needed. if not isinstance(run_id, str): raise ValueError(f"`run_id` should be a str, not {type(run_id)}") - if not run_type: - run_type = DagRunType.from_run_id(run_id) - elif run_type and execution_date is not None: # Generate run_id from run_type and execution_date. - if not isinstance(run_type, DagRunType): - raise ValueError(f"`run_type` should be a DagRunType, not {type(run_type)}") - run_id = DagRun.generate_run_id(run_type, execution_date) + inferred_run_type = DagRunType.from_run_id(run_id) + if run_type is None: + # No explicit type given, use the inferred type. + run_type = inferred_run_type + elif run_type == DagRunType.MANUAL and inferred_run_type != DagRunType.MANUAL: + # Prevent a manual run from using an ID that looks like a scheduled run. + raise ValueError( + f"A {run_type.value} DAG run cannot use ID {run_id!r} since it " + f"is reserved for {inferred_run_type.value} runs" + ) + elif run_type and logical_date is not None: # Generate run_id from run_type and execution_date. + run_id = self.timetable.generate_run_id( + run_type=run_type, logical_date=logical_date, data_interval=data_interval + ) else: raise AirflowException( "Creating DagRun needs either `run_id` or both `run_type` and `execution_date`" @@ -2305,21 +2619,9 @@ def create_dagrun( warnings.warn( "Using forward slash ('/') in a DAG run ID is deprecated. Note that this character " "also makes the run impossible to retrieve via Airflow's REST API.", - DeprecationWarning, - stacklevel=3, - ) - - logical_date = timezone.coerce_datetime(execution_date) - if data_interval is None and logical_date is not None: - warnings.warn( - "Calling `DAG.create_dagrun()` without an explicit data interval is deprecated", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=3, ) - if run_type == DagRunType.MANUAL: - data_interval = self.timetable.infer_manual_data_interval(run_after=logical_date) - else: - data_interval = self.infer_automated_data_interval(logical_date) # create a copy of params before validating copied_params = copy.deepcopy(self.params) @@ -2352,18 +2654,27 @@ def create_dagrun( @classmethod @provide_session - def bulk_sync_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION): + def bulk_sync_to_db( + cls, + dags: Collection[DAG], + session=NEW_SESSION, + ): """This method is deprecated in favor of bulk_write_to_db""" warnings.warn( "This method is deprecated and will be removed in a future version. Please use bulk_write_to_db", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) - return cls.bulk_write_to_db(dags, session) + return cls.bulk_write_to_db(dags=dags, session=session) @classmethod @provide_session - def bulk_write_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION): + def bulk_write_to_db( + cls, + dags: Collection[DAG], + processor_subdir: str | None = None, + session=NEW_SESSION, + ): """ Ensure the DagModel rows for the given dags are up-to-date in the dag table in the DB, including calculated fields. @@ -2378,16 +2689,18 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION): log.info("Sync %s DAGs", len(dags)) dag_by_ids = {dag.dag_id: dag for dag in dags} + dag_ids = set(dag_by_ids.keys()) query = ( session.query(DagModel) .options(joinedload(DagModel.tags, innerjoin=False)) .filter(DagModel.dag_id.in_(dag_ids)) + .options(joinedload(DagModel.schedule_dataset_references)) + .options(joinedload(DagModel.task_outlet_dataset_references)) ) - orm_dags: List[DagModel] = with_row_locks(query, of=DagModel, session=session).all() - - existing_dag_ids = {orm_dag.dag_id for orm_dag in orm_dags} - missing_dag_ids = dag_ids.difference(existing_dag_ids) + orm_dags: list[DagModel] = with_row_locks(query, of=DagModel, session=session).all() + existing_dags = {orm_dag.dag_id: orm_dag for orm_dag in orm_dags} + missing_dag_ids = dag_ids.difference(existing_dags) for missing_dag_id in missing_dag_ids: orm_dag = DagModel(dag_id=missing_dag_id) @@ -2403,7 +2716,7 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION): most_recent_subq = ( session.query(DagRun.dag_id, func.max(DagRun.execution_date).label("max_execution_date")) .filter( - DagRun.dag_id.in_(existing_dag_ids), + DagRun.dag_id.in_(existing_dags), or_(DagRun.run_type == DagRunType.BACKFILL_JOB, DagRun.run_type == DagRunType.SCHEDULED), ) .group_by(DagRun.dag_id) @@ -2417,7 +2730,7 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION): # Get number of active dagruns for all dags we are processing as a single query. - num_active_runs = DagRun.active_runs_of_dags(dag_ids=existing_dag_ids, session=session) + num_active_runs = DagRun.active_runs_of_dags(dag_ids=existing_dags, session=session) filelocs = [] @@ -2438,13 +2751,14 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION): orm_dag.last_parsed_time = timezone.utcnow() orm_dag.default_view = dag.default_view orm_dag.description = dag.description - orm_dag.schedule_interval = dag.schedule_interval orm_dag.max_active_tasks = dag.max_active_tasks orm_dag.max_active_runs = dag.max_active_runs orm_dag.has_task_concurrency_limits = any(t.max_active_tis_per_dag is not None for t in dag.tasks) + orm_dag.schedule_interval = dag.schedule_interval orm_dag.timetable_description = dag.timetable.description + orm_dag.processor_subdir = processor_subdir - run: Optional[DagRun] = most_recent_runs.get(dag.dag_id) + run: DagRun | None = most_recent_runs.get(dag.dag_id) if run is None: data_interval = None else: @@ -2467,17 +2781,119 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION): orm_dag.tags.append(dag_tag_orm) session.add(dag_tag_orm) + orm_dag_links = orm_dag.dag_owner_links or [] + for orm_dag_link in orm_dag_links: + if orm_dag_link not in dag.owner_links: + session.delete(orm_dag_link) + for owner_name, owner_link in dag.owner_links.items(): + dag_owner_orm = DagOwnerAttributes(dag_id=dag.dag_id, owner=owner_name, link=owner_link) + session.add(dag_owner_orm) + DagCode.bulk_sync_to_db(filelocs, session=session) + from airflow.datasets import Dataset + from airflow.models.dataset import ( + DagScheduleDatasetReference, + DatasetModel, + TaskOutletDatasetReference, + ) + + dag_references = collections.defaultdict(set) + outlet_references = collections.defaultdict(set) + # We can't use a set here as we want to preserve order + outlet_datasets: dict[Dataset, None] = {} + input_datasets: dict[Dataset, None] = {} + + # here we go through dags and tasks to check for dataset references + # if there are now None and previously there were some, we delete them + # if there are now *any*, we add them to the above data structures, and + # later we'll persist them to the database. + for dag in dags: + curr_orm_dag = existing_dags.get(dag.dag_id) + if not dag.dataset_triggers: + if curr_orm_dag and curr_orm_dag.schedule_dataset_references: + curr_orm_dag.schedule_dataset_references = [] + for dataset in dag.dataset_triggers: + dag_references[dag.dag_id].add(dataset.uri) + input_datasets[DatasetModel.from_public(dataset)] = None + curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references + for task in dag.tasks: + dataset_outlets = [x for x in task.outlets or [] if isinstance(x, Dataset)] + if not dataset_outlets: + if curr_outlet_references: + this_task_outlet_refs = [ + x + for x in curr_outlet_references + if x.dag_id == dag.dag_id and x.task_id == task.task_id + ] + for ref in this_task_outlet_refs: + curr_outlet_references.remove(ref) + for d in dataset_outlets: + outlet_references[(task.dag_id, task.task_id)].add(d.uri) + outlet_datasets[DatasetModel.from_public(d)] = None + all_datasets = outlet_datasets + all_datasets.update(input_datasets) + + # store datasets + stored_datasets = {} + for dataset in all_datasets: + stored_dataset = session.query(DatasetModel).filter(DatasetModel.uri == dataset.uri).first() + if stored_dataset: + # Some datasets may have been previously unreferenced, and therefore orphaned by the + # scheduler. But if we're here, then we have found that dataset again in our DAGs, which + # means that it is no longer an orphan, so set is_orphaned to False. + stored_dataset.is_orphaned = expression.false() + stored_datasets[stored_dataset.uri] = stored_dataset + else: + session.add(dataset) + stored_datasets[dataset.uri] = dataset + + session.flush() # this is required to ensure each dataset has its PK loaded + + del all_datasets + + # reconcile dag-schedule-on-dataset references + for dag_id, uri_list in dag_references.items(): + dag_refs_needed = { + DagScheduleDatasetReference(dataset_id=stored_datasets[uri].id, dag_id=dag_id) + for uri in uri_list + } + dag_refs_stored = set( + existing_dags.get(dag_id) + and existing_dags.get(dag_id).schedule_dataset_references # type: ignore + or [] + ) + dag_refs_to_add = {x for x in dag_refs_needed if x not in dag_refs_stored} + session.bulk_save_objects(dag_refs_to_add) + for obj in dag_refs_stored - dag_refs_needed: + session.delete(obj) + + existing_task_outlet_refs_dict = collections.defaultdict(set) + for dag_id, orm_dag in existing_dags.items(): + for todr in orm_dag.task_outlet_dataset_references: + existing_task_outlet_refs_dict[(dag_id, todr.task_id)].add(todr) + + # reconcile task-outlet-dataset references + for (dag_id, task_id), uri_list in outlet_references.items(): + task_refs_needed = { + TaskOutletDatasetReference(dataset_id=stored_datasets[uri].id, dag_id=dag_id, task_id=task_id) + for uri in uri_list + } + task_refs_stored = existing_task_outlet_refs_dict[(dag_id, task_id)] + task_refs_to_add = {x for x in task_refs_needed if x not in task_refs_stored} + session.bulk_save_objects(task_refs_to_add) + for obj in task_refs_stored - task_refs_needed: + session.delete(obj) + # Issue SQL/finish "Unit of Work", but let @provide_session commit (or if passed a session, let caller # decide when to commit session.flush() for dag in dags: - cls.bulk_write_to_db(dag.subdags, session=session) + cls.bulk_write_to_db(dag.subdags, processor_subdir=processor_subdir, session=session) @provide_session - def sync_to_db(self, session=NEW_SESSION): + def sync_to_db(self, processor_subdir: str | None = None, session=NEW_SESSION): """ Save attributes about this DAG to the DB. Note that this method can be called for both DAGs and SubDAGs. A SubDag is actually a @@ -2485,12 +2901,12 @@ def sync_to_db(self, session=NEW_SESSION): :return: None """ - self.bulk_write_to_db([self], session) + self.bulk_write_to_db([self], processor_subdir=processor_subdir, session=session) def get_default_view(self): """This is only there for backward compatible jinja2 templates""" if self.default_view is None: - return conf.get('webserver', 'dag_default_view').lower() + return conf.get("webserver", "dag_default_view").lower() else: return self.default_view @@ -2538,7 +2954,7 @@ def deactivate_stale_dags(expiration_date, session=NEW_SESSION): @staticmethod @provide_session - def get_num_task_instances(dag_id, task_ids=None, states=None, session=NEW_SESSION): + def get_num_task_instances(dag_id, task_ids=None, states=None, session=NEW_SESSION) -> int: """ Returns the number of task instances in the given DAG. @@ -2547,7 +2963,6 @@ def get_num_task_instances(dag_id, task_ids=None, states=None, session=NEW_SESSI :param task_ids: A list of valid task IDs for the given DAG :param states: A list of states to filter by if supplied :return: The number of running tasks - :rtype: int """ qry = session.query(func.count(TaskInstance.task_id)).filter( TaskInstance.dag_id == dag_id, @@ -2574,28 +2989,32 @@ def get_num_task_instances(dag_id, task_ids=None, states=None, session=NEW_SESSI def get_serialized_fields(cls): """Stringified DAGs and operators contain exactly these fields.""" if not cls.__serialized_fields: - cls.__serialized_fields = frozenset(vars(DAG(dag_id='test')).keys()) - { - 'parent_dag', - '_old_context_manager_dags', - 'safe_dag_id', - 'last_loaded', - 'user_defined_filters', - 'user_defined_macros', - 'partial', - 'params', - '_pickle_id', - '_log', - 'task_dict', - 'template_searchpath', - 'sla_miss_callback', - 'on_success_callback', - 'on_failure_callback', - 'template_undefined', - 'jinja_environment_kwargs', + exclusion_list = { + "parent_dag", + "schedule_dataset_references", + "task_outlet_dataset_references", + "_old_context_manager_dags", + "safe_dag_id", + "last_loaded", + "user_defined_filters", + "user_defined_macros", + "partial", + "params", + "_pickle_id", + "_log", + "task_dict", + "template_searchpath", + "sla_miss_callback", + "on_success_callback", + "on_failure_callback", + "template_undefined", + "jinja_environment_kwargs", # has_on_*_callback are only stored if the value is True, as the default is False - 'has_on_success_callback', - 'has_on_failure_callback', + "has_on_success_callback", + "has_on_failure_callback", + "auto_register", } + cls.__serialized_fields = frozenset(vars(DAG(dag_id="test")).keys()) - exclusion_list return cls.__serialized_fields def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeInfoType: @@ -2632,15 +3051,28 @@ def validate_schedule_and_params(self): "DAG Schedule must be None, if there are any required params without default values" ) + def iter_invalid_owner_links(self) -> Iterator[tuple[str, str]]: + """Parses a given link, and verifies if it's a valid URL, or a 'mailto' link. + Returns an iterator of invalid (owner, link) pairs. + """ + for owner, link in self.owner_links.items(): + result = urlsplit(link) + if result.scheme == "mailto": + # netloc is not existing for 'mailto' link, so we are checking that the path is parsed + if not result.path: + yield result.path, link + elif not result.scheme or not result.netloc: + yield owner, link + class DagTag(Base): """A tag name per dag, to allow quick filtering in the DAG view.""" __tablename__ = "dag_tag" - name = Column(String(100), primary_key=True) + name = Column(String(TAG_MAX_LEN), primary_key=True) dag_id = Column( - String(ID_LEN), - ForeignKey('dag.dag_id', name='dag_tag_dag_id_fkey', ondelete='CASCADE'), + StringID(), + ForeignKey("dag.dag_id", name="dag_tag_dag_id_fkey", ondelete="CASCADE"), primary_key=True, ) @@ -2648,6 +3080,33 @@ def __repr__(self): return self.name +class DagOwnerAttributes(Base): + """ + Table defining different owner attributes. For example, a link for an owner that will be passed as + a hyperlink to the DAGs view + """ + + __tablename__ = "dag_owner_attributes" + dag_id = Column( + StringID(), + ForeignKey("dag.dag_id", name="dag.dag_id", ondelete="CASCADE"), + nullable=False, + primary_key=True, + ) + owner = Column(String(500), primary_key=True, nullable=False) + link = Column(String(500), nullable=False) + + def __repr__(self): + return f"" + + @classmethod + def get_all(cls, session) -> dict[str, dict[str, str]]: + dag_links: dict = collections.defaultdict(dict) + for obj in session.query(cls): + dag_links[obj.dag_id].update({obj.owner: obj.link}) + return dag_links + + class DagModel(Base): """Table containing DAG properties""" @@ -2655,11 +3114,11 @@ class DagModel(Base): """ These items are stored in the database for state related information """ - dag_id = Column(String(ID_LEN), primary_key=True) - root_dag_id = Column(String(ID_LEN)) + dag_id = Column(StringID(), primary_key=True) + root_dag_id = Column(StringID()) # A DAG can be paused from the UI / DB # Set this default value of is_paused based on a configuration value! - is_paused_at_creation = conf.getboolean('core', 'dags_are_paused_at_creation') + is_paused_at_creation = conf.getboolean("core", "dags_are_paused_at_creation") is_paused = Column(Boolean, default=is_paused_at_creation) # Whether the DAG is a subdag is_subdag = Column(Boolean, default=False) @@ -2681,6 +3140,8 @@ class DagModel(Base): # packaged DAG, it will point to the subpath of the DAG within the # associated zip. fileloc = Column(String(2000)) + # The base directory used by Dag Processor that parsed this dag. + processor_subdir = Column(String(2000), nullable=True) # String representing the owners owners = Column(String(2000)) # Description of the dag @@ -2691,15 +3152,18 @@ class DagModel(Base): schedule_interval = Column(Interval) # Timetable/Schedule Interval description timetable_description = Column(String(1000), nullable=True) - # Tags for view filter - tags = relationship('DagTag', cascade='all, delete, delete-orphan', backref=backref('dag')) + tags = relationship("DagTag", cascade="all, delete, delete-orphan", backref=backref("dag")) + # Dag owner links for DAGs view + dag_owner_links = relationship( + "DagOwnerAttributes", cascade="all, delete, delete-orphan", backref=backref("dag") + ) max_active_tasks = Column(Integer, nullable=False) max_active_runs = Column(Integer, nullable=True) has_task_concurrency_limits = Column(Boolean, nullable=False) - has_import_errors = Column(Boolean(), default=False) + has_import_errors = Column(Boolean(), default=False, server_default="0") # The logical date of the next dag run. next_dagrun = Column(UtcDateTime) @@ -2712,15 +3176,23 @@ class DagModel(Base): next_dagrun_create_after = Column(UtcDateTime) __table_args__ = ( - Index('idx_root_dag_id', root_dag_id, unique=False), - Index('idx_next_dagrun_create_after', next_dagrun_create_after, unique=False), + Index("idx_root_dag_id", root_dag_id, unique=False), + Index("idx_next_dagrun_create_after", next_dagrun_create_after, unique=False), ) parent_dag = relationship( "DagModel", remote_side=[dag_id], primaryjoin=root_dag_id == dag_id, foreign_keys=[root_dag_id] ) - - NUM_DAGS_PER_DAGRUN_QUERY = conf.getint('scheduler', 'max_dagruns_to_create_per_loop', fallback=10) + schedule_dataset_references = relationship( + "DagScheduleDatasetReference", + cascade="all, delete, delete-orphan", + ) + schedule_datasets = association_proxy("schedule_dataset_references", "dataset") + task_outlet_dataset_references = relationship( + "TaskOutletDatasetReference", + cascade="all, delete, delete-orphan", + ) + NUM_DAGS_PER_DAGRUN_QUERY = conf.getint("scheduler", "max_dagruns_to_create_per_loop", fallback=10) def __init__(self, concurrency=None, **kwargs): super().__init__(**kwargs) @@ -2728,15 +3200,15 @@ def __init__(self, concurrency=None, **kwargs): if concurrency: warnings.warn( "The 'DagModel.concurrency' parameter is deprecated. Please use 'max_active_tasks'.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) self.max_active_tasks = concurrency else: - self.max_active_tasks = conf.getint('core', 'max_active_tasks_per_dag') + self.max_active_tasks = conf.getint("core", "max_active_tasks_per_dag") if self.max_active_runs is None: - self.max_active_runs = conf.getint('core', 'max_active_runs_per_dag') + self.max_active_runs = conf.getint("core", "max_active_runs_per_dag") if self.has_task_concurrency_limits is None: # Be safe -- this will be updated later once the DAG is parsed @@ -2746,7 +3218,7 @@ def __repr__(self): return f"" @property - def next_dagrun_data_interval(self) -> Optional[DataInterval]: + def next_dagrun_data_interval(self) -> DataInterval | None: return _get_model_data_interval( self, "next_dagrun_data_interval_start", @@ -2754,7 +3226,7 @@ def next_dagrun_data_interval(self) -> Optional[DataInterval]: ) @next_dagrun_data_interval.setter - def next_dagrun_data_interval(self, value: Optional[Tuple[datetime, datetime]]) -> None: + def next_dagrun_data_interval(self, value: tuple[datetime, datetime] | None) -> None: if value is None: self.next_dagrun_data_interval_start = self.next_dagrun_data_interval_end = None else: @@ -2774,28 +3246,19 @@ def get_dagmodel(dag_id, session=NEW_SESSION): def get_current(cls, dag_id, session=NEW_SESSION): return session.query(cls).filter(cls.dag_id == dag_id).first() - @staticmethod - @provide_session - def get_all_paused_dag_ids(session: Session = NEW_SESSION) -> Set[str]: - """Get a set of paused DAG ids""" - paused_dag_ids = session.query(DagModel.dag_id).filter(DagModel.is_paused == expression.true()).all() - - paused_dag_ids = {paused_dag_id for paused_dag_id, in paused_dag_ids} - return paused_dag_ids - @provide_session def get_last_dagrun(self, session=NEW_SESSION, include_externally_triggered=False): return get_last_dagrun( self.dag_id, session=session, include_externally_triggered=include_externally_triggered ) - def get_is_paused(self, *, session: Optional[Session] = None) -> bool: + def get_is_paused(self, *, session: Session | None = None) -> bool: """Provide interface compatibility to 'DAG'.""" return self.is_paused @staticmethod @provide_session - def get_paused_dag_ids(dag_ids: List[str], session: Session = NEW_SESSION) -> Set[str]: + def get_paused_dag_ids(dag_ids: list[str], session: Session = NEW_SESSION) -> set[str]: """ Given a list of dag_ids, get a set of Paused Dag Ids @@ -2819,14 +3282,14 @@ def get_default_view(self) -> str: have a value """ # This is for backwards-compatibility with old dags that don't have None as default_view - return self.default_view or conf.get_mandatory_value('webserver', 'dag_default_view').lower() + return self.default_view or conf.get_mandatory_value("webserver", "dag_default_view").lower() @property def safe_dag_id(self): - return self.dag_id.replace('.', '__dot__') + return self.dag_id.replace(".", "__dot__") @property - def relative_fileloc(self) -> Optional[pathlib.Path]: + def relative_fileloc(self) -> pathlib.Path | None: """File location of the importable dag 'file' relative to the configured DAGs folder.""" if self.fileloc is None: return None @@ -2852,13 +3315,13 @@ def set_is_paused(self, is_paused: bool, including_subdags: bool = True, session if including_subdags: filter_query.append(DagModel.root_dag_id == self.dag_id) session.query(DagModel).filter(or_(*filter_query)).update( - {DagModel.is_paused: is_paused}, synchronize_session='fetch' + {DagModel.is_paused: is_paused}, synchronize_session="fetch" ) session.commit() @classmethod @provide_session - def deactivate_deleted_dags(cls, alive_dag_filelocs: List[str], session=NEW_SESSION): + def deactivate_deleted_dags(cls, alive_dag_filelocs: list[str], session=NEW_SESSION): """ Set ``is_active=False`` on the DAGs for which the DAG files have been removed. @@ -2876,36 +3339,74 @@ def deactivate_deleted_dags(cls, alive_dag_filelocs: List[str], session=NEW_SESS continue @classmethod - def dags_needing_dagruns(cls, session: Session): + def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[datetime, datetime]]]: """ Return (and lock) a list of Dag objects that are due to create a new DagRun. - This will return a resultset of rows that is row-level-locked with a "SELECT ... FOR UPDATE" query, + This will return a resultset of rows that is row-level-locked with a "SELECT ... FOR UPDATE" query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as the transaction is committed it will be unlocked. """ - # TODO[HA]: Bake this query, it is run _A lot_ - # We limit so that _one_ scheduler doesn't try to do all the creation - # of dag runs + from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue as DDRQ + + # these dag ids are triggered by datasets, and they are ready to go. + dataset_triggered_dag_info = { + x.dag_id: (x.first_queued_time, x.last_queued_time) + for x in session.query( + DagScheduleDatasetReference.dag_id, + func.max(DDRQ.created_at).label("last_queued_time"), + func.min(DDRQ.created_at).label("first_queued_time"), + ) + .join(DagScheduleDatasetReference.queue_records, isouter=True) + .group_by(DagScheduleDatasetReference.dag_id) + .having(func.count() == func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0))) + .all() + } + dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys()) + if dataset_triggered_dag_ids: + exclusion_list = { + x.dag_id + for x in ( + session.query(DagModel.dag_id) + .join(DagRun.dag_model) + .filter(DagRun.state.in_((DagRunState.QUEUED, DagRunState.RUNNING))) + .filter(DagModel.dag_id.in_(dataset_triggered_dag_ids)) + .group_by(DagModel.dag_id) + .having(func.count() >= func.max(DagModel.max_active_runs)) + .all() + ) + } + if exclusion_list: + dataset_triggered_dag_ids -= exclusion_list + dataset_triggered_dag_info = { + k: v for k, v in dataset_triggered_dag_info.items() if k not in exclusion_list + } + # We limit so that _one_ scheduler doesn't try to do all the creation of dag runs query = ( session.query(cls) .filter( cls.is_paused == expression.false(), cls.is_active == expression.true(), cls.has_import_errors == expression.false(), - cls.next_dagrun_create_after <= func.now(), + or_( + cls.next_dagrun_create_after <= func.now(), + cls.dag_id.in_(dataset_triggered_dag_ids), + ), ) .order_by(cls.next_dagrun_create_after) .limit(cls.NUM_DAGS_PER_DAGRUN_QUERY) ) - return with_row_locks(query, of=cls, session=session, **skip_locked(session=session)) + return ( + with_row_locks(query, of=cls, session=session, **skip_locked(session=session)), + dataset_triggered_dag_info, + ) def calculate_dagrun_date_fields( self, dag: DAG, - most_recent_dag_run: Union[None, datetime, DataInterval], + most_recent_dag_run: None | datetime | DataInterval, ) -> None: """ Calculate ``next_dagrun`` and `next_dagrun_create_after`` @@ -2914,12 +3415,12 @@ def calculate_dagrun_date_fields( :param most_recent_dag_run: DataInterval (or datetime) of most recent run of this dag, or none if not yet scheduled. """ - most_recent_data_interval: Optional[DataInterval] + most_recent_data_interval: DataInterval | None if isinstance(most_recent_dag_run, datetime): warnings.warn( "Passing a datetime to `DagModel.calculate_dagrun_date_fields` is deprecated. " "Provide a data interval instead.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) most_recent_data_interval = dag.infer_automated_data_interval(most_recent_dag_run) @@ -2940,8 +3441,49 @@ def calculate_dagrun_date_fields( self.next_dagrun_create_after, ) - -def dag(*dag_args, **dag_kwargs): + @provide_session + def get_dataset_triggered_next_run_info(self, *, session=NEW_SESSION) -> dict[str, int | str] | None: + if self.schedule_interval != "Dataset": + return None + return get_dataset_triggered_next_run_info([self.dag_id], session=session)[self.dag_id] + + +# NOTE: Please keep the list of arguments in sync with DAG.__init__. +# Only exception: dag_id here should have a default value, but not in DAG. +def dag( + dag_id: str = "", + description: str | None = None, + schedule: ScheduleArg = NOTSET, + schedule_interval: ScheduleIntervalArg = NOTSET, + timetable: Timetable | None = None, + start_date: datetime | None = None, + end_date: datetime | None = None, + full_filepath: str | None = None, + template_searchpath: str | Iterable[str] | None = None, + template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined, + user_defined_macros: dict | None = None, + user_defined_filters: dict | None = None, + default_args: dict | None = None, + concurrency: int | None = None, + max_active_tasks: int = conf.getint("core", "max_active_tasks_per_dag"), + max_active_runs: int = conf.getint("core", "max_active_runs_per_dag"), + dagrun_timeout: timedelta | None = None, + sla_miss_callback: SLAMissCallback | None = None, + default_view: str = conf.get_mandatory_value("webserver", "dag_default_view").lower(), + orientation: str = conf.get_mandatory_value("webserver", "dag_orientation"), + catchup: bool = conf.getboolean("scheduler", "catchup_by_default"), + on_success_callback: DagStateChangeCallback | None = None, + on_failure_callback: DagStateChangeCallback | None = None, + doc_md: str | None = None, + params: dict | None = None, + access_control: dict | None = None, + is_paused_upon_creation: bool | None = None, + jinja_environment_kwargs: dict | None = None, + render_template_as_native_obj: bool = False, + tags: list[str] | None = None, + owner_links: dict[str, str] | None = None, + auto_register: bool = True, +) -> Callable[[Callable], Callable[..., DAG]]: """ Python dag decorator. Wraps a function into an Airflow DAG. Accepts kwargs for operator kwarg. Can be used to parameterize DAGs. @@ -2950,26 +3492,51 @@ def dag(*dag_args, **dag_kwargs): :param dag_kwargs: Kwargs for DAG object. """ - def wrapper(f: Callable): - # Get dag initializer signature and bind it to validate that dag_args, and dag_kwargs are correct - dag_sig = signature(DAG.__init__) - dag_bound_args = dag_sig.bind_partial(*dag_args, **dag_kwargs) - + def wrapper(f: Callable) -> Callable[..., DAG]: @functools.wraps(f) def factory(*args, **kwargs): # Generate signature for decorated function and bind the arguments when called - # we do this to extract parameters so we can annotate them on the DAG object. + # we do this to extract parameters, so we can annotate them on the DAG object. # In addition, this fails if we are missing any args/kwargs with TypeError as expected. f_sig = signature(f).bind(*args, **kwargs) # Apply defaults to capture default values if set. f_sig.apply_defaults() - # Set function name as dag_id if not set - dag_id = dag_bound_args.arguments.get('dag_id', f.__name__) - dag_bound_args.arguments['dag_id'] = dag_id - # Initialize DAG with bound arguments - with DAG(*dag_bound_args.args, **dag_bound_args.kwargs) as dag_obj: + with DAG( + dag_id or f.__name__, + description=description, + schedule_interval=schedule_interval, + timetable=timetable, + start_date=start_date, + end_date=end_date, + full_filepath=full_filepath, + template_searchpath=template_searchpath, + template_undefined=template_undefined, + user_defined_macros=user_defined_macros, + user_defined_filters=user_defined_filters, + default_args=default_args, + concurrency=concurrency, + max_active_tasks=max_active_tasks, + max_active_runs=max_active_runs, + dagrun_timeout=dagrun_timeout, + sla_miss_callback=sla_miss_callback, + default_view=default_view, + orientation=orientation, + catchup=catchup, + on_success_callback=on_success_callback, + on_failure_callback=on_failure_callback, + doc_md=doc_md, + params=params, + access_control=access_control, + is_paused_upon_creation=is_paused_upon_creation, + jinja_environment_kwargs=jinja_environment_kwargs, + render_template_as_native_obj=render_template_as_native_obj, + tags=tags, + schedule=schedule, + owner_links=owner_links, + auto_register=auto_register, + ) as dag_obj: # Set DAG documentation from function documentation. if f.__doc__: dag_obj.doc_md = f.__doc__ @@ -2990,13 +3557,15 @@ def factory(*args, **kwargs): # Return dag object such that it's accessible in Globals. return dag_obj + # Ensure that warnings from inside DAG() are emitted from the caller, not here + fixup_decorator_warning_stack(factory) return factory return wrapper STATICA_HACK = True -globals()['kcah_acitats'[::-1].upper()] = False +globals()["kcah_acitats"[::-1].upper()] = False if STATICA_HACK: # pragma: no cover from airflow.models.serialized_dag import SerializedDagModel @@ -3016,7 +3585,7 @@ class DagContext: with DAG( dag_id="example_dag", default_args=default_args, - schedule_interval="0 0 * * *", + schedule="0 0 * * *", dagrun_timeout=timedelta(minutes=60), ) as dag: ... @@ -3026,24 +3595,91 @@ class DagContext: """ - _context_managed_dag: Optional[DAG] = None - _previous_context_managed_dags: List[DAG] = [] + _context_managed_dags: Deque[DAG] = deque() + autoregistered_dags: set[tuple[DAG, ModuleType]] = set() + current_autoregister_module_name: str | None = None @classmethod def push_context_managed_dag(cls, dag: DAG): - if cls._context_managed_dag: - cls._previous_context_managed_dags.append(cls._context_managed_dag) - cls._context_managed_dag = dag + cls._context_managed_dags.appendleft(dag) @classmethod - def pop_context_managed_dag(cls) -> Optional[DAG]: - old_dag = cls._context_managed_dag - if cls._previous_context_managed_dags: - cls._context_managed_dag = cls._previous_context_managed_dags.pop() - else: - cls._context_managed_dag = None - return old_dag + def pop_context_managed_dag(cls) -> DAG | None: + dag = cls._context_managed_dags.popleft() + + # In a few cases around serialization we explicitly push None in to the stack + if cls.current_autoregister_module_name is not None and dag and dag.auto_register: + mod = sys.modules[cls.current_autoregister_module_name] + cls.autoregistered_dags.add((dag, mod)) + + return dag @classmethod - def get_current_dag(cls) -> Optional[DAG]: - return cls._context_managed_dag + def get_current_dag(cls) -> DAG | None: + try: + return cls._context_managed_dags[0] + except IndexError: + return None + + +def _run_task(ti: TaskInstance, session): + """ + Run a single task instance, and push result to Xcom for downstream tasks. Bypasses a lot of + extra steps used in `task.run` to keep our local running as fast as possible + This function is only meant for the `dag.test` function as a helper function. + + Args: + ti: TaskInstance to run + """ + log.info("*****************************************************") + if ti.map_index > 0: + log.info("Running task %s index %d", ti.task_id, ti.map_index) + else: + log.info("Running task %s", ti.task_id) + try: + ti._run_raw_task(session=session) + session.flush() + log.info("%s ran successfully!", ti.task_id) + except AirflowSkipException: + log.info("Task Skipped, continuing") + log.info("*****************************************************") + + +def _get_or_create_dagrun( + dag: DAG, + conf: dict[Any, Any] | None, + start_date: datetime, + execution_date: datetime, + run_id: str, + session: Session, +) -> DagRun: + """ + Create a DAGRun, but only after clearing the previous instance of said dagrun to prevent collisions. + This function is only meant for the `dag.test` function as a helper function. + :param dag: Dag to be used to find dagrun + :param conf: configuration to pass to newly created dagrun + :param start_date: start date of new dagrun, defaults to execution_date + :param execution_date: execution_date for finding the dagrun + :param run_id: run_id to pass to new dagrun + :param session: sqlalchemy session + :return: + """ + log.info("dagrun id: %s", dag.dag_id) + dr: DagRun = ( + session.query(DagRun) + .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date) + .first() + ) + if dr: + session.delete(dr) + session.commit() + dr = dag.create_dagrun( + state=DagRunState.RUNNING, + execution_date=execution_date, + run_id=run_id, + start_date=start_date or execution_date, + session=session, + conf=conf, # type: ignore + ) + log.info("created dagrun " + str(dr)) + return dr diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 3673ce095ea16..f78125ebd9d46 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import hashlib import importlib @@ -27,7 +28,7 @@ import warnings import zipfile from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional, Union +from typing import TYPE_CHECKING, NamedTuple from sqlalchemy.exc import OperationalError from sqlalchemy.orm import Session @@ -39,8 +40,10 @@ AirflowClusterPolicyViolation, AirflowDagCycleException, AirflowDagDuplicatedIdException, + AirflowDagInconsistent, AirflowTimetableInvalid, ParamValidationError, + RemovedInAirflow3Warning, ) from airflow.stats import Stats from airflow.utils import timezone @@ -51,6 +54,7 @@ from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries from airflow.utils.session import provide_session from airflow.utils.timeout import timeout +from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: import pathlib @@ -79,8 +83,6 @@ class DagBag(LoggingMixin): :param dag_folder: the folder to scan to find DAGs :param include_examples: whether to include the examples that ship with airflow or not - :param include_smart_sensor: whether to include the smart sensor native - DAGs that create the smart sensor operators for whole cluster :param read_dags_from_db: Read DAGs from DB if ``True`` is passed. If ``False`` DAGs are read from python files. :param load_op_links: Should the extra operator link be loaded via plugins when @@ -90,49 +92,58 @@ class DagBag(LoggingMixin): def __init__( self, - dag_folder: Union[str, "pathlib.Path", None] = None, - include_examples: bool = conf.getboolean('core', 'LOAD_EXAMPLES'), - include_smart_sensor: bool = conf.getboolean('smart_sensor', 'USE_SMART_SENSOR'), - safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'), + dag_folder: str | pathlib.Path | None = None, + include_examples: bool | ArgNotSet = NOTSET, + safe_mode: bool | ArgNotSet = NOTSET, read_dags_from_db: bool = False, - store_serialized_dags: Optional[bool] = None, + store_serialized_dags: bool | None = None, load_op_links: bool = True, + collect_dags: bool = True, ): # Avoid circular import from airflow.models.dag import DAG super().__init__() + include_examples = ( + include_examples + if isinstance(include_examples, bool) + else conf.getboolean("core", "LOAD_EXAMPLES") + ) + safe_mode = ( + safe_mode if isinstance(safe_mode, bool) else conf.getboolean("core", "DAG_DISCOVERY_SAFE_MODE") + ) + if store_serialized_dags: warnings.warn( "The store_serialized_dags parameter has been deprecated. " "You should pass the read_dags_from_db parameter.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) read_dags_from_db = store_serialized_dags dag_folder = dag_folder or settings.DAGS_FOLDER self.dag_folder = dag_folder - self.dags: Dict[str, DAG] = {} + self.dags: dict[str, DAG] = {} # the file's last modified timestamp when we last read it - self.file_last_changed: Dict[str, datetime] = {} - self.import_errors: Dict[str, str] = {} + self.file_last_changed: dict[str, datetime] = {} + self.import_errors: dict[str, str] = {} self.has_logged = False self.read_dags_from_db = read_dags_from_db # Only used by read_dags_from_db=True - self.dags_last_fetched: Dict[str, datetime] = {} + self.dags_last_fetched: dict[str, datetime] = {} # Only used by SchedulerJob to compare the dag_hash to identify change in DAGs - self.dags_hash: Dict[str, str] = {} - - self.dagbag_import_error_tracebacks = conf.getboolean('core', 'dagbag_import_error_tracebacks') - self.dagbag_import_error_traceback_depth = conf.getint('core', 'dagbag_import_error_traceback_depth') - self.collect_dags( - dag_folder=dag_folder, - include_examples=include_examples, - include_smart_sensor=include_smart_sensor, - safe_mode=safe_mode, - ) + self.dags_hash: dict[str, str] = {} + + self.dagbag_import_error_tracebacks = conf.getboolean("core", "dagbag_import_error_tracebacks") + self.dagbag_import_error_traceback_depth = conf.getint("core", "dagbag_import_error_traceback_depth") + if collect_dags: + self.collect_dags( + dag_folder=dag_folder, + include_examples=include_examples, + safe_mode=safe_mode, + ) # Should the extra operator link be loaded via plugins? # This flag is set to False in Scheduler so that Extra Operator links are not loaded self.load_op_links = load_op_links @@ -143,19 +154,20 @@ def size(self) -> int: @property def store_serialized_dags(self) -> bool: - """Whether or not to read dags from DB""" + """Whether to read dags from DB""" warnings.warn( "The store_serialized_dags property has been deprecated. Use read_dags_from_db instead.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) return self.read_dags_from_db @property - def dag_ids(self) -> List[str]: + def dag_ids(self) -> list[str]: """ + Get DAG ids. + :return: a list of DAG IDs in this bag - :rtype: List[unicode] """ return list(self.dags.keys()) @@ -164,7 +176,7 @@ def get_dag(self, dag_id, session: Session = None): """ Gets the DAG out of the dictionary, and refreshes it if expired - :param dag_id: DAG Id + :param dag_id: DAG ID """ # Avoid circular import from airflow.models.dag import DagModel @@ -262,9 +274,12 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True): Given a path to a python module or zip file, this method imports the module and look for dag objects within it. """ + from airflow.models.dag import DagContext + # if the source file no longer exists in the DB or in the filesystem, # return an empty list # todo: raise exception? + if filepath is None or not os.path.isfile(filepath): return [] @@ -282,6 +297,9 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True): self.log.exception(e) return [] + # Ensure we don't pick up anything else we didn't mean to + DagContext.autoregistered_dags.clear() + if filepath.endswith(".py") or not zipfile.is_zipfile(filepath): mods = self._load_modules_from_file(filepath, safe_mode) else: @@ -293,6 +311,8 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True): return found_dags def _load_modules_from_file(self, filepath, safe_mode): + from airflow.models.dag import DagContext + if not might_contain_dag(filepath, safe_mode): # Don't want to spam user with skip messages if not self.has_logged: @@ -302,12 +322,14 @@ def _load_modules_from_file(self, filepath, safe_mode): self.log.debug("Importing %s", filepath) org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1]) - path_hash = hashlib.sha1(filepath.encode('utf-8')).hexdigest() - mod_name = f'unusual_prefix_{path_hash}_{org_mod_name}' + path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest() + mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}" if mod_name in sys.modules: del sys.modules[mod_name] + DagContext.current_autoregister_module_name = mod_name + def parse(mod_name, filepath): try: loader = importlib.machinery.SourceFileLoader(mod_name, filepath) @@ -317,6 +339,7 @@ def parse(mod_name, filepath): loader.exec_module(new_module) return [new_module] except Exception as e: + DagContext.autoregistered_dags.clear() self.log.exception("Failed to import: %s", filepath) if self.dagbag_import_error_tracebacks: self.import_errors[filepath] = traceback.format_exc( @@ -330,7 +353,7 @@ def parse(mod_name, filepath): if not isinstance(dagbag_import_timeout, (int, float)): raise TypeError( - f'Value ({dagbag_import_timeout}) from get_dagbag_import_timeout must be int or float' + f"Value ({dagbag_import_timeout}) from get_dagbag_import_timeout must be int or float" ) if dagbag_import_timeout <= 0: # no parsing timeout @@ -346,6 +369,8 @@ def parse(mod_name, filepath): return parse(mod_name, filepath) def _load_modules_from_zip(self, filepath, safe_mode): + from airflow.models.dag import DagContext + mods = [] with zipfile.ZipFile(filepath) as current_zip_file: for zip_info in current_zip_file.infolist(): @@ -356,7 +381,7 @@ def _load_modules_from_zip(self, filepath, safe_mode): if head: continue - if mod_name == '__init__': + if mod_name == "__init__": self.log.warning("Found __init__.%s at root of %s", ext, filepath) self.log.debug("Reading %s from %s", zip_info.filename, filepath) @@ -374,11 +399,13 @@ def _load_modules_from_zip(self, filepath, safe_mode): if mod_name in sys.modules: del sys.modules[mod_name] + DagContext.current_autoregister_module_name = mod_name try: sys.path.insert(0, filepath) current_module = importlib.import_module(mod_name) mods.append(current_module) except Exception as e: + DagContext.autoregistered_dags.clear() fileloc = os.path.join(filepath, zip_info.filename) self.log.exception("Failed to import: %s", fileloc) if self.dagbag_import_error_tracebacks: @@ -393,34 +420,39 @@ def _load_modules_from_zip(self, filepath, safe_mode): return mods def _process_modules(self, filepath, mods, file_last_changed_on_disk): - from airflow.models.dag import DAG # Avoid circular import + from airflow.models.dag import DAG, DagContext # Avoid circular import + + top_level_dags = {(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)} - top_level_dags = ((o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)) + top_level_dags.update(DagContext.autoregistered_dags) + + DagContext.current_autoregister_module_name = None + DagContext.autoregistered_dags.clear() found_dags = [] for (dag, mod) in top_level_dags: dag.fileloc = mod.__file__ try: - dag.timetable.validate() - # validate dag params - dag.params.validate() + dag.validate() self.bag_dag(dag=dag, root_dag=dag) - found_dags.append(dag) - found_dags += dag.subdags except AirflowTimetableInvalid as exception: self.log.exception("Failed to bag_dag: %s", dag.fileloc) self.import_errors[dag.fileloc] = f"Invalid timetable expression: {exception}" self.file_last_changed[dag.fileloc] = file_last_changed_on_disk except ( + AirflowClusterPolicyViolation, AirflowDagCycleException, AirflowDagDuplicatedIdException, - AirflowClusterPolicyViolation, + AirflowDagInconsistent, ParamValidationError, ) as exception: self.log.exception("Failed to bag_dag: %s", dag.fileloc) self.import_errors[dag.fileloc] = str(exception) self.file_last_changed[dag.fileloc] = file_last_changed_on_disk + else: + found_dags.append(dag) + found_dags += dag.subdags return found_dags def bag_dag(self, dag, root_dag): @@ -468,10 +500,10 @@ def _bag_dag(self, *, dag, root_dag, recursive): existing=self.dags[dag.dag_id].fileloc, ) self.dags[dag.dag_id] = dag - self.log.debug('Loaded DAG %s', dag) + self.log.debug("Loaded DAG %s", dag) except (AirflowDagCycleException, AirflowDagDuplicatedIdException): # There was an error in bagging the dag. Remove it from the list of dags - self.log.exception('Exception bagging dag: %s', dag.dag_id) + self.log.exception("Exception bagging dag: %s", dag.dag_id) # Only necessary at the root level since DAG.subdags automatically # performs DFS to search through all subdags if recursive: @@ -482,11 +514,10 @@ def _bag_dag(self, *, dag, root_dag, recursive): def collect_dags( self, - dag_folder: Union[str, "pathlib.Path", None] = None, + dag_folder: str | pathlib.Path | None = None, only_if_updated: bool = True, - include_examples: bool = conf.getboolean('core', 'LOAD_EXAMPLES'), - include_smart_sensor: bool = conf.getboolean('smart_sensor', 'USE_SMART_SENSOR'), - safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'), + include_examples: bool = conf.getboolean("core", "LOAD_EXAMPLES"), + safe_mode: bool = conf.getboolean("core", "DAG_DISCOVERY_SAFE_MODE"), ): """ Given a file path or a folder, this method looks for python modules, @@ -515,7 +546,6 @@ def collect_dags( dag_folder, safe_mode=safe_mode, include_examples=include_examples, - include_smart_sensor=include_smart_sensor, ): try: file_parse_start_dttm = timezone.utcnow() @@ -524,7 +554,7 @@ def collect_dags( file_parse_end_dttm = timezone.utcnow() stats.append( FileLoadStat( - file=filepath.replace(settings.DAGS_FOLDER, ''), + file=filepath.replace(settings.DAGS_FOLDER, ""), duration=file_parse_end_dttm - file_parse_start_dttm, dag_num=len(found_dags), task_num=sum(len(dag.tasks) for dag in found_dags), @@ -540,7 +570,7 @@ def collect_dags_from_db(self): """Collects DAGs from database.""" from airflow.models.serialized_dag import SerializedDagModel - with Stats.timer('collect_db_dags'): + with Stats.timer("collect_db_dags"): self.log.info("Filling up the DagBag from database") # The dagbag contains all rows in serialized_dag table. Deleted DAGs are deleted @@ -579,7 +609,7 @@ def dagbag_report(self): return report @provide_session - def sync_to_db(self, session: Session = None): + def sync_to_db(self, processor_subdir: str | None = None, session: Session = None): """Save attributes about list of DAG to the DB.""" # To avoid circular import - airflow.models.dagbag -> airflow.models.dag -> airflow.models.dagbag from airflow.models.dag import DAG @@ -626,7 +656,9 @@ def _serialize_dag_capturing_errors(dag, session): for dag in self.dags.values(): serialize_errors.extend(_serialize_dag_capturing_errors(dag, session)) - DAG.bulk_write_to_db(self.dags.values(), session=session) + DAG.bulk_write_to_db( + self.dags.values(), processor_subdir=processor_subdir, session=session + ) except OperationalError: session.rollback() raise @@ -640,6 +672,8 @@ def _sync_perm_for_dag(self, dag, session: Session = None): from airflow.security.permissions import DAG_ACTIONS, resource_name_for_dag from airflow.www.fab_security.sqla.models import Action, Permission, Resource + root_dag_id = dag.parent_dag.dag_id if dag.parent_dag else dag.dag_id + def needs_perms(dag_id: str) -> bool: dag_resource_name = resource_name_for_dag(dag_id) for permission_name in DAG_ACTIONS: @@ -654,9 +688,9 @@ def needs_perms(dag_id: str) -> bool: return True return False - if dag.access_control or needs_perms(dag.dag_id): - self.log.debug("Syncing DAG permissions: %s to the DB", dag.dag_id) + if dag.access_control or needs_perms(root_dag_id): + self.log.debug("Syncing DAG permissions: %s to the DB", root_dag_id) from airflow.www.security import ApplessAirflowSecurityManager security_manager = ApplessAirflowSecurityManager(session=session) - security_manager.sync_perm_for_dag(dag.dag_id, dag.access_control) + security_manager.sync_perm_for_dag(root_dag_id, dag.access_control) diff --git a/airflow/models/dagcode.py b/airflow/models/dagcode.py index 7322ba92fb76a..47b9588cc8d00 100644 --- a/airflow/models/dagcode.py +++ b/airflow/models/dagcode.py @@ -14,13 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import logging import os import struct from datetime import datetime -from typing import Iterable, List, Optional +from typing import Iterable from sqlalchemy import BigInteger, Column, String, Text +from sqlalchemy.dialects.mysql import MEDIUMTEXT from sqlalchemy.sql.expression import literal from airflow.exceptions import AirflowException, DagCodeNotFound @@ -41,15 +44,15 @@ class DagCode(Base): For details on dag serialization see SerializedDagModel """ - __tablename__ = 'dag_code' + __tablename__ = "dag_code" fileloc_hash = Column(BigInteger, nullable=False, primary_key=True, autoincrement=False) fileloc = Column(String(2000), nullable=False) # The max length of fileloc exceeds the limit of indexing. last_updated = Column(UtcDateTime, nullable=False) - source_code = Column(Text, nullable=False) + source_code = Column(Text().with_variant(MEDIUMTEXT(), "mysql"), nullable=False) - def __init__(self, full_filepath: str, source_code: Optional[str] = None): + def __init__(self, full_filepath: str, source_code: str | None = None): self.fileloc = full_filepath self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc) self.last_updated = timezone.utcnow() @@ -122,7 +125,7 @@ def bulk_sync_to_db(cls, filelocs: Iterable[str], session=None): @classmethod @provide_session - def remove_deleted_code(cls, alive_dag_filelocs: List[str], session=None): + def remove_deleted_code(cls, alive_dag_filelocs: list[str], session=None): """Deletes code not included in alive_dag_filelocs. :param alive_dag_filelocs: file paths of alive DAGs @@ -134,7 +137,7 @@ def remove_deleted_code(cls, alive_dag_filelocs: List[str], session=None): session.query(cls).filter( cls.fileloc_hash.notin_(alive_fileloc_hashes), cls.fileloc.notin_(alive_dag_filelocs) - ).delete(synchronize_session='fetch') + ).delete(synchronize_session="fetch") @classmethod @provide_session @@ -166,7 +169,7 @@ def code(cls, fileloc) -> str: @staticmethod def _get_code_from_file(fileloc): - with open_maybe_zipped(fileloc, 'r') as f: + with open_maybe_zipped(fileloc, "r") as f: code = f.read() return code @@ -192,4 +195,4 @@ def dag_fileloc_hash(full_filepath: str) -> int: import hashlib # Only 7 bytes because MySQL BigInteger can hold only 8 bytes (signed). - return struct.unpack('>Q', hashlib.sha1(full_filepath.encode('utf-8')).digest()[-8:])[0] >> 8 + return struct.unpack(">Q", hashlib.sha1(full_filepath.encode("utf-8")).digest()[-8:])[0] >> 8 diff --git a/airflow/models/dagparam.py b/airflow/models/dagparam.py index 83a2f2c05532b..f20bd078c80a6 100644 --- a/airflow/models/dagparam.py +++ b/airflow/models/dagparam.py @@ -14,15 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module is deprecated. Please use :mod:`airflow.models.param`.""" +from __future__ import annotations import warnings +from airflow.exceptions import RemovedInAirflow3Warning from airflow.models.param import DagParam # noqa warnings.warn( "This module is deprecated. Please use `airflow.models.param`.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) diff --git a/airflow/models/dagpickle.py b/airflow/models/dagpickle.py index aa56ce3e5884b..caa319e9840f4 100644 --- a/airflow/models/dagpickle.py +++ b/airflow/models/dagpickle.py @@ -15,9 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import dill -from sqlalchemy import Column, Integer, PickleType, Text +from sqlalchemy import BigInteger, Column, Integer, PickleType from airflow.models.base import Base from airflow.utils import timezone @@ -39,13 +40,13 @@ class DagPickle(Base): id = Column(Integer, primary_key=True) pickle = Column(PickleType(pickler=dill)) created_dttm = Column(UtcDateTime, default=timezone.utcnow) - pickle_hash = Column(Text) + pickle_hash = Column(BigInteger) __tablename__ = "dag_pickle" def __init__(self, dag): self.dag_id = dag.dag_id - if hasattr(dag, 'template_env'): + if hasattr(dag, "template_env"): dag.template_env = None self.pickle_hash = hash(dag) self.pickle = dag diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index eeec4d5b099b2..2c736c4c2efe5 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -15,34 +15,26 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import itertools import os import warnings from collections import defaultdict from datetime import datetime -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Generator, - Iterable, - List, - NamedTuple, - Optional, - Sequence, - Tuple, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, NamedTuple, Sequence, TypeVar, overload from sqlalchemy import ( Boolean, Column, ForeignKey, + ForeignKeyConstraint, Index, Integer, PickleType, + PrimaryKeyConstraint, String, + Text, UniqueConstraint, and_, func, @@ -50,6 +42,7 @@ text, ) from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import joinedload, relationship, synonym from sqlalchemy.orm.session import Session @@ -58,14 +51,17 @@ from airflow import settings from airflow.callbacks.callback_requests import DagCallbackRequest from airflow.configuration import conf as airflow_conf -from airflow.exceptions import AirflowException, TaskNotFound -from airflow.models.base import COLLATION_ARGS, ID_LEN, Base -from airflow.models.mappedoperator import MappedOperator +from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskNotFound +from airflow.listeners.listener import get_listener_manager +from airflow.models.abstractoperator import NotMapped +from airflow.models.base import Base, StringID +from airflow.models.expandinput import NotFullyPopulated from airflow.models.taskinstance import TaskInstance as TI from airflow.models.tasklog import LogTemplate from airflow.stats import Stats from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES +from airflow.typing_compat import Literal from airflow.utils import timezone from airflow.utils.helpers import is_container from airflow.utils.log.logging_mixin import LoggingMixin @@ -78,15 +74,28 @@ from airflow.models.dag import DAG from airflow.models.operator import Operator + CreatedTasks = TypeVar("CreatedTasks", Iterator["dict[str, Any]"], Iterator[TI]) + TaskCreator = Callable[[Operator, Iterable[int]], CreatedTasks] + class TISchedulingDecision(NamedTuple): """Type of return for DagRun.task_instance_scheduling_decisions""" - tis: List[TI] - schedulable_tis: List[TI] + tis: list[TI] + schedulable_tis: list[TI] changed_tis: bool - unfinished_tis: List[TI] - finished_tis: List[TI] + unfinished_tis: list[TI] + finished_tis: list[TI] + + +def _creator_note(val): + """Custom creator for the ``note`` association proxy.""" + if isinstance(val, str): + return DagRunNote(content=val) + elif isinstance(val, dict): + return DagRunNote(**val) + else: + return DagRunNote(*val) class DagRun(Base, LoggingMixin): @@ -98,13 +107,13 @@ class DagRun(Base, LoggingMixin): __tablename__ = "dag_run" id = Column(Integer, primary_key=True) - dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) + dag_id = Column(StringID(), nullable=False) queued_at = Column(UtcDateTime) execution_date = Column(UtcDateTime, default=timezone.utcnow, nullable=False) start_date = Column(UtcDateTime) end_date = Column(UtcDateTime) - _state = Column('state', String(50), default=State.QUEUED) - run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) + _state = Column("state", String(50), default=State.QUEUED) + run_id = Column(StringID(), nullable=False) creating_job_id = Column(Integer) external_trigger = Column(Boolean, default=True) run_type = Column(String(50), nullable=False) @@ -123,23 +132,24 @@ class DagRun(Base, LoggingMixin): ForeignKey("log_template.id", name="task_instance_log_template_id_fkey", ondelete="NO ACTION"), default=select([func.max(LogTemplate.__table__.c.id)]), ) + updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow) # Remove this `if` after upgrading Sphinx-AutoAPI if not TYPE_CHECKING and "BUILDING_AIRFLOW_DOCS" in os.environ: - dag: "Optional[DAG]" + dag: DAG | None else: - dag: "Optional[DAG]" = None + dag: DAG | None = None __table_args__ = ( - Index('dag_id_state', dag_id, _state), - UniqueConstraint('dag_id', 'execution_date', name='dag_run_dag_id_execution_date_key'), - UniqueConstraint('dag_id', 'run_id', name='dag_run_dag_id_run_id_key'), - Index('idx_last_scheduling_decision', last_scheduling_decision), - Index('idx_dag_run_dag_id', dag_id), + Index("dag_id_state", dag_id, _state), + UniqueConstraint("dag_id", "execution_date", name="dag_run_dag_id_execution_date_key"), + UniqueConstraint("dag_id", "run_id", name="dag_run_dag_id_run_id_key"), + Index("idx_last_scheduling_decision", last_scheduling_decision), + Index("idx_dag_run_dag_id", dag_id), Index( - 'idx_dag_run_running_dags', - 'state', - 'dag_id', + "idx_dag_run_running_dags", + "state", + "dag_id", postgresql_where=text("state='running'"), mssql_where=text("state='running'"), sqlite_where=text("state='running'"), @@ -147,9 +157,9 @@ class DagRun(Base, LoggingMixin): # since mysql lacks filtered/partial indices, this creates a # duplicate index on mysql. Not the end of the world Index( - 'idx_dag_run_queued_dags', - 'state', - 'dag_id', + "idx_dag_run_queued_dags", + "state", + "dag_id", postgresql_where=text("state='queued'"), mssql_where=text("state='queued'"), sqlite_where=text("state='queued'"), @@ -157,29 +167,37 @@ class DagRun(Base, LoggingMixin): ) task_instances = relationship( - TI, back_populates="dag_run", cascade='save-update, merge, delete, delete-orphan' + TI, back_populates="dag_run", cascade="save-update, merge, delete, delete-orphan" + ) + dag_model = relationship( + "DagModel", + primaryjoin="foreign(DagRun.dag_id) == DagModel.dag_id", + uselist=False, + viewonly=True, ) + dag_run_note = relationship("DagRunNote", back_populates="dag_run", uselist=False) + note = association_proxy("dag_run_note", "content", creator=_creator_note) DEFAULT_DAGRUNS_TO_EXAMINE = airflow_conf.getint( - 'scheduler', - 'max_dagruns_per_loop_to_schedule', + "scheduler", + "max_dagruns_per_loop_to_schedule", fallback=20, ) def __init__( self, - dag_id: Optional[str] = None, - run_id: Optional[str] = None, - queued_at: Union[datetime, None, ArgNotSet] = NOTSET, - execution_date: Optional[datetime] = None, - start_date: Optional[datetime] = None, - external_trigger: Optional[bool] = None, - conf: Optional[Any] = None, - state: Optional[DagRunState] = None, - run_type: Optional[str] = None, - dag_hash: Optional[str] = None, - creating_job_id: Optional[int] = None, - data_interval: Optional[Tuple[datetime, datetime]] = None, + dag_id: str | None = None, + run_id: str | None = None, + queued_at: datetime | None | ArgNotSet = NOTSET, + execution_date: datetime | None = None, + start_date: datetime | None = None, + external_trigger: bool | None = None, + conf: Any | None = None, + state: DagRunState | None = None, + run_type: str | None = None, + dag_hash: str | None = None, + creating_job_id: int | None = None, + data_interval: tuple[datetime, datetime] | None = None, ): if data_interval is None: # Legacy: Only happen for runs created prior to Airflow 2.2. @@ -206,11 +224,14 @@ def __init__( def __repr__(self): return ( - '' + "" ).format( dag_id=self.dag_id, execution_date=self.execution_date, run_id=self.run_id, + state=self.state, + queued_at=self.queued_at, external_trigger=self.external_trigger, ) @@ -232,7 +253,7 @@ def set_state(self, state: DagRunState): @declared_attr def state(self): - return synonym('_state', descriptor=property(self.get_state, self.set_state)) + return synonym("_state", descriptor=property(self.get_state, self.set_state)) @provide_session def refresh_from_db(self, session: Session = NEW_SESSION) -> None: @@ -247,9 +268,9 @@ def refresh_from_db(self, session: Session = NEW_SESSION) -> None: @classmethod @provide_session - def active_runs_of_dags(cls, dag_ids=None, only_running=False, session=None) -> Dict[str, int]: + def active_runs_of_dags(cls, dag_ids=None, only_running=False, session=None) -> dict[str, int]: """Get the number of active dag runs for each dag.""" - query = session.query(cls.dag_id, func.count('*')) + query = session.query(cls.dag_id, func.count("*")) if dag_ids is not None: # 'set' called to avoid duplicate dag_ids, but converted back to 'list' # because SQLAlchemy doesn't accept a set here. @@ -266,8 +287,8 @@ def next_dagruns_to_examine( cls, state: DagRunState, session: Session, - max_number: Optional[int] = None, - ): + max_number: int | None = None, + ) -> list[DagRun]: """ Return the next DagRuns that the scheduler should attempt to schedule. @@ -275,7 +296,6 @@ def next_dagruns_to_examine( query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as the transaction is committed it will be unlocked. - :rtype: list[airflow.models.DagRun] """ from airflow.models.dag import DagModel @@ -285,6 +305,7 @@ def next_dagruns_to_examine( # TODO: Bake this query, it is run _A lot_ query = ( session.query(cls) + .with_hint(cls, "USE INDEX (idx_dag_run_running_dags)", dialect_name="mysql") .filter(cls.state == state, cls.run_type != DagRunType.BACKFILL_JOB) .join(DagModel, DagModel.dag_id == cls.dag_id) .filter(DagModel.is_paused == false(), DagModel.is_active == true()) @@ -293,7 +314,7 @@ def next_dagruns_to_examine( # For dag runs in the queued state, we check if they have reached the max_active_runs limit # and if so we drop them running_drs = ( - session.query(DagRun.dag_id, func.count(DagRun.state).label('num_running')) + session.query(DagRun.dag_id, func.count(DagRun.state).label("num_running")) .filter(DagRun.state == DagRunState.RUNNING) .group_by(DagRun.dag_id) .subquery() @@ -317,17 +338,17 @@ def next_dagruns_to_examine( @provide_session def find( cls, - dag_id: Optional[Union[str, List[str]]] = None, - run_id: Optional[Iterable[str]] = None, - execution_date: Optional[Union[datetime, Iterable[datetime]]] = None, - state: Optional[DagRunState] = None, - external_trigger: Optional[bool] = None, + dag_id: str | list[str] | None = None, + run_id: Iterable[str] | None = None, + execution_date: datetime | Iterable[datetime] | None = None, + state: DagRunState | None = None, + external_trigger: bool | None = None, no_backfills: bool = False, - run_type: Optional[DagRunType] = None, + run_type: DagRunType | None = None, session: Session = NEW_SESSION, - execution_start_date: Optional[datetime] = None, - execution_end_date: Optional[datetime] = None, - ) -> List["DagRun"]: + execution_start_date: datetime | None = None, + execution_end_date: datetime | None = None, + ) -> list[DagRun]: """ Returns a set of dag runs for the given search criteria. @@ -381,7 +402,7 @@ def find_duplicate( run_id: str, execution_date: datetime, session: Session = NEW_SESSION, - ) -> Optional['DagRun']: + ) -> DagRun | None: """ Return an existing run for the DAG with a specific run_id or execution_date. @@ -404,14 +425,15 @@ def find_duplicate( @staticmethod def generate_run_id(run_type: DagRunType, execution_date: datetime) -> str: """Generate Run ID based on Run Type and Execution Date""" - return f"{run_type}__{execution_date.isoformat()}" + # _Ensure_ run_type is a DagRunType, not just a string from user code + return DagRunType(run_type).generate_run_id(execution_date) @provide_session def get_task_instances( self, - state: Optional[Iterable[Optional[TaskInstanceState]]] = None, + state: Iterable[TaskInstanceState | None] | None = None, session: Session = NEW_SESSION, - ) -> List[TI]: + ) -> list[TI]: """Returns the task instances for this dag run""" tis = ( session.query(TI) @@ -447,7 +469,7 @@ def get_task_instance( session: Session = NEW_SESSION, *, map_index: int = -1, - ) -> Optional[TI]: + ) -> TI | None: """ Returns the task instance specified by task_id for this dag run @@ -460,7 +482,7 @@ def get_task_instance( .one_or_none() ) - def get_dag(self) -> "DAG": + def get_dag(self) -> DAG: """ Returns the Dag associated with this DagRun. @@ -473,8 +495,8 @@ def get_dag(self) -> "DAG": @provide_session def get_previous_dagrun( - self, state: Optional[DagRunState] = None, session: Session = NEW_SESSION - ) -> Optional['DagRun']: + self, state: DagRunState | None = None, session: Session = NEW_SESSION + ) -> DagRun | None: """The previous DagRun, if there is one""" filters = [ DagRun.dag_id == self.dag_id, @@ -485,7 +507,7 @@ def get_previous_dagrun( return session.query(DagRun).filter(*filters).order_by(DagRun.execution_date.desc()).first() @provide_session - def get_previous_scheduled_dagrun(self, session: Session = NEW_SESSION) -> Optional['DagRun']: + def get_previous_scheduled_dagrun(self, session: Session = NEW_SESSION) -> DagRun | None: """The previous, SCHEDULED DagRun, if there is one""" return ( session.query(DagRun) @@ -501,19 +523,38 @@ def get_previous_scheduled_dagrun(self, session: Session = NEW_SESSION) -> Optio @provide_session def update_state( self, session: Session = NEW_SESSION, execute_callbacks: bool = True - ) -> Tuple[List[TI], Optional[DagCallbackRequest]]: + ) -> tuple[list[TI], DagCallbackRequest | None]: """ Determines the overall state of the DagRun based on the state of its TaskInstances. :param session: Sqlalchemy ORM Session - :param execute_callbacks: Should dag callbacks (success/failure, SLA etc) be invoked - directly (default: true) or recorded as a pending request in the ``callback`` property - :return: Tuple containing tis that can be scheduled in the current loop & `callback` that + :param execute_callbacks: Should dag callbacks (success/failure, SLA etc.) be invoked + directly (default: true) or recorded as a pending request in the ``returned_callback`` property + :return: Tuple containing tis that can be scheduled in the current loop & `returned_callback` that needs to be executed """ # Callback to execute in case of Task Failures - callback: Optional[DagCallbackRequest] = None + callback: DagCallbackRequest | None = None + + class _UnfinishedStates(NamedTuple): + tis: Sequence[TI] + + @classmethod + def calculate(cls, unfinished_tis: Sequence[TI]) -> _UnfinishedStates: + return cls(tis=unfinished_tis) + + @property + def should_schedule(self) -> bool: + return ( + bool(self.tis) + and all(not t.task.depends_on_past for t in self.tis) + and all(t.task.max_active_tis_per_dag is None for t in self.tis) + and all(t.state != TaskInstanceState.DEFERRED for t in self.tis) + ) + + def recalculate(self) -> _UnfinishedStates: + return self._replace(tis=[t for t in self.tis if t.state in State.unfinished]) start_dttm = timezone.utcnow() self.last_scheduling_decision = start_dttm @@ -525,72 +566,82 @@ def update_state( schedulable_tis = info.schedulable_tis changed_tis = info.changed_tis finished_tis = info.finished_tis - unfinished_tis = info.unfinished_tis + unfinished = _UnfinishedStates.calculate(info.unfinished_tis) - none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tis) - none_task_concurrency = all(t.task.max_active_tis_per_dag is None for t in unfinished_tis) - none_deferred = all(t.state != State.DEFERRED for t in unfinished_tis) - - if unfinished_tis and none_depends_on_past and none_task_concurrency and none_deferred: + if unfinished.should_schedule: + are_runnable_tasks = schedulable_tis or changed_tis # small speed up - are_runnable_tasks = ( - schedulable_tis - or self._are_premature_tis(unfinished_tis, finished_tis, session) - or changed_tis - ) + if not are_runnable_tasks: + are_runnable_tasks, changed_by_upstream = self._are_premature_tis( + unfinished.tis, finished_tis, session + ) + if changed_by_upstream: # Something changed, we need to recalculate! + unfinished = unfinished.recalculate() leaf_task_ids = {t.task_id for t in dag.leaves} leaf_tis = [ti for ti in tis if ti.task_id in leaf_task_ids if ti.state != TaskInstanceState.REMOVED] # if all roots finished and at least one failed, the run failed - if not unfinished_tis and any(leaf_ti.state in State.failed_states for leaf_ti in leaf_tis): - self.log.error('Marking run %s failed', self) + if not unfinished.tis and any(leaf_ti.state in State.failed_states for leaf_ti in leaf_tis): + self.log.error("Marking run %s failed", self) self.set_state(DagRunState.FAILED) + self.notify_dagrun_state_changed(msg="task_failure") + if execute_callbacks: - dag.handle_callback(self, success=False, reason='task_failure', session=session) + dag.handle_callback(self, success=False, reason="task_failure", session=session) elif dag.has_on_failure_callback: + from airflow.models.dag import DagModel + + dag_model = DagModel.get_dagmodel(dag.dag_id, session) callback = DagCallbackRequest( full_filepath=dag.fileloc, dag_id=self.dag_id, run_id=self.run_id, is_failure_callback=True, - msg='task_failure', + processor_subdir=dag_model.processor_subdir, + msg="task_failure", ) # if all leaves succeeded and no unfinished tasks, the run succeeded - elif not unfinished_tis and all(leaf_ti.state in State.success_states for leaf_ti in leaf_tis): - self.log.info('Marking run %s successful', self) + elif not unfinished.tis and all(leaf_ti.state in State.success_states for leaf_ti in leaf_tis): + self.log.info("Marking run %s successful", self) self.set_state(DagRunState.SUCCESS) + self.notify_dagrun_state_changed(msg="success") + if execute_callbacks: - dag.handle_callback(self, success=True, reason='success', session=session) + dag.handle_callback(self, success=True, reason="success", session=session) elif dag.has_on_success_callback: + from airflow.models.dag import DagModel + + dag_model = DagModel.get_dagmodel(dag.dag_id, session) callback = DagCallbackRequest( full_filepath=dag.fileloc, dag_id=self.dag_id, run_id=self.run_id, is_failure_callback=False, - msg='success', + processor_subdir=dag_model.processor_subdir, + msg="success", ) # if *all tasks* are deadlocked, the run failed - elif ( - unfinished_tis - and none_depends_on_past - and none_task_concurrency - and none_deferred - and not are_runnable_tasks - ): - self.log.error('Deadlock; marking run %s failed', self) + elif unfinished.should_schedule and not are_runnable_tasks: + self.log.error("Task deadlock (no runnable tasks); marking run %s failed", self) self.set_state(DagRunState.FAILED) + self.notify_dagrun_state_changed(msg="all_tasks_deadlocked") + if execute_callbacks: - dag.handle_callback(self, success=False, reason='all_tasks_deadlocked', session=session) + dag.handle_callback(self, success=False, reason="all_tasks_deadlocked", session=session) elif dag.has_on_failure_callback: + from airflow.models.dag import DagModel + + dag_model = DagModel.get_dagmodel(dag.dag_id, session) callback = DagCallbackRequest( full_filepath=dag.fileloc, dag_id=self.dag_id, run_id=self.run_id, is_failure_callback=True, - msg='all_tasks_deadlocked', + processor_subdir=dag_model.processor_subdir, + msg="all_tasks_deadlocked", ) # finally, if the roots aren't done, the dag is still running @@ -633,22 +684,23 @@ def update_state( @provide_session def task_instance_scheduling_decisions(self, session: Session = NEW_SESSION) -> TISchedulingDecision: + tis = self.get_task_instances(session=session, state=State.task_states) + self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis)) - schedulable_tis: List[TI] = [] - changed_tis = False + def _filter_tis_and_exclude_removed(dag: DAG, tis: list[TI]) -> Iterable[TI]: + """Populate ``ti.task`` while excluding those missing one, marking them as REMOVED.""" + for ti in tis: + try: + ti.task = dag.get_task(ti.task_id) + except TaskNotFound: + if ti.state != State.REMOVED: + self.log.error("Failed to get task for ti %s. Marking it as removed.", ti) + ti.state = State.REMOVED + session.flush() + else: + yield ti - tis = list(self.get_task_instances(session=session, state=State.task_states)) - self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis)) - dag = self.get_dag() - for ti in tis: - try: - ti.task = dag.get_task(ti.task_id) - except TaskNotFound: - self.log.warning( - "Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, ti.dag_id - ) - ti.state = State.REMOVED - session.flush() + tis = list(_filter_tis_and_exclude_removed(self.get_dag(), tis)) unfinished_tis = [t for t in tis if t.state in State.unfinished] finished_tis = [t for t in tis if t.state in State.finished] @@ -661,12 +713,16 @@ def task_instance_scheduling_decisions(self, session: Session = NEW_SESSION) -> session=session, ) - # During expansion we may change some tis into non-schedulable + # During expansion, we may change some tis into non-schedulable # states, so we need to re-compute. if expansion_happened: + changed_tis = True new_unfinished_tis = [t for t in unfinished_tis if t.state in State.unfinished] finished_tis.extend(t for t in unfinished_tis if t.state in State.finished) unfinished_tis = new_unfinished_tis + else: + schedulable_tis = [] + changed_tis = False return TISchedulingDecision( tis=tis, @@ -676,14 +732,25 @@ def task_instance_scheduling_decisions(self, session: Session = NEW_SESSION) -> finished_tis=finished_tis, ) + def notify_dagrun_state_changed(self, msg: str = ""): + if self.state == DagRunState.RUNNING: + get_listener_manager().hook.on_dag_run_running(dag_run=self, msg=msg) + elif self.state == DagRunState.SUCCESS: + get_listener_manager().hook.on_dag_run_success(dag_run=self, msg=msg) + elif self.state == DagRunState.FAILED: + get_listener_manager().hook.on_dag_run_failed(dag_run=self, msg=msg) + # deliberately not notifying on QUEUED + # we can't get all the state changes on SchedulerJob, BackfillJob + # or LocalTaskJob, so we don't want to "falsely advertise" we notify about that + def _get_ready_tis( self, - schedulable_tis: List[TI], - finished_tis: List[TI], + schedulable_tis: list[TI], + finished_tis: list[TI], session: Session, - ) -> Tuple[List[TI], bool, bool]: + ) -> tuple[list[TI], bool, bool]: old_states = {} - ready_tis: List[TI] = [] + ready_tis: list[TI] = [] changed_tis = False if not schedulable_tis: @@ -691,13 +758,34 @@ def _get_ready_tis( # If we expand TIs, we need a new list so that we iterate over them too. (We can't alter # `schedulable_tis` in place and have the `for` loop pick them up - additional_tis: List[TI] = [] + additional_tis: list[TI] = [] dep_context = DepContext( flag_upstream_failed=True, ignore_unmapped_tasks=True, # Ignore this Dep, as we will expand it if we can. finished_tis=finished_tis, ) + def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: + """Try to expand the ti, if needed. + + If the ti needs expansion, newly created task instances are + returned as well as the original ti. + The original ti is also modified in-place and assigned the + ``map_index`` of 0. + + If the ti does not need expansion, either because the task is not + mapped, or has already been expanded, *None* is returned. + """ + if ti.map_index >= 0: # Already expanded, we're good. + return None + try: + expanded_tis, _ = ti.task.expand_mapped_task(self.run_id, session=session) + except NotMapped: # Not a mapped task, nothing needed. + return None + if expanded_tis: + return expanded_tis + return () + # Check dependencies. expansion_happened = False for schedulable in itertools.chain(schedulable_tis, additional_tis): @@ -705,18 +793,20 @@ def _get_ready_tis( if not schedulable.are_dependencies_met(session=session, dep_context=dep_context): old_states[schedulable.key] = old_state continue - # If schedulable is from a mapped task, but not yet expanded, do it - # now. This is called in two places: First and ideally in the mini - # scheduler at the end of LocalTaskJob, and then as an "expansion of - # last resort" in the scheduler to ensure that the mapped task is - # correctly expanded before executed. - if schedulable.map_index < 0 and isinstance(schedulable.task, MappedOperator): - expanded_tis, _ = schedulable.task.expand_mapped_task(self.run_id, session=session) - if expanded_tis: - assert expanded_tis[0] is schedulable - additional_tis.extend(expanded_tis[1:]) - expansion_happened = True - if schedulable.state in SCHEDULEABLE_STATES: + # If schedulable is not yet expanded, try doing it now. This is + # called in two places: First and ideally in the mini scheduler at + # the end of LocalTaskJob, and then as an "expansion of last resort" + # in the scheduler to ensure that the mapped task is correctly + # expanded before executed. Also see _revise_map_indexes_if_mapped + # docstring for additional information. + new_tis = None + if schedulable.map_index < 0: + new_tis = _expand_mapped_task_if_needed(schedulable) + if new_tis is not None: + additional_tis.extend(new_tis) + expansion_happened = True + if new_tis is None and schedulable.state in SCHEDULEABLE_STATES: + ready_tis.extend(self._revise_map_indexes_if_mapped(schedulable.task, session=session)) ready_tis.append(schedulable) # Check if any ti changed state @@ -729,26 +819,24 @@ def _get_ready_tis( def _are_premature_tis( self, - unfinished_tis: List[TI], - finished_tis: List[TI], + unfinished_tis: Sequence[TI], + finished_tis: list[TI], session: Session, - ) -> bool: - # there might be runnable tasks that are up for retry and for some reason(retry delay, etc) are - # not ready yet so we set the flags to count them in - for ut in unfinished_tis: - if ut.are_dependencies_met( - dep_context=DepContext( - flag_upstream_failed=True, - ignore_in_retry_period=True, - ignore_in_reschedule_period=True, - finished_tis=finished_tis, - ), - session=session, - ): - return True - return False + ) -> tuple[bool, bool]: + dep_context = DepContext( + flag_upstream_failed=True, + ignore_in_retry_period=True, + ignore_in_reschedule_period=True, + finished_tis=finished_tis, + ) + # there might be runnable tasks that are up for retry and for some reason(retry delay, etc.) are + # not ready yet, so we set the flags to count them in + return ( + any(ut.are_dependencies_met(dep_context=dep_context, session=session) for ut in unfinished_tis), + dep_context.have_changed_ti_states, + ) - def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis: List[TI]) -> None: + def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis: list[TI]) -> None: """ This is a helper method to emit the true scheduling delay stats, which is defined as the time when the first task in DAG starts minus the expected DAG run datetime. @@ -756,7 +844,7 @@ def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis: Lis is updated to a completed status (either success or failure). The method will find the first started task within the DAG and calculate the expected DagRun start time (based on dag.execution_date & dag.timetable), and minus these two values to get the delay. - The emitted data may contains outlier (e.g. when the first task was cleared, so + The emitted data may contain outlier (e.g. when the first task was cleared, so the second task's start_date will be used), but we can get rid of the outliers on the stats side through the dashboards tooling built. Note, the stat will only be emitted if the DagRun is a scheduler triggered one @@ -778,7 +866,7 @@ def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis: Lis ordered_tis_by_start_date = [ti for ti in finished_tis if ti.start_date] ordered_tis_by_start_date.sort(key=lambda ti: ti.start_date, reverse=False) - first_start_date = ordered_tis_by_start_date[0].start_date + first_start_date = ordered_tis_by_start_date[0].start_date if ordered_tis_by_start_date else None if first_start_date: # TODO: Logically, this should be DagRunInfo.run_after, but the # information is not stored on a DagRun, only before the actual @@ -788,45 +876,81 @@ def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis: Lis data_interval_end = dag.get_run_data_interval(self).end true_delay = first_start_date - data_interval_end if true_delay.total_seconds() > 0: - Stats.timing(f'dagrun.{dag.dag_id}.first_task_scheduling_delay', true_delay) + Stats.timing(f"dagrun.{dag.dag_id}.first_task_scheduling_delay", true_delay) except Exception: - self.log.warning('Failed to record first_task_scheduling_delay metric:', exc_info=True) + self.log.warning("Failed to record first_task_scheduling_delay metric:", exc_info=True) def _emit_duration_stats_for_finished_state(self): if self.state == State.RUNNING: return if self.start_date is None: - self.log.warning('Failed to record duration of %s: start_date is not set.', self) + self.log.warning("Failed to record duration of %s: start_date is not set.", self) return if self.end_date is None: - self.log.warning('Failed to record duration of %s: end_date is not set.', self) + self.log.warning("Failed to record duration of %s: end_date is not set.", self) return duration = self.end_date - self.start_date if self.state == State.SUCCESS: - Stats.timing(f'dagrun.duration.success.{self.dag_id}', duration) + Stats.timing(f"dagrun.duration.success.{self.dag_id}", duration) elif self.state == State.FAILED: - Stats.timing(f'dagrun.duration.failed.{self.dag_id}', duration) + Stats.timing(f"dagrun.duration.failed.{self.dag_id}", duration) @provide_session - def verify_integrity(self, session: Session = NEW_SESSION): + def verify_integrity(self, *, session: Session = NEW_SESSION) -> None: """ Verifies the DagRun by checking for removed tasks or tasks that are not in the database yet. It will set state to removed or add the task if required. + :missing_indexes: A dictionary of task vs indexes that are missing. :param session: Sqlalchemy ORM Session """ from airflow.settings import task_instance_mutation_hook + # Set for the empty default in airflow.settings -- if it's not set this means it has been changed + # Note: Literal[True, False] instead of bool because otherwise it doesn't correctly find the overload. + hook_is_noop: Literal[True, False] = getattr(task_instance_mutation_hook, "is_noop", False) + dag = self.get_dag() + task_ids = self._check_for_removed_or_restored_tasks( + dag, task_instance_mutation_hook, session=session + ) + + def task_filter(task: Operator) -> bool: + return task.task_id not in task_ids and ( + self.is_backfill + or task.start_date <= self.execution_date + and (task.end_date is None or self.execution_date <= task.end_date) + ) + + created_counts: dict[str, int] = defaultdict(int) + task_creator = self._get_task_creator(created_counts, task_instance_mutation_hook, hook_is_noop) + + # Create the missing tasks, including mapped tasks + tasks_to_create = (task for task in dag.task_dict.values() if task_filter(task)) + tis_to_create = self._create_tasks(tasks_to_create, task_creator, session=session) + self._create_task_instances(self.dag_id, tis_to_create, created_counts, hook_is_noop, session=session) + + def _check_for_removed_or_restored_tasks( + self, dag: DAG, ti_mutation_hook, *, session: Session + ) -> set[str]: + """ + Check for removed tasks/restored/missing tasks. + + :param dag: DAG object corresponding to the dagrun + :param ti_mutation_hook: task_instance_mutation_hook function + :param session: Sqlalchemy ORM Session + + :return: Task IDs in the DAG run + + """ tis = self.get_task_instances(session=session) # check for removed or restored tasks task_ids = set() for ti in tis: - task_instance_mutation_hook(ti) + ti_mutation_hook(ti) task_ids.add(ti.task_id) - task = None try: task = dag.get_task(ti.task_id) @@ -844,31 +968,15 @@ def verify_integrity(self, session: Session = NEW_SESSION): ti.state = State.REMOVED continue - if not task.is_mapped: + try: + num_mapped_tis = task.get_parse_time_mapped_ti_count() + except NotMapped: continue - task = cast("MappedOperator", task) - num_mapped_tis = task.parse_time_mapped_ti_count - # Check if the number of mapped literals has changed and we need to mark this TI as removed - if num_mapped_tis is not None: - if ti.map_index >= num_mapped_tis: - self.log.debug( - "Removing task '%s' as the map_index is longer than the literal mapping list (%s)", - ti, - num_mapped_tis, - ) - ti.state = State.REMOVED - elif ti.map_index < 0: - self.log.debug("Removing the unmapped TI '%s' as the mapping can now be performed", ti) - ti.state = State.REMOVED - else: - self.log.info("Restoring mapped task '%s'", ti) - Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1) - ti.state = State.NONE - else: - # What if it is _now_ dynamically mapped, but wasn't before? - total_length = task.run_time_mapped_ti_count(self.run_id, session=session) - - if total_length is None: + except NotFullyPopulated: + # What if it is _now_ dynamically mapped, but wasn't before? + try: + total_length = task.get_mapped_ti_count(self.run_id, session=session) + except NotFullyPopulated: # Not all upstreams finished, so we can't tell what should be here. Remove everything. if ti.map_index >= 0: self.log.debug( @@ -884,23 +992,58 @@ def verify_integrity(self, session: Session = NEW_SESSION): total_length, ) ti.state = State.REMOVED - ... + else: + # Check if the number of mapped literals has changed, and we need to mark this TI as removed. + if ti.map_index >= num_mapped_tis: + self.log.debug( + "Removing task '%s' as the map_index is longer than the literal mapping list (%s)", + ti, + num_mapped_tis, + ) + ti.state = State.REMOVED + elif ti.map_index < 0: + self.log.debug("Removing the unmapped TI '%s' as the mapping can now be performed", ti) + ti.state = State.REMOVED - def task_filter(task: "Operator") -> bool: - return task.task_id not in task_ids and ( - self.is_backfill - or task.start_date <= self.execution_date - and (task.end_date is None or self.execution_date <= task.end_date) - ) + return task_ids - created_counts: Dict[str, int] = defaultdict(int) + @overload + def _get_task_creator( + self, + created_counts: dict[str, int], + ti_mutation_hook: Callable, + hook_is_noop: Literal[True], + ) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]]]: + ... + + @overload + def _get_task_creator( + self, + created_counts: dict[str, int], + ti_mutation_hook: Callable, + hook_is_noop: Literal[False], + ) -> Callable[[Operator, Iterable[int]], Iterator[TI]]: + ... - # Set for the empty default in airflow.settings -- if it's not set this means it has been changed - hook_is_noop = getattr(task_instance_mutation_hook, 'is_noop', False) + def _get_task_creator( + self, + created_counts: dict[str, int], + ti_mutation_hook: Callable, + hook_is_noop: Literal[True, False], + ) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]] | Iterator[TI]]: + """ + Get the task creator function. + + This function also updates the created_counts dictionary with the number of tasks created. + + :param created_counts: Dictionary of task_type -> count of created TIs + :param ti_mutation_hook: task_instance_mutation_hook function + :param hook_is_noop: Whether the task_instance_mutation_hook is a noop + """ if hook_is_noop: - def create_ti_mapping(task: "Operator", indexes: Tuple[int, ...]) -> Generator: + def create_ti_mapping(task: Operator, indexes: Iterable[int]) -> Iterator[dict[str, Any]]: created_counts[task.task_type] += 1 for map_index in indexes: yield TI.insert_mapping(self.run_id, task, map_index=map_index) @@ -909,30 +1052,67 @@ def create_ti_mapping(task: "Operator", indexes: Tuple[int, ...]) -> Generator: else: - def create_ti(task: "Operator", indexes: Tuple[int, ...]) -> Generator: + def create_ti(task: Operator, indexes: Iterable[int]) -> Iterator[TI]: for map_index in indexes: ti = TI(task, run_id=self.run_id, map_index=map_index) - task_instance_mutation_hook(ti) + ti_mutation_hook(ti) created_counts[ti.operator] += 1 yield ti creator = create_ti + return creator - # Create missing tasks -- and expand any MappedOperator that _only_ have literals as input - def expand_mapped_literals(task: "Operator") -> Tuple["Operator", Sequence[int]]: - if not task.is_mapped: - return (task, (-1,)) - task = cast("MappedOperator", task) - count = task.parse_time_mapped_ti_count or task.run_time_mapped_ti_count( - self.run_id, session=session - ) - if not count: - return (task, (-1,)) - return (task, range(count)) + def _create_tasks( + self, + tasks: Iterable[Operator], + task_creator: TaskCreator, + *, + session: Session, + ) -> CreatedTasks: + """ + Create missing tasks -- and expand any MappedOperator that _only_ have literals as input - tasks_and_map_idxs = map(expand_mapped_literals, filter(task_filter, dag.task_dict.values())) - tasks = itertools.chain.from_iterable(itertools.starmap(creator, tasks_and_map_idxs)) + :param tasks: Tasks to create jobs for in the DAG run + :param task_creator: Function to create task instances + """ + map_indexes: Iterable[int] + for task in tasks: + try: + count = task.get_mapped_ti_count(self.run_id, session=session) + except (NotMapped, NotFullyPopulated): + map_indexes = (-1,) + else: + if count: + map_indexes = range(count) + else: + # Make sure to always create at least one ti; this will be + # marked as REMOVED later at runtime. + map_indexes = (-1,) + yield from task_creator(task, map_indexes) + + def _create_task_instances( + self, + dag_id: str, + tasks: Iterator[dict[str, Any]] | Iterator[TI], + created_counts: dict[str, int], + hook_is_noop: bool, + *, + session: Session, + ) -> None: + """ + Create the necessary task instances from the given tasks. + :param dag_id: DAG ID associated with the dagrun + :param tasks: the tasks to create the task instances from + :param created_counts: a dictionary of number of tasks -> total ti created by the task creator + :param hook_is_noop: whether the task_instance_mutation_hook is noop + :param session: the session to use + + """ + # Fetch the information we need before handling the exception to avoid + # PendingRollbackError due to the session being invalidated on exception + # see https://github.com/apache/superset/pull/530 + run_id = self.run_id try: if hook_is_noop: session.bulk_insert_mappings(TI, tasks) @@ -944,17 +1124,62 @@ def expand_mapped_literals(task: "Operator") -> Tuple["Operator", Sequence[int]] session.flush() except IntegrityError: self.log.info( - 'Hit IntegrityError while creating the TIs for %s- %s', - dag.dag_id, - self.run_id, + "Hit IntegrityError while creating the TIs for %s- %s", + dag_id, + run_id, exc_info=True, ) - self.log.info('Doing session rollback.') + self.log.info("Doing session rollback.") # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive. session.rollback() + def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) -> Iterator[TI]: + """Check if task increased or reduced in length and handle appropriately. + + Task instances that do not already exist are created and returned if + possible. Expansion only happens if all upstreams are ready; otherwise + we delay expansion to the "last resort". See comments at the call site + for more details. + """ + from airflow.settings import task_instance_mutation_hook + + try: + total_length = task.get_mapped_ti_count(self.run_id, session=session) + except NotMapped: + return # Not a mapped task, don't need to do anything. + except NotFullyPopulated: + return # Upstreams not ready, don't need to revise this yet. + + query = session.query(TI.map_index).filter( + TI.dag_id == self.dag_id, + TI.task_id == task.task_id, + TI.run_id == self.run_id, + ) + existing_indexes = {i for (i,) in query} + + removed_indexes = existing_indexes.difference(range(total_length)) + if removed_indexes: + session.query(TI).filter( + TI.dag_id == self.dag_id, + TI.task_id == task.task_id, + TI.run_id == self.run_id, + TI.map_index.in_(removed_indexes), + ).update({TI.state: TaskInstanceState.REMOVED}) + session.flush() + + for index in range(total_length): + if index in existing_indexes: + continue + ti = TI(task, run_id=self.run_id, map_index=index, state=None) + self.log.debug("Expanding TIs upserted %s", ti) + task_instance_mutation_hook(ti) + ti = session.merge(ti) + ti.refresh_from_task(task) + session.flush() + yield ti + @staticmethod - def get_run(session: Session, dag_id: str, execution_date: datetime) -> Optional['DagRun']: + def get_run(session: Session, dag_id: str, execution_date: datetime) -> DagRun | None: """ Get a single DAG Run @@ -964,11 +1189,10 @@ def get_run(session: Session, dag_id: str, execution_date: datetime) -> Optional :param execution_date: execution date :return: DagRun corresponding to the given dag_id and execution date if one exists. None otherwise. - :rtype: airflow.models.DagRun """ warnings.warn( "This method is deprecated. Please use SQLAlchemy directly", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) return ( @@ -987,10 +1211,10 @@ def is_backfill(self) -> bool: @classmethod @provide_session - def get_latest_runs(cls, session=None) -> List['DagRun']: + def get_latest_runs(cls, session=None) -> list[DagRun]: """Returns the latest DagRun for each DAG""" subquery = ( - session.query(cls.dag_id, func.max(cls.execution_date).label('execution_date')) + session.query(cls.dag_id, func.max(cls.execution_date).label("execution_date")) .group_by(cls.dag_id) .subquery() ) @@ -1010,7 +1234,7 @@ def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = NEW_SES Each element of ``schedulable_tis`` should have it's ``task`` attribute already set. - Any EmptyOperator without callbacks is instead set straight to the success state. + Any EmptyOperator without callbacks or outlets is instead set straight to the success state. All the TIs should belong to this DagRun, but this code is in the hot-path, this is not checked -- it is the caller's responsibility to call this function only with TIs from a single dag run. @@ -1024,6 +1248,7 @@ def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = NEW_SES ti.task.inherits_from_empty_operator and not ti.task.on_execute_callback and not ti.task.on_success_callback + and not ti.task.outlets ): dummy_ti_ids.append(ti.task_id) else: @@ -1065,14 +1290,62 @@ def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = NEW_SES return count @provide_session - def get_log_filename_template(self, *, session: Session = NEW_SESSION) -> str: + def get_log_template(self, *, session: Session = NEW_SESSION) -> LogTemplate: if self.log_template_id is None: # DagRun created before LogTemplate introduction. - template = session.query(LogTemplate.filename).order_by(LogTemplate.id).limit(1).scalar() + template = session.query(LogTemplate).order_by(LogTemplate.id).first() else: - template = session.query(LogTemplate.filename).filter_by(id=self.log_template_id).scalar() + template = session.query(LogTemplate).get(self.log_template_id) if template is None: raise AirflowException( f"No log_template entry found for ID {self.log_template_id!r}. " f"Please make sure you set up the metadatabase correctly." ) return template + + @provide_session + def get_log_filename_template(self, *, session: Session = NEW_SESSION) -> str: + warnings.warn( + "This method is deprecated. Please use get_log_template instead.", + RemovedInAirflow3Warning, + stacklevel=2, + ) + return self.get_log_template(session=session).filename + + +class DagRunNote(Base): + """For storage of arbitrary notes concerning the dagrun instance.""" + + __tablename__ = "dag_run_note" + + user_id = Column(Integer, nullable=True) + dag_run_id = Column(Integer, primary_key=True, nullable=False) + content = Column(String(1000).with_variant(Text(1000), "mysql")) + created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) + + dag_run = relationship("DagRun", back_populates="dag_run_note") + + __table_args__ = ( + PrimaryKeyConstraint("dag_run_id", name="dag_run_note_pkey"), + ForeignKeyConstraint( + (dag_run_id,), + ["dag_run.id"], + name="dag_run_note_dr_fkey", + ondelete="CASCADE", + ), + ForeignKeyConstraint( + (user_id,), + ["ab_user.id"], + name="dag_run_note_user_fkey", + ), + ) + + def __init__(self, content, user_id=None): + self.content = content + self.user_id = user_id + + def __repr__(self): + prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.dagrun_id} {self.run_id}" + if self.map_index != -1: + prefix += f" map_index={self.map_index}" + return prefix + ">" diff --git a/airflow/models/dagwarning.py b/airflow/models/dagwarning.py new file mode 100644 index 0000000000000..db93aafb0fba5 --- /dev/null +++ b/airflow/models/dagwarning.py @@ -0,0 +1,100 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from enum import Enum + +from sqlalchemy import Column, ForeignKeyConstraint, String, Text, false +from sqlalchemy.orm import Session + +from airflow.models.base import Base, StringID +from airflow.utils import timezone +from airflow.utils.retries import retry_db_transaction +from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.sqlalchemy import UtcDateTime + + +class DagWarning(Base): + """ + A table to store DAG warnings. + + DAG warnings are problems that don't rise to the level of failing the DAG parse + but which users should nonetheless be warned about. These warnings are recorded + when parsing DAG and displayed on the Webserver in a flash message. + """ + + dag_id = Column(StringID(), primary_key=True) + warning_type = Column(String(50), primary_key=True) + message = Column(Text, nullable=False) + timestamp = Column(UtcDateTime, nullable=False, default=timezone.utcnow) + + __tablename__ = "dag_warning" + __table_args__ = ( + ForeignKeyConstraint( + ("dag_id",), + ["dag.dag_id"], + name="dcw_dag_id_fkey", + ondelete="CASCADE", + ), + ) + + def __init__(self, dag_id: str, error_type: str, message: str, **kwargs): + super().__init__(**kwargs) + self.dag_id = dag_id + self.warning_type = DagWarningType(error_type).value # make sure valid type + self.message = message + + def __eq__(self, other) -> bool: + return self.dag_id == other.dag_id and self.warning_type == other.warning_type + + def __hash__(self) -> int: + return hash((self.dag_id, self.warning_type)) + + @classmethod + @provide_session + def purge_inactive_dag_warnings(cls, session: Session = NEW_SESSION) -> None: + """ + Deactivate DagWarning records for inactive dags. + + :return: None + """ + cls._purge_inactive_dag_warnings_with_retry(session) + + @classmethod + @retry_db_transaction + def _purge_inactive_dag_warnings_with_retry(cls, session: Session) -> None: + from airflow.models.dag import DagModel + + if session.get_bind().dialect.name == "sqlite": + dag_ids = session.query(DagModel.dag_id).filter(DagModel.is_active == false()) + query = session.query(cls).filter(cls.dag_id.in_(dag_ids)) + else: + query = session.query(cls).filter(cls.dag_id == DagModel.dag_id, DagModel.is_active == false()) + query.delete(synchronize_session=False) + session.commit() + + +class DagWarningType(str, Enum): + """ + Enum for DAG warning types. + + This is the set of allowable values for the ``warning_type`` field + in the DagWarning model. + """ + + NONEXISTENT_POOL = "non-existent pool" diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py new file mode 100644 index 0000000000000..4cd370386c4c0 --- /dev/null +++ b/airflow/models/dataset.py @@ -0,0 +1,338 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from urllib.parse import urlsplit + +import sqlalchemy_jsonfield +from sqlalchemy import ( + Boolean, + Column, + ForeignKey, + ForeignKeyConstraint, + Index, + Integer, + PrimaryKeyConstraint, + String, + Table, + text, +) +from sqlalchemy.orm import relationship + +from airflow.datasets import Dataset +from airflow.models.base import Base, StringID +from airflow.settings import json +from airflow.utils import timezone +from airflow.utils.sqlalchemy import UtcDateTime + + +class DatasetModel(Base): + """ + A table to store datasets. + + :param uri: a string that uniquely identifies the dataset + :param extra: JSON field for arbitrary extra info + """ + + id = Column(Integer, primary_key=True, autoincrement=True) + uri = Column( + String(length=3000).with_variant( + String( + length=3000, + # latin1 allows for more indexed length in mysql + # and this field should only be ascii chars + collation="latin1_general_cs", + ), + "mysql", + ), + nullable=False, + ) + extra = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}) + created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) + is_orphaned = Column(Boolean, default=False, nullable=False, server_default="0") + + consuming_dags = relationship("DagScheduleDatasetReference", back_populates="dataset") + producing_tasks = relationship("TaskOutletDatasetReference", back_populates="dataset") + + __tablename__ = "dataset" + __table_args__ = ( + Index("idx_uri_unique", uri, unique=True), + {"sqlite_autoincrement": True}, # ensures PK values not reused + ) + + @classmethod + def from_public(cls, obj: Dataset) -> DatasetModel: + return cls(uri=obj.uri, extra=obj.extra) + + def __init__(self, uri: str, **kwargs): + try: + uri.encode("ascii") + except UnicodeEncodeError: + raise ValueError("URI must be ascii") + parsed = urlsplit(uri) + if parsed.scheme and parsed.scheme.lower() == "airflow": + raise ValueError("Scheme `airflow` is reserved.") + super().__init__(uri=uri, **kwargs) + + def __eq__(self, other): + if isinstance(other, (self.__class__, Dataset)): + return self.uri == other.uri + else: + return NotImplemented + + def __hash__(self): + return hash(self.uri) + + def __repr__(self): + return f"{self.__class__.__name__}(uri={self.uri!r}, extra={self.extra!r})" + + +class DagScheduleDatasetReference(Base): + """References from a DAG to a dataset of which it is a consumer.""" + + dataset_id = Column(Integer, primary_key=True, nullable=False) + dag_id = Column(StringID(), primary_key=True, nullable=False) + created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) + + dataset = relationship("DatasetModel", back_populates="consuming_dags") + queue_records = relationship( + "DatasetDagRunQueue", + primaryjoin="""and_( + DagScheduleDatasetReference.dataset_id == foreign(DatasetDagRunQueue.dataset_id), + DagScheduleDatasetReference.dag_id == foreign(DatasetDagRunQueue.target_dag_id), + )""", + cascade="all, delete, delete-orphan", + ) + + __tablename__ = "dag_schedule_dataset_reference" + __table_args__ = ( + PrimaryKeyConstraint(dataset_id, dag_id, name="dsdr_pkey", mssql_clustered=True), + ForeignKeyConstraint( + (dataset_id,), + ["dataset.id"], + name="dsdr_dataset_fkey", + ondelete="CASCADE", + ), + ForeignKeyConstraint( + columns=(dag_id,), + refcolumns=["dag.dag_id"], + name="dsdr_dag_id_fkey", + ondelete="CASCADE", + ), + ) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.dataset_id == other.dataset_id and self.dag_id == other.dag_id + else: + return NotImplemented + + def __hash__(self): + return hash(self.__mapper__.primary_key) + + def __repr__(self): + args = [] + for attr in [x.name for x in self.__mapper__.primary_key]: + args.append(f"{attr}={getattr(self, attr)!r}") + return f"{self.__class__.__name__}({', '.join(args)})" + + +class TaskOutletDatasetReference(Base): + """References from a task to a dataset that it updates / produces.""" + + dataset_id = Column(Integer, primary_key=True, nullable=False) + dag_id = Column(StringID(), primary_key=True, nullable=False) + task_id = Column(StringID(), primary_key=True, nullable=False) + created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) + + dataset = relationship("DatasetModel", back_populates="producing_tasks") + + __tablename__ = "task_outlet_dataset_reference" + __table_args__ = ( + ForeignKeyConstraint( + (dataset_id,), + ["dataset.id"], + name="todr_dataset_fkey", + ondelete="CASCADE", + ), + PrimaryKeyConstraint(dataset_id, dag_id, task_id, name="todr_pkey", mssql_clustered=True), + ForeignKeyConstraint( + columns=(dag_id,), + refcolumns=["dag.dag_id"], + name="todr_dag_id_fkey", + ondelete="CASCADE", + ), + ) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return ( + self.dataset_id == other.dataset_id + and self.dag_id == other.dag_id + and self.task_id == other.task_id + ) + else: + return NotImplemented + + def __hash__(self): + return hash(self.__mapper__.primary_key) + + def __repr__(self): + args = [] + for attr in [x.name for x in self.__mapper__.primary_key]: + args.append(f"{attr}={getattr(self, attr)!r}") + return f"{self.__class__.__name__}({', '.join(args)})" + + +class DatasetDagRunQueue(Base): + """Model for storing dataset events that need processing.""" + + dataset_id = Column(Integer, primary_key=True, nullable=False) + target_dag_id = Column(StringID(), primary_key=True, nullable=False) + created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + + __tablename__ = "dataset_dag_run_queue" + __table_args__ = ( + PrimaryKeyConstraint(dataset_id, target_dag_id, name="datasetdagrunqueue_pkey", mssql_clustered=True), + ForeignKeyConstraint( + (dataset_id,), + ["dataset.id"], + name="ddrq_dataset_fkey", + ondelete="CASCADE", + ), + ForeignKeyConstraint( + (target_dag_id,), + ["dag.dag_id"], + name="ddrq_dag_fkey", + ondelete="CASCADE", + ), + ) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.dataset_id == other.dataset_id and self.target_dag_id == other.target_dag_id + else: + return NotImplemented + + def __hash__(self): + return hash(self.__mapper__.primary_key) + + def __repr__(self): + args = [] + for attr in [x.name for x in self.__mapper__.primary_key]: + args.append(f"{attr}={getattr(self, attr)!r}") + return f"{self.__class__.__name__}({', '.join(args)})" + + +association_table = Table( + "dagrun_dataset_event", + Base.metadata, + Column("dag_run_id", ForeignKey("dag_run.id", ondelete="CASCADE"), primary_key=True), + Column("event_id", ForeignKey("dataset_event.id", ondelete="CASCADE"), primary_key=True), + Index("idx_dagrun_dataset_events_dag_run_id", "dag_run_id"), + Index("idx_dagrun_dataset_events_event_id", "event_id"), +) + + +class DatasetEvent(Base): + """ + A table to store datasets events. + + :param dataset_id: reference to DatasetModel record + :param extra: JSON field for arbitrary extra info + :param source_task_id: the task_id of the TI which updated the dataset + :param source_dag_id: the dag_id of the TI which updated the dataset + :param source_run_id: the run_id of the TI which updated the dataset + :param source_map_index: the map_index of the TI which updated the dataset + :param timestamp: the time the event was logged + + We use relationships instead of foreign keys so that dataset events are not deleted even + if the foreign key object is. + """ + + id = Column(Integer, primary_key=True, autoincrement=True) + dataset_id = Column(Integer, nullable=False) + extra = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}) + source_task_id = Column(StringID(), nullable=True) + source_dag_id = Column(StringID(), nullable=True) + source_run_id = Column(StringID(), nullable=True) + source_map_index = Column(Integer, nullable=True, server_default=text("-1")) + timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + + __tablename__ = "dataset_event" + __table_args__ = ( + Index("idx_dataset_id_timestamp", dataset_id, timestamp), + {"sqlite_autoincrement": True}, # ensures PK values not reused + ) + + created_dagruns = relationship( + "DagRun", + secondary=association_table, + backref="consumed_dataset_events", + ) + + source_task_instance = relationship( + "TaskInstance", + primaryjoin="""and_( + DatasetEvent.source_dag_id == foreign(TaskInstance.dag_id), + DatasetEvent.source_run_id == foreign(TaskInstance.run_id), + DatasetEvent.source_task_id == foreign(TaskInstance.task_id), + DatasetEvent.source_map_index == foreign(TaskInstance.map_index), + )""", + viewonly=True, + lazy="select", + uselist=False, + ) + source_dag_run = relationship( + "DagRun", + primaryjoin="""and_( + DatasetEvent.source_dag_id == foreign(DagRun.dag_id), + DatasetEvent.source_run_id == foreign(DagRun.run_id), + )""", + viewonly=True, + lazy="select", + uselist=False, + ) + dataset = relationship( + DatasetModel, + primaryjoin="DatasetEvent.dataset_id == foreign(DatasetModel.id)", + viewonly=True, + lazy="select", + uselist=False, + ) + + @property + def uri(self): + return self.dataset.uri + + def __repr__(self) -> str: + args = [] + for attr in [ + "id", + "dataset_id", + "extra", + "source_task_id", + "source_dag_id", + "source_run_id", + "source_map_index", + ]: + args.append(f"{attr}={getattr(self, attr)!r}") + return f"{self.__class__.__name__}({', '.join(args)})" diff --git a/airflow/models/db_callback_request.py b/airflow/models/db_callback_request.py index 4fdd36a71be4b..5a264dee9abc0 100644 --- a/airflow/models/db_callback_request.py +++ b/airflow/models/db_callback_request.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from importlib import import_module @@ -36,11 +37,12 @@ class DbCallbackRequest(Base): priority_weight = Column(Integer(), nullable=False) callback_data = Column(ExtendedJSON, nullable=False) callback_type = Column(String(20), nullable=False) - dag_directory = Column(String(1000), nullable=True) + processor_subdir = Column(String(2000), nullable=True) def __init__(self, priority_weight: int, callback: CallbackRequest): self.created_at = timezone.utcnow() self.priority_weight = priority_weight + self.processor_subdir = callback.processor_subdir self.callback_data = callback.to_json() self.callback_type = callback.__class__.__name__ @@ -48,5 +50,5 @@ def get_callback_request(self) -> CallbackRequest: module = import_module("airflow.callbacks.callback_requests") callback_class = getattr(module, self.callback_type) # Get the function (from the instance) that we need to call - from_json = getattr(callback_class, 'from_json') + from_json = getattr(callback_class, "from_json") return from_json(self.callback_data) diff --git a/airflow/models/errors.py b/airflow/models/errors.py index 9718c063a28f0..974d9b8eebbe5 100644 --- a/airflow/models/errors.py +++ b/airflow/models/errors.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from sqlalchemy import Column, Integer, String, Text diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py new file mode 100644 index 0000000000000..8a9a3d874012d --- /dev/null +++ b/airflow/models/expandinput.py @@ -0,0 +1,284 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import collections.abc +import functools +import operator +from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, NamedTuple, Sequence, Sized, Union + +import attr + +from airflow.typing_compat import TypeGuard +from airflow.utils.context import Context +from airflow.utils.mixins import ResolveMixin +from airflow.utils.session import NEW_SESSION, provide_session + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from airflow.models.operator import Operator + from airflow.models.xcom_arg import XComArg + +ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"] + +# Each keyword argument to expand() can be an XComArg, sequence, or dict (not +# any mapping since we need the value to be ordered). +OperatorExpandArgument = Union["MappedArgument", "XComArg", Sequence, Dict[str, Any]] + +# The single argument of expand_kwargs() can be an XComArg, or a list with each +# element being either an XComArg or a dict. +OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]] + + +@attr.define(kw_only=True) +class MappedArgument(ResolveMixin): + """Stand-in stub for task-group-mapping arguments. + + This is very similar to an XComArg, but resolved differently. Declared here + (instead of in the task group module) to avoid import cycles. + """ + + _input: ExpandInput + _key: str + + def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: + # TODO (AIP-42): Implement run-time task map length inspection. This is + # needed when we implement task mapping inside a mapped task group. + raise NotImplementedError() + + def iter_references(self) -> Iterable[tuple[Operator, str]]: + yield from self._input.iter_references() + + @provide_session + def resolve(self, context: Context, *, session: Session = NEW_SESSION) -> Any: + data, _ = self._input.resolve(context, session=session) + return data[self._key] + + +# To replace tedious isinstance() checks. +def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]: + from airflow.models.xcom_arg import XComArg + + return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str) + + +# To replace tedious isinstance() checks. +def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]: + from airflow.models.xcom_arg import XComArg + + return not isinstance(v, (MappedArgument, XComArg)) + + +# To replace tedious isinstance() checks. +def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]: + from airflow.models.xcom_arg import XComArg + + return isinstance(v, (MappedArgument, XComArg)) + + +class NotFullyPopulated(RuntimeError): + """Raise when ``get_map_lengths`` cannot populate all mapping metadata. + + This is generally due to not all upstream tasks have finished when the + function is called. + """ + + def __init__(self, missing: set[str]) -> None: + self.missing = missing + + def __str__(self) -> str: + keys = ", ".join(repr(k) for k in sorted(self.missing)) + return f"Failed to populate all mapping metadata; missing: {keys}" + + +class DictOfListsExpandInput(NamedTuple): + """Storage type of a mapped operator's mapped kwargs. + + This is created from ``expand(**kwargs)``. + """ + + value: dict[str, OperatorExpandArgument] + + def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: + """Generate kwargs with values available on parse-time.""" + return ((k, v) for k, v in self.value.items() if _is_parse_time_mappable(v)) + + def get_parse_time_mapped_ti_count(self) -> int: + if not self.value: + return 0 + literal_values = [len(v) for _, v in self._iter_parse_time_resolved_kwargs()] + if len(literal_values) != len(self.value): + literal_keys = (k for k, _ in self._iter_parse_time_resolved_kwargs()) + raise NotFullyPopulated(set(self.value).difference(literal_keys)) + return functools.reduce(operator.mul, literal_values, 1) + + def _get_map_lengths(self, run_id: str, *, session: Session) -> dict[str, int]: + """Return dict of argument name to map length. + + If any arguments are not known right now (upstream task not finished), + they will not be present in the dict. + """ + # TODO: This initiates one database call for each XComArg. Would it be + # more efficient to do one single db call and unpack the value here? + def _get_length(v: OperatorExpandArgument) -> int | None: + if _needs_run_time_resolution(v): + return v.get_task_map_length(run_id, session=session) + # Unfortunately a user-defined TypeGuard cannot apply negative type + # narrowing. https://github.com/python/typing/discussions/1013 + if TYPE_CHECKING: + assert isinstance(v, Sized) + return len(v) + + map_lengths_iterator = ((k, _get_length(v)) for k, v in self.value.items()) + + map_lengths = {k: v for k, v in map_lengths_iterator if v is not None} + if len(map_lengths) < len(self.value): + raise NotFullyPopulated(set(self.value).difference(map_lengths)) + return map_lengths + + def get_total_map_length(self, run_id: str, *, session: Session) -> int: + if not self.value: + return 0 + lengths = self._get_map_lengths(run_id, session=session) + return functools.reduce(operator.mul, (lengths[name] for name in self.value), 1) + + def _expand_mapped_field(self, key: str, value: Any, context: Context, *, session: Session) -> Any: + if _needs_run_time_resolution(value): + value = value.resolve(context, session=session) + map_index = context["ti"].map_index + if map_index < 0: + raise RuntimeError("can't resolve task-mapping argument without expanding") + all_lengths = self._get_map_lengths(context["run_id"], session=session) + + def _find_index_for_this_field(index: int) -> int: + # Need to use the original user input to retain argument order. + for mapped_key in reversed(list(self.value)): + mapped_length = all_lengths[mapped_key] + if mapped_length < 1: + raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}") + if mapped_key == key: + return index % mapped_length + index //= mapped_length + return -1 + + found_index = _find_index_for_this_field(map_index) + if found_index < 0: + return value + if isinstance(value, collections.abc.Sequence): + return value[found_index] + if not isinstance(value, dict): + raise TypeError(f"can't map over value of type {type(value)}") + for i, (k, v) in enumerate(value.items()): + if i == found_index: + return k, v + raise IndexError(f"index {map_index} is over mapped length") + + def iter_references(self) -> Iterable[tuple[Operator, str]]: + from airflow.models.xcom_arg import XComArg + + for x in self.value.values(): + if isinstance(x, XComArg): + yield from x.iter_references() + + def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]: + data = {k: self._expand_mapped_field(k, v, context, session=session) for k, v in self.value.items()} + literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()} + resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys} + return data, resolved_oids + + +def _describe_type(value: Any) -> str: + if value is None: + return "None" + return type(value).__name__ + + +class ListOfDictsExpandInput(NamedTuple): + """Storage type of a mapped operator's mapped kwargs. + + This is created from ``expand_kwargs(xcom_arg)``. + """ + + value: OperatorExpandKwargsArgument + + def get_parse_time_mapped_ti_count(self) -> int: + if isinstance(self.value, collections.abc.Sized): + return len(self.value) + raise NotFullyPopulated({"expand_kwargs() argument"}) + + def get_total_map_length(self, run_id: str, *, session: Session) -> int: + if isinstance(self.value, collections.abc.Sized): + return len(self.value) + length = self.value.get_task_map_length(run_id, session=session) + if length is None: + raise NotFullyPopulated({"expand_kwargs() argument"}) + return length + + def iter_references(self) -> Iterable[tuple[Operator, str]]: + from airflow.models.xcom_arg import XComArg + + if isinstance(self.value, XComArg): + yield from self.value.iter_references() + else: + for x in self.value: + if isinstance(x, XComArg): + yield from x.iter_references() + + def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]: + map_index = context["ti"].map_index + if map_index < 0: + raise RuntimeError("can't resolve task-mapping argument without expanding") + + mapping: Any + if isinstance(self.value, collections.abc.Sized): + mapping = self.value[map_index] + if not isinstance(mapping, collections.abc.Mapping): + mapping = mapping.resolve(context, session) + else: + mappings = self.value.resolve(context, session) + if not isinstance(mappings, collections.abc.Sequence): + raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}") + mapping = mappings[map_index] + + if not isinstance(mapping, collections.abc.Mapping): + raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]") + + for key in mapping: + if not isinstance(key, str): + raise ValueError( + f"expand_kwargs() input dict keys must all be str, " + f"but {key!r} is of type {_describe_type(key)}" + ) + return mapping, {id(v) for v in mapping.values()} + + +EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value. + +_EXPAND_INPUT_TYPES = { + "dict-of-lists": DictOfListsExpandInput, + "list-of-dicts": ListOfDictsExpandInput, +} + + +def get_map_type_key(expand_input: ExpandInput) -> str: + return next(k for k, v in _EXPAND_INPUT_TYPES.items() if v == type(expand_input)) + + +def create_expand_input(kind: str, value: Any) -> ExpandInput: + return _EXPAND_INPUT_TYPES[kind](value) diff --git a/airflow/models/log.py b/airflow/models/log.py index b2a5639dcdd57..7994abc4638ef 100644 --- a/airflow/models/log.py +++ b/airflow/models/log.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from sqlalchemy import Column, Index, Integer, String, Text, text +from sqlalchemy import Column, Index, Integer, String, Text from airflow.models.base import Base, StringID from airflow.utils import timezone @@ -32,15 +33,15 @@ class Log(Base): dttm = Column(UtcDateTime) dag_id = Column(StringID()) task_id = Column(StringID()) - map_index = Column(Integer, server_default=text('NULL')) + map_index = Column(Integer) event = Column(String(30)) execution_date = Column(UtcDateTime) owner = Column(String(500)) extra = Column(Text) __table_args__ = ( - Index('idx_log_dag', dag_id), - Index('idx_log_event', event), + Index("idx_log_dag", dag_id), + Index("idx_log_event", event), ) def __init__(self, event, task_instance=None, owner=None, extra=None, **kwargs): @@ -55,15 +56,19 @@ def __init__(self, event, task_instance=None, owner=None, extra=None, **kwargs): self.task_id = task_instance.task_id self.execution_date = task_instance.execution_date self.map_index = task_instance.map_index - task_owner = task_instance.task.owner + if getattr(task_instance, "task", None): + task_owner = task_instance.task.owner - if 'task_id' in kwargs: - self.task_id = kwargs['task_id'] - if 'dag_id' in kwargs: - self.dag_id = kwargs['dag_id'] - if kwargs.get('execution_date'): - self.execution_date = kwargs['execution_date'] - if 'map_index' in kwargs: - self.map_index = kwargs['map_index'] + if "task_id" in kwargs: + self.task_id = kwargs["task_id"] + if "dag_id" in kwargs: + self.dag_id = kwargs["dag_id"] + if kwargs.get("execution_date"): + self.execution_date = kwargs["execution_date"] + if "map_index" in kwargs: + self.map_index = kwargs["map_index"] self.owner = owner or task_owner + + def __str__(self) -> str: + return f"Log({self.event}, {self.task_id}, {self.owner}, {self.extra})" diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index fe18a97cc1414..99e2b67f50018 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -15,38 +15,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import collections import collections.abc +import contextlib +import copy import datetime -import functools -import operator import warnings -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Collection, - Dict, - FrozenSet, - Iterable, - Iterator, - List, - Optional, - Sequence, - Set, - Tuple, - Type, - Union, -) +from typing import TYPE_CHECKING, Any, ClassVar, Collection, Iterable, Iterator, Mapping, Sequence, Union import attr import pendulum -from sqlalchemy import func, or_ from sqlalchemy.orm.session import Session from airflow import settings -from airflow.compat.functools import cache, cached_property +from airflow.compat.functools import cache from airflow.exceptions import AirflowException, UnmappableOperator from airflow.models.abstractoperator import ( DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, @@ -59,17 +43,26 @@ DEFAULT_TRIGGER_RULE, DEFAULT_WEIGHT_RULE, AbstractOperator, + NotMapped, TaskStateChangeCallback, ) +from airflow.models.expandinput import ( + DictOfListsExpandInput, + ExpandInput, + ListOfDictsExpandInput, + OperatorExpandArgument, + OperatorExpandKwargsArgument, + is_mappable, +) +from airflow.models.param import ParamsDict from airflow.models.pool import Pool from airflow.serialization.enums import DagAttributeTypes from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded from airflow.typing_compat import Literal -from airflow.utils.context import Context -from airflow.utils.helpers import is_container +from airflow.utils.context import Context, context_update_for_unmapped +from airflow.utils.helpers import is_container, prevent_duplicates from airflow.utils.operator_resources import Resources -from airflow.utils.state import State, TaskInstanceState from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import NOTSET @@ -79,29 +72,13 @@ from airflow.models.baseoperator import BaseOperator, BaseOperatorLink from airflow.models.dag import DAG from airflow.models.operator import Operator - from airflow.models.taskinstance import TaskInstance from airflow.models.xcom_arg import XComArg from airflow.utils.task_group import TaskGroup - # BaseOperator.expand() can be called on an XComArg, sequence, or dict (not - # any mapping since we need the value to be ordered). - Mappable = Union[XComArg, Sequence, dict] - ValidationSource = Union[Literal["expand"], Literal["partial"]] -MAPPABLE_LITERAL_TYPES = (dict, list) - - -# For isinstance() check. -@cache -def get_mappable_types() -> Tuple[type, ...]: - from airflow.models.xcom_arg import XComArg - - return (XComArg,) + MAPPABLE_LITERAL_TYPES - - -def validate_mapping_kwargs(op: Type["BaseOperator"], func: ValidationSource, value: Dict[str, Any]) -> None: +def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None: # use a dict so order of args is same as code order unknown_args = value.copy() for klass in op.mro(): @@ -116,7 +93,7 @@ def validate_mapping_kwargs(op: Type["BaseOperator"], func: ValidationSource, va continue if value is NOTSET: continue - if isinstance(value, get_mappable_types()): + if is_mappable(value): continue type_name = type(value).__name__ error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}" @@ -132,22 +109,13 @@ def validate_mapping_kwargs(op: Type["BaseOperator"], func: ValidationSource, va raise TypeError(f"{op.__name__}.{func}() got {error}") -def prevent_duplicates(kwargs1: Dict[str, Any], kwargs2: Dict[str, Any], *, fail_reason: str) -> None: - duplicated_keys = set(kwargs1).intersection(kwargs2) - if not duplicated_keys: - return - if len(duplicated_keys) == 1: - raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}") - duplicated_keys_display = ", ".join(sorted(duplicated_keys)) - raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}") - - def ensure_xcomarg_return_value(arg: Any) -> None: from airflow.models.xcom_arg import XCOM_RETURN_KEY, XComArg if isinstance(arg, XComArg): - if arg.key != XCOM_RETURN_KEY: - raise ValueError(f"cannot map over XCom with custom key {arg.key!r} from {arg.operator}") + for operator, key in arg.iter_references(): + if key != XCOM_RETURN_KEY: + raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}") elif not is_container(arg): return elif isinstance(arg, collections.abc.Mapping): @@ -167,8 +135,9 @@ class OperatorPartial: create a ``MappedOperator`` to add into the DAG. """ - operator_class: Type["BaseOperator"] - kwargs: Dict[str, Any] + operator_class: type[BaseOperator] + kwargs: dict[str, Any] + params: ParamsDict | dict _expand_called: bool = False # Set when expand() is called to ease user debugging. @@ -191,34 +160,50 @@ def __del__(self): task_id = f"at {hex(id(self))}" warnings.warn(f"Task {task_id} was never mapped!") - def expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator": + def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator: if not mapped_kwargs: raise TypeError("no arguments to expand against") - return self._expand(**mapped_kwargs) + validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs) + prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") + # Since the input is already checked at parse time, we can set strict + # to False to skip the checks on execution. + return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False) - def _expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator": - self._expand_called = True + def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator: + from airflow.models.xcom_arg import XComArg + if isinstance(kwargs, collections.abc.Sequence): + for item in kwargs: + if not isinstance(item, (XComArg, collections.abc.Mapping)): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") + elif not isinstance(kwargs, XComArg): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") + return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) + + def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: from airflow.operators.empty import EmptyOperator - validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs) - prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") - ensure_xcomarg_return_value(mapped_kwargs) + self._expand_called = True + ensure_xcomarg_return_value(expand_input.value) partial_kwargs = self.kwargs.copy() task_id = partial_kwargs.pop("task_id") - params = partial_kwargs.pop("params") dag = partial_kwargs.pop("dag") task_group = partial_kwargs.pop("task_group") start_date = partial_kwargs.pop("start_date") end_date = partial_kwargs.pop("end_date") + try: + operator_name = self.operator_class.custom_operator_name # type: ignore + except AttributeError: + operator_name = self.operator_class.__name__ + op = MappedOperator( operator_class=self.operator_class, - mapped_kwargs=mapped_kwargs, + expand_input=expand_input, partial_kwargs=partial_kwargs, task_id=task_id, - params=params, + params=self.params, deps=MappedOperator.deps_for(self.operator_class), operator_extra_links=self.operator_class.operator_extra_links, template_ext=self.operator_class.template_ext, @@ -229,18 +214,30 @@ def _expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator": is_empty=issubclass(self.operator_class, EmptyOperator), task_module=self.operator_class.__module__, task_type=self.operator_class.__name__, + operator_name=operator_name, dag=dag, task_group=task_group, start_date=start_date, end_date=end_date, - # For classic operators, this points to mapped_kwargs because kwargs + disallow_kwargs_override=strict, + # For classic operators, this points to expand_input because kwargs # to BaseOperator.expand() contribute to operator arguments. - expansion_kwargs_attr="mapped_kwargs", + expand_input_attr="expand_input", ) return op -@attr.define(kw_only=True) +@attr.define( + kw_only=True, + # Disable custom __getstate__ and __setstate__ generation since it interacts + # badly with Airflow's DAG serialization and pickling. When a mapped task is + # deserialized, subclasses are coerced into MappedOperator, but when it goes + # through DAG pickling, all attributes defined in the subclasses are dropped + # by attrs's custom state management. Since attrs does not do anything too + # special here (the logic is only important for slots=True), we use Python's + # built-in implementation, which works (as proven by good old BaseOperator). + getstate_setstate=False, +) class MappedOperator(AbstractOperator): """Object representing a mapped operator in a DAG.""" @@ -249,45 +246,52 @@ class MappedOperator(AbstractOperator): # can be used to create an unmapped operator for execution. For an operator # recreated from a serialized DAG, however, this holds the serialized data # that can be used to unmap this into a SerializedBaseOperator. - operator_class: Union[Type["BaseOperator"], Dict[str, Any]] + operator_class: type[BaseOperator] | dict[str, Any] - mapped_kwargs: Dict[str, "Mappable"] - partial_kwargs: Dict[str, Any] + expand_input: ExpandInput + partial_kwargs: dict[str, Any] # Needed for serialization. task_id: str - params: Optional[dict] - deps: FrozenSet[BaseTIDep] - operator_extra_links: Collection["BaseOperatorLink"] + params: ParamsDict | dict + deps: frozenset[BaseTIDep] + operator_extra_links: Collection[BaseOperatorLink] template_ext: Sequence[str] template_fields: Collection[str] - template_fields_renderers: Dict[str, str] + template_fields_renderers: dict[str, str] ui_color: str ui_fgcolor: str _is_empty: bool _task_module: str _task_type: str + _operator_name: str + + dag: DAG | None + task_group: TaskGroup | None + start_date: pendulum.DateTime | None + end_date: pendulum.DateTime | None + upstream_task_ids: set[str] = attr.ib(factory=set, init=False) + downstream_task_ids: set[str] = attr.ib(factory=set, init=False) - dag: Optional["DAG"] - task_group: Optional["TaskGroup"] - start_date: Optional[pendulum.DateTime] - end_date: Optional[pendulum.DateTime] - upstream_task_ids: Set[str] = attr.ib(factory=set, init=False) - downstream_task_ids: Set[str] = attr.ib(factory=set, init=False) + _disallow_kwargs_override: bool + """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``. - _expansion_kwargs_attr: str + If *False*, values from ``expand_input`` under duplicate keys override those + under corresponding keys in ``partial_kwargs``. + """ + + _expand_input_attr: str """Where to get kwargs to calculate expansion length against. This should be a name to call ``getattr()`` on. """ - is_mapped: ClassVar[bool] = True subdag: None = None # Since we don't support SubDagOperator, this is always None. - HIDE_ATTRS_FROM_UI: ClassVar[FrozenSet[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset( + HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset( ( - 'parse_time_mapped_ti_count', - 'operator_class', + "parse_time_mapped_ti_count", + "operator_class", ) ) @@ -300,17 +304,18 @@ def __repr__(self): def __attrs_post_init__(self): from airflow.models.xcom_arg import XComArg - self._validate_argument_count() + if self.get_closest_mapped_task_group() is not None: + raise NotImplementedError("operator expansion in an expanded task group is not yet supported") + if self.task_group: self.task_group.add(self) if self.dag: self.dag.add_task(self) - for k, v in self.mapped_kwargs.items(): - XComArg.apply_upstream_relationship(self, v) + XComArg.apply_upstream_relationship(self, self.expand_input.value) for k, v in self.partial_kwargs.items(): if k in self.template_fields: XComArg.apply_upstream_relationship(self, v) - if self.partial_kwargs.get('sla') is not None: + if self.partial_kwargs.get("sla") is not None: raise AirflowException( f"SLAs are unsupported with mapped tasks. Please set `sla=None` for task " f"{self.task_id!r}." @@ -323,7 +328,7 @@ def get_serialized_fields(cls): return frozenset(attr.fields_dict(MappedOperator)) - { "dag", "deps", - "is_mapped", + "expand_input", # This is needed to be able to accept XComArg. "subdag", "task_group", "upstream_task_ids", @@ -331,7 +336,7 @@ def get_serialized_fields(cls): @staticmethod @cache - def deps_for(operator_class: Type["BaseOperator"]) -> FrozenSet[BaseTIDep]: + def deps_for(operator_class: type[BaseOperator]) -> frozenset[BaseTIDep]: operator_deps = operator_class.deps if not isinstance(operator_deps, collections.abc.Set): raise UnmappableOperator( @@ -340,23 +345,15 @@ def deps_for(operator_class: Type["BaseOperator"]) -> FrozenSet[BaseTIDep]: ) return operator_deps | {MappedTaskIsExpanded()} - def _validate_argument_count(self) -> None: - """Validate mapping arguments by unmapping with mocked values. - - This ensures the user passed enough arguments in the DAG definition for - the operator to work in the task runner. This does not guarantee the - arguments are *valid* (that depends on the actual mapping values), but - makes sure there are *enough* of them. - """ - if not isinstance(self.operator_class, type): - return # No need to validate deserialized operator. - self.operator_class.validate_mapped_arguments(**self._get_unmap_kwargs()) - @property def task_type(self) -> str: """Implementing Operator.""" return self._task_type + @property + def operator_name(self) -> str: + return self._operator_name + @property def inherits_from_empty_operator(self) -> bool: """Implementing Operator.""" @@ -377,7 +374,7 @@ def owner(self) -> str: # type: ignore[override] return self.partial_kwargs.get("owner", DEFAULT_OWNER) @property - def email(self) -> Union[None, str, Iterable[str]]: + def email(self) -> None | str | Iterable[str]: return self.partial_kwargs.get("email") @property @@ -398,7 +395,7 @@ def wait_for_downstream(self) -> bool: return bool(self.partial_kwargs.get("wait_for_downstream")) @property - def retries(self) -> Optional[int]: + def retries(self) -> int | None: return self.partial_kwargs.get("retries", DEFAULT_RETRIES) @property @@ -410,15 +407,15 @@ def pool(self) -> str: return self.partial_kwargs.get("pool", Pool.DEFAULT_POOL_NAME) @property - def pool_slots(self) -> Optional[str]: + def pool_slots(self) -> str | None: return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS) @property - def execution_timeout(self) -> Optional[datetime.timedelta]: + def execution_timeout(self) -> datetime.timedelta | None: return self.partial_kwargs.get("execution_timeout") @property - def max_retry_delay(self) -> Optional[datetime.timedelta]: + def max_retry_delay(self) -> datetime.timedelta | None: return self.partial_kwargs.get("max_retry_delay") @property @@ -438,109 +435,175 @@ def weight_rule(self) -> int: # type: ignore[override] return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) @property - def sla(self) -> Optional[datetime.timedelta]: + def sla(self) -> datetime.timedelta | None: return self.partial_kwargs.get("sla") @property - def max_active_tis_per_dag(self) -> Optional[int]: + def max_active_tis_per_dag(self) -> int | None: return self.partial_kwargs.get("max_active_tis_per_dag") @property - def resources(self) -> Optional[Resources]: + def resources(self) -> Resources | None: return self.partial_kwargs.get("resources") @property - def on_execute_callback(self) -> Optional[TaskStateChangeCallback]: + def on_execute_callback(self) -> TaskStateChangeCallback | None: return self.partial_kwargs.get("on_execute_callback") + @on_execute_callback.setter + def on_execute_callback(self, value: TaskStateChangeCallback | None) -> None: + self.partial_kwargs["on_execute_callback"] = value + @property - def on_failure_callback(self) -> Optional[TaskStateChangeCallback]: + def on_failure_callback(self) -> TaskStateChangeCallback | None: return self.partial_kwargs.get("on_failure_callback") + @on_failure_callback.setter + def on_failure_callback(self, value: TaskStateChangeCallback | None) -> None: + self.partial_kwargs["on_failure_callback"] = value + @property - def on_retry_callback(self) -> Optional[TaskStateChangeCallback]: + def on_retry_callback(self) -> TaskStateChangeCallback | None: return self.partial_kwargs.get("on_retry_callback") + @on_retry_callback.setter + def on_retry_callback(self, value: TaskStateChangeCallback | None) -> None: + self.partial_kwargs["on_retry_callback"] = value + @property - def on_success_callback(self) -> Optional[TaskStateChangeCallback]: + def on_success_callback(self) -> TaskStateChangeCallback | None: return self.partial_kwargs.get("on_success_callback") + @on_success_callback.setter + def on_success_callback(self, value: TaskStateChangeCallback | None) -> None: + self.partial_kwargs["on_success_callback"] = value + @property - def run_as_user(self) -> Optional[str]: + def run_as_user(self) -> str | None: return self.partial_kwargs.get("run_as_user") @property def executor_config(self) -> dict: return self.partial_kwargs.get("executor_config", {}) - @property - def inlets(self) -> Optional[Any]: - return self.partial_kwargs.get("inlets", None) + @property # type: ignore[override] + def inlets(self) -> list[Any]: # type: ignore[override] + return self.partial_kwargs.get("inlets", []) - @property - def outlets(self) -> Optional[Any]: - return self.partial_kwargs.get("outlets", None) + @inlets.setter + def inlets(self, value: list[Any]) -> None: # type: ignore[override] + self.partial_kwargs["inlets"] = value + + @property # type: ignore[override] + def outlets(self) -> list[Any]: # type: ignore[override] + return self.partial_kwargs.get("outlets", []) + + @outlets.setter + def outlets(self, value: list[Any]) -> None: # type: ignore[override] + self.partial_kwargs["outlets"] = value @property - def doc(self) -> Optional[str]: + def doc(self) -> str | None: return self.partial_kwargs.get("doc") @property - def doc_md(self) -> Optional[str]: + def doc_md(self) -> str | None: return self.partial_kwargs.get("doc_md") @property - def doc_json(self) -> Optional[str]: + def doc_json(self) -> str | None: return self.partial_kwargs.get("doc_json") @property - def doc_yaml(self) -> Optional[str]: + def doc_yaml(self) -> str | None: return self.partial_kwargs.get("doc_yaml") @property - def doc_rst(self) -> Optional[str]: + def doc_rst(self) -> str | None: return self.partial_kwargs.get("doc_rst") - def get_dag(self) -> Optional["DAG"]: + def get_dag(self) -> DAG | None: """Implementing Operator.""" return self.dag - def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]: + @property + def output(self) -> XComArg: + """Returns reference to XCom pushed by current operator""" + from airflow.models.xcom_arg import XComArg + + return XComArg(operator=self) + + def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: """Implementing DAGNode.""" return DagAttributeTypes.OP, self.task_id - def _get_unmap_kwargs(self) -> Dict[str, Any]: + def _expand_mapped_kwargs(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]: + """Get the kwargs to create the unmapped operator. + + This exists because taskflow operators expand against op_kwargs, not the + entire operator kwargs dict. + """ + return self._get_specified_expand_input().resolve(context, session) + + def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]: + """Get init kwargs to unmap the underlying operator class. + + :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``. + """ + if strict: + prevent_duplicates( + self.partial_kwargs, + mapped_kwargs, + fail_reason="unmappable or already specified", + ) + + # If params appears in the mapped kwargs, we need to merge it into the + # partial params, overriding existing keys. + params = copy.copy(self.params) + with contextlib.suppress(KeyError): + params.update(mapped_kwargs["params"]) + + # Ordering is significant; mapped kwargs should override partial ones, + # and the specially handled params should be respected. return { "task_id": self.task_id, "dag": self.dag, "task_group": self.task_group, - "params": self.params, "start_date": self.start_date, "end_date": self.end_date, **self.partial_kwargs, - **self.mapped_kwargs, + **mapped_kwargs, + "params": params, } - def unmap(self, unmap_kwargs: Optional[Dict[str, Any]] = None) -> "BaseOperator": - """ - Get the "normal" Operator after applying the current mapping. + def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> BaseOperator: + """Get the "normal" Operator after applying the current mapping. - If ``operator_class`` is not a class (i.e. this DAG has been deserialized) then this will return a - SerializedBaseOperator that aims to "look like" the real operator. + The *resolve* argument is only used if ``operator_class`` is a real + class, i.e. if this operator is not serialized. If ``operator_class`` is + not a class (i.e. this DAG has been deserialized), this returns a + SerializedBaseOperator that "looks like" the actual unmapping result. - :param unmap_kwargs: Override the args to pass to the Operator constructor. Only used when - ``operator_class`` is still an actual class. + If *resolve* is a two-tuple (context, session), the information is used + to resolve the mapped arguments into init arguments. If it is a mapping, + no resolving happens, the mapping directly provides those init arguments + resolved from mapped kwargs. :meta private: """ if isinstance(self.operator_class, type): - # We can't simply specify task_id here because BaseOperator further + if isinstance(resolve, collections.abc.Mapping): + kwargs = resolve + elif resolve is not None: + kwargs, _ = self._expand_mapped_kwargs(*resolve) + else: + raise RuntimeError("cannot unmap a non-serialized operator without context") + kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override) + op = self.operator_class(**kwargs, _airflow_from_mapped=True) + # We need to overwrite task_id here because BaseOperator further # mangles the task_id based on the task hierarchy (namely, group_id - # is prepended, and '__N' appended to deduplicate). Instead of - # recreating the whole logic here, we just overwrite task_id later. - if unmap_kwargs is None: - unmap_kwargs = self._get_unmap_kwargs() - op = self.operator_class(**unmap_kwargs, _airflow_from_mapped=True) + # is prepended, and '__N' appended to deduplicate). This is hacky, + # but better than duplicating the whole mangling logic. op.task_id = self.task_id return op @@ -550,311 +613,77 @@ def unmap(self, unmap_kwargs: Optional[Dict[str, Any]] = None) -> "BaseOperator" # mapped operator to a new SerializedBaseOperator instance. from airflow.serialization.serialized_objects import SerializedBaseOperator - op = SerializedBaseOperator(task_id=self.task_id, _airflow_from_mapped=True) + op = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True) SerializedBaseOperator.populate_operator(op, self.operator_class) return op - def _get_expansion_kwargs(self) -> Dict[str, "Mappable"]: - """The kwargs to calculate expansion length against.""" - return getattr(self, self._expansion_kwargs_attr) + def _get_specified_expand_input(self) -> ExpandInput: + """Input received from the expand call on the operator.""" + return getattr(self, self._expand_input_attr) - def _get_map_lengths(self, run_id: str, *, session: Session) -> Dict[str, int]: - """Return dict of argument name to map length. + def prepare_for_execution(self) -> MappedOperator: + # Since a mapped operator cannot be used for execution, and an unmapped + # BaseOperator needs to be created later (see render_template_fields), + # we don't need to create a copy of the MappedOperator here. + return self - If any arguments are not known right now (upstream task not finished) they will not be present in the - dict. - """ - # TODO: Find a way to cache this. - from airflow.models.taskmap import TaskMap - from airflow.models.xcom import XCom + def iter_mapped_dependencies(self) -> Iterator[Operator]: + """Upstream dependencies that provide XComs used by this task for task mapping.""" from airflow.models.xcom_arg import XComArg - expansion_kwargs = self._get_expansion_kwargs() - - # Populate literal mapped arguments first. - map_lengths: Dict[str, int] = collections.defaultdict(int) - map_lengths.update((k, len(v)) for k, v in expansion_kwargs.items() if not isinstance(v, XComArg)) - - # Build a reverse mapping of what arguments each task contributes to. - mapped_dep_keys: Dict[str, Set[str]] = collections.defaultdict(set) - non_mapped_dep_keys: Dict[str, Set[str]] = collections.defaultdict(set) - for k, v in expansion_kwargs.items(): - if not isinstance(v, XComArg): - continue - if v.operator.is_mapped: - mapped_dep_keys[v.operator.task_id].add(k) - else: - non_mapped_dep_keys[v.operator.task_id].add(k) - # TODO: It's not possible now, but in the future (AIP-42 Phase 2) - # we will add support for depending on one single mapped task - # instance. When that happens, we need to further analyze the mapped - # case to contain only tasks we depend on "as a whole", and put - # those we only depend on individually to the non-mapped lookup. - - # Collect lengths from unmapped upstreams. - taskmap_query = session.query(TaskMap.task_id, TaskMap.length).filter( - TaskMap.dag_id == self.dag_id, - TaskMap.run_id == run_id, - TaskMap.task_id.in_(non_mapped_dep_keys), - TaskMap.map_index < 0, - ) - for task_id, length in taskmap_query: - for mapped_arg_name in non_mapped_dep_keys[task_id]: - map_lengths[mapped_arg_name] += length - - # Collect lengths from mapped upstreams. - xcom_query = ( - session.query(XCom.task_id, func.count(XCom.map_index)) - .group_by(XCom.task_id) - .filter( - XCom.dag_id == self.dag_id, - XCom.run_id == run_id, - XCom.task_id.in_(mapped_dep_keys), - XCom.map_index >= 0, - ) - ) - for task_id, length in xcom_query: - for mapped_arg_name in mapped_dep_keys[task_id]: - map_lengths[mapped_arg_name] += length - return map_lengths + for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()): + yield operator @cache - def _resolve_map_lengths(self, run_id: str, *, session: Session) -> Dict[str, int]: - """Return dict of argument name to map length, or throw if some are not resolvable""" - expansion_kwargs = self._get_expansion_kwargs() - map_lengths = self._get_map_lengths(run_id, session=session) - if len(map_lengths) < len(expansion_kwargs): - keys = ", ".join(repr(k) for k in sorted(set(expansion_kwargs).difference(map_lengths))) - raise RuntimeError(f"Failed to populate all mapping metadata; missing: {keys}") - - return map_lengths - - def expand_mapped_task(self, run_id: str, *, session: Session) -> Tuple[Sequence["TaskInstance"], int]: - """Create the mapped task instances for mapped task. - - :return: The newly created mapped TaskInstances (if any) in ascending order by map index, and the - maximum map_index. - """ - from airflow.models.taskinstance import TaskInstance - from airflow.settings import task_instance_mutation_hook - - total_length = functools.reduce( - operator.mul, self._resolve_map_lengths(run_id, session=session).values() - ) - - state: Optional[TaskInstanceState] = None - unmapped_ti: Optional[TaskInstance] = ( - session.query(TaskInstance) - .filter( - TaskInstance.dag_id == self.dag_id, - TaskInstance.task_id == self.task_id, - TaskInstance.run_id == run_id, - TaskInstance.map_index == -1, - or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)), - ) - .one_or_none() - ) + def get_parse_time_mapped_ti_count(self) -> int: + current_count = self._get_specified_expand_input().get_parse_time_mapped_ti_count() + try: + parent_count = super().get_parse_time_mapped_ti_count() + except NotMapped: + return current_count + return parent_count * current_count - all_expanded_tis: List[TaskInstance] = [] - - if unmapped_ti: - # The unmapped task instance still exists and is unfinished, i.e. we - # haven't tried to run it before. - if total_length < 1: - # If the upstream maps this to a zero-length value, simply marked the - # unmapped task instance as SKIPPED (if needed). - self.log.info( - "Marking %s as SKIPPED since the map has %d values to expand", - unmapped_ti, - total_length, - ) - unmapped_ti.state = TaskInstanceState.SKIPPED - else: - # Otherwise convert this into the first mapped index, and create - # TaskInstance for other indexes. - unmapped_ti.map_index = 0 - self.log.debug("Updated in place to become %s", unmapped_ti) - all_expanded_tis.append(unmapped_ti) - state = unmapped_ti.state - indexes_to_map = range(1, total_length) - else: - # Only create "missing" ones. - current_max_mapping = ( - session.query(func.max(TaskInstance.map_index)) - .filter( - TaskInstance.dag_id == self.dag_id, - TaskInstance.task_id == self.task_id, - TaskInstance.run_id == run_id, - ) - .scalar() - ) - indexes_to_map = range(current_max_mapping + 1, total_length) - - for index in indexes_to_map: - # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings. - ti = TaskInstance(self, run_id=run_id, map_index=index, state=state) - self.log.debug("Expanding TIs upserted %s", ti) - task_instance_mutation_hook(ti) - ti = session.merge(ti) - ti.refresh_from_task(self) # session.merge() loses task information. - all_expanded_tis.append(ti) - - # Set to "REMOVED" any (old) TaskInstances with map indices greater - # than the current map value - session.query(TaskInstance).filter( - TaskInstance.dag_id == self.dag_id, - TaskInstance.task_id == self.task_id, - TaskInstance.run_id == run_id, - TaskInstance.map_index >= total_length, - ).update({TaskInstance.state: TaskInstanceState.REMOVED}) - - session.flush() - return all_expanded_tis, total_length - - def prepare_for_execution(self) -> "MappedOperator": - # Since a mapped operator cannot be used for execution, and an unmapped - # BaseOperator needs to be created later (see render_template_fields), - # we don't need to create a copy of the MappedOperator here. - return self + def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: + current_count = self._get_specified_expand_input().get_total_map_length(run_id, session=session) + try: + parent_count = super().get_mapped_ti_count(run_id, session=session) + except NotMapped: + return current_count + return parent_count * current_count def render_template_fields( self, context: Context, - jinja_env: Optional["jinja2.Environment"] = None, - ) -> Optional["BaseOperator"]: - """Template all attributes listed in template_fields. + jinja_env: jinja2.Environment | None = None, + ) -> None: + """Template all attributes listed in *self.template_fields*. - Different from the BaseOperator implementation, this renders the - template fields on the *unmapped* BaseOperator. + This updates *context* to reference the map-expanded task and relevant + information, without modifying the mapped operator. The expanded task + in *context* is then rendered in-place. - :param context: Dict with values to apply on content - :param jinja_env: Jinja environment - :return: The unmapped, populated BaseOperator + :param context: Context dict with values to apply on content. + :param jinja_env: Jinja environment to use for rendering. """ if not jinja_env: jinja_env = self.get_template_env() - # Before we unmap we have to resolve the mapped arguments, otherwise the real operator constructor - # could be called with an XComArg, rather than the value it resolves to. - # - # We also need to resolve _all_ mapped arguments, even if they aren't marked as templated - kwargs = self._get_unmap_kwargs() - - template_fields = set(self.template_fields) - - # Ideally we'd like to pass in session as an argument to this function, but since operators _could_ - # override this we can't easily change this function signature. - # We can't use @provide_session, as that closes and expunges everything, which we don't want to do - # when we are so "deep" in the weeds here. - # - # Nor do we want to close the session -- that would expunge all the things from the internal cache - # which we don't want to do either + + # Ideally we'd like to pass in session as an argument to this function, + # but we can't easily change this function signature since operators + # could override this. We can't use @provide_session since it closes and + # expunges everything, which we don't want to do when we are so "deep" + # in the weeds here. We don't close this session for the same reason. session = settings.Session() - self._resolve_expansion_kwargs(kwargs, template_fields, context, session) - unmapped_task = self.unmap(unmap_kwargs=kwargs) + mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session) + unmapped_task = self.unmap(mapped_kwargs) + context_update_for_unmapped(context, unmapped_task) + self._do_render_template_fields( parent=unmapped_task, - template_fields=template_fields, + template_fields=self.template_fields, context=context, jinja_env=jinja_env, - seen_oids=set(), + seen_oids=seen_oids, session=session, ) - return unmapped_task - - def _resolve_expansion_kwargs( - self, kwargs: Dict[str, Any], template_fields: Set[str], context: Context, session: Session - ) -> None: - """Update mapped fields in place in kwargs dict""" - from airflow.models.xcom_arg import XComArg - - expansion_kwargs = self._get_expansion_kwargs() - - for k, v in expansion_kwargs.items(): - if isinstance(v, XComArg): - v = v.resolve(context, session=session) - v = self._expand_mapped_field(k, v, context, session=session) - template_fields.discard(k) - kwargs[k] = v - - def _expand_mapped_field(self, key: str, value: Any, context: Context, *, session: Session) -> Any: - map_index = context["ti"].map_index - if map_index < 0: - return value - expansion_kwargs = self._get_expansion_kwargs() - all_lengths = self._resolve_map_lengths(context["run_id"], session=session) - - def _find_index_for_this_field(index: int) -> int: - # Need to use self.mapped_kwargs for the original argument order. - for mapped_key in reversed(list(expansion_kwargs)): - mapped_length = all_lengths[mapped_key] - if mapped_length < 1: - raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}") - if mapped_key == key: - return index % mapped_length - index //= mapped_length - return -1 - - found_index = _find_index_for_this_field(map_index) - if found_index < 0: - return value - if isinstance(value, collections.abc.Sequence): - return value[found_index] - if not isinstance(value, dict): - raise TypeError(f"can't map over value of type {type(value)}") - for i, (k, v) in enumerate(value.items()): - if i == found_index: - return k, v - raise IndexError(f"index {map_index} is over mapped length") - - def iter_mapped_dependencies(self) -> Iterator["Operator"]: - """Upstream dependencies that provide XComs used by this task for task mapping.""" - from airflow.models.xcom_arg import XComArg - - for ref in XComArg.iter_xcom_args(self._get_expansion_kwargs()): - yield ref.operator - - @cached_property - def parse_time_mapped_ti_count(self) -> Optional[int]: - """ - Number of mapped TaskInstances that can be created at DagRun create time. - - :return: None if non-literal mapped arg encountered, or else total number of mapped TIs this task - should have - """ - total = 0 - - for value in self._get_expansion_kwargs().values(): - if not isinstance(value, MAPPABLE_LITERAL_TYPES): - # None literal type encountered, so give up - return None - if total == 0: - total = len(value) - else: - total *= len(value) - return total - - @cache - def run_time_mapped_ti_count(self, run_id: str, *, session: Session) -> Optional[int]: - """ - Number of mapped TaskInstances that can be created at run time, or None if upstream tasks are not - complete yet. - - :return: None if upstream tasks are not complete yet, or else total number of mapped TIs this task - should have - """ - - lengths = self._get_map_lengths(run_id, session=session) - expansion_kwargs = self._get_expansion_kwargs() - - if not lengths or not expansion_kwargs: - return None - - total = 1 - for name in expansion_kwargs: - val = lengths.get(name) - if val is None: - return None - total *= val - - return total diff --git a/airflow/models/operator.py b/airflow/models/operator.py index b95e8dc0036b1..7352ecbdcb021 100644 --- a/airflow/models/operator.py +++ b/airflow/models/operator.py @@ -15,12 +15,34 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from typing import Union +from airflow.models.abstractoperator import AbstractOperator from airflow.models.baseoperator import BaseOperator from airflow.models.mappedoperator import MappedOperator +from airflow.typing_compat import TypeGuard Operator = Union[BaseOperator, MappedOperator] -__all__ = ["Operator"] + +def needs_expansion(task: AbstractOperator) -> TypeGuard[Operator]: + """Whether a task needs expansion at runtime. + + A task needs expansion if it either + + * Is a mapped operator, or + * Is in a mapped task group. + + This is implemented as a free function (instead of a property) so we can + make it a type guard. + """ + if isinstance(task, MappedOperator): + return True + if task.get_closest_mapped_task_group() is not None: + return True + return False + + +__all__ = ["Operator", "needs_expansion"] diff --git a/airflow/models/param.py b/airflow/models/param.py index fcbe7a0f931c6..d944ef20311d3 100644 --- a/airflow/models/param.py +++ b/airflow/models/param.py @@ -14,18 +14,26 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import contextlib import copy import json +import logging import warnings -from typing import TYPE_CHECKING, Any, Dict, ItemsView, MutableMapping, Optional, ValuesView +from typing import TYPE_CHECKING, Any, ItemsView, Iterable, MutableMapping, ValuesView -from airflow.exceptions import AirflowException, ParamValidationError +from airflow.exceptions import AirflowException, ParamValidationError, RemovedInAirflow3Warning from airflow.utils.context import Context +from airflow.utils.mixins import ResolveMixin from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: from airflow.models.dag import DAG + from airflow.models.dagrun import DagRun + from airflow.models.operator import Operator + +logger = logging.getLogger(__name__) class Param: @@ -39,16 +47,16 @@ class Param: default & description will form the schema """ - CLASS_IDENTIFIER = '__class' + CLASS_IDENTIFIER = "__class" - def __init__(self, default: Any = NOTSET, description: Optional[str] = None, **kwargs): + def __init__(self, default: Any = NOTSET, description: str | None = None, **kwargs): if default is not NOTSET: self._warn_if_not_json(default) self.value = default self.description = description - self.schema = kwargs.pop('schema') if 'schema' in kwargs else kwargs + self.schema = kwargs.pop("schema") if "schema" in kwargs else kwargs - def __copy__(self) -> "Param": + def __copy__(self) -> Param: return Param(self.value, self.description, schema=self.schema) @staticmethod @@ -59,7 +67,7 @@ def _warn_if_not_json(value): warnings.warn( "The use of non-json-serializable params is deprecated and will be removed in " "a future release", - DeprecationWarning, + RemovedInAirflow3Warning, ) def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any: @@ -96,7 +104,7 @@ def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any: def dump(self) -> dict: """Dump the Param as a dictionary""" - out_dict = {self.CLASS_IDENTIFIER: f'{self.__module__}.{self.__class__.__name__}'} + out_dict = {self.CLASS_IDENTIFIER: f"{self.__module__}.{self.__class__.__name__}"} out_dict.update(self.__dict__) return out_dict @@ -112,14 +120,14 @@ class ParamsDict(MutableMapping[str, Any]): dictionary implicitly and ideally not needed to be used directly. """ - __slots__ = ['__dict', 'suppress_exception'] + __slots__ = ["__dict", "suppress_exception"] - def __init__(self, dict_obj: Optional[Dict] = None, suppress_exception: bool = False): + def __init__(self, dict_obj: dict | None = None, suppress_exception: bool = False): """ :param dict_obj: A dict or dict like object to init ParamsDict :param suppress_exception: Flag to suppress value exceptions while initializing the ParamsDict """ - params_dict: Dict[str, Param] = {} + params_dict: dict[str, Param] = {} dict_obj = dict_obj or {} for k, v in dict_obj.items(): if not isinstance(v, Param): @@ -129,10 +137,20 @@ def __init__(self, dict_obj: Optional[Dict] = None, suppress_exception: bool = F self.__dict = params_dict self.suppress_exception = suppress_exception - def __copy__(self) -> "ParamsDict": + def __bool__(self) -> bool: + return bool(self.__dict) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, ParamsDict): + return self.dump() == other.dump() + if isinstance(other, dict): + return self.dump() == other + return NotImplemented + + def __copy__(self) -> ParamsDict: return ParamsDict(self.__dict, self.suppress_exception) - def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "ParamsDict": + def __deepcopy__(self, memo: dict[int, Any] | None) -> ParamsDict: return ParamsDict(copy.deepcopy(self.__dict, memo), self.suppress_exception) def __contains__(self, o: object) -> bool: @@ -147,6 +165,9 @@ def __delitem__(self, v: str) -> None: def __iter__(self): return iter(self.__dict) + def __repr__(self): + return repr(self.dump()) + def __setitem__(self, key: str, value: Any) -> None: """ Override for dictionary's ``setitem`` method. This method make sure that all values are of @@ -163,7 +184,7 @@ def __setitem__(self, key: str, value: Any) -> None: try: param.resolve(value=value, suppress_exception=self.suppress_exception) except ParamValidationError as ve: - raise ParamValidationError(f'Invalid input for param {key}: {ve}') from None + raise ParamValidationError(f"Invalid input for param {key}: {ve}") from None else: # if the key isn't there already and if the value isn't of Param type create a new Param object param = Param(value) @@ -195,32 +216,34 @@ def update(self, *args, **kwargs) -> None: return super().update(args[0].__dict) super().update(*args, **kwargs) - def dump(self) -> Dict[str, Any]: + def dump(self) -> dict[str, Any]: """Dumps the ParamsDict object as a dictionary, while suppressing exceptions""" return {k: v.resolve(suppress_exception=True) for k, v in self.items()} - def validate(self) -> Dict[str, Any]: + def validate(self) -> dict[str, Any]: """Validates & returns all the Params object stored in the dictionary""" resolved_dict = {} try: for k, v in self.items(): resolved_dict[k] = v.resolve(suppress_exception=self.suppress_exception) except ParamValidationError as ve: - raise ParamValidationError(f'Invalid input for param {k}: {ve}') from None + raise ParamValidationError(f"Invalid input for param {k}: {ve}") from None return resolved_dict -class DagParam: - """ - Class that represents a DAG run parameter & binds a simple Param object to a name within a DAG instance, - so that it can be resolved during the run time via ``{{ context }}`` dictionary. The ideal use case of - this class is to implicitly convert args passed to a method which is being decorated by ``@dag`` keyword. +class DagParam(ResolveMixin): + """DAG run parameter reference. - It can be used to parameterize your dags. You can overwrite its value by setting it on conf - when you trigger your DagRun. + This binds a simple Param object to a name within a DAG instance, so that it + can be resolved during the runtime via the ``{{ context }}`` dictionary. The + ideal use case of this class is to implicitly convert args passed to a + method decorated by ``@dag``. - This can also be used in templates by accessing ``{{context.params}}`` dictionary. + It can be used to parameterize a DAG. You can overwrite its value by setting + it on conf when you trigger your DagRun. + + This can also be used in templates by accessing ``{{ context.params }}``. **Example**: @@ -232,18 +255,42 @@ class DagParam: :param default: Default value used if no parameter was set. """ - def __init__(self, current_dag: "DAG", name: str, default: Any = NOTSET): + def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET): if default is not NOTSET: current_dag.params[name] = default self._name = name self._default = default + def iter_references(self) -> Iterable[tuple[Operator, str]]: + return () + def resolve(self, context: Context) -> Any: """Pull DagParam value from DagRun context. This method is run during ``op.execute()``.""" with contextlib.suppress(KeyError): - return context['dag_run'].conf[self._name] + return context["dag_run"].conf[self._name] if self._default is not NOTSET: return self._default with contextlib.suppress(KeyError): - return context['params'][self._name] - raise AirflowException(f'No value could be resolved for parameter {self._name}') + return context["params"][self._name] + raise AirflowException(f"No value could be resolved for parameter {self._name}") + + +def process_params( + dag: DAG, + task: Operator, + dag_run: DagRun | None, + *, + suppress_exception: bool, +) -> dict[str, Any]: + """Merge, validate params, and convert them into a simple dict.""" + from airflow.configuration import conf + + params = ParamsDict(suppress_exception=suppress_exception) + with contextlib.suppress(AttributeError): + params.update(dag.params) + if task.params: + params.update(task.params) + if conf.getboolean("core", "dag_run_conf_overrides_params") and dag_run and dag_run.conf: + logger.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf) + params.update(dag_run.conf) + return params.validate() diff --git a/airflow/models/pool.py b/airflow/models/pool.py index f195b93d42623..d6b8ad38872dd 100644 --- a/airflow/models/pool.py +++ b/airflow/models/pool.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import Dict, Iterable, Optional, Tuple +from typing import Iterable from sqlalchemy import Column, Integer, String, Text, func from sqlalchemy.orm.session import Session @@ -50,7 +51,7 @@ class Pool(Base): slots = Column(Integer, default=0) description = Column(Text) - DEFAULT_POOL_NAME = 'default_pool' + DEFAULT_POOL_NAME = "default_pool" def __repr__(self): return str(self.pool) @@ -141,7 +142,7 @@ def slots_stats( *, lock_rows: bool = False, session: Session = NEW_SESSION, - ) -> Dict[str, PoolStats]: + ) -> dict[str, PoolStats]: """ Get Pool stats (Number of Running, Queued, Open & Total tasks) @@ -154,17 +155,17 @@ def slots_stats( """ from airflow.models.taskinstance import TaskInstance # Avoid circular import - pools: Dict[str, PoolStats] = {} + pools: dict[str, PoolStats] = {} query = session.query(Pool.pool, Pool.slots) if lock_rows: query = with_row_locks(query, session=session, **nowait(session)) - pool_rows: Iterable[Tuple[str, int]] = query.all() + pool_rows: Iterable[tuple[str, int]] = query.all() for (pool_name, total_slots) in pool_rows: if total_slots == -1: - total_slots = float('inf') # type: ignore + total_slots = float("inf") # type: ignore pools[pool_name] = PoolStats(total=total_slots, running=0, queued=0, open=0) state_count_by_pool = ( @@ -178,7 +179,7 @@ def slots_stats( # Some databases return decimal.Decimal here. count = int(count) - stats_dict: Optional[PoolStats] = pools.get(pool_name) + stats_dict: PoolStats | None = pools.get(pool_name) if not stats_dict: continue # TypedDict key must be a string literal, so we use if-statements to set value @@ -202,10 +203,10 @@ def to_json(self): :return: the pool object in json format """ return { - 'id': self.id, - 'pool': self.pool, - 'slots': self.slots, - 'description': self.description, + "id": self.id, + "pool": self.pool, + "slots": self.slots, + "description": self.description, } @provide_session @@ -262,6 +263,24 @@ def queued_slots(self, session: Session = NEW_SESSION): or 0 ) + @provide_session + def scheduled_slots(self, session: Session = NEW_SESSION): + """ + Get the number of slots scheduled at the moment. + + :param session: SQLAlchemy ORM Session + :return: the number of scheduled slots + """ + from airflow.models.taskinstance import TaskInstance # Avoid circular import + + return int( + session.query(func.sum(TaskInstance.pool_slots)) + .filter(TaskInstance.pool == self.pool) + .filter(TaskInstance.state == State.SCHEDULED) + .scalar() + or 0 + ) + @provide_session def open_slots(self, session: Session = NEW_SESSION) -> float: """ @@ -271,6 +290,6 @@ def open_slots(self, session: Session = NEW_SESSION) -> float: :return: the number of slots """ if self.slots == -1: - return float('inf') + return float("inf") else: return self.slots - self.occupied_slots(session) diff --git a/airflow/models/renderedtifields.py b/airflow/models/renderedtifields.py index c7bad78b5f3f4..7fd2c29edfa61 100644 --- a/airflow/models/renderedtifields.py +++ b/airflow/models/renderedtifields.py @@ -16,11 +16,13 @@ # specific language governing permissions and limitations # under the License. """Save Rendered Template Fields""" +from __future__ import annotations + import os -from typing import Optional +from typing import TYPE_CHECKING import sqlalchemy_jsonfield -from sqlalchemy import Column, ForeignKeyConstraint, Integer, and_, not_, tuple_ +from sqlalchemy import Column, ForeignKeyConstraint, Integer, PrimaryKeyConstraint, text from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm import Session, relationship @@ -31,6 +33,10 @@ from airflow.settings import json from airflow.utils.retries import retry_db_transaction from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.sqlalchemy import tuple_not_in_condition + +if TYPE_CHECKING: + from sqlalchemy.sql import FromClause class RenderedTaskInstanceFields(Base): @@ -41,11 +47,19 @@ class RenderedTaskInstanceFields(Base): dag_id = Column(StringID(), primary_key=True) task_id = Column(StringID(), primary_key=True) run_id = Column(StringID(), primary_key=True) - map_index = Column(Integer, primary_key=True, server_default='-1') + map_index = Column(Integer, primary_key=True, server_default=text("-1")) rendered_fields = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False) k8s_pod_yaml = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True) __table_args__ = ( + PrimaryKeyConstraint( + "dag_id", + "task_id", + "run_id", + "map_index", + name="rendered_task_instance_fields_pkey", + mssql_clustered=True, + ), ForeignKeyConstraint( [dag_id, task_id, run_id, map_index], [ @@ -54,13 +68,13 @@ class RenderedTaskInstanceFields(Base): "task_instance.run_id", "task_instance.map_index", ], - name='rtif_ti_fkey', + name="rtif_ti_fkey", ondelete="CASCADE", ), ) task_instance = relationship( "TaskInstance", - lazy='joined', + lazy="joined", back_populates="rendered_task_instance_fields", ) @@ -98,7 +112,7 @@ def __repr__(self): prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}" if self.map_index != -1: prefix += f" map_index={self.map_index}" - return prefix + '>' + return prefix + ">" def _redact(self): from airflow.utils.log.secrets_masker import redact @@ -111,7 +125,7 @@ def _redact(self): @classmethod @provide_session - def get_templated_fields(cls, ti: TaskInstance, session: Session = NEW_SESSION) -> Optional[dict]: + def get_templated_fields(cls, ti: TaskInstance, session: Session = NEW_SESSION) -> dict | None: """ Get templated field for a TaskInstance from the RenderedTaskInstanceFields table. @@ -139,7 +153,7 @@ def get_templated_fields(cls, ti: TaskInstance, session: Session = NEW_SESSION) @classmethod @provide_session - def get_k8s_pod_yaml(cls, ti: TaskInstance, session: Session = NEW_SESSION) -> Optional[dict]: + def get_k8s_pod_yaml(cls, ti: TaskInstance, session: Session = NEW_SESSION) -> dict | None: """ Get rendered Kubernetes Pod Yaml for a TaskInstance from the RenderedTaskInstanceFields table. @@ -174,9 +188,9 @@ def delete_old_records( cls, task_id: str, dag_id: str, - num_to_keep=conf.getint("core", "max_num_rendered_ti_fields_per_task", fallback=0), - session: Session = None, - ): + num_to_keep: int = conf.getint("core", "max_num_rendered_ti_fields_per_task", fallback=0), + session: Session = NEW_SESSION, + ) -> None: """ Keep only Last X (num_to_keep) number of records for a task by deleting others. @@ -202,49 +216,30 @@ def delete_old_records( .limit(num_to_keep) ) - if session.bind.dialect.name in ["postgresql", "sqlite"]: - # Fetch Top X records given dag_id & task_id ordered by Execution Date - subq1 = tis_to_keep_query.subquery() - excluded = session.query(subq1.c.dag_id, subq1.c.task_id, subq1.c.run_id) - session.query(cls).filter( - cls.dag_id == dag_id, - cls.task_id == task_id, - tuple_(cls.dag_id, cls.task_id, cls.run_id).notin_(excluded), - ).delete(synchronize_session=False) - elif session.bind.dialect.name in ["mysql"]: - cls._remove_old_rendered_ti_fields_mysql(dag_id, session, task_id, tis_to_keep_query) - else: - # Fetch Top X records given dag_id & task_id ordered by Execution Date - tis_to_keep = tis_to_keep_query.all() - - filter_tis = [ - not_( - and_( - cls.dag_id == ti.dag_id, - cls.task_id == ti.task_id, - cls.run_id == ti.run_id, - ) - ) - for ti in tis_to_keep - ] - - session.query(cls).filter(and_(*filter_tis)).delete(synchronize_session=False) - + cls._do_delete_old_records( + dag_id=dag_id, + task_id=task_id, + ti_clause=tis_to_keep_query.subquery(), + session=session, + ) session.flush() @classmethod @retry_db_transaction - def _remove_old_rendered_ti_fields_mysql(cls, dag_id, session, task_id, tis_to_keep_query): - # Fetch Top X records given dag_id & task_id ordered by Execution Date - subq1 = tis_to_keep_query.subquery('subq1') - # Second Subquery - # Workaround for MySQL Limitation (https://stackoverflow.com/a/19344141/5691525) - # Limitation: This version of MySQL does not yet support - # LIMIT & IN/ALL/ANY/SOME subquery - subq2 = session.query(subq1.c.dag_id, subq1.c.task_id, subq1.c.run_id).subquery('subq2') + def _do_delete_old_records( + cls, + *, + task_id: str, + dag_id: str, + ti_clause: FromClause, + session: Session, + ) -> None: # This query might deadlock occasionally and it should be retried if fails (see decorator) session.query(cls).filter( cls.dag_id == dag_id, cls.task_id == task_id, - tuple_(cls.dag_id, cls.task_id, cls.run_id).notin_(subq2), + tuple_not_in_condition( + (cls.dag_id, cls.task_id, cls.run_id), + session.query(ti_clause.c.dag_id, ti_clause.c.task_id, ti_clause.c.run_id), + ), ).delete(synchronize_session=False) diff --git a/airflow/models/sensorinstance.py b/airflow/models/sensorinstance.py deleted file mode 100644 index b1681e52278f3..0000000000000 --- a/airflow/models/sensorinstance.py +++ /dev/null @@ -1,185 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import json - -from sqlalchemy import BigInteger, Column, Index, Integer, String, Text - -from airflow.configuration import conf -from airflow.exceptions import AirflowException -from airflow.models.base import ID_LEN, Base -from airflow.utils import timezone -from airflow.utils.session import provide_session -from airflow.utils.sqlalchemy import UtcDateTime -from airflow.utils.state import State - - -class SensorInstance(Base): - """ - SensorInstance support the smart sensor service. It stores the sensor task states - and context that required for poking include poke context and execution context. - In sensor_instance table we also save the sensor operator classpath so that inside - smart sensor there is no need to import the dagbag and create task object for each - sensor task. - - SensorInstance include another set of columns to support the smart sensor shard on - large number of sensor instance. The key idea is to generate the hash code from the - poke context and use it to map to a shorter shard code which can be used as an index. - Every smart sensor process takes care of tasks whose `shardcode` are in a certain range. - - """ - - __tablename__ = "sensor_instance" - - id = Column(Integer, primary_key=True) - task_id = Column(String(ID_LEN), nullable=False) - dag_id = Column(String(ID_LEN), nullable=False) - execution_date = Column(UtcDateTime, nullable=False) - state = Column(String(20)) - _try_number = Column('try_number', Integer, default=0) - start_date = Column(UtcDateTime) - operator = Column(String(1000), nullable=False) - op_classpath = Column(String(1000), nullable=False) - hashcode = Column(BigInteger, nullable=False) - shardcode = Column(Integer, nullable=False) - poke_context = Column(Text, nullable=False) - execution_context = Column(Text) - created_at = Column(UtcDateTime, default=timezone.utcnow(), nullable=False) - updated_at = Column(UtcDateTime, default=timezone.utcnow(), onupdate=timezone.utcnow(), nullable=False) - - # SmartSensor doesn't support mapped operators, but this is needed for compatibly with the - # log_filename_template of TaskInstances - map_index = -1 - - __table_args__ = ( - Index('ti_primary_key', dag_id, task_id, execution_date, unique=True), - Index('si_hashcode', hashcode), - Index('si_shardcode', shardcode), - Index('si_state_shard', state, shardcode), - Index('si_updated_at', updated_at), - ) - - def __init__(self, ti): - self.dag_id = ti.dag_id - self.task_id = ti.task_id - self.execution_date = ti.execution_date - - @staticmethod - def get_classpath(obj): - """ - Get the object dotted class path. Used for getting operator classpath. - - :param obj: - :return: The class path of input object - :rtype: str - """ - module_name, class_name = obj.__module__, obj.__class__.__name__ - - return module_name + "." + class_name - - @classmethod - @provide_session - def register(cls, ti, poke_context, execution_context, session=None): - """ - Register task instance ti for a sensor in sensor_instance table. Persist the - context used for a sensor and set the sensor_instance table state to sensing. - - :param ti: The task instance for the sensor to be registered. - :param poke_context: Context used for sensor poke function. - :param execution_context: Context used for execute sensor such as timeout - setting and email configuration. - :param session: SQLAlchemy ORM Session - :return: True if the ti was registered successfully. - :rtype: Boolean - """ - if poke_context is None: - raise AirflowException('poke_context should not be None') - - encoded_poke = json.dumps(poke_context) - encoded_execution_context = json.dumps(execution_context) - - sensor = ( - session.query(SensorInstance) - .filter( - SensorInstance.dag_id == ti.dag_id, - SensorInstance.task_id == ti.task_id, - SensorInstance.execution_date == ti.execution_date, - ) - .with_for_update() - .first() - ) - - if sensor is None: - sensor = SensorInstance(ti=ti) - - sensor.operator = ti.operator - sensor.op_classpath = SensorInstance.get_classpath(ti.task) - sensor.poke_context = encoded_poke - sensor.execution_context = encoded_execution_context - - sensor.hashcode = hash(encoded_poke) - sensor.shardcode = sensor.hashcode % conf.getint('smart_sensor', 'shard_code_upper_limit') - sensor.try_number = ti.try_number - - sensor.state = State.SENSING - sensor.start_date = timezone.utcnow() - session.add(sensor) - session.commit() - - return True - - @property - def try_number(self): - """ - Return the try number that this task number will be when it is actually - run. - If the TI is currently running, this will match the column in the - database, in all other cases this will be incremented. - """ - # This is designed so that task logs end up in the right file. - if self.state in State.running: - return self._try_number - return self._try_number + 1 - - @try_number.setter - def try_number(self, value): - self._try_number = value - - def __repr__(self): - return ( - "<{self.__class__.__name__}: id: {self.id} poke_context: {self.poke_context} " - "execution_context: {self.execution_context} state: {self.state}>".format(self=self) - ) - - @provide_session - def get_dagrun(self, session): - """ - Returns the DagRun for this SensorInstance - - :param session: SQLAlchemy ORM Session - :return: DagRun - """ - from airflow.models.dagrun import DagRun # Avoid circular import - - dr = ( - session.query(DagRun) - .filter(DagRun.dag_id == self.dag_id, DagRun.execution_date == self.execution_date) - .one() - ) - - return dr diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py index 45424709a8421..3e1ec0c70cce3 100644 --- a/airflow/models/serialized_dag.py +++ b/airflow/models/serialized_dag.py @@ -15,17 +15,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Serialized DAG table in database.""" +from __future__ import annotations import hashlib import logging import zlib from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional import sqlalchemy_jsonfield -from sqlalchemy import BigInteger, Column, Index, LargeBinary, String, and_ +from sqlalchemy import BigInteger, Column, Index, LargeBinary, String, and_, or_ from sqlalchemy.orm import Session, backref, foreign, relationship from sqlalchemy.sql.expression import func, literal @@ -62,23 +61,24 @@ class SerializedDagModel(Base): it solves the webserver scalability issue. """ - __tablename__ = 'serialized_dag' + __tablename__ = "serialized_dag" dag_id = Column(String(ID_LEN), primary_key=True) fileloc = Column(String(2000), nullable=False) # The max length of fileloc exceeds the limit of indexing. - fileloc_hash = Column(BigInteger, nullable=False) - _data = Column('data', sqlalchemy_jsonfield.JSONField(json=json), nullable=True) - _data_compressed = Column('data_compressed', LargeBinary, nullable=True) + fileloc_hash = Column(BigInteger(), nullable=False) + _data = Column("data", sqlalchemy_jsonfield.JSONField(json=json), nullable=True) + _data_compressed = Column("data_compressed", LargeBinary, nullable=True) last_updated = Column(UtcDateTime, nullable=False) dag_hash = Column(String(32), nullable=False) + processor_subdir = Column(String(2000), nullable=True) - __table_args__ = (Index('idx_fileloc_hash', fileloc_hash, unique=False),) + __table_args__ = (Index("idx_fileloc_hash", fileloc_hash, unique=False),) dag_runs = relationship( DagRun, primaryjoin=dag_id == foreign(DagRun.dag_id), # type: ignore - backref=backref('serialized_dag', uselist=False, innerjoin=True), + backref=backref("serialized_dag", uselist=False, innerjoin=True), ) dag_model = relationship( @@ -87,16 +87,17 @@ class SerializedDagModel(Base): foreign_keys=dag_id, uselist=False, innerjoin=True, - backref=backref('serialized_dag', uselist=False, innerjoin=True), + backref=backref("serialized_dag", uselist=False, innerjoin=True), ) load_op_links = True - def __init__(self, dag: DAG): + def __init__(self, dag: DAG, processor_subdir: str | None = None): self.dag_id = dag.dag_id self.fileloc = dag.fileloc self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc) self.last_updated = timezone.utcnow() + self.processor_subdir = processor_subdir dag_data = SerializedDAG.to_dict(dag) dag_data_json = json.dumps(dag_data, sort_keys=True).encode("utf-8") @@ -119,7 +120,13 @@ def __repr__(self): @classmethod @provide_session - def write_dag(cls, dag: DAG, min_update_interval: Optional[int] = None, session: Session = None) -> bool: + def write_dag( + cls, + dag: DAG, + min_update_interval: int | None = None, + processor_subdir: str | None = None, + session: Session = None, + ) -> bool: """Serializes a DAG and writes it into database. If the record already exists, it checks if the Serialized DAG changed or not. If it is changed, it updates the record, ignores otherwise. @@ -142,19 +149,21 @@ def write_dag(cls, dag: DAG, min_update_interval: Optional[int] = None, session: (timezone.utcnow() - timedelta(seconds=min_update_interval)) < cls.last_updated, ) ) - .first() - is not None + .scalar() ): - # TODO: .first() is not None can be changed to .scalar() once we update to sqlalchemy 1.4+ - # as the associated sqlalchemy bug for MySQL was fixed - # related issue : https://github.com/sqlalchemy/sqlalchemy/issues/5481 return False log.debug("Checking if DAG (%s) changed", dag.dag_id) - new_serialized_dag = cls(dag) - serialized_dag_hash_from_db = session.query(cls.dag_hash).filter(cls.dag_id == dag.dag_id).scalar() + new_serialized_dag = cls(dag, processor_subdir) + serialized_dag_db = ( + session.query(cls.dag_hash, cls.processor_subdir).filter(cls.dag_id == dag.dag_id).first() + ) - if serialized_dag_hash_from_db == new_serialized_dag.dag_hash: + if ( + serialized_dag_db is not None + and serialized_dag_db.dag_hash == new_serialized_dag.dag_hash + and serialized_dag_db.processor_subdir == new_serialized_dag.processor_subdir + ): log.debug("Serialized DAG (%s) is unchanged. Skipping writing to DB", dag.dag_id) return False @@ -165,7 +174,7 @@ def write_dag(cls, dag: DAG, min_update_interval: Optional[int] = None, session: @classmethod @provide_session - def read_all_dags(cls, session: Session = None) -> Dict[str, 'SerializedDAG']: + def read_all_dags(cls, session: Session = None) -> dict[str, SerializedDAG]: """Reads all DAGs in serialized_dag table. :param session: ORM Session @@ -206,7 +215,7 @@ def dag(self): SerializedDAG._load_operator_extra_links = self.load_op_links if isinstance(self.data, dict): - dag = SerializedDAG.from_dict(self.data) # type: Any + dag = SerializedDAG.from_dict(self.data) else: dag = SerializedDAG.from_json(self.data) return dag @@ -222,7 +231,9 @@ def remove_dag(cls, dag_id: str, session: Session = None): @classmethod @provide_session - def remove_deleted_dags(cls, alive_dag_filelocs: List[str], session=None): + def remove_deleted_dags( + cls, alive_dag_filelocs: list[str], processor_subdir: str | None = None, session=None + ): """Deletes DAGs not included in alive_dag_filelocs. :param alive_dag_filelocs: file paths of alive DAGs @@ -236,7 +247,14 @@ def remove_deleted_dags(cls, alive_dag_filelocs: List[str], session=None): session.execute( cls.__table__.delete().where( - and_(cls.fileloc_hash.notin_(alive_fileloc_hashes), cls.fileloc.notin_(alive_dag_filelocs)) + and_( + cls.fileloc_hash.notin_(alive_fileloc_hashes), + cls.fileloc.notin_(alive_dag_filelocs), + or_( + cls.processor_subdir is None, + cls.processor_subdir == processor_subdir, + ), + ) ) ) @@ -252,7 +270,7 @@ def has_dag(cls, dag_id: str, session: Session = None) -> bool: @classmethod @provide_session - def get_dag(cls, dag_id: str, session: Session = None) -> Optional['SerializedDAG']: + def get_dag(cls, dag_id: str, session: Session = None) -> SerializedDAG | None: row = cls.get(dag_id, session=session) if row: return row.dag @@ -260,7 +278,7 @@ def get_dag(cls, dag_id: str, session: Session = None) -> Optional['SerializedDA @classmethod @provide_session - def get(cls, dag_id: str, session: Session = None) -> Optional['SerializedDagModel']: + def get(cls, dag_id: str, session: Session = None) -> SerializedDagModel | None: """ Get the SerializedDAG for the given dag ID. It will cope with being passed the ID of a subdag by looking up the @@ -281,7 +299,7 @@ def get(cls, dag_id: str, session: Session = None) -> Optional['SerializedDagMod @staticmethod @provide_session - def bulk_sync_to_db(dags: List[DAG], session: Session = None): + def bulk_sync_to_db(dags: list[DAG], processor_subdir: str | None = None, session: Session = None): """ Saves DAGs as Serialized DAG objects in the database. Each DAG is saved in a separate database query. @@ -293,12 +311,15 @@ def bulk_sync_to_db(dags: List[DAG], session: Session = None): for dag in dags: if not dag.is_subdag: SerializedDagModel.write_dag( - dag, min_update_interval=MIN_SERIALIZED_DAG_UPDATE_INTERVAL, session=session + dag=dag, + min_update_interval=MIN_SERIALIZED_DAG_UPDATE_INTERVAL, + processor_subdir=processor_subdir, + session=session, ) @classmethod @provide_session - def get_last_updated_datetime(cls, dag_id: str, session: Session = None) -> Optional[datetime]: + def get_last_updated_datetime(cls, dag_id: str, session: Session = None) -> datetime | None: """ Get the date when the Serialized DAG associated to DAG was last updated in serialized_dag table @@ -310,7 +331,7 @@ def get_last_updated_datetime(cls, dag_id: str, session: Session = None) -> Opti @classmethod @provide_session - def get_max_last_updated_datetime(cls, session: Session = None) -> Optional[datetime]: + def get_max_last_updated_datetime(cls, session: Session = None) -> datetime | None: """ Get the maximum date when any DAG was last updated in serialized_dag table @@ -320,20 +341,19 @@ def get_max_last_updated_datetime(cls, session: Session = None) -> Optional[date @classmethod @provide_session - def get_latest_version_hash(cls, dag_id: str, session: Session = None) -> Optional[str]: + def get_latest_version_hash(cls, dag_id: str, session: Session = None) -> str | None: """ Get the latest DAG version for a given DAG ID. :param dag_id: DAG ID :param session: ORM Session :return: DAG Hash, or None if the DAG is not found - :rtype: str | None """ return session.query(cls.dag_hash).filter(cls.dag_id == dag_id).scalar() @classmethod @provide_session - def get_dag_dependencies(cls, session: Session = None) -> Dict[str, List['DagDependency']]: + def get_dag_dependencies(cls, session: Session = None) -> dict[str, list[DagDependency]]: """ Get the dependencies between DAGs diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index d5b1481cb19b2..c2ec087cf6042 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -15,10 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Union, cast +from typing import TYPE_CHECKING, Iterable, Sequence +from airflow.exceptions import AirflowException, RemovedInAirflow3Warning from airflow.models.taskinstance import TaskInstance from airflow.utils import timezone from airflow.utils.log.logging_mixin import LoggingMixin @@ -29,8 +31,9 @@ from pendulum import DateTime from sqlalchemy import Session - from airflow.models.baseoperator import BaseOperator from airflow.models.dagrun import DagRun + from airflow.models.operator import Operator + from airflow.models.taskmixin import DAGNode # The key used by SkipMixin to store XCom data. XCOM_SKIPMIXIN_KEY = "skipmixin_key" @@ -42,18 +45,29 @@ XCOM_SKIPMIXIN_FOLLOWED = "followed" +def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]: + from airflow.models.baseoperator import BaseOperator + from airflow.models.mappedoperator import MappedOperator + + return [n for n in nodes if isinstance(n, (BaseOperator, MappedOperator))] + + class SkipMixin(LoggingMixin): """A Mixin to skip Tasks Instances""" - def _set_state_to_skipped(self, dag_run: "DagRun", tasks: "Iterable[BaseOperator]", session: "Session"): + def _set_state_to_skipped( + self, + dag_run: DagRun, + tasks: Iterable[Operator], + session: Session, + ) -> None: """Used internally to set state of task instances to skipped from the same dag run.""" - task_ids = [d.task_id for d in tasks] now = timezone.utcnow() session.query(TaskInstance).filter( TaskInstance.dag_id == dag_run.dag_id, TaskInstance.run_id == dag_run.run_id, - TaskInstance.task_id.in_(task_ids), + TaskInstance.task_id.in_(d.task_id for d in tasks), ).update( { TaskInstance.state: State.SKIPPED, @@ -66,10 +80,10 @@ def _set_state_to_skipped(self, dag_run: "DagRun", tasks: "Iterable[BaseOperator @provide_session def skip( self, - dag_run: "DagRun", - execution_date: "DateTime", - tasks: Sequence["BaseOperator"], - session: "Session" = NEW_SESSION, + dag_run: DagRun, + execution_date: DateTime, + tasks: Iterable[DAGNode], + session: Session = NEW_SESSION, ): """ Sets tasks instances to skipped from the same dag run. @@ -83,7 +97,8 @@ def skip( :param tasks: tasks to skip (not task_ids) :param session: db session to use """ - if not tasks: + task_list = _ensure_tasks(tasks) + if not task_list: return if execution_date and not dag_run: @@ -91,14 +106,14 @@ def skip( warnings.warn( "Passing an execution_date to `skip()` is deprecated in favour of passing a dag_run", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) dag_run = ( session.query(DagRun) .filter( - DagRun.dag_id == tasks[0].dag_id, + DagRun.dag_id == task_list[0].dag_id, DagRun.execution_date == execution_date, ) .one() @@ -111,24 +126,24 @@ def skip( if dag_run is None: raise ValueError("dag_run is required") - self._set_state_to_skipped(dag_run, tasks, session) + self._set_state_to_skipped(dag_run, task_list, session) session.commit() # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available. - task_id: Optional[str] = getattr(self, "task_id", None) + task_id: str | None = getattr(self, "task_id", None) if task_id is not None: from airflow.models.xcom import XCom XCom.set( key=XCOM_SKIPMIXIN_KEY, - value={XCOM_SKIPMIXIN_SKIPPED: [d.task_id for d in tasks]}, + value={XCOM_SKIPMIXIN_SKIPPED: [d.task_id for d in task_list]}, task_id=task_id, dag_id=dag_run.dag_id, run_id=dag_run.run_id, session=session, ) - def skip_all_except(self, ti: TaskInstance, branch_task_ids: Union[None, str, Iterable[str]]): + def skip_all_except(self, ti: TaskInstance, branch_task_ids: None | str | Iterable[str]): """ This method implements the logic for a branching operator; given a single task ID or list of task IDs to follow, this skips all other tasks @@ -139,19 +154,40 @@ def skip_all_except(self, ti: TaskInstance, branch_task_ids: Union[None, str, It """ self.log.info("Following branch %s", branch_task_ids) if isinstance(branch_task_ids, str): - branch_task_ids = {branch_task_ids} + branch_task_id_set = {branch_task_ids} + elif isinstance(branch_task_ids, Iterable): + branch_task_id_set = set(branch_task_ids) + invalid_task_ids_type = { + (bti, type(bti).__name__) for bti in branch_task_ids if not isinstance(bti, str) + } + if invalid_task_ids_type: + raise AirflowException( + f"'branch_task_ids' expected all task IDs are strings. " + f"Invalid tasks found: {invalid_task_ids_type}." + ) elif branch_task_ids is None: - branch_task_ids = () - - branch_task_ids = set(branch_task_ids) + branch_task_id_set = set() + else: + raise AirflowException( + "'branch_task_ids' must be either None, a task ID, or an Iterable of IDs, " + f"but got {type(branch_task_ids).__name__!r}." + ) dag_run = ti.get_dagrun() task = ti.task dag = task.dag - assert dag # For Mypy. + if TYPE_CHECKING: + assert dag + + valid_task_ids = set(dag.task_ids) + invalid_task_ids = branch_task_id_set - valid_task_ids + if invalid_task_ids: + raise AirflowException( + "'branch_task_ids' must contain only valid task_ids. " + f"Invalid tasks found: {invalid_task_ids}." + ) - # At runtime, the downstream list will only be operators - downstream_tasks = cast("List[BaseOperator]", task.downstream_list) + downstream_tasks = _ensure_tasks(task.downstream_list) if downstream_tasks: # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"), @@ -166,11 +202,11 @@ def skip_all_except(self, ti: TaskInstance, branch_task_ids: Union[None, str, It # v / # task1 # - for branch_task_id in list(branch_task_ids): - branch_task_ids.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False)) + for branch_task_id in list(branch_task_id_set): + branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False)) - skip_tasks = [t for t in downstream_tasks if t.task_id not in branch_task_ids] - follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_ids] + skip_tasks = [t for t in downstream_tasks if t.task_id not in branch_task_id_set] + follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set] self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks]) with create_session() as session: diff --git a/airflow/models/slamiss.py b/airflow/models/slamiss.py index 6c841e3ee6e05..cd2f0f1dd5271 100644 --- a/airflow/models/slamiss.py +++ b/airflow/models/slamiss.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from sqlalchemy import Boolean, Column, Index, String, Text @@ -39,7 +40,7 @@ class SlaMiss(Base): description = Column(Text) notification_sent = Column(Boolean, default=False) - __table_args__ = (Index('sm_dag', dag_id, unique=False),) + __table_args__ = (Index("sm_dag", dag_id, unique=False),) def __repr__(self): return str((self.dag_id, self.task_id, self.execution_date.isoformat())) diff --git a/airflow/models/taskfail.py b/airflow/models/taskfail.py index f7de99c308cac..ead4cee000f28 100644 --- a/airflow/models/taskfail.py +++ b/airflow/models/taskfail.py @@ -16,8 +16,9 @@ # specific language governing permissions and limitations # under the License. """Taskfail tracks the failed run durations of each task instance""" +from __future__ import annotations -from sqlalchemy import Column, ForeignKeyConstraint, Integer +from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, text from sqlalchemy.orm import relationship from airflow.models.base import Base, StringID @@ -33,12 +34,13 @@ class TaskFail(Base): task_id = Column(StringID(), nullable=False) dag_id = Column(StringID(), nullable=False) run_id = Column(StringID(), nullable=False) - map_index = Column(Integer, nullable=False) + map_index = Column(Integer, nullable=False, server_default=text("-1")) start_date = Column(UtcDateTime) end_date = Column(UtcDateTime) duration = Column(Integer) __table_args__ = ( + Index("idx_task_fail_task_instance", dag_id, task_id, run_id, map_index), ForeignKeyConstraint( [dag_id, task_id, run_id, map_index], [ @@ -47,7 +49,7 @@ class TaskFail(Base): "task_instance.run_id", "task_instance.map_index", ], - name='task_fail_ti_fkey', + name="task_fail_ti_fkey", ondelete="CASCADE", ), ) @@ -79,4 +81,4 @@ def __repr__(self): prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}" if self.map_index != -1: prefix += f" map_index={self.map_index}" - return prefix + '>' + return prefix + ">" diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 0f5d49b819762..9708d73124b47 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import collections.abc import contextlib import hashlib @@ -28,37 +30,24 @@ from datetime import datetime, timedelta from functools import partial from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Collection, - ContextManager, - Dict, - Generator, - Iterable, - List, - NamedTuple, - Optional, - Set, - Tuple, - Union, -) +from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, NamedTuple, Tuple from urllib.parse import quote -import attr import dill import jinja2 +import lazy_object_proxy import pendulum from jinja2 import TemplateAssertionError, UndefinedError from sqlalchemy import ( Column, + DateTime, Float, ForeignKeyConstraint, Index, Integer, - PickleType, + PrimaryKeyConstraint, String, + Text, and_, false, func, @@ -70,38 +59,38 @@ from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import reconstructor, relationship from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value -from sqlalchemy.orm.exc import NoResultFound -from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session from sqlalchemy.sql.elements import BooleanClauseList -from sqlalchemy.sql.expression import ColumnOperators -from sqlalchemy.sql.sqltypes import BigInteger +from sqlalchemy.sql.expression import ColumnOperators, case from airflow import settings from airflow.compat.functools import cache from airflow.configuration import conf +from airflow.datasets import Dataset +from airflow.datasets.manager import dataset_manager from airflow.exceptions import ( AirflowException, AirflowFailException, AirflowRescheduleException, AirflowSensorTimeout, AirflowSkipException, - AirflowSmartSensorException, AirflowTaskTimeout, DagRunNotFound, + RemovedInAirflow3Warning, TaskDeferralError, TaskDeferred, UnmappableXComLengthPushed, UnmappableXComTypePushed, XComForMappingNotPushed, ) -from airflow.models.base import COLLATION_ARGS, ID_LEN, Base +from airflow.models.base import Base, StringID from airflow.models.log import Log -from airflow.models.param import ParamsDict +from airflow.models.mappedoperator import MappedOperator +from airflow.models.param import process_params from airflow.models.taskfail import TaskFail from airflow.models.taskmap import TaskMap from airflow.models.taskreschedule import TaskReschedule -from airflow.models.xcom import XCOM_RETURN_KEY, XCom +from airflow.models.xcom import XCOM_RETURN_KEY, LazyXComAccess, XCom from airflow.plugins_manager import integrate_macros_plugins from airflow.sentry import Sentry from airflow.stats import Stats @@ -109,7 +98,7 @@ from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS from airflow.timetables.base import DataInterval -from airflow.typing_compat import Literal +from airflow.typing_compat import Literal, TypeGuard from airflow.utils import timezone from airflow.utils.context import ConnectionAccessor, Context, VariableAccessor, context_merge from airflow.utils.email import send_email @@ -120,21 +109,30 @@ from airflow.utils.platform import getuser from airflow.utils.retries import run_with_db_retries from airflow.utils.session import NEW_SESSION, create_session, provide_session -from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, tuple_in_condition, with_row_locks +from airflow.utils.sqlalchemy import ( + ExecutorConfigType, + ExtendedJSON, + UtcDateTime, + tuple_in_condition, + with_row_locks, +) from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.timeout import timeout TR = TaskReschedule -_CURRENT_CONTEXT: List[Context] = [] +_CURRENT_CONTEXT: list[Context] = [] log = logging.getLogger(__name__) if TYPE_CHECKING: + from airflow.models.abstractoperator import TaskStateChangeCallback from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG, DagModel from airflow.models.dagrun import DagRun + from airflow.models.dataset import DatasetEvent from airflow.models.operator import Operator + from airflow.utils.task_group import MappedTaskGroup, TaskGroup @contextlib.contextmanager @@ -157,12 +155,12 @@ def set_current_context(context: Context) -> Generator[Context, None, None]: def clear_task_instances( - tis, - session, - activate_dag_runs=None, - dag=None, - dag_run_state: Union[DagRunState, Literal[False]] = DagRunState.QUEUED, -): + tis: list[TaskInstance], + session: Session, + activate_dag_runs: None = None, + dag: DAG | None = None, + dag_run_state: DagRunState | Literal[False] = DagRunState.QUEUED, +) -> None: """ Clears a set of task instances, but makes sure the running ones get killed. @@ -176,7 +174,7 @@ def clear_task_instances( """ job_ids = [] # Keys: dag_id -> run_id -> map_indexes -> try_numbers -> task_id - task_id_by_key: Dict[str, Dict[str, Dict[int, Dict[int, Set[str]]]]] = defaultdict( + task_id_by_key: dict[str, dict[str, dict[int, dict[int, set[str]]]]] = defaultdict( lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(set))) ) for ti in tis: @@ -201,6 +199,7 @@ def clear_task_instances( ti.max_tries = max(ti.max_tries, ti.prev_attempted_tries) ti.state = None ti.external_executor_id = None + ti.clear_next_method_args() session.merge(ti) task_id_by_key[ti.dag_id][ti.run_id][ti.map_index][ti.try_number].add(ti.task_id) @@ -249,7 +248,7 @@ def clear_task_instances( warnings.warn( "`activate_dag_runs` parameter to clear_task_instances function is deprecated. " "Please use `dag_run_state`", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) if not activate_dag_runs: @@ -279,91 +278,20 @@ def clear_task_instances( if dag_run_state == DagRunState.QUEUED: dr.last_scheduling_decision = None dr.start_date = None + session.flush() -class _LazyXComAccessIterator(collections.abc.Iterator): - __slots__ = ['_cm', '_it'] +def _is_mappable_value(value: Any) -> TypeGuard[Collection]: + """Whether a value can be used for task mapping. - def __init__(self, cm: ContextManager[Query]): - self._cm = cm - self._it = None - - def __del__(self): - if self._it: - self._cm.__exit__(None, None, None) - - def __iter__(self): - return self - - def __next__(self): - if not self._it: - self._it = iter(self._cm.__enter__()) - return XCom.deserialize_value(next(self._it)) - - -@attr.define -class _LazyXComAccess(collections.abc.Sequence): - """Wrapper to lazily pull XCom with a sequence-like interface. - - Note that since the session bound to the parent query may have died when we - actually access the sequence's content, we must create a new session - for every function call with ``with_session()``. + We only allow collections with guaranteed ordering, but exclude character + sequences since that's usually not what users would expect to be mappable. """ - - dag_id: str - run_id: str - task_id: str - _query: Query = attr.ib(repr=False) - _len: Optional[int] = attr.ib(init=False, repr=False, default=None) - - @classmethod - def build_from_single_xcom(cls, first: "XCom", query: Query) -> "_LazyXComAccess": - return cls( - dag_id=first.dag_id, - run_id=first.run_id, - task_id=first.task_id, - query=query.with_entities(XCom.value) - .filter( - XCom.run_id == first.run_id, - XCom.task_id == first.task_id, - XCom.dag_id == first.dag_id, - XCom.map_index >= 0, - ) - .order_by(None) - .order_by(XCom.map_index.asc()), - ) - - def __len__(self): - if self._len is None: - with self._get_bound_query() as query: - self._len = query.count() - return self._len - - def __iter__(self): - return _LazyXComAccessIterator(self._get_bound_query()) - - def __getitem__(self, key): - if not isinstance(key, int): - raise ValueError("only support index access for now") - try: - with self._get_bound_query() as query: - r = query.offset(key).limit(1).one() - except NoResultFound: - raise IndexError(key) from None - return XCom.deserialize_value(r) - - @contextlib.contextmanager - def _get_bound_query(self) -> Generator[Query, None, None]: - # Do we have a valid session already? - if self._query.session and self._query.session.is_active: - yield self._query - return - - session = settings.Session() - try: - yield self._query.with_session(session) - finally: - session.close() + if not isinstance(value, (collections.abc.Sequence, dict)): + return False + if isinstance(value, (bytearray, bytes, str)): + return False + return True class TaskInstanceKey(NamedTuple): @@ -376,23 +304,23 @@ class TaskInstanceKey(NamedTuple): map_index: int = -1 @property - def primary(self) -> Tuple[str, str, str, int]: + def primary(self) -> tuple[str, str, str, int]: """Return task instance primary key part of the key""" return self.dag_id, self.task_id, self.run_id, self.map_index @property - def reduced(self) -> 'TaskInstanceKey': + def reduced(self) -> TaskInstanceKey: """Remake the key by subtracting 1 from try number to match in memory information""" return TaskInstanceKey( self.dag_id, self.task_id, self.run_id, max(1, self.try_number - 1), self.map_index ) - def with_try_number(self, try_number: int) -> 'TaskInstanceKey': + def with_try_number(self, try_number: int) -> TaskInstanceKey: """Returns TaskInstanceKey with provided ``try_number``""" return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, try_number, self.map_index) @property - def key(self) -> "TaskInstanceKey": + def key(self) -> TaskInstanceKey: """For API-compatibly with TaskInstance. Returns self @@ -400,6 +328,16 @@ def key(self) -> "TaskInstanceKey": return self +def _creator_note(val): + """Custom creator for the ``note`` association proxy.""" + if isinstance(val, str): + return TaskInstanceNote(content=val) + elif isinstance(val, dict): + return TaskInstanceNote(**val) + else: + return TaskInstanceNote(*val) + + class TaskInstance(Base, LoggingMixin): """ Task instances store the state of a task instance. This table is the @@ -419,38 +357,41 @@ class TaskInstance(Base, LoggingMixin): """ __tablename__ = "task_instance" - - task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True, nullable=False) - dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True, nullable=False) - run_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True, nullable=False) + task_id = Column(StringID(), primary_key=True, nullable=False) + dag_id = Column(StringID(), primary_key=True, nullable=False) + run_id = Column(StringID(), primary_key=True, nullable=False) map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1")) start_date = Column(UtcDateTime) end_date = Column(UtcDateTime) duration = Column(Float) state = Column(String(20)) - _try_number = Column('try_number', Integer, default=0) - max_tries = Column(Integer) + _try_number = Column("try_number", Integer, default=0) + max_tries = Column(Integer, server_default=text("-1")) hostname = Column(String(1000)) unixname = Column(String(1000)) job_id = Column(Integer) pool = Column(String(256), nullable=False) - pool_slots = Column(Integer, default=1, nullable=False, server_default=text("1")) + pool_slots = Column(Integer, default=1, nullable=False) queue = Column(String(256)) priority_weight = Column(Integer) operator = Column(String(1000)) queued_dttm = Column(UtcDateTime) queued_by_job_id = Column(Integer) pid = Column(Integer) - executor_config = Column(PickleType(pickler=dill)) + executor_config = Column(ExecutorConfigType(pickler=dill)) + updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow) - external_executor_id = Column(String(ID_LEN, **COLLATION_ARGS)) + external_executor_id = Column(StringID()) # The trigger to resume on if we are in state DEFERRED - trigger_id = Column(BigInteger) + trigger_id = Column(Integer) # Optional timeout datetime for the trigger (past this, we'll fail) - trigger_timeout = Column(UtcDateTime) + trigger_timeout = Column(DateTime) + # The trigger_timeout should be TIMESTAMP(using UtcDateTime) but for ease of + # migration, we are keeping it as DateTime pending a change where expensive + # migration is inevitable. # The method to call next, and any extra arguments to pass to it. # Usually used when resuming from DEFERRED. @@ -461,23 +402,26 @@ class TaskInstance(Base, LoggingMixin): # refresh_from_db() or they won't display in the UI correctly __table_args__ = ( - Index('ti_dag_state', dag_id, state), - Index('ti_dag_run', dag_id, run_id), - Index('ti_state', state), - Index('ti_state_lkp', dag_id, task_id, run_id, state), - Index('ti_pool', pool, state, priority_weight), - Index('ti_job_id', job_id), - Index('ti_trigger_id', trigger_id), + Index("ti_dag_state", dag_id, state), + Index("ti_dag_run", dag_id, run_id), + Index("ti_state", state), + Index("ti_state_lkp", dag_id, task_id, run_id, state), + Index("ti_pool", pool, state, priority_weight), + Index("ti_job_id", job_id), + Index("ti_trigger_id", trigger_id), + PrimaryKeyConstraint( + "dag_id", "task_id", "run_id", "map_index", name="task_instance_pkey", mssql_clustered=True + ), ForeignKeyConstraint( [trigger_id], - ['trigger.id'], - name='task_instance_trigger_id_fkey', - ondelete='CASCADE', + ["trigger.id"], + name="task_instance_trigger_id_fkey", + ondelete="CASCADE", ), ForeignKeyConstraint( [dag_id, run_id], ["dag_run.dag_id", "dag_run.run_id"], - name='task_instance_dag_run_fkey', + name="task_instance_dag_run_fkey", ondelete="CASCADE", ), ) @@ -491,27 +435,21 @@ class TaskInstance(Base, LoggingMixin): viewonly=True, ) - trigger = relationship( - "Trigger", - primaryjoin="TaskInstance.trigger_id == Trigger.id", - foreign_keys=trigger_id, - uselist=False, - innerjoin=True, - ) - - dag_run = relationship("DagRun", back_populates="task_instances", lazy='joined', innerjoin=True) - rendered_task_instance_fields = relationship("RenderedTaskInstanceFields", lazy='noload', uselist=False) - + trigger = relationship("Trigger", uselist=False) + triggerer_job = association_proxy("trigger", "triggerer_job") + dag_run = relationship("DagRun", back_populates="task_instances", lazy="joined", innerjoin=True) + rendered_task_instance_fields = relationship("RenderedTaskInstanceFields", lazy="noload", uselist=False) execution_date = association_proxy("dag_run", "execution_date") - - task: "Operator" # Not always set... + task_instance_note = relationship("TaskInstanceNote", back_populates="task_instance", uselist=False) + note = association_proxy("task_instance_note", "content", creator=_creator_note) + task: Operator # Not always set... def __init__( self, - task: "Operator", - execution_date: Optional[datetime] = None, - run_id: Optional[str] = None, - state: Optional[str] = None, + task: Operator, + execution_date: datetime | None = None, + run_id: str | None = None, + state: str | None = None, map_index: int = -1, ): super().__init__() @@ -527,7 +465,7 @@ def __init__( warnings.warn( "Passing an execution_date to `TaskInstance()` is deprecated in favour of passing a run_id", - DeprecationWarning, + RemovedInAirflow3Warning, # Stack level is 4 because SQLA adds some wrappers around the constructor stacklevel=4, ) @@ -538,7 +476,8 @@ def __init__( execution_date, ) if self.task.has_dag(): - assert self.task.dag # For Mypy. + if TYPE_CHECKING: + assert self.task.dag execution_date = timezone.make_aware(execution_date, self.task.dag.timezone) else: execution_date = timezone.make_aware(execution_date) @@ -562,7 +501,7 @@ def __init__( self.unixname = getuser() if state: self.state = state - self.hostname = '' + self.hostname = "" # Is this TaskInstance being currently running within `airflow tasks run --raw`. # Not persisted to the database so only valid for the current process self.raw = False @@ -570,28 +509,28 @@ def __init__( self.test_mode = False @staticmethod - def insert_mapping(run_id: str, task: "Operator", map_index: int) -> dict: + def insert_mapping(run_id: str, task: Operator, map_index: int) -> dict[str, Any]: """:meta private:""" return { - 'dag_id': task.dag_id, - 'task_id': task.task_id, - 'run_id': run_id, - '_try_number': 0, - 'hostname': '', - 'unixname': getuser(), - 'queue': task.queue, - 'pool': task.pool, - 'pool_slots': task.pool_slots, - 'priority_weight': task.priority_weight_total, - 'run_as_user': task.run_as_user, - 'max_tries': task.retries, - 'executor_config': task.executor_config, - 'operator': task.task_type, - 'map_index': map_index, + "dag_id": task.dag_id, + "task_id": task.task_id, + "run_id": run_id, + "_try_number": 0, + "hostname": "", + "unixname": getuser(), + "queue": task.queue, + "pool": task.pool, + "pool_slots": task.pool_slots, + "priority_weight": task.priority_weight_total, + "run_as_user": task.run_as_user, + "max_tries": task.retries, + "executor_config": task.executor_config, + "operator": task.task_type, + "map_index": map_index, } @reconstructor - def init_on_load(self): + def init_on_load(self) -> None: """Initialize the attributes that aren't stored in the DB""" # correctly config the ti log self._log = logging.getLogger("airflow.task") @@ -607,17 +546,16 @@ def try_number(self): database, in all other cases this will be incremented. """ # This is designed so that task logs end up in the right file. - # TODO: whether we need sensing here or not (in sensor and task_instance state machine) - if self.state in State.running: + if self.state == State.RUNNING: return self._try_number return self._try_number + 1 @try_number.setter - def try_number(self, value): + def try_number(self, value: int) -> None: self._try_number = value @property - def prev_attempted_tries(self): + def prev_attempted_tries(self) -> int: """ Based on this instance's try_number, this will calculate the number of previously attempted tries, defaulting to 0. @@ -631,8 +569,7 @@ def prev_attempted_tries(self): return self._try_number @property - def next_try_number(self): - """Setting Next Try Number""" + def next_try_number(self) -> int: return self._try_number + 1 def command_as_list( @@ -654,9 +591,9 @@ def command_as_list( installed. This command is part of the message sent to executors by the orchestrator. """ - dag: Union["DAG", "DagModel"] + dag: DAG | DagModel # Use the dag if we have it, else fallback to the ORM dag_model, which might not be loaded - if hasattr(self, 'task') and hasattr(self.task, 'dag'): + if hasattr(self, "task") and hasattr(self.task, "dag"): dag = self.task.dag else: dag = self.dag_model @@ -671,7 +608,7 @@ def command_as_list( if path: if not path.is_absolute(): - path = 'DAGS_FOLDER' / path + path = "DAGS_FOLDER" / path path = str(path) return TaskInstance.generate_command( @@ -704,14 +641,14 @@ def generate_command( ignore_task_deps: bool = False, ignore_ti_state: bool = False, local: bool = False, - pickle_id: Optional[int] = None, - file_path: Optional[str] = None, + pickle_id: int | None = None, + file_path: str | None = None, raw: bool = False, - job_id: Optional[str] = None, - pool: Optional[str] = None, - cfg_path: Optional[str] = None, + job_id: str | None = None, + pool: str | None = None, + cfg_path: str | None = None, map_index: int = -1, - ) -> List[str]: + ) -> list[str]: """ Generates the shell command required to execute this task instance. @@ -735,7 +672,6 @@ def generate_command( :param pool: the Airflow pool that the task should run in :param cfg_path: the Path to the configuration file :return: shell command that can be used to run the task instance - :rtype: list[str] """ cmd = ["airflow", "tasks", "run", dag_id, task_id, run_id] if mark_success: @@ -763,22 +699,28 @@ def generate_command( if cfg_path: cmd.extend(["--cfg-path", cfg_path]) if map_index != -1: - cmd.extend(['--map-index', str(map_index)]) + cmd.extend(["--map-index", str(map_index)]) return cmd @property - def log_url(self): + def log_url(self) -> str: """Log URL for TaskInstance""" iso = quote(self.execution_date.isoformat()) - base_url = conf.get('webserver', 'BASE_URL') - return base_url + f"/log?execution_date={iso}&task_id={self.task_id}&dag_id={self.dag_id}" + base_url = conf.get_mandatory_value("webserver", "BASE_URL") + return ( + f"{base_url}/log" + f"?execution_date={iso}" + f"&task_id={self.task_id}" + f"&dag_id={self.dag_id}" + f"&map_index={self.map_index}" + ) @property - def mark_success_url(self): + def mark_success_url(self) -> str: """URL to mark TI success""" - base_url = conf.get('webserver', 'BASE_URL') - return base_url + ( - "/confirm" + base_url = conf.get_mandatory_value("webserver", "BASE_URL") + return ( + f"{base_url}/confirm" f"?task_id={self.task_id}" f"&dag_id={self.dag_id}" f"&dag_run_id={quote(self.run_id)}" @@ -788,26 +730,22 @@ def mark_success_url(self): ) @provide_session - def current_state(self, session=NEW_SESSION) -> str: + def current_state(self, session: Session = NEW_SESSION) -> str: """ Get the very latest state from the database, if a session is passed, we use and looking up the state becomes part of the session, otherwise a new session is used. + sqlalchemy.inspect is used here to get the primary keys ensuring that if they change + it will not regress + :param session: SQLAlchemy ORM Session """ - return ( - session.query(TaskInstance.state) - .filter( - TaskInstance.dag_id == self.dag_id, - TaskInstance.task_id == self.task_id, - TaskInstance.run_id == self.run_id, - ) - .scalar() - ) + filters = (col == getattr(self, col.name) for col in inspect(TaskInstance).primary_key) + return session.query(TaskInstance.state).filter(*filters).scalar() @provide_session - def error(self, session=NEW_SESSION): + def error(self, session: Session = NEW_SESSION) -> None: """ Forces the task instance's state to FAILED in the database. @@ -819,7 +757,7 @@ def error(self, session=NEW_SESSION): session.commit() @provide_session - def refresh_from_db(self, session=NEW_SESSION, lock_for_update=False) -> None: + def refresh_from_db(self, session: Session = NEW_SESSION, lock_for_update: bool = False) -> None: """ Refreshes the task instance from the database based on the primary key @@ -830,28 +768,35 @@ def refresh_from_db(self, session=NEW_SESSION, lock_for_update=False) -> None: """ self.log.debug("Refreshing TaskInstance %s from DB", self) - qry = session.query(TaskInstance).filter( - TaskInstance.dag_id == self.dag_id, - TaskInstance.task_id == self.task_id, - TaskInstance.run_id == self.run_id, - TaskInstance.map_index == self.map_index, + if self in session: + session.refresh(self, TaskInstance.__mapper__.column_attrs.keys()) + + qry = ( + # To avoid joining any relationships, by default select all + # columns, not the object. This also means we get (effectively) a + # namedtuple back, not a TI object + session.query(*TaskInstance.__table__.columns).filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.run_id == self.run_id, + TaskInstance.map_index == self.map_index, + ) ) if lock_for_update: for attempt in run_with_db_retries(logger=self.log): with attempt: - ti: Optional[TaskInstance] = qry.with_for_update().first() + ti: TaskInstance | None = qry.with_for_update().one_or_none() else: - ti = qry.first() + ti = qry.one_or_none() if ti: # Fields ordered per model definition self.start_date = ti.start_date self.end_date = ti.end_date self.duration = ti.duration self.state = ti.state - # Get the raw value of try_number column, don't read through the - # accessor here otherwise it will be incremented by one already. - self.try_number = ti._try_number + # Since we selected columns, not the object, this is the raw value + self.try_number = ti.try_number self.max_tries = ti.max_tries self.hostname = ti.hostname self.unixname = ti.unixname @@ -872,7 +817,7 @@ def refresh_from_db(self, session=NEW_SESSION, lock_for_update=False) -> None: else: self.state = None - def refresh_from_task(self, task: "Operator", pool_override=None): + def refresh_from_task(self, task: Operator, pool_override: str | None = None) -> None: """ Copy common attributes from the given task. @@ -891,7 +836,7 @@ def refresh_from_task(self, task: "Operator", pool_override=None): self.operator = task.task_type @provide_session - def clear_xcom_data(self, session: Session = NEW_SESSION): + def clear_xcom_data(self, session: Session = NEW_SESSION) -> None: """Clear all XCom data from the database for the task instance. If the task is unmapped, all XComs matching this task ID in the same DAG @@ -902,7 +847,7 @@ def clear_xcom_data(self, session: Session = NEW_SESSION): """ self.log.debug("Clearing XCom data") if self.map_index < 0: - map_index: Optional[int] = None + map_index: int | None = None else: map_index = self.map_index XCom.clear( @@ -919,13 +864,17 @@ def key(self) -> TaskInstanceKey: return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index) @provide_session - def set_state(self, state: Optional[str], session=NEW_SESSION): + def set_state(self, state: str | None, session: Session = NEW_SESSION) -> bool: """ Set TaskInstance state. :param state: State to set for the TI :param session: SQLAlchemy ORM Session + :return: Was the state changed """ + if self.state == state: + return False + current_time = timezone.utcnow() self.log.debug("Setting task state for %s to %s", self, state) self.state = state @@ -934,9 +883,10 @@ def set_state(self, state: Optional[str], session=NEW_SESSION): self.end_date = self.end_date or current_time self.duration = (self.end_date - self.start_date).total_seconds() session.merge(self) + return True @property - def is_premature(self): + def is_premature(self) -> bool: """ Returns whether a task is in UP_FOR_RETRY state and its retry interval has elapsed. @@ -945,7 +895,7 @@ def is_premature(self): return self.state == State.UP_FOR_RETRY and not self.ready_for_retry() @provide_session - def are_dependents_done(self, session=NEW_SESSION): + def are_dependents_done(self, session: Session = NEW_SESSION) -> bool: """ Checks whether the immediate dependents of this task instance have succeeded or have been skipped. This is meant to be used by wait_for_downstream. @@ -973,9 +923,9 @@ def are_dependents_done(self, session=NEW_SESSION): @provide_session def get_previous_dagrun( self, - state: Optional[DagRunState] = None, - session: Optional[Session] = None, - ) -> Optional["DagRun"]: + state: DagRunState | None = None, + session: Session | None = None, + ) -> DagRun | None: """The DagRun that ran before this task instance's DagRun. :param state: If passed, it only take into account instances of a specific state. @@ -1006,9 +956,9 @@ def get_previous_dagrun( @provide_session def get_previous_ti( self, - state: Optional[DagRunState] = None, + state: DagRunState | None = None, session: Session = NEW_SESSION, - ) -> Optional['TaskInstance']: + ) -> TaskInstance | None: """ The task instance for the task that ran before this task instance. @@ -1021,7 +971,7 @@ def get_previous_ti( return dagrun.get_task_instance(self.task_id, session=session) @property - def previous_ti(self): + def previous_ti(self) -> TaskInstance | None: """ This attribute is deprecated. Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method. @@ -1031,13 +981,13 @@ def previous_ti(self): This attribute is deprecated. Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method. """, - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) return self.get_previous_ti() @property - def previous_ti_success(self) -> Optional['TaskInstance']: + def previous_ti_success(self) -> TaskInstance | None: """ This attribute is deprecated. Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method. @@ -1047,7 +997,7 @@ def previous_ti_success(self) -> Optional['TaskInstance']: This attribute is deprecated. Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method. """, - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) return self.get_previous_ti(state=DagRunState.SUCCESS) @@ -1055,9 +1005,9 @@ def previous_ti_success(self) -> Optional['TaskInstance']: @provide_session def get_previous_execution_date( self, - state: Optional[DagRunState] = None, + state: DagRunState | None = None, session: Session = NEW_SESSION, - ) -> Optional[pendulum.DateTime]: + ) -> pendulum.DateTime | None: """ The execution date from property previous_ti_success. @@ -1070,8 +1020,8 @@ def get_previous_execution_date( @provide_session def get_previous_start_date( - self, state: Optional[DagRunState] = None, session: Session = NEW_SESSION - ) -> Optional[pendulum.DateTime]: + self, state: DagRunState | None = None, session: Session = NEW_SESSION + ) -> pendulum.DateTime | None: """ The start date from property previous_ti_success. @@ -1084,7 +1034,7 @@ def get_previous_start_date( return prev_ti and prev_ti.start_date and pendulum.instance(prev_ti.start_date) @property - def previous_start_date_success(self) -> Optional[pendulum.DateTime]: + def previous_start_date_success(self) -> pendulum.DateTime | None: """ This attribute is deprecated. Please use `airflow.models.taskinstance.TaskInstance.get_previous_start_date` method. @@ -1094,13 +1044,15 @@ def previous_start_date_success(self) -> Optional[pendulum.DateTime]: This attribute is deprecated. Please use `airflow.models.taskinstance.TaskInstance.get_previous_start_date` method. """, - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) return self.get_previous_start_date(state=DagRunState.SUCCESS) @provide_session - def are_dependencies_met(self, dep_context=None, session=NEW_SESSION, verbose=False): + def are_dependencies_met( + self, dep_context: DepContext | None = None, session: Session = NEW_SESSION, verbose: bool = False + ) -> bool: """ Returns whether or not all the conditions are met for this task instance to be run given the context for the dependencies (e.g. a task instance being force run from @@ -1132,7 +1084,7 @@ def are_dependencies_met(self, dep_context=None, session=NEW_SESSION, verbose=Fa return True @provide_session - def get_failed_dep_statuses(self, dep_context=None, session=NEW_SESSION): + def get_failed_dep_statuses(self, dep_context: DepContext | None = None, session: Session = NEW_SESSION): """Get failed Dependencies""" dep_context = dep_context or DepContext() for dep in dep_context.deps | self.task.deps: @@ -1149,7 +1101,7 @@ def get_failed_dep_statuses(self, dep_context=None, session=NEW_SESSION): if not dep_status.passed: yield dep_status - def __repr__(self): + def __repr__(self) -> str: prefix = f" bool: """ Checks on whether the task instance is in the right state and timeframe to be retried. @@ -1202,7 +1154,7 @@ def ready_for_retry(self): return self.state == State.UP_FOR_RETRY and self.next_retry_datetime() < timezone.utcnow() @provide_session - def get_dagrun(self, session: Session = NEW_SESSION) -> "DagRun": + def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun: """ Returns the DagRun for this TaskInstance @@ -1218,7 +1170,7 @@ def get_dagrun(self, session: Session = NEW_SESSION) -> "DagRun": dr = session.query(DagRun).filter(DagRun.dag_id == self.dag_id, DagRun.run_id == self.run_id).one() # Record it in the instance for next time. This means that `self.execution_date` will work correctly - set_committed_value(self, 'dag_run', dr) + set_committed_value(self, "dag_run", dr) return dr @@ -1232,10 +1184,10 @@ def check_and_change_state_before_execution( ignore_ti_state: bool = False, mark_success: bool = False, test_mode: bool = False, - job_id: Optional[str] = None, - pool: Optional[str] = None, - external_executor_id: Optional[str] = None, - session=NEW_SESSION, + job_id: str | None = None, + pool: str | None = None, + external_executor_id: str | None = None, + session: Session = NEW_SESSION, ) -> bool: """ Checks dependencies and then sets state to RUNNING if they are met. Returns @@ -1254,7 +1206,6 @@ def check_and_change_state_before_execution( :param external_executor_id: The identifier of the celery executor :param session: SQLAlchemy ORM Session :return: whether the state was changed to running or not - :rtype: bool """ task = self.task self.refresh_from_task(task, pool_override=pool) @@ -1265,7 +1216,7 @@ def check_and_change_state_before_execution( self.pid = None if not ignore_all_deps and not ignore_ti_state and self.state == State.SUCCESS: - Stats.incr('previously_succeeded', 1, 1) + Stats.incr("previously_succeeded", 1, 1) # TODO: Logging needs cleanup, not clear what is being printed hr_line_break = "\n" + ("-" * 80) # Line break @@ -1349,32 +1300,32 @@ def check_and_change_state_before_execution( self.log.info("Executing %s on %s", self.task, self.execution_date) return True - def _date_or_empty(self, attr: str): - result: Optional[datetime] = getattr(self, attr, None) - return result.strftime('%Y%m%dT%H%M%S') if result else '' + def _date_or_empty(self, attr: str) -> str: + result: datetime | None = getattr(self, attr, None) + return result.strftime("%Y%m%dT%H%M%S") if result else "" - def _log_state(self, lead_msg: str = ''): + def _log_state(self, lead_msg: str = "") -> None: params = [ lead_msg, str(self.state).upper(), self.dag_id, self.task_id, ] - message = '%sMarking task as %s. dag_id=%s, task_id=%s, ' + message = "%sMarking task as %s. dag_id=%s, task_id=%s, " if self.map_index >= 0: params.append(self.map_index) - message += 'map_index=%d, ' + message += "map_index=%d, " self.log.info( - message + 'execution_date=%s, start_date=%s, end_date=%s', + message + "execution_date=%s, start_date=%s, end_date=%s", *params, - self._date_or_empty('execution_date'), - self._date_or_empty('start_date'), - self._date_or_empty('end_date'), + self._date_or_empty("execution_date"), + self._date_or_empty("start_date"), + self._date_or_empty("end_date"), ) # Ensure we unset next_method and next_kwargs to ensure that any # retries don't re-use them. - def clear_next_method_args(self): + def clear_next_method_args(self) -> None: self.log.debug("Clearing next_method and next_kwargs.") self.next_method = None @@ -1386,9 +1337,9 @@ def _run_raw_task( self, mark_success: bool = False, test_mode: bool = False, - job_id: Optional[str] = None, - pool: Optional[str] = None, - session=NEW_SESSION, + job_id: str | None = None, + pool: str | None = None, + session: Session = NEW_SESSION, ) -> None: """ Immediately runs the task (without checking or changing db state @@ -1411,10 +1362,10 @@ def _run_raw_task( session.merge(self) session.commit() actual_start_date = timezone.utcnow() - Stats.incr(f'ti.start.{self.task.dag_id}.{self.task.task_id}') + Stats.incr(f"ti.start.{self.task.dag_id}.{self.task.task_id}") # Initialize final state counters at zero for state in State.task_states: - Stats.incr(f'ti.finish.{self.task.dag_id}.{self.task.task_id}.{state}', count=0) + Stats.incr(f"ti.finish.{self.task.dag_id}.{self.task.task_id}.{state}", count=0) self.task = self.task.prepare_for_execution() context = self.get_template_context(ignore_param_exceptions=False) @@ -1429,20 +1380,17 @@ def _run_raw_task( # a trigger. self._defer_task(defer=defer, session=session) self.log.info( - 'Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s', + "Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s", self.dag_id, self.task_id, - self._date_or_empty('execution_date'), - self._date_or_empty('start_date'), + self._date_or_empty("execution_date"), + self._date_or_empty("start_date"), ) if not test_mode: session.add(Log(self.state, self)) session.merge(self) session.commit() return - except AirflowSmartSensorException as e: - self.log.info(e) - return except AirflowSkipException as e: # Recording SKIP # log only if exception has any arguments to prevent log flooding @@ -1481,7 +1429,7 @@ def _run_raw_task( session.commit() raise finally: - Stats.incr(f'ti.finish.{self.dag_id}.{self.task_id}.{self.state}') + Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}") # Recording SKIPPED or SUCCESS self.clear_next_method_args() @@ -1492,14 +1440,26 @@ def _run_raw_task( # run on_success_callback before db committing # otherwise, the LocalTaskJob sees the state is changed to `success`, # but the task_runner is still running, LocalTaskJob then treats the state is set externally! - self._run_finished_callback(self.task.on_success_callback, context, 'on_success') + self._run_finished_callback(self.task.on_success_callback, context, "on_success") if not test_mode: session.add(Log(self.state, self)) - session.merge(self) - + session.merge(self).task = self.task + if self.state == TaskInstanceState.SUCCESS: + self._register_dataset_changes(session=session) session.commit() + def _register_dataset_changes(self, *, session: Session) -> None: + for obj in self.task.outlets or []: + self.log.debug("outlet obj %s", obj) + # Lineage can have other types of objects besides datasets + if isinstance(obj, Dataset): + dataset_manager.register_dataset_change( + task_instance=self, + dataset=obj, + session=session, + ) + def _execute_task_with_callbacks(self, context, test_mode=False): """Prepare Task for Execution""" from airflow.models.renderedtifields import RenderedTaskInstanceFields @@ -1526,9 +1486,9 @@ def signal_handler(signum, frame): if not self.next_method: self.clear_xcom_data() - with Stats.timer(f'dag.{self.task.dag_id}.{self.task.task_id}.duration'): + with Stats.timer(f"dag.{self.task.dag_id}.{self.task.task_id}.duration"): # Set the validated/merged params on the task object. - self.task.params = context['params'] + self.task.params = context["params"] task_orig = self.render_templates(context=context) if not test_mode: @@ -1546,7 +1506,7 @@ def signal_handler(signum, frame): if not self.next_method: self.log.info( "Exporting the following env vars:\n%s", - '\n'.join(f"{k}={v}" for k, v in airflow_context_vars.items()), + "\n".join(f"{k}={v}" for k, v in airflow_context_vars.items()), ) # Run pre_execute callback @@ -1555,22 +1515,6 @@ def signal_handler(signum, frame): # Run on_execute callback self._run_execute_callback(context, self.task) - if self.task.is_smart_sensor_compatible(): - # Try to register it in the smart sensor service. - registered = False - try: - registered = self.task.register_in_sensor_service(self, context) - except Exception: - self.log.warning( - "Failed to register in sensor service." - " Continue to run task in non smart sensor mode.", - exc_info=True, - ) - - if registered: - # Will raise AirflowSmartSensorException to avoid long running execution. - self._update_ti_state_for_sensing() - # Execute the task with set_current_context(context): result = self._execute_task(context, task_orig) @@ -1578,20 +1522,12 @@ def signal_handler(signum, frame): # Run post_execute callback self.task.post_execute(context=context, result=result) - Stats.incr(f'operator_successes_{self.task.task_type}', 1, 1) - Stats.incr('ti_successes') + Stats.incr(f"operator_successes_{self.task.task_type}", 1, 1) + Stats.incr("ti_successes") - @provide_session - def _update_ti_state_for_sensing(self, session=NEW_SESSION): - self.log.info('Submitting %s to sensor service', self) - self.state = State.SENSING - self.start_date = timezone.utcnow() - session.merge(self) - session.commit() - # Raise exception for sensing state - raise AirflowSmartSensorException("Task successfully registered in smart sensor.") - - def _run_finished_callback(self, callback, context, callback_type): + def _run_finished_callback( + self, callback: TaskStateChangeCallback | None, context: Context, callback_type: str + ) -> None: """Run callback after task finishes""" try: if callback: @@ -1654,7 +1590,7 @@ def _execute_task(self, context, task_orig): return result @provide_session - def _defer_task(self, session, defer: TaskDeferred): + def _defer_task(self, session: Session, defer: TaskDeferred) -> None: """ Marks the task as deferred and sets up the trigger that is needed to resume it. @@ -1692,7 +1628,7 @@ def _defer_task(self, session, defer: TaskDeferred): else: self.trigger_timeout = self.start_date + execution_timeout - def _run_execute_callback(self, context: Context, task): + def _run_execute_callback(self, context: Context, task: Operator) -> None: """Functions that need to be run before a Task is executed""" try: if task.on_execute_callback: @@ -1710,9 +1646,9 @@ def run( ignore_ti_state: bool = False, mark_success: bool = False, test_mode: bool = False, - job_id: Optional[str] = None, - pool: Optional[str] = None, - session=NEW_SESSION, + job_id: str | None = None, + pool: str | None = None, + session: Session = NEW_SESSION, ) -> None: """Run TaskInstance""" res = self.check_and_change_state_before_execution( @@ -1734,13 +1670,14 @@ def run( mark_success=mark_success, test_mode=test_mode, job_id=job_id, pool=pool, session=session ) - def dry_run(self): + def dry_run(self) -> None: """Only Renders Templates for the TI""" from airflow.models.baseoperator import BaseOperator self.task = self.task.prepare_for_execution() self.render_templates() - assert isinstance(self.task, BaseOperator) # For Mypy. + if TYPE_CHECKING: + assert isinstance(self.task, BaseOperator) self.task.dry_run() @provide_session @@ -1777,6 +1714,7 @@ def _handle_reschedule( actual_start_date, self.end_date, reschedule_exception.reschedule_date, + self.map_index, ) ) @@ -1791,10 +1729,10 @@ def _handle_reschedule( session.merge(self) session.commit() - self.log.info('Rescheduling task, marking task as UP_FOR_RESCHEDULE') + self.log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE") @staticmethod - def get_truncated_error_traceback(error: BaseException, truncate_to: Callable) -> Optional[TracebackType]: + def get_truncated_error_traceback(error: BaseException, truncate_to: Callable) -> TracebackType | None: """ Truncates the traceback of an exception to the first frame called from within a given function @@ -1812,14 +1750,18 @@ def get_truncated_error_traceback(error: BaseException, truncate_to: Callable) - return tb or error.__traceback__ @provide_session - def handle_failure(self, error, test_mode=None, context=None, force_fail=False, session=None) -> None: + def handle_failure( + self, + error: None | str | Exception | KeyboardInterrupt, + test_mode: bool | None = None, + context: Context | None = None, + force_fail: bool = False, + session: Session = NEW_SESSION, + ) -> None: """Handle Failure for the TaskInstance""" if test_mode is None: test_mode = self.test_mode - if context is None: - context = self.get_template_context() - if error: if isinstance(error, BaseException): tb = self.get_truncated_error_traceback(error, truncate_to=self._execute_task) @@ -1831,8 +1773,8 @@ def handle_failure(self, error, test_mode=None, context=None, force_fail=False, self.end_date = timezone.utcnow() self.set_duration() - Stats.incr(f'operator_failures_{self.task.task_type}') - Stats.incr('ti_failures') + Stats.incr(f"operator_failures_{self.operator}") + Stats.incr("ti_failures") if not test_mode: session.add(Log(State.FAILED, self)) @@ -1841,8 +1783,12 @@ def handle_failure(self, error, test_mode=None, context=None, force_fail=False, self.clear_next_method_args() + # In extreme cases (zombie in case of dag with parse error) we might _not_ have a Task. + if context is None and getattr(self, "task", None): + context = self.get_template_context(session) + if context is not None: - context['exception'] = error + context["exception"] = error # Set state correctly and figure out how to log it and decide whether # to email @@ -1856,34 +1802,35 @@ def handle_failure(self, error, test_mode=None, context=None, force_fail=False, # only mark task instance as FAILED if the next task instance # try_number exceeds the max_tries ... or if force_fail is truthy - task = None + task: BaseOperator | None = None try: - task = self.task.unmap() + if getattr(self, "task", None) and context: + task = self.task.unmap((context, session)) except Exception: - self.log.error("Unable to unmap task, can't determine if we need to send an alert email or not") + self.log.error("Unable to unmap task to determine if we need to send an alert email") if force_fail or not self.is_eligible_to_retry(): self.state = State.FAILED - email_for_state = operator.attrgetter('email_on_failure') + email_for_state = operator.attrgetter("email_on_failure") callback = task.on_failure_callback if task else None - callback_type = 'on_failure' + callback_type = "on_failure" else: if self.state == State.QUEUED: # We increase the try_number so as to fail the task if it fails to start after sometime self._try_number += 1 self.state = State.UP_FOR_RETRY - email_for_state = operator.attrgetter('email_on_retry') + email_for_state = operator.attrgetter("email_on_retry") callback = task.on_retry_callback if task else None - callback_type = 'on_retry' + callback_type = "on_retry" - self._log_state('Immediate failure requested. ' if force_fail else '') + self._log_state("Immediate failure requested. " if force_fail else "") if task and email_for_state(task) and task.email: try: self.email_alert(error, task) except Exception: - self.log.exception('Failed to send email to: %s', task.email) + self.log.exception("Failed to send email to: %s", task.email) - if callback: + if callback and context: self._run_finished_callback(callback, context, callback_type) if not test_mode: @@ -1896,6 +1843,9 @@ def is_eligible_to_retry(self): # If a task is cleared when running, it goes into RESTARTING state and is always # eligible for retry return True + if not getattr(self, "task", None): + # Couldn't load the task, don't know number of retries, guess: + return self.try_number <= self.max_tries return self.task.retries and self.try_number <= self.max_tries @@ -1908,56 +1858,50 @@ def get_template_context( session = settings.Session() from airflow import macros + from airflow.models.abstractoperator import NotMapped integrate_macros_plugins() task = self.task - assert task.dag # For Mypy. + if TYPE_CHECKING: + assert task.dag dag: DAG = task.dag dag_run = self.get_dagrun(session) data_interval = dag.get_run_data_interval(dag_run) - # Validates Params and convert them into a simple dict. - params = ParamsDict(suppress_exception=ignore_param_exceptions) - with contextlib.suppress(AttributeError): - params.update(dag.params) - if task.params: - params.update(task.params) - if conf.getboolean('core', 'dag_run_conf_overrides_params'): - self.overwrite_params_with_dag_run_conf(params=params, dag_run=dag_run) - validated_params = params.validate() + validated_params = process_params(dag, task, dag_run, suppress_exception=ignore_param_exceptions) logical_date = timezone.coerce_datetime(self.execution_date) - ds = logical_date.strftime('%Y-%m-%d') - ds_nodash = ds.replace('-', '') + ds = logical_date.strftime("%Y-%m-%d") + ds_nodash = ds.replace("-", "") ts = logical_date.isoformat() - ts_nodash = logical_date.strftime('%Y%m%dT%H%M%S') - ts_nodash_with_tz = ts.replace('-', '').replace(':', '') + ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S") + ts_nodash_with_tz = ts.replace("-", "").replace(":", "") @cache # Prevent multiple database access. - def _get_previous_dagrun_success() -> Optional["DagRun"]: + def _get_previous_dagrun_success() -> DagRun | None: return self.get_previous_dagrun(state=DagRunState.SUCCESS, session=session) - def _get_previous_dagrun_data_interval_success() -> Optional["DataInterval"]: + def _get_previous_dagrun_data_interval_success() -> DataInterval | None: dagrun = _get_previous_dagrun_success() if dagrun is None: return None return dag.get_run_data_interval(dagrun) - def get_prev_data_interval_start_success() -> Optional[pendulum.DateTime]: + def get_prev_data_interval_start_success() -> pendulum.DateTime | None: data_interval = _get_previous_dagrun_data_interval_success() if data_interval is None: return None return data_interval.start - def get_prev_data_interval_end_success() -> Optional[pendulum.DateTime]: + def get_prev_data_interval_end_success() -> pendulum.DateTime | None: data_interval = _get_previous_dagrun_data_interval_success() if data_interval is None: return None return data_interval.end - def get_prev_start_date_success() -> Optional[pendulum.DateTime]: + def get_prev_start_date_success() -> pendulum.DateTime | None: dagrun = _get_previous_dagrun_success() if dagrun is None: return None @@ -1965,20 +1909,20 @@ def get_prev_start_date_success() -> Optional[pendulum.DateTime]: @cache def get_yesterday_ds() -> str: - return (logical_date - timedelta(1)).strftime('%Y-%m-%d') + return (logical_date - timedelta(1)).strftime("%Y-%m-%d") def get_yesterday_ds_nodash() -> str: - return get_yesterday_ds().replace('-', '') + return get_yesterday_ds().replace("-", "") @cache def get_tomorrow_ds() -> str: - return (logical_date + timedelta(1)).strftime('%Y-%m-%d') + return (logical_date + timedelta(1)).strftime("%Y-%m-%d") def get_tomorrow_ds_nodash() -> str: - return get_tomorrow_ds().replace('-', '') + return get_tomorrow_ds().replace("-", "") @cache - def get_next_execution_date() -> Optional[pendulum.DateTime]: + def get_next_execution_date() -> pendulum.DateTime | None: # For manually triggered dagruns that aren't run on a schedule, # the "next" execution date doesn't make sense, and should be set # to execution date for consistency with how execution_date is set @@ -1992,17 +1936,17 @@ def get_next_execution_date() -> Optional[pendulum.DateTime]: return None return timezone.coerce_datetime(next_info.logical_date) - def get_next_ds() -> Optional[str]: + def get_next_ds() -> str | None: execution_date = get_next_execution_date() if execution_date is None: return None - return execution_date.strftime('%Y-%m-%d') + return execution_date.strftime("%Y-%m-%d") - def get_next_ds_nodash() -> Optional[str]: + def get_next_ds_nodash() -> str | None: ds = get_next_ds() if ds is None: return ds - return ds.replace('-', '') + return ds.replace("-", "") @cache def get_prev_execution_date(): @@ -2013,70 +1957,91 @@ def get_prev_execution_date(): if dag_run.external_trigger: return logical_date with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + warnings.simplefilter("ignore", RemovedInAirflow3Warning) return dag.previous_schedule(logical_date) @cache - def get_prev_ds() -> Optional[str]: + def get_prev_ds() -> str | None: execution_date = get_prev_execution_date() if execution_date is None: return None - return execution_date.strftime(r'%Y-%m-%d') + return execution_date.strftime(r"%Y-%m-%d") - def get_prev_ds_nodash() -> Optional[str]: + def get_prev_ds_nodash() -> str | None: prev_ds = get_prev_ds() if prev_ds is None: return None - return prev_ds.replace('-', '') + return prev_ds.replace("-", "") + + def get_triggering_events() -> dict[str, list[DatasetEvent]]: + nonlocal dag_run + # The dag_run may not be attached to the session anymore (code base is over-zealous with use of + # `session.expunge_all()`) so re-attach it if we get called + if dag_run not in session: + dag_run = session.merge(dag_run, load=False) + + dataset_events = dag_run.consumed_dataset_events + triggering_events: dict[str, list[DatasetEvent]] = defaultdict(list) + for event in dataset_events: + triggering_events[event.dataset.uri].append(event) + + return triggering_events + + try: + expanded_ti_count: int | None = task.get_mapped_ti_count(self.run_id, session=session) + except NotMapped: + expanded_ti_count = None # NOTE: If you add anything to this dict, make sure to also update the # definition in airflow/utils/context.pyi, and KNOWN_CONTEXT_KEYS in # airflow/utils/context.py! context = { - 'conf': conf, - 'dag': dag, - 'dag_run': dag_run, - 'data_interval_end': timezone.coerce_datetime(data_interval.end), - 'data_interval_start': timezone.coerce_datetime(data_interval.start), - 'ds': ds, - 'ds_nodash': ds_nodash, - 'execution_date': logical_date, - 'inlets': task.inlets, - 'logical_date': logical_date, - 'macros': macros, - 'next_ds': get_next_ds(), - 'next_ds_nodash': get_next_ds_nodash(), - 'next_execution_date': get_next_execution_date(), - 'outlets': task.outlets, - 'params': validated_params, - 'prev_data_interval_start_success': get_prev_data_interval_start_success(), - 'prev_data_interval_end_success': get_prev_data_interval_end_success(), - 'prev_ds': get_prev_ds(), - 'prev_ds_nodash': get_prev_ds_nodash(), - 'prev_execution_date': get_prev_execution_date(), - 'prev_execution_date_success': self.get_previous_execution_date( + "conf": conf, + "dag": dag, + "dag_run": dag_run, + "data_interval_end": timezone.coerce_datetime(data_interval.end), + "data_interval_start": timezone.coerce_datetime(data_interval.start), + "ds": ds, + "ds_nodash": ds_nodash, + "execution_date": logical_date, + "expanded_ti_count": expanded_ti_count, + "inlets": task.inlets, + "logical_date": logical_date, + "macros": macros, + "next_ds": get_next_ds(), + "next_ds_nodash": get_next_ds_nodash(), + "next_execution_date": get_next_execution_date(), + "outlets": task.outlets, + "params": validated_params, + "prev_data_interval_start_success": get_prev_data_interval_start_success(), + "prev_data_interval_end_success": get_prev_data_interval_end_success(), + "prev_ds": get_prev_ds(), + "prev_ds_nodash": get_prev_ds_nodash(), + "prev_execution_date": get_prev_execution_date(), + "prev_execution_date_success": self.get_previous_execution_date( state=DagRunState.SUCCESS, session=session, ), - 'prev_start_date_success': get_prev_start_date_success(), - 'run_id': self.run_id, - 'task': task, - 'task_instance': self, - 'task_instance_key_str': f"{task.dag_id}__{task.task_id}__{ds_nodash}", - 'test_mode': self.test_mode, - 'ti': self, - 'tomorrow_ds': get_tomorrow_ds(), - 'tomorrow_ds_nodash': get_tomorrow_ds_nodash(), - 'ts': ts, - 'ts_nodash': ts_nodash, - 'ts_nodash_with_tz': ts_nodash_with_tz, - 'var': { - 'json': VariableAccessor(deserialize_json=True), - 'value': VariableAccessor(deserialize_json=False), + "prev_start_date_success": get_prev_start_date_success(), + "run_id": self.run_id, + "task": task, + "task_instance": self, + "task_instance_key_str": f"{task.dag_id}__{task.task_id}__{ds_nodash}", + "test_mode": self.test_mode, + "ti": self, + "tomorrow_ds": get_tomorrow_ds(), + "tomorrow_ds_nodash": get_tomorrow_ds_nodash(), + "triggering_dataset_events": lazy_object_proxy.Proxy(get_triggering_events), + "ts": ts, + "ts_nodash": ts_nodash, + "ts_nodash_with_tz": ts_nodash_with_tz, + "var": { + "json": VariableAccessor(deserialize_json=True), + "value": VariableAccessor(deserialize_json=False), }, - 'conn': ConnectionAccessor(), - 'yesterday_ds': get_yesterday_ds(), - 'yesterday_ds_nodash': get_yesterday_ds_nodash(), + "conn": ConnectionAccessor(), + "yesterday_ds": get_yesterday_ds(), + "yesterday_ds_nodash": get_yesterday_ds_nodash(), } # Mypy doesn't like turning existing dicts in to a TypeDict -- and we "lie" in the type stub to say it # is one, but in practice it isn't. See https://github.com/python/mypy/issues/8890 @@ -2092,7 +2057,7 @@ def get_rendered_template_fields(self, session: Session = NEW_SESSION) -> None: rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self, session=session) if rendered_task_instance_fields: - self.task = self.task.unmap() + self.task = self.task.unmap(None) for field_name, rendered_value in rendered_task_instance_fields.items(): setattr(self.task, field_name, rendered_value) return @@ -2114,7 +2079,7 @@ def get_rendered_template_fields(self, session: Session = NEW_SESSION) -> None: ) from e @provide_session - def get_rendered_k8s_spec(self, session=NEW_SESSION): + def get_rendered_k8s_spec(self, session: Session = NEW_SESSION): """Fetch rendered template fields from DB""" from airflow.models.renderedtifields import RenderedTaskInstanceFields @@ -2132,7 +2097,7 @@ def overwrite_params_with_dag_run_conf(self, params, dag_run): self.log.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf) params.update(dag_run.conf) - def render_templates(self, context: Optional[Context] = None) -> "Operator": + def render_templates(self, context: Context | None = None) -> Operator: """Render templates in the operator fields. If the task was originally mapped, this may replace ``self.task`` with @@ -2141,13 +2106,17 @@ def render_templates(self, context: Optional[Context] = None) -> "Operator": """ if not context: context = self.get_template_context() - rendered_task = self.task.render_template_fields(context) - if rendered_task is None: # Compatibility -- custom renderer, assume unmapped. - return self.task - original_task, self.task = self.task, rendered_task + original_task = self.task + + # If self.task is mapped, this call replaces self.task to point to the + # unmapped BaseOperator created by this function! This is because the + # MappedOperator is useless for template rendering, and we need to be + # able to access the unmapped task instead. + original_task.render_template_fields(context) + return original_task - def render_k8s_pod_yaml(self) -> Optional[dict]: + def render_k8s_pod_yaml(self) -> dict | None: """Render k8s pod yaml""" from kubernetes.client.api_client import ApiClient @@ -2176,40 +2145,39 @@ def render_k8s_pod_yaml(self) -> Optional[dict]: return sanitized_pod def get_email_subject_content( - self, exception: BaseException, task: Optional["BaseOperator"] = None - ) -> Tuple[str, str, str]: + self, exception: BaseException, task: BaseOperator | None = None + ) -> tuple[str, str, str]: """Get the email subject content for exceptions.""" # For a ti from DB (without ti.task), return the default value - # Reuse it for smart sensor to send default email alert if task is None: - task = getattr(self, 'task') + task = getattr(self, "task") use_default = task is None - exception_html = str(exception).replace('\n', '
') + exception_html = str(exception).replace("\n", "
") - default_subject = 'Airflow alert: {{ti}}' + default_subject = "Airflow alert: {{ti}}" # For reporting purposes, we report based on 1-indexed, # not 0-indexed lists (i.e. Try 1 instead of # Try 0 for the first attempt). default_html_content = ( - 'Try {{try_number}} out of {{max_tries + 1}}
' - 'Exception:
{{exception_html}}
' + "Try {{try_number}} out of {{max_tries + 1}}
" + "Exception:
{{exception_html}}
" 'Log: Link
' - 'Host: {{ti.hostname}}
' + "Host: {{ti.hostname}}
" 'Mark success: Link
' ) default_html_content_err = ( - 'Try {{try_number}} out of {{max_tries + 1}}
' - 'Exception:
Failed attempt to attach error logs
' + "Try {{try_number}} out of {{max_tries + 1}}
" + "Exception:
Failed attempt to attach error logs
" 'Log: Link
' - 'Host: {{ti.hostname}}
' + "Host: {{ti.hostname}}
" 'Mark success: Link
' ) # This function is called after changing the state from State.RUNNING, # so we need to subtract 1 from self.try_number here. current_try_number = self.try_number - 1 - additional_context: Dict[str, Any] = { + additional_context: dict[str, Any] = { "exception": exception, "exception_html": exception_html, "try_number": current_try_number, @@ -2238,19 +2206,24 @@ def get_email_subject_content( context_merge(jinja_context, additional_context) def render(key: str, content: str) -> str: - if conf.has_option('email', key): - path = conf.get_mandatory_value('email', key) - with open(path) as f: - content = f.read() + if conf.has_option("email", key): + path = conf.get_mandatory_value("email", key) + try: + with open(path) as f: + content = f.read() + except FileNotFoundError: + self.log.warning(f"Could not find email template file '{path!r}'. Using defaults...") + except OSError: + self.log.exception(f"Error while using email template '{path!r}'. Using defaults...") return render_template_to_string(jinja_env.from_string(content), jinja_context) - subject = render('subject_template', default_subject) - html_content = render('html_content_template', default_html_content) - html_content_err = render('html_content_template', default_html_content_err) + subject = render("subject_template", default_subject) + html_content = render("html_content_template", default_html_content) + html_content_err = render("html_content_template", default_html_content_err) return subject, html_content, html_content_err - def email_alert(self, exception, task: "BaseOperator"): + def email_alert(self, exception, task: BaseOperator) -> None: """Send alert email with exception information.""" subject, html_content, html_content_err = self.get_email_subject_content(exception, task=task) assert task.email @@ -2267,18 +2240,18 @@ def set_duration(self) -> None: self.duration = None self.log.debug("Task Duration set to %s", self.duration) - def _record_task_map_for_downstreams(self, task: "Operator", value: Any, *, session: Session) -> None: + def _record_task_map_for_downstreams(self, task: Operator, value: Any, *, session: Session) -> None: + if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate. + return # TODO: We don't push TaskMap for mapped task instances because it's not # currently possible for a downstream to depend on one individual mapped - # task instance, only a task as a whole. This will change in AIP-42 - # Phase 2, and we'll need to further analyze the mapped task case. - if next(task.iter_mapped_dependants(), None) is None: + # task instance. This will change when we implement task mapping inside + # a mapped task group, and we'll need to further analyze the case. + if isinstance(task, MappedOperator): return if value is None: raise XComForMappingNotPushed() - if task.is_mapped: - return - if not isinstance(value, collections.abc.Collection) or isinstance(value, (bytes, str)): + if not _is_mappable_value(value): raise UnmappableXComTypePushed(value) task_map = TaskMap.from_task_instance_xcom(self, value) max_map_length = conf.getint("core", "max_map_length", fallback=1024) @@ -2291,7 +2264,7 @@ def xcom_push( self, key: str, value: Any, - execution_date: Optional[datetime] = None, + execution_date: datetime | None = None, session: Session = NEW_SESSION, ) -> None: """ @@ -2307,12 +2280,12 @@ def xcom_push( self_execution_date = self.get_dagrun(session).execution_date if execution_date < self_execution_date: raise ValueError( - f'execution_date can not be in the past (current execution_date is ' - f'{self_execution_date}; received {execution_date})' + f"execution_date can not be in the past (current execution_date is " + f"{self_execution_date}; received {execution_date})" ) elif execution_date is not None: message = "Passing 'execution_date' to 'TaskInstance.xcom_push()' is deprecated." - warnings.warn(message, DeprecationWarning, stacklevel=3) + warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) XCom.set( key=key, @@ -2327,13 +2300,13 @@ def xcom_push( @provide_session def xcom_pull( self, - task_ids: Optional[Union[str, Iterable[str]]] = None, - dag_id: Optional[str] = None, + task_ids: str | Iterable[str] | None = None, + dag_id: str | None = None, key: str = XCOM_RETURN_KEY, include_prior_dates: bool = False, session: Session = NEW_SESSION, *, - map_indexes: Optional[Union[int, Iterable[int]]] = None, + map_indexes: int | Iterable[int] | None = None, default: Any = None, ) -> Any: """Pull XComs that optionally meet certain criteria. @@ -2392,38 +2365,37 @@ def xcom_pull( return default if map_indexes is not None or first.map_index < 0: return XCom.deserialize_value(first) - - return _LazyXComAccess.build_from_single_xcom(first, query) + query = query.order_by(None).order_by(XCom.map_index.asc()) + return LazyXComAccess.build_from_xcom_query(query) # At this point either task_ids or map_indexes is explicitly multi-value. - - results = ( - (r.task_id, r.map_index, XCom.deserialize_value(r)) - for r in query.with_entities(XCom.task_id, XCom.map_index, XCom.value) - ) - - if task_ids is None: - task_id_pos: Dict[str, int] = defaultdict(int) - elif isinstance(task_ids, str): - task_id_pos = {task_ids: 0} + # Order return values to match task_ids and map_indexes ordering. + query = query.order_by(None) + if task_ids is None or isinstance(task_ids, str): + query = query.order_by(XCom.task_id) else: - task_id_pos = {task_id: i for i, task_id in enumerate(task_ids)} - if map_indexes is None: - map_index_pos: Dict[int, int] = defaultdict(int) - elif isinstance(map_indexes, int): - map_index_pos = {map_indexes: 0} + task_id_whens = {tid: i for i, tid in enumerate(task_ids)} + if task_id_whens: + query = query.order_by(case(task_id_whens, value=XCom.task_id)) + else: + query = query.order_by(XCom.task_id) + if map_indexes is None or isinstance(map_indexes, int): + query = query.order_by(XCom.map_index) + elif isinstance(map_indexes, range): + order = XCom.map_index + if map_indexes.step < 0: + order = order.desc() + query = query.order_by(order) else: - map_index_pos = {map_index: i for i, map_index in enumerate(map_indexes)} - - def _arg_pos(item: Tuple[str, int, Any]) -> Tuple[int, int]: - task_id, map_index, _ = item - return task_id_pos[task_id], map_index_pos[map_index] - - results_sorted_by_arg_pos = sorted(results, key=_arg_pos) - return [value for _, _, value in results_sorted_by_arg_pos] + map_index_whens = {map_index: i for i, map_index in enumerate(map_indexes)} + if map_index_whens: + query = query.order_by(case(map_index_whens, value=XCom.map_index)) + else: + query = query.order_by(XCom.map_index) + return LazyXComAccess.build_from_xcom_query(query) @provide_session - def get_num_running_task_instances(self, session): + def get_num_running_task_instances(self, session: Session) -> int: """Return Number of running TIs from the DB""" # .count() is inefficient return ( @@ -2436,13 +2408,13 @@ def get_num_running_task_instances(self, session): .scalar() ) - def init_run_context(self, raw=False): + def init_run_context(self, raw: bool = False) -> None: """Sets the log context.""" self.raw = raw self._set_context(self) @staticmethod - def filter_for_tis(tis: Iterable[Union["TaskInstance", TaskInstanceKey]]) -> Optional[BooleanClauseList]: + def filter_for_tis(tis: Iterable[TaskInstance | TaskInstanceKey]) -> BooleanClauseList | None: """Returns SQLAlchemy filter to query selected task instances""" # DictKeys type, (what we often pass here from the scheduler) is not directly indexable :( # Or it might be a generator, but we need to be able to iterate over it more than once @@ -2457,37 +2429,81 @@ def filter_for_tis(tis: Iterable[Union["TaskInstance", TaskInstanceKey]]) -> Opt run_id = first.run_id map_index = first.map_index first_task_id = first.task_id + + # pre-compute the set of dag_id, run_id, map_indices and task_ids + dag_ids, run_ids, map_indices, task_ids = set(), set(), set(), set() + for t in tis: + dag_ids.add(t.dag_id) + run_ids.add(t.run_id) + map_indices.add(t.map_index) + task_ids.add(t.task_id) + # Common path optimisations: when all TIs are for the same dag_id and run_id, or same dag_id # and task_id -- this can be over 150x faster for huge numbers of TIs (20k+) - if all(t.dag_id == dag_id and t.run_id == run_id and t.map_index == map_index for t in tis): + if dag_ids == {dag_id} and run_ids == {run_id} and map_indices == {map_index}: return and_( TaskInstance.dag_id == dag_id, TaskInstance.run_id == run_id, TaskInstance.map_index == map_index, - TaskInstance.task_id.in_(t.task_id for t in tis), + TaskInstance.task_id.in_(task_ids), ) - if all(t.dag_id == dag_id and t.task_id == first_task_id and t.map_index == map_index for t in tis): + if dag_ids == {dag_id} and task_ids == {first_task_id} and map_indices == {map_index}: return and_( TaskInstance.dag_id == dag_id, - TaskInstance.run_id.in_(t.run_id for t in tis), + TaskInstance.run_id.in_(run_ids), TaskInstance.map_index == map_index, TaskInstance.task_id == first_task_id, ) - if all(t.dag_id == dag_id and t.run_id == run_id and t.task_id == first_task_id for t in tis): + if dag_ids == {dag_id} and run_ids == {run_id} and task_ids == {first_task_id}: return and_( TaskInstance.dag_id == dag_id, TaskInstance.run_id == run_id, - TaskInstance.map_index.in_(t.map_index for t in tis), + TaskInstance.map_index.in_(map_indices), TaskInstance.task_id == first_task_id, ) - return tuple_in_condition( - (TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.run_id, TaskInstance.map_index), - (ti.key.primary for ti in tis), - ) + filter_condition = [] + # create 2 nested groups, both primarily grouped by dag_id and run_id, + # and in the nested group 1 grouped by task_id the other by map_index. + task_id_groups: dict[tuple, dict[Any, list[Any]]] = defaultdict(lambda: defaultdict(list)) + map_index_groups: dict[tuple, dict[Any, list[Any]]] = defaultdict(lambda: defaultdict(list)) + for t in tis: + task_id_groups[(t.dag_id, t.run_id)][t.task_id].append(t.map_index) + map_index_groups[(t.dag_id, t.run_id)][t.map_index].append(t.task_id) + + # this assumes that most dags have dag_id as the largest grouping, followed by run_id. even + # if its not, this is still a significant optimization over querying for every single tuple key + for cur_dag_id in dag_ids: + for cur_run_id in run_ids: + # we compare the group size between task_id and map_index and use the smaller group + dag_task_id_groups = task_id_groups[(cur_dag_id, cur_run_id)] + dag_map_index_groups = map_index_groups[(cur_dag_id, cur_run_id)] + + if len(dag_task_id_groups) <= len(dag_map_index_groups): + for cur_task_id, cur_map_indices in dag_task_id_groups.items(): + filter_condition.append( + and_( + TaskInstance.dag_id == cur_dag_id, + TaskInstance.run_id == cur_run_id, + TaskInstance.task_id == cur_task_id, + TaskInstance.map_index.in_(cur_map_indices), + ) + ) + else: + for cur_map_index, cur_task_ids in dag_map_index_groups.items(): + filter_condition.append( + and_( + TaskInstance.dag_id == cur_dag_id, + TaskInstance.run_id == cur_run_id, + TaskInstance.task_id.in_(cur_task_ids), + TaskInstance.map_index == cur_map_index, + ) + ) + + return or_(*filter_condition) @classmethod - def ti_selector_condition(cls, vals: Collection[Union[str, Tuple[str, int]]]) -> ColumnOperators: + def ti_selector_condition(cls, vals: Collection[str | tuple[str, int]]) -> ColumnOperators: """ Build an SQLAlchemy filter for a list where each element can contain whether a task_id, or a tuple of (task_id,map_index) @@ -2499,7 +2515,7 @@ def ti_selector_condition(cls, vals: Collection[Union[str, Tuple[str, int]]]) -> task_id_only = [v for v in vals if isinstance(v, str)] with_map_index = [v for v in vals if not isinstance(v, str)] - filters: List[ColumnOperators] = [] + filters: list[ColumnOperators] = [] if task_id_only: filters.append(cls.task_id.in_(task_id_only)) if with_map_index: @@ -2511,6 +2527,164 @@ def ti_selector_condition(cls, vals: Collection[Union[str, Tuple[str, int]]]) -> return filters[0] return or_(*filters) + @Sentry.enrich_errors + @provide_session + def schedule_downstream_tasks(self, session=None): + """ + The mini-scheduler for scheduling downstream tasks of this task instance + :meta: private + """ + from sqlalchemy.exc import OperationalError + + from airflow.models import DagRun + + try: + # Re-select the row with a lock + dag_run = with_row_locks( + session.query(DagRun).filter_by( + dag_id=self.dag_id, + run_id=self.run_id, + ), + session=session, + ).one() + + task = self.task + if TYPE_CHECKING: + assert task.dag + + # Get a partial DAG with just the specific tasks we want to examine. + # In order for dep checks to work correctly, we include ourself (so + # TriggerRuleDep can check the state of the task we just executed). + partial_dag = task.dag.partial_subset( + task.downstream_task_ids, + include_downstream=True, + include_upstream=False, + include_direct_upstream=True, + ) + + dag_run.dag = partial_dag + info = dag_run.task_instance_scheduling_decisions(session) + + skippable_task_ids = { + task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids + } + + schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids] + for schedulable_ti in schedulable_tis: + if not hasattr(schedulable_ti, "task"): + schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id) + + num = dag_run.schedule_tis(schedulable_tis, session=session) + self.log.info("%d downstream tasks scheduled from follow-on schedule check", num) + + session.flush() + + except OperationalError as e: + # Any kind of DB error here is _non fatal_ as this block is just an optimisation. + self.log.info( + "Skipping mini scheduling run due to exception: %s", + e.statement, + exc_info=True, + ) + session.rollback() + + def get_relevant_upstream_map_indexes( + self, + upstream: Operator, + ti_count: int | None, + *, + session: Session, + ) -> int | range | None: + """Infer the map indexes of an upstream "relevant" to this ti. + + The bulk of the logic mainly exists to solve the problem described by + the following example, where 'val' must resolve to different values, + depending on where the reference is being used:: + + @task + def this_task(v): # This is self.task. + return v * 2 + + @task_group + def tg1(inp): + val = upstream(inp) # This is the upstream task. + this_task(val) # When inp is 1, val here should resolve to 2. + return val + + # This val is the same object returned by tg1. + val = tg1.expand(inp=[1, 2, 3]) + + @task_group + def tg2(inp): + another_task(inp, val) # val here should resolve to [2, 4, 6]. + + tg2.expand(inp=["a", "b"]) + + The surrounding mapped task groups of ``upstream`` and ``self.task`` are + inspected to find a common "ancestor". If such an ancestor is found, + we need to return specific map indexes to pull a partial value from + upstream XCom. + + :param upstream: The referenced upstream task. + :param ti_count: The total count of task instance this task was expanded + by the scheduler, i.e. ``expanded_ti_count`` in the template context. + :return: Specific map index or map indexes to pull, or ``None`` if we + want to "whole" return value (i.e. no mapped task groups involved). + """ + # Find the innermost common mapped task group between the current task + # If the current task and the referenced task does not have a common + # mapped task group, the two are in different task mapping contexts + # (like another_task above), and we should use the "whole" value. + common_ancestor = _find_common_ancestor_mapped_group(self.task, upstream) + if common_ancestor is None: + return None + + # This value should never be None since we already know the current task + # is in a mapped task group, and should have been expanded. The check + # exists mainly to satisfy Mypy. + if ti_count is None: + return None + + # At this point we know the two tasks share a mapped task group, and we + # should use a "partial" value. Let's break down the mapped ti count + # between the ancestor and further expansion happened inside it. + ancestor_ti_count = common_ancestor.get_mapped_ti_count(self.run_id, session=session) + ancestor_map_index = self.map_index * ancestor_ti_count // ti_count + + # If the task is NOT further expanded inside the common ancestor, we + # only want to reference one single ti. We must walk the actual DAG, + # and "ti_count == ancestor_ti_count" does not work, since the further + # expansion may be of length 1. + if not _is_further_mapped_inside(upstream, common_ancestor): + return ancestor_map_index + + # Otherwise we need a partial aggregation for values from selected task + # instances in the ancestor's expansion context. + further_count = ti_count // ancestor_ti_count + map_index_start = ancestor_map_index * further_count + return range(map_index_start, map_index_start + further_count) + + +def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> MappedTaskGroup | None: + """Given two operators, find their innermost common mapped task group.""" + if node1.dag is None or node2.dag is None or node1.dag_id != node2.dag_id: + return None + parent_group_ids = {g.group_id for g in node1.iter_mapped_task_groups()} + common_groups = (g for g in node2.iter_mapped_task_groups() if g.group_id in parent_group_ids) + return next(common_groups, None) + + +def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool: + """Whether given operator is *further* mapped inside a task group.""" + if isinstance(operator, MappedOperator): + return True + task_group = operator.task_group + while task_group is not None and task_group.group_id != container.group_id: + if isinstance(task_group, MappedTaskGroup): + return True + task_group = task_group.parent_group + return False + # State of the task instance. # Stores string version of the task state. @@ -2529,8 +2703,8 @@ def __init__( dag_id: str, task_id: str, run_id: str, - start_date: Optional[datetime], - end_date: Optional[datetime], + start_date: datetime | None, + end_date: datetime | None, try_number: int, map_index: int, state: str, @@ -2538,8 +2712,8 @@ def __init__( pool: str, queue: str, key: TaskInstanceKey, - run_as_user: Optional[str] = None, - priority_weight: Optional[int] = None, + run_as_user: str | None = None, + priority_weight: int | None = None, ): self.dag_id = dag_id self.task_id = task_id @@ -2561,8 +2735,23 @@ def __eq__(self, other): return self.__dict__ == other.__dict__ return NotImplemented + def as_dict(self): + warnings.warn( + "This method is deprecated. Use BaseSerialization.serialize.", + RemovedInAirflow3Warning, + stacklevel=2, + ) + new_dict = dict(self.__dict__) + for key in new_dict: + if key in ["start_date", "end_date"]: + val = new_dict[key] + if not val or isinstance(val, str): + continue + new_dict.update({key: val.isoformat()}) + return new_dict + @classmethod - def from_ti(cls, ti: TaskInstance): + def from_ti(cls, ti: TaskInstance) -> SimpleTaskInstance: return cls( dag_id=ti.dag_id, task_id=ti.task_id, @@ -2576,17 +2765,22 @@ def from_ti(cls, ti: TaskInstance): pool=ti.pool, queue=ti.queue, key=ti.key, - run_as_user=ti.run_as_user if hasattr(ti, 'run_as_user') else None, - priority_weight=ti.priority_weight if hasattr(ti, 'priority_weight') else None, + run_as_user=ti.run_as_user if hasattr(ti, "run_as_user") else None, + priority_weight=ti.priority_weight if hasattr(ti, "priority_weight") else None, ) @classmethod - def from_dict(cls, obj_dict: dict): - ti_key = TaskInstanceKey(*obj_dict.pop('key')) + def from_dict(cls, obj_dict: dict) -> SimpleTaskInstance: + warnings.warn( + "This method is deprecated. Use BaseSerialization.deserialize.", + RemovedInAirflow3Warning, + stacklevel=2, + ) + ti_key = TaskInstanceKey(*obj_dict.pop("key")) start_date = None end_date = None - start_date_str: Optional[str] = obj_dict.pop('start_date') - end_date_str: Optional[str] = obj_dict.pop('end_date') + start_date_str: str | None = obj_dict.pop("start_date") + end_date_str: str | None = obj_dict.pop("end_date") if start_date_str: start_date = timezone.parse(start_date_str) if end_date_str: @@ -2594,8 +2788,57 @@ def from_dict(cls, obj_dict: dict): return cls(**obj_dict, start_date=start_date, end_date=end_date, key=ti_key) +class TaskInstanceNote(Base): + """For storage of arbitrary notes concerning the task instance.""" + + __tablename__ = "task_instance_note" + + user_id = Column(Integer, nullable=True) + task_id = Column(StringID(), primary_key=True, nullable=False) + dag_id = Column(StringID(), primary_key=True, nullable=False) + run_id = Column(StringID(), primary_key=True, nullable=False) + map_index = Column(Integer, primary_key=True, nullable=False) + content = Column(String(1000).with_variant(Text(1000), "mysql")) + created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) + + task_instance = relationship("TaskInstance", back_populates="task_instance_note") + + __table_args__ = ( + PrimaryKeyConstraint( + "task_id", "dag_id", "run_id", "map_index", name="task_instance_note_pkey", mssql_clustered=True + ), + ForeignKeyConstraint( + (dag_id, task_id, run_id, map_index), + [ + "task_instance.dag_id", + "task_instance.task_id", + "task_instance.run_id", + "task_instance.map_index", + ], + name="task_instance_note_ti_fkey", + ondelete="CASCADE", + ), + ForeignKeyConstraint( + (user_id,), + ["ab_user.id"], + name="task_instance_note_user_fkey", + ), + ) + + def __init__(self, content, user_id=None): + self.content = content + self.user_id = user_id + + def __repr__(self): + prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}" + if self.map_index != -1: + prefix += f" map_index={self.map_index}" + return prefix + ">" + + STATICA_HACK = True -globals()['kcah_acitats'[::-1].upper()] = False +globals()["kcah_acitats"[::-1].upper()] = False if STATICA_HACK: # pragma: no cover from airflow.jobs.base_job import BaseJob diff --git a/airflow/models/tasklog.py b/airflow/models/tasklog.py index a5a3e510fcfb0..3e5a9e195ea56 100644 --- a/airflow/models/tasklog.py +++ b/airflow/models/tasklog.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from sqlalchemy import Column, Integer, Text diff --git a/airflow/models/taskmap.py b/airflow/models/taskmap.py index 3945e2dcd7d30..e7abcc1b6e0ae 100644 --- a/airflow/models/taskmap.py +++ b/airflow/models/taskmap.py @@ -15,12 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Table to store information about mapped task instances (AIP-42).""" +from __future__ import annotations import collections.abc import enum -from typing import TYPE_CHECKING, Any, Collection, List, Optional +from typing import TYPE_CHECKING, Any, Collection from sqlalchemy import CheckConstraint, Column, ForeignKeyConstraint, Integer, String @@ -82,7 +82,7 @@ def __init__( run_id: str, map_index: int, length: int, - keys: Optional[List[Any]], + keys: list[Any] | None, ) -> None: self.dag_id = dag_id self.task_id = task_id @@ -92,7 +92,7 @@ def __init__( self.keys = keys @classmethod - def from_task_instance_xcom(cls, ti: "TaskInstance", value: Collection) -> "TaskMap": + def from_task_instance_xcom(cls, ti: TaskInstance, value: Collection) -> TaskMap: if ti.run_id is None: raise ValueError("cannot record task map for unrun task instance") return cls( diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py index 1d66a719a6aec..211fad6ff909d 100644 --- a/airflow/models/taskmixin.py +++ b/airflow/models/taskmixin.py @@ -14,21 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import warnings from abc import ABCMeta, abstractmethod -from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sequence, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Iterable, Sequence import pendulum -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, RemovedInAirflow3Warning from airflow.serialization.enums import DagAttributeTypes if TYPE_CHECKING: from logging import Logger from airflow.models.dag import DAG - from airflow.models.mappedoperator import MappedOperator + from airflow.models.operator import Operator from airflow.utils.edgemodifier import EdgeModifier from airflow.utils.task_group import TaskGroup @@ -37,7 +38,7 @@ class DependencyMixin: """Mixing implementing common dependency setting methods methods like >> and <<.""" @property - def roots(self) -> Sequence["DependencyMixin"]: + def roots(self) -> Sequence[DependencyMixin]: """ List of root nodes -- ones with no upstream dependencies. @@ -46,7 +47,7 @@ def roots(self) -> Sequence["DependencyMixin"]: raise NotImplementedError() @property - def leaves(self) -> Sequence["DependencyMixin"]: + def leaves(self) -> Sequence[DependencyMixin]: """ List of leaf nodes -- ones with only upstream dependencies. @@ -55,37 +56,37 @@ def leaves(self) -> Sequence["DependencyMixin"]: raise NotImplementedError() @abstractmethod - def set_upstream(self, other: Union["DependencyMixin", Sequence["DependencyMixin"]]): + def set_upstream(self, other: DependencyMixin | Sequence[DependencyMixin]): """Set a task or a task list to be directly upstream from the current task.""" raise NotImplementedError() @abstractmethod - def set_downstream(self, other: Union["DependencyMixin", Sequence["DependencyMixin"]]): + def set_downstream(self, other: DependencyMixin | Sequence[DependencyMixin]): """Set a task or a task list to be directly downstream from the current task.""" raise NotImplementedError() - def update_relative(self, other: "DependencyMixin", upstream=True) -> None: + def update_relative(self, other: DependencyMixin, upstream=True) -> None: """ Update relationship information about another TaskMixin. Default is no-op. Override if necessary. """ - def __lshift__(self, other: Union["DependencyMixin", Sequence["DependencyMixin"]]): + def __lshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): """Implements Task << Task""" self.set_upstream(other) return other - def __rshift__(self, other: Union["DependencyMixin", Sequence["DependencyMixin"]]): + def __rshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): """Implements Task >> Task""" self.set_downstream(other) return other - def __rrshift__(self, other: Union["DependencyMixin", Sequence["DependencyMixin"]]): + def __rrshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): """Called for Task >> [Task] because list don't have __rshift__ operators.""" self.__lshift__(other) return self - def __rlshift__(self, other: Union["DependencyMixin", Sequence["DependencyMixin"]]): + def __rlshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): """Called for Task << [Task] because list don't have __lshift__ operators.""" self.__rshift__(other) return self @@ -97,7 +98,7 @@ class TaskMixin(DependencyMixin): def __init_subclass__(cls) -> None: warnings.warn( f"TaskMixin has been renamed to DependencyMixin, please update {cls.__name__}", - category=DeprecationWarning, + category=RemovedInAirflow3Warning, stacklevel=2, ) return super().__init_subclass__() @@ -109,8 +110,8 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta): unmapped. """ - dag: Optional["DAG"] = None - task_group: Optional["TaskGroup"] = None + dag: DAG | None = None + task_group: TaskGroup | None = None """The task_group that contains this node""" @property @@ -119,17 +120,17 @@ def node_id(self) -> str: raise NotImplementedError() @property - def label(self) -> Optional[str]: + def label(self) -> str | None: tg = self.task_group if tg and tg.node_id and tg.prefix_group_id: # "task_group_id.task_id" -> "task_id" return self.node_id[len(tg.node_id) + 1 :] return self.node_id - start_date: Optional[pendulum.DateTime] - end_date: Optional[pendulum.DateTime] - upstream_task_ids: Set[str] - downstream_task_ids: Set[str] + start_date: pendulum.DateTime | None + end_date: pendulum.DateTime | None + upstream_task_ids: set[str] + downstream_task_ids: set[str] def has_dag(self) -> bool: return self.dag is not None @@ -142,24 +143,24 @@ def dag_id(self) -> str: return "_in_memory_dag_" @property - def log(self) -> "Logger": + def log(self) -> Logger: raise NotImplementedError() @property @abstractmethod - def roots(self) -> Sequence["DAGNode"]: + def roots(self) -> Sequence[DAGNode]: raise NotImplementedError() @property @abstractmethod - def leaves(self) -> Sequence["DAGNode"]: + def leaves(self) -> Sequence[DAGNode]: raise NotImplementedError() def _set_relatives( self, - task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]], + task_or_task_list: DependencyMixin | Sequence[DependencyMixin], upstream: bool = False, - edge_modifier: Optional["EdgeModifier"] = None, + edge_modifier: EdgeModifier | None = None, ) -> None: """Sets relatives for the task or task list.""" from airflow.models.baseoperator import BaseOperator @@ -169,7 +170,7 @@ def _set_relatives( if not isinstance(task_or_task_list, Sequence): task_or_task_list = [task_or_task_list] - task_list: List[Operator] = [] + task_list: list[Operator] = [] for task_object in task_or_task_list: task_object.update_relative(self, not upstream) relatives = task_object.leaves if upstream else task_object.roots @@ -182,10 +183,10 @@ def _set_relatives( # relationships can only be set if the tasks share a single DAG. Tasks # without a DAG are assigned to that DAG. - dags: Set["DAG"] = {task.dag for task in [*self.roots, *task_list] if task.has_dag() and task.dag} + dags: set[DAG] = {task.dag for task in [*self.roots, *task_list] if task.has_dag() and task.dag} if len(dags) > 1: - raise AirflowException(f'Tried to set relationships between tasks in more than one DAG: {dags}') + raise AirflowException(f"Tried to set relationships between tasks in more than one DAG: {dags}") elif len(dags) == 1: dag = dags.pop() else: @@ -195,24 +196,20 @@ def _set_relatives( ) if not self.has_dag(): - # If this task does not yet have a dag, add it to the same dag as the other task and - # put it in the dag's root TaskGroup. + # If this task does not yet have a dag, add it to the same dag as the other task. self.dag = dag - self.dag.task_group.add(self) - def add_only_new(obj, item_set: Set[str], item: str) -> None: + def add_only_new(obj, item_set: set[str], item: str) -> None: """Adds only new items to item set""" if item in item_set: - self.log.warning('Dependency %s, %s already registered for DAG: %s', obj, item, dag.dag_id) + self.log.warning("Dependency %s, %s already registered for DAG: %s", obj, item, dag.dag_id) else: item_set.add(item) for task in task_list: if dag and not task.has_dag(): # If the other task does not yet have a dag, add it to the same dag as this task and - # put it in the dag's root TaskGroup. dag.add_task(task) - dag.task_group.add(task) if upstream: add_only_new(task, task.downstream_task_ids, self.node_id) add_only_new(self, self.upstream_task_ids, task.node_id) @@ -226,35 +223,35 @@ def add_only_new(obj, item_set: Set[str], item: str) -> None: def set_downstream( self, - task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]], - edge_modifier: Optional["EdgeModifier"] = None, + task_or_task_list: DependencyMixin | Sequence[DependencyMixin], + edge_modifier: EdgeModifier | None = None, ) -> None: """Set a node (or nodes) to be directly downstream from the current node.""" self._set_relatives(task_or_task_list, upstream=False, edge_modifier=edge_modifier) def set_upstream( self, - task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]], - edge_modifier: Optional["EdgeModifier"] = None, + task_or_task_list: DependencyMixin | Sequence[DependencyMixin], + edge_modifier: EdgeModifier | None = None, ) -> None: - """Set a node (or nodes) to be directly downstream from the current node.""" + """Set a node (or nodes) to be directly upstream from the current node.""" self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier) @property - def downstream_list(self) -> Iterable["DAGNode"]: + def downstream_list(self) -> Iterable[Operator]: """List of nodes directly downstream""" if not self.dag: - raise AirflowException(f'Operator {self} has not been assigned to a DAG yet') + raise AirflowException(f"Operator {self} has not been assigned to a DAG yet") return [self.dag.get_task(tid) for tid in self.downstream_task_ids] @property - def upstream_list(self) -> Iterable["DAGNode"]: + def upstream_list(self) -> Iterable[Operator]: """List of nodes directly upstream""" if not self.dag: - raise AirflowException(f'Operator {self} has not been assigned to a DAG yet') + raise AirflowException(f"Operator {self} has not been assigned to a DAG yet") return [self.dag.get_task(tid) for tid in self.upstream_task_ids] - def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]: + def get_direct_relative_ids(self, upstream: bool = False) -> set[str]: """ Get set of the direct relative ids to the current task, upstream or downstream. @@ -264,7 +261,7 @@ def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]: else: return self.downstream_task_ids - def get_direct_relatives(self, upstream: bool = False) -> Iterable["DAGNode"]: + def get_direct_relatives(self, upstream: bool = False) -> Iterable[DAGNode]: """ Get list of the direct relatives to the current task, upstream or downstream. @@ -274,60 +271,6 @@ def get_direct_relatives(self, upstream: bool = False) -> Iterable["DAGNode"]: else: return self.downstream_list - def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]: - """This is used by SerializedTaskGroup to serialize a task group's content.""" + def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: + """This is used by TaskGroupSerialization to serialize a task group's content.""" raise NotImplementedError() - - def _iter_all_mapped_downstreams(self) -> Iterator["MappedOperator"]: - """Return mapped nodes that are direct dependencies of the current task. - - For now, this walks the entire DAG to find mapped nodes that has this - current task as an upstream. We cannot use ``downstream_list`` since it - only contains operators, not task groups. In the future, we should - provide a way to record an DAG node's all downstream nodes instead. - - Note that this does not guarantee the returned tasks actually use the - current task for task mapping, but only checks those task are mapped - operators, and are downstreams of the current task. - - To get a list of tasks that uses the current task for task mapping, use - :meth:`iter_mapped_dependants` instead. - """ - from airflow.models.mappedoperator import MappedOperator - from airflow.utils.task_group import TaskGroup - - def _walk_group(group: TaskGroup) -> Iterable[Tuple[str, DAGNode]]: - """Recursively walk children in a task group. - - This yields all direct children (including both tasks and task - groups), and all children of any task groups. - """ - for key, child in group.children.items(): - yield key, child - if isinstance(child, TaskGroup): - yield from _walk_group(child) - - tg = self.task_group - if not tg: - raise RuntimeError("Cannot check for mapped dependants when not attached to a DAG") - for key, child in _walk_group(tg): - if key == self.node_id: - continue - if not isinstance(child, MappedOperator): - continue - if self.node_id in child.upstream_task_ids: - yield child - - def iter_mapped_dependants(self) -> Iterator["MappedOperator"]: - """Return mapped nodes that depend on the current task the expansion. - - For now, this walks the entire DAG to find mapped nodes that has this - current task as an upstream. We cannot use ``downstream_list`` since it - only contains operators, not task groups. In the future, we should - provide a way to record an DAG node's all downstream nodes instead. - """ - return ( - downstream - for downstream in self._iter_all_mapped_downstreams() - if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies()) - ) diff --git a/airflow/models/taskreschedule.py b/airflow/models/taskreschedule.py index 518f1e77ff65f..bbe09145c0c95 100644 --- a/airflow/models/taskreschedule.py +++ b/airflow/models/taskreschedule.py @@ -16,11 +16,12 @@ # specific language governing permissions and limitations # under the License. """TaskReschedule tracks rescheduled task instances.""" +from __future__ import annotations import datetime from typing import TYPE_CHECKING -from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, String, asc, desc, text +from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, String, asc, desc, event, text from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm import relationship @@ -49,7 +50,7 @@ class TaskReschedule(Base): reschedule_date = Column(UtcDateTime, nullable=False) __table_args__ = ( - Index('idx_task_reschedule_dag_task_run', dag_id, task_id, run_id, map_index, unique=False), + Index("idx_task_reschedule_dag_task_run", dag_id, task_id, run_id, map_index, unique=False), ForeignKeyConstraint( [dag_id, task_id, run_id, map_index], [ @@ -61,11 +62,12 @@ class TaskReschedule(Base): name="task_reschedule_ti_fkey", ondelete="CASCADE", ), + Index("idx_task_reschedule_dag_run", dag_id, run_id), ForeignKeyConstraint( [dag_id, run_id], - ['dag_run.dag_id', 'dag_run.run_id'], - name='task_reschedule_dr_fkey', - ondelete='CASCADE', + ["dag_run.dag_id", "dag_run.run_id"], + name="task_reschedule_dr_fkey", + ondelete="CASCADE", ), ) dag_run = relationship("DagRun") @@ -73,7 +75,7 @@ class TaskReschedule(Base): def __init__( self, - task: "BaseOperator", + task: BaseOperator, run_id: str, try_number: int, start_date: datetime.datetime, @@ -111,6 +113,7 @@ def query_for_task_instance(task_instance, descending=False, session=None, try_n TR.dag_id == task_instance.dag_id, TR.task_id == task_instance.task_id, TR.run_id == task_instance.run_id, + TR.map_index == task_instance.map_index, TR.try_number == try_number, ) if descending: @@ -133,3 +136,15 @@ def find_for_task_instance(task_instance, session=None, try_number=None): return TaskReschedule.query_for_task_instance( task_instance, session=session, try_number=try_number ).all() + + +@event.listens_for(TaskReschedule.__table__, "before_create") +def add_ondelete_for_mssql(table, conn, **kw): + if conn.dialect.name != "mssql": + return + + for constraint in table.constraints: + if constraint.name != "task_reschedule_dr_fkey": + continue + constraint.ondelete = "NO ACTION" + return diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py index dc91fdccdf2ba..57d2ac8f2600a 100644 --- a/airflow/models/trigger.py +++ b/airflow/models/trigger.py @@ -14,11 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import datetime from traceback import format_exception -from typing import Any, Dict, Iterable, Optional +from typing import Any, Iterable from sqlalchemy import Column, Integer, String, func, or_ +from sqlalchemy.orm import relationship from airflow.models.base import Base from airflow.models.taskinstance import TaskInstance @@ -26,7 +29,7 @@ from airflow.utils import timezone from airflow.utils.retries import run_with_db_retries from airflow.utils.session import provide_session -from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime +from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, with_row_locks from airflow.utils.state import State @@ -55,9 +58,14 @@ class Trigger(Base): created_date = Column(UtcDateTime, nullable=False) triggerer_id = Column(Integer, nullable=True) - def __init__( - self, classpath: str, kwargs: Dict[str, Any], created_date: Optional[datetime.datetime] = None - ): + triggerer_job = relationship( + "BaseJob", + primaryjoin="BaseJob.id == Trigger.triggerer_id", + foreign_keys=triggerer_id, + uselist=False, + ) + + def __init__(self, classpath: str, kwargs: dict[str, Any], created_date: datetime.datetime | None = None): super().__init__() self.classpath = classpath self.kwargs = kwargs @@ -74,7 +82,7 @@ def from_object(cls, trigger: BaseTrigger): @classmethod @provide_session - def bulk_fetch(cls, ids: Iterable[int], session=None) -> Dict[int, "Trigger"]: + def bulk_fetch(cls, ids: Iterable[int], session=None) -> dict[int, Trigger]: """ Fetches all of the Triggers by ID and returns a dict mapping ID -> Trigger instance @@ -188,15 +196,16 @@ def assign_unassigned(cls, triggerer_id, capacity, session=None): # Find triggers who do NOT have an alive triggerer_id, and then assign # up to `capacity` of those to us. - trigger_ids_query = ( + trigger_ids_query = with_row_locks( session.query(cls.id) - # notin_ doesn't find NULL rows .filter(or_(cls.triggerer_id.is_(None), cls.triggerer_id.notin_(alive_triggerer_ids))) - .limit(capacity) - .all() - ) - session.query(cls).filter(cls.id.in_([i.id for i in trigger_ids_query])).update( - {cls.triggerer_id: triggerer_id}, - synchronize_session=False, - ) + .limit(capacity), + session, + skip_locked=True, + ).all() + if trigger_ids_query: + session.query(cls).filter(cls.id.in_([i.id for i in trigger_ids_query])).update( + {cls.triggerer_id: triggerer_id}, + synchronize_session=False, + ) session.commit() diff --git a/airflow/models/variable.py b/airflow/models/variable.py index 0904ddb23e6e9..dc7db0f7c09de 100644 --- a/airflow/models/variable.py +++ b/airflow/models/variable.py @@ -15,13 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import json import logging -from typing import Any, Optional +from typing import Any -from cryptography.fernet import InvalidToken as InvalidFernetToken from sqlalchemy import Boolean, Column, Integer, String, Text +from sqlalchemy.dialects.mysql import MEDIUMTEXT from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import Session, reconstructor, synonym @@ -47,7 +48,7 @@ class Variable(Base, LoggingMixin): id = Column(Integer, primary_key=True) key = Column(String(ID_LEN), unique=True) - _val = Column('val', Text) + _val = Column("val", Text().with_variant(MEDIUMTEXT, "mysql")) description = Column(Text) is_encrypted = Column(Boolean, unique=False, default=False) @@ -64,14 +65,16 @@ def on_db_load(self): def __repr__(self): # Hiding the value - return f'{self.key} : {self._val}' + return f"{self.key} : {self._val}" def get_val(self): """Get Airflow Variable from Metadata DB and decode it using the Fernet Key""" + from cryptography.fernet import InvalidToken as InvalidFernetToken + if self._val is not None and self.is_encrypted: try: fernet = get_fernet() - return fernet.decrypt(bytes(self._val, 'utf-8')).decode() + return fernet.decrypt(bytes(self._val, "utf-8")).decode() except InvalidFernetToken: self.log.error("Can't decrypt _val for key=%s, invalid token or value", self.key) return None @@ -85,13 +88,13 @@ def set_val(self, value): """Encode the specified value with Fernet Key and store it in Variables Table.""" if value is not None: fernet = get_fernet() - self._val = fernet.encrypt(bytes(value, 'utf-8')).decode() + self._val = fernet.encrypt(bytes(value, "utf-8")).decode() self.is_encrypted = fernet.is_encrypted @declared_attr def val(cls): """Get Airflow Variable from Metadata DB and decode it using the Fernet Key""" - return synonym('_val', descriptor=property(cls.get_val, cls.set_val)) + return synonym("_val", descriptor=property(cls.get_val, cls.set_val)) @classmethod def setdefault(cls, key, default, description=None, deserialize_json=False): @@ -112,7 +115,7 @@ def setdefault(cls, key, default, description=None, deserialize_json=False): Variable.set(key, default, description=description, serialize_json=deserialize_json) return default else: - raise ValueError('Default Value must be set') + raise ValueError("Default Value must be set") else: return obj @@ -135,7 +138,7 @@ def get( if default_var is not cls.__NO_DEFAULT_SENTINEL: return default_var else: - raise KeyError(f'Variable {key} does not exist') + raise KeyError(f"Variable {key} does not exist") else: if deserialize_json: obj = json.loads(var_val) @@ -151,7 +154,7 @@ def set( cls, key: str, value: Any, - description: Optional[str] = None, + description: str | None = None, serialize_json: bool = False, session: Session = None, ): @@ -196,11 +199,11 @@ def update( cls.check_for_write_conflict(key) if cls.get_variable_from_secrets(key=key) is None: - raise KeyError(f'Variable {key} does not exist') + raise KeyError(f"Variable {key} does not exist") obj = session.query(cls).filter(cls.key == key).first() if obj is None: - raise AttributeError(f'Variable {key} does not exist in the Database and cannot be updated.') + raise AttributeError(f"Variable {key} does not exist in the Database and cannot be updated.") cls.set(key, value, description=obj.description, serialize_json=serialize_json) @@ -219,7 +222,7 @@ def rotate_fernet_key(self): """Rotate Fernet Key""" fernet = get_fernet() if self._val and self.is_encrypted: - self._val = fernet.rotate(self._val.encode('utf-8')).decode() + self._val = fernet.rotate(self._val.encode("utf-8")).decode() @staticmethod def check_for_write_conflict(key: str) -> None: @@ -246,14 +249,14 @@ def check_for_write_conflict(key: str) -> None: return except Exception: log.exception( - 'Unable to retrieve variable from secrets backend (%s). ' - 'Checking subsequent secrets backend.', + "Unable to retrieve variable from secrets backend (%s). " + "Checking subsequent secrets backend.", type(secrets_backend).__name__, ) return None @staticmethod - def get_variable_from_secrets(key: str) -> Optional[str]: + def get_variable_from_secrets(key: str) -> str | None: """ Get Airflow Variable by iterating over all Secret Backends. @@ -267,8 +270,8 @@ def get_variable_from_secrets(key: str) -> Optional[str]: return var_val except Exception: log.exception( - 'Unable to retrieve variable from secrets backend (%s). ' - 'Checking subsequent secrets backend.', + "Unable to retrieve variable from secrets backend (%s). " + "Checking subsequent secrets backend.", type(secrets_backend).__name__, ) return None diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index d67160d1fa901..3b4361842409f 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -15,26 +15,44 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations +import collections.abc +import contextlib import datetime import inspect +import itertools import json import logging import pickle import warnings from functools import wraps -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Type, Union, cast, overload +from typing import TYPE_CHECKING, Any, Generator, Iterable, cast, overload +import attr import pendulum -from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, LargeBinary, String +from sqlalchemy import ( + Column, + ForeignKeyConstraint, + Index, + Integer, + LargeBinary, + PrimaryKeyConstraint, + String, + text, +) from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm import Query, Session, reconstructor, relationship from sqlalchemy.orm.exc import NoResultFound +from airflow import settings +from airflow.compat.functools import cached_property from airflow.configuration import conf +from airflow.exceptions import RemovedInAirflow3Warning from airflow.models.base import COLLATION_ARGS, ID_LEN, Base from airflow.utils import timezone from airflow.utils.helpers import exactly_one, is_container +from airflow.utils.json import XComDecoder, XComEncoder from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import UtcDateTime @@ -44,7 +62,7 @@ # MAX XCOM Size is 48KB # https://github.com/apache/airflow/pull/1618#discussion_r68249677 MAX_XCOM_SIZE = 49344 -XCOM_RETURN_KEY = 'return_value' +XCOM_RETURN_KEY = "return_value" if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstanceKey @@ -57,7 +75,7 @@ class BaseXCom(Base, LoggingMixin): dag_run_id = Column(Integer(), nullable=False, primary_key=True) task_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False, primary_key=True) - map_index = Column(Integer, primary_key=True, nullable=False, server_default="-1") + map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1")) key = Column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True) # Denormalized for easier lookup. @@ -72,6 +90,10 @@ class BaseXCom(Base, LoggingMixin): # but it goes over MySQL's index length limit. So we instead index 'key' # separately, and enforce uniqueness with DagRun.id instead. Index("idx_xcom_key", key), + Index("idx_xcom_task_instance", dag_id, task_id, run_id, map_index), + PrimaryKeyConstraint( + "dag_run_id", "task_id", "map_index", "key", name="xcom_pkey", mssql_clustered=True + ), ForeignKeyConstraint( [dag_id, task_id, run_id, map_index], [ @@ -157,10 +179,10 @@ def set( value: Any, task_id: str, dag_id: str, - execution_date: Optional[datetime.datetime] = None, + execution_date: datetime.datetime | None = None, session: Session = NEW_SESSION, *, - run_id: Optional[str] = None, + run_id: str | None = None, map_index: int = -1, ) -> None: """:sphinx-autoapi-skip:""" @@ -174,7 +196,7 @@ def set( if run_id is None: message = "Passing 'execution_date' to 'XCom.set()' is deprecated. Use 'run_id' instead." - warnings.warn(message, DeprecationWarning, stacklevel=3) + warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) try: dag_run_id, run_id = ( session.query(DagRun.id, DagRun.run_id) @@ -188,6 +210,27 @@ def set( if dag_run_id is None: raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}") + # Seamlessly resolve LazyXComAccess to a list. This is intended to work + # as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but if + # it's pushed into XCom, the user should be aware of the performance + # implications, and this avoids leaking the implementation detail. + if isinstance(value, LazyXComAccess): + warning_message = ( + "Coercing mapped lazy proxy %s from task %s (DAG %s, run %s) " + "to list, which may degrade performance. Review resource " + "requirements for this operation, and call list() to suppress " + "this message. See Dynamic Task Mapping documentation for " + "more information about lazy proxy objects." + ) + log.warning( + warning_message, + "return value" if key == XCOM_RETURN_KEY else f"value {key}", + task_id, + dag_id, + run_id or execution_date, + ) + value = list(value) + value = cls.serialize_value( value=value, key=key, @@ -222,8 +265,8 @@ def set( def get_value( cls, *, - ti_key: "TaskInstanceKey", - key: Optional[str] = None, + ti_key: TaskInstanceKey, + key: str | None = None, session: Session = NEW_SESSION, ) -> Any: """Retrieve an XCom value for a task instance. @@ -255,13 +298,13 @@ def get_value( def get_one( cls, *, - key: Optional[str] = None, - dag_id: Optional[str] = None, - task_id: Optional[str] = None, - run_id: Optional[str] = None, - map_index: Optional[int] = None, + key: str | None = None, + dag_id: str | None = None, + task_id: str | None = None, + run_id: str | None = None, + map_index: int | None = None, session: Session = NEW_SESSION, - ) -> Optional[Any]: + ) -> Any | None: """Retrieve an XCom value, optionally meeting certain criteria. This method returns "full" XCom values (i.e. uses ``deserialize_value`` @@ -298,28 +341,28 @@ def get_one( def get_one( cls, execution_date: datetime.datetime, - key: Optional[str] = None, - task_id: Optional[str] = None, - dag_id: Optional[str] = None, + key: str | None = None, + task_id: str | None = None, + dag_id: str | None = None, include_prior_dates: bool = False, session: Session = NEW_SESSION, - ) -> Optional[Any]: + ) -> Any | None: """:sphinx-autoapi-skip:""" @classmethod @provide_session def get_one( cls, - execution_date: Optional[datetime.datetime] = None, - key: Optional[str] = None, - task_id: Optional[str] = None, - dag_id: Optional[str] = None, + execution_date: datetime.datetime | None = None, + key: str | None = None, + task_id: str | None = None, + dag_id: str | None = None, include_prior_dates: bool = False, session: Session = NEW_SESSION, *, - run_id: Optional[str] = None, - map_index: Optional[int] = None, - ) -> Optional[Any]: + run_id: str | None = None, + map_index: int | None = None, + ) -> Any | None: """:sphinx-autoapi-skip:""" if not exactly_one(execution_date is not None, run_id is not None): raise ValueError("Exactly one of ti_key, run_id, or execution_date must be passed") @@ -337,10 +380,10 @@ def get_one( ) elif execution_date is not None: message = "Passing 'execution_date' to 'XCom.get_one()' is deprecated. Use 'run_id' instead." - warnings.warn(message, PendingDeprecationWarning, stacklevel=3) + warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + warnings.simplefilter("ignore", RemovedInAirflow3Warning) query = cls.get_many( execution_date=execution_date, key=key, @@ -365,12 +408,12 @@ def get_many( cls, *, run_id: str, - key: Optional[str] = None, - task_ids: Union[str, Iterable[str], None] = None, - dag_ids: Union[str, Iterable[str], None] = None, - map_indexes: Union[int, Iterable[int], None] = None, + key: str | None = None, + task_ids: str | Iterable[str] | None = None, + dag_ids: str | Iterable[str] | None = None, + map_indexes: int | Iterable[int] | None = None, include_prior_dates: bool = False, - limit: Optional[int] = None, + limit: int | None = None, session: Session = NEW_SESSION, ) -> Query: """Composes a query to get one or more XCom entries. @@ -402,12 +445,12 @@ def get_many( def get_many( cls, execution_date: datetime.datetime, - key: Optional[str] = None, - task_ids: Union[str, Iterable[str], None] = None, - dag_ids: Union[str, Iterable[str], None] = None, - map_indexes: Union[int, Iterable[int], None] = None, + key: str | None = None, + task_ids: str | Iterable[str] | None = None, + dag_ids: str | Iterable[str] | None = None, + map_indexes: int | Iterable[int] | None = None, include_prior_dates: bool = False, - limit: Optional[int] = None, + limit: int | None = None, session: Session = NEW_SESSION, ) -> Query: """:sphinx-autoapi-skip:""" @@ -416,16 +459,16 @@ def get_many( @provide_session def get_many( cls, - execution_date: Optional[datetime.datetime] = None, - key: Optional[str] = None, - task_ids: Optional[Union[str, Iterable[str]]] = None, - dag_ids: Optional[Union[str, Iterable[str]]] = None, - map_indexes: Union[int, Iterable[int], None] = None, + execution_date: datetime.datetime | None = None, + key: str | None = None, + task_ids: str | Iterable[str] | None = None, + dag_ids: str | Iterable[str] | None = None, + map_indexes: int | Iterable[int] | None = None, include_prior_dates: bool = False, - limit: Optional[int] = None, + limit: int | None = None, session: Session = NEW_SESSION, *, - run_id: Optional[str] = None, + run_id: str | None = None, ) -> Query: """:sphinx-autoapi-skip:""" from airflow.models.dagrun import DagRun @@ -437,7 +480,7 @@ def get_many( ) if execution_date is not None: message = "Passing 'execution_date' to 'XCom.get_many()' is deprecated. Use 'run_id' instead." - warnings.warn(message, PendingDeprecationWarning, stacklevel=3) + warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) query = session.query(cls).join(cls.dag_run) @@ -454,7 +497,9 @@ def get_many( elif dag_ids is not None: query = query.filter(cls.dag_id == dag_ids) - if is_container(map_indexes): + if isinstance(map_indexes, range) and map_indexes.step == 1: + query = query.filter(cls.map_index >= map_indexes.start, cls.map_index < map_indexes.stop) + elif is_container(map_indexes): query = query.filter(cls.map_index.in_(map_indexes)) elif map_indexes is not None: query = query.filter(cls.map_index == map_indexes) @@ -477,13 +522,13 @@ def get_many( @classmethod @provide_session - def delete(cls, xcoms: Union["XCom", Iterable["XCom"]], session: Session) -> None: + def delete(cls, xcoms: XCom | Iterable[XCom], session: Session) -> None: """Delete one or multiple XCom entries.""" if isinstance(xcoms, XCom): xcoms = [xcoms] for xcom in xcoms: if not isinstance(xcom, XCom): - raise TypeError(f'Expected XCom; received {xcom.__class__.__name__}') + raise TypeError(f"Expected XCom; received {xcom.__class__.__name__}") session.delete(xcom) session.commit() @@ -495,7 +540,7 @@ def clear( dag_id: str, task_id: str, run_id: str, - map_index: Optional[int] = None, + map_index: int | None = None, session: Session = NEW_SESSION, ) -> None: """Clear all XCom data from the database for the given task instance. @@ -527,13 +572,13 @@ def clear( @provide_session def clear( cls, - execution_date: Optional[pendulum.DateTime] = None, - dag_id: Optional[str] = None, - task_id: Optional[str] = None, + execution_date: pendulum.DateTime | None = None, + dag_id: str | None = None, + task_id: str | None = None, session: Session = NEW_SESSION, *, - run_id: Optional[str] = None, - map_index: Optional[int] = None, + run_id: str | None = None, + map_index: int | None = None, ) -> None: """:sphinx-autoapi-skip:""" from airflow.models import DagRun @@ -553,7 +598,7 @@ def clear( if execution_date is not None: message = "Passing 'execution_date' to 'XCom.clear()' is deprecated. Use 'run_id' instead." - warnings.warn(message, DeprecationWarning, stacklevel=3) + warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) run_id = ( session.query(DagRun.run_id) .filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date) @@ -569,42 +614,52 @@ def clear( def serialize_value( value: Any, *, - key: Optional[str] = None, - task_id: Optional[str] = None, - dag_id: Optional[str] = None, - run_id: Optional[str] = None, - map_index: Optional[int] = None, - ): - """Serialize XCom value to str or pickled object""" - if conf.getboolean('core', 'enable_xcom_pickling'): + key: str | None = None, + task_id: str | None = None, + dag_id: str | None = None, + run_id: str | None = None, + map_index: int | None = None, + ) -> Any: + """Serialize XCom value to str or pickled object.""" + if conf.getboolean("core", "enable_xcom_pickling"): return pickle.dumps(value) try: - return json.dumps(value).encode('UTF-8') - except (ValueError, TypeError): + return json.dumps(value, cls=XComEncoder).encode("UTF-8") + except (ValueError, TypeError) as ex: log.error( - "Could not serialize the XCom value into JSON." + "%s." " If you are using pickle instead of JSON for XCom," " then you need to enable pickle support for XCom" - " in your airflow config." + " in your airflow config or make sure to decorate your" + " object with attr.", + ex, ) raise @staticmethod - def deserialize_value(result: "XCom") -> Any: - """Deserialize XCom value from str or pickle object""" + def _deserialize_value(result: XCom, orm: bool) -> Any: + object_hook = None + if orm: + object_hook = XComDecoder.orm_object_hook + if result.value is None: return None - if conf.getboolean('core', 'enable_xcom_pickling'): + if conf.getboolean("core", "enable_xcom_pickling"): try: return pickle.loads(result.value) except pickle.UnpicklingError: - return json.loads(result.value.decode('UTF-8')) + return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook) else: try: - return json.loads(result.value.decode('UTF-8')) + return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook) except (json.JSONDecodeError, UnicodeDecodeError): return pickle.loads(result.value) + @staticmethod + def deserialize_value(result: XCom) -> Any: + """Deserialize XCom value from str or pickle object""" + return BaseXCom._deserialize_value(result, False) + def orm_deserialize_value(self) -> Any: """ Deserialize method which is used to reconstruct ORM XCom object. @@ -614,10 +669,109 @@ def orm_deserialize_value(self) -> Any: creating XCom orm model. This is used when viewing XCom listing in the webserver, for example. """ - return BaseXCom.deserialize_value(self) + return BaseXCom._deserialize_value(self, True) + + +class _LazyXComAccessIterator(collections.abc.Iterator): + def __init__(self, cm: contextlib.AbstractContextManager[Query]) -> None: + self._cm = cm + self._entered = False + + def __del__(self) -> None: + if self._entered: + self._cm.__exit__(None, None, None) + + def __iter__(self) -> collections.abc.Iterator: + return self + def __next__(self) -> Any: + return XCom.deserialize_value(next(self._it)) -def _patch_outdated_serializer(clazz: Type[BaseXCom], params: Iterable[str]) -> None: + @cached_property + def _it(self) -> collections.abc.Iterator: + self._entered = True + return iter(self._cm.__enter__()) + + +@attr.define(slots=True) +class LazyXComAccess(collections.abc.Sequence): + """Wrapper to lazily pull XCom with a sequence-like interface. + + Note that since the session bound to the parent query may have died when we + actually access the sequence's content, we must create a new session + for every function call with ``with_session()``. + + :meta private: + """ + + _query: Query + _len: int | None = attr.ib(init=False, default=None) + + @classmethod + def build_from_xcom_query(cls, query: Query) -> LazyXComAccess: + return cls(query=query.with_entities(XCom.value)) + + def __repr__(self) -> str: + return f"LazyXComAccess([{len(self)} items])" + + def __str__(self) -> str: + return str(list(self)) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, (list, LazyXComAccess)): + z = itertools.zip_longest(iter(self), iter(other), fillvalue=object()) + return all(x == y for x, y in z) + return NotImplemented + + def __getstate__(self) -> Any: + # We don't want to go to the trouble of serializing the entire Query + # object, including its filters, hints, etc. (plus SQLAlchemy does not + # provide a public API to inspect a query's contents). Converting the + # query into a SQL string is the best we can get. Theoratically we can + # do the same for count(), but I think it should be performant enough to + # calculate only that eagerly. + with self._get_bound_query() as query: + statement = query.statement.compile(query.session.get_bind()) + return (str(statement), query.count()) + + def __setstate__(self, state: Any) -> None: + statement, self._len = state + self._query = Query(XCom.value).from_statement(text(statement)) + + def __len__(self): + if self._len is None: + with self._get_bound_query() as query: + self._len = query.count() + return self._len + + def __iter__(self): + return _LazyXComAccessIterator(self._get_bound_query()) + + def __getitem__(self, key): + if not isinstance(key, int): + raise ValueError("only support index access for now") + try: + with self._get_bound_query() as query: + r = query.offset(key).limit(1).one() + except NoResultFound: + raise IndexError(key) from None + return XCom.deserialize_value(r) + + @contextlib.contextmanager + def _get_bound_query(self) -> Generator[Query, None, None]: + # Do we have a valid session already? + if self._query.session and self._query.session.is_active: + yield self._query + return + + session = settings.Session() + try: + yield self._query.with_session(session) + finally: + session.close() + + +def _patch_outdated_serializer(clazz: type[BaseXCom], params: Iterable[str]) -> None: """Patch a custom ``serialize_value`` to accept the modern signature. To give custom XCom backends more flexibility with how they store values, we @@ -635,19 +789,18 @@ def _shim(**kwargs): f"Method `serialize_value` in XCom backend {XCom.__name__} is using outdated signature and" f"must be updated to accept all params in `BaseXCom.set` except `session`. Support will be " f"removed in a future release.", - DeprecationWarning, + RemovedInAirflow3Warning, ) return old_serializer(**kwargs) clazz.serialize_value = _shim # type: ignore[assignment] -def _get_function_params(function) -> List[str]: +def _get_function_params(function) -> list[str]: """ Returns the list of variables names of a function :param function: The function to inspect - :rtype: List[str] """ parameters = inspect.signature(function).parameters bound_arguments = [ @@ -656,8 +809,12 @@ def _get_function_params(function) -> List[str]: return bound_arguments -def resolve_xcom_backend() -> Type[BaseXCom]: - """Resolves custom XCom class""" +def resolve_xcom_backend() -> type[BaseXCom]: + """Resolves custom XCom class + + Confirms that custom XCom class extends the BaseXCom. + Compares the function signature of the custom XCom serialize_value to the base XCom serialize_value. + """ clazz = conf.getimport("core", "xcom_backend", fallback=f"airflow.models.xcom.{BaseXCom.__name__}") if not clazz: return BaseXCom diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 2c2a8f9b5868d..d2b80474a9f12 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -14,33 +14,46 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Sequence, Union -from airflow.exceptions import AirflowException +from __future__ import annotations + +import contextlib +import inspect +from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence, Union, overload + +from sqlalchemy import func +from sqlalchemy.orm import Session + +from airflow.exceptions import XComNotFound from airflow.models.abstractoperator import AbstractOperator +from airflow.models.mappedoperator import MappedOperator from airflow.models.taskmixin import DAGNode, DependencyMixin from airflow.models.xcom import XCOM_RETURN_KEY from airflow.utils.context import Context from airflow.utils.edgemodifier import EdgeModifier +from airflow.utils.mixins import ResolveMixin from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.types import NOTSET +from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: - from sqlalchemy.orm import Session - + from airflow.models.dag import DAG from airflow.models.operator import Operator +# Callable objects contained by MapXComArg. We only accept callables from +# the user, but deserialize them into strings in a serialized XComArg for +# safety (those callables are arbitrary user code). +MapCallables = Sequence[Union[Callable[[Any], Any], str]] -class XComArg(DependencyMixin): - """ - Class that represents a XCom push from a previous operator. - Defaults to "return_value" as only key. - Current implementation supports +class XComArg(ResolveMixin, DependencyMixin): + """Reference to an XCom value pushed from another operator. + + The implementation supports:: + xcomarg >> op xcomarg << op - op >> xcomarg (by BaseOperator code) - op << xcomarg (by BaseOperator code) + op >> xcomarg # By BaseOperator code + op << xcomarg # By BaseOperator code **Example**: The moment you get a result from any operator (decorated or regular) you can :: @@ -53,29 +66,175 @@ class XComArg(DependencyMixin): This object can be used in legacy Operators via Jinja. - **Example**: You can make this result to be part of any generated string :: + **Example**: You can make this result to be part of any generated string:: any_op = AnyOperator() xcomarg = any_op.output op1 = MyOperator(my_text_message=f"the value is {xcomarg}") op2 = MyOperator(my_text_message=f"the value is {xcomarg['topic']}") - :param operator: operator to which the XComArg belongs to - :param key: key value which is used for xcom_pull (key in the XCom table) + :param operator: Operator instance to which the XComArg references. + :param key: Key used to pull the XCom value. Defaults to *XCOM_RETURN_KEY*, + i.e. the referenced operator's return value. """ - def __init__(self, operator: "Operator", key: str = XCOM_RETURN_KEY): + @overload + def __new__(cls: type[XComArg], operator: Operator, key: str = XCOM_RETURN_KEY) -> XComArg: + """Called when the user writes ``XComArg(...)`` directly.""" + + @overload + def __new__(cls: type[XComArg]) -> XComArg: + """Called by Python internals from subclasses.""" + + def __new__(cls, *args, **kwargs) -> XComArg: + if cls is XComArg: + return PlainXComArg(*args, **kwargs) + return super().__new__(cls) + + @staticmethod + def iter_xcom_references(arg: Any) -> Iterator[tuple[Operator, str]]: + """Return XCom references in an arbitrary value. + + Recursively traverse ``arg`` and look for XComArg instances in any + collection objects, and instances with ``template_fields`` set. + """ + if isinstance(arg, ResolveMixin): + yield from arg.iter_references() + elif isinstance(arg, (tuple, set, list)): + for elem in arg: + yield from XComArg.iter_xcom_references(elem) + elif isinstance(arg, dict): + for elem in arg.values(): + yield from XComArg.iter_xcom_references(elem) + elif isinstance(arg, AbstractOperator): + for attr in arg.template_fields: + yield from XComArg.iter_xcom_references(getattr(arg, attr)) + + @staticmethod + def apply_upstream_relationship(op: Operator, arg: Any): + """Set dependency for XComArgs. + + This looks for XComArg objects in ``arg`` "deeply" (looking inside + collections objects and classes decorated with ``template_fields``), and + sets the relationship to ``op`` on any found. + """ + for operator, _ in XComArg.iter_xcom_references(arg): + op.set_upstream(operator) + + @property + def roots(self) -> list[DAGNode]: + """Required by TaskMixin""" + return [op for op, _ in self.iter_references()] + + @property + def leaves(self) -> list[DAGNode]: + """Required by TaskMixin""" + return [op for op, _ in self.iter_references()] + + def set_upstream( + self, + task_or_task_list: DependencyMixin | Sequence[DependencyMixin], + edge_modifier: EdgeModifier | None = None, + ): + """Proxy to underlying operator set_upstream method. Required by TaskMixin.""" + for operator, _ in self.iter_references(): + operator.set_upstream(task_or_task_list, edge_modifier) + + def set_downstream( + self, + task_or_task_list: DependencyMixin | Sequence[DependencyMixin], + edge_modifier: EdgeModifier | None = None, + ): + """Proxy to underlying operator set_downstream method. Required by TaskMixin.""" + for operator, _ in self.iter_references(): + operator.set_downstream(task_or_task_list, edge_modifier) + + def _serialize(self) -> dict[str, Any]: + """Called by DAG serialization. + + The implementation should be the inverse function to ``deserialize``, + returning a data dict converted from this XComArg derivative. DAG + serialization does not call this directly, but ``serialize_xcom_arg`` + instead, which adds additional information to dispatch deserialization + to the correct class. + """ + raise NotImplementedError() + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: + """Called when deserializing a DAG. + + The implementation should be the inverse function to ``serialize``, + implementing given a data dict converted from this XComArg derivative, + how the original XComArg should be created. DAG serialization relies on + additional information added in ``serialize_xcom_arg`` to dispatch data + dicts to the correct ``_deserialize`` information, so this function does + not need to validate whether the incoming data contains correct keys. + """ + raise NotImplementedError() + + def map(self, f: Callable[[Any], Any]) -> MapXComArg: + return MapXComArg(self, [f]) + + def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: + return ZipXComArg([self, *others], fillvalue=fillvalue) + + def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: + """Inspect length of pushed value for task-mapping. + + This is used to determine how many task instances the scheduler should + create for a downstream using this XComArg for task-mapping. + + *None* may be returned if the depended XCom has not been pushed. + """ + raise NotImplementedError() + + def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any: + """Pull XCom value. + + This should only be called during ``op.execute()`` with an appropriate + context (e.g. generated from ``TaskInstance.get_template_context()``). + Although the ``ResolveMixin`` parent mixin also has a ``resolve`` + protocol, this adds the optional ``session`` argument that some of the + subclasses need. + + :meta private: + """ + raise NotImplementedError() + + +class PlainXComArg(XComArg): + """Reference to one single XCom without any additional semantics. + + This class should not be accessed directly, but only through XComArg. The + class inheritance chain and ``__new__`` is implemented in this slightly + convoluted way because we want to + + a. Allow the user to continue using XComArg directly for the simple + semantics (see documentation of the base class for details). + b. Make ``isinstance(thing, XComArg)`` be able to detect all kinds of XCom + references. + c. Not allow many properties of PlainXComArg (including ``__getitem__`` and + ``__str__``) to exist on other kinds of XComArg implementations since + they don't make sense. + + :meta private: + """ + + def __init__(self, operator: Operator, key: str = XCOM_RETURN_KEY): self.operator = operator self.key = key - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: + if not isinstance(other, PlainXComArg): + return NotImplemented return self.operator == other.operator and self.key == other.key - def __getitem__(self, item: str) -> "XComArg": + def __getitem__(self, item: str) -> XComArg: """Implements xcomresult['some_result_key']""" if not isinstance(item, str): raise ValueError(f"XComArg only supports str lookup, received {type(item).__name__}") - return XComArg(operator=self.operator, key=item) + return PlainXComArg(operator=self.operator, key=item) def __iter__(self): """Override iterable protocol to raise error explicitly. @@ -89,9 +248,14 @@ def __iter__(self): This override catches the error eagerly, so an incorrectly implemented DAG fails fast and avoids wasting resources on nonsensical iterating. """ - raise TypeError(f"{self.__class__.__name__!r} object is not iterable") + raise TypeError("'XComArg' object is not iterable") + + def __repr__(self) -> str: + if self.key == XCOM_RETURN_KEY: + return f"XComArg({self.operator!r})" + return f"XComArg({self.operator!r}, {self.key!r})" - def __str__(self): + def __str__(self) -> str: """ Backward compatibility for old-style jinja used in Airflow Operators @@ -103,84 +267,267 @@ def __str__(self): """ xcom_pull_kwargs = [ f"task_ids='{self.operator.task_id}'", - f"dag_id='{self.operator.dag.dag_id}'", + f"dag_id='{self.operator.dag_id}'", ] if self.key is not None: xcom_pull_kwargs.append(f"key='{self.key}'") - xcom_pull_kwargs = ", ".join(xcom_pull_kwargs) + xcom_pull_str = ", ".join(xcom_pull_kwargs) # {{{{ are required for escape {{ in f-string - xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_kwargs}) }}}}" + xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_str}) }}}}" return xcom_pull - @property - def roots(self) -> List[DAGNode]: - """Required by TaskMixin""" - return [self.operator] + def _serialize(self) -> dict[str, Any]: + return {"task_id": self.operator.task_id, "key": self.key} + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: + return cls(dag.get_task(data["task_id"]), data["key"]) + + def iter_references(self) -> Iterator[tuple[Operator, str]]: + yield self.operator, self.key + + def map(self, f: Callable[[Any], Any]) -> MapXComArg: + if self.key != XCOM_RETURN_KEY: + raise ValueError("cannot map against non-return XCom") + return super().map(f) + + def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: + if self.key != XCOM_RETURN_KEY: + raise ValueError("cannot map against non-return XCom") + return super().zip(*others, fillvalue=fillvalue) + + def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: + from airflow.models.taskmap import TaskMap + from airflow.models.xcom import XCom + + task = self.operator + if isinstance(task, MappedOperator): + query = session.query(func.count(XCom.map_index)).filter( + XCom.dag_id == task.dag_id, + XCom.run_id == run_id, + XCom.task_id == task.task_id, + XCom.map_index >= 0, + XCom.key == XCOM_RETURN_KEY, + ) + else: + query = session.query(TaskMap.length).filter( + TaskMap.dag_id == task.dag_id, + TaskMap.run_id == run_id, + TaskMap.task_id == task.task_id, + TaskMap.map_index < 0, + ) + return query.scalar() - @property - def leaves(self) -> List[DAGNode]: - """Required by TaskMixin""" - return [self.operator] + @provide_session + def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any: + ti = context["ti"] + task_id = self.operator.task_id + map_indexes = ti.get_relevant_upstream_map_indexes( + self.operator, + context["expanded_ti_count"], + session=session, + ) + result = ti.xcom_pull( + task_ids=task_id, + map_indexes=map_indexes, + key=self.key, + default=NOTSET, + session=session, + ) + if not isinstance(result, ArgNotSet): + return result + if self.key == XCOM_RETURN_KEY: + return None + raise XComNotFound(ti.dag_id, task_id, self.key) + + +def _get_callable_name(f: Callable | str) -> str: + """Try to "describe" a callable by getting its name.""" + if callable(f): + return f.__name__ + # Parse the source to find whatever is behind "def". For safety, we don't + # want to evaluate the code in any meaningful way! + with contextlib.suppress(Exception): + kw, name, _ = f.lstrip().split(None, 2) + if kw == "def": + return name + return "" + + +class _MapResult(Sequence): + def __init__(self, value: Sequence | dict, callables: MapCallables) -> None: + self.value = value + self.callables = callables + + def __getitem__(self, index: Any) -> Any: + value = self.value[index] + + # In the worker, we can access all actual callables. Call them. + callables = [f for f in self.callables if callable(f)] + if len(callables) == len(self.callables): + for f in callables: + value = f(value) + return value + + # In the scheduler, we don't have access to the actual callables, nor do + # we want to run it since it's arbitrary code. This builds a string to + # represent the call chain in the UI or logs instead. + for v in self.callables: + value = f"{_get_callable_name(v)}({value})" + return value + + def __len__(self) -> int: + return len(self.value) + + +class MapXComArg(XComArg): + """An XCom reference with ``map()`` call(s) applied. + + This is based on an XComArg, but also applies a series of "transforms" that + convert the pulled XCom value. + + :meta private: + """ - def set_upstream( - self, - task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]], - edge_modifier: Optional[EdgeModifier] = None, - ): - """Proxy to underlying operator set_upstream method. Required by TaskMixin.""" - self.operator.set_upstream(task_or_task_list, edge_modifier) + def __init__(self, arg: XComArg, callables: MapCallables) -> None: + for c in callables: + if getattr(c, "_airflow_is_task_decorator", False): + raise ValueError("map() argument must be a plain function, not a @task operator") + self.arg = arg + self.callables = callables - def set_downstream( - self, - task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]], - edge_modifier: Optional[EdgeModifier] = None, - ): - """Proxy to underlying operator set_downstream method. Required by TaskMixin.""" - self.operator.set_downstream(task_or_task_list, edge_modifier) + def __repr__(self) -> str: + map_calls = "".join(f".map({_get_callable_name(f)})" for f in self.callables) + return f"{self.arg!r}{map_calls}" + + def _serialize(self) -> dict[str, Any]: + return { + "arg": serialize_xcom_arg(self.arg), + "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables], + } + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: + # We are deliberately NOT deserializing the callables. These are shown + # in the UI, and displaying a function object is useless. + return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"]) + + def iter_references(self) -> Iterator[tuple[Operator, str]]: + yield from self.arg.iter_references() + + def map(self, f: Callable[[Any], Any]) -> MapXComArg: + # Flatten arg.map(f1).map(f2) into one MapXComArg. + return MapXComArg(self.arg, [*self.callables, f]) + + def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: + return self.arg.get_task_map_length(run_id, session=session) @provide_session - def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any: - """ - Pull XCom value for the existing arg. This method is run during ``op.execute()`` - in respectable context. - """ - result = context["ti"].xcom_pull( - task_ids=self.operator.task_id, key=str(self.key), default=NOTSET, session=session - ) - if result is NOTSET: - raise AirflowException( - f'XComArg result from {self.operator.task_id} at {context["ti"].dag_id} ' - f'with key="{self.key}" is not found!' - ) - return result + def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any: + value = self.arg.resolve(context, session=session) + if not isinstance(value, (Sequence, dict)): + raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") + return _MapResult(value, self.callables) - @staticmethod - def iter_xcom_args(arg: Any) -> Iterator["XComArg"]: - """Return XComArg instances in an arbitrary value. - This recursively traverse ``arg`` and look for XComArg instances in any - collection objects, and instances with ``template_fields`` set. - """ - if isinstance(arg, XComArg): - yield arg - elif isinstance(arg, (tuple, set, list)): - for elem in arg: - yield from XComArg.iter_xcom_args(elem) - elif isinstance(arg, dict): - for elem in arg.values(): - yield from XComArg.iter_xcom_args(elem) - elif isinstance(arg, AbstractOperator): - for elem in arg.template_fields: - yield from XComArg.iter_xcom_args(elem) +class _ZipResult(Sequence): + def __init__(self, values: Sequence[Sequence | dict], *, fillvalue: Any = NOTSET) -> None: + self.values = values + self.fillvalue = fillvalue @staticmethod - def apply_upstream_relationship(op: "Operator", arg: Any): - """Set dependency for XComArgs. + def _get_or_fill(container: Sequence | dict, index: Any, fillvalue: Any) -> Any: + try: + return container[index] + except (IndexError, KeyError): + return fillvalue + + def __getitem__(self, index: Any) -> Any: + if index >= len(self): + raise IndexError(index) + return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values) + + def __len__(self) -> int: + lengths = (len(v) for v in self.values) + if isinstance(self.fillvalue, ArgNotSet): + return min(lengths) + return max(lengths) + + +class ZipXComArg(XComArg): + """An XCom reference with ``zip()`` applied. + + This is constructed from multiple XComArg instances, and presents an + iterable that "zips" them together like the built-in ``zip()`` (and + ``itertools.zip_longest()`` if ``fillvalue`` is provided). + """ - This looks for XComArg objects in ``arg`` "deeply" (looking inside - collections objects and classes decorated with ``template_fields``), and - sets the relationship to ``op`` on any found. - """ - for ref in XComArg.iter_xcom_args(arg): - op.set_upstream(ref.operator) + def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None: + if not args: + raise ValueError("At least one input is required") + self.args = args + self.fillvalue = fillvalue + + def __repr__(self) -> str: + args_iter = iter(self.args) + first = repr(next(args_iter)) + rest = ", ".join(repr(arg) for arg in args_iter) + if isinstance(self.fillvalue, ArgNotSet): + return f"{first}.zip({rest})" + return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})" + + def _serialize(self) -> dict[str, Any]: + args = [serialize_xcom_arg(arg) for arg in self.args] + if isinstance(self.fillvalue, ArgNotSet): + return {"args": args} + return {"args": args, "fillvalue": self.fillvalue} + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: + return cls( + [deserialize_xcom_arg(arg, dag) for arg in data["args"]], + fillvalue=data.get("fillvalue", NOTSET), + ) + + def iter_references(self) -> Iterator[tuple[Operator, str]]: + for arg in self.args: + yield from arg.iter_references() + + def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: + all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args) + ready_lengths = [length for length in all_lengths if length is not None] + if len(ready_lengths) != len(self.args): + return None # If any of the referenced XComs is not ready, we are not ready either. + if isinstance(self.fillvalue, ArgNotSet): + return min(ready_lengths) + return max(ready_lengths) + + @provide_session + def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any: + values = [arg.resolve(context, session=session) for arg in self.args] + for value in values: + if not isinstance(value, (Sequence, dict)): + raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}") + return _ZipResult(values, fillvalue=self.fillvalue) + + +_XCOM_ARG_TYPES: Mapping[str, type[XComArg]] = { + "": PlainXComArg, + "map": MapXComArg, + "zip": ZipXComArg, +} + + +def serialize_xcom_arg(value: XComArg) -> dict[str, Any]: + """DAG serialization interface.""" + key = next(k for k, v in _XCOM_ARG_TYPES.items() if v == type(value)) + if key: + return {"type": key, **value._serialize()} + return value._serialize() + + +def deserialize_xcom_arg(data: dict[str, Any], dag: DAG) -> XComArg: + """DAG serialization interface.""" + klass = _XCOM_ARG_TYPES[data.get("type", "")] + return klass._deserialize(data, dag) diff --git a/airflow/operators/__init__.py b/airflow/operators/__init__.py index ec83c1e147a32..f4acbe0ebf883 100644 --- a/airflow/operators/__init__.py +++ b/airflow/operators/__init__.py @@ -15,4 +15,180 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# fmt: off """Operators.""" +from __future__ import annotations + +from airflow.utils.deprecation_tools import add_deprecated_classes + +__deprecated_classes = { + 'bash_operator': { + 'BashOperator': 'airflow.operators.bash.BashOperator', + }, + 'branch_operator': { + 'BaseBranchOperator': 'airflow.operators.branch.BaseBranchOperator', + }, + 'check_operator': { + 'SQLCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLCheckOperator', + 'SQLIntervalCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator', + 'SQLThresholdCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLThresholdCheckOperator', + 'SQLValueCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLValueCheckOperator', + 'CheckOperator': 'airflow.providers.common.sql.operators.sql.SQLCheckOperator', + 'IntervalCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator', + 'ThresholdCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLThresholdCheckOperator', + 'ValueCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLValueCheckOperator', + }, + 'dagrun_operator': { + 'TriggerDagRunLink': 'airflow.operators.trigger_dagrun.TriggerDagRunLink', + 'TriggerDagRunOperator': 'airflow.operators.trigger_dagrun.TriggerDagRunOperator', + }, + 'docker_operator': { + 'DockerOperator': 'airflow.providers.docker.operators.docker.DockerOperator', + }, + 'druid_check_operator': { + 'DruidCheckOperator': 'airflow.providers.apache.druid.operators.druid_check.DruidCheckOperator', + }, + 'dummy': { + 'EmptyOperator': 'airflow.operators.empty.EmptyOperator', + 'DummyOperator': 'airflow.operators.empty.EmptyOperator', + }, + 'dummy_operator': { + 'EmptyOperator': 'airflow.operators.empty.EmptyOperator', + 'DummyOperator': 'airflow.operators.empty.EmptyOperator', + }, + 'email_operator': { + 'EmailOperator': 'airflow.operators.email.EmailOperator', + }, + 'gcs_to_s3': { + 'GCSToS3Operator': 'airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSToS3Operator', + }, + 'google_api_to_s3_transfer': { + 'GoogleApiToS3Operator': ( + 'airflow.providers.amazon.aws.transfers.google_api_to_s3.GoogleApiToS3Operator' + ), + 'GoogleApiToS3Transfer': ( + 'airflow.providers.amazon.aws.transfers.google_api_to_s3.GoogleApiToS3Operator' + ), + }, + 'hive_operator': { + 'HiveOperator': 'airflow.providers.apache.hive.operators.hive.HiveOperator', + }, + 'hive_stats_operator': { + 'HiveStatsCollectionOperator': ( + 'airflow.providers.apache.hive.operators.hive_stats.HiveStatsCollectionOperator' + ), + }, + 'hive_to_druid': { + 'HiveToDruidOperator': 'airflow.providers.apache.druid.transfers.hive_to_druid.HiveToDruidOperator', + 'HiveToDruidTransfer': 'airflow.providers.apache.druid.transfers.hive_to_druid.HiveToDruidOperator', + }, + 'hive_to_mysql': { + 'HiveToMySqlOperator': 'airflow.providers.apache.hive.transfers.hive_to_mysql.HiveToMySqlOperator', + 'HiveToMySqlTransfer': 'airflow.providers.apache.hive.transfers.hive_to_mysql.HiveToMySqlOperator', + }, + 'hive_to_samba_operator': { + 'HiveToSambaOperator': 'airflow.providers.apache.hive.transfers.hive_to_samba.HiveToSambaOperator', + }, + 'http_operator': { + 'SimpleHttpOperator': 'airflow.providers.http.operators.http.SimpleHttpOperator', + }, + 'jdbc_operator': { + 'JdbcOperator': 'airflow.providers.jdbc.operators.jdbc.JdbcOperator', + }, + 'latest_only_operator': { + 'LatestOnlyOperator': 'airflow.operators.latest_only.LatestOnlyOperator', + }, + 'mssql_operator': { + 'MsSqlOperator': 'airflow.providers.microsoft.mssql.operators.mssql.MsSqlOperator', + }, + 'mssql_to_hive': { + 'MsSqlToHiveOperator': 'airflow.providers.apache.hive.transfers.mssql_to_hive.MsSqlToHiveOperator', + 'MsSqlToHiveTransfer': 'airflow.providers.apache.hive.transfers.mssql_to_hive.MsSqlToHiveOperator', + }, + 'mysql_operator': { + 'MySqlOperator': 'airflow.providers.mysql.operators.mysql.MySqlOperator', + }, + 'mysql_to_hive': { + 'MySqlToHiveOperator': 'airflow.providers.apache.hive.transfers.mysql_to_hive.MySqlToHiveOperator', + 'MySqlToHiveTransfer': 'airflow.providers.apache.hive.transfers.mysql_to_hive.MySqlToHiveOperator', + }, + 'oracle_operator': { + 'OracleOperator': 'airflow.providers.oracle.operators.oracle.OracleOperator', + }, + 'papermill_operator': { + 'PapermillOperator': 'airflow.providers.papermill.operators.papermill.PapermillOperator', + }, + 'pig_operator': { + 'PigOperator': 'airflow.providers.apache.pig.operators.pig.PigOperator', + }, + 'postgres_operator': { + 'Mapping': 'airflow.providers.postgres.operators.postgres.Mapping', + 'PostgresOperator': 'airflow.providers.postgres.operators.postgres.PostgresOperator', + }, + 'presto_check_operator': { + 'SQLCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLCheckOperator', + 'SQLIntervalCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator', + 'SQLValueCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLValueCheckOperator', + 'PrestoCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLCheckOperator', + 'PrestoIntervalCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator', + 'PrestoValueCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLValueCheckOperator', + }, + 'presto_to_mysql': { + 'PrestoToMySqlOperator': 'airflow.providers.mysql.transfers.presto_to_mysql.PrestoToMySqlOperator', + 'PrestoToMySqlTransfer': 'airflow.providers.mysql.transfers.presto_to_mysql.PrestoToMySqlOperator', + }, + 'python_operator': { + 'BranchPythonOperator': 'airflow.operators.python.BranchPythonOperator', + 'PythonOperator': 'airflow.operators.python.PythonOperator', + 'PythonVirtualenvOperator': 'airflow.operators.python.PythonVirtualenvOperator', + 'ShortCircuitOperator': 'airflow.operators.python.ShortCircuitOperator', + }, + 'redshift_to_s3_operator': { + 'RedshiftToS3Operator': 'airflow.providers.amazon.aws.transfers.redshift_to_s3.RedshiftToS3Operator', + 'RedshiftToS3Transfer': 'airflow.providers.amazon.aws.transfers.redshift_to_s3.RedshiftToS3Operator', + }, + 's3_file_transform_operator': { + 'S3FileTransformOperator': ( + 'airflow.providers.amazon.aws.operators.s3_file_transform.S3FileTransformOperator' + ), + }, + 's3_to_hive_operator': { + 'S3ToHiveOperator': 'airflow.providers.apache.hive.transfers.s3_to_hive.S3ToHiveOperator', + 'S3ToHiveTransfer': 'airflow.providers.apache.hive.transfers.s3_to_hive.S3ToHiveOperator', + }, + 's3_to_redshift_operator': { + 'S3ToRedshiftOperator': 'airflow.providers.amazon.aws.transfers.s3_to_redshift.S3ToRedshiftOperator', + 'S3ToRedshiftTransfer': 'airflow.providers.amazon.aws.transfers.s3_to_redshift.S3ToRedshiftOperator', + }, + 'slack_operator': { + 'SlackAPIOperator': 'airflow.providers.slack.operators.slack.SlackAPIOperator', + 'SlackAPIPostOperator': 'airflow.providers.slack.operators.slack.SlackAPIPostOperator', + }, + 'sql': { + 'BaseSQLOperator': 'airflow.providers.common.sql.operators.sql.BaseSQLOperator', + 'BranchSQLOperator': 'airflow.providers.common.sql.operators.sql.BranchSQLOperator', + 'SQLCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLCheckOperator', + 'SQLColumnCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLColumnCheckOperator', + 'SQLIntervalCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator', + 'SQLTableCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLTableCheckOperator', + 'SQLThresholdCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLThresholdCheckOperator', + 'SQLValueCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLValueCheckOperator', + '_convert_to_float_if_possible': ( + 'airflow.providers.common.sql.operators.sql._convert_to_float_if_possible' + ), + 'parse_boolean': 'airflow.providers.common.sql.operators.sql.parse_boolean', + }, + 'sql_branch_operator': { + 'BranchSQLOperator': 'airflow.providers.common.sql.operators.sql.BranchSQLOperator', + 'BranchSqlOperator': 'airflow.providers.common.sql.operators.sql.BranchSQLOperator', + }, + 'sqlite_operator': { + 'SqliteOperator': 'airflow.providers.sqlite.operators.sqlite.SqliteOperator', + }, + 'subdag_operator': { + 'SkippedStatePropagationOptions': 'airflow.operators.subdag.SkippedStatePropagationOptions', + 'SubDagOperator': 'airflow.operators.subdag.SubDagOperator', + }, +} + +add_deprecated_classes(__deprecated_classes, __name__) diff --git a/airflow/operators/bash.py b/airflow/operators/bash.py index a4fac2b7f7eab..d50733bafa33f 100644 --- a/airflow/operators/bash.py +++ b/airflow/operators/bash.py @@ -15,8 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import os -from typing import Dict, Optional, Sequence +import shutil +from typing import Sequence from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException, AirflowSkipException @@ -122,23 +125,23 @@ class BashOperator(BaseOperator): """ - template_fields: Sequence[str] = ('bash_command', 'env') - template_fields_renderers = {'bash_command': 'bash', 'env': 'json'} + template_fields: Sequence[str] = ("bash_command", "env") + template_fields_renderers = {"bash_command": "bash", "env": "json"} template_ext: Sequence[str] = ( - '.sh', - '.bash', + ".sh", + ".bash", ) - ui_color = '#f0ede4' + ui_color = "#f0ede4" def __init__( self, *, bash_command: str, - env: Optional[Dict[str, str]] = None, + env: dict[str, str] | None = None, append_env: bool = False, - output_encoding: str = 'utf-8', + output_encoding: str = "utf-8", skip_exit_code: int = 99, - cwd: Optional[str] = None, + cwd: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -148,8 +151,6 @@ def __init__( self.skip_exit_code = skip_exit_code self.cwd = cwd self.append_env = append_env - if kwargs.get('xcom_push') is not None: - raise AirflowException("'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead") @cached_property def subprocess_hook(self): @@ -169,13 +170,14 @@ def get_env(self, context): airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) self.log.debug( - 'Exporting the following env vars:\n%s', - '\n'.join(f"{k}={v}" for k, v in airflow_context_vars.items()), + "Exporting the following env vars:\n%s", + "\n".join(f"{k}={v}" for k, v in airflow_context_vars.items()), ) env.update(airflow_context_vars) return env def execute(self, context: Context): + bash_path = shutil.which("bash") or "bash" if self.cwd is not None: if not os.path.exists(self.cwd): raise AirflowException(f"Can not find the cwd: {self.cwd}") @@ -183,7 +185,7 @@ def execute(self, context: Context): raise AirflowException(f"The cwd {self.cwd} must be a directory") env = self.get_env(context) result = self.subprocess_hook.run_command( - command=['bash', '-c', self.bash_command], + command=[bash_path, "-c", self.bash_command], env=env, output_encoding=self.output_encoding, cwd=self.cwd, @@ -192,7 +194,7 @@ def execute(self, context: Context): raise AirflowSkipException(f"Bash command returned exit code {self.skip_exit_code}. Skipping.") elif result.exit_code != 0: raise AirflowException( - f'Bash command failed. The command returned a non-zero exit code {result.exit_code}.' + f"Bash command failed. The command returned a non-zero exit code {result.exit_code}." ) return result.output diff --git a/airflow/operators/bash_operator.py b/airflow/operators/bash_operator.py deleted file mode 100644 index 3b7764dfbd316..0000000000000 --- a/airflow/operators/bash_operator.py +++ /dev/null @@ -1,26 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.operators.bash`.""" - -import warnings - -from airflow.operators.bash import BashOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.operators.bash`.", DeprecationWarning, stacklevel=2 -) diff --git a/airflow/operators/branch.py b/airflow/operators/branch.py index cdd546feca469..8dfbe7c2767dc 100644 --- a/airflow/operators/branch.py +++ b/airflow/operators/branch.py @@ -16,10 +16,11 @@ # specific language governing permissions and limitations # under the License. """Branching operators""" +from __future__ import annotations -from typing import Iterable, Union +from typing import Iterable -from airflow.models import BaseOperator +from airflow.models.baseoperator import BaseOperator from airflow.models.skipmixin import SkipMixin from airflow.utils.context import Context @@ -38,7 +39,7 @@ class BaseBranchOperator(BaseOperator, SkipMixin): tasks directly downstream of this operator will be skipped. """ - def choose_branch(self, context: Context) -> Union[str, Iterable[str]]: + def choose_branch(self, context: Context) -> str | Iterable[str]: """ Subclasses should implement this, running whatever logic is necessary to choose a branch and returning a task_id or list of @@ -50,5 +51,5 @@ def choose_branch(self, context: Context) -> Union[str, Iterable[str]]: def execute(self, context: Context): branches_to_execute = self.choose_branch(context) - self.skip_all_except(context['ti'], branches_to_execute) + self.skip_all_except(context["ti"], branches_to_execute) return branches_to_execute diff --git a/airflow/operators/branch_operator.py b/airflow/operators/branch_operator.py deleted file mode 100644 index b4c71d5bc1f88..0000000000000 --- a/airflow/operators/branch_operator.py +++ /dev/null @@ -1,26 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.operators.branch`.""" - -import warnings - -from airflow.operators.branch import BaseBranchOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.operators.branch`.", DeprecationWarning, stacklevel=2 -) diff --git a/airflow/operators/check_operator.py b/airflow/operators/check_operator.py deleted file mode 100644 index 130eb6e577c9c..0000000000000 --- a/airflow/operators/check_operator.py +++ /dev/null @@ -1,96 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.operators.sql`.""" - -import warnings - -from airflow.operators.sql import ( - SQLCheckOperator, - SQLIntervalCheckOperator, - SQLThresholdCheckOperator, - SQLValueCheckOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.operators.sql`.", DeprecationWarning, stacklevel=2 -) - - -class CheckOperator(SQLCheckOperator): - """ - This class is deprecated. - Please use `airflow.operators.sql.SQLCheckOperator`. - """ - - def __init__(self, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.operators.sql.SQLCheckOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(**kwargs) - - -class IntervalCheckOperator(SQLIntervalCheckOperator): - """ - This class is deprecated. - Please use `airflow.operators.sql.SQLIntervalCheckOperator`. - """ - - def __init__(self, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.operators.sql.SQLIntervalCheckOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(**kwargs) - - -class ThresholdCheckOperator(SQLThresholdCheckOperator): - """ - This class is deprecated. - Please use `airflow.operators.sql.SQLThresholdCheckOperator`. - """ - - def __init__(self, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.operators.sql.SQLThresholdCheckOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(**kwargs) - - -class ValueCheckOperator(SQLValueCheckOperator): - """ - This class is deprecated. - Please use `airflow.operators.sql.SQLValueCheckOperator`. - """ - - def __init__(self, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.operators.sql.SQLValueCheckOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(**kwargs) diff --git a/airflow/operators/dagrun_operator.py b/airflow/operators/dagrun_operator.py deleted file mode 100644 index bdcc6671516af..0000000000000 --- a/airflow/operators/dagrun_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.operators.trigger_dagrun`.""" - -import warnings - -from airflow.operators.trigger_dagrun import TriggerDagRunLink, TriggerDagRunOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.operators.trigger_dagrun`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/operators/datetime.py b/airflow/operators/datetime.py index c5a423d563868..335d33707684c 100644 --- a/airflow/operators/datetime.py +++ b/airflow/operators/datetime.py @@ -14,12 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import datetime import warnings -from typing import Iterable, Union +from typing import Iterable -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, RemovedInAirflow3Warning from airflow.operators.branch import BaseBranchOperator from airflow.utils import timezone from airflow.utils.context import Context @@ -47,10 +48,10 @@ class BranchDateTimeOperator(BaseBranchOperator): def __init__( self, *, - follow_task_ids_if_true: Union[str, Iterable[str]], - follow_task_ids_if_false: Union[str, Iterable[str]], - target_lower: Union[datetime.datetime, datetime.time, None], - target_upper: Union[datetime.datetime, datetime.time, None], + follow_task_ids_if_true: str | Iterable[str], + follow_task_ids_if_false: str | Iterable[str], + target_lower: datetime.datetime | datetime.time | None, + target_upper: datetime.datetime | datetime.time | None, use_task_logical_date: bool = False, use_task_execution_date: bool = False, **kwargs, @@ -71,17 +72,19 @@ def __init__( self.use_task_logical_date = use_task_execution_date warnings.warn( "Parameter ``use_task_execution_date`` is deprecated. Use ``use_task_logical_date``.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) - def choose_branch(self, context: Context) -> Union[str, Iterable[str]]: + def choose_branch(self, context: Context) -> str | Iterable[str]: if self.use_task_logical_date: - now = timezone.make_naive(context["logical_date"], self.dag.timezone) + now = context["logical_date"] else: - now = timezone.make_naive(timezone.utcnow(), self.dag.timezone) - + now = timezone.coerce_datetime(timezone.utcnow()) lower, upper = target_times_as_dates(now, self.target_lower, self.target_upper) + lower = timezone.coerce_datetime(lower, self.dag.timezone) + upper = timezone.coerce_datetime(upper, self.dag.timezone) + if upper is not None and upper < now: return self.follow_task_ids_if_false @@ -93,8 +96,8 @@ def choose_branch(self, context: Context) -> Union[str, Iterable[str]]: def target_times_as_dates( base_date: datetime.datetime, - lower: Union[datetime.datetime, datetime.time, None], - upper: Union[datetime.datetime, datetime.time, None], + lower: datetime.datetime | datetime.time | None, + upper: datetime.datetime | datetime.time | None, ): """Ensures upper and lower time targets are datetimes by combining them with base_date""" if isinstance(lower, datetime.datetime) and isinstance(upper, datetime.datetime): diff --git a/airflow/operators/docker_operator.py b/airflow/operators/docker_operator.py deleted file mode 100644 index 88235b4382461..0000000000000 --- a/airflow/operators/docker_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.docker.operators.docker`.""" - -import warnings - -from airflow.providers.docker.operators.docker import DockerOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.docker.operators.docker`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/operators/druid_check_operator.py b/airflow/operators/druid_check_operator.py deleted file mode 100644 index 008a91750c91d..0000000000000 --- a/airflow/operators/druid_check_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.druid.operators.druid_check`.""" - -import warnings - -from airflow.providers.apache.druid.operators.druid_check import DruidCheckOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.operators.sql.SQLCheckOperator`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/operators/dummy.py b/airflow/operators/dummy.py deleted file mode 100644 index b2912e92586f0..0000000000000 --- a/airflow/operators/dummy.py +++ /dev/null @@ -1,45 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.operators.empty`.""" - -import warnings - -from airflow.operators.empty import EmptyOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.operators.empty`.", - DeprecationWarning, - stacklevel=2, -) - - -class DummyOperator(EmptyOperator): - """This class is deprecated. Please use `airflow.operators.empty.EmptyOperator`.""" - - @property - def inherits_from_dummy_operator(self): - return True - - def __init__(self, **kwargs): - warnings.warn( - """This class is deprecated. Please use `airflow.operators.empty.EmptyOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - self.inherits_from_empty_operator = self.inherits_from_dummy_operator - super().__init__(**kwargs) diff --git a/airflow/operators/dummy_operator.py b/airflow/operators/dummy_operator.py deleted file mode 100644 index 2b46095027875..0000000000000 --- a/airflow/operators/dummy_operator.py +++ /dev/null @@ -1,38 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.operators.empty`.""" - -import warnings - -from airflow.operators.empty import EmptyOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.operators.empty`.", DeprecationWarning, stacklevel=2 -) - - -class DummyOperator(EmptyOperator): - """This class is deprecated. Please use `airflow.operators.empty.EmptyOperator`.""" - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. Please use `airflow.operators.empty.EmptyOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/operators/email.py b/airflow/operators/email.py index 220bafa944b12..1b8e529593451 100644 --- a/airflow/operators/email.py +++ b/airflow/operators/email.py @@ -15,9 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional, Sequence, Union +from __future__ import annotations -from airflow.models import BaseOperator +from typing import Any, Sequence + +from airflow.models.baseoperator import BaseOperator from airflow.utils.context import Context from airflow.utils.email import send_email @@ -39,24 +41,24 @@ class EmailOperator(BaseOperator): :param custom_headers: additional headers to add to the MIME message. """ - template_fields: Sequence[str] = ('to', 'subject', 'html_content', 'files') + template_fields: Sequence[str] = ("to", "subject", "html_content", "files") template_fields_renderers = {"html_content": "html"} - template_ext: Sequence[str] = ('.html',) - ui_color = '#e6faf9' + template_ext: Sequence[str] = (".html",) + ui_color = "#e6faf9" def __init__( self, *, - to: Union[List[str], str], + to: list[str] | str, subject: str, html_content: str, - files: Optional[List] = None, - cc: Optional[Union[List[str], str]] = None, - bcc: Optional[Union[List[str], str]] = None, - mime_subtype: str = 'mixed', - mime_charset: str = 'utf-8', - conn_id: Optional[str] = None, - custom_headers: Optional[Dict[str, Any]] = None, + files: list | None = None, + cc: list[str] | str | None = None, + bcc: list[str] | str | None = None, + mime_subtype: str = "mixed", + mime_charset: str = "utf-8", + conn_id: str | None = None, + custom_headers: dict[str, Any] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) diff --git a/airflow/operators/email_operator.py b/airflow/operators/email_operator.py deleted file mode 100644 index 80901d010f669..0000000000000 --- a/airflow/operators/email_operator.py +++ /dev/null @@ -1,26 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.operators.email`.""" - -import warnings - -from airflow.operators.email import EmailOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.operators.email`.", DeprecationWarning, stacklevel=2 -) diff --git a/airflow/operators/empty.py b/airflow/operators/empty.py index 1f396d0f6a664..d4438c565e927 100644 --- a/airflow/operators/empty.py +++ b/airflow/operators/empty.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from airflow.models.baseoperator import BaseOperator from airflow.utils.context import Context @@ -27,7 +28,7 @@ class EmptyOperator(BaseOperator): The task is evaluated by the scheduler but never processed by the executor. """ - ui_color = '#e8f7e4' + ui_color = "#e8f7e4" inherits_from_empty_operator = True def execute(self, context: Context): diff --git a/airflow/operators/gcs_to_s3.py b/airflow/operators/gcs_to_s3.py deleted file mode 100644 index d02bc7f224ea9..0000000000000 --- a/airflow/operators/gcs_to_s3.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.transfers.gcs_to_s3`.""" - -import warnings - -from airflow.providers.amazon.aws.transfers.gcs_to_s3 import GCSToS3Operator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.gcs_to_s3`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/operators/generic_transfer.py b/airflow/operators/generic_transfer.py index 2a42859d17e74..25cd3f1f69bdf 100644 --- a/airflow/operators/generic_transfer.py +++ b/airflow/operators/generic_transfer.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List, Optional, Sequence, Union +from __future__ import annotations + +from typing import Sequence from airflow.hooks.base import BaseHook from airflow.models import BaseOperator @@ -40,13 +42,13 @@ class GenericTransfer(BaseOperator): :param insert_args: extra params for `insert_rows` method. """ - template_fields: Sequence[str] = ('sql', 'destination_table', 'preoperator') + template_fields: Sequence[str] = ("sql", "destination_table", "preoperator") template_ext: Sequence[str] = ( - '.sql', - '.hql', + ".sql", + ".hql", ) template_fields_renderers = {"preoperator": "sql"} - ui_color = '#b0f07c' + ui_color = "#b0f07c" def __init__( self, @@ -55,8 +57,8 @@ def __init__( destination_table: str, source_conn_id: str, destination_conn_id: str, - preoperator: Optional[Union[str, List[str]]] = None, - insert_args: Optional[dict] = None, + preoperator: str | list[str] | None = None, + insert_args: dict | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -73,7 +75,7 @@ def execute(self, context: Context): self.log.info("Extracting data from %s", self.source_conn_id) self.log.info("Executing: \n %s", self.sql) - get_records = getattr(source_hook, 'get_records', None) + get_records = getattr(source_hook, "get_records", None) if not callable(get_records): raise RuntimeError( f"Hook for connection {self.source_conn_id!r} " @@ -83,7 +85,7 @@ def execute(self, context: Context): results = get_records(self.sql) if self.preoperator: - run = getattr(destination_hook, 'run', None) + run = getattr(destination_hook, "run", None) if not callable(run): raise RuntimeError( f"Hook for connection {self.destination_conn_id!r} " @@ -93,7 +95,7 @@ def execute(self, context: Context): self.log.info(self.preoperator) run(self.preoperator) - insert_rows = getattr(destination_hook, 'insert_rows', None) + insert_rows = getattr(destination_hook, "insert_rows", None) if not callable(insert_rows): raise RuntimeError( f"Hook for connection {self.destination_conn_id!r} " diff --git a/airflow/operators/google_api_to_s3_transfer.py b/airflow/operators/google_api_to_s3_transfer.py deleted file mode 100644 index 9566cddc77641..0000000000000 --- a/airflow/operators/google_api_to_s3_transfer.py +++ /dev/null @@ -1,50 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use `airflow.providers.amazon.aws.transfers.google_api_to_s3`. -""" - -import warnings - -from airflow.providers.amazon.aws.transfers.google_api_to_s3 import GoogleApiToS3Operator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.google_api_to_s3`.", - DeprecationWarning, - stacklevel=2, -) - - -class GoogleApiToS3Transfer(GoogleApiToS3Operator): - """This class is deprecated. - - Please use: - `airflow.providers.amazon.aws.transfers.google_api_to_s3.GoogleApiToS3Operator`. - """ - - def __init__(self, **kwargs): - warnings.warn( - "This class is deprecated. " - "Please use " - "`airflow.providers.amazon.aws.transfers." - "google_api_to_s3_transfer.GoogleApiToS3Operator`.", - DeprecationWarning, - stacklevel=3, - ) - super().__init__(**kwargs) diff --git a/airflow/operators/hive_operator.py b/airflow/operators/hive_operator.py deleted file mode 100644 index b49cf097305ea..0000000000000 --- a/airflow/operators/hive_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.hive.operators.hive`.""" - -import warnings - -from airflow.providers.apache.hive.operators.hive import HiveOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.hive.operators.hive`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/operators/hive_stats_operator.py b/airflow/operators/hive_stats_operator.py deleted file mode 100644 index af1e260a4a155..0000000000000 --- a/airflow/operators/hive_stats_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.hive.operators.hive_stats`.""" - -import warnings - -from airflow.providers.apache.hive.operators.hive_stats import HiveStatsCollectionOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.hive.operators.hive_stats`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/operators/hive_to_druid.py b/airflow/operators/hive_to_druid.py deleted file mode 100644 index a6537a1337a56..0000000000000 --- a/airflow/operators/hive_to_druid.py +++ /dev/null @@ -1,49 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.apache.druid.transfers.hive_to_druid`. -""" - -import warnings - -from airflow.providers.apache.druid.transfers.hive_to_druid import HiveToDruidOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.druid.transfers.hive_to_druid`.", - DeprecationWarning, - stacklevel=2, -) - - -class HiveToDruidTransfer(HiveToDruidOperator): - """This class is deprecated. - - Please use: - `airflow.providers.apache.druid.transfers.hive_to_druid.HiveToDruidOperator`. - """ - - def __init__(self, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.apache.druid.transfers.hive_to_druid.HiveToDruidOperator`.""", - DeprecationWarning, - stacklevel=3, - ) - super().__init__(**kwargs) diff --git a/airflow/operators/hive_to_mysql.py b/airflow/operators/hive_to_mysql.py deleted file mode 100644 index 0a13c7666a4cb..0000000000000 --- a/airflow/operators/hive_to_mysql.py +++ /dev/null @@ -1,49 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use `airflow.providers.apache.hive.transfers.hive_to_mysql`. -""" - -import warnings - -from airflow.providers.apache.hive.transfers.hive_to_mysql import HiveToMySqlOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.hive.transfers.hive_to_mysql`.", - DeprecationWarning, - stacklevel=2, -) - - -class HiveToMySqlTransfer(HiveToMySqlOperator): - """This class is deprecated. - - Please use: - `airflow.providers.apache.hive.transfers.hive_to_mysql.HiveToMySqlOperator`. - """ - - def __init__(self, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.apache.hive.transfers.hive_to_mysql.HiveToMySqlOperator`.""", - DeprecationWarning, - stacklevel=3, - ) - super().__init__(**kwargs) diff --git a/airflow/operators/hive_to_samba_operator.py b/airflow/operators/hive_to_samba_operator.py deleted file mode 100644 index ed3b180b3e7c4..0000000000000 --- a/airflow/operators/hive_to_samba_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.hive.transfers.hive_to_samba`.""" - -import warnings - -from airflow.providers.apache.hive.transfers.hive_to_samba import HiveToSambaOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.hive.transfers.hive_to_samba`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/operators/http_operator.py b/airflow/operators/http_operator.py deleted file mode 100644 index 6e2ab56df4e58..0000000000000 --- a/airflow/operators/http_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.http.operators.http`.""" - -import warnings - -from airflow.providers.http.operators.http import SimpleHttpOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.http.operators.http`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/operators/jdbc_operator.py b/airflow/operators/jdbc_operator.py deleted file mode 100644 index ff36f9f5d6467..0000000000000 --- a/airflow/operators/jdbc_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.jdbc.operators.jdbc`.""" - -import warnings - -from airflow.providers.jdbc.operators.jdbc import JdbcOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.jdbc.operators.jdbc`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/operators/latest_only.py b/airflow/operators/latest_only.py index 8ca9688f96627..f0ba76ecf6b67 100644 --- a/airflow/operators/latest_only.py +++ b/airflow/operators/latest_only.py @@ -19,7 +19,9 @@ This module contains an operator to run downstream tasks only for the latest scheduled DagRun """ -from typing import TYPE_CHECKING, Iterable, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable import pendulum @@ -42,37 +44,37 @@ class LatestOnlyOperator(BaseBranchOperator): marked as externally triggered. """ - ui_color = '#e9ffdb' # nyanza + ui_color = "#e9ffdb" # nyanza - def choose_branch(self, context: Context) -> Union[str, Iterable[str]]: + def choose_branch(self, context: Context) -> str | Iterable[str]: # If the DAG Run is externally triggered, then return without # skipping downstream tasks - dag_run: "DagRun" = context["dag_run"] + dag_run: DagRun = context["dag_run"] if dag_run.external_trigger: self.log.info("Externally triggered DAG_Run: allowing execution to proceed.") - return list(context['task'].get_direct_relative_ids(upstream=False)) + return list(context["task"].get_direct_relative_ids(upstream=False)) - dag: "DAG" = context["dag"] + dag: DAG = context["dag"] next_info = dag.next_dagrun_info(dag.get_run_data_interval(dag_run), restricted=False) - now = pendulum.now('UTC') + now = pendulum.now("UTC") if next_info is None: self.log.info("Last scheduled execution: allowing execution to proceed.") - return list(context['task'].get_direct_relative_ids(upstream=False)) + return list(context["task"].get_direct_relative_ids(upstream=False)) left_window, right_window = next_info.data_interval self.log.info( - 'Checking latest only with left_window: %s right_window: %s now: %s', + "Checking latest only with left_window: %s right_window: %s now: %s", left_window, right_window, now, ) if not left_window < now <= right_window: - self.log.info('Not latest execution, skipping downstream.') + self.log.info("Not latest execution, skipping downstream.") # we return an empty list, thus the parent BaseBranchOperator # won't exclude any downstream tasks from skipping. return [] else: - self.log.info('Latest, allowing execution to proceed.') - return list(context['task'].get_direct_relative_ids(upstream=False)) + self.log.info("Latest, allowing execution to proceed.") + return list(context["task"].get_direct_relative_ids(upstream=False)) diff --git a/airflow/operators/latest_only_operator.py b/airflow/operators/latest_only_operator.py deleted file mode 100644 index 07644f4a82c10..0000000000000 --- a/airflow/operators/latest_only_operator.py +++ /dev/null @@ -1,25 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.operators.latest_only`""" -import warnings - -from airflow.operators.latest_only import LatestOnlyOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.operators.latest_only`.", DeprecationWarning, stacklevel=2 -) diff --git a/airflow/operators/mssql_operator.py b/airflow/operators/mssql_operator.py deleted file mode 100644 index d1047b827a722..0000000000000 --- a/airflow/operators/mssql_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.mssql.operators.mssql`.""" - -import warnings - -from airflow.providers.microsoft.mssql.operators.mssql import MsSqlOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.mssql.operators.mssql`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/operators/mssql_to_hive.py b/airflow/operators/mssql_to_hive.py deleted file mode 100644 index 02edb36b89b15..0000000000000 --- a/airflow/operators/mssql_to_hive.py +++ /dev/null @@ -1,49 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.apache.hive.transfers.mssql_to_hive`. -""" - -import warnings - -from airflow.providers.apache.hive.transfers.mssql_to_hive import MsSqlToHiveOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.hive.transfers.mssql_to_hive`.", - DeprecationWarning, - stacklevel=2, -) - - -class MsSqlToHiveTransfer(MsSqlToHiveOperator): - """This class is deprecated. - - Please use: - `airflow.providers.apache.hive.transfers.mssql_to_hive.MsSqlToHiveOperator`. - """ - - def __init__(self, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.apache.hive.transfers.mssql_to_hive.MsSqlToHiveOperator`.""", - DeprecationWarning, - stacklevel=3, - ) - super().__init__(**kwargs) diff --git a/airflow/operators/mysql_operator.py b/airflow/operators/mysql_operator.py deleted file mode 100644 index 82a94edd66add..0000000000000 --- a/airflow/operators/mysql_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.mysql.operators.mysql`.""" - -import warnings - -from airflow.providers.mysql.operators.mysql import MySqlOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.mysql.operators.mysql`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/operators/mysql_to_hive.py b/airflow/operators/mysql_to_hive.py deleted file mode 100644 index 95bd4302e7961..0000000000000 --- a/airflow/operators/mysql_to_hive.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.hive.transfers.mysql_to_hive`.""" - -import warnings - -from airflow.providers.apache.hive.transfers.mysql_to_hive import MySqlToHiveOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.hive.transfers.mysql_to_hive`.", - DeprecationWarning, - stacklevel=2, -) - - -class MySqlToHiveTransfer(MySqlToHiveOperator): - """ - This class is deprecated. - Please use `airflow.providers.apache.hive.transfers.mysql_to_hive.MySqlToHiveOperator`. - """ - - def __init__(self, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.apache.hive.transfers.mysql_to_hive.MySqlToHiveOperator`.""", - DeprecationWarning, - stacklevel=3, - ) - super().__init__(**kwargs) diff --git a/airflow/operators/oracle_operator.py b/airflow/operators/oracle_operator.py deleted file mode 100644 index 8ad61db754dcb..0000000000000 --- a/airflow/operators/oracle_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.oracle.operators.oracle`.""" - -import warnings - -from airflow.providers.oracle.operators.oracle import OracleOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.oracle.operators.oracle`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/operators/papermill_operator.py b/airflow/operators/papermill_operator.py deleted file mode 100644 index 5d63e38e13721..0000000000000 --- a/airflow/operators/papermill_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.papermill.operators.papermill`.""" - -import warnings - -from airflow.providers.papermill.operators.papermill import PapermillOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.papermill.operators.papermill`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/operators/pig_operator.py b/airflow/operators/pig_operator.py deleted file mode 100644 index 3b2ea0e05ac99..0000000000000 --- a/airflow/operators/pig_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.pig.operators.pig`.""" - -import warnings - -from airflow.providers.apache.pig.operators.pig import PigOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.pig.operators.pig`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/operators/postgres_operator.py b/airflow/operators/postgres_operator.py deleted file mode 100644 index e5dc53c82bde6..0000000000000 --- a/airflow/operators/postgres_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.postgres.operators.postgres`.""" - -import warnings - -from airflow.providers.postgres.operators.postgres import Mapping, PostgresOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.postgres.operators.postgres`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/operators/presto_check_operator.py b/airflow/operators/presto_check_operator.py deleted file mode 100644 index 693471f18ceb4..0000000000000 --- a/airflow/operators/presto_check_operator.py +++ /dev/null @@ -1,78 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.operators.sql`.""" - -import warnings - -from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.operators.sql`.", DeprecationWarning, stacklevel=2 -) - - -class PrestoCheckOperator(SQLCheckOperator): - """ - This class is deprecated. - Please use `airflow.operators.sql.SQLCheckOperator`. - """ - - def __init__(self, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.operators.sql.SQLCheckOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(**kwargs) - - -class PrestoIntervalCheckOperator(SQLIntervalCheckOperator): - """ - This class is deprecated. - Please use `airflow.operators.sql.SQLIntervalCheckOperator`. - """ - - def __init__(self, **kwargs): - warnings.warn( - """ - This class is deprecated.l - Please use `airflow.operators.sql.SQLIntervalCheckOperator`. - """, - DeprecationWarning, - stacklevel=2, - ) - super().__init__(**kwargs) - - -class PrestoValueCheckOperator(SQLValueCheckOperator): - """ - This class is deprecated. - Please use `airflow.operators.sql.SQLValueCheckOperator`. - """ - - def __init__(self, **kwargs): - warnings.warn( - """ - This class is deprecated.l - Please use `airflow.operators.sql.SQLValueCheckOperator`. - """, - DeprecationWarning, - stacklevel=2, - ) - super().__init__(**kwargs) diff --git a/airflow/operators/presto_to_mysql.py b/airflow/operators/presto_to_mysql.py deleted file mode 100644 index bfc117327d672..0000000000000 --- a/airflow/operators/presto_to_mysql.py +++ /dev/null @@ -1,49 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.mysql.transfers.presto_to_mysql`. -""" - -import warnings - -from airflow.providers.mysql.transfers.presto_to_mysql import PrestoToMySqlOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.mysql.transfers.presto_to_mysql`.", - DeprecationWarning, - stacklevel=2, -) - - -class PrestoToMySqlTransfer(PrestoToMySqlOperator): - """This class is deprecated. - - Please use: - `airflow.providers.mysql.transfers.presto_to_mysql.PrestoToMySqlOperator`. - """ - - def __init__(self, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.mysql.transfers.presto_to_mysql.PrestoToMySqlOperator`.""", - DeprecationWarning, - stacklevel=3, - ) - super().__init__(**kwargs) diff --git a/airflow/operators/python.py b/airflow/operators/python.py index 920a75d08f766..0a0dd34fef355 100644 --- a/airflow/operators/python.py +++ b/airflow/operators/python.py @@ -15,20 +15,25 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import inspect import os import pickle import shutil +import subprocess import sys import types import warnings +from abc import ABCMeta, abstractmethod +from pathlib import Path from tempfile import TemporaryDirectory from textwrap import dedent -from typing import Any, Callable, Collection, Dict, Iterable, List, Mapping, Optional, Sequence, Union +from typing import Any, Callable, Collection, Iterable, Mapping, Sequence import dill -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowConfigException, AirflowException, RemovedInAirflow3Warning from airflow.models.baseoperator import BaseOperator from airflow.models.skipmixin import SkipMixin from airflow.models.taskinstance import _CURRENT_CONTEXT @@ -36,9 +41,10 @@ from airflow.utils.operator_helpers import KeywordParameters from airflow.utils.process_utils import execute_in_subprocess from airflow.utils.python_virtualenv import prepare_virtualenv, write_python_script +from airflow.version import version as airflow_version -def task(python_callable: Optional[Callable] = None, multiple_outputs: Optional[bool] = None, **kwargs): +def task(python_callable: Callable | None = None, multiple_outputs: bool | None = None, **kwargs): """ Deprecated function that calls @task.python and allows users to turn a python function into an Airflow task. Please use the following instead: @@ -68,7 +74,7 @@ def my_task() from airflow.decorators import task @task def my_task()""", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) return python_task(python_callable=python_callable, multiple_outputs=multiple_outputs, **kwargs) @@ -121,41 +127,39 @@ def my_python_callable(**kwargs): such as transmission a large amount of XCom to TaskAPI. """ - template_fields: Sequence[str] = ('templates_dict', 'op_args', 'op_kwargs') + template_fields: Sequence[str] = ("templates_dict", "op_args", "op_kwargs") template_fields_renderers = {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"} - BLUE = '#ffefeb' + BLUE = "#ffefeb" ui_color = BLUE # since we won't mutate the arguments, we should just do the shallow copy # there are some cases we can't deepcopy the objects(e.g protobuf). shallow_copy_attrs: Sequence[str] = ( - 'python_callable', - 'op_kwargs', + "python_callable", + "op_kwargs", ) - mapped_arguments_validated_by_init = True - def __init__( self, *, python_callable: Callable, - op_args: Optional[Collection[Any]] = None, - op_kwargs: Optional[Mapping[str, Any]] = None, - templates_dict: Optional[Dict[str, Any]] = None, - templates_exts: Optional[Sequence[str]] = None, + op_args: Collection[Any] | None = None, + op_kwargs: Mapping[str, Any] | None = None, + templates_dict: dict[str, Any] | None = None, + templates_exts: Sequence[str] | None = None, show_return_value_in_logs: bool = True, **kwargs, ) -> None: if kwargs.get("provide_context"): warnings.warn( "provide_context is deprecated as of 2.0 and is no longer required", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) - kwargs.pop('provide_context', None) + kwargs.pop("provide_context", None) super().__init__(**kwargs) if not callable(python_callable): - raise AirflowException('`python_callable` param must be callable') + raise AirflowException("`python_callable` param must be callable") self.python_callable = python_callable self.op_args = op_args or () self.op_kwargs = op_kwargs or {} @@ -179,12 +183,11 @@ def execute(self, context: Context) -> Any: def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]: return KeywordParameters.determine(self.python_callable, self.op_args, context).unpacking() - def execute_callable(self): + def execute_callable(self) -> Any: """ Calls the python callable with the given arguments. :return: the return value of the call. - :rtype: any """ return self.python_callable(*self.op_args, **self.op_kwargs) @@ -205,22 +208,8 @@ class BranchPythonOperator(PythonOperator, SkipMixin): def execute(self, context: Context) -> Any: branch = super().execute(context) - # TODO: The logic should be moved to SkipMixin to be available to all branch operators. - if isinstance(branch, str): - branches = {branch} - elif isinstance(branch, list): - branches = set(branch) - elif branch is None: - branches = set() - else: - raise AirflowException("Branch callable must return either None, a task ID, or a list of IDs") - valid_task_ids = set(context["dag"].task_ids) - invalid_task_ids = branches - valid_task_ids - if invalid_task_ids: - raise AirflowException( - f"Branch callable must return valid task_ids. Invalid tasks found: {invalid_task_ids}" - ) - self.skip_all_except(context['ti'], branch) + self.log.info("Branch callable return %s", branch) + self.skip_all_except(context["ti"], branch) return branch @@ -259,10 +248,10 @@ def execute(self, context: Context) -> Any: self.log.info("Condition result is %s", condition) if condition: - self.log.info('Proceeding with downstream tasks...') + self.log.info("Proceeding with downstream tasks...") return condition - downstream_tasks = context['task'].get_flat_relatives(upstream=False) + downstream_tasks = context["task"].get_flat_relatives(upstream=False) self.log.debug("Downstream task IDs %s", downstream_tasks) if downstream_tasks: @@ -281,7 +270,161 @@ def execute(self, context: Context) -> Any: self.log.info("Done.") -class PythonVirtualenvOperator(PythonOperator): +class _BasePythonVirtualenvOperator(PythonOperator, metaclass=ABCMeta): + BASE_SERIALIZABLE_CONTEXT_KEYS = { + "ds", + "ds_nodash", + "expanded_ti_count", + "inlets", + "next_ds", + "next_ds_nodash", + "outlets", + "prev_ds", + "prev_ds_nodash", + "run_id", + "task_instance_key_str", + "test_mode", + "tomorrow_ds", + "tomorrow_ds_nodash", + "ts", + "ts_nodash", + "ts_nodash_with_tz", + "yesterday_ds", + "yesterday_ds_nodash", + } + PENDULUM_SERIALIZABLE_CONTEXT_KEYS = { + "data_interval_end", + "data_interval_start", + "execution_date", + "logical_date", + "next_execution_date", + "prev_data_interval_end_success", + "prev_data_interval_start_success", + "prev_execution_date", + "prev_execution_date_success", + "prev_start_date_success", + } + AIRFLOW_SERIALIZABLE_CONTEXT_KEYS = { + "macros", + "conf", + "dag", + "dag_run", + "task", + "params", + "triggering_dataset_events", + } + + def __init__( + self, + *, + python_callable: Callable, + use_dill: bool = False, + op_args: Collection[Any] | None = None, + op_kwargs: Mapping[str, Any] | None = None, + string_args: Iterable[str] | None = None, + templates_dict: dict | None = None, + templates_exts: list[str] | None = None, + expect_airflow: bool = True, + **kwargs, + ): + if ( + not isinstance(python_callable, types.FunctionType) + or isinstance(python_callable, types.LambdaType) + and python_callable.__name__ == "" + ): + raise AirflowException("PythonVirtualenvOperator only supports functions for python_callable arg") + super().__init__( + python_callable=python_callable, + op_args=op_args, + op_kwargs=op_kwargs, + templates_dict=templates_dict, + templates_exts=templates_exts, + **kwargs, + ) + self.string_args = string_args or [] + self.use_dill = use_dill + self.pickling_library = dill if self.use_dill else pickle + self.expect_airflow = expect_airflow + + @abstractmethod + def _iter_serializable_context_keys(self): + pass + + def execute(self, context: Context) -> Any: + serializable_keys = set(self._iter_serializable_context_keys()) + serializable_context = context_copy_partial(context, serializable_keys) + return super().execute(context=serializable_context) + + def get_python_source(self): + """ + Returns the source of self.python_callable + @return: + """ + return dedent(inspect.getsource(self.python_callable)) + + def _write_args(self, file: Path): + if self.op_args or self.op_kwargs: + file.write_bytes(self.pickling_library.dumps({"args": self.op_args, "kwargs": self.op_kwargs})) + + def _write_string_args(self, file: Path): + file.write_text("\n".join(map(str, self.string_args))) + + def _read_result(self, path: Path): + if path.stat().st_size == 0: + return None + try: + return self.pickling_library.loads(path.read_bytes()) + except ValueError: + self.log.error( + "Error deserializing result. Note that result deserialization " + "is not supported across major Python versions." + ) + raise + + def __deepcopy__(self, memo): + # module objects can't be copied _at all__ + memo[id(self.pickling_library)] = self.pickling_library + return super().__deepcopy__(memo) + + def _execute_python_callable_in_subprocess(self, python_path: Path, tmp_dir: Path): + op_kwargs: dict[str, Any] = {k: v for k, v in self.op_kwargs.items()} + if self.templates_dict: + op_kwargs["templates_dict"] = self.templates_dict + input_path = tmp_dir / "script.in" + output_path = tmp_dir / "script.out" + string_args_path = tmp_dir / "string_args.txt" + script_path = tmp_dir / "script.py" + self._write_args(input_path) + self._write_string_args(string_args_path) + write_python_script( + jinja_context=dict( + op_args=self.op_args, + op_kwargs=op_kwargs, + expect_airflow=self.expect_airflow, + pickling_library=self.pickling_library.__name__, + python_callable=self.python_callable.__name__, + python_callable_source=self.get_python_source(), + ), + filename=os.fspath(script_path), + render_template_as_native_obj=self.dag.render_template_as_native_obj, + ) + + execute_in_subprocess( + cmd=[ + os.fspath(python_path), + os.fspath(script_path), + os.fspath(input_path), + os.fspath(output_path), + os.fspath(string_args_path), + ] + ) + return self._read_result(output_path) + + def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]: + return KeywordParameters.determine(self.python_callable, self.op_args, context).serializing() + + +class PythonVirtualenvOperator(_BasePythonVirtualenvOperator): """ Allows one to run a function in a virtualenv that is created and destroyed automatically (with certain caveats). @@ -325,67 +468,31 @@ class PythonVirtualenvOperator(PythonOperator): in your callable's context after the template has been applied :param templates_exts: a list of file extensions to resolve while processing templated fields, for examples ``['.sql', '.hql']`` + :param expect_airflow: expect Airflow to be installed in the target environment. If true, the operator + will raise warning if Airflow is not installed, and it will attempt to load Airflow + macros when starting. """ - template_fields: Sequence[str] = tuple({'requirements'} | set(PythonOperator.template_fields)) - - template_ext: Sequence[str] = ('.txt',) - BASE_SERIALIZABLE_CONTEXT_KEYS = { - 'ds', - 'ds_nodash', - 'inlets', - 'next_ds', - 'next_ds_nodash', - 'outlets', - 'prev_ds', - 'prev_ds_nodash', - 'run_id', - 'task_instance_key_str', - 'test_mode', - 'tomorrow_ds', - 'tomorrow_ds_nodash', - 'ts', - 'ts_nodash', - 'ts_nodash_with_tz', - 'yesterday_ds', - 'yesterday_ds_nodash', - } - PENDULUM_SERIALIZABLE_CONTEXT_KEYS = { - 'data_interval_end', - 'data_interval_start', - 'execution_date', - 'logical_date', - 'next_execution_date', - 'prev_data_interval_end_success', - 'prev_data_interval_start_success', - 'prev_execution_date', - 'prev_execution_date_success', - 'prev_start_date_success', - } - AIRFLOW_SERIALIZABLE_CONTEXT_KEYS = {'macros', 'conf', 'dag', 'dag_run', 'task', 'params'} + template_fields: Sequence[str] = tuple({"requirements"} | set(PythonOperator.template_fields)) + template_ext: Sequence[str] = (".txt",) def __init__( self, *, python_callable: Callable, - requirements: Union[None, Iterable[str], str] = None, - python_version: Optional[Union[str, int, float]] = None, + requirements: None | Iterable[str] | str = None, + python_version: str | int | float | None = None, use_dill: bool = False, system_site_packages: bool = True, - pip_install_options: Optional[List[str]] = None, - op_args: Optional[Collection[Any]] = None, - op_kwargs: Optional[Mapping[str, Any]] = None, - string_args: Optional[Iterable[str]] = None, - templates_dict: Optional[Dict] = None, - templates_exts: Optional[List[str]] = None, + pip_install_options: list[str] | None = None, + op_args: Collection[Any] | None = None, + op_kwargs: Mapping[str, Any] | None = None, + string_args: Iterable[str] | None = None, + templates_dict: dict | None = None, + templates_exts: list[str] | None = None, + expect_airflow: bool = True, **kwargs, ): - if ( - not isinstance(python_callable, types.FunctionType) - or isinstance(python_callable, types.LambdaType) - and python_callable.__name__ == "" - ): - raise AirflowException('PythonVirtualenvOperator only supports functions for python_callable arg') if ( python_version and str(python_version)[0] != str(sys.version_info.major) @@ -394,41 +501,35 @@ def __init__( raise AirflowException( "Passing op_args or op_kwargs is not supported across different Python " "major versions for PythonVirtualenvOperator. Please use string_args." + f"Sys version: {sys.version_info}. Venv version: {python_version}" ) if not shutil.which("virtualenv"): - raise AirflowException('PythonVirtualenvOperator requires virtualenv, please install it.') - super().__init__( - python_callable=python_callable, - op_args=op_args, - op_kwargs=op_kwargs, - templates_dict=templates_dict, - templates_exts=templates_exts, - **kwargs, - ) + raise AirflowException("PythonVirtualenvOperator requires virtualenv, please install it.") if not requirements: - self.requirements: Union[List[str], str] = [] + self.requirements: list[str] | str = [] elif isinstance(requirements, str): self.requirements = requirements else: self.requirements = list(requirements) - self.string_args = string_args or [] self.python_version = python_version - self.use_dill = use_dill self.system_site_packages = system_site_packages self.pip_install_options = pip_install_options - self.pickling_library = dill if self.use_dill else pickle - - def execute(self, context: Context) -> Any: - serializable_keys = set(self._iter_serializable_context_keys()) - serializable_context = context_copy_partial(context, serializable_keys) - return super().execute(context=serializable_context) - - def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]: - return KeywordParameters.determine(self.python_callable, self.op_args, context).serializing() + super().__init__( + python_callable=python_callable, + use_dill=use_dill, + op_args=op_args, + op_kwargs=op_kwargs, + string_args=string_args, + templates_dict=templates_dict, + templates_exts=templates_exts, + expect_airflow=expect_airflow, + **kwargs, + ) def execute_callable(self): - with TemporaryDirectory(prefix='venv') as tmp_dir: - requirements_file_name = f'{tmp_dir}/requirements.txt' + with TemporaryDirectory(prefix="venv") as tmp_dir: + tmp_path = Path(tmp_dir) + requirements_file_name = f"{tmp_dir}/requirements.txt" if not isinstance(self.requirements, str): requirements_file_contents = "\n".join(str(dependency) for dependency in self.requirements) @@ -436,94 +537,184 @@ def execute_callable(self): requirements_file_contents = self.requirements if not self.system_site_packages and self.use_dill: - requirements_file_contents += '\ndill' + requirements_file_contents += "\ndill" - with open(requirements_file_name, 'w') as file: + with open(requirements_file_name, "w") as file: file.write(requirements_file_contents) - - if self.templates_dict: - self.op_kwargs['templates_dict'] = self.templates_dict - - input_filename = os.path.join(tmp_dir, 'script.in') - output_filename = os.path.join(tmp_dir, 'script.out') - string_args_filename = os.path.join(tmp_dir, 'string_args.txt') - script_filename = os.path.join(tmp_dir, 'script.py') - prepare_virtualenv( venv_directory=tmp_dir, - python_bin=f'python{self.python_version}' if self.python_version else None, + python_bin=f"python{self.python_version}" if self.python_version else None, system_site_packages=self.system_site_packages, requirements_file_path=requirements_file_name, pip_install_options=self.pip_install_options, ) + python_path = tmp_path / "bin" / "python" - self._write_args(input_filename) - self._write_string_args(string_args_filename) - write_python_script( - jinja_context=dict( - op_args=self.op_args, - op_kwargs=self.op_kwargs, - pickling_library=self.pickling_library.__name__, - python_callable=self.python_callable.__name__, - python_callable_source=self.get_python_source(), - ), - filename=script_filename, - render_template_as_native_obj=self.dag.render_template_as_native_obj, - ) + return self._execute_python_callable_in_subprocess(python_path, tmp_path) - execute_in_subprocess( - cmd=[ - f'{tmp_dir}/bin/python', - script_filename, - input_filename, - output_filename, - string_args_filename, - ] - ) + def _iter_serializable_context_keys(self): + yield from self.BASE_SERIALIZABLE_CONTEXT_KEYS + if self.system_site_packages or "apache-airflow" in self.requirements: + yield from self.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS + yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS + elif "pendulum" in self.requirements: + yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS - return self._read_result(output_filename) - def get_python_source(self): - """ - Returns the source of self.python_callable - @return: - """ - return dedent(inspect.getsource(self.python_callable)) +class ExternalPythonOperator(_BasePythonVirtualenvOperator): + """ + Allows one to run a function in a virtualenv that is not re-created but used as is + without the overhead of creating the virtualenv (with certain caveats). - def _write_args(self, filename): - if self.op_args or self.op_kwargs: - with open(filename, 'wb') as file: - self.pickling_library.dump({'args': self.op_args, 'kwargs': self.op_kwargs}, file) + The function must be defined using def, and not be + part of a class. All imports must happen inside the function + and no variables outside the scope may be referenced. A global scope + variable named virtualenv_string_args will be available (populated by + string_args). In addition, one can pass stuff through op_args and op_kwargs, and one + can use a return value. + Note that if your virtualenv runs in a different Python major version than Airflow, + you cannot use return values, op_args, op_kwargs, or use any macros that are being provided to + Airflow through plugins. You can use string_args though. + + If Airflow is installed in the external environment in different version that the version + used by the operator, the operator will fail., + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:ExternalPythonOperator` + + :param python: Full path string (file-system specific) that points to a Python binary inside + a virtualenv that should be used (in ``VENV/bin`` folder). Should be absolute path + (so usually start with "/" or "X:/" depending on the filesystem/os used). + :param python_callable: A python function with no references to outside variables, + defined with def, which will be run in a virtualenv + :param use_dill: Whether to use dill to serialize + the args and result (pickle is default). This allow more complex types + but if dill is not preinstalled in your venv, the task will fail with use_dill enabled. + :param op_args: A list of positional arguments to pass to python_callable. + :param op_kwargs: A dict of keyword arguments to pass to python_callable. + :param string_args: Strings that are present in the global var virtualenv_string_args, + available to python_callable at runtime as a list[str]. Note that args are split + by newline. + :param templates_dict: a dictionary where the values are templates that + will get templated by the Airflow engine sometime between + ``__init__`` and ``execute`` takes place and are made available + in your callable's context after the template has been applied + :param templates_exts: a list of file extensions to resolve while + processing templated fields, for examples ``['.sql', '.hql']`` + :param expect_airflow: expect Airflow to be installed in the target environment. If true, the operator + will raise warning if Airflow is not installed, and it will attempt to load Airflow + macros when starting. + """ + + template_fields: Sequence[str] = tuple({"python"} | set(PythonOperator.template_fields)) + + def __init__( + self, + *, + python: str, + python_callable: Callable, + use_dill: bool = False, + op_args: Collection[Any] | None = None, + op_kwargs: Mapping[str, Any] | None = None, + string_args: Iterable[str] | None = None, + templates_dict: dict | None = None, + templates_exts: list[str] | None = None, + expect_airflow: bool = True, + expect_pendulum: bool = False, + **kwargs, + ): + if not python: + raise ValueError("Python Path must be defined in ExternalPythonOperator") + self.python = python + self.expect_pendulum = expect_pendulum + super().__init__( + python_callable=python_callable, + use_dill=use_dill, + op_args=op_args, + op_kwargs=op_kwargs, + string_args=string_args, + templates_dict=templates_dict, + templates_exts=templates_exts, + expect_airflow=expect_airflow, + **kwargs, + ) + + def execute_callable(self): + python_path = Path(self.python) + if not python_path.exists(): + raise ValueError(f"Python Path '{python_path}' must exists") + if not python_path.is_file(): + raise ValueError(f"Python Path '{python_path}' must be a file") + if not python_path.is_absolute(): + raise ValueError(f"Python Path '{python_path}' must be an absolute path.") + python_version_as_list_of_strings = self._get_python_version_from_environment() + if ( + python_version_as_list_of_strings + and str(python_version_as_list_of_strings[0]) != str(sys.version_info.major) + and (self.op_args or self.op_kwargs) + ): + raise AirflowException( + "Passing op_args or op_kwargs is not supported across different Python " + "major versions for ExternalPythonOperator. Please use string_args." + f"Sys version: {sys.version_info}. Venv version: {python_version_as_list_of_strings}" + ) + with TemporaryDirectory(prefix="tmd") as tmp_dir: + tmp_path = Path(tmp_dir) + return self._execute_python_callable_in_subprocess(python_path, tmp_path) + + def _get_python_version_from_environment(self) -> list[str]: + try: + result = subprocess.check_output([self.python, "--version"], text=True) + return result.strip().split(" ")[-1].split(".") + except Exception as e: + raise ValueError(f"Error while executing {self.python}: {e}") def _iter_serializable_context_keys(self): yield from self.BASE_SERIALIZABLE_CONTEXT_KEYS - if self.system_site_packages or 'apache-airflow' in self.requirements: + if self._get_airflow_version_from_target_env(): yield from self.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS - elif 'pendulum' in self.requirements: + elif self._is_pendulum_installed_in_target_env(): yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS - def _write_string_args(self, filename): - with open(filename, 'w') as file: - file.write('\n'.join(map(str, self.string_args))) - - def _read_result(self, filename): - if os.stat(filename).st_size == 0: - return None - with open(filename, 'rb') as file: - try: - return self.pickling_library.load(file) - except ValueError: - self.log.error( - "Error deserializing result. Note that result deserialization " - "is not supported across major Python versions." + def _is_pendulum_installed_in_target_env(self) -> bool: + try: + subprocess.check_call([self.python, "-c", "import pendulum"]) + return True + except Exception as e: + if self.expect_pendulum: + self.log.warning("When checking for Pendulum installed in venv got %s", e) + self.log.warning( + "Pendulum is not properly installed in the virtualenv " + "Pendulum context keys will not be available. " + "Please Install Pendulum or Airflow in your venv to access them." ) - raise + return False - def __deepcopy__(self, memo): - # module objects can't be copied _at all__ - memo[id(self.pickling_library)] = self.pickling_library - return super().__deepcopy__(memo) + def _get_airflow_version_from_target_env(self) -> str | None: + try: + result = subprocess.check_output( + [self.python, "-c", "from airflow import version; print(version.version)"], text=True + ) + target_airflow_version = result.strip() + if target_airflow_version != airflow_version: + raise AirflowConfigException( + f"The version of Airflow installed for the {self.python}(" + f"{target_airflow_version}) is different than the runtime Airflow version: " + f"{airflow_version}. Make sure your environment has the same Airflow version " + f"installed as the Airflow runtime." + ) + return target_airflow_version + except Exception as e: + if self.expect_airflow: + self.log.warning("When checking for Airflow installed in venv got %s", e) + self.log.warning( + f"This means that Airflow is not properly installed by " + f"{self.python}. Airflow context keys will not be available. " + f"Please Install Airflow {airflow_version} in your environment to access them." + ) + return None def get_current_context() -> Context: diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py deleted file mode 100644 index ac8c6448d241d..0000000000000 --- a/airflow/operators/python_operator.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.operators.python`.""" - -import warnings - -from airflow.operators.python import ( # noqa - BranchPythonOperator, - PythonOperator, - PythonVirtualenvOperator, - ShortCircuitOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.operators.python`.", DeprecationWarning, stacklevel=2 -) diff --git a/airflow/operators/redshift_to_s3_operator.py b/airflow/operators/redshift_to_s3_operator.py deleted file mode 100644 index 9fceb700d42c7..0000000000000 --- a/airflow/operators/redshift_to_s3_operator.py +++ /dev/null @@ -1,48 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.amazon.aws.transfers.redshift_to_s3`. -""" - -import warnings - -from airflow.providers.amazon.aws.transfers.redshift_to_s3 import RedshiftToS3Operator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.redshift_to_s3`.", - DeprecationWarning, - stacklevel=2, -) - - -class RedshiftToS3Transfer(RedshiftToS3Operator): - """ - This class is deprecated. - Please use: :class:`airflow.providers.amazon.aws.transfers.redshift_to_s3.RedshiftToS3Operator`. - """ - - def __init__(self, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.amazon.aws.transfers.redshift_to_s3.RedshiftToS3Operator`.""", - DeprecationWarning, - stacklevel=3, - ) - super().__init__(**kwargs) diff --git a/airflow/operators/s3_file_transform_operator.py b/airflow/operators/s3_file_transform_operator.py deleted file mode 100644 index 828031d814102..0000000000000 --- a/airflow/operators/s3_file_transform_operator.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.amazon.aws.operators.s3_file_transform` -""" - -import warnings - -from airflow.providers.amazon.aws.operators.s3_file_transform import S3FileTransformOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_file_transform`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/operators/s3_to_hive_operator.py b/airflow/operators/s3_to_hive_operator.py deleted file mode 100644 index b0e1f6b69258a..0000000000000 --- a/airflow/operators/s3_to_hive_operator.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.hive.transfers.s3_to_hive`.""" - -import warnings - -from airflow.providers.apache.hive.transfers.s3_to_hive import S3ToHiveOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.hive.transfers.s3_to_hive`.", - DeprecationWarning, - stacklevel=2, -) - - -class S3ToHiveTransfer(S3ToHiveOperator): - """ - This class is deprecated. - Please use `airflow.providers.apache.hive.transfers.s3_to_hive.S3ToHiveOperator`. - """ - - def __init__(self, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.apache.hive.transfers.s3_to_hive.S3ToHiveOperator`.""", - DeprecationWarning, - stacklevel=3, - ) - super().__init__(**kwargs) diff --git a/airflow/operators/s3_to_redshift_operator.py b/airflow/operators/s3_to_redshift_operator.py deleted file mode 100644 index f14a2912a8e8f..0000000000000 --- a/airflow/operators/s3_to_redshift_operator.py +++ /dev/null @@ -1,49 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.amazon.aws.transfers.s3_to_redshift`. -""" - -import warnings - -from airflow.providers.amazon.aws.transfers.s3_to_redshift import S3ToRedshiftOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.s3_to_redshift`.", - DeprecationWarning, - stacklevel=2, -) - - -class S3ToRedshiftTransfer(S3ToRedshiftOperator): - """This class is deprecated. - - Please use: - `airflow.providers.amazon.aws.transfers.s3_to_redshift.S3ToRedshiftOperator`. - """ - - def __init__(self, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.amazon.aws.transfers.s3_to_redshift.S3ToRedshiftOperator`.""", - DeprecationWarning, - stacklevel=3, - ) - super().__init__(**kwargs) diff --git a/airflow/operators/slack_operator.py b/airflow/operators/slack_operator.py deleted file mode 100644 index 3af49e222218e..0000000000000 --- a/airflow/operators/slack_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.slack.operators.slack`.""" - -import warnings - -from airflow.providers.slack.operators.slack import SlackAPIOperator, SlackAPIPostOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.slack.operators.slack`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/operators/smooth.py b/airflow/operators/smooth.py index 9dcbccb9e0a2f..66bf632de8395 100644 --- a/airflow/operators/smooth.py +++ b/airflow/operators/smooth.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from airflow.models.baseoperator import BaseOperator from airflow.utils.context import Context @@ -25,7 +27,7 @@ class SmoothOperator(BaseOperator): Sade song "Smooth Operator". """ - ui_color = '#e8f7e4' + ui_color = "#e8f7e4" yt_link: str = "https://www.youtube.com/watch?v=4TYv2PhG89A" def __init__(self, **kwargs) -> None: diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py deleted file mode 100644 index efa5d0d81a8f0..0000000000000 --- a/airflow/operators/sql.py +++ /dev/null @@ -1,557 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, SupportsAbs, Union - -from airflow.compat.functools import cached_property -from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook -from airflow.hooks.dbapi import DbApiHook -from airflow.models import BaseOperator, SkipMixin -from airflow.utils.context import Context - - -def parse_boolean(val: str) -> Union[str, bool]: - """Try to parse a string into boolean. - - Raises ValueError if the input is not a valid true- or false-like string value. - """ - val = val.lower() - if val in ('y', 'yes', 't', 'true', 'on', '1'): - return True - if val in ('n', 'no', 'f', 'false', 'off', '0'): - return False - raise ValueError(f"{val!r} is not a boolean-like string value") - - -class BaseSQLOperator(BaseOperator): - """ - This is a base class for generic SQL Operator to get a DB Hook - - The provided method is .get_db_hook(). The default behavior will try to - retrieve the DB hook based on connection type. - You can custom the behavior by overriding the .get_db_hook() method. - """ - - def __init__( - self, - *, - conn_id: Optional[str] = None, - database: Optional[str] = None, - hook_params: Optional[Dict] = None, - **kwargs, - ): - super().__init__(**kwargs) - self.conn_id = conn_id - self.database = database - self.hook_params = {} if hook_params is None else hook_params - - @cached_property - def _hook(self): - """Get DB Hook based on connection type""" - self.log.debug("Get connection for %s", self.conn_id) - conn = BaseHook.get_connection(self.conn_id) - - hook = conn.get_hook(hook_params=self.hook_params) - if not isinstance(hook, DbApiHook): - raise AirflowException( - f'The connection type is not supported by {self.__class__.__name__}. ' - f'The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}' - ) - - if self.database: - hook.schema = self.database - - return hook - - def get_db_hook(self) -> DbApiHook: - """ - Get the database hook for the connection. - - :return: the database hook object. - :rtype: DbApiHook - """ - return self._hook - - -class SQLCheckOperator(BaseSQLOperator): - """ - Performs checks against a db. The ``SQLCheckOperator`` expects - a sql query that will return a single row. Each value on that - first row is evaluated using python ``bool`` casting. If any of the - values return ``False`` the check is failed and errors out. - - Note that Python bool casting evals the following as ``False``: - - * ``False`` - * ``0`` - * Empty string (``""``) - * Empty list (``[]``) - * Empty dictionary or set (``{}``) - - Given a query like ``SELECT COUNT(*) FROM foo``, it will fail only if - the count ``== 0``. You can craft much more complex query that could, - for instance, check that the table has the same number of rows as - the source table upstream, or that the count of today's partition is - greater than yesterday's partition, or that a set of metrics are less - than 3 standard deviation for the 7 day average. - - This operator can be used as a data quality check in your pipeline, and - depending on where you put it in your DAG, you have the choice to - stop the critical path, preventing from - publishing dubious data, or on the side and receive email alerts - without stopping the progress of the DAG. - - :param sql: the sql to be executed. (templated) - :param conn_id: the connection ID used to connect to the database. - :param database: name of database which overwrite the defined one in connection - """ - - template_fields: Sequence[str] = ("sql",) - template_ext: Sequence[str] = ( - ".hql", - ".sql", - ) - template_fields_renderers = {"sql": "sql"} - ui_color = "#fff7e6" - - def __init__( - self, *, sql: str, conn_id: Optional[str] = None, database: Optional[str] = None, **kwargs - ) -> None: - super().__init__(conn_id=conn_id, database=database, **kwargs) - self.sql = sql - - def execute(self, context: Context): - self.log.info("Executing SQL check: %s", self.sql) - records = self.get_db_hook().get_first(self.sql) - - self.log.info("Record: %s", records) - if not records: - raise AirflowException("The query returned None") - elif not all(bool(r) for r in records): - raise AirflowException(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}") - - self.log.info("Success.") - - -def _convert_to_float_if_possible(s): - """ - A small helper function to convert a string to a numeric value - if appropriate - - :param s: the string to be converted - """ - try: - ret = float(s) - except (ValueError, TypeError): - ret = s - return ret - - -class SQLValueCheckOperator(BaseSQLOperator): - """ - Performs a simple value check using sql code. - - :param sql: the sql to be executed. (templated) - :param conn_id: the connection ID used to connect to the database. - :param database: name of database which overwrite the defined one in connection - """ - - __mapper_args__ = {"polymorphic_identity": "SQLValueCheckOperator"} - template_fields: Sequence[str] = ( - "sql", - "pass_value", - ) - template_ext: Sequence[str] = ( - ".hql", - ".sql", - ) - template_fields_renderers = {"sql": "sql"} - ui_color = "#fff7e6" - - def __init__( - self, - *, - sql: str, - pass_value: Any, - tolerance: Any = None, - conn_id: Optional[str] = None, - database: Optional[str] = None, - **kwargs, - ): - super().__init__(conn_id=conn_id, database=database, **kwargs) - self.sql = sql - self.pass_value = str(pass_value) - tol = _convert_to_float_if_possible(tolerance) - self.tol = tol if isinstance(tol, float) else None - self.has_tolerance = self.tol is not None - - def execute(self, context=None): - self.log.info("Executing SQL check: %s", self.sql) - records = self.get_db_hook().get_first(self.sql) - - if not records: - raise AirflowException("The query returned None") - - pass_value_conv = _convert_to_float_if_possible(self.pass_value) - is_numeric_value_check = isinstance(pass_value_conv, float) - - tolerance_pct_str = str(self.tol * 100) + "%" if self.has_tolerance else None - error_msg = ( - "Test failed.\nPass value:{pass_value_conv}\n" - "Tolerance:{tolerance_pct_str}\n" - "Query:\n{sql}\nResults:\n{records!s}" - ).format( - pass_value_conv=pass_value_conv, - tolerance_pct_str=tolerance_pct_str, - sql=self.sql, - records=records, - ) - - if not is_numeric_value_check: - tests = self._get_string_matches(records, pass_value_conv) - elif is_numeric_value_check: - try: - numeric_records = self._to_float(records) - except (ValueError, TypeError): - raise AirflowException(f"Converting a result to float failed.\n{error_msg}") - tests = self._get_numeric_matches(numeric_records, pass_value_conv) - else: - tests = [] - - if not all(tests): - raise AirflowException(error_msg) - - def _to_float(self, records): - return [float(record) for record in records] - - def _get_string_matches(self, records, pass_value_conv): - return [str(record) == pass_value_conv for record in records] - - def _get_numeric_matches(self, numeric_records, numeric_pass_value_conv): - if self.has_tolerance: - return [ - numeric_pass_value_conv * (1 - self.tol) <= record <= numeric_pass_value_conv * (1 + self.tol) - for record in numeric_records - ] - - return [record == numeric_pass_value_conv for record in numeric_records] - - -class SQLIntervalCheckOperator(BaseSQLOperator): - """ - Checks that the values of metrics given as SQL expressions are within - a certain tolerance of the ones from days_back before. - - :param table: the table name - :param conn_id: the connection ID used to connect to the database. - :param database: name of database which will overwrite the defined one in connection - :param days_back: number of days between ds and the ds we want to check - against. Defaults to 7 days - :param date_filter_column: The column name for the dates to filter on. Defaults to 'ds' - :param ratio_formula: which formula to use to compute the ratio between - the two metrics. Assuming cur is the metric of today and ref is - the metric to today - days_back. - - max_over_min: computes max(cur, ref) / min(cur, ref) - relative_diff: computes abs(cur-ref) / ref - - Default: 'max_over_min' - :param ignore_zero: whether we should ignore zero metrics - :param metrics_thresholds: a dictionary of ratios indexed by metrics - """ - - __mapper_args__ = {"polymorphic_identity": "SQLIntervalCheckOperator"} - template_fields: Sequence[str] = ("sql1", "sql2") - template_ext: Sequence[str] = ( - ".hql", - ".sql", - ) - template_fields_renderers = {"sql1": "sql", "sql2": "sql"} - ui_color = "#fff7e6" - - ratio_formulas = { - "max_over_min": lambda cur, ref: float(max(cur, ref)) / min(cur, ref), - "relative_diff": lambda cur, ref: float(abs(cur - ref)) / ref, - } - - def __init__( - self, - *, - table: str, - metrics_thresholds: Dict[str, int], - date_filter_column: Optional[str] = "ds", - days_back: SupportsAbs[int] = -7, - ratio_formula: Optional[str] = "max_over_min", - ignore_zero: bool = True, - conn_id: Optional[str] = None, - database: Optional[str] = None, - **kwargs, - ): - super().__init__(conn_id=conn_id, database=database, **kwargs) - if ratio_formula not in self.ratio_formulas: - msg_template = "Invalid diff_method: {diff_method}. Supported diff methods are: {diff_methods}" - - raise AirflowException( - msg_template.format(diff_method=ratio_formula, diff_methods=self.ratio_formulas) - ) - self.ratio_formula = ratio_formula - self.ignore_zero = ignore_zero - self.table = table - self.metrics_thresholds = metrics_thresholds - self.metrics_sorted = sorted(metrics_thresholds.keys()) - self.date_filter_column = date_filter_column - self.days_back = -abs(days_back) - sqlexp = ", ".join(self.metrics_sorted) - sqlt = f"SELECT {sqlexp} FROM {table} WHERE {date_filter_column}=" - - self.sql1 = sqlt + "'{{ ds }}'" - self.sql2 = sqlt + "'{{ macros.ds_add(ds, " + str(self.days_back) + ") }}'" - - def execute(self, context=None): - hook = self.get_db_hook() - self.log.info("Using ratio formula: %s", self.ratio_formula) - self.log.info("Executing SQL check: %s", self.sql2) - row2 = hook.get_first(self.sql2) - self.log.info("Executing SQL check: %s", self.sql1) - row1 = hook.get_first(self.sql1) - - if not row2: - raise AirflowException(f"The query {self.sql2} returned None") - if not row1: - raise AirflowException(f"The query {self.sql1} returned None") - - current = dict(zip(self.metrics_sorted, row1)) - reference = dict(zip(self.metrics_sorted, row2)) - - ratios = {} - test_results = {} - - for metric in self.metrics_sorted: - cur = current[metric] - ref = reference[metric] - threshold = self.metrics_thresholds[metric] - if cur == 0 or ref == 0: - ratios[metric] = None - test_results[metric] = self.ignore_zero - else: - ratios[metric] = self.ratio_formulas[self.ratio_formula](current[metric], reference[metric]) - test_results[metric] = ratios[metric] < threshold - - self.log.info( - ( - "Current metric for %s: %s\n" - "Past metric for %s: %s\n" - "Ratio for %s: %s\n" - "Threshold: %s\n" - ), - metric, - cur, - metric, - ref, - metric, - ratios[metric], - threshold, - ) - - if not all(test_results.values()): - failed_tests = [it[0] for it in test_results.items() if not it[1]] - self.log.warning( - "The following %s tests out of %s failed:", - len(failed_tests), - len(self.metrics_sorted), - ) - for k in failed_tests: - self.log.warning( - "'%s' check failed. %s is above %s", - k, - ratios[k], - self.metrics_thresholds[k], - ) - raise AirflowException(f"The following tests have failed:\n {', '.join(sorted(failed_tests))}") - - self.log.info("All tests have passed") - - -class SQLThresholdCheckOperator(BaseSQLOperator): - """ - Performs a value check using sql code against a minimum threshold - and a maximum threshold. Thresholds can be in the form of a numeric - value OR a sql statement that results a numeric. - - :param sql: the sql to be executed. (templated) - :param conn_id: the connection ID used to connect to the database. - :param database: name of database which overwrite the defined one in connection - :param min_threshold: numerical value or min threshold sql to be executed (templated) - :param max_threshold: numerical value or max threshold sql to be executed (templated) - """ - - template_fields: Sequence[str] = ("sql", "min_threshold", "max_threshold") - template_ext: Sequence[str] = ( - ".hql", - ".sql", - ) - template_fields_renderers = {"sql": "sql"} - - def __init__( - self, - *, - sql: str, - min_threshold: Any, - max_threshold: Any, - conn_id: Optional[str] = None, - database: Optional[str] = None, - **kwargs, - ): - super().__init__(conn_id=conn_id, database=database, **kwargs) - self.sql = sql - self.min_threshold = _convert_to_float_if_possible(min_threshold) - self.max_threshold = _convert_to_float_if_possible(max_threshold) - - def execute(self, context=None): - hook = self.get_db_hook() - result = hook.get_first(self.sql)[0] - - if isinstance(self.min_threshold, float): - lower_bound = self.min_threshold - else: - lower_bound = hook.get_first(self.min_threshold)[0] - - if isinstance(self.max_threshold, float): - upper_bound = self.max_threshold - else: - upper_bound = hook.get_first(self.max_threshold)[0] - - meta_data = { - "result": result, - "task_id": self.task_id, - "min_threshold": lower_bound, - "max_threshold": upper_bound, - "within_threshold": lower_bound <= result <= upper_bound, - } - - self.push(meta_data) - if not meta_data["within_threshold"]: - error_msg = ( - f'Threshold Check: "{meta_data.get("task_id")}" failed.\n' - f'DAG: {self.dag_id}\nTask_id: {meta_data.get("task_id")}\n' - f'Check description: {meta_data.get("description")}\n' - f"SQL: {self.sql}\n" - f'Result: {round(meta_data.get("result"), 2)} is not within thresholds ' - f'{meta_data.get("min_threshold")} and {meta_data.get("max_threshold")}' - ) - raise AirflowException(error_msg) - - self.log.info("Test %s Successful.", self.task_id) - - def push(self, meta_data): - """ - Optional: Send data check info and metadata to an external database. - Default functionality will log metadata. - """ - info = "\n".join(f"""{key}: {item}""" for key, item in meta_data.items()) - self.log.info("Log from %s:\n%s", self.dag_id, info) - - -class BranchSQLOperator(BaseSQLOperator, SkipMixin): - """ - Allows a DAG to "branch" or follow a specified path based on the results of a SQL query. - - :param sql: The SQL code to be executed, should return true or false (templated) - Template reference are recognized by str ending in '.sql'. - Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1) - or string (true/y/yes/1/on/false/n/no/0/off). - :param follow_task_ids_if_true: task id or task ids to follow if query returns true - :param follow_task_ids_if_false: task id or task ids to follow if query returns false - :param conn_id: the connection ID used to connect to the database. - :param database: name of database which overwrite the defined one in connection - :param parameters: (optional) the parameters to render the SQL query with. - """ - - template_fields: Sequence[str] = ("sql",) - template_ext: Sequence[str] = (".sql",) - template_fields_renderers = {"sql": "sql"} - ui_color = "#a22034" - ui_fgcolor = "#F7F7F7" - - def __init__( - self, - *, - sql: str, - follow_task_ids_if_true: List[str], - follow_task_ids_if_false: List[str], - conn_id: str = "default_conn_id", - database: Optional[str] = None, - parameters: Optional[Union[Mapping, Iterable]] = None, - **kwargs, - ) -> None: - super().__init__(conn_id=conn_id, database=database, **kwargs) - self.sql = sql - self.parameters = parameters - self.follow_task_ids_if_true = follow_task_ids_if_true - self.follow_task_ids_if_false = follow_task_ids_if_false - - def execute(self, context: Context): - self.log.info( - "Executing: %s (with parameters %s) with connection: %s", - self.sql, - self.parameters, - self.conn_id, - ) - record = self.get_db_hook().get_first(self.sql, self.parameters) - if not record: - raise AirflowException( - "No rows returned from sql query. Operator expected True or False return value." - ) - - if isinstance(record, list): - if isinstance(record[0], list): - query_result = record[0][0] - else: - query_result = record[0] - elif isinstance(record, tuple): - query_result = record[0] - else: - query_result = record - - self.log.info("Query returns %s, type '%s'", query_result, type(query_result)) - - follow_branch = None - try: - if isinstance(query_result, bool): - if query_result: - follow_branch = self.follow_task_ids_if_true - elif isinstance(query_result, str): - # return result is not Boolean, try to convert from String to Boolean - if parse_boolean(query_result): - follow_branch = self.follow_task_ids_if_true - elif isinstance(query_result, int): - if bool(query_result): - follow_branch = self.follow_task_ids_if_true - else: - raise AirflowException( - f"Unexpected query return result '{query_result}' type '{type(query_result)}'" - ) - - if follow_branch is None: - follow_branch = self.follow_task_ids_if_false - except ValueError: - raise AirflowException( - f"Unexpected query return result '{query_result}' type '{type(query_result)}'" - ) - - self.skip_all_except(context["ti"], follow_branch) diff --git a/airflow/operators/sql_branch_operator.py b/airflow/operators/sql_branch_operator.py deleted file mode 100644 index 5987bce61003f..0000000000000 --- a/airflow/operators/sql_branch_operator.py +++ /dev/null @@ -1,40 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.operators.sql`.""" -import warnings - -from airflow.operators.sql import BranchSQLOperator - -warnings.warn( - "This module is deprecated. Please use :mod:`airflow.operators.sql`.", DeprecationWarning, stacklevel=2 -) - - -class BranchSqlOperator(BranchSQLOperator): - """ - This class is deprecated. - Please use `airflow.operators.sql.BranchSQLOperator`. - """ - - def __init__(self, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.operators.sql.BranchSQLOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(**kwargs) diff --git a/airflow/operators/sqlite_operator.py b/airflow/operators/sqlite_operator.py deleted file mode 100644 index 68791d69846c0..0000000000000 --- a/airflow/operators/sqlite_operator.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.sqlite.operators.sqlite`.""" - -import warnings - -from airflow.providers.sqlite.operators.sqlite import SqliteOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.sqlite.operators.sqlite`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/operators/subdag.py b/airflow/operators/subdag.py index bd81314dda8fd..1b3f0049788cc 100644 --- a/airflow/operators/subdag.py +++ b/airflow/operators/subdag.py @@ -19,16 +19,16 @@ This module is deprecated. Please use :mod:`airflow.utils.task_group`. The module which provides a way to nest your DAGs and so your levels of complexity. """ +from __future__ import annotations import warnings from datetime import datetime from enum import Enum -from typing import Dict, Optional, Tuple from sqlalchemy.orm.session import Session from airflow.api.common.experimental.get_task_instance import get_task_instance -from airflow.exceptions import AirflowException, TaskInstanceNotFound +from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskInstanceNotFound from airflow.models import DagRun from airflow.models.dag import DAG, DagContext from airflow.models.pool import Pool @@ -43,8 +43,8 @@ class SkippedStatePropagationOptions(Enum): """Available options for skipped state propagation of subdag's tasks to parent dag tasks.""" - ALL_LEAVES = 'all_leaves' - ANY_LEAF = 'any_leaf' + ALL_LEAVES = "all_leaves" + ANY_LEAF = "any_leaf" class SubDagOperator(BaseSensorOperator): @@ -66,10 +66,10 @@ class SubDagOperator(BaseSensorOperator): parent dag's downstream task. """ - ui_color = '#555' - ui_fgcolor = '#fff' + ui_color = "#555" + ui_fgcolor = "#fff" - subdag: "DAG" + subdag: DAG @provide_session def __init__( @@ -77,8 +77,8 @@ def __init__( *, subdag: DAG, session: Session = NEW_SESSION, - conf: Optional[Dict] = None, - propagate_skipped_state: Optional[SkippedStatePropagationOptions] = None, + conf: dict | None = None, + propagate_skipped_state: SkippedStatePropagationOptions | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -91,17 +91,17 @@ def __init__( warnings.warn( """This class is deprecated. Please use `airflow.utils.task_group.TaskGroup`.""", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=4, ) def _validate_dag(self, kwargs): - dag = kwargs.get('dag') or DagContext.get_current_dag() + dag = kwargs.get("dag") or DagContext.get_current_dag() if not dag: - raise AirflowException('Please pass in the `dag` param or call within a DAG context manager') + raise AirflowException("Please pass in the `dag` param or call within a DAG context manager") - if dag.dag_id + '.' + kwargs['task_id'] != self.subdag.dag_id: + if dag.dag_id + "." + kwargs["task_id"] != self.subdag.dag_id: raise AirflowException( f"The subdag's dag_id should have the form '{{parent_dag_id}}.{{this_task_id}}'. " f"Expected '{dag.dag_id}.{kwargs['task_id']}'; received '{self.subdag.dag_id}'." @@ -152,14 +152,14 @@ def _reset_dag_run_and_task_instances(self, dag_run, execution_date): def pre_execute(self, context): super().pre_execute(context) - execution_date = context['execution_date'] + execution_date = context["execution_date"] dag_run = self._get_dagrun(execution_date) if dag_run is None: - if context['data_interval_start'] is None or context['data_interval_end'] is None: - data_interval: Optional[Tuple[datetime, datetime]] = None + if context["data_interval_start"] is None or context["data_interval_end"] is None: + data_interval: tuple[datetime, datetime] | None = None else: - data_interval = (context['data_interval_start'], context['data_interval_end']) + data_interval = (context["data_interval_start"], context["data_interval_end"]) dag_run = self.subdag.create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=execution_date, @@ -175,13 +175,13 @@ def pre_execute(self, context): self._reset_dag_run_and_task_instances(dag_run, execution_date) def poke(self, context: Context): - execution_date = context['execution_date'] + execution_date = context["execution_date"] dag_run = self._get_dagrun(execution_date=execution_date) return dag_run.state != State.RUNNING def post_execute(self, context, result=None): super().post_execute(context) - execution_date = context['execution_date'] + execution_date = context["execution_date"] dag_run = self._get_dagrun(execution_date=execution_date) self.log.info("Execution finished. State is %s", dag_run.state) @@ -192,14 +192,14 @@ def post_execute(self, context, result=None): self._skip_downstream_tasks(context) def _check_skipped_states(self, context): - leaves_tis = self._get_leaves_tis(context['execution_date']) + leaves_tis = self._get_leaves_tis(context["execution_date"]) if self.propagate_skipped_state == SkippedStatePropagationOptions.ANY_LEAF: return any(ti.state == State.SKIPPED for ti in leaves_tis) if self.propagate_skipped_state == SkippedStatePropagationOptions.ALL_LEAVES: return all(ti.state == State.SKIPPED for ti in leaves_tis) raise AirflowException( - f'Unimplemented SkippedStatePropagationOptions {self.propagate_skipped_state} used.' + f"Unimplemented SkippedStatePropagationOptions {self.propagate_skipped_state} used." ) def _get_leaves_tis(self, execution_date): @@ -216,15 +216,15 @@ def _get_leaves_tis(self, execution_date): def _skip_downstream_tasks(self, context): self.log.info( - 'Skipping downstream tasks because propagate_skipped_state is set to %s ' - 'and skipped task(s) were found.', + "Skipping downstream tasks because propagate_skipped_state is set to %s " + "and skipped task(s) were found.", self.propagate_skipped_state, ) - downstream_tasks = context['task'].downstream_list - self.log.debug('Downstream task_ids %s', downstream_tasks) + downstream_tasks = context["task"].downstream_list + self.log.debug("Downstream task_ids %s", downstream_tasks) if downstream_tasks: - self.skip(context['dag_run'], context['execution_date'], downstream_tasks) + self.skip(context["dag_run"], context["execution_date"], downstream_tasks) - self.log.info('Done.') + self.log.info("Done.") diff --git a/airflow/operators/subdag_operator.py b/airflow/operators/subdag_operator.py deleted file mode 100644 index bb5a088d23b6d..0000000000000 --- a/airflow/operators/subdag_operator.py +++ /dev/null @@ -1,26 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.operators.subdag`.""" - -import warnings - -from airflow.operators.subdag import SkippedStatePropagationOptions, SubDagOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.operators.subdag`.", DeprecationWarning, stacklevel=2 -) diff --git a/airflow/operators/trigger_dagrun.py b/airflow/operators/trigger_dagrun.py index 0689f14c56261..b687c05dc4b97 100644 --- a/airflow/operators/trigger_dagrun.py +++ b/airflow/operators/trigger_dagrun.py @@ -15,15 +15,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import datetime import json import time -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union, cast +from typing import TYPE_CHECKING, Sequence, cast from airflow.api.common.trigger_dag import trigger_dag from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists -from airflow.models import BaseOperator, BaseOperatorLink, DagBag, DagModel, DagRun +from airflow.models.baseoperator import BaseOperator, BaseOperatorLink +from airflow.models.dag import DagModel +from airflow.models.dagbag import DagBag +from airflow.models.dagrun import DagRun from airflow.models.xcom import XCom from airflow.utils import timezone from airflow.utils.context import Context @@ -36,7 +40,6 @@ if TYPE_CHECKING: - from airflow.models.abstractoperator import AbstractOperator from airflow.models.taskinstance import TaskInstanceKey @@ -46,14 +49,9 @@ class TriggerDagRunLink(BaseOperatorLink): DAG triggered by task using TriggerDagRunOperator. """ - name = 'Triggered DAG' + name = "Triggered DAG" - def get_link( - self, - operator: "AbstractOperator", - *, - ti_key: "TaskInstanceKey", - ) -> str: + def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str: # Fetch the correct execution date for the triggerED dag which is # stored in xcom during execution of the triggerING task. when = XCom.get_value(ti_key=ti_key, key=XCOM_EXECUTION_DATE_ISO) @@ -72,6 +70,8 @@ class TriggerDagRunOperator(BaseOperator): :param execution_date: Execution date for the dag (templated). :param reset_dag_run: Whether or not clear existing dag run if already exists. This is useful when backfill or rerun an existing dag run. + This only resets (not recreates) the dag run. + Dag run conf is immutable and will not be reset on rerun of an existing dag run. When reset_dag_run=False and dag run exists, DagRunAlreadyExists will be raised. When reset_dag_run=True and dag run exists, existing dag run will be cleared to rerun. :param wait_for_completion: Whether or not wait for dag run completion. (default: False) @@ -79,29 +79,26 @@ class TriggerDagRunOperator(BaseOperator): (default: 60) :param allowed_states: List of allowed states, default is ``['success']``. :param failed_states: List of failed or dis-allowed states, default is ``None``. + :param notes: Set a custom note for the newly created DagRun. """ template_fields: Sequence[str] = ("trigger_dag_id", "trigger_run_id", "execution_date", "conf") template_fields_renderers = {"conf": "py"} ui_color = "#ffefeb" - - @property - def operator_extra_links(self): - """Return operator extra links""" - return [TriggerDagRunLink()] + operator_extra_links = [TriggerDagRunLink()] def __init__( self, *, trigger_dag_id: str, - trigger_run_id: Optional[str] = None, - conf: Optional[Dict] = None, - execution_date: Optional[Union[str, datetime.datetime]] = None, + trigger_run_id: str | None = None, + conf: dict | None = None, + execution_date: str | datetime.datetime | None = None, reset_dag_run: bool = False, wait_for_completion: bool = False, poke_interval: int = 60, - allowed_states: Optional[List] = None, - failed_states: Optional[List] = None, + allowed_states: list | None = None, + failed_states: list | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -121,11 +118,6 @@ def __init__( self.execution_date = execution_date - try: - json.dumps(self.conf) - except TypeError: - raise AirflowException("conf parameter should be JSON Serializable") - def execute(self, context: Context): if isinstance(self.execution_date, datetime.datetime): parsed_execution_date = self.execution_date @@ -134,6 +126,11 @@ def execute(self, context: Context): else: parsed_execution_date = timezone.utcnow() + try: + json.dumps(self.conf) + except TypeError: + raise AirflowException("conf parameter should be JSON Serializable") + if self.trigger_run_id: run_id = self.trigger_run_id else: @@ -160,14 +157,14 @@ def execute(self, context: Context): dag_bag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True) dag = dag_bag.get_dag(self.trigger_dag_id) dag.clear(start_date=parsed_execution_date, end_date=parsed_execution_date) - dag_run = DagRun.find(dag_id=dag.dag_id, run_id=run_id)[0] + dag_run = e.dag_run else: raise e if dag_run is None: raise RuntimeError("The dag_run should be set here!") # Store the execution date from the dag run (either created or found above) to # be used when creating the extra link on the webserver. - ti = context['task_instance'] + ti = context["task_instance"] ti.xcom_push(key=XCOM_EXECUTION_DATE_ISO, value=dag_run.execution_date.isoformat()) ti.xcom_push(key=XCOM_RUN_ID, value=dag_run.run_id) @@ -175,7 +172,7 @@ def execute(self, context: Context): # wait for dag to complete while True: self.log.info( - 'Waiting for %s on %s to become allowed state %s ...', + "Waiting for %s on %s to become allowed state %s ...", self.trigger_dag_id, dag_run.execution_date, self.allowed_states, diff --git a/airflow/operators/weekday.py b/airflow/operators/weekday.py index b23d57e9fb1d4..2b4e0f4698df6 100644 --- a/airflow/operators/weekday.py +++ b/airflow/operators/weekday.py @@ -15,9 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import warnings -from typing import Iterable, Union +from typing import Iterable +from airflow.exceptions import RemovedInAirflow3Warning from airflow.operators.branch import BaseBranchOperator from airflow.utils import timezone from airflow.utils.context import Context @@ -30,6 +33,40 @@ class BranchDayOfWeekOperator(BaseBranchOperator): For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:BranchDayOfWeekOperator` + **Example** (with single day): :: + + from airflow.operators.empty import EmptyOperator + + monday = EmptyOperator(task_id='monday') + other_day = EmptyOperator(task_id='other_day') + + monday_check = DayOfWeekSensor( + task_id='monday_check', + week_day='Monday', + use_task_logical_date=True, + follow_task_ids_if_true='monday', + follow_task_ids_if_false='other_day', + dag=dag) + monday_check >> [monday, other_day] + + **Example** (with :class:`~airflow.utils.weekday.WeekDay` enum): :: + + # import WeekDay Enum + from airflow.utils.weekday import WeekDay + from airflow.operators.empty import EmptyOperator + + workday = EmptyOperator(task_id='workday') + weekend = EmptyOperator(task_id='weekend') + weekend_check = BranchDayOfWeekOperator( + task_id='weekend_check', + week_day={WeekDay.SATURDAY, WeekDay.SUNDAY}, + use_task_logical_date=True, + follow_task_ids_if_true='weekend', + follow_task_ids_if_false='workday', + dag=dag) + # add downstream dependencies as you would do with any branch operator + weekend_check >> [workday, weekend] + :param follow_task_ids_if_true: task id or task ids to follow if criteria met :param follow_task_ids_if_false: task id or task ids to follow if criteria does not met :param week_day: Day of the week to check (full name). Optionally, a set @@ -41,17 +78,20 @@ class BranchDayOfWeekOperator(BaseBranchOperator): * ``{WeekDay.TUESDAY}`` * ``{WeekDay.SATURDAY, WeekDay.SUNDAY}`` + To use `WeekDay` enum, import it from `airflow.utils.weekday` + :param use_task_logical_date: If ``True``, uses task's logical date to compare with is_today. Execution Date is Useful for backfilling. If ``False``, uses system's day of the week. + :param use_task_execution_day: deprecated parameter, same effect as `use_task_logical_date` """ def __init__( self, *, - follow_task_ids_if_true: Union[str, Iterable[str]], - follow_task_ids_if_false: Union[str, Iterable[str]], - week_day: Union[str, Iterable[str]], + follow_task_ids_if_true: str | Iterable[str], + follow_task_ids_if_false: str | Iterable[str], + week_day: str | Iterable[str] | WeekDay | Iterable[WeekDay], use_task_logical_date: bool = False, use_task_execution_day: bool = False, **kwargs, @@ -65,12 +105,12 @@ def __init__( self.use_task_logical_date = use_task_execution_day warnings.warn( "Parameter ``use_task_execution_day`` is deprecated. Use ``use_task_logical_date``.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) self._week_day_num = WeekDay.validate_week_day(week_day) - def choose_branch(self, context: Context) -> Union[str, Iterable[str]]: + def choose_branch(self, context: Context) -> str | Iterable[str]: if self.use_task_logical_date: now = context["logical_date"] else: diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py index 82e295fa1906c..51c74a37c3b1c 100644 --- a/airflow/plugins_manager.py +++ b/airflow/plugins_manager.py @@ -16,6 +16,8 @@ # specific language governing permissions and limitations # under the License. """Manages all plugins.""" +from __future__ import annotations + import importlib import importlib.machinery import importlib.util @@ -24,7 +26,7 @@ import os import sys import types -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Type +from typing import TYPE_CHECKING, Any, Iterable try: import importlib_metadata @@ -45,26 +47,26 @@ log = logging.getLogger(__name__) -import_errors: Dict[str, str] = {} +import_errors: dict[str, str] = {} -plugins = None # type: Optional[List[AirflowPlugin]] +plugins: list[AirflowPlugin] | None = None # Plugin components to integrate as modules -registered_hooks: Optional[List['BaseHook']] = None -macros_modules: Optional[List[Any]] = None -executors_modules: Optional[List[Any]] = None +registered_hooks: list[BaseHook] | None = None +macros_modules: list[Any] | None = None +executors_modules: list[Any] | None = None # Plugin components to integrate directly -admin_views: Optional[List[Any]] = None -flask_blueprints: Optional[List[Any]] = None -menu_links: Optional[List[Any]] = None -flask_appbuilder_views: Optional[List[Any]] = None -flask_appbuilder_menu_links: Optional[List[Any]] = None -global_operator_extra_links: Optional[List[Any]] = None -operator_extra_links: Optional[List[Any]] = None -registered_operator_link_classes: Optional[Dict[str, Type]] = None -registered_ti_dep_classes: Optional[Dict[str, Type]] = None -timetable_classes: Optional[Dict[str, Type["Timetable"]]] = None +admin_views: list[Any] | None = None +flask_blueprints: list[Any] | None = None +menu_links: list[Any] | None = None +flask_appbuilder_views: list[Any] | None = None +flask_appbuilder_menu_links: list[Any] | None = None +global_operator_extra_links: list[Any] | None = None +operator_extra_links: list[Any] | None = None +registered_operator_link_classes: dict[str, type] | None = None +registered_ti_dep_classes: dict[str, type] | None = None +timetable_classes: dict[str, type[Timetable]] | None = None """Mapping of class names to class of OperatorLinks registered by plugins. Used by the DAG serialization code to only allow specific classes to be created @@ -113,7 +115,7 @@ class EntryPointSource(AirflowPluginSource): """Class used to define Plugins loaded from entrypoint.""" def __init__(self, entrypoint: importlib_metadata.EntryPoint, dist: importlib_metadata.Distribution): - self.dist = dist.metadata['name'] + self.dist = dist.metadata["Name"] self.version = dist.version self.entrypoint = str(entrypoint) @@ -131,16 +133,16 @@ class AirflowPluginException(Exception): class AirflowPlugin: """Class used to define AirflowPlugin.""" - name: Optional[str] = None - source: Optional[AirflowPluginSource] = None - hooks: List[Any] = [] - executors: List[Any] = [] - macros: List[Any] = [] - admin_views: List[Any] = [] - flask_blueprints: List[Any] = [] - menu_links: List[Any] = [] - appbuilder_views: List[Any] = [] - appbuilder_menu_items: List[Any] = [] + name: str | None = None + source: AirflowPluginSource | None = None + hooks: list[Any] = [] + executors: list[Any] = [] + macros: list[Any] = [] + admin_views: list[Any] = [] + flask_blueprints: list[Any] = [] + menu_links: list[Any] = [] + appbuilder_views: list[Any] = [] + appbuilder_menu_items: list[Any] = [] # A list of global operator extra links that can redirect users to # external systems. These extra links will be available on the @@ -148,20 +150,20 @@ class AirflowPlugin: # # Note: the global operator extra link can be overridden at each # operator level. - global_operator_extra_links: List[Any] = [] + global_operator_extra_links: list[Any] = [] # A list of operator extra links to override or add operator links # to existing Airflow Operators. # These extra links will be available on the task page in form of # buttons. - operator_extra_links: List[Any] = [] + operator_extra_links: list[Any] = [] - ti_deps: List[Any] = [] + ti_deps: list[Any] = [] # A list of timetable classes that can be used for DAG scheduling. - timetables: List[Type["Timetable"]] = [] + timetables: list[type[Timetable]] = [] - listeners: List[ModuleType] = [] + listeners: list[ModuleType] = [] @classmethod def validate(cls): @@ -221,8 +223,8 @@ def load_entrypoint_plugins(): log.debug("Loading plugins from entrypoints") - for entry_point, dist in entry_points_with_dist('airflow.plugins'): - log.debug('Importing entry_point plugin %s', entry_point.name) + for entry_point, dist in entry_points_with_dist("airflow.plugins"): + log.debug("Importing entry_point plugin %s", entry_point.name) try: plugin_class = entry_point.load() if not is_valid_plugin(plugin_class): @@ -245,7 +247,7 @@ def load_plugins_from_plugin_directory(): if not os.path.isfile(file_path): continue mod_name, file_ext = os.path.splitext(os.path.split(file_path)[-1]) - if file_ext != '.py': + if file_ext != ".py": continue try: @@ -254,25 +256,25 @@ def load_plugins_from_plugin_directory(): mod = importlib.util.module_from_spec(spec) sys.modules[spec.name] = mod loader.exec_module(mod) - log.debug('Importing plugin module %s', file_path) + log.debug("Importing plugin module %s", file_path) for mod_attr_value in (m for m in mod.__dict__.values() if is_valid_plugin(m)): plugin_instance = mod_attr_value() plugin_instance.source = PluginsDirectorySource(file_path) register_plugin(plugin_instance) except Exception as e: - log.exception('Failed to import plugin %s', file_path) + log.exception("Failed to import plugin %s", file_path) import_errors[file_path] = str(e) -def make_module(name: str, objects: List[Any]): +def make_module(name: str, objects: list[Any]): """Creates new module.""" if not objects: return None - log.debug('Creating module %s', name) + log.debug("Creating module %s", name) name = name.lower() module = types.ModuleType(name) - module._name = name.split('.')[-1] # type: ignore + module._name = name.split(".")[-1] # type: ignore module._objects = objects # type: ignore module.__dict__.update((o.__name__, o) for o in objects) return module @@ -342,13 +344,13 @@ def initialize_web_ui_plugins(): for plugin in plugins: flask_appbuilder_views.extend(plugin.appbuilder_views) flask_appbuilder_menu_links.extend(plugin.appbuilder_menu_items) - flask_blueprints.extend([{'name': plugin.name, 'blueprint': bp} for bp in plugin.flask_blueprints]) + flask_blueprints.extend([{"name": plugin.name, "blueprint": bp} for bp in plugin.flask_blueprints]) if (plugin.admin_views and not plugin.appbuilder_views) or ( plugin.menu_links and not plugin.appbuilder_menu_items ): log.warning( - "Plugin \'%s\' may not be compatible with the current Airflow version. " + "Plugin '%s' may not be compatible with the current Airflow version. " "Please contact the author of the plugin.", plugin.name, ) @@ -450,7 +452,7 @@ def integrate_executor_plugins() -> None: raise AirflowPluginException("Invalid plugin name") plugin_name: str = plugin.name - executors_module = make_module('airflow.executors.' + plugin_name, plugin.executors) + executors_module = make_module("airflow.executors." + plugin_name, plugin.executors) if executors_module: executors_modules.append(executors_module) sys.modules[executors_module.__name__] = executors_module @@ -479,7 +481,7 @@ def integrate_macros_plugins() -> None: if plugin.name is None: raise AirflowPluginException("Invalid plugin name") - macros_module = make_module(f'airflow.macros.{plugin.name}', plugin.macros) + macros_module = make_module(f"airflow.macros.{plugin.name}", plugin.macros) if macros_module: macros_modules.append(macros_module) @@ -489,7 +491,7 @@ def integrate_macros_plugins() -> None: setattr(macros, plugin.name, macros_module) -def integrate_listener_plugins(listener_manager: "ListenerManager") -> None: +def integrate_listener_plugins(listener_manager: ListenerManager) -> None: global plugins ensure_plugins_loaded() @@ -503,7 +505,7 @@ def integrate_listener_plugins(listener_manager: "ListenerManager") -> None: listener_manager.add_listener(listener) -def get_plugin_info(attrs_to_dump: Optional[Iterable[str]] = None) -> List[Dict[str, Any]]: +def get_plugin_info(attrs_to_dump: Iterable[str] | None = None) -> list[dict[str, Any]]: """ Dump plugins attributes @@ -519,23 +521,23 @@ def get_plugin_info(attrs_to_dump: Optional[Iterable[str]] = None) -> List[Dict[ plugins_info = [] if plugins: for plugin in plugins: - info: Dict[str, Any] = {"name": plugin.name} + info: dict[str, Any] = {"name": plugin.name} for attr in attrs_to_dump: - if attr in ('global_operator_extra_links', 'operator_extra_links'): + if attr in ("global_operator_extra_links", "operator_extra_links"): info[attr] = [ - f'<{as_importable_string(d.__class__)} object>' for d in getattr(plugin, attr) + f"<{as_importable_string(d.__class__)} object>" for d in getattr(plugin, attr) ] - elif attr in ('macros', 'timetables', 'hooks', 'executors'): + elif attr in ("macros", "timetables", "hooks", "executors"): info[attr] = [as_importable_string(d) for d in getattr(plugin, attr)] - elif attr == 'listeners': + elif attr == "listeners": # listeners are always modules info[attr] = [d.__name__ for d in getattr(plugin, attr)] - elif attr == 'appbuilder_views': + elif attr == "appbuilder_views": info[attr] = [ - {**d, 'view': as_importable_string(d['view'].__class__) if 'view' in d else None} + {**d, "view": as_importable_string(d["view"].__class__) if "view" in d else None} for d in getattr(plugin, attr) ] - elif attr == 'flask_blueprints': + elif attr == "flask_blueprints": info[attr] = [ ( f"<{as_importable_string(d.__class__)}: " diff --git a/airflow/provider.yaml.schema.json b/airflow/provider.yaml.schema.json index d34fcce95e7c2..ff0537db32cf4 100644 --- a/airflow/provider.yaml.schema.json +++ b/airflow/provider.yaml.schema.json @@ -21,8 +21,8 @@ "type": "string" } }, - "additional-dependencies": { - "description": "Additional dependencies that should be added to the provider", + "dependencies": { + "description": "Dependencies that should be added to the provider", "type": "array", "items": { "type": "string" @@ -194,17 +194,6 @@ ] } }, - "hook-class-names": { - "type": "array", - "description": "Hook class names that provide connection types to core (deprecated by connection-types)", - "items": { - "type": "string" - }, - "deprecated": { - "description": "The hook-class-names property has been deprecated in favour of connection-types which is more performant version allowing to only import individual Hooks rather than all hooks at once", - "deprecatedVersion": "2.2" - } - }, "connection-types": { "type": "array", "description": "Array of connection types mapped to hook class names", @@ -219,9 +208,12 @@ "description": "Hook class name that implements the connection type", "type": "string" } - } - }, - "required": ["connection-type", "hook-class-name"] + }, + "required": [ + "connection-type", + "hook-class-name" + ] + } }, "extra-links": { "type": "array", @@ -231,8 +223,26 @@ } }, "additional-extras": { - "type": "object", - "description": "Additional extras that the provider should have" + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "description": "Name of the extra", + "type": "string" + }, + "dependencies": { + "description": "Dependencies that should be added for the extra", + "type": "array", + "items": { + "type": "string" + } + } + }, + "required": [ "name", "dependencies"] + }, + + "description": "Additional extras that the provider should have. Replaces auto-generated cross-provider extras, if matching the same prefix, so that you can specify boundaries for existing dependencies." }, "task-decorators": { "type": "array", @@ -273,6 +283,7 @@ "name", "package-name", "description", + "dependencies", "versions" ] } diff --git a/airflow/providers/airbyte/.latest-doc-only-change.txt b/airflow/providers/airbyte/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/airbyte/.latest-doc-only-change.txt +++ b/airflow/providers/airbyte/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/airbyte/CHANGELOG.rst b/airflow/providers/airbyte/CHANGELOG.rst index deb0e79e968e0..103cf3179fc46 100644 --- a/airflow/providers/airbyte/CHANGELOG.rst +++ b/airflow/providers/airbyte/CHANGELOG.rst @@ -16,9 +16,61 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.2.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``AIP-47 - Migrate Airbyte DAGs to new design (#25135)`` + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``'AirbyteHook' add cancel job option (#24593)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.1.4 ..... diff --git a/airflow/providers/airbyte/example_dags/example_airbyte_trigger_job.py b/airflow/providers/airbyte/example_dags/example_airbyte_trigger_job.py deleted file mode 100644 index 55563ff5e03bd..0000000000000 --- a/airflow/providers/airbyte/example_dags/example_airbyte_trigger_job.py +++ /dev/null @@ -1,57 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Example DAG demonstrating the usage of the AirbyteTriggerSyncOperator.""" - -from datetime import datetime, timedelta - -from airflow import DAG -from airflow.providers.airbyte.operators.airbyte import AirbyteTriggerSyncOperator -from airflow.providers.airbyte.sensors.airbyte import AirbyteJobSensor - -with DAG( - dag_id='example_airbyte_operator', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - dagrun_timeout=timedelta(minutes=60), - tags=['example'], - catchup=False, -) as dag: - - # [START howto_operator_airbyte_synchronous] - sync_source_destination = AirbyteTriggerSyncOperator( - task_id='airbyte_sync_source_dest_example', - connection_id='15bc3800-82e4-48c3-a32d-620661273f28', - ) - # [END howto_operator_airbyte_synchronous] - - # [START howto_operator_airbyte_asynchronous] - async_source_destination = AirbyteTriggerSyncOperator( - task_id='airbyte_async_source_dest_example', - connection_id='15bc3800-82e4-48c3-a32d-620661273f28', - asynchronous=True, - ) - - airbyte_sensor = AirbyteJobSensor( - task_id='airbyte_sensor_source_dest_example', - airbyte_job_id=async_source_destination.output, - ) - # [END howto_operator_airbyte_asynchronous] - - # Task dependency created via `XComArgs`: - # async_source_destination >> airbyte_sensor diff --git a/airflow/providers/airbyte/hooks/airbyte.py b/airflow/providers/airbyte/hooks/airbyte.py index b1f6317530514..0ececccdfa907 100644 --- a/airflow/providers/airbyte/hooks/airbyte.py +++ b/airflow/providers/airbyte/hooks/airbyte.py @@ -15,8 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import time -from typing import Any, Optional, Union +from typing import Any from airflow.exceptions import AirflowException from airflow.providers.http.hooks.http import HttpHook @@ -31,10 +33,10 @@ class AirbyteHook(HttpHook): :param api_version: Optional. Airbyte API version. """ - conn_name_attr = 'airbyte_conn_id' - default_conn_name = 'airbyte_default' - conn_type = 'airbyte' - hook_name = 'Airbyte' + conn_name_attr = "airbyte_conn_id" + default_conn_name = "airbyte_default" + conn_type = "airbyte" + hook_name = "Airbyte" RUNNING = "running" SUCCEEDED = "succeeded" @@ -48,9 +50,7 @@ def __init__(self, airbyte_conn_id: str = "airbyte_default", api_version: str = super().__init__(http_conn_id=airbyte_conn_id) self.api_version: str = api_version - def wait_for_job( - self, job_id: Union[str, int], wait_seconds: float = 3, timeout: Optional[float] = 3600 - ) -> None: + def wait_for_job(self, job_id: str | int, wait_seconds: float = 3, timeout: float | None = 3600) -> None: """ Helper method which polls a job to check if it finishes. @@ -107,21 +107,33 @@ def get_job(self, job_id: int) -> Any: headers={"accept": "application/json"}, ) + def cancel_job(self, job_id: int) -> Any: + """ + Cancel the job when task is cancelled + + :param job_id: Required. Id of the Airbyte job + """ + return self.run( + endpoint=f"api/{self.api_version}/jobs/cancel", + json={"id": job_id}, + headers={"accept": "application/json"}, + ) + def test_connection(self): """Tests the Airbyte connection by hitting the health API""" - self.method = 'GET' + self.method = "GET" try: res = self.run( endpoint=f"api/{self.api_version}/health", headers={"accept": "application/json"}, - extra_options={'check_response': False}, + extra_options={"check_response": False}, ) if res.status_code == 200: - return True, 'Connection successfully tested' + return True, "Connection successfully tested" else: return False, res.text except Exception as e: return False, str(e) finally: - self.method = 'POST' + self.method = "POST" diff --git a/airflow/providers/airbyte/operators/airbyte.py b/airflow/providers/airbyte/operators/airbyte.py index ef2e2c1559902..e5ebc443e6d2d 100644 --- a/airflow/providers/airbyte/operators/airbyte.py +++ b/airflow/providers/airbyte/operators/airbyte.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.airbyte.hooks.airbyte import AirbyteHook @@ -45,16 +47,16 @@ class AirbyteTriggerSyncOperator(BaseOperator): Only used when ``asynchronous`` is False. """ - template_fields: Sequence[str] = ('connection_id',) + template_fields: Sequence[str] = ("connection_id",) def __init__( self, connection_id: str, airbyte_conn_id: str = "airbyte_default", - asynchronous: Optional[bool] = False, + asynchronous: bool | None = False, api_version: str = "v1", wait_seconds: float = 3, - timeout: Optional[float] = 3600, + timeout: float | None = 3600, **kwargs, ) -> None: super().__init__(**kwargs) @@ -65,16 +67,22 @@ def __init__( self.wait_seconds = wait_seconds self.asynchronous = asynchronous - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: """Create Airbyte Job and wait to finish""" - hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id, api_version=self.api_version) - job_object = hook.submit_sync_connection(connection_id=self.connection_id) - job_id = job_object.json()['job']['id'] + self.hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id, api_version=self.api_version) + job_object = self.hook.submit_sync_connection(connection_id=self.connection_id) + self.job_id = job_object.json()["job"]["id"] - self.log.info("Job %s was submitted to Airbyte Server", job_id) + self.log.info("Job %s was submitted to Airbyte Server", self.job_id) if not self.asynchronous: - self.log.info('Waiting for job %s to complete', job_id) - hook.wait_for_job(job_id=job_id, wait_seconds=self.wait_seconds, timeout=self.timeout) - self.log.info('Job %s completed successfully', job_id) + self.log.info("Waiting for job %s to complete", self.job_id) + self.hook.wait_for_job(job_id=self.job_id, wait_seconds=self.wait_seconds, timeout=self.timeout) + self.log.info("Job %s completed successfully", self.job_id) + + return self.job_id - return job_id + def on_kill(self): + """Cancel the job if task is cancelled""" + if self.job_id: + self.log.info("on_kill: cancel the airbyte Job %s", self.job_id) + self.hook.cancel_job(self.job_id) diff --git a/airflow/providers/airbyte/provider.yaml b/airflow/providers/airbyte/provider.yaml index a292b5378b14e..2d164460e6cf9 100644 --- a/airflow/providers/airbyte/provider.yaml +++ b/airflow/providers/airbyte/provider.yaml @@ -22,6 +22,9 @@ description: | `Airbyte `__ versions: + - 3.2.0 + - 3.1.0 + - 3.0.0 - 2.1.4 - 2.1.3 - 2.1.2 @@ -30,8 +33,9 @@ versions: - 2.0.0 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-http integrations: - integration-name: Airbyte @@ -56,9 +60,6 @@ sensors: python-modules: - airflow.providers.airbyte.sensors.airbyte -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.airbyte.hooks.airbyte.AirbyteHook - connection-types: - hook-class-name: airflow.providers.airbyte.hooks.airbyte.AirbyteHook connection-type: airbyte diff --git a/airflow/providers/airbyte/sensors/airbyte.py b/airflow/providers/airbyte/sensors/airbyte.py index 10c5954ee3a79..f1c03c90a35b1 100644 --- a/airflow/providers/airbyte/sensors/airbyte.py +++ b/airflow/providers/airbyte/sensors/airbyte.py @@ -16,6 +16,8 @@ # specific language governing permissions and limitations # under the License. """This module contains a Airbyte Job sensor.""" +from __future__ import annotations + from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException @@ -36,14 +38,14 @@ class AirbyteJobSensor(BaseSensorOperator): :param api_version: Optional. Airbyte API version. """ - template_fields: Sequence[str] = ('airbyte_job_id',) - ui_color = '#6C51FD' + template_fields: Sequence[str] = ("airbyte_job_id",) + ui_color = "#6C51FD" def __init__( self, *, airbyte_job_id: int, - airbyte_conn_id: str = 'airbyte_default', + airbyte_conn_id: str = "airbyte_default", api_version: str = "v1", **kwargs, ) -> None: @@ -52,10 +54,10 @@ def __init__( self.airbyte_job_id = airbyte_job_id self.api_version = api_version - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id, api_version=self.api_version) job = hook.get_job(job_id=self.airbyte_job_id) - status = job.json()['job']['status'] + status = job.json()["job"]["status"] if status == hook.FAILED: raise AirflowException(f"Job failed: \n{job}") diff --git a/airflow/providers/alibaba/CHANGELOG.rst b/airflow/providers/alibaba/CHANGELOG.rst index 3eceea9636fe2..265b65f634550 100644 --- a/airflow/providers/alibaba/CHANGELOG.rst +++ b/airflow/providers/alibaba/CHANGELOG.rst @@ -16,9 +16,83 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +2.2.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` +* ``Use log.exception where more economical than log.error (#27517)`` +* ``Replace urlparse with urlsplit (#27389)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + +2.1.0 +..... + +Features +~~~~~~~~ + +* ``Auto tail file logs in Web UI (#26169)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +2.0.1 +..... + +Bug Fixes +~~~~~~~~~ + +* ``Update providers to use functools compat for ''cached_property'' (#24582)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +2.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Features +~~~~~~~~ + + * ``SSL Bucket, Light Logic Refactor and Docstring Update for Alibaba Provider (#23891)`` + +Misc +~~~~ + + * ``Apply per-run log templates to log handlers (#24153)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Migrate Alibaba example DAGs to new design #22437 (#24130)`` + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 1.1.1 ..... diff --git a/airflow/providers/alibaba/cloud/example_dags/example_oss_bucket.py b/airflow/providers/alibaba/cloud/example_dags/example_oss_bucket.py deleted file mode 100644 index 6d16ce3feb23f..0000000000000 --- a/airflow/providers/alibaba/cloud/example_dags/example_oss_bucket.py +++ /dev/null @@ -1,41 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Ignore missing args provided by default_args -# type: ignore[call-arg] - -from datetime import datetime - -from airflow.models.dag import DAG -from airflow.providers.alibaba.cloud.operators.oss import OSSCreateBucketOperator, OSSDeleteBucketOperator - -# [START howto_operator_oss_bucket] -with DAG( - dag_id='oss_bucket_dag', - start_date=datetime(2021, 1, 1), - default_args={'bucket_name': 'your bucket', 'region': 'your region'}, - max_active_runs=1, - tags=['example'], - catchup=False, -) as dag: - - create_bucket = OSSCreateBucketOperator(task_id='task1') - - delete_bucket = OSSDeleteBucketOperator(task_id='task2') - - create_bucket >> delete_bucket -# [END howto_operator_oss_bucket] diff --git a/airflow/providers/alibaba/cloud/example_dags/example_oss_object.py b/airflow/providers/alibaba/cloud/example_dags/example_oss_object.py deleted file mode 100644 index ac3a07674492f..0000000000000 --- a/airflow/providers/alibaba/cloud/example_dags/example_oss_object.py +++ /dev/null @@ -1,61 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# Ignore missing args provided by default_args -# type: ignore[call-arg] - -from datetime import datetime - -from airflow.models.dag import DAG -from airflow.providers.alibaba.cloud.operators.oss import ( - OSSDeleteBatchObjectOperator, - OSSDeleteObjectOperator, - OSSDownloadObjectOperator, - OSSUploadObjectOperator, -) - -with DAG( - dag_id='oss_object_dag', - start_date=datetime(2021, 1, 1), - default_args={'bucket_name': 'your bucket', 'region': 'your region'}, - max_active_runs=1, - tags=['example'], - catchup=False, -) as dag: - - create_object = OSSUploadObjectOperator( - file='your local file', - key='your oss key', - task_id='task1', - ) - - download_object = OSSDownloadObjectOperator( - file='your local file', - key='your oss key', - task_id='task2', - ) - - delete_object = OSSDeleteObjectOperator( - key='your oss key', - task_id='task3', - ) - - delete_batch_object = OSSDeleteBatchObjectOperator( - keys=['obj1', 'obj2', 'obj3'], - task_id='task4', - ) - - create_object >> download_object >> delete_object >> delete_batch_object diff --git a/airflow/providers/alibaba/cloud/hooks/oss.py b/airflow/providers/alibaba/cloud/hooks/oss.py index 08272adb25e70..ab29958a00ba7 100644 --- a/airflow/providers/alibaba/cloud/hooks/oss.py +++ b/airflow/providers/alibaba/cloud/hooks/oss.py @@ -15,10 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from functools import wraps from inspect import signature -from typing import TYPE_CHECKING, Callable, Optional, TypeVar, cast -from urllib.parse import urlparse +from typing import TYPE_CHECKING, Callable, TypeVar, cast +from urllib.parse import urlsplit import oss2 from oss2.exceptions import ClientError @@ -43,10 +45,10 @@ def provide_bucket_name(func: T) -> T: def wrapper(*args, **kwargs) -> T: bound_args = function_signature.bind(*args, **kwargs) self = args[0] - if bound_args.arguments.get('bucket_name') is None and self.oss_conn_id: + if bound_args.arguments.get("bucket_name") is None and self.oss_conn_id: connection = self.get_connection(self.oss_conn_id) if connection.schema: - bound_args.arguments['bucket_name'] = connection.schema + bound_args.arguments["bucket_name"] = connection.schema return func(*bound_args.args, **bound_args.kwargs) @@ -65,13 +67,13 @@ def wrapper(*args, **kwargs) -> T: bound_args = function_signature.bind(*args, **kwargs) def get_key() -> str: - if 'key' in bound_args.arguments: - return 'key' - raise ValueError('Missing key parameter!') + if "key" in bound_args.arguments: + return "key" + raise ValueError("Missing key parameter!") key_name = get_key() - if 'bucket_name' not in bound_args.arguments or bound_args.arguments['bucket_name'] is None: - bound_args.arguments['bucket_name'], bound_args.arguments['key'] = OSSHook.parse_oss_url( + if "bucket_name" not in bound_args.arguments or bound_args.arguments["bucket_name"] is None: + bound_args.arguments["bucket_name"], bound_args.arguments["key"] = OSSHook.parse_oss_url( bound_args.arguments[key_name] ) @@ -83,18 +85,18 @@ def get_key() -> str: class OSSHook(BaseHook): """Interact with Alibaba Cloud OSS, using the oss2 library.""" - conn_name_attr = 'alibabacloud_conn_id' - default_conn_name = 'oss_default' - conn_type = 'oss' - hook_name = 'OSS' + conn_name_attr = "alibabacloud_conn_id" + default_conn_name = "oss_default" + conn_type = "oss" + hook_name = "OSS" - def __init__(self, region: Optional[str] = None, oss_conn_id='oss_default', *args, **kwargs) -> None: + def __init__(self, region: str | None = None, oss_conn_id="oss_default", *args, **kwargs) -> None: self.oss_conn_id = oss_conn_id self.oss_conn = self.get_connection(oss_conn_id) self.region = self.get_default_region() if region is None else region super().__init__(*args, **kwargs) - def get_conn(self) -> "Connection": + def get_conn(self) -> Connection: """Returns connection for the hook.""" return self.oss_conn @@ -106,26 +108,25 @@ def parse_oss_url(ossurl: str) -> tuple: :param ossurl: The OSS Url to parse. :return: the parsed bucket name and key """ - parsed_url = urlparse(ossurl) + parsed_url = urlsplit(ossurl) if not parsed_url.netloc: raise AirflowException(f'Please provide a bucket_name instead of "{ossurl}"') bucket_name = parsed_url.netloc - key = parsed_url.path.lstrip('/') + key = parsed_url.path.lstrip("/") return bucket_name, key @provide_bucket_name @unify_bucket_name_and_key - def object_exists(self, key: str, bucket_name: Optional[str] = None) -> bool: + def object_exists(self, key: str, bucket_name: str | None = None) -> bool: """ Check if object exists. :param key: the path of the object :param bucket_name: the name of the bucket :return: True if it exists and False if not. - :rtype: bool """ try: return self.get_bucket(bucket_name).object_exists(key) @@ -134,21 +135,20 @@ def object_exists(self, key: str, bucket_name: Optional[str] = None) -> bool: return False @provide_bucket_name - def get_bucket(self, bucket_name: Optional[str] = None) -> oss2.api.Bucket: + def get_bucket(self, bucket_name: str | None = None) -> oss2.api.Bucket: """ Returns a oss2.Bucket object :param bucket_name: the name of the bucket :return: the bucket object to the bucket name. - :rtype: oss2.api.Bucket """ auth = self.get_credential() assert self.region is not None - return oss2.Bucket(auth, f'https://oss-{self.region}.aliyuncs.com', bucket_name) + return oss2.Bucket(auth, f"https://oss-{self.region}.aliyuncs.com", bucket_name) @provide_bucket_name @unify_bucket_name_and_key - def load_string(self, key: str, content: str, bucket_name: Optional[str] = None) -> None: + def load_string(self, key: str, content: str, bucket_name: str | None = None) -> None: """ Loads a string to OSS @@ -167,7 +167,7 @@ def upload_local_file( self, key: str, file: str, - bucket_name: Optional[str] = None, + bucket_name: str | None = None, ) -> None: """ Upload a local file to OSS @@ -187,8 +187,8 @@ def download_file( self, key: str, local_file: str, - bucket_name: Optional[str] = None, - ) -> Optional[str]: + bucket_name: str | None = None, + ) -> str | None: """ Download file from OSS @@ -196,7 +196,6 @@ def download_file( :param local_file: local path + file name to save. :param bucket_name: the name of the bucket :return: the file name. - :rtype: str """ try: self.get_bucket(bucket_name).get_object_to_file(key, local_file) @@ -210,7 +209,7 @@ def download_file( def delete_object( self, key: str, - bucket_name: Optional[str] = None, + bucket_name: str | None = None, ) -> None: """ Delete object from OSS @@ -229,7 +228,7 @@ def delete_object( def delete_objects( self, key: list, - bucket_name: Optional[str] = None, + bucket_name: str | None = None, ) -> None: """ Delete objects from OSS @@ -246,7 +245,7 @@ def delete_objects( @provide_bucket_name def delete_bucket( self, - bucket_name: Optional[str] = None, + bucket_name: str | None = None, ) -> None: """ Delete bucket from OSS @@ -262,7 +261,7 @@ def delete_bucket( @provide_bucket_name def create_bucket( self, - bucket_name: Optional[str] = None, + bucket_name: str | None = None, ) -> None: """ Create bucket @@ -277,7 +276,7 @@ def create_bucket( @provide_bucket_name @unify_bucket_name_and_key - def append_string(self, bucket_name: Optional[str], content: str, key: str, pos: int) -> None: + def append_string(self, bucket_name: str | None, content: str, key: str, pos: int) -> None: """ Append string to a remote existing file @@ -295,7 +294,7 @@ def append_string(self, bucket_name: Optional[str], content: str, key: str, pos: @provide_bucket_name @unify_bucket_name_and_key - def read_key(self, bucket_name: Optional[str], key: str) -> str: + def read_key(self, bucket_name: str | None, key: str) -> str: """ Read oss remote object content with the specified key @@ -311,7 +310,7 @@ def read_key(self, bucket_name: Optional[str], key: str) -> str: @provide_bucket_name @unify_bucket_name_and_key - def head_key(self, bucket_name: Optional[str], key: str) -> oss2.models.HeadObjectResult: + def head_key(self, bucket_name: str | None, key: str) -> oss2.models.HeadObjectResult: """ Get meta info of the specified remote object @@ -327,7 +326,7 @@ def head_key(self, bucket_name: Optional[str], key: str) -> oss2.models.HeadObje @provide_bucket_name @unify_bucket_name_and_key - def key_exist(self, bucket_name: Optional[str], key: str) -> bool: + def key_exist(self, bucket_name: str | None, key: str) -> bool: """ Find out whether the specified key exists in the oss remote storage @@ -335,7 +334,7 @@ def key_exist(self, bucket_name: Optional[str], key: str) -> bool: :param key: oss bucket key """ # full_path = None - self.log.info('Looking up oss bucket %s for bucket key %s ...', bucket_name, key) + self.log.info("Looking up oss bucket %s for bucket key %s ...", bucket_name, key) try: return self.get_bucket(bucket_name).object_exists(key) except Exception as e: @@ -344,14 +343,14 @@ def key_exist(self, bucket_name: Optional[str], key: str) -> bool: def get_credential(self) -> oss2.auth.Auth: extra_config = self.oss_conn.extra_dejson - auth_type = extra_config.get('auth_type', None) + auth_type = extra_config.get("auth_type", None) if not auth_type: raise Exception("No auth_type specified in extra_config. ") - if auth_type != 'AK': + if auth_type != "AK": raise Exception(f"Unsupported auth_type: {auth_type}") - oss_access_key_id = extra_config.get('access_key_id', None) - oss_access_key_secret = extra_config.get('access_key_secret', None) + oss_access_key_id = extra_config.get("access_key_id", None) + oss_access_key_secret = extra_config.get("access_key_secret", None) if not oss_access_key_id: raise Exception(f"No access_key_id is specified for connection: {self.oss_conn_id}") @@ -360,16 +359,16 @@ def get_credential(self) -> oss2.auth.Auth: return oss2.Auth(oss_access_key_id, oss_access_key_secret) - def get_default_region(self) -> Optional[str]: + def get_default_region(self) -> str | None: extra_config = self.oss_conn.extra_dejson - auth_type = extra_config.get('auth_type', None) + auth_type = extra_config.get("auth_type", None) if not auth_type: raise Exception("No auth_type specified in extra_config. ") - if auth_type != 'AK': + if auth_type != "AK": raise Exception(f"Unsupported auth_type: {auth_type}") - default_region = extra_config.get('region', None) + default_region = extra_config.get("region", None) if not default_region: raise Exception(f"No region is specified for connection: {self.oss_conn_id}") return default_region diff --git a/airflow/providers/alibaba/cloud/log/oss_task_handler.py b/airflow/providers/alibaba/cloud/log/oss_task_handler.py index bb1d801ac9821..c443b4e014392 100644 --- a/airflow/providers/alibaba/cloud/log/oss_task_handler.py +++ b/airflow/providers/alibaba/cloud/log/oss_task_handler.py @@ -15,16 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import contextlib import os import pathlib -import sys - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property +from airflow.compat.functools import cached_property from airflow.configuration import conf from airflow.providers.alibaba.cloud.hooks.oss import OSSHook from airflow.utils.log.file_task_handler import FileTaskHandler @@ -38,27 +35,27 @@ class OSSTaskHandler(FileTaskHandler, LoggingMixin): uploads to and reads from OSS remote storage. """ - def __init__(self, base_log_folder, oss_log_folder, filename_template): + def __init__(self, base_log_folder, oss_log_folder, filename_template=None): self.log.info("Using oss_task_handler for remote logging...") super().__init__(base_log_folder, filename_template) (self.bucket_name, self.base_folder) = OSSHook.parse_oss_url(oss_log_folder) - self.log_relative_path = '' + self.log_relative_path = "" self._hook = None self.closed = False self.upload_on_close = True @cached_property def hook(self): - remote_conn_id = conf.get('logging', 'REMOTE_LOG_CONN_ID') + remote_conn_id = conf.get("logging", "REMOTE_LOG_CONN_ID") self.log.info("remote_conn_id: %s", remote_conn_id) try: return OSSHook(oss_conn_id=remote_conn_id) except Exception as e: - self.log.error(e, exc_info=True) + self.log.exception(e) self.log.error( 'Could not create an OSSHook with connection id "%s". ' - 'Please make sure that airflow[oss] is installed and ' - 'the OSS connection exists.', + "Please make sure that airflow[oss] is installed and " + "the OSS connection exists.", remote_conn_id, ) @@ -73,7 +70,7 @@ def set_context(self, ti): # Clear the file first so that duplicate data is not uploaded # when re-using the same path (e.g. with rescheduled sensors) if self.upload_on_close: - with open(self.handler.baseFilename, 'w'): + with open(self.handler.baseFilename, "w"): pass def close(self): @@ -117,13 +114,13 @@ def _read(self, ti, try_number, metadata=None): remote_loc = log_relative_path if not self.oss_log_exists(remote_loc): - return super()._read(ti, try_number) + return super()._read(ti, try_number, metadata) # If OSS remote file exists, we do not fetch logs from task instance # local machine even if there are errors reading remote logs, as # returned remote_log will contain error messages. remote_log = self.oss_read(remote_loc, return_error=True) - log = f'*** Reading remote log from {remote_loc}.\n{remote_log}\n' - return log, {'end_of_log': True} + log = f"*** Reading remote log from {remote_loc}.\n{remote_log}\n" + return log, {"end_of_log": True} def oss_log_exists(self, remote_log_location): """ @@ -132,7 +129,7 @@ def oss_log_exists(self, remote_log_location): :param remote_log_location: log's location in remote storage :return: True if location exists else False """ - oss_remote_log_location = f'{self.base_folder}/{remote_log_location}' + oss_remote_log_location = f"{self.base_folder}/{remote_log_location}" with contextlib.suppress(Exception): return self.hook.key_exist(self.bucket_name, oss_remote_log_location) return False @@ -147,11 +144,11 @@ def oss_read(self, remote_log_location, return_error=False): error occurs. Otherwise returns '' when an error occurs. """ try: - oss_remote_log_location = f'{self.base_folder}/{remote_log_location}' + oss_remote_log_location = f"{self.base_folder}/{remote_log_location}" self.log.info("read remote log: %s", oss_remote_log_location) return self.hook.read_key(self.bucket_name, oss_remote_log_location) except Exception: - msg = f'Could not read logs from {oss_remote_log_location}' + msg = f"Could not read logs from {oss_remote_log_location}" self.log.exception(msg) # return error if needed if return_error: @@ -167,7 +164,7 @@ def oss_write(self, log, remote_log_location, append=True): :param append: if False, any existing log file is overwritten. If True, the new log is appended to any existing logs. """ - oss_remote_log_location = f'{self.base_folder}/{remote_log_location}' + oss_remote_log_location = f"{self.base_folder}/{remote_log_location}" pos = 0 if append and self.oss_log_exists(oss_remote_log_location): head = self.hook.head_key(self.bucket_name, oss_remote_log_location) @@ -178,7 +175,7 @@ def oss_write(self, log, remote_log_location, append=True): self.hook.append_string(self.bucket_name, log, oss_remote_log_location, pos) except Exception: self.log.exception( - 'Could not write logs to %s, log write pos is: %s, Append is %s', + "Could not write logs to %s, log write pos is: %s, Append is %s", oss_remote_log_location, str(pos), str(append), diff --git a/airflow/providers/alibaba/cloud/operators/oss.py b/airflow/providers/alibaba/cloud/operators/oss.py index 8ec9b4b13975e..2e43529566b03 100644 --- a/airflow/providers/alibaba/cloud/operators/oss.py +++ b/airflow/providers/alibaba/cloud/operators/oss.py @@ -15,9 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains Alibaba Cloud OSS operators.""" -from typing import TYPE_CHECKING, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING from airflow.models import BaseOperator from airflow.providers.alibaba.cloud.hooks.oss import OSSHook @@ -38,8 +39,8 @@ class OSSCreateBucketOperator(BaseOperator): def __init__( self, region: str, - bucket_name: Optional[str] = None, - oss_conn_id: str = 'oss_default', + bucket_name: str | None = None, + oss_conn_id: str = "oss_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -47,7 +48,7 @@ def __init__( self.region = region self.bucket_name = bucket_name - def execute(self, context: 'Context'): + def execute(self, context: Context): oss_hook = OSSHook(oss_conn_id=self.oss_conn_id, region=self.region) oss_hook.create_bucket(bucket_name=self.bucket_name) @@ -64,8 +65,8 @@ class OSSDeleteBucketOperator(BaseOperator): def __init__( self, region: str, - bucket_name: Optional[str] = None, - oss_conn_id: str = 'oss_default', + bucket_name: str | None = None, + oss_conn_id: str = "oss_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -73,7 +74,7 @@ def __init__( self.region = region self.bucket_name = bucket_name - def execute(self, context: 'Context'): + def execute(self, context: Context): oss_hook = OSSHook(oss_conn_id=self.oss_conn_id, region=self.region) oss_hook.delete_bucket(bucket_name=self.bucket_name) @@ -94,8 +95,8 @@ def __init__( key: str, file: str, region: str, - bucket_name: Optional[str] = None, - oss_conn_id: str = 'oss_default', + bucket_name: str | None = None, + oss_conn_id: str = "oss_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -105,7 +106,7 @@ def __init__( self.region = region self.bucket_name = bucket_name - def execute(self, context: 'Context'): + def execute(self, context: Context): oss_hook = OSSHook(oss_conn_id=self.oss_conn_id, region=self.region) oss_hook.upload_local_file(bucket_name=self.bucket_name, key=self.key, file=self.file) @@ -126,8 +127,8 @@ def __init__( key: str, file: str, region: str, - bucket_name: Optional[str] = None, - oss_conn_id: str = 'oss_default', + bucket_name: str | None = None, + oss_conn_id: str = "oss_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -137,7 +138,7 @@ def __init__( self.region = region self.bucket_name = bucket_name - def execute(self, context: 'Context'): + def execute(self, context: Context): oss_hook = OSSHook(oss_conn_id=self.oss_conn_id, region=self.region) oss_hook.download_file(bucket_name=self.bucket_name, key=self.key, local_file=self.file) @@ -156,8 +157,8 @@ def __init__( self, keys: list, region: str, - bucket_name: Optional[str] = None, - oss_conn_id: str = 'oss_default', + bucket_name: str | None = None, + oss_conn_id: str = "oss_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -166,7 +167,7 @@ def __init__( self.region = region self.bucket_name = bucket_name - def execute(self, context: 'Context'): + def execute(self, context: Context): oss_hook = OSSHook(oss_conn_id=self.oss_conn_id, region=self.region) oss_hook.delete_objects(bucket_name=self.bucket_name, key=self.keys) @@ -185,8 +186,8 @@ def __init__( self, key: str, region: str, - bucket_name: Optional[str] = None, - oss_conn_id: str = 'oss_default', + bucket_name: str | None = None, + oss_conn_id: str = "oss_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -195,6 +196,6 @@ def __init__( self.region = region self.bucket_name = bucket_name - def execute(self, context: 'Context'): + def execute(self, context: Context): oss_hook = OSSHook(oss_conn_id=self.oss_conn_id, region=self.region) oss_hook.delete_object(bucket_name=self.bucket_name, key=self.key) diff --git a/airflow/providers/alibaba/cloud/sensors/oss_key.py b/airflow/providers/alibaba/cloud/sensors/oss_key.py index 0160783178b08..98b2c25beaee1 100644 --- a/airflow/providers/alibaba/cloud/sensors/oss_key.py +++ b/airflow/providers/alibaba/cloud/sensors/oss_key.py @@ -15,16 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys +from __future__ import annotations -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - -from typing import TYPE_CHECKING, Optional, Sequence -from urllib.parse import urlparse +from typing import TYPE_CHECKING, Sequence +from urllib.parse import urlsplit +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.providers.alibaba.cloud.hooks.oss import OSSHook from airflow.sensors.base import BaseSensorOperator @@ -47,14 +43,14 @@ class OSSKeySensor(BaseSensorOperator): :param oss_conn_id: The Airflow connection used for OSS credentials. """ - template_fields: Sequence[str] = ('bucket_key', 'bucket_name') + template_fields: Sequence[str] = ("bucket_key", "bucket_name") def __init__( self, bucket_key: str, region: str, - bucket_name: Optional[str] = None, - oss_conn_id: Optional[str] = 'oss_default', + bucket_name: str | None = None, + oss_conn_id: str | None = "oss_default", **kwargs, ): super().__init__(**kwargs) @@ -63,9 +59,9 @@ def __init__( self.bucket_key = bucket_key self.region = region self.oss_conn_id = oss_conn_id - self.hook: Optional[OSSHook] = None + self.hook: OSSHook | None = None - def poke(self, context: 'Context'): + def poke(self, context: Context): """ Check if the object exists in the bucket to pull key. @param self - the object itself @@ -73,21 +69,21 @@ def poke(self, context: 'Context'): @returns True if the object exists, False otherwise """ if self.bucket_name is None: - parsed_url = urlparse(self.bucket_key) - if parsed_url.netloc == '': - raise AirflowException('If key is a relative path from root, please provide a bucket_name') + parsed_url = urlsplit(self.bucket_key) + if parsed_url.netloc == "": + raise AirflowException("If key is a relative path from root, please provide a bucket_name") self.bucket_name = parsed_url.netloc - self.bucket_key = parsed_url.path.lstrip('/') + self.bucket_key = parsed_url.path.lstrip("/") else: - parsed_url = urlparse(self.bucket_key) - if parsed_url.scheme != '' or parsed_url.netloc != '': + parsed_url = urlsplit(self.bucket_key) + if parsed_url.scheme != "" or parsed_url.netloc != "": raise AirflowException( - 'If bucket_name is provided, bucket_key' - ' should be relative path from root' - ' level, rather than a full oss:// url' + "If bucket_name is provided, bucket_key" + " should be relative path from root" + " level, rather than a full oss:// url" ) - self.log.info('Poking for key : oss://%s/%s', self.bucket_name, self.bucket_key) + self.log.info("Poking for key : oss://%s/%s", self.bucket_name, self.bucket_key) return self.get_hook.object_exists(key=self.bucket_key, bucket_name=self.bucket_name) @cached_property diff --git a/airflow/providers/alibaba/provider.yaml b/airflow/providers/alibaba/provider.yaml index ca5e03f1b27d2..cd5cbced35249 100644 --- a/airflow/providers/alibaba/provider.yaml +++ b/airflow/providers/alibaba/provider.yaml @@ -22,13 +22,18 @@ description: | Alibaba Cloud integration (including `Alibaba Cloud `__). versions: + - 2.2.0 + - 2.1.0 + - 2.0.1 + - 2.0.0 - 1.1.1 - 1.1.0 - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - oss2>=2.14.0 integrations: - integration-name: Alibaba Cloud OSS @@ -53,8 +58,6 @@ hooks: python-modules: - airflow.providers.alibaba.cloud.hooks.oss -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.alibaba.cloud.hooks.oss.OSSHook connection-types: - hook-class-name: airflow.providers.alibaba.cloud.hooks.oss.OSSHook diff --git a/airflow/providers/amazon/CHANGELOG.rst b/airflow/providers/amazon/CHANGELOG.rst index e8a0f442c457c..d55861f1c6d2d 100644 --- a/airflow/providers/amazon/CHANGELOG.rst +++ b/airflow/providers/amazon/CHANGELOG.rst @@ -16,9 +16,300 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +6.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` +* ``Replace urlparse with urlsplit (#27389)`` + +Features +~~~~~~~~ + +* ``Add info about JSON Connection format for AWS SSM Parameter Store Secrets Backend (#27134)`` +* ``Add default name to EMR Serverless jobs (#27458)`` +* ``Adding 'preserve_file_name' param to 'S3Hook.download_file' method (#26886)`` +* ``Add GlacierUploadArchiveOperator (#26652)`` +* ``Add RdsStopDbOperator and RdsStartDbOperator (#27076)`` +* ``'GoogleApiToS3Operator' : add 'gcp_conn_id' to template fields (#27017)`` +* ``Add SQLExecuteQueryOperator (#25717)`` +* ``Add information about Amazon Elastic MapReduce Connection (#26687)`` +* ``Add BatchOperator template fields (#26805)`` +* ``Improve testing AWS Connection response (#26953)`` + +Bug Fixes +~~~~~~~~~ + +* ``SagemakerProcessingOperator stopped honoring 'existing_jobs_found' (#27456)`` +* ``CloudWatch task handler doesn't fall back to local logs when Amazon CloudWatch logs aren't found (#27564)`` +* ``Fix backwards compatibility for RedshiftSQLOperator (#27602)`` +* ``Fix typo in redshift sql hook get_ui_field_behaviour (#27533)`` +* ``Fix example_emr_serverless system test (#27149)`` +* ``Fix param in docstring RedshiftSQLHook get_table_primary_key method (#27330)`` +* ``Adds s3_key_prefix to template fields (#27207)`` +* ``Fix assume role if user explicit set credentials (#26946)`` +* ``Fix failure state in waiter call for EmrServerlessStartJobOperator. (#26853)`` +* ``Fix a bunch of deprecation warnings AWS tests (#26857)`` +* ``Fix null strings bug in SqlToS3Operator in non parquet formats (#26676)`` +* ``Sagemaker hook: remove extra call at the end when waiting for completion (#27551)`` +* ``ECS Buglette (#26921)`` +* ``Avoid circular imports in AWS Secrets Backends if obtain secrets from config (#26784)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``sagemaker operators: mutualize init of aws_conn_id (#27579)`` + * ``Upgrade dependencies in order to avoid backtracking (#27531)`` + * ``Code quality improvements on sagemaker operators/hook (#27453)`` + * ``Update old style typing (#26872)`` + * ``System test for SQL to S3 Transfer (AIP-47) (#27097)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Convert emr_eks example dag to system test (#26723)`` + * ``System test for Dynamo DB (#26729)`` + * ``ECS System Test (#26808)`` + * ``RDS Instance System Tests (#26733)`` + +6.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +.. warning:: + In this version of provider Amazon S3 Connection (``conn_type="s3"``) removed due to the fact that it was always + an alias to AWS connection ``conn_type="aws"`` + In practice the only impact is you won't be able to ``test`` the connection in the web UI / API. + In order to restore ability to test connection you need to change connection type from **Amazon S3** (``conn_type="s3"``) + to **Amazon Web Services** (``conn_type="aws"``) manually. + +* ``Remove Amazon S3 Connection Type (#25980)`` + +Features +~~~~~~~~ + +* ``Add RdsDbSensor to amazon provider package (#26003)`` +* ``Set template_fields on RDS operators (#26005)`` +* ``Auto tail file logs in Web UI (#26169)`` + +Bug Fixes +~~~~~~~~~ + +* ``Fix SageMakerEndpointConfigOperator's return value (#26541)`` +* ``EMR Serverless Fix for Jobs marked as success even on failure (#26218)`` +* ``Fix AWS Connection warn condition for invalid 'profile_name' argument (#26464)`` +* ``Athena and EMR operator max_retries mix-up fix (#25971)`` +* ``Fixes SageMaker operator return values (#23628)`` +* ``Remove redundant catch exception in Amazon Log Task Handlers (#26442)`` + +Misc +~~~~ + +* ``Remove duplicated connection-type within the provider (#26628)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Redshift to S3 and S3 to Redshift System test (AIP-47) (#26613)`` + * ``Convert example_eks_with_fargate_in_one_step.py and example_eks_with_fargate_profile to AIP-47 (#26537)`` + * ``Redshift System Test (AIP-47) (#26187)`` + * ``GoogleAPIToS3Operator System Test (AIP-47) (#26370)`` + * ``Convert EKS with Nodegroups sample DAG to a system test (AIP-47) (#26539)`` + * ``Convert EC2 sample DAG to system test (#26540)`` + * ``Convert S3 example DAG to System test (AIP-47) (#26535)`` + * ``Convert 'example_eks_with_nodegroup_in_one_step' sample DAG to system test (AIP-47) (#26410)`` + * ``Migrate DMS sample dag to system test (#26270)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``D400 first line should end with period batch02 (#25268)`` + * ``Change links to 'boto3' documentation (#26708)`` + +5.1.0 +..... + + +Features +~~~~~~~~ + +* ``Additional mask aws credentials (#26014)`` +* ``Add RedshiftDeleteClusterSnapshotOperator (#25975)`` +* ``Add redshift create cluster snapshot operator (#25857)`` +* ``Add common-sql lower bound for common-sql (#25789)`` +* ``Allow AWS Secrets Backends use AWS Connection capabilities (#25628)`` +* ``Implement 'EmrEksCreateClusterOperator' (#25816)`` +* ``Improve error handling/messaging around bucket exist check (#25805)`` + +Bug Fixes +~~~~~~~~~ + +* ``Fix display aws connection info (#26025)`` +* ``Fix 'EcsBaseOperator' and 'EcsBaseSensor' arguments (#25989)`` +* ``Fix RDS system test (#25839)`` +* ``Avoid circular import problems when instantiating AWS SM backend (#25810)`` +* ``fix bug construction of Connection object in version 5.0.0rc3 (#25716)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Fix EMR serverless system test (#25969)`` + * ``Add 'output' property to MappedOperator (#25604)`` + * ``Add Airflow specific warning classes (#25799)`` + * ``Replace SQL with Common SQL in pre commit (#26058)`` + * ``Hook into Mypy to get rid of those cast() (#26023)`` + * ``Raise an error on create bucket if use regional endpoint for us-east-1 and region not set (#25945)`` + * ``Update AWS system tests to use SystemTestContextBuilder (#25748)`` + * ``Convert Quicksight Sample DAG to System Test (#25696)`` + * ``Consolidate to one 'schedule' param (#25410)`` + +5.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* ``Avoid requirement that AWS Secret Manager JSON values be urlencoded. (#25432)`` +* ``Remove deprecated modules (#25543)`` +* ``Resolve Amazon Hook's 'region_name' and 'config' in wrapper (#25336)`` +* ``Resolve and validate AWS Connection parameters in wrapper (#25256)`` +* ``Standardize AwsLambda (#25100)`` +* ``Refactor monolithic ECS Operator into Operators, Sensors, and a Hook (#25413)`` +* ``Remove deprecated modules from Amazon provider package (#25609)`` + +Features +~~~~~~~~ + +* ``Add EMR Serverless Operators and Hooks (#25324)`` +* ``Hide unused fields for Amazon Web Services connection (#25416)`` +* ``Enable Auto-incrementing Transform job name in SageMakerTransformOperator (#25263)`` +* ``Unify DbApiHook.run() method with the methods which override it (#23971)`` +* ``SQSPublishOperator should allow sending messages to a FIFO Queue (#25171)`` +* ``Glue Job Driver logging (#25142)`` +* ``Bump typing-extensions and mypy for ParamSpec (#25088)`` +* ``Enable multiple query execution in RedshiftDataOperator (#25619)`` + +Bug Fixes +~~~~~~~~~ + +* ``Fix S3Hook transfer config arguments validation (#25544)`` +* ``Fix BatchOperator links on wait_for_completion = True (#25228)`` +* ``Makes changes to SqlToS3Operator method _fix_int_dtypes (#25083)`` +* ``refactor: Deprecate parameter 'host' as an extra attribute for the connection. Depreciation is happening in favor of 'endpoint_url' in extra. (#25494)`` +* ``Get boto3.session.Session by appropriate method (#25569)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``System test for EMR Serverless (#25559)`` + * ``Convert Local to S3 example DAG to System Test (AIP-47) (#25345)`` + * ``Convert ECS Fargate Sample DAG to System Test (#25316)`` + * ``Sagemaker System Tests - Part 3 of 3 - example_sagemaker_endpoint.py (AIP-47) (#25134)`` + * ``Convert RDS Export Sample DAG to System Test (AIP-47) (#25205)`` + * ``AIP-47 - Migrate redshift DAGs to new design #22438 (#24239)`` + * ``Convert Glue Sample DAG to System Test (#25136)`` + * ``Convert the batch sample dag to system tests (AIP-47) (#24448)`` + * ``Migrate datasync sample dag to system tests (AIP-47) (#24354)`` + * ``Sagemaker System Tests - Part 2 of 3 - example_sagemaker.py (#25079)`` + * ``Migrate lambda sample dag to system test (AIP-47) (#24355)`` + * ``SageMaker system tests - Part 1 of 3 - Prep Work (AIP-47) (#25078)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + +4.1.0 +..... + +Features +~~~~~~~~ + +* ``Add test_connection method to AWS hook (#24662)`` +* ``Add AWS operators to create and delete RDS Database (#24099)`` +* ``Add batch option to 'SqsSensor' (#24554)`` +* ``Add AWS Batch & AWS CloudWatch Extra Links (#24406)`` +* ``Refactoring EmrClusterLink and add for other AWS EMR Operators (#24294)`` +* ``Move all SQL classes to common-sql provider (#24836)`` +* ``Amazon appflow (#24057)`` +* ``Make extra_args in S3Hook immutable between calls (#24527)`` + +Bug Fixes +~~~~~~~~~ + +* ``Refactor and fix AWS secret manager invalid exception (#24898)`` +* ``fix: RedshiftDataHook and RdsHook not use cached connection (#24387)`` +* ``Fix links to sources for examples (#24386)`` +* ``Fix S3KeySensor. See #24321 (#24378)`` +* ``Fix: 'emr_conn_id' should be optional in 'EmrCreateJobFlowOperator' (#24306)`` +* ``Update providers to use functools compat for ''cached_property'' (#24582)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Convert RDS Event and Snapshot Sample DAGs to System Tests (#24932)`` + * ``Convert Step Functions Example DAG to System Test (AIP-47) (#24643)`` + * ``Update AWS Connection docs and deprecate some extras (#24670)`` + * ``Remove 'xcom_push' flag from providers (#24823)`` + * ``Align Black and blacken-docs configs (#24785)`` + * ``Restore Optional value of script_location (#24754)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Use our yaml util in all providers (#24720)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + * ``Convert SQS Sample DAG to System Test (#24513)`` + * ``Convert Cloudformation Sample DAG to System Test (#24447)`` + * ``Convert SNS Sample DAG to System Test (#24384)`` + +4.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Features +~~~~~~~~ + +* ``Add partition related methods to GlueCatalogHook: (#23857)`` +* ``Add support for associating custom tags to job runs submitted via EmrContainerOperator (#23769)`` +* ``Add number of node params only for single-node cluster in RedshiftCreateClusterOperator (#23839)`` + +Bug Fixes +~~~~~~~~~ + +* ``fix: StepFunctionHook ignores explicit set 'region_name' (#23976)`` +* ``Fix Amazon EKS example DAG raises warning during Imports (#23849)`` +* ``Move string arg evals to 'execute()' in 'EksCreateClusterOperator' (#23877)`` +* ``fix: patches #24215. Won't raise KeyError when 'create_job_kwargs' contains the 'Command' key. (#24308)`` + +Misc +~~~~ + +* ``Light Refactor and Clean-up AWS Provider (#23907)`` +* ``Update sample dag and doc for RDS (#23651)`` +* ``Reformat the whole AWS documentation (#23810)`` +* ``Replace "absolute()" with "resolve()" in pathlib objects (#23675)`` +* ``Apply per-run log templates to log handlers (#24153)`` +* ``Refactor GlueJobHook get_or_create_glue_job method. (#24215)`` +* ``Update the DMS Sample DAG and Docs (#23681)`` +* ``Update doc and sample dag for Quicksight (#23653)`` +* ``Update doc and sample dag for EMR Containers (#24087)`` +* ``Add AWS project structure tests (re: AIP-47) (#23630)`` +* ``Add doc and sample dag for GCSToS3Operator (#23730)`` +* ``Remove old Athena Sample DAG (#24170)`` +* ``Clean up f-strings in logging calls (#23597)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Introduce 'flake8-implicit-str-concat' plugin to static checks (#23873)`` + * ``pydocstyle D202 added (#24221)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 3.4.0 ..... @@ -66,7 +357,6 @@ Misc * ``Bump pre-commit hook versions (#22887)`` * ``Use new Breese for building, pulling and verifying the images. (#23104)`` -.. Review and move the new changes to one of the sections above: 3.3.0 ..... diff --git a/airflow/providers/amazon/aws/example_dags/example_appflow.py b/airflow/providers/amazon/aws/example_dags/example_appflow.py new file mode 100644 index 0000000000000..c986961499dc8 --- /dev/null +++ b/airflow/providers/amazon/aws/example_dags/example_appflow.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from airflow import DAG +from airflow.models.baseoperator import chain +from airflow.operators.bash import BashOperator +from airflow.providers.amazon.aws.operators.appflow import ( + AppflowRecordsShortCircuitOperator, + AppflowRunAfterOperator, + AppflowRunBeforeOperator, + AppflowRunDailyOperator, + AppflowRunFullOperator, + AppflowRunOperator, +) + +SOURCE_NAME = "salesforce" +FLOW_NAME = "salesforce-campaign" + +with DAG( + "example_appflow", + start_date=datetime(2022, 1, 1), + catchup=False, + tags=["example"], +) as dag: + + # [START howto_operator_appflow_run] + campaign_dump = AppflowRunOperator( + task_id="campaign_dump", + source=SOURCE_NAME, + flow_name=FLOW_NAME, + ) + # [END howto_operator_appflow_run] + + # [START howto_operator_appflow_run_full] + campaign_dump_full = AppflowRunFullOperator( + task_id="campaign_dump_full", + source=SOURCE_NAME, + flow_name=FLOW_NAME, + ) + # [END howto_operator_appflow_run_full] + + # [START howto_operator_appflow_run_daily] + campaign_dump_daily = AppflowRunDailyOperator( + task_id="campaign_dump_daily", + source=SOURCE_NAME, + flow_name=FLOW_NAME, + source_field="LastModifiedDate", + filter_date="{{ ds }}", + ) + # [END howto_operator_appflow_run_daily] + + # [START howto_operator_appflow_run_before] + campaign_dump_before = AppflowRunBeforeOperator( + task_id="campaign_dump_before", + source=SOURCE_NAME, + flow_name=FLOW_NAME, + source_field="LastModifiedDate", + filter_date="{{ ds }}", + ) + # [END howto_operator_appflow_run_before] + + # [START howto_operator_appflow_run_after] + campaign_dump_after = AppflowRunAfterOperator( + task_id="campaign_dump_after", + source=SOURCE_NAME, + flow_name=FLOW_NAME, + source_field="LastModifiedDate", + filter_date="3000-01-01", # Future date, so no records to dump + ) + # [END howto_operator_appflow_run_after] + + # [START howto_operator_appflow_shortcircuit] + campaign_dump_short_circuit = AppflowRecordsShortCircuitOperator( + task_id="campaign_dump_short_circuit", + flow_name=FLOW_NAME, + appflow_run_task_id="campaign_dump_after", # Should shortcircuit, no records expected + ) + # [END howto_operator_appflow_shortcircuit] + + should_be_skipped = BashOperator( + task_id="should_be_skipped", + bash_command="echo 1", + ) + + chain( + campaign_dump, + campaign_dump_full, + campaign_dump_daily, + campaign_dump_before, + campaign_dump_after, + campaign_dump_short_circuit, + should_be_skipped, + ) diff --git a/airflow/providers/amazon/aws/example_dags/example_athena.py b/airflow/providers/amazon/aws/example_dags/example_athena.py deleted file mode 100644 index 26b47bdfa19b5..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_athena.py +++ /dev/null @@ -1,124 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from datetime import datetime -from os import getenv - -from airflow import DAG -from airflow.decorators import task -from airflow.providers.amazon.aws.hooks.s3 import S3Hook -from airflow.providers.amazon.aws.operators.athena import AthenaOperator -from airflow.providers.amazon.aws.operators.s3 import S3CreateObjectOperator, S3DeleteObjectsOperator -from airflow.providers.amazon.aws.sensors.athena import AthenaSensor - -S3_BUCKET = getenv("S3_BUCKET", "test-bucket") -S3_KEY = getenv('S3_KEY', 'athena-demo') -ATHENA_TABLE = getenv('ATHENA_TABLE', 'test_table') -ATHENA_DATABASE = getenv('ATHENA_DATABASE', 'default') - -SAMPLE_DATA = """"Alice",20 -"Bob",25 -"Charlie",30 -""" -SAMPLE_FILENAME = 'airflow_sample.csv' - - -@task -def read_results_from_s3(query_execution_id): - s3_hook = S3Hook() - file_obj = s3_hook.get_conn().get_object(Bucket=S3_BUCKET, Key=f'{S3_KEY}/{query_execution_id}.csv') - file_content = file_obj['Body'].read().decode('utf-8') - print(file_content) - - -QUERY_CREATE_TABLE = f""" -CREATE EXTERNAL TABLE IF NOT EXISTS {ATHENA_DATABASE}.{ATHENA_TABLE} ( `name` string, `age` int ) -ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' -WITH SERDEPROPERTIES ( 'serialization.format' = ',', 'field.delim' = ',' -) LOCATION 's3://{S3_BUCKET}/{S3_KEY}/{ATHENA_TABLE}' -TBLPROPERTIES ('has_encrypted_data'='false') -""" - -QUERY_READ_TABLE = f""" -SELECT * from {ATHENA_DATABASE}.{ATHENA_TABLE} -""" - -QUERY_DROP_TABLE = f""" -DROP TABLE IF EXISTS {ATHENA_DATABASE}.{ATHENA_TABLE} -""" - -with DAG( - dag_id='example_athena', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - - upload_sample_data = S3CreateObjectOperator( - task_id='upload_sample_data', - s3_bucket=S3_BUCKET, - s3_key=f'{S3_KEY}/{ATHENA_TABLE}/{SAMPLE_FILENAME}', - data=SAMPLE_DATA, - replace=True, - ) - - create_table = AthenaOperator( - task_id='create_table', - query=QUERY_CREATE_TABLE, - database=ATHENA_DATABASE, - output_location=f's3://{S3_BUCKET}/{S3_KEY}', - ) - - # [START howto_operator_athena] - read_table = AthenaOperator( - task_id='read_table', - query=QUERY_READ_TABLE, - database=ATHENA_DATABASE, - output_location=f's3://{S3_BUCKET}/{S3_KEY}', - ) - # [END howto_operator_athena] - - # [START howto_sensor_athena] - await_query = AthenaSensor( - task_id='await_query', - query_execution_id=read_table.output, - ) - # [END howto_sensor_athena] - - drop_table = AthenaOperator( - task_id='drop_table', - query=QUERY_DROP_TABLE, - database=ATHENA_DATABASE, - output_location=f's3://{S3_BUCKET}/{S3_KEY}', - ) - - remove_s3_files = S3DeleteObjectsOperator( - task_id='remove_s3_files', - bucket=S3_BUCKET, - prefix=S3_KEY, - ) - - ( - upload_sample_data - >> create_table - >> read_table - >> await_query - >> read_results_from_s3(read_table.output) - >> drop_table - >> remove_s3_files - ) diff --git a/airflow/providers/amazon/aws/example_dags/example_batch.py b/airflow/providers/amazon/aws/example_dags/example_batch.py deleted file mode 100644 index 959e4de52d300..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_batch.py +++ /dev/null @@ -1,66 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from datetime import datetime -from json import loads -from os import environ - -from airflow import DAG -from airflow.providers.amazon.aws.operators.batch import BatchOperator -from airflow.providers.amazon.aws.sensors.batch import BatchSensor - -# The inputs below are required for the submit batch example DAG. -JOB_NAME = environ.get('BATCH_JOB_NAME', 'example_job_name') -JOB_DEFINITION = environ.get('BATCH_JOB_DEFINITION', 'example_job_definition') -JOB_QUEUE = environ.get('BATCH_JOB_QUEUE', 'example_job_queue') -JOB_OVERRIDES = loads(environ.get('BATCH_JOB_OVERRIDES', '{}')) - -# An existing (externally triggered) job id is required for the sensor example DAG. -JOB_ID = environ.get('BATCH_JOB_ID', '00000000-0000-0000-0000-000000000000') - - -with DAG( - dag_id='example_batch_submit_job', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as submit_dag: - - # [START howto_operator_batch] - submit_batch_job = BatchOperator( - task_id='submit_batch_job', - job_name=JOB_NAME, - job_queue=JOB_QUEUE, - job_definition=JOB_DEFINITION, - overrides=JOB_OVERRIDES, - ) - # [END howto_operator_batch] - -with DAG( - dag_id='example_batch_wait_for_job_sensor', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as sensor_dag: - - # [START howto_sensor_batch] - wait_for_batch_job = BatchSensor( - task_id='wait_for_batch_job', - job_id=JOB_ID, - ) - # [END howto_sensor_batch] diff --git a/airflow/providers/amazon/aws/example_dags/example_cloudformation.py b/airflow/providers/amazon/aws/example_dags/example_cloudformation.py deleted file mode 100644 index c31d3215ff5d8..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_cloudformation.py +++ /dev/null @@ -1,83 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from datetime import datetime -from json import dumps - -from airflow import DAG -from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.operators.cloud_formation import ( - CloudFormationCreateStackOperator, - CloudFormationDeleteStackOperator, -) -from airflow.providers.amazon.aws.sensors.cloud_formation import ( - CloudFormationCreateStackSensor, - CloudFormationDeleteStackSensor, -) - -CLOUDFORMATION_STACK_NAME = 'example-stack-name' -# The CloudFormation template must have at least one resource to be usable, this example uses SQS -# as a free and serverless option. -CLOUDFORMATION_TEMPLATE = { - 'Description': 'Stack from Airflow CloudFormation example DAG', - 'Resources': { - 'ExampleQueue': { - 'Type': 'AWS::SQS::Queue', - } - }, -} -CLOUDFORMATION_CREATE_PARAMETERS = { - 'StackName': CLOUDFORMATION_STACK_NAME, - 'TemplateBody': dumps(CLOUDFORMATION_TEMPLATE), - 'TimeoutInMinutes': 2, - 'OnFailure': 'DELETE', # Don't leave stacks behind if creation fails. -} - -with DAG( - dag_id='example_cloudformation', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - - # [START howto_operator_cloudformation_create_stack] - create_stack = CloudFormationCreateStackOperator( - task_id='create_stack', - stack_name=CLOUDFORMATION_STACK_NAME, - cloudformation_parameters=CLOUDFORMATION_CREATE_PARAMETERS, - ) - # [END howto_operator_cloudformation_create_stack] - - # [START howto_sensor_cloudformation_create_stack] - wait_for_stack_create = CloudFormationCreateStackSensor( - task_id='wait_for_stack_creation', stack_name=CLOUDFORMATION_STACK_NAME - ) - # [END howto_sensor_cloudformation_create_stack] - - # [START howto_operator_cloudformation_delete_stack] - delete_stack = CloudFormationDeleteStackOperator( - task_id='delete_stack', stack_name=CLOUDFORMATION_STACK_NAME - ) - # [END howto_operator_cloudformation_delete_stack] - - # [START howto_sensor_cloudformation_delete_stack] - wait_for_stack_delete = CloudFormationDeleteStackSensor( - task_id='wait_for_stack_deletion', trigger_rule='all_done', stack_name=CLOUDFORMATION_STACK_NAME - ) - # [END howto_sensor_cloudformation_delete_stack] - - chain(create_stack, wait_for_stack_create, delete_stack, wait_for_stack_delete) diff --git a/airflow/providers/amazon/aws/example_dags/example_datasync.py b/airflow/providers/amazon/aws/example_dags/example_datasync.py deleted file mode 100644 index 09c474079e497..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_datasync.py +++ /dev/null @@ -1,81 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import json -import re -from datetime import datetime -from os import getenv - -from airflow import models -from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.operators.datasync import DataSyncOperator - -TASK_ARN = getenv("TASK_ARN", "my_aws_datasync_task_arn") -SOURCE_LOCATION_URI = getenv("SOURCE_LOCATION_URI", "smb://hostname/directory/") -DESTINATION_LOCATION_URI = getenv("DESTINATION_LOCATION_URI", "s3://mybucket/prefix") -CREATE_TASK_KWARGS = json.loads(getenv("CREATE_TASK_KWARGS", '{"Name": "Created by Airflow"}')) -CREATE_SOURCE_LOCATION_KWARGS = json.loads(getenv("CREATE_SOURCE_LOCATION_KWARGS", '{}')) -default_destination_location_kwargs = """\ -{"S3BucketArn": "arn:aws:s3:::mybucket", - "S3Config": {"BucketAccessRoleArn": - "arn:aws:iam::11112223344:role/r-11112223344-my-bucket-access-role"} -}""" -CREATE_DESTINATION_LOCATION_KWARGS = json.loads( - getenv("CREATE_DESTINATION_LOCATION_KWARGS", re.sub(r"[\s+]", '', default_destination_location_kwargs)) -) -UPDATE_TASK_KWARGS = json.loads(getenv("UPDATE_TASK_KWARGS", '{"Name": "Updated by Airflow"}')) - -with models.DAG( - "example_datasync", - schedule_interval=None, # Override to match your needs - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['example'], -) as dag: - # [START howto_operator_datasync_specific_task] - # Execute a specific task - datasync_specific_task = DataSyncOperator(task_id="datasync_specific_task", task_arn=TASK_ARN) - # [END howto_operator_datasync_specific_task] - - # [START howto_operator_datasync_search_task] - # Search and execute a task - datasync_search_task = DataSyncOperator( - task_id="datasync_search_task", - source_location_uri=SOURCE_LOCATION_URI, - destination_location_uri=DESTINATION_LOCATION_URI, - ) - # [END howto_operator_datasync_search_task] - - # [START howto_operator_datasync_create_task] - # Create a task (the task does not exist) - datasync_create_task = DataSyncOperator( - task_id="datasync_create_task", - source_location_uri=SOURCE_LOCATION_URI, - destination_location_uri=DESTINATION_LOCATION_URI, - create_task_kwargs=CREATE_TASK_KWARGS, - create_source_location_kwargs=CREATE_SOURCE_LOCATION_KWARGS, - create_destination_location_kwargs=CREATE_DESTINATION_LOCATION_KWARGS, - update_task_kwargs=UPDATE_TASK_KWARGS, - delete_task_after_execution=True, - ) - # [END howto_operator_datasync_create_task] - - chain( - datasync_specific_task, - datasync_search_task, - datasync_create_task, - ) diff --git a/airflow/providers/amazon/aws/example_dags/example_dms.py b/airflow/providers/amazon/aws/example_dags/example_dms.py deleted file mode 100644 index caffe44353fda..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_dms.py +++ /dev/null @@ -1,347 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Note: DMS requires you to configure specific IAM roles/permissions. For more information, see -https://docs.aws.amazon.com/dms/latest/userguide/CHAP_Security.html#CHAP_Security.APIRole -""" - -import json -import os -from datetime import datetime - -import boto3 -from sqlalchemy import Column, MetaData, String, Table, create_engine - -from airflow import DAG -from airflow.decorators import task -from airflow.models.baseoperator import chain -from airflow.operators.python import get_current_context -from airflow.providers.amazon.aws.operators.dms import ( - DmsCreateTaskOperator, - DmsDeleteTaskOperator, - DmsDescribeTasksOperator, - DmsStartTaskOperator, - DmsStopTaskOperator, -) -from airflow.providers.amazon.aws.sensors.dms import DmsTaskBaseSensor, DmsTaskCompletedSensor - -S3_BUCKET = os.getenv('S3_BUCKET', 's3_bucket_name') -ROLE_ARN = os.getenv('ROLE_ARN', 'arn:aws:iam::1234567890:role/s3_target_endpoint_role') - -# The project name will be used as a prefix for various entity names. -# Use either PascalCase or camelCase. While some names require kebab-case -# and others require snake_case, they all accept mixedCase strings. -PROJECT_NAME = 'DmsDemo' - -# Config values for setting up the "Source" database. -RDS_ENGINE = 'postgres' -RDS_PROTOCOL = 'postgresql' -RDS_USERNAME = 'username' -# NEVER store your production password in plaintext in a DAG like this. -# Use Airflow Secrets or a secret manager for this in production. -RDS_PASSWORD = 'rds_password' - -# Config values for RDS. -RDS_INSTANCE_NAME = f'{PROJECT_NAME}-instance' -RDS_DB_NAME = f'{PROJECT_NAME}_source_database' - -# Config values for DMS. -DMS_REPLICATION_INSTANCE_NAME = f'{PROJECT_NAME}-replication-instance' -DMS_REPLICATION_TASK_ID = f'{PROJECT_NAME}-replication-task' -SOURCE_ENDPOINT_IDENTIFIER = f'{PROJECT_NAME}-source-endpoint' -TARGET_ENDPOINT_IDENTIFIER = f'{PROJECT_NAME}-target-endpoint' - -# Sample data. -TABLE_NAME = f'{PROJECT_NAME}-table' -TABLE_HEADERS = ['apache_project', 'release_year'] -SAMPLE_DATA = [ - ('Airflow', '2015'), - ('OpenOffice', '2012'), - ('Subversion', '2000'), - ('NiFi', '2006'), -] -TABLE_DEFINITION = { - 'TableCount': '1', - 'Tables': [ - { - 'TableName': TABLE_NAME, - 'TableColumns': [ - { - 'ColumnName': TABLE_HEADERS[0], - 'ColumnType': 'STRING', - 'ColumnNullable': 'false', - 'ColumnIsPk': 'true', - }, - {"ColumnName": TABLE_HEADERS[1], "ColumnType": 'STRING', "ColumnLength": "4"}, - ], - 'TableColumnsTotal': '2', - } - ], -} -TABLE_MAPPINGS = { - 'rules': [ - { - 'rule-type': 'selection', - 'rule-id': '1', - 'rule-name': '1', - 'object-locator': { - 'schema-name': 'public', - 'table-name': TABLE_NAME, - }, - 'rule-action': 'include', - } - ] -} - - -def _create_rds_instance(): - print('Creating RDS Instance.') - - rds_client = boto3.client('rds') - rds_client.create_db_instance( - DBName=RDS_DB_NAME, - DBInstanceIdentifier=RDS_INSTANCE_NAME, - AllocatedStorage=20, - DBInstanceClass='db.t3.micro', - Engine=RDS_ENGINE, - MasterUsername=RDS_USERNAME, - MasterUserPassword=RDS_PASSWORD, - ) - - rds_client.get_waiter('db_instance_available').wait(DBInstanceIdentifier=RDS_INSTANCE_NAME) - - response = rds_client.describe_db_instances(DBInstanceIdentifier=RDS_INSTANCE_NAME) - return response['DBInstances'][0]['Endpoint'] - - -def _create_rds_table(rds_endpoint): - print('Creating table.') - - hostname = rds_endpoint['Address'] - port = rds_endpoint['Port'] - rds_url = f'{RDS_PROTOCOL}://{RDS_USERNAME}:{RDS_PASSWORD}@{hostname}:{port}/{RDS_DB_NAME}' - engine = create_engine(rds_url) - - table = Table( - TABLE_NAME, - MetaData(engine), - Column(TABLE_HEADERS[0], String, primary_key=True), - Column(TABLE_HEADERS[1], String), - ) - - with engine.connect() as connection: - # Create the Table. - table.create() - load_data = table.insert().values(SAMPLE_DATA) - connection.execute(load_data) - - # Read the data back to verify everything is working. - connection.execute(table.select()) - - -def _create_dms_replication_instance(ti, dms_client): - print('Creating replication instance.') - instance_arn = dms_client.create_replication_instance( - ReplicationInstanceIdentifier=DMS_REPLICATION_INSTANCE_NAME, - ReplicationInstanceClass='dms.t3.micro', - )['ReplicationInstance']['ReplicationInstanceArn'] - - ti.xcom_push(key='replication_instance_arn', value=instance_arn) - return instance_arn - - -def _create_dms_endpoints(ti, dms_client, rds_instance_endpoint): - print('Creating DMS source endpoint.') - source_endpoint_arn = dms_client.create_endpoint( - EndpointIdentifier=SOURCE_ENDPOINT_IDENTIFIER, - EndpointType='source', - EngineName=RDS_ENGINE, - Username=RDS_USERNAME, - Password=RDS_PASSWORD, - ServerName=rds_instance_endpoint['Address'], - Port=rds_instance_endpoint['Port'], - DatabaseName=RDS_DB_NAME, - )['Endpoint']['EndpointArn'] - - print('Creating DMS target endpoint.') - target_endpoint_arn = dms_client.create_endpoint( - EndpointIdentifier=TARGET_ENDPOINT_IDENTIFIER, - EndpointType='target', - EngineName='s3', - S3Settings={ - 'BucketName': S3_BUCKET, - 'BucketFolder': PROJECT_NAME, - 'ServiceAccessRoleArn': ROLE_ARN, - 'ExternalTableDefinition': json.dumps(TABLE_DEFINITION), - }, - )['Endpoint']['EndpointArn'] - - ti.xcom_push(key='source_endpoint_arn', value=source_endpoint_arn) - ti.xcom_push(key='target_endpoint_arn', value=target_endpoint_arn) - - -def _await_setup_assets(dms_client, instance_arn): - print("Awaiting asset provisioning.") - dms_client.get_waiter('replication_instance_available').wait( - Filters=[{'Name': 'replication-instance-arn', 'Values': [instance_arn]}] - ) - - -def _delete_rds_instance(): - print('Deleting RDS Instance.') - - rds_client = boto3.client('rds') - rds_client.delete_db_instance( - DBInstanceIdentifier=RDS_INSTANCE_NAME, - SkipFinalSnapshot=True, - ) - - rds_client.get_waiter('db_instance_deleted').wait(DBInstanceIdentifier=RDS_INSTANCE_NAME) - - -def _delete_dms_assets(dms_client): - ti = get_current_context()['ti'] - replication_instance_arn = ti.xcom_pull(key='replication_instance_arn') - source_arn = ti.xcom_pull(key='source_endpoint_arn') - target_arn = ti.xcom_pull(key='target_endpoint_arn') - - print('Deleting DMS assets.') - dms_client.delete_replication_instance(ReplicationInstanceArn=replication_instance_arn) - dms_client.delete_endpoint(EndpointArn=source_arn) - dms_client.delete_endpoint(EndpointArn=target_arn) - - -def _await_all_teardowns(dms_client): - print('Awaiting tear-down.') - dms_client.get_waiter('replication_instance_deleted').wait( - Filters=[{'Name': 'replication-instance-id', 'Values': [DMS_REPLICATION_INSTANCE_NAME]}] - ) - - dms_client.get_waiter('endpoint_deleted').wait( - Filters=[ - { - 'Name': 'endpoint-id', - 'Values': [SOURCE_ENDPOINT_IDENTIFIER, TARGET_ENDPOINT_IDENTIFIER], - } - ] - ) - - -@task -def set_up(): - ti = get_current_context()['ti'] - dms_client = boto3.client('dms') - - rds_instance_endpoint = _create_rds_instance() - _create_rds_table(rds_instance_endpoint) - instance_arn = _create_dms_replication_instance(ti, dms_client) - _create_dms_endpoints(ti, dms_client, rds_instance_endpoint) - _await_setup_assets(dms_client, instance_arn) - - -@task(trigger_rule='all_done') -def clean_up(): - dms_client = boto3.client('dms') - - _delete_rds_instance() - _delete_dms_assets(dms_client) - _await_all_teardowns(dms_client) - - -with DAG( - dag_id='example_dms', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - - # [START howto_operator_dms_create_task] - create_task = DmsCreateTaskOperator( - task_id='create_task', - replication_task_id=DMS_REPLICATION_TASK_ID, - source_endpoint_arn='{{ ti.xcom_pull(key="source_endpoint_arn") }}', - target_endpoint_arn='{{ ti.xcom_pull(key="target_endpoint_arn") }}', - replication_instance_arn='{{ ti.xcom_pull(key="replication_instance_arn") }}', - table_mappings=TABLE_MAPPINGS, - ) - # [END howto_operator_dms_create_task] - - # [START howto_operator_dms_start_task] - start_task = DmsStartTaskOperator( - task_id='start_task', - replication_task_arn=create_task.output, - ) - # [END howto_operator_dms_start_task] - - # [START howto_operator_dms_describe_tasks] - describe_tasks = DmsDescribeTasksOperator( - task_id='describe_tasks', - describe_tasks_kwargs={ - 'Filters': [ - { - 'Name': 'replication-instance-arn', - 'Values': ['{{ ti.xcom_pull(key="replication_instance_arn") }}'], - } - ] - }, - do_xcom_push=False, - ) - # [END howto_operator_dms_describe_tasks] - - await_task_start = DmsTaskBaseSensor( - task_id='await_task_start', - replication_task_arn=create_task.output, - target_statuses=['running'], - termination_statuses=['stopped', 'deleting', 'failed'], - ) - - # [START howto_operator_dms_stop_task] - stop_task = DmsStopTaskOperator( - task_id='stop_task', - replication_task_arn=create_task.output, - ) - # [END howto_operator_dms_stop_task] - - # TaskCompletedSensor actually waits until task reaches the "Stopped" state, so it will work here. - # [START howto_sensor_dms_task_completed] - await_task_stop = DmsTaskCompletedSensor( - task_id='await_task_stop', - replication_task_arn=create_task.output, - ) - # [END howto_sensor_dms_task_completed] - - # [START howto_operator_dms_delete_task] - delete_task = DmsDeleteTaskOperator( - task_id='delete_task', - replication_task_arn=create_task.output, - trigger_rule='all_done', - ) - # [END howto_operator_dms_delete_task] - - chain( - set_up() - >> create_task - >> start_task - >> describe_tasks - >> await_task_start - >> stop_task - >> await_task_stop - >> delete_task - >> clean_up() - ) diff --git a/airflow/providers/amazon/aws/example_dags/example_dynamodb_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_dynamodb_to_s3.py deleted file mode 100644 index 66334fc996522..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_dynamodb_to_s3.py +++ /dev/null @@ -1,72 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from datetime import datetime -from os import environ - -from airflow import DAG -from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.transfers.dynamodb_to_s3 import DynamoDBToS3Operator - -TABLE_NAME = environ.get('DYNAMO_TABLE_NAME', 'ExistingDynamoDbTableName') -BUCKET_NAME = environ.get('S3_BUCKET_NAME', 'ExistingS3BucketName') - - -with DAG( - dag_id='example_dynamodb_to_s3', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - # [START howto_transfer_dynamodb_to_s3] - backup_db = DynamoDBToS3Operator( - task_id='backup_db', - dynamodb_table_name=TABLE_NAME, - s3_bucket_name=BUCKET_NAME, - # Max output file size in bytes. If the Table is too large, multiple files will be created. - file_size=1000, - ) - # [END howto_transfer_dynamodb_to_s3] - - # [START howto_transfer_dynamodb_to_s3_segmented] - # Segmenting allows the transfer to be parallelized into {segment} number of parallel tasks. - backup_db_segment_1 = DynamoDBToS3Operator( - task_id='backup-1', - dynamodb_table_name=TABLE_NAME, - s3_bucket_name=BUCKET_NAME, - # Max output file size in bytes. If the Table is too large, multiple files will be created. - file_size=1000, - dynamodb_scan_kwargs={ - "TotalSegments": 2, - "Segment": 0, - }, - ) - - backup_db_segment_2 = DynamoDBToS3Operator( - task_id="backup-2", - dynamodb_table_name=TABLE_NAME, - s3_bucket_name=BUCKET_NAME, - # Max output file size in bytes. If the Table is too large, multiple files will be created. - file_size=1000, - dynamodb_scan_kwargs={ - "TotalSegments": 2, - "Segment": 1, - }, - ) - # [END howto_transfer_dynamodb_to_s3_segmented] - - chain(backup_db, [backup_db_segment_1, backup_db_segment_2]) diff --git a/airflow/providers/amazon/aws/example_dags/example_ec2.py b/airflow/providers/amazon/aws/example_dags/example_ec2.py deleted file mode 100644 index 7d12aef36746e..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_ec2.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import os -from datetime import datetime - -from airflow import DAG -from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.operators.ec2 import EC2StartInstanceOperator, EC2StopInstanceOperator -from airflow.providers.amazon.aws.sensors.ec2 import EC2InstanceStateSensor - -INSTANCE_ID = os.getenv("INSTANCE_ID", "instance-id") - -with DAG( - dag_id='example_ec2', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - # [START howto_operator_ec2_start_instance] - start_instance = EC2StartInstanceOperator( - task_id="ec2_start_instance", - instance_id=INSTANCE_ID, - ) - # [END howto_operator_ec2_start_instance] - - # [START howto_sensor_ec2_instance_state] - instance_state = EC2InstanceStateSensor( - task_id="ec2_instance_state", - instance_id=INSTANCE_ID, - target_state="running", - ) - # [END howto_sensor_ec2_instance_state] - - # [START howto_operator_ec2_stop_instance] - stop_instance = EC2StopInstanceOperator( - task_id="ec2_stop_instance", - instance_id=INSTANCE_ID, - ) - # [END howto_operator_ec2_stop_instance] - - chain(start_instance, instance_state, stop_instance) diff --git a/airflow/providers/amazon/aws/example_dags/example_ecs.py b/airflow/providers/amazon/aws/example_dags/example_ecs.py deleted file mode 100644 index b8f1f67b8eb08..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_ecs.py +++ /dev/null @@ -1,59 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import os -from datetime import datetime - -from airflow import DAG -from airflow.providers.amazon.aws.operators.ecs import EcsOperator - -with DAG( - dag_id='example_ecs', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - - # [START howto_operator_ecs] - hello_world = EcsOperator( - task_id="hello_world", - cluster=os.environ.get("CLUSTER_NAME", "existing_cluster_name"), - task_definition=os.environ.get("TASK_DEFINITION", "existing_task_definition_name"), - launch_type="EXTERNAL|EC2", - aws_conn_id="aws_ecs", - overrides={ - "containerOverrides": [ - { - "name": "hello-world-container", - "command": ["echo", "hello", "world"], - }, - ], - }, - tags={ - "Customer": "X", - "Project": "Y", - "Application": "Z", - "Version": "0.0.1", - "Environment": "Development", - }, - # [START howto_awslogs_ecs] - awslogs_group="/ecs/hello-world", - awslogs_region="aws-region", - awslogs_stream_prefix="ecs/hello-world-container" - # [END howto_awslogs_ecs] - ) - # [END howto_operator_ecs] diff --git a/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py b/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py deleted file mode 100644 index 1e48367429faa..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py +++ /dev/null @@ -1,62 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import os -from datetime import datetime - -from airflow import DAG -from airflow.providers.amazon.aws.operators.ecs import EcsOperator - -with DAG( - dag_id='example_ecs_fargate', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - - # [START howto_operator_ecs] - hello_world = EcsOperator( - task_id="hello_world", - cluster=os.environ.get("CLUSTER_NAME", "existing_cluster_name"), - task_definition=os.environ.get("TASK_DEFINITION", "existing_task_definition_name"), - launch_type="FARGATE", - aws_conn_id="aws_ecs", - overrides={ - "containerOverrides": [ - { - "name": "hello-world-container", - "command": ["echo", "hello", "world"], - }, - ], - }, - network_configuration={ - "awsvpcConfiguration": { - "securityGroups": [os.environ.get("SECURITY_GROUP_ID", "sg-123abc")], - "subnets": [os.environ.get("SUBNET_ID", "subnet-123456ab")], - }, - }, - tags={ - "Customer": "X", - "Project": "Y", - "Application": "Z", - "Version": "0.0.1", - "Environment": "Development", - }, - awslogs_group="/ecs/hello-world", - awslogs_stream_prefix="prefix_b/hello-world-container", - ) - # [END howto_operator_ecs] diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_templated.py b/airflow/providers/amazon/aws/example_dags/example_eks_templated.py index 87d6d9ef44e68..da984adcf5547 100644 --- a/airflow/providers/amazon/aws/example_dags/example_eks_templated.py +++ b/airflow/providers/amazon/aws/example_dags/example_eks_templated.py @@ -14,9 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# mypy ignore arg types (for templated fields) -# type: ignore[arg-type] +from __future__ import annotations from datetime import datetime @@ -31,6 +29,10 @@ ) from airflow.providers.amazon.aws.sensors.eks import EksClusterStateSensor, EksNodegroupStateSensor +# mypy ignore arg types (for templated fields) +# type: ignore[arg-type] + + # Example Jinja Template format, substitute your values: """ { @@ -48,10 +50,9 @@ """ with DAG( - dag_id='example_eks_templated', - schedule_interval=None, + dag_id="example_eks_templated", start_date=datetime(2021, 1, 1), - tags=['example', 'templated'], + tags=["example", "templated"], catchup=False, # render_template_as_native_obj=True is what converts the Jinja to Python objects, instead of a string. render_template_as_native_obj=True, @@ -62,21 +63,22 @@ # Create an Amazon EKS Cluster control plane without attaching a compute service. create_cluster = EksCreateClusterOperator( - task_id='create_eks_cluster', + task_id="create_eks_cluster", cluster_name=CLUSTER_NAME, compute=None, cluster_role_arn="{{ dag_run.conf['cluster_role_arn'] }}", - resources_vpc_config="{{ dag_run.conf['resources_vpc_config'] }}", + # This only works with render_template_as_native_obj flag (this dag has it set) + resources_vpc_config="{{ dag_run.conf['resources_vpc_config'] }}", # type: ignore[arg-type] ) await_create_cluster = EksClusterStateSensor( - task_id='wait_for_create_cluster', + task_id="wait_for_create_cluster", cluster_name=CLUSTER_NAME, target_state=ClusterStates.ACTIVE, ) create_nodegroup = EksCreateNodegroupOperator( - task_id='create_eks_nodegroup', + task_id="create_eks_nodegroup", cluster_name=CLUSTER_NAME, nodegroup_name=NODEGROUP_NAME, nodegroup_subnets="{{ dag_run.conf['nodegroup_subnets'] }}", @@ -84,7 +86,7 @@ ) await_create_nodegroup = EksNodegroupStateSensor( - task_id='wait_for_create_nodegroup', + task_id="wait_for_create_nodegroup", cluster_name=CLUSTER_NAME, nodegroup_name=NODEGROUP_NAME, target_state=NodegroupStates.ACTIVE, @@ -103,25 +105,25 @@ ) delete_nodegroup = EksDeleteNodegroupOperator( - task_id='delete_eks_nodegroup', + task_id="delete_eks_nodegroup", cluster_name=CLUSTER_NAME, nodegroup_name=NODEGROUP_NAME, ) await_delete_nodegroup = EksNodegroupStateSensor( - task_id='wait_for_delete_nodegroup', + task_id="wait_for_delete_nodegroup", cluster_name=CLUSTER_NAME, nodegroup_name=NODEGROUP_NAME, target_state=NodegroupStates.NONEXISTENT, ) delete_cluster = EksDeleteClusterOperator( - task_id='delete_eks_cluster', + task_id="delete_eks_cluster", cluster_name=CLUSTER_NAME, ) await_delete_cluster = EksClusterStateSensor( - task_id='wait_for_delete_cluster', + task_id="wait_for_delete_cluster", cluster_name=CLUSTER_NAME, target_state=ClusterStates.NONEXISTENT, ) diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_in_one_step.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_in_one_step.py deleted file mode 100644 index e08e6525e6fc8..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_in_one_step.py +++ /dev/null @@ -1,107 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Ignore missing args provided by default_args -# type: ignore[call-arg] - -from datetime import datetime -from os import environ - -from airflow.models.dag import DAG -from airflow.providers.amazon.aws.hooks.eks import ClusterStates, FargateProfileStates -from airflow.providers.amazon.aws.operators.eks import ( - EksCreateClusterOperator, - EksDeleteClusterOperator, - EksPodOperator, -) -from airflow.providers.amazon.aws.sensors.eks import EksClusterStateSensor, EksFargateProfileStateSensor - -CLUSTER_NAME = 'fargate-all-in-one' -FARGATE_PROFILE_NAME = f'{CLUSTER_NAME}-profile' - -ROLE_ARN = environ.get('EKS_DEMO_ROLE_ARN', 'arn:aws:iam::123456789012:role/role_name') -SUBNETS = environ.get('EKS_DEMO_SUBNETS', 'subnet-12345ab subnet-67890cd').split(' ') -VPC_CONFIG = { - 'subnetIds': SUBNETS, - 'endpointPublicAccess': True, - 'endpointPrivateAccess': False, -} - - -with DAG( - dag_id='example_eks_with_fargate_in_one_step', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - - # [START howto_operator_eks_create_cluster_with_fargate_profile] - # Create an Amazon EKS cluster control plane and an AWS Fargate compute platform in one step. - create_cluster_and_fargate_profile = EksCreateClusterOperator( - task_id='create_eks_cluster_and_fargate_profile', - cluster_name=CLUSTER_NAME, - cluster_role_arn=ROLE_ARN, - resources_vpc_config=VPC_CONFIG, - compute='fargate', - fargate_profile_name=FARGATE_PROFILE_NAME, - # Opting to use the same ARN for the cluster and the pod here, - # but a different ARN could be configured and passed if desired. - fargate_pod_execution_role_arn=ROLE_ARN, - ) - # [END howto_operator_eks_create_cluster_with_fargate_profile] - - await_create_fargate_profile = EksFargateProfileStateSensor( - task_id='wait_for_create_fargate_profile', - cluster_name=CLUSTER_NAME, - fargate_profile_name=FARGATE_PROFILE_NAME, - target_state=FargateProfileStates.ACTIVE, - ) - - start_pod = EksPodOperator( - task_id="run_pod", - pod_name="run_pod", - cluster_name=CLUSTER_NAME, - image="amazon/aws-cli:latest", - cmds=["sh", "-c", "echo Test Airflow; date"], - labels={"demo": "hello_world"}, - get_logs=True, - # Delete the pod when it reaches its final state, or the execution is interrupted. - is_delete_operator_pod=True, - ) - - # An Amazon EKS cluster can not be deleted with attached resources such as nodegroups or Fargate profiles. - # Setting the `force` to `True` will delete any attached resources before deleting the cluster. - delete_all = EksDeleteClusterOperator( - task_id='delete_fargate_profile_and_cluster', - cluster_name=CLUSTER_NAME, - force_delete_compute=True, - ) - - await_delete_cluster = EksClusterStateSensor( - task_id='wait_for_delete_cluster', - cluster_name=CLUSTER_NAME, - target_state=ClusterStates.NONEXISTENT, - ) - - ( - create_cluster_and_fargate_profile - >> await_create_fargate_profile - >> start_pod - >> delete_all - >> await_delete_cluster - ) diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py deleted file mode 100644 index 3ca3b2eb87728..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py +++ /dev/null @@ -1,138 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Ignore missing args provided by default_args -# type: ignore[call-arg] - -from datetime import datetime -from os import environ - -from airflow.models.dag import DAG -from airflow.providers.amazon.aws.hooks.eks import ClusterStates, FargateProfileStates -from airflow.providers.amazon.aws.operators.eks import ( - EksCreateClusterOperator, - EksCreateFargateProfileOperator, - EksDeleteClusterOperator, - EksDeleteFargateProfileOperator, - EksPodOperator, -) -from airflow.providers.amazon.aws.sensors.eks import EksClusterStateSensor, EksFargateProfileStateSensor - -CLUSTER_NAME = 'fargate-demo' -FARGATE_PROFILE_NAME = f'{CLUSTER_NAME}-profile' -SELECTORS = [{'namespace': 'default'}] - -ROLE_ARN = environ.get('EKS_DEMO_ROLE_ARN', 'arn:aws:iam::123456789012:role/role_name') -SUBNETS = environ.get('EKS_DEMO_SUBNETS', 'subnet-12345ab subnet-67890cd').split(' ') -VPC_CONFIG = { - 'subnetIds': SUBNETS, - 'endpointPublicAccess': True, - 'endpointPrivateAccess': False, -} - - -with DAG( - dag_id='example_eks_with_fargate_profile', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - - # Create an Amazon EKS Cluster control plane without attaching a compute service. - create_cluster = EksCreateClusterOperator( - task_id='create_eks_cluster', - cluster_name=CLUSTER_NAME, - cluster_role_arn=ROLE_ARN, - resources_vpc_config=VPC_CONFIG, - compute=None, - ) - - await_create_cluster = EksClusterStateSensor( - task_id='wait_for_create_cluster', - cluster_name=CLUSTER_NAME, - target_state=ClusterStates.ACTIVE, - ) - - # [START howto_operator_eks_create_fargate_profile] - create_fargate_profile = EksCreateFargateProfileOperator( - task_id='create_eks_fargate_profile', - cluster_name=CLUSTER_NAME, - pod_execution_role_arn=ROLE_ARN, - fargate_profile_name=FARGATE_PROFILE_NAME, - selectors=SELECTORS, - ) - # [END howto_operator_eks_create_fargate_profile] - - # [START howto_sensor_eks_fargate] - await_create_fargate_profile = EksFargateProfileStateSensor( - task_id='wait_for_create_fargate_profile', - cluster_name=CLUSTER_NAME, - fargate_profile_name=FARGATE_PROFILE_NAME, - target_state=FargateProfileStates.ACTIVE, - ) - # [END howto_sensor_eks_fargate] - - start_pod = EksPodOperator( - task_id="run_pod", - cluster_name=CLUSTER_NAME, - pod_name="run_pod", - image="amazon/aws-cli:latest", - cmds=["sh", "-c", "echo Test Airflow; date"], - labels={"demo": "hello_world"}, - get_logs=True, - # Delete the pod when it reaches its final state, or the execution is interrupted. - is_delete_operator_pod=True, - ) - - # [START howto_operator_eks_delete_fargate_profile] - delete_fargate_profile = EksDeleteFargateProfileOperator( - task_id='delete_eks_fargate_profile', - cluster_name=CLUSTER_NAME, - fargate_profile_name=FARGATE_PROFILE_NAME, - ) - # [END howto_operator_eks_delete_fargate_profile] - - await_delete_fargate_profile = EksFargateProfileStateSensor( - task_id='wait_for_delete_fargate_profile', - cluster_name=CLUSTER_NAME, - fargate_profile_name=FARGATE_PROFILE_NAME, - target_state=FargateProfileStates.NONEXISTENT, - ) - - delete_cluster = EksDeleteClusterOperator( - task_id='delete_eks_cluster', - cluster_name=CLUSTER_NAME, - ) - - await_delete_cluster = EksClusterStateSensor( - task_id='wait_for_delete_cluster', - cluster_name=CLUSTER_NAME, - target_state=ClusterStates.NONEXISTENT, - ) - - ( - create_cluster - >> await_create_cluster - >> create_fargate_profile - >> await_create_fargate_profile - >> start_pod - >> delete_fargate_profile - >> await_delete_fargate_profile - >> delete_cluster - >> await_delete_cluster - ) diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py deleted file mode 100644 index 38d1bd1ad4c2f..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py +++ /dev/null @@ -1,109 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Ignore missing args provided by default_args -# type: ignore[call-arg] - -from datetime import datetime -from os import environ - -from airflow.models.dag import DAG -from airflow.providers.amazon.aws.hooks.eks import ClusterStates, NodegroupStates -from airflow.providers.amazon.aws.operators.eks import ( - EksCreateClusterOperator, - EksDeleteClusterOperator, - EksPodOperator, -) -from airflow.providers.amazon.aws.sensors.eks import EksClusterStateSensor, EksNodegroupStateSensor - -CLUSTER_NAME = environ.get('EKS_CLUSTER_NAME', 'eks-demo') -NODEGROUP_NAME = f'{CLUSTER_NAME}-nodegroup' -ROLE_ARN = environ.get('EKS_DEMO_ROLE_ARN', 'arn:aws:iam::123456789012:role/role_name') -SUBNETS = environ.get('EKS_DEMO_SUBNETS', 'subnet-12345ab subnet-67890cd').split(' ') -VPC_CONFIG = { - 'subnetIds': SUBNETS, - 'endpointPublicAccess': True, - 'endpointPrivateAccess': False, -} - - -with DAG( - dag_id='example_eks_with_nodegroup_in_one_step', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - - # [START howto_operator_eks_create_cluster_with_nodegroup] - # Create an Amazon EKS cluster control plane and an EKS nodegroup compute platform in one step. - create_cluster_and_nodegroup = EksCreateClusterOperator( - task_id='create_eks_cluster_and_nodegroup', - cluster_name=CLUSTER_NAME, - nodegroup_name=NODEGROUP_NAME, - cluster_role_arn=ROLE_ARN, - nodegroup_role_arn=ROLE_ARN, - # Opting to use the same ARN for the cluster and the nodegroup here, - # but a different ARN could be configured and passed if desired. - resources_vpc_config=VPC_CONFIG, - # Compute defaults to 'nodegroup' but is called out here for the purposed of the example. - compute='nodegroup', - ) - # [END howto_operator_eks_create_cluster_with_nodegroup] - - await_create_nodegroup = EksNodegroupStateSensor( - task_id='wait_for_create_nodegroup', - cluster_name=CLUSTER_NAME, - nodegroup_name=NODEGROUP_NAME, - target_state=NodegroupStates.ACTIVE, - ) - - start_pod = EksPodOperator( - task_id="run_pod", - cluster_name=CLUSTER_NAME, - pod_name="run_pod", - image="amazon/aws-cli:latest", - cmds=["sh", "-c", "echo Test Airflow; date"], - labels={"demo": "hello_world"}, - get_logs=True, - # Delete the pod when it reaches its final state, or the execution is interrupted. - is_delete_operator_pod=True, - ) - - # [START howto_operator_eks_force_delete_cluster] - # An Amazon EKS cluster can not be deleted with attached resources such as nodegroups or Fargate profiles. - # Setting the `force` to `True` will delete any attached resources before deleting the cluster. - delete_all = EksDeleteClusterOperator( - task_id='delete_nodegroup_and_cluster', - cluster_name=CLUSTER_NAME, - force_delete_compute=True, - ) - # [END howto_operator_eks_force_delete_cluster] - - await_delete_cluster = EksClusterStateSensor( - task_id='wait_for_delete_cluster', - cluster_name=CLUSTER_NAME, - target_state=ClusterStates.NONEXISTENT, - ) - - ( - create_cluster_and_nodegroup - >> await_create_nodegroup - >> start_pod - >> delete_all - >> await_delete_cluster - ) diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py deleted file mode 100644 index efeeb14e04bfb..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py +++ /dev/null @@ -1,145 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Ignore missing args provided by default_args -# type: ignore[call-arg] - -from datetime import datetime -from os import environ - -from airflow.models.dag import DAG -from airflow.providers.amazon.aws.hooks.eks import ClusterStates, NodegroupStates -from airflow.providers.amazon.aws.operators.eks import ( - EksCreateClusterOperator, - EksCreateNodegroupOperator, - EksDeleteClusterOperator, - EksDeleteNodegroupOperator, - EksPodOperator, -) -from airflow.providers.amazon.aws.sensors.eks import EksClusterStateSensor, EksNodegroupStateSensor - -CLUSTER_NAME = 'eks-demo' -NODEGROUP_SUFFIX = '-nodegroup' -NODEGROUP_NAME = CLUSTER_NAME + NODEGROUP_SUFFIX -ROLE_ARN = environ.get('EKS_DEMO_ROLE_ARN', 'arn:aws:iam::123456789012:role/role_name') -SUBNETS = environ.get('EKS_DEMO_SUBNETS', 'subnet-12345ab subnet-67890cd').split(' ') -VPC_CONFIG = { - 'subnetIds': SUBNETS, - 'endpointPublicAccess': True, - 'endpointPrivateAccess': False, -} - - -with DAG( - dag_id='example_eks_with_nodegroups', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - - # [START howto_operator_eks_create_cluster] - # Create an Amazon EKS Cluster control plane without attaching compute service. - create_cluster = EksCreateClusterOperator( - task_id='create_eks_cluster', - cluster_name=CLUSTER_NAME, - cluster_role_arn=ROLE_ARN, - resources_vpc_config=VPC_CONFIG, - compute=None, - ) - # [END howto_operator_eks_create_cluster] - - # [START howto_sensor_eks_cluster] - await_create_cluster = EksClusterStateSensor( - task_id='wait_for_create_cluster', - cluster_name=CLUSTER_NAME, - target_state=ClusterStates.ACTIVE, - ) - # [END howto_sensor_eks_cluster] - - # [START howto_operator_eks_create_nodegroup] - create_nodegroup = EksCreateNodegroupOperator( - task_id='create_eks_nodegroup', - cluster_name=CLUSTER_NAME, - nodegroup_name=NODEGROUP_NAME, - nodegroup_subnets=SUBNETS, - nodegroup_role_arn=ROLE_ARN, - ) - # [END howto_operator_eks_create_nodegroup] - - # [START howto_sensor_eks_nodegroup] - await_create_nodegroup = EksNodegroupStateSensor( - task_id='wait_for_create_nodegroup', - cluster_name=CLUSTER_NAME, - nodegroup_name=NODEGROUP_NAME, - target_state=NodegroupStates.ACTIVE, - ) - # [END howto_sensor_eks_nodegroup] - - # [START howto_operator_eks_pod_operator] - start_pod = EksPodOperator( - task_id="run_pod", - cluster_name=CLUSTER_NAME, - pod_name="run_pod", - image="amazon/aws-cli:latest", - cmds=["sh", "-c", "ls"], - labels={"demo": "hello_world"}, - get_logs=True, - # Delete the pod when it reaches its final state, or the execution is interrupted. - is_delete_operator_pod=True, - ) - # [END howto_operator_eks_pod_operator] - - # [START howto_operator_eks_delete_nodegroup] - delete_nodegroup = EksDeleteNodegroupOperator( - task_id='delete_eks_nodegroup', - cluster_name=CLUSTER_NAME, - nodegroup_name=NODEGROUP_NAME, - ) - # [END howto_operator_eks_delete_nodegroup] - - await_delete_nodegroup = EksNodegroupStateSensor( - task_id='wait_for_delete_nodegroup', - cluster_name=CLUSTER_NAME, - nodegroup_name=NODEGROUP_NAME, - target_state=NodegroupStates.NONEXISTENT, - ) - - # [START howto_operator_eks_delete_cluster] - delete_cluster = EksDeleteClusterOperator( - task_id='delete_eks_cluster', - cluster_name=CLUSTER_NAME, - ) - # [END howto_operator_eks_delete_cluster] - - await_delete_cluster = EksClusterStateSensor( - task_id='wait_for_delete_cluster', - cluster_name=CLUSTER_NAME, - target_state=ClusterStates.NONEXISTENT, - ) - - ( - create_cluster - >> await_create_cluster - >> create_nodegroup - >> await_create_nodegroup - >> start_pod - >> delete_nodegroup - >> await_delete_nodegroup - >> delete_cluster - >> await_delete_cluster - ) diff --git a/airflow/providers/amazon/aws/example_dags/example_emr_eks_job.py b/airflow/providers/amazon/aws/example_dags/example_emr_eks_job.py deleted file mode 100644 index e8932630d9d45..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_emr_eks_job.py +++ /dev/null @@ -1,75 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -from datetime import datetime - -from airflow import DAG -from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator - -VIRTUAL_CLUSTER_ID = os.getenv("VIRTUAL_CLUSTER_ID", "test-cluster") -JOB_ROLE_ARN = os.getenv("JOB_ROLE_ARN", "arn:aws:iam::012345678912:role/emr_eks_default_role") - -# [START howto_operator_emr_eks_config] -JOB_DRIVER_ARG = { - "sparkSubmitJobDriver": { - "entryPoint": "local:///usr/lib/spark/examples/src/main/python/pi.py", - "sparkSubmitParameters": "--conf spark.executors.instances=2 --conf spark.executors.memory=2G --conf spark.executor.cores=2 --conf spark.driver.cores=1", # noqa: E501 - } -} - -CONFIGURATION_OVERRIDES_ARG = { - "applicationConfiguration": [ - { - "classification": "spark-defaults", - "properties": { - "spark.hadoop.hive.metastore.client.factory.class": "com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory", # noqa: E501 - }, - } - ], - "monitoringConfiguration": { - "cloudWatchMonitoringConfiguration": { - "logGroupName": "/aws/emr-eks-spark", - "logStreamNamePrefix": "airflow", - } - }, -} -# [END howto_operator_emr_eks_config] - -with DAG( - dag_id='example_emr_eks_job', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - - # An example of how to get the cluster id and arn from an Airflow connection - # VIRTUAL_CLUSTER_ID = '{{ conn.emr_eks.extra_dejson["virtual_cluster_id"] }}' - # JOB_ROLE_ARN = '{{ conn.emr_eks.extra_dejson["job_role_arn"] }}' - - # [START howto_operator_emr_eks_job] - job_starter = EmrContainerOperator( - task_id="start_job", - virtual_cluster_id=VIRTUAL_CLUSTER_ID, - execution_role_arn=JOB_ROLE_ARN, - release_label="emr-6.3.0-latest", - job_driver=JOB_DRIVER_ARG, - configuration_overrides=CONFIGURATION_OVERRIDES_ARG, - name="pi.py", - ) - # [END howto_operator_emr_eks_job] diff --git a/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_automatic_steps.py b/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_automatic_steps.py deleted file mode 100644 index ab92a3cb21a38..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_automatic_steps.py +++ /dev/null @@ -1,84 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import os -from datetime import datetime - -from airflow import DAG -from airflow.providers.amazon.aws.operators.emr import EmrCreateJobFlowOperator -from airflow.providers.amazon.aws.sensors.emr import EmrJobFlowSensor - -JOB_FLOW_ROLE = os.getenv('EMR_JOB_FLOW_ROLE', 'EMR_EC2_DefaultRole') -SERVICE_ROLE = os.getenv('EMR_SERVICE_ROLE', 'EMR_DefaultRole') - -# [START howto_operator_emr_automatic_steps_config] -SPARK_STEPS = [ - { - 'Name': 'calculate_pi', - 'ActionOnFailure': 'CONTINUE', - 'HadoopJarStep': { - 'Jar': 'command-runner.jar', - 'Args': ['/usr/lib/spark/bin/run-example', 'SparkPi', '10'], - }, - } -] - -JOB_FLOW_OVERRIDES = { - 'Name': 'PiCalc', - 'ReleaseLabel': 'emr-5.29.0', - 'Applications': [{'Name': 'Spark'}], - 'Instances': { - 'InstanceGroups': [ - { - 'Name': 'Primary node', - 'Market': 'ON_DEMAND', - 'InstanceRole': 'MASTER', - 'InstanceType': 'm5.xlarge', - 'InstanceCount': 1, - }, - ], - 'KeepJobFlowAliveWhenNoSteps': False, - 'TerminationProtected': False, - }, - 'Steps': SPARK_STEPS, - 'JobFlowRole': JOB_FLOW_ROLE, - 'ServiceRole': SERVICE_ROLE, -} -# [END howto_operator_emr_automatic_steps_config] - - -with DAG( - dag_id='example_emr_job_flow_automatic_steps', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - - # [START howto_operator_emr_create_job_flow] - job_flow_creator = EmrCreateJobFlowOperator( - task_id='create_job_flow', - job_flow_overrides=JOB_FLOW_OVERRIDES, - ) - # [END howto_operator_emr_create_job_flow] - - # [START howto_sensor_emr_job_flow_sensor] - job_sensor = EmrJobFlowSensor( - task_id='check_job_flow', - job_flow_id=job_flow_creator.output, - ) - # [END howto_sensor_emr_job_flow_sensor] diff --git a/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_manual_steps.py b/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_manual_steps.py deleted file mode 100644 index d18237ecb0723..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_manual_steps.py +++ /dev/null @@ -1,106 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import os -from datetime import datetime - -from airflow import DAG -from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.operators.emr import ( - EmrAddStepsOperator, - EmrCreateJobFlowOperator, - EmrTerminateJobFlowOperator, -) -from airflow.providers.amazon.aws.sensors.emr import EmrStepSensor - -JOB_FLOW_ROLE = os.getenv('EMR_JOB_FLOW_ROLE', 'EMR_EC2_DefaultRole') -SERVICE_ROLE = os.getenv('EMR_SERVICE_ROLE', 'EMR_DefaultRole') - -SPARK_STEPS = [ - { - 'Name': 'calculate_pi', - 'ActionOnFailure': 'CONTINUE', - 'HadoopJarStep': { - 'Jar': 'command-runner.jar', - 'Args': ['/usr/lib/spark/bin/run-example', 'SparkPi', '10'], - }, - } -] - -JOB_FLOW_OVERRIDES = { - 'Name': 'PiCalc', - 'ReleaseLabel': 'emr-5.29.0', - 'Applications': [{'Name': 'Spark'}], - 'Instances': { - 'InstanceGroups': [ - { - 'Name': 'Primary node', - 'Market': 'ON_DEMAND', - 'InstanceRole': 'MASTER', - 'InstanceType': 'm5.xlarge', - 'InstanceCount': 1, - }, - ], - 'KeepJobFlowAliveWhenNoSteps': False, - 'TerminationProtected': False, - }, - 'JobFlowRole': JOB_FLOW_ROLE, - 'ServiceRole': SERVICE_ROLE, -} - - -with DAG( - dag_id='example_emr_job_flow_manual_steps', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - - cluster_creator = EmrCreateJobFlowOperator( - task_id='create_job_flow', - job_flow_overrides=JOB_FLOW_OVERRIDES, - ) - - # [START howto_operator_emr_add_steps] - step_adder = EmrAddStepsOperator( - task_id='add_steps', - job_flow_id=cluster_creator.output, - steps=SPARK_STEPS, - ) - # [END howto_operator_emr_add_steps] - - # [START howto_sensor_emr_step_sensor] - step_checker = EmrStepSensor( - task_id='watch_step', - job_flow_id=cluster_creator.output, - step_id="{{ task_instance.xcom_pull(task_ids='add_steps', key='return_value')[0] }}", - ) - # [END howto_sensor_emr_step_sensor] - - # [START howto_operator_emr_terminate_job_flow] - cluster_remover = EmrTerminateJobFlowOperator( - task_id='remove_cluster', - job_flow_id=cluster_creator.output, - ) - # [END howto_operator_emr_terminate_job_flow] - - chain( - step_adder, - step_checker, - cluster_remover, - ) diff --git a/airflow/providers/amazon/aws/example_dags/example_ftp_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_ftp_to_s3.py index d01ca38810729..4c187dc1e0c38 100644 --- a/airflow/providers/amazon/aws/example_dags/example_ftp_to_s3.py +++ b/airflow/providers/amazon/aws/example_dags/example_ftp_to_s3.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +from __future__ import annotations import os from datetime import datetime @@ -27,7 +27,6 @@ with models.DAG( "example_ftp_to_s3", - schedule_interval=None, start_date=datetime(2021, 1, 1), catchup=False, ) as dag: diff --git a/airflow/providers/amazon/aws/example_dags/example_gcs_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_gcs_to_s3.py index d9d04c73ffa31..ad5c8072df0a8 100644 --- a/airflow/providers/amazon/aws/example_dags/example_gcs_to_s3.py +++ b/airflow/providers/amazon/aws/example_dags/example_gcs_to_s3.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import os from datetime import datetime @@ -26,7 +27,6 @@ with DAG( dag_id="example_gcs_to_s3", - schedule_interval=None, start_date=datetime(2021, 1, 1), tags=["example"], catchup=False, diff --git a/airflow/providers/amazon/aws/example_dags/example_glacier_to_gcs.py b/airflow/providers/amazon/aws/example_dags/example_glacier_to_gcs.py index 2df40a7d0c37b..9c9f8a1976e89 100644 --- a/airflow/providers/amazon/aws/example_dags/example_glacier_to_gcs.py +++ b/airflow/providers/amazon/aws/example_dags/example_glacier_to_gcs.py @@ -14,11 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import os from datetime import datetime from airflow import DAG -from airflow.providers.amazon.aws.operators.glacier import GlacierCreateJobOperator +from airflow.providers.amazon.aws.operators.glacier import ( + GlacierCreateJobOperator, + GlacierUploadArchiveOperator, +) from airflow.providers.amazon.aws.sensors.glacier import GlacierJobOperationSensor from airflow.providers.amazon.aws.transfers.glacier_to_gcs import GlacierToGCSOperator @@ -28,7 +33,6 @@ with DAG( "example_glacier_to_gcs", - schedule_interval=None, start_date=datetime(2021, 1, 1), # Override to match your needs catchup=False, ) as dag: @@ -45,6 +49,12 @@ ) # [END howto_sensor_glacier_job_operation] + # [START howto_operator_glacier_upload_archive] + upload_archive_to_glacier = GlacierUploadArchiveOperator( + vault_name=VAULT_NAME, body=b"Test Data", task_id="upload_data_to_glacier" + ) + # [END howto_operator_glacier_upload_archive] + # [START howto_transfer_glacier_to_gcs] transfer_archive_to_gcs = GlacierToGCSOperator( task_id="transfer_archive_to_gcs", @@ -59,4 +69,4 @@ ) # [END howto_transfer_glacier_to_gcs] - create_glacier_job >> wait_for_operation_complete >> transfer_archive_to_gcs + create_glacier_job >> wait_for_operation_complete >> upload_archive_to_glacier >> transfer_archive_to_gcs diff --git a/airflow/providers/amazon/aws/example_dags/example_glue.py b/airflow/providers/amazon/aws/example_dags/example_glue.py deleted file mode 100644 index 65417dc24f97c..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_glue.py +++ /dev/null @@ -1,123 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from datetime import datetime -from os import getenv - -from airflow import DAG -from airflow.decorators import task -from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.hooks.s3 import S3Hook -from airflow.providers.amazon.aws.operators.glue import GlueJobOperator -from airflow.providers.amazon.aws.operators.glue_crawler import GlueCrawlerOperator -from airflow.providers.amazon.aws.sensors.glue import GlueJobSensor -from airflow.providers.amazon.aws.sensors.glue_crawler import GlueCrawlerSensor - -GLUE_DATABASE_NAME = getenv('GLUE_DATABASE_NAME', 'glue_database_name') -GLUE_EXAMPLE_S3_BUCKET = getenv('GLUE_EXAMPLE_S3_BUCKET', 'glue_example_s3_bucket') - -# Role needs putobject/getobject access to the above bucket as well as the glue -# service role, see docs here: https://docs.aws.amazon.com/glue/latest/dg/create-an-iam-role.html -GLUE_CRAWLER_ROLE = getenv('GLUE_CRAWLER_ROLE', 'glue_crawler_role') -GLUE_CRAWLER_NAME = 'example_crawler' -GLUE_CRAWLER_CONFIG = { - 'Name': GLUE_CRAWLER_NAME, - 'Role': GLUE_CRAWLER_ROLE, - 'DatabaseName': GLUE_DATABASE_NAME, - 'Targets': { - 'S3Targets': [ - { - 'Path': f'{GLUE_EXAMPLE_S3_BUCKET}/input', - } - ] - }, -} - -# Example csv data used as input to the example AWS Glue Job. -EXAMPLE_CSV = ''' -apple,0.5 -milk,2.5 -bread,4.0 -''' - -# Example Spark script to operate on the above sample csv data. -EXAMPLE_SCRIPT = f''' -from pyspark.context import SparkContext -from awsglue.context import GlueContext - -glueContext = GlueContext(SparkContext.getOrCreate()) -datasource = glueContext.create_dynamic_frame.from_catalog( - database='{GLUE_DATABASE_NAME}', table_name='input') -print('There are %s items in the table' % datasource.count()) - -datasource.toDF().write.format('csv').mode("append").save('s3://{GLUE_EXAMPLE_S3_BUCKET}/output') -''' - - -@task(task_id='setup__upload_artifacts_to_s3') -def upload_artifacts_to_s3(): - '''Upload example CSV input data and an example Spark script to be used by the Glue Job''' - s3_hook = S3Hook() - s3_load_kwargs = {"replace": True, "bucket_name": GLUE_EXAMPLE_S3_BUCKET} - s3_hook.load_string(string_data=EXAMPLE_CSV, key='input/input.csv', **s3_load_kwargs) - s3_hook.load_string(string_data=EXAMPLE_SCRIPT, key='etl_script.py', **s3_load_kwargs) - - -with DAG( - dag_id='example_glue', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as glue_dag: - - setup_upload_artifacts_to_s3 = upload_artifacts_to_s3() - - # [START howto_operator_glue_crawler] - crawl_s3 = GlueCrawlerOperator( - task_id='crawl_s3', - config=GLUE_CRAWLER_CONFIG, - wait_for_completion=False, - ) - # [END howto_operator_glue_crawler] - - # [START howto_sensor_glue_crawler] - wait_for_crawl = GlueCrawlerSensor(task_id='wait_for_crawl', crawler_name=GLUE_CRAWLER_NAME) - # [END howto_sensor_glue_crawler] - - # [START howto_operator_glue] - job_name = 'example_glue_job' - submit_glue_job = GlueJobOperator( - task_id='submit_glue_job', - job_name=job_name, - wait_for_completion=False, - script_location=f's3://{GLUE_EXAMPLE_S3_BUCKET}/etl_script.py', - s3_bucket=GLUE_EXAMPLE_S3_BUCKET, - iam_role_name=GLUE_CRAWLER_ROLE.split('/')[-1], - create_job_kwargs={'GlueVersion': '3.0', 'NumberOfWorkers': 2, 'WorkerType': 'G.1X'}, - ) - # [END howto_operator_glue] - - # [START howto_sensor_glue] - wait_for_job = GlueJobSensor( - task_id='wait_for_job', - job_name=job_name, - # Job ID extracted from previous Glue Job Operator task - run_id=submit_glue_job.output, - ) - # [END howto_sensor_glue] - - chain(setup_upload_artifacts_to_s3, crawl_s3, wait_for_crawl, submit_glue_job, wait_for_job) diff --git a/airflow/providers/amazon/aws/example_dags/example_google_api_sheets_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_google_api_sheets_to_s3.py index 7b6e4b291a66c..c405e802ac259 100644 --- a/airflow/providers/amazon/aws/example_dags/example_google_api_sheets_to_s3.py +++ b/airflow/providers/amazon/aws/example_dags/example_google_api_sheets_to_s3.py @@ -18,6 +18,7 @@ This is a basic example dag for using `GoogleApiToS3Operator` to retrieve Google Sheets data You need to set all env variables to request the data. """ +from __future__ import annotations from datetime import datetime from os import getenv @@ -31,18 +32,17 @@ with DAG( dag_id="example_google_api_sheets_to_s3", - schedule_interval=None, start_date=datetime(2021, 1, 1), catchup=False, - tags=['example'], + tags=["example"], ) as dag: # [START howto_transfer_google_api_sheets_to_s3] task_google_sheets_values_to_s3 = GoogleApiToS3Operator( - task_id='google_sheet_data_to_s3', - google_api_service_name='sheets', - google_api_service_version='v4', - google_api_endpoint_path='sheets.spreadsheets.values.get', - google_api_endpoint_params={'spreadsheetId': GOOGLE_SHEET_ID, 'range': GOOGLE_SHEET_RANGE}, + task_id="google_sheet_data_to_s3", + google_api_service_name="sheets", + google_api_service_version="v4", + google_api_endpoint_path="sheets.spreadsheets.values.get", + google_api_endpoint_params={"spreadsheetId": GOOGLE_SHEET_ID, "range": GOOGLE_SHEET_RANGE}, s3_destination_key=S3_DESTINATION_KEY, ) # [END howto_transfer_google_api_sheets_to_s3] diff --git a/airflow/providers/amazon/aws/example_dags/example_google_api_youtube_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_google_api_youtube_to_s3.py deleted file mode 100644 index 682cc83919129..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_google_api_youtube_to_s3.py +++ /dev/null @@ -1,113 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This is a more advanced example dag for using `GoogleApiToS3Operator` which uses xcom to pass data between -tasks to retrieve specific information about YouTube videos: - -First it searches for up to 50 videos (due to pagination) in a given time range -(YOUTUBE_VIDEO_PUBLISHED_AFTER, YOUTUBE_VIDEO_PUBLISHED_BEFORE) on a YouTube channel (YOUTUBE_CHANNEL_ID) -saves the response in S3 + passes over the YouTube IDs to the next request which then gets the information -(YOUTUBE_VIDEO_FIELDS) for the requested videos and saves them in S3 (S3_DESTINATION_KEY). - -Further information: - -YOUTUBE_VIDEO_PUBLISHED_AFTER and YOUTUBE_VIDEO_PUBLISHED_BEFORE needs to be formatted -"YYYY-MM-DDThh:mm:ss.sZ". See https://developers.google.com/youtube/v3/docs/search/list for more information. -YOUTUBE_VIDEO_PARTS depends on the fields you pass via YOUTUBE_VIDEO_FIELDS. See -https://developers.google.com/youtube/v3/docs/videos/list#parameters for more information. -YOUTUBE_CONN_ID is optional for public videos. It does only need to authenticate when there are private videos -on a YouTube channel you want to retrieve. -""" - -from datetime import datetime -from os import getenv - -from airflow import DAG -from airflow.decorators import task -from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.transfers.google_api_to_s3 import GoogleApiToS3Operator - -YOUTUBE_CHANNEL_ID = getenv( - "YOUTUBE_CHANNEL_ID", "UCSXwxpWZQ7XZ1WL3wqevChA" -) # Youtube channel "Apache Airflow" -YOUTUBE_VIDEO_PUBLISHED_AFTER = getenv("YOUTUBE_VIDEO_PUBLISHED_AFTER", "2019-09-25T00:00:00Z") -YOUTUBE_VIDEO_PUBLISHED_BEFORE = getenv("YOUTUBE_VIDEO_PUBLISHED_BEFORE", "2019-10-18T00:00:00Z") -S3_BUCKET_NAME = getenv("S3_DESTINATION_KEY", "s3://bucket-test") -YOUTUBE_VIDEO_PARTS = getenv("YOUTUBE_VIDEO_PARTS", "snippet") -YOUTUBE_VIDEO_FIELDS = getenv("YOUTUBE_VIDEO_FIELDS", "items(id,snippet(description,publishedAt,tags,title))") - - -@task(task_id='transform_video_ids') -def transform_video_ids(**kwargs): - task_instance = kwargs['task_instance'] - output = task_instance.xcom_pull(task_ids="video_ids_to_s3", key="video_ids_response") - video_ids = [item['id']['videoId'] for item in output['items']] - - if not video_ids: - video_ids = [] - - kwargs['task_instance'].xcom_push(key='video_ids', value={'id': ','.join(video_ids)}) - - -with DAG( - dag_id="example_google_api_youtube_to_s3", - schedule_interval=None, - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['example'], -) as dag: - # [START howto_transfer_google_api_youtube_search_to_s3] - task_video_ids_to_s3 = GoogleApiToS3Operator( - task_id='video_ids_to_s3', - google_api_service_name='youtube', - google_api_service_version='v3', - google_api_endpoint_path='youtube.search.list', - google_api_endpoint_params={ - 'part': 'snippet', - 'channelId': YOUTUBE_CHANNEL_ID, - 'maxResults': 50, - 'publishedAfter': YOUTUBE_VIDEO_PUBLISHED_AFTER, - 'publishedBefore': YOUTUBE_VIDEO_PUBLISHED_BEFORE, - 'type': 'video', - 'fields': 'items/id/videoId', - }, - google_api_response_via_xcom='video_ids_response', - s3_destination_key=f'{S3_BUCKET_NAME}/youtube_search.json', - s3_overwrite=True, - ) - # [END howto_transfer_google_api_youtube_search_to_s3] - - task_transform_video_ids = transform_video_ids() - - # [START howto_transfer_google_api_youtube_list_to_s3] - task_video_data_to_s3 = GoogleApiToS3Operator( - task_id='video_data_to_s3', - google_api_service_name='youtube', - google_api_service_version='v3', - google_api_endpoint_path='youtube.videos.list', - google_api_endpoint_params={ - 'part': YOUTUBE_VIDEO_PARTS, - 'maxResults': 50, - 'fields': YOUTUBE_VIDEO_FIELDS, - }, - google_api_endpoint_params_via_xcom='video_ids', - s3_destination_key=f'{S3_BUCKET_NAME}/youtube_videos.json', - s3_overwrite=True, - ) - # [END howto_transfer_google_api_youtube_list_to_s3] - - chain(task_video_ids_to_s3, task_transform_video_ids, task_video_data_to_s3) diff --git a/airflow/providers/amazon/aws/example_dags/example_hive_to_dynamodb.py b/airflow/providers/amazon/aws/example_dags/example_hive_to_dynamodb.py index 6fccd64c6d596..1e4375febd607 100644 --- a/airflow/providers/amazon/aws/example_dags/example_hive_to_dynamodb.py +++ b/airflow/providers/amazon/aws/example_dags/example_hive_to_dynamodb.py @@ -14,12 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ This DAG will not work unless you create an Amazon EMR cluster running Apache Hive and copy data into it following steps 1-4 (inclusive) here: https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/EMRforDynamoDB.Tutorial.html """ +from __future__ import annotations import os from datetime import datetime @@ -31,33 +31,33 @@ from airflow.providers.amazon.aws.transfers.hive_to_dynamodb import HiveToDynamoDBOperator from airflow.utils import db -DYNAMODB_TABLE_NAME = 'example_hive_to_dynamodb_table' -HIVE_CONNECTION_ID = os.getenv('HIVE_CONNECTION_ID', 'hive_on_emr') -HIVE_HOSTNAME = os.getenv('HIVE_HOSTNAME', 'ec2-123-45-67-890.compute-1.amazonaws.com') +DYNAMODB_TABLE_NAME = "example_hive_to_dynamodb_table" +HIVE_CONNECTION_ID = os.getenv("HIVE_CONNECTION_ID", "hive_on_emr") +HIVE_HOSTNAME = os.getenv("HIVE_HOSTNAME", "ec2-123-45-67-890.compute-1.amazonaws.com") # These values assume you set up the Hive data source following the link above. -DYNAMODB_TABLE_HASH_KEY = 'feature_id' -HIVE_SQL = 'SELECT feature_id, feature_name, feature_class, state_alpha FROM hive_features' +DYNAMODB_TABLE_HASH_KEY = "feature_id" +HIVE_SQL = "SELECT feature_id, feature_name, feature_class, state_alpha FROM hive_features" @task def create_dynamodb_table(): - client = DynamoDBHook(client_type='dynamodb').conn + client = DynamoDBHook(client_type="dynamodb").conn client.create_table( TableName=DYNAMODB_TABLE_NAME, KeySchema=[ - {'AttributeName': DYNAMODB_TABLE_HASH_KEY, 'KeyType': 'HASH'}, + {"AttributeName": DYNAMODB_TABLE_HASH_KEY, "KeyType": "HASH"}, ], AttributeDefinitions=[ - {'AttributeName': DYNAMODB_TABLE_HASH_KEY, 'AttributeType': 'N'}, + {"AttributeName": DYNAMODB_TABLE_HASH_KEY, "AttributeType": "N"}, ], - ProvisionedThroughput={'ReadCapacityUnits': 20, 'WriteCapacityUnits': 20}, + ProvisionedThroughput={"ReadCapacityUnits": 20, "WriteCapacityUnits": 20}, ) # DynamoDB table creation is nearly, but not quite, instantaneous. # Wait for the table to be active to avoid race conditions writing to it. - waiter = client.get_waiter('table_exists') - waiter.wait(TableName=DYNAMODB_TABLE_NAME, WaiterConfig={'Delay': 1}) + waiter = client.get_waiter("table_exists") + waiter.wait(TableName=DYNAMODB_TABLE_NAME, WaiterConfig={"Delay": 1}) @task @@ -66,24 +66,24 @@ def get_dynamodb_item_count(): A DynamoDB table has an ItemCount value, but it is only updated every six hours. To verify this DAG worked, we will scan the table and count the items manually. """ - table = DynamoDBHook(resource_type='dynamodb').conn.Table(DYNAMODB_TABLE_NAME) + table = DynamoDBHook(resource_type="dynamodb").conn.Table(DYNAMODB_TABLE_NAME) - response = table.scan(Select='COUNT') - item_count = response['Count'] + response = table.scan(Select="COUNT") + item_count = response["Count"] - while 'LastEvaluatedKey' in response: - response = table.scan(Select='COUNT', ExclusiveStartKey=response['LastEvaluatedKey']) - item_count += response['Count'] + while "LastEvaluatedKey" in response: + response = table.scan(Select="COUNT", ExclusiveStartKey=response["LastEvaluatedKey"]) + item_count += response["Count"] - print(f'DynamoDB table contains {item_count} items.') + print(f"DynamoDB table contains {item_count} items.") # Included for sample purposes only; in production you wouldn't delete # the table you just backed your data up to. Using 'all_done' so even # if an intermediate step fails, the DAG will clean up after itself. -@task(trigger_rule='all_done') +@task(trigger_rule="all_done") def delete_dynamodb_table(): - DynamoDBHook(client_type='dynamodb').conn.delete_table(TableName=DYNAMODB_TABLE_NAME) + DynamoDBHook(client_type="dynamodb").conn.delete_table(TableName=DYNAMODB_TABLE_NAME) # Included for sample purposes only; in production this should @@ -96,7 +96,7 @@ def configure_hive_connection(): db.merge_conn( Connection( conn_id=HIVE_CONNECTION_ID, - conn_type='hiveserver2', + conn_type="hiveserver2", host=HIVE_HOSTNAME, port=10000, ) @@ -104,10 +104,9 @@ def configure_hive_connection(): with DAG( - dag_id='example_hive_to_dynamodb', - schedule_interval=None, + dag_id="example_hive_to_dynamodb", start_date=datetime(2021, 1, 1), - tags=['example'], + tags=["example"], catchup=False, ) as dag: # Add the prerequisites docstring to the DAG in the UI. @@ -115,7 +114,7 @@ def configure_hive_connection(): # [START howto_transfer_hive_to_dynamodb] backup_to_dynamodb = HiveToDynamoDBOperator( - task_id='backup_to_dynamodb', + task_id="backup_to_dynamodb", hiveserver2_conn_id=HIVE_CONNECTION_ID, sql=HIVE_SQL, table_name=DYNAMODB_TABLE_NAME, diff --git a/airflow/providers/amazon/aws/example_dags/example_imap_attachment_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_imap_attachment_to_s3.py index 357d92a6f5694..3d63f115f0900 100644 --- a/airflow/providers/amazon/aws/example_dags/example_imap_attachment_to_s3.py +++ b/airflow/providers/amazon/aws/example_dags/example_imap_attachment_to_s3.py @@ -14,11 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ This is an example dag for using `ImapAttachmentToS3Operator` to transfer an email attachment via IMAP protocol from a mail server to S3 Bucket. """ +from __future__ import annotations from datetime import datetime from os import getenv @@ -35,13 +35,12 @@ with DAG( dag_id="example_imap_attachment_to_s3", start_date=datetime(2021, 1, 1), - schedule_interval=None, catchup=False, - tags=['example'], + tags=["example"], ) as dag: # [START howto_transfer_imap_attachment_to_s3] task_transfer_imap_attachment_to_s3 = ImapAttachmentToS3Operator( - task_id='transfer_imap_attachment_to_s3', + task_id="transfer_imap_attachment_to_s3", imap_attachment_name=IMAP_ATTACHMENT_NAME, s3_bucket=S3_BUCKET, s3_key=S3_KEY, diff --git a/airflow/providers/amazon/aws/example_dags/example_lambda.py b/airflow/providers/amazon/aws/example_dags/example_lambda.py deleted file mode 100644 index 3b87c3aa3c3fb..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_lambda.py +++ /dev/null @@ -1,45 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import json -from datetime import datetime, timedelta -from os import getenv - -from airflow import DAG -from airflow.providers.amazon.aws.operators.aws_lambda import AwsLambdaInvokeFunctionOperator - -# [START howto_operator_lambda_env_variables] -LAMBDA_FUNCTION_NAME = getenv("LAMBDA_FUNCTION_NAME", "test-function") -# [END howto_operator_lambda_env_variables] - -SAMPLE_EVENT = json.dumps({"SampleEvent": {"SampleData": {"Name": "XYZ", "DoB": "1993-01-01"}}}) - -with DAG( - dag_id='example_lambda', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - dagrun_timeout=timedelta(minutes=60), - tags=['example'], - catchup=False, -) as dag: - # [START howto_operator_lambda] - invoke_lambda_function = AwsLambdaInvokeFunctionOperator( - task_id='setup__invoke_lambda_function', - function_name=LAMBDA_FUNCTION_NAME, - payload=SAMPLE_EVENT, - ) - # [END howto_operator_lambda] diff --git a/airflow/providers/amazon/aws/example_dags/example_local_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_local_to_s3.py deleted file mode 100644 index 05f9c74d7685b..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_local_to_s3.py +++ /dev/null @@ -1,42 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -import os - -from airflow import models -from airflow.providers.amazon.aws.transfers.local_to_s3 import LocalFilesystemToS3Operator -from airflow.utils.dates import datetime - -S3_BUCKET = os.environ.get("S3_BUCKET", "test-bucket") -S3_KEY = os.environ.get("S3_KEY", "key") - -with models.DAG( - "example_local_to_s3", - schedule_interval=None, - start_date=datetime(2021, 1, 1), # Override to match your needs - catchup=False, -) as dag: - # [START howto_transfer_local_to_s3] - create_local_to_s3_job = LocalFilesystemToS3Operator( - task_id="create_local_to_s3_job", - filename="relative/path/to/file.csv", - dest_key=S3_KEY, - dest_bucket=S3_BUCKET, - replace=True, - ) - # [END howto_transfer_local_to_s3] diff --git a/airflow/providers/amazon/aws/example_dags/example_mongo_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_mongo_to_s3.py index e95964b59cee1..39e4d6ac89869 100644 --- a/airflow/providers/amazon/aws/example_dags/example_mongo_to_s3.py +++ b/airflow/providers/amazon/aws/example_dags/example_mongo_to_s3.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +from __future__ import annotations import os @@ -29,7 +29,6 @@ with models.DAG( "example_mongo_to_s3", - schedule_interval=None, start_date=datetime(2021, 1, 1), catchup=False, ) as dag: diff --git a/airflow/providers/amazon/aws/example_dags/example_quicksight.py b/airflow/providers/amazon/aws/example_dags/example_quicksight.py deleted file mode 100644 index 5c50a54492617..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_quicksight.py +++ /dev/null @@ -1,57 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -from datetime import datetime - -from airflow import DAG -from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.operators.quicksight import QuickSightCreateIngestionOperator -from airflow.providers.amazon.aws.sensors.quicksight import QuickSightSensor - -DATA_SET_ID = os.getenv("DATA_SET_ID", "data-set-id") -INGESTION_ID = os.getenv("INGESTION_ID", "ingestion-id") - -with DAG( - dag_id="example_quicksight", - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=["example"], - catchup=False, -) as dag: - # Create and Start the QuickSight SPICE data ingestion - # and does not wait for its completion - # [START howto_operator_quicksight_create_ingestion] - quicksight_create_ingestion_no_waiting = QuickSightCreateIngestionOperator( - task_id="quicksight_create_ingestion_no_waiting", - data_set_id=DATA_SET_ID, - ingestion_id=INGESTION_ID, - wait_for_completion=False, - ) - # [END howto_operator_quicksight_create_ingestion] - - # The following task checks the status of the QuickSight SPICE ingestion - # job until it succeeds. - # [START howto_sensor_quicksight] - quicksight_job_status = QuickSightSensor( - task_id="quicksight_job_status", - data_set_id=DATA_SET_ID, - ingestion_id=INGESTION_ID, - ) - # [END howto_sensor_quicksight] - - chain(quicksight_create_ingestion_no_waiting, quicksight_job_status) diff --git a/airflow/providers/amazon/aws/example_dags/example_rds_event.py b/airflow/providers/amazon/aws/example_dags/example_rds_event.py deleted file mode 100644 index 4ec8b6f5be3c4..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_rds_event.py +++ /dev/null @@ -1,58 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from datetime import datetime -from os import getenv - -from airflow import DAG -from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.operators.rds import ( - RdsCreateEventSubscriptionOperator, - RdsDeleteEventSubscriptionOperator, -) - -SUBSCRIPTION_NAME = getenv("SUBSCRIPTION_NAME", "subscription-name") -SNS_TOPIC_ARN = getenv("SNS_TOPIC_ARN", "arn:aws:sns:::MyTopic") -RDS_DB_IDENTIFIER = getenv("RDS_DB_IDENTIFIER", "database-identifier") - -with DAG( - dag_id='example_rds_event', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - # [START howto_operator_rds_create_event_subscription] - create_subscription = RdsCreateEventSubscriptionOperator( - task_id='create_subscription', - subscription_name=SUBSCRIPTION_NAME, - sns_topic_arn=SNS_TOPIC_ARN, - source_type='db-instance', - source_ids=[RDS_DB_IDENTIFIER], - event_categories=['availability'], - ) - # [END howto_operator_rds_create_event_subscription] - - # [START howto_operator_rds_delete_event_subscription] - delete_subscription = RdsDeleteEventSubscriptionOperator( - task_id='delete_subscription', - subscription_name=SUBSCRIPTION_NAME, - ) - # [END howto_operator_rds_delete_event_subscription] - - chain(create_subscription, delete_subscription) diff --git a/airflow/providers/amazon/aws/example_dags/example_rds_export.py b/airflow/providers/amazon/aws/example_dags/example_rds_export.py deleted file mode 100644 index 1dce5804910af..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_rds_export.py +++ /dev/null @@ -1,71 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from datetime import datetime -from os import getenv - -from airflow import DAG -from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.operators.rds import RdsCancelExportTaskOperator, RdsStartExportTaskOperator -from airflow.providers.amazon.aws.sensors.rds import RdsExportTaskExistenceSensor - -RDS_EXPORT_TASK_IDENTIFIER = getenv("RDS_EXPORT_TASK_IDENTIFIER", "export-task-identifier") -RDS_EXPORT_SOURCE_ARN = getenv( - "RDS_EXPORT_SOURCE_ARN", "arn:aws:rds:::snapshot:snap-id" -) -BUCKET_NAME = getenv("BUCKET_NAME", "bucket-name") -BUCKET_PREFIX = getenv("BUCKET_PREFIX", "bucket-prefix") -ROLE_ARN = getenv("ROLE_ARN", "arn:aws:iam:::role/Role") -KMS_KEY_ID = getenv("KMS_KEY_ID", "arn:aws:kms:::key/key-id") - - -with DAG( - dag_id='example_rds_export', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - # [START howto_operator_rds_start_export_task] - start_export = RdsStartExportTaskOperator( - task_id='start_export', - export_task_identifier=RDS_EXPORT_TASK_IDENTIFIER, - source_arn=RDS_EXPORT_SOURCE_ARN, - s3_bucket_name=BUCKET_NAME, - s3_prefix=BUCKET_PREFIX, - iam_role_arn=ROLE_ARN, - kms_key_id=KMS_KEY_ID, - ) - # [END howto_operator_rds_start_export_task] - - # [START howto_operator_rds_cancel_export] - cancel_export = RdsCancelExportTaskOperator( - task_id='cancel_export', - export_task_identifier=RDS_EXPORT_TASK_IDENTIFIER, - ) - # [END howto_operator_rds_cancel_export] - - # [START howto_sensor_rds_export_task_existence] - export_sensor = RdsExportTaskExistenceSensor( - task_id='export_sensor', - export_task_identifier=RDS_EXPORT_TASK_IDENTIFIER, - target_statuses=['canceled'], - ) - # [END howto_sensor_rds_export_task_existence] - - chain(start_export, cancel_export, export_sensor) diff --git a/airflow/providers/amazon/aws/example_dags/example_rds_snapshot.py b/airflow/providers/amazon/aws/example_dags/example_rds_snapshot.py deleted file mode 100644 index f7e1d02e07d49..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_rds_snapshot.py +++ /dev/null @@ -1,76 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from datetime import datetime -from os import getenv - -from airflow import DAG -from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.operators.rds import ( - RdsCopyDbSnapshotOperator, - RdsCreateDbSnapshotOperator, - RdsDeleteDbSnapshotOperator, -) -from airflow.providers.amazon.aws.sensors.rds import RdsSnapshotExistenceSensor - -RDS_DB_IDENTIFIER = getenv("RDS_DB_IDENTIFIER", "database-identifier") -RDS_DB_SNAPSHOT_IDENTIFIER = getenv("RDS_DB_SNAPSHOT_IDENTIFIER", "database-1-snap") - -with DAG( - dag_id='example_rds_snapshot', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - # [START howto_operator_rds_create_db_snapshot] - create_snapshot = RdsCreateDbSnapshotOperator( - task_id='create_snapshot', - db_type='instance', - db_identifier=RDS_DB_IDENTIFIER, - db_snapshot_identifier=RDS_DB_SNAPSHOT_IDENTIFIER, - ) - # [END howto_operator_rds_create_db_snapshot] - - # [START howto_sensor_rds_snapshot_existence] - snapshot_sensor = RdsSnapshotExistenceSensor( - task_id='snapshot_sensor', - db_type='instance', - db_snapshot_identifier=RDS_DB_IDENTIFIER, - target_statuses=['available'], - ) - # [END howto_sensor_rds_snapshot_existence] - - # [START howto_operator_rds_copy_snapshot] - copy_snapshot = RdsCopyDbSnapshotOperator( - task_id='copy_snapshot', - db_type='instance', - source_db_snapshot_identifier=RDS_DB_IDENTIFIER, - target_db_snapshot_identifier=f'{RDS_DB_IDENTIFIER}-copy', - ) - # [END howto_operator_rds_copy_snapshot] - - # [START howto_operator_rds_delete_snapshot] - delete_snapshot = RdsDeleteDbSnapshotOperator( - task_id='delete_snapshot', - db_type='instance', - db_snapshot_identifier=RDS_DB_IDENTIFIER, - ) - # [END howto_operator_rds_delete_snapshot] - - chain(create_snapshot, snapshot_sensor, copy_snapshot, delete_snapshot) diff --git a/airflow/providers/amazon/aws/example_dags/example_redshift_cluster.py b/airflow/providers/amazon/aws/example_dags/example_redshift_cluster.py deleted file mode 100644 index dc6deead3c618..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_redshift_cluster.py +++ /dev/null @@ -1,98 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from datetime import datetime -from os import getenv - -from airflow import DAG -from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.operators.redshift_cluster import ( - RedshiftCreateClusterOperator, - RedshiftDeleteClusterOperator, - RedshiftPauseClusterOperator, - RedshiftResumeClusterOperator, -) -from airflow.providers.amazon.aws.sensors.redshift_cluster import RedshiftClusterSensor - -REDSHIFT_CLUSTER_IDENTIFIER = getenv("REDSHIFT_CLUSTER_IDENTIFIER", "redshift-cluster-1") - -with DAG( - dag_id="example_redshift_cluster", - start_date=datetime(2021, 1, 1), - schedule_interval=None, - catchup=False, - tags=['example'], -) as dag: - # [START howto_operator_redshift_cluster] - task_create_cluster = RedshiftCreateClusterOperator( - task_id="redshift_create_cluster", - cluster_identifier=REDSHIFT_CLUSTER_IDENTIFIER, - cluster_type="single-node", - node_type="dc2.large", - master_username="adminuser", - master_user_password="dummypass", - ) - # [END howto_operator_redshift_cluster] - - # [START howto_sensor_redshift_cluster] - task_wait_cluster_available = RedshiftClusterSensor( - task_id='sensor_redshift_cluster_available', - cluster_identifier=REDSHIFT_CLUSTER_IDENTIFIER, - target_status='available', - poke_interval=5, - timeout=60 * 15, - ) - # [END howto_sensor_redshift_cluster] - - # [START howto_operator_redshift_pause_cluster] - task_pause_cluster = RedshiftPauseClusterOperator( - task_id='redshift_pause_cluster', - cluster_identifier=REDSHIFT_CLUSTER_IDENTIFIER, - ) - # [END howto_operator_redshift_pause_cluster] - - task_wait_cluster_paused = RedshiftClusterSensor( - task_id='sensor_redshift_cluster_paused', - cluster_identifier=REDSHIFT_CLUSTER_IDENTIFIER, - target_status='paused', - poke_interval=5, - timeout=60 * 15, - ) - - # [START howto_operator_redshift_resume_cluster] - task_resume_cluster = RedshiftResumeClusterOperator( - task_id='redshift_resume_cluster', - cluster_identifier=REDSHIFT_CLUSTER_IDENTIFIER, - ) - # [END howto_operator_redshift_resume_cluster] - - # [START howto_operator_redshift_delete_cluster] - task_delete_cluster = RedshiftDeleteClusterOperator( - task_id="delete_cluster", - cluster_identifier=REDSHIFT_CLUSTER_IDENTIFIER, - ) - # [END howto_operator_redshift_delete_cluster] - - chain( - task_create_cluster, - task_wait_cluster_available, - task_pause_cluster, - task_wait_cluster_paused, - task_resume_cluster, - task_delete_cluster, - ) diff --git a/airflow/providers/amazon/aws/example_dags/example_redshift_data_execute_sql.py b/airflow/providers/amazon/aws/example_dags/example_redshift_data_execute_sql.py deleted file mode 100644 index cfa3e4cefcb2d..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_redshift_data_execute_sql.py +++ /dev/null @@ -1,72 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from datetime import datetime -from os import getenv - -from airflow import DAG -from airflow.decorators import task -from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook -from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator - -REDSHIFT_CLUSTER_IDENTIFIER = getenv("REDSHIFT_CLUSTER_IDENTIFIER", "redshift_cluster_identifier") -REDSHIFT_DATABASE = getenv("REDSHIFT_DATABASE", "redshift_database") -REDSHIFT_DATABASE_USER = getenv("REDSHIFT_DATABASE_USER", "awsuser") - -REDSHIFT_QUERY = """ -SELECT table_schema, - table_name -FROM information_schema.tables -WHERE table_schema NOT IN ('information_schema', 'pg_catalog') - AND table_type = 'BASE TABLE' -ORDER BY table_schema, - table_name; - """ -POLL_INTERVAL = 10 - - -@task(task_id="output_results") -def output_query_results(statement_id): - hook = RedshiftDataHook() - resp = hook.conn.get_statement_result( - Id=statement_id, - ) - - print(resp) - return resp - - -with DAG( - dag_id="example_redshift_data_execute_sql", - start_date=datetime(2021, 1, 1), - schedule_interval=None, - catchup=False, - tags=['example'], -) as dag: - # [START howto_operator_redshift_data] - task_query = RedshiftDataOperator( - task_id='redshift_query', - cluster_identifier=REDSHIFT_CLUSTER_IDENTIFIER, - database=REDSHIFT_DATABASE, - db_user=REDSHIFT_DATABASE_USER, - sql=REDSHIFT_QUERY, - poll_interval=POLL_INTERVAL, - await_result=True, - ) - # [END howto_operator_redshift_data] - - task_output = output_query_results(task_query.output) diff --git a/airflow/providers/amazon/aws/example_dags/example_redshift_sql.py b/airflow/providers/amazon/aws/example_dags/example_redshift_sql.py deleted file mode 100644 index a71ef71934edc..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_redshift_sql.py +++ /dev/null @@ -1,79 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from datetime import datetime - -from airflow import DAG -from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.operators.redshift_sql import RedshiftSQLOperator - -with DAG( - dag_id="example_redshift_sql", - start_date=datetime(2021, 1, 1), - schedule_interval=None, - catchup=False, - tags=['example'], -) as dag: - setup__task_create_table = RedshiftSQLOperator( - task_id='setup__create_table', - sql=""" - CREATE TABLE IF NOT EXISTS fruit ( - fruit_id INTEGER, - name VARCHAR NOT NULL, - color VARCHAR NOT NULL - ); - """, - ) - - setup__task_insert_data = RedshiftSQLOperator( - task_id='setup__task_insert_data', - sql=[ - "INSERT INTO fruit VALUES ( 1, 'Banana', 'Yellow');", - "INSERT INTO fruit VALUES ( 2, 'Apple', 'Red');", - "INSERT INTO fruit VALUES ( 3, 'Lemon', 'Yellow');", - "INSERT INTO fruit VALUES ( 4, 'Grape', 'Purple');", - "INSERT INTO fruit VALUES ( 5, 'Pear', 'Green');", - "INSERT INTO fruit VALUES ( 6, 'Strawberry', 'Red');", - ], - ) - - # [START howto_operator_redshift_sql] - task_select_data = RedshiftSQLOperator( - task_id='task_get_all_table_data', sql="""CREATE TABLE more_fruit AS SELECT * FROM fruit;""" - ) - # [END howto_operator_redshift_sql] - - # [START howto_operator_redshift_sql_with_params] - task_select_filtered_data = RedshiftSQLOperator( - task_id='task_get_filtered_table_data', - sql="""CREATE TABLE filtered_fruit AS SELECT * FROM fruit WHERE color = '{{ params.color }}';""", - params={'color': 'Red'}, - ) - # [END howto_operator_redshift_sql_with_params] - - teardown__task_drop_table = RedshiftSQLOperator( - task_id='teardown__drop_table', - sql='DROP TABLE IF EXISTS fruit', - ) - - chain( - setup__task_create_table, - setup__task_insert_data, - [task_select_data, task_select_filtered_data], - teardown__task_drop_table, - ) diff --git a/airflow/providers/amazon/aws/example_dags/example_redshift_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_redshift_to_s3.py deleted file mode 100644 index 8116e02dc165c..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_redshift_to_s3.py +++ /dev/null @@ -1,43 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from datetime import datetime -from os import getenv - -from airflow import DAG -from airflow.providers.amazon.aws.transfers.redshift_to_s3 import RedshiftToS3Operator - -S3_BUCKET_NAME = getenv("S3_BUCKET_NAME", "s3_bucket_name") -S3_KEY = getenv("S3_KEY", "s3_key") -REDSHIFT_TABLE = getenv("REDSHIFT_TABLE", "redshift_table") - -with DAG( - dag_id="example_redshift_to_s3", - start_date=datetime(2021, 1, 1), - schedule_interval=None, - catchup=False, - tags=['example'], -) as dag: - # [START howto_transfer_redshift_to_s3] - task_transfer_redshift_to_s3 = RedshiftToS3Operator( - task_id='transfer_redshift_to_s3', - s3_bucket=S3_BUCKET_NAME, - s3_key=S3_KEY, - schema='PUBLIC', - table=REDSHIFT_TABLE, - ) - # [END howto_transfer_redshift_to_s3] diff --git a/airflow/providers/amazon/aws/example_dags/example_s3.py b/airflow/providers/amazon/aws/example_dags/example_s3.py deleted file mode 100644 index 7e06575d4a58b..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_s3.py +++ /dev/null @@ -1,224 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -from datetime import datetime -from typing import List - -from airflow.models.baseoperator import chain -from airflow.models.dag import DAG -from airflow.providers.amazon.aws.operators.s3 import ( - S3CopyObjectOperator, - S3CreateBucketOperator, - S3CreateObjectOperator, - S3DeleteBucketOperator, - S3DeleteBucketTaggingOperator, - S3DeleteObjectsOperator, - S3FileTransformOperator, - S3GetBucketTaggingOperator, - S3ListOperator, - S3ListPrefixesOperator, - S3PutBucketTaggingOperator, -) -from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor, S3KeysUnchangedSensor - -BUCKET_NAME = os.environ.get('BUCKET_NAME', 'test-airflow-12345') -BUCKET_NAME_2 = os.environ.get('BUCKET_NAME_2', 'test-airflow-123456') -KEY = os.environ.get('KEY', 'key') -KEY_2 = os.environ.get('KEY_2', 'key2') -# Empty string prefix refers to the bucket root -# See what prefix is here https://docs.aws.amazon.com/AmazonS3/latest/userguide/using-prefixes.html -PREFIX = os.environ.get('PREFIX', '') -DELIMITER = os.environ.get('DELIMITER', '/') -TAG_KEY = os.environ.get('TAG_KEY', 'test-s3-bucket-tagging-key') -TAG_VALUE = os.environ.get('TAG_VALUE', 'test-s3-bucket-tagging-value') -DATA = os.environ.get( - 'DATA', - ''' -apple,0.5 -milk,2.5 -bread,4.0 -''', -) - -with DAG( - dag_id='example_s3', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['example'], -) as dag: - # [START howto_sensor_s3_key_function_definition] - def check_fn(files: List) -> bool: - """ - Example of custom check: check if all files are bigger than 1kB - - :param files: List of S3 object attributes. - Format: [{ - 'Size': int - }] - :return: true if the criteria is met - :rtype: bool - """ - return all(f.get('Size', 0) > 1024 for f in files) - - # [END howto_sensor_s3_key_function_definition] - - # [START howto_operator_s3_create_bucket] - create_bucket = S3CreateBucketOperator( - task_id='s3_create_bucket', - bucket_name=BUCKET_NAME, - ) - # [END howto_operator_s3_create_bucket] - - # [START howto_operator_s3_put_bucket_tagging] - put_tagging = S3PutBucketTaggingOperator( - task_id='s3_put_bucket_tagging', - bucket_name=BUCKET_NAME, - key=TAG_KEY, - value=TAG_VALUE, - ) - # [END howto_operator_s3_put_bucket_tagging] - - # [START howto_operator_s3_get_bucket_tagging] - get_tagging = S3GetBucketTaggingOperator( - task_id='s3_get_bucket_tagging', - bucket_name=BUCKET_NAME, - ) - # [END howto_operator_s3_get_bucket_tagging] - - # [START howto_operator_s3_delete_bucket_tagging] - delete_tagging = S3DeleteBucketTaggingOperator( - task_id='s3_delete_bucket_tagging', - bucket_name=BUCKET_NAME, - ) - # [END howto_operator_s3_delete_bucket_tagging] - - # [START howto_operator_s3_create_object] - create_object = S3CreateObjectOperator( - task_id="s3_create_object", - s3_bucket=BUCKET_NAME, - s3_key=KEY, - data=DATA, - replace=True, - ) - # [END howto_operator_s3_create_object] - - # [START howto_operator_s3_list_prefixes] - list_prefixes = S3ListPrefixesOperator( - task_id="s3_list_prefix_operator", - bucket=BUCKET_NAME, - prefix=PREFIX, - delimiter=DELIMITER, - ) - # [END howto_operator_s3_list_prefixes] - - # [START howto_operator_s3_list] - list_keys = S3ListOperator( - task_id="s3_list_operator", - bucket=BUCKET_NAME, - prefix=PREFIX, - ) - # [END howto_operator_s3_list] - - # [START howto_sensor_s3_key_single_key] - # Check if a file exists - sensor_one_key = S3KeySensor( - task_id="s3_sensor_one_key", - bucket_name=BUCKET_NAME, - bucket_key=KEY, - ) - # [END howto_sensor_s3_key_single_key] - - # [START howto_sensor_s3_key_multiple_keys] - # Check if both files exist - sensor_two_keys = S3KeySensor( - task_id="s3_sensor_two_keys", - bucket_name=BUCKET_NAME, - bucket_key=[KEY, KEY_2], - ) - # [END howto_sensor_s3_key_multiple_keys] - - # [START howto_sensor_s3_key_function] - # Check if a file exists and match a certain pattern defined in check_fn - sensor_key_with_function = S3KeySensor( - task_id="s3_sensor_key_function", - bucket_name=BUCKET_NAME, - bucket_key=KEY, - check_fn=check_fn, - ) - # [END howto_sensor_s3_key_function] - - # [START howto_sensor_s3_keys_unchanged] - sensor_keys_unchanged = S3KeysUnchangedSensor( - task_id="s3_sensor_one_key_size", - bucket_name=BUCKET_NAME_2, - prefix=PREFIX, - inactivity_period=10, - ) - # [END howto_sensor_s3_keys_unchanged] - - # [START howto_operator_s3_copy_object] - copy_object = S3CopyObjectOperator( - task_id="s3_copy_object", - source_bucket_name=BUCKET_NAME, - dest_bucket_name=BUCKET_NAME_2, - source_bucket_key=KEY, - dest_bucket_key=KEY_2, - ) - # [END howto_operator_s3_copy_object] - - # [START howto_operator_s3_file_transform] - transforms_file = S3FileTransformOperator( - task_id="s3_file_transform", - source_s3_key=f's3://{BUCKET_NAME}/{KEY}', - dest_s3_key=f's3://{BUCKET_NAME_2}/{KEY_2}', - # Use `cp` command as transform script as an example - transform_script='cp', - replace=True, - ) - # [END howto_operator_s3_file_transform] - - # [START howto_operator_s3_delete_objects] - delete_objects = S3DeleteObjectsOperator( - task_id="s3_delete_objects", - bucket=BUCKET_NAME_2, - keys=KEY_2, - ) - # [END howto_operator_s3_delete_objects] - - # [START howto_operator_s3_delete_bucket] - delete_bucket = S3DeleteBucketOperator( - task_id='s3_delete_bucket', bucket_name=BUCKET_NAME, force_delete=True - ) - # [END howto_operator_s3_delete_bucket] - - chain( - create_bucket, - put_tagging, - get_tagging, - delete_tagging, - create_object, - list_prefixes, - list_keys, - [sensor_one_key, sensor_two_keys, sensor_key_with_function], - copy_object, - transforms_file, - sensor_keys_unchanged, - delete_objects, - delete_bucket, - ) diff --git a/airflow/providers/amazon/aws/example_dags/example_s3_to_ftp.py b/airflow/providers/amazon/aws/example_dags/example_s3_to_ftp.py index 6ebe2501c613b..47ca88a0af72b 100644 --- a/airflow/providers/amazon/aws/example_dags/example_s3_to_ftp.py +++ b/airflow/providers/amazon/aws/example_dags/example_s3_to_ftp.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +from __future__ import annotations import os from datetime import datetime @@ -27,7 +27,6 @@ with models.DAG( "example_s3_to_ftp", - schedule_interval=None, start_date=datetime(2021, 1, 1), catchup=False, ) as dag: diff --git a/airflow/providers/amazon/aws/example_dags/example_s3_to_redshift.py b/airflow/providers/amazon/aws/example_dags/example_s3_to_redshift.py deleted file mode 100644 index 82ae0660053f1..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_s3_to_redshift.py +++ /dev/null @@ -1,80 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from datetime import datetime -from os import getenv - -from airflow import DAG -from airflow.decorators import task -from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.hooks.s3 import S3Hook -from airflow.providers.amazon.aws.operators.redshift_sql import RedshiftSQLOperator -from airflow.providers.amazon.aws.transfers.s3_to_redshift import S3ToRedshiftOperator - -S3_BUCKET_NAME = getenv("S3_BUCKET_NAME", "s3_bucket_name") -S3_KEY = getenv("S3_KEY", "s3_filename") -REDSHIFT_TABLE = getenv("REDSHIFT_TABLE", "redshift_table") - - -@task(task_id='setup__add_sample_data_to_s3') -def task_add_sample_data_to_s3(): - s3_hook = S3Hook() - s3_hook.load_string("0,Airflow", f'{S3_KEY}/{REDSHIFT_TABLE}', S3_BUCKET_NAME, replace=True) - - -@task(task_id='teardown__remove_sample_data_from_s3') -def task_remove_sample_data_from_s3(): - s3_hook = S3Hook() - if s3_hook.check_for_key(f'{S3_KEY}/{REDSHIFT_TABLE}', S3_BUCKET_NAME): - s3_hook.delete_objects(S3_BUCKET_NAME, f'{S3_KEY}/{REDSHIFT_TABLE}') - - -with DAG( - dag_id="example_s3_to_redshift", - start_date=datetime(2021, 1, 1), - schedule_interval=None, - catchup=False, - tags=['example'], -) as dag: - add_sample_data_to_s3 = task_add_sample_data_to_s3() - - setup__task_create_table = RedshiftSQLOperator( - sql=f'CREATE TABLE IF NOT EXISTS {REDSHIFT_TABLE}(Id int, Name varchar)', - task_id='setup__create_table', - ) - # [START howto_transfer_s3_to_redshift] - task_transfer_s3_to_redshift = S3ToRedshiftOperator( - s3_bucket=S3_BUCKET_NAME, - s3_key=S3_KEY, - schema='PUBLIC', - table=REDSHIFT_TABLE, - copy_options=['csv'], - task_id='transfer_s3_to_redshift', - ) - # [END howto_transfer_s3_to_redshift] - teardown__task_drop_table = RedshiftSQLOperator( - sql=f'DROP TABLE IF EXISTS {REDSHIFT_TABLE}', - task_id='teardown__drop_table', - ) - - remove_sample_data_from_s3 = task_remove_sample_data_from_s3() - - chain( - [add_sample_data_to_s3, setup__task_create_table], - task_transfer_s3_to_redshift, - [teardown__task_drop_table, remove_sample_data_from_s3], - ) diff --git a/airflow/providers/amazon/aws/example_dags/example_s3_to_sftp.py b/airflow/providers/amazon/aws/example_dags/example_s3_to_sftp.py index d7983265e7697..1c625b84942e4 100644 --- a/airflow/providers/amazon/aws/example_dags/example_s3_to_sftp.py +++ b/airflow/providers/amazon/aws/example_dags/example_s3_to_sftp.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import os from datetime import datetime @@ -26,7 +27,6 @@ with models.DAG( "example_s3_to_sftp", - schedule_interval=None, start_date=datetime(2021, 1, 1), catchup=False, ) as dag: diff --git a/airflow/providers/amazon/aws/example_dags/example_sagemaker.py b/airflow/providers/amazon/aws/example_dags/example_sagemaker.py deleted file mode 100644 index 41e7666eabb8a..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_sagemaker.py +++ /dev/null @@ -1,467 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import base64 -import os -import subprocess -from datetime import datetime -from tempfile import NamedTemporaryFile - -import boto3 - -from airflow import DAG -from airflow.decorators import task -from airflow.providers.amazon.aws.hooks.s3 import S3Hook -from airflow.providers.amazon.aws.operators.sagemaker import ( - SageMakerDeleteModelOperator, - SageMakerModelOperator, - SageMakerProcessingOperator, - SageMakerTrainingOperator, - SageMakerTransformOperator, - SageMakerTuningOperator, -) -from airflow.providers.amazon.aws.sensors.sagemaker import ( - SageMakerTrainingSensor, - SageMakerTransformSensor, - SageMakerTuningSensor, -) - -# Project name will be used in naming the S3 buckets and various tasks. -# The dataset used in this example is identifying varieties of the Iris flower. -PROJECT_NAME = 'iris' -TIMESTAMP = '{{ ts_nodash }}' - -S3_BUCKET = os.getenv('S3_BUCKET', 'S3_bucket') -RAW_DATA_S3_KEY = f'{PROJECT_NAME}/preprocessing/input.csv' -INPUT_DATA_S3_KEY = f'{PROJECT_NAME}/processed-input-data' -TRAINING_OUTPUT_S3_KEY = f'{PROJECT_NAME}/results' -PREDICTION_OUTPUT_S3_KEY = f'{PROJECT_NAME}/transform' - -PROCESSING_LOCAL_INPUT_PATH = '/opt/ml/processing/input' -PROCESSING_LOCAL_OUTPUT_PATH = '/opt/ml/processing/output' - -MODEL_NAME = f'{PROJECT_NAME}-KNN-model' -# Job names can't be reused, so appending a timestamp ensures it is unique. -PROCESSING_JOB_NAME = f'{PROJECT_NAME}-processing-{TIMESTAMP}' -TRAINING_JOB_NAME = f'{PROJECT_NAME}-train-{TIMESTAMP}' -TRANSFORM_JOB_NAME = f'{PROJECT_NAME}-transform-{TIMESTAMP}' -TUNING_JOB_NAME = f'{PROJECT_NAME}-tune-{TIMESTAMP}' - -ROLE_ARN = os.getenv( - 'SAGEMAKER_ROLE_ARN', - 'arn:aws:iam::1234567890:role/service-role/AmazonSageMaker-ExecutionRole', -) -ECR_REPOSITORY = os.getenv('ECR_REPOSITORY', '1234567890.dkr.ecr.us-west-2.amazonaws.com/process_data') -REGION = ECR_REPOSITORY.split('.')[3] - -# For this example we are using a subset of Fischer's Iris Data Set. -# The full dataset can be found at UC Irvine's machine learning repository: -# https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data -DATASET = """ - 5.1,3.5,1.4,0.2,Iris-setosa - 4.9,3.0,1.4,0.2,Iris-setosa - 7.0,3.2,4.7,1.4,Iris-versicolor - 6.4,3.2,4.5,1.5,Iris-versicolor - 4.9,2.5,4.5,1.7,Iris-virginica - 7.3,2.9,6.3,1.8,Iris-virginica - """ -SAMPLE_SIZE = DATASET.count('\n') - 1 - -# The URI of an Amazon-provided docker image for handling KNN model training. This is a public ECR -# repo cited in public SageMaker documentation, so the account number does not need to be redacted. -# For more info see: https://docs.aws.amazon.com/sagemaker/latest/dg/ecr-us-west-2.html#knn-us-west-2.title -KNN_IMAGE_URI = '174872318107.dkr.ecr.us-west-2.amazonaws.com/knn' - -TASK_TIMEOUT = {'MaxRuntimeInSeconds': 6 * 60} - -RESOURCE_CONFIG = { - 'InstanceCount': 1, - 'InstanceType': 'ml.m5.large', - 'VolumeSizeInGB': 1, -} - -TRAINING_DATA_SOURCE = { - 'CompressionType': 'None', - 'ContentType': 'text/csv', - 'DataSource': { # type: ignore - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri': f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}/train.csv', - } - }, -} - -# Define configs for processing, training, model creation, and batch transform jobs -SAGEMAKER_PROCESSING_JOB_CONFIG = { - 'ProcessingJobName': PROCESSING_JOB_NAME, - 'RoleArn': f'{ROLE_ARN}', - 'ProcessingInputs': [ - { - 'InputName': 'input', - 'AppManaged': False, - 'S3Input': { - 'S3Uri': f's3://{S3_BUCKET}/{RAW_DATA_S3_KEY}', - 'LocalPath': PROCESSING_LOCAL_INPUT_PATH, - 'S3DataType': 'S3Prefix', - 'S3InputMode': 'File', - 'S3DataDistributionType': 'FullyReplicated', - 'S3CompressionType': 'None', - }, - }, - ], - 'ProcessingOutputConfig': { - 'Outputs': [ - { - 'OutputName': 'output', - 'S3Output': { - 'S3Uri': f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}', - 'LocalPath': PROCESSING_LOCAL_OUTPUT_PATH, - 'S3UploadMode': 'EndOfJob', - }, - 'AppManaged': False, - } - ] - }, - 'ProcessingResources': { - 'ClusterConfig': RESOURCE_CONFIG, - }, - 'StoppingCondition': TASK_TIMEOUT, - 'AppSpecification': { - 'ImageUri': ECR_REPOSITORY, - }, -} - -TRAINING_CONFIG = { - 'TrainingJobName': TRAINING_JOB_NAME, - 'RoleArn': ROLE_ARN, - 'AlgorithmSpecification': { - "TrainingImage": KNN_IMAGE_URI, - "TrainingInputMode": "File", - }, - 'HyperParameters': { - 'predictor_type': 'classifier', - 'feature_dim': '4', - 'k': '3', - 'sample_size': str(SAMPLE_SIZE), - }, - 'InputDataConfig': [ - { - 'ChannelName': 'train', - **TRAINING_DATA_SOURCE, # type: ignore [arg-type] - } - ], - 'OutputDataConfig': {'S3OutputPath': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/'}, - 'ResourceConfig': RESOURCE_CONFIG, - 'StoppingCondition': TASK_TIMEOUT, -} - -MODEL_CONFIG = { - 'ModelName': MODEL_NAME, - 'ExecutionRoleArn': ROLE_ARN, - 'PrimaryContainer': { - 'Mode': 'SingleModel', - 'Image': KNN_IMAGE_URI, - 'ModelDataUrl': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/{TRAINING_JOB_NAME}/output/model.tar.gz', - }, -} - -TRANSFORM_CONFIG = { - 'TransformJobName': TRANSFORM_JOB_NAME, - 'ModelName': MODEL_NAME, - 'TransformInput': { - 'DataSource': { - 'S3DataSource': { - 'S3DataType': 'S3Prefix', - 'S3Uri': f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}/test.csv', - } - }, - 'SplitType': 'Line', - 'ContentType': 'text/csv', - }, - 'TransformOutput': {'S3OutputPath': f's3://{S3_BUCKET}/{PREDICTION_OUTPUT_S3_KEY}'}, - 'TransformResources': { - 'InstanceCount': 1, - 'InstanceType': 'ml.m5.large', - }, -} - -TUNING_CONFIG = { - 'HyperParameterTuningJobName': TUNING_JOB_NAME, - 'HyperParameterTuningJobConfig': { - 'Strategy': 'Bayesian', - 'HyperParameterTuningJobObjective': { - 'MetricName': 'test:accuracy', - 'Type': 'Maximize', - }, - 'ResourceLimits': { - # You would bump these up in production as appropriate. - 'MaxNumberOfTrainingJobs': 1, - 'MaxParallelTrainingJobs': 1, - }, - 'ParameterRanges': { - 'CategoricalParameterRanges': [], - 'IntegerParameterRanges': [ - # Set the min and max values of the hyperparameters you want to tune. - { - 'Name': 'k', - 'MinValue': '1', - "MaxValue": str(SAMPLE_SIZE), - }, - { - 'Name': 'sample_size', - 'MinValue': '1', - 'MaxValue': str(SAMPLE_SIZE), - }, - ], - }, - }, - 'TrainingJobDefinition': { - 'StaticHyperParameters': { - 'predictor_type': 'classifier', - 'feature_dim': '4', - }, - 'AlgorithmSpecification': {'TrainingImage': KNN_IMAGE_URI, 'TrainingInputMode': 'File'}, - 'InputDataConfig': [ - { - 'ChannelName': 'train', - **TRAINING_DATA_SOURCE, # type: ignore [arg-type] - }, - { - 'ChannelName': 'test', - **TRAINING_DATA_SOURCE, # type: ignore [arg-type] - }, - ], - 'OutputDataConfig': {'S3OutputPath': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}'}, - 'ResourceConfig': RESOURCE_CONFIG, - 'StoppingCondition': TASK_TIMEOUT, - 'RoleArn': ROLE_ARN, - }, -} - - -# This script will be the entrypoint for the docker image which will handle preprocessing the raw data -# NOTE: The following string must remain dedented as it is being written to a file. -PREPROCESS_SCRIPT = ( - """ -import boto3 -import numpy as np -import pandas as pd - -def main(): - # Load the Iris dataset from {input_path}/input.csv, split it into train/test - # subsets, and write them to {output_path}/ for the Processing Operator. - - columns = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species'] - iris = pd.read_csv('{input_path}/input.csv', names=columns) - - # Process data - iris['species'] = iris['species'].replace({{'Iris-virginica': 0, 'Iris-versicolor': 1, 'Iris-setosa': 2}}) - iris = iris[['species', 'sepal_length', 'sepal_width', 'petal_length', 'petal_width']] - - # Split into test and train data - iris_train, iris_test = np.split( - iris.sample(frac=1, random_state=np.random.RandomState()), [int(0.7 * len(iris))] - ) - - # Remove the "answers" from the test set - iris_test.drop(['species'], axis=1, inplace=True) - - # Write the splits to disk - iris_train.to_csv('{output_path}/train.csv', index=False, header=False) - iris_test.to_csv('{output_path}/test.csv', index=False, header=False) - - print('Preprocessing Done.') - -if __name__ == "__main__": - main() - - """ -).format(input_path=PROCESSING_LOCAL_INPUT_PATH, output_path=PROCESSING_LOCAL_OUTPUT_PATH) - - -@task -def upload_dataset_to_s3(): - """Uploads the provided dataset to a designated Amazon S3 bucket.""" - S3Hook().load_string( - string_data=DATASET, - bucket_name=S3_BUCKET, - key=RAW_DATA_S3_KEY, - replace=True, - ) - - -@task -def build_and_upload_docker_image(): - """ - We need a Docker image with the following requirements: - - Has numpy, pandas, requests, and boto3 installed - - Has our data preprocessing script mounted and set as the entry point - """ - - # Fetch and parse ECR Token to be used for the docker push - ecr_client = boto3.client('ecr', region_name=REGION) - token = ecr_client.get_authorization_token() - credentials = (base64.b64decode(token['authorizationData'][0]['authorizationToken'])).decode('utf-8') - username, password = credentials.split(':') - - with NamedTemporaryFile(mode='w+t') as preprocessing_script, NamedTemporaryFile(mode='w+t') as dockerfile: - - preprocessing_script.write(PREPROCESS_SCRIPT) - preprocessing_script.flush() - - dockerfile.write( - f""" - FROM amazonlinux - COPY {preprocessing_script.name.split('/')[2]} /preprocessing.py - ADD credentials /credentials - ENV AWS_SHARED_CREDENTIALS_FILE=/credentials - RUN yum install python3 pip -y - RUN pip3 install boto3 pandas requests - CMD [ "python3", "/preprocessing.py"] - """ - ) - dockerfile.flush() - - docker_build_and_push_commands = f""" - cp /root/.aws/credentials /tmp/credentials && - docker build -f {dockerfile.name} -t {ECR_REPOSITORY} /tmp && - rm /tmp/credentials && - aws ecr get-login-password --region {REGION} | - docker login --username {username} --password {password} {ECR_REPOSITORY} && - docker push {ECR_REPOSITORY} - """ - docker_build = subprocess.Popen( - docker_build_and_push_commands, - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - _, err = docker_build.communicate() - - if docker_build.returncode != 0: - raise RuntimeError(err) - - -@task(trigger_rule='all_done') -def cleanup(): - # Delete S3 Artifacts - client = boto3.client('s3') - object_keys = [ - key['Key'] for key in client.list_objects_v2(Bucket=S3_BUCKET, Prefix=PROJECT_NAME)['Contents'] - ] - for key in object_keys: - client.delete_objects(Bucket=S3_BUCKET, Delete={'Objects': [{'Key': key}]}) - - -with DAG( - dag_id='example_sagemaker', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - - # [START howto_operator_sagemaker_processing] - preprocess_raw_data = SageMakerProcessingOperator( - task_id='preprocess_raw_data', - config=SAGEMAKER_PROCESSING_JOB_CONFIG, - do_xcom_push=False, - ) - # [END howto_operator_sagemaker_processing] - - # [START howto_operator_sagemaker_training] - train_model = SageMakerTrainingOperator( - task_id='train_model', - config=TRAINING_CONFIG, - # Waits by default, setting as False to demonstrate the Sensor below. - wait_for_completion=False, - do_xcom_push=False, - ) - # [END howto_operator_sagemaker_training] - - # [START howto_sensor_sagemaker_training] - await_training = SageMakerTrainingSensor( - task_id='await_training', - job_name=TRAINING_JOB_NAME, - ) - # [END howto_sensor_sagemaker_training] - - # [START howto_operator_sagemaker_model] - create_model = SageMakerModelOperator( - task_id='create_model', - config=MODEL_CONFIG, - do_xcom_push=False, - ) - # [END howto_operator_sagemaker_model] - - # [START howto_operator_sagemaker_tuning] - tune_model = SageMakerTuningOperator( - task_id='tune_model', - config=TUNING_CONFIG, - # Waits by default, setting as False to demonstrate the Sensor below. - wait_for_completion=False, - do_xcom_push=False, - ) - # [END howto_operator_sagemaker_tuning] - - # [START howto_sensor_sagemaker_tuning] - await_tune = SageMakerTuningSensor( - task_id='await_tuning', - job_name=TUNING_JOB_NAME, - ) - # [END howto_sensor_sagemaker_tuning] - - # [START howto_operator_sagemaker_transform] - test_model = SageMakerTransformOperator( - task_id='test_model', - config=TRANSFORM_CONFIG, - # Waits by default, setting as False to demonstrate the Sensor below. - wait_for_completion=False, - do_xcom_push=False, - ) - # [END howto_operator_sagemaker_transform] - - # [START howto_sensor_sagemaker_transform] - await_transform = SageMakerTransformSensor( - task_id='await_transform', - job_name=TRANSFORM_JOB_NAME, - ) - # [END howto_sensor_sagemaker_transform] - - # Trigger rule set to "all_done" so clean up will run regardless of success on other tasks. - # [START howto_operator_sagemaker_delete_model] - delete_model = SageMakerDeleteModelOperator( - task_id='delete_model', - config={'ModelName': MODEL_NAME}, - trigger_rule='all_done', - ) - # [END howto_operator_sagemaker_delete_model] - - ( - upload_dataset_to_s3() - >> build_and_upload_docker_image() - >> preprocess_raw_data - >> train_model - >> await_training - >> create_model - >> tune_model - >> await_tune - >> test_model - >> await_transform - >> cleanup() - >> delete_model - ) diff --git a/airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py b/airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py deleted file mode 100644 index 52ee5303f2553..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py +++ /dev/null @@ -1,230 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import json -import os -from datetime import datetime - -import boto3 - -from airflow import DAG -from airflow.decorators import task -from airflow.providers.amazon.aws.operators.s3 import S3CreateObjectOperator -from airflow.providers.amazon.aws.operators.sagemaker import ( - SageMakerDeleteModelOperator, - SageMakerEndpointConfigOperator, - SageMakerEndpointOperator, - SageMakerModelOperator, - SageMakerTrainingOperator, -) -from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerEndpointSensor - -# Project name will be used in naming the S3 buckets and various tasks. -# The dataset used in this example is identifying varieties of the Iris flower. -PROJECT_NAME = 'iris' -TIMESTAMP = '{{ ts_nodash }}' - -S3_BUCKET = os.getenv('S3_BUCKET', 'S3_bucket') -ROLE_ARN = os.getenv( - 'SAGEMAKER_ROLE_ARN', - 'arn:aws:iam::1234567890:role/service-role/AmazonSageMaker-ExecutionRole', -) -INPUT_DATA_S3_KEY = f'{PROJECT_NAME}/processed-input-data' -TRAINING_OUTPUT_S3_KEY = f'{PROJECT_NAME}/training-results' - -MODEL_NAME = f'{PROJECT_NAME}-KNN-model' -ENDPOINT_NAME = f'{PROJECT_NAME}-endpoint' -# Job names can't be reused, so appending a timestamp ensures it is unique. -ENDPOINT_CONFIG_JOB_NAME = f'{PROJECT_NAME}-endpoint-config-{TIMESTAMP}' -TRAINING_JOB_NAME = f'{PROJECT_NAME}-train-{TIMESTAMP}' - -# For an example of how to obtain the following train and test data, please see -# https://github.com/apache/airflow/blob/main/airflow/providers/amazon/aws/example_dags/example_sagemaker.py -TRAIN_DATA = '0,4.9,2.5,4.5,1.7\n1,7.0,3.2,4.7,1.4\n0,7.3,2.9,6.3,1.8\n2,5.1,3.5,1.4,0.2\n' -SAMPLE_TEST_DATA = '6.4,3.2,4.5,1.5' - -# The URI of an Amazon-provided docker image for handling KNN model training. This is a public ECR -# repo cited in public SageMaker documentation, so the account number does not need to be redacted. -# For more info see: https://docs.aws.amazon.com/sagemaker/latest/dg/ecr-us-west-2.html#knn-us-west-2.title -KNN_IMAGE_URI = '174872318107.dkr.ecr.us-west-2.amazonaws.com/knn' - -# Define configs for processing, training, model creation, and batch transform jobs -TRAINING_CONFIG = { - 'TrainingJobName': TRAINING_JOB_NAME, - 'RoleArn': ROLE_ARN, - 'AlgorithmSpecification': { - "TrainingImage": KNN_IMAGE_URI, - "TrainingInputMode": "File", - }, - 'HyperParameters': { - 'predictor_type': 'classifier', - 'feature_dim': '4', - 'k': '3', - 'sample_size': '6', - }, - 'InputDataConfig': [ - { - 'ChannelName': 'train', - 'CompressionType': 'None', - 'ContentType': 'text/csv', - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri': f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}/train.csv', - } - }, - } - ], - 'OutputDataConfig': {'S3OutputPath': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/'}, - 'ResourceConfig': { - 'InstanceCount': 1, - 'InstanceType': 'ml.m5.large', - 'VolumeSizeInGB': 1, - }, - 'StoppingCondition': {'MaxRuntimeInSeconds': 6 * 60}, -} - -MODEL_CONFIG = { - 'ModelName': MODEL_NAME, - 'ExecutionRoleArn': ROLE_ARN, - 'PrimaryContainer': { - 'Mode': 'SingleModel', - 'Image': KNN_IMAGE_URI, - 'ModelDataUrl': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/{TRAINING_JOB_NAME}/output/model.tar.gz', - }, -} - -ENDPOINT_CONFIG_CONFIG = { - 'EndpointConfigName': ENDPOINT_CONFIG_JOB_NAME, - 'ProductionVariants': [ - { - 'VariantName': f'{PROJECT_NAME}-demo', - 'ModelName': MODEL_NAME, - 'InstanceType': 'ml.t2.medium', - 'InitialInstanceCount': 1, - }, - ], -} - -DEPLOY_ENDPOINT_CONFIG = { - 'EndpointName': ENDPOINT_NAME, - 'EndpointConfigName': ENDPOINT_CONFIG_JOB_NAME, -} - - -@task -def call_endpoint(): - runtime = boto3.Session().client('sagemaker-runtime') - - response = runtime.invoke_endpoint( - EndpointName=ENDPOINT_NAME, - ContentType='text/csv', - Body=SAMPLE_TEST_DATA, - ) - - return json.loads(response["Body"].read().decode())['predictions'] - - -@task(trigger_rule='all_done') -def cleanup(): - # Delete Endpoint and Endpoint Config - client = boto3.client('sagemaker') - endpoint_config_name = client.list_endpoint_configs()['EndpointConfigs'][0]['EndpointConfigName'] - client.delete_endpoint_config(EndpointConfigName=endpoint_config_name) - client.delete_endpoint(EndpointName=ENDPOINT_NAME) - - # Delete S3 Artifacts - client = boto3.client('s3') - object_keys = [ - key['Key'] for key in client.list_objects_v2(Bucket=S3_BUCKET, Prefix=PROJECT_NAME)['Contents'] - ] - for key in object_keys: - client.delete_objects(Bucket=S3_BUCKET, Delete={'Objects': [{'Key': key}]}) - - -with DAG( - dag_id='example_sagemaker_endpoint', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - - upload_data = S3CreateObjectOperator( - task_id='upload_data', - s3_bucket=S3_BUCKET, - s3_key=f'{INPUT_DATA_S3_KEY}/train.csv', - data=TRAIN_DATA, - replace=True, - ) - - train_model = SageMakerTrainingOperator( - task_id='train_model', - config=TRAINING_CONFIG, - do_xcom_push=False, - ) - - create_model = SageMakerModelOperator( - task_id='create_model', - config=MODEL_CONFIG, - do_xcom_push=False, - ) - - # [START howto_operator_sagemaker_endpoint_config] - configure_endpoint = SageMakerEndpointConfigOperator( - task_id='configure_endpoint', - config=ENDPOINT_CONFIG_CONFIG, - do_xcom_push=False, - ) - # [END howto_operator_sagemaker_endpoint_config] - - # [START howto_operator_sagemaker_endpoint] - deploy_endpoint = SageMakerEndpointOperator( - task_id='deploy_endpoint', - config=DEPLOY_ENDPOINT_CONFIG, - # Waits by default, > train_model - >> create_model - >> configure_endpoint - >> deploy_endpoint - >> await_endpoint - >> call_endpoint() - >> cleanup() - >> delete_model - ) diff --git a/airflow/providers/amazon/aws/example_dags/example_salesforce_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_salesforce_to_s3.py index 735cda4d3af14..067b5fa86ef5f 100644 --- a/airflow/providers/amazon/aws/example_dags/example_salesforce_to_s3.py +++ b/airflow/providers/amazon/aws/example_dags/example_salesforce_to_s3.py @@ -14,11 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ This is a basic example DAG for using `SalesforceToS3Operator` to retrieve Salesforce account data and upload it to an Amazon S3 bucket. """ +from __future__ import annotations from datetime import datetime from os import getenv @@ -32,7 +32,6 @@ with DAG( dag_id="example_salesforce_to_s3", - schedule_interval=None, start_date=datetime(2021, 7, 8), catchup=False, tags=["example"], diff --git a/airflow/providers/amazon/aws/example_dags/example_sftp_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_sftp_to_s3.py index 0e2407a7d3546..24f480fb2fec3 100644 --- a/airflow/providers/amazon/aws/example_dags/example_sftp_to_s3.py +++ b/airflow/providers/amazon/aws/example_dags/example_sftp_to_s3.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +from __future__ import annotations import os from datetime import datetime @@ -27,7 +27,6 @@ with models.DAG( "example_sftp_to_s3", - schedule_interval=None, start_date=datetime(2021, 1, 1), catchup=False, ) as dag: diff --git a/airflow/providers/amazon/aws/example_dags/example_sns.py b/airflow/providers/amazon/aws/example_dags/example_sns.py deleted file mode 100644 index 782156b14c3d3..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_sns.py +++ /dev/null @@ -1,39 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from datetime import datetime -from os import environ - -from airflow import DAG -from airflow.providers.amazon.aws.operators.sns import SnsPublishOperator - -SNS_TOPIC_ARN = environ.get('SNS_TOPIC_ARN', 'arn:aws:sns:us-west-2:123456789012:dummy-topic-name') - -with DAG( - dag_id='example_sns', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - - # [START howto_operator_sns_publish_operator] - publish = SnsPublishOperator( - task_id='publish_message', - target_arn=SNS_TOPIC_ARN, - message='This is a sample message sent to SNS via an Apache Airflow DAG task.', - ) - # [END howto_operator_sns_publish_operator] diff --git a/airflow/providers/amazon/aws/example_dags/example_sql_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_sql_to_s3.py deleted file mode 100644 index df2abee0f3052..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_sql_to_s3.py +++ /dev/null @@ -1,44 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -import os -from datetime import datetime - -from airflow import models -from airflow.providers.amazon.aws.transfers.sql_to_s3 import SqlToS3Operator - -S3_BUCKET = os.environ.get("S3_BUCKET", "test-bucket") -S3_KEY = os.environ.get("S3_KEY", "key") -SQL_QUERY = os.environ.get("SQL_QUERY", "SHOW tables") - -with models.DAG( - "example_sql_to_s3", - schedule_interval=None, - start_date=datetime(2021, 1, 1), - catchup=False, -) as dag: - # [START howto_transfer_sql_to_s3] - sql_to_s3_task = SqlToS3Operator( - task_id="sql_to_s3_task", - sql_conn_id="mysql_default", - query=SQL_QUERY, - s3_bucket=S3_BUCKET, - s3_key=S3_KEY, - replace=True, - ) - # [END howto_transfer_sql_to_s3] diff --git a/airflow/providers/amazon/aws/example_dags/example_sqs.py b/airflow/providers/amazon/aws/example_dags/example_sqs.py deleted file mode 100644 index 69ff1e5b75831..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_sqs.py +++ /dev/null @@ -1,72 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from datetime import datetime -from os import getenv - -from airflow import DAG -from airflow.decorators import task -from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.hooks.sqs import SqsHook -from airflow.providers.amazon.aws.operators.sqs import SqsPublishOperator -from airflow.providers.amazon.aws.sensors.sqs import SqsSensor - -QUEUE_NAME = getenv('QUEUE_NAME', 'Airflow-Example-Queue') - - -@task(task_id="create_queue") -def create_queue_fn(): - """Create the example queue""" - hook = SqsHook() - result = hook.create_queue(queue_name=QUEUE_NAME) - return result['QueueUrl'] - - -@task(task_id="delete_queue") -def delete_queue_fn(queue_url): - """Delete the example queue""" - hook = SqsHook() - hook.get_conn().delete_queue(QueueUrl=queue_url) - - -with DAG( - dag_id='example_sqs', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - - create_queue = create_queue_fn() - - # [START howto_operator_sqs] - publish_to_queue = SqsPublishOperator( - task_id='publish_to_queue', - sqs_queue=create_queue, - message_content="{{ task_instance }}-{{ execution_date }}", - ) - # [END howto_operator_sqs] - - # [START howto_sensor_sqs] - read_from_queue = SqsSensor( - task_id='read_from_queue', - sqs_queue=create_queue, - ) - # [END howto_sensor_sqs] - - delete_queue = delete_queue_fn(create_queue) - - chain(create_queue, publish_to_queue, read_from_queue, delete_queue) diff --git a/airflow/providers/amazon/aws/example_dags/example_step_functions.py b/airflow/providers/amazon/aws/example_dags/example_step_functions.py deleted file mode 100644 index 02763e3ea13f1..0000000000000 --- a/airflow/providers/amazon/aws/example_dags/example_step_functions.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from datetime import datetime -from os import environ - -from airflow import DAG -from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.operators.step_function import ( - StepFunctionGetExecutionOutputOperator, - StepFunctionStartExecutionOperator, -) -from airflow.providers.amazon.aws.sensors.step_function import StepFunctionExecutionSensor - -STEP_FUNCTIONS_STATE_MACHINE_ARN = environ.get('STEP_FUNCTIONS_STATE_MACHINE_ARN', 'state_machine_arn') - -with DAG( - dag_id='example_step_functions', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - - # [START howto_operator_step_function_start_execution] - start_execution = StepFunctionStartExecutionOperator( - task_id='start_execution', state_machine_arn=STEP_FUNCTIONS_STATE_MACHINE_ARN - ) - # [END howto_operator_step_function_start_execution] - - # [START howto_sensor_step_function_execution] - wait_for_execution = StepFunctionExecutionSensor( - task_id='wait_for_execution', execution_arn=start_execution.output - ) - # [END howto_sensor_step_function_execution] - - # [START howto_operator_step_function_get_execution_output] - get_execution_output = StepFunctionGetExecutionOutputOperator( - task_id='get_execution_output', execution_arn=start_execution.output - ) - # [END howto_operator_step_function_get_execution_output] - - chain(start_execution, wait_for_execution, get_execution_output) diff --git a/airflow/providers/amazon/aws/exceptions.py b/airflow/providers/amazon/aws/exceptions.py index 27f46dfae7563..b606dc504f2dc 100644 --- a/airflow/providers/amazon/aws/exceptions.py +++ b/airflow/providers/amazon/aws/exceptions.py @@ -15,10 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations + # Note: Any AirflowException raised is expected to cause the TaskInstance # to be marked in an ERROR state -import warnings class EcsTaskFailToStart(Exception): @@ -42,19 +42,3 @@ def __init__(self, failures: list, message: str): def __reduce__(self): return EcsOperatorError, (self.failures, self.message) - - -class ECSOperatorError(EcsOperatorError): - """ - This class is deprecated. - Please use :class:`airflow.providers.amazon.aws.exceptions.EcsOperatorError`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This class is deprecated. " - "Please use `airflow.providers.amazon.aws.exceptions.EcsOperatorError`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/hooks/appflow.py b/airflow/providers/amazon/aws/hooks/appflow.py new file mode 100644 index 0000000000000..3bf57e50e0001 --- /dev/null +++ b/airflow/providers/amazon/aws/hooks/appflow.py @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from datetime import datetime, timezone +from time import sleep +from typing import TYPE_CHECKING + +from airflow.compat.functools import cached_property +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + +if TYPE_CHECKING: + from mypy_boto3_appflow.client import AppflowClient + from mypy_boto3_appflow.type_defs import TaskTypeDef + + +class AppflowHook(AwsBaseHook): + """ + Interact with Amazon Appflow, using the boto3 library. + Hook attribute ``conn`` has all methods that listed in documentation. + + .. seealso:: + - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/appflow.html + - https://docs.aws.amazon.com/appflow/1.0/APIReference/Welcome.html + + Additional arguments (such as ``aws_conn_id`` or ``region_name``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + + """ + + EVENTUAL_CONSISTENCY_OFFSET: int = 15 # seconds + EVENTUAL_CONSISTENCY_POLLING: int = 10 # seconds + + def __init__(self, *args, **kwargs) -> None: + kwargs["client_type"] = "appflow" + super().__init__(*args, **kwargs) + + @cached_property + def conn(self) -> AppflowClient: + """Get the underlying boto3 Appflow client (cached)""" + return super().conn + + def run_flow(self, flow_name: str, poll_interval: int = 20) -> str: + """ + Execute an AppFlow run. + + :param flow_name: The flow name + :param poll_interval: Time (seconds) to wait between two consecutive calls to check the run status + :return: The run execution ID + """ + ts_before: datetime = datetime.now(timezone.utc) + sleep(self.EVENTUAL_CONSISTENCY_OFFSET) + response_start = self.conn.start_flow(flowName=flow_name) + execution_id = response_start["executionId"] + self.log.info("executionId: %s", execution_id) + + response_desc = self.conn.describe_flow(flowName=flow_name) + last_exec_details = response_desc["lastRunExecutionDetails"] + + # Wait Appflow eventual consistence + self.log.info("Waiting for Appflow eventual consistence...") + while ( + response_desc.get("lastRunExecutionDetails", {}).get( + "mostRecentExecutionTime", datetime(1970, 1, 1, tzinfo=timezone.utc) + ) + < ts_before + ): + sleep(self.EVENTUAL_CONSISTENCY_POLLING) + response_desc = self.conn.describe_flow(flowName=flow_name) + last_exec_details = response_desc["lastRunExecutionDetails"] + + # Wait flow stops + self.log.info("Waiting for flow run...") + while ( + "mostRecentExecutionStatus" not in last_exec_details + or last_exec_details["mostRecentExecutionStatus"] == "InProgress" + ): + sleep(poll_interval) + response_desc = self.conn.describe_flow(flowName=flow_name) + last_exec_details = response_desc["lastRunExecutionDetails"] + + self.log.info("lastRunExecutionDetails: %s", last_exec_details) + + if last_exec_details["mostRecentExecutionStatus"] == "Error": + raise Exception(f"Flow error:\n{json.dumps(response_desc, default=str)}") + + return execution_id + + def update_flow_filter( + self, flow_name: str, filter_tasks: list[TaskTypeDef], set_trigger_ondemand: bool = False + ) -> None: + """ + Update the flow task filter. + All filters will be removed if an empty array is passed to filter_tasks. + + :param flow_name: The flow name + :param filter_tasks: List flow tasks to be added + :param set_trigger_ondemand: If True, set the trigger to on-demand; otherwise, keep the trigger as is + :return: None + """ + response = self.conn.describe_flow(flowName=flow_name) + connector_type = response["sourceFlowConfig"]["connectorType"] + tasks: list[TaskTypeDef] = [] + + # cleanup old filter tasks + for task in response["tasks"]: + if ( + task["taskType"] == "Filter" + and task.get("connectorOperator", {}).get(connector_type) != "PROJECTION" + ): + self.log.info("Removing task: %s", task) + else: + tasks.append(task) # List of non-filter tasks + + tasks += filter_tasks # Add the new filter tasks + + if set_trigger_ondemand: + # Clean up attribute to force on-demand trigger + del response["triggerConfig"]["triggerProperties"] + + self.conn.update_flow( + flowName=response["flowName"], + destinationFlowConfigList=response["destinationFlowConfigList"], + sourceFlowConfig=response["sourceFlowConfig"], + triggerConfig=response["triggerConfig"], + description=response.get("description", "Flow description."), + tasks=tasks, + ) diff --git a/airflow/providers/amazon/aws/hooks/athena.py b/airflow/providers/amazon/aws/hooks/athena.py index 82e69ccf6f001..3c70898a34298 100644 --- a/airflow/providers/amazon/aws/hooks/athena.py +++ b/airflow/providers/amazon/aws/hooks/athena.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ This module contains AWS Athena hook. @@ -23,9 +22,11 @@ PageIterator """ +from __future__ import annotations + import warnings from time import sleep -from typing import Any, Dict, Optional +from typing import Any from botocore.paginate import PageIterator @@ -46,14 +47,14 @@ class AthenaHook(AwsBaseHook): """ INTERMEDIATE_STATES = ( - 'QUEUED', - 'RUNNING', + "QUEUED", + "RUNNING", ) FAILURE_STATES = ( - 'FAILED', - 'CANCELLED', + "FAILED", + "CANCELLED", ) - SUCCESS_STATES = ('SUCCEEDED',) + SUCCESS_STATES = ("SUCCEEDED",) TERMINAL_STATES = ( "SUCCEEDED", "FAILED", @@ -61,16 +62,16 @@ class AthenaHook(AwsBaseHook): ) def __init__(self, *args: Any, sleep_time: int = 30, **kwargs: Any) -> None: - super().__init__(client_type='athena', *args, **kwargs) # type: ignore + super().__init__(client_type="athena", *args, **kwargs) # type: ignore self.sleep_time = sleep_time def run_query( self, query: str, - query_context: Dict[str, str], - result_configuration: Dict[str, Any], - client_request_token: Optional[str] = None, - workgroup: str = 'primary', + query_context: dict[str, str], + result_configuration: dict[str, Any], + client_request_token: str | None = None, + workgroup: str = "primary", ) -> str: """ Run Presto query on athena with provided config and return submitted query_execution_id @@ -83,17 +84,17 @@ def run_query( :return: str """ params = { - 'QueryString': query, - 'QueryExecutionContext': query_context, - 'ResultConfiguration': result_configuration, - 'WorkGroup': workgroup, + "QueryString": query, + "QueryExecutionContext": query_context, + "ResultConfiguration": result_configuration, + "WorkGroup": workgroup, } if client_request_token: - params['ClientRequestToken'] = client_request_token + params["ClientRequestToken"] = client_request_token response = self.get_conn().start_query_execution(**params) - return response['QueryExecutionId'] + return response["QueryExecutionId"] - def check_query_status(self, query_execution_id: str) -> Optional[str]: + def check_query_status(self, query_execution_id: str) -> str | None: """ Fetch the status of submitted athena query. Returns None or one of valid query states. @@ -103,15 +104,15 @@ def check_query_status(self, query_execution_id: str) -> Optional[str]: response = self.get_conn().get_query_execution(QueryExecutionId=query_execution_id) state = None try: - state = response['QueryExecution']['Status']['State'] + state = response["QueryExecution"]["Status"]["State"] except Exception as ex: - self.log.error('Exception while getting query state %s', ex) + self.log.error("Exception while getting query state %s", ex) finally: # The error is being absorbed here and is being handled by the caller. # The error is being absorbed to implement retries. return state - def get_state_change_reason(self, query_execution_id: str) -> Optional[str]: + def get_state_change_reason(self, query_execution_id: str) -> str | None: """ Fetch the reason for a state change (e.g. error message). Returns None or reason string. @@ -121,17 +122,17 @@ def get_state_change_reason(self, query_execution_id: str) -> Optional[str]: response = self.get_conn().get_query_execution(QueryExecutionId=query_execution_id) reason = None try: - reason = response['QueryExecution']['Status']['StateChangeReason'] + reason = response["QueryExecution"]["Status"]["StateChangeReason"] except Exception as ex: - self.log.error('Exception while getting query state change reason: %s', ex) + self.log.error("Exception while getting query state change reason: %s", ex) finally: # The error is being absorbed here and is being handled by the caller. # The error is being absorbed to implement retries. return reason def get_query_results( - self, query_execution_id: str, next_token_id: Optional[str] = None, max_results: int = 1000 - ) -> Optional[dict]: + self, query_execution_id: str, next_token_id: str | None = None, max_results: int = 1000 + ) -> dict | None: """ Fetch submitted athena query results. returns none if query is in intermediate state or failed/cancelled state else dict of query output @@ -143,23 +144,23 @@ def get_query_results( """ query_state = self.check_query_status(query_execution_id) if query_state is None: - self.log.error('Invalid Query state') + self.log.error("Invalid Query state") return None elif query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES: self.log.error('Query is in "%s" state. Cannot fetch results', query_state) return None - result_params = {'QueryExecutionId': query_execution_id, 'MaxResults': max_results} + result_params = {"QueryExecutionId": query_execution_id, "MaxResults": max_results} if next_token_id: - result_params['NextToken'] = next_token_id + result_params["NextToken"] = next_token_id return self.get_conn().get_query_results(**result_params) def get_query_results_paginator( self, query_execution_id: str, - max_items: Optional[int] = None, - page_size: Optional[int] = None, - starting_token: Optional[str] = None, - ) -> Optional[PageIterator]: + max_items: int | None = None, + page_size: int | None = None, + starting_token: str | None = None, + ) -> PageIterator | None: """ Fetch submitted athena query results. returns none if query is in intermediate state or failed/cancelled state else a paginator to iterate through pages of results. If you @@ -173,46 +174,66 @@ def get_query_results_paginator( """ query_state = self.check_query_status(query_execution_id) if query_state is None: - self.log.error('Invalid Query state (null)') + self.log.error("Invalid Query state (null)") return None if query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES: self.log.error('Query is in "%s" state. Cannot fetch results', query_state) return None result_params = { - 'QueryExecutionId': query_execution_id, - 'PaginationConfig': { - 'MaxItems': max_items, - 'PageSize': page_size, - 'StartingToken': starting_token, + "QueryExecutionId": query_execution_id, + "PaginationConfig": { + "MaxItems": max_items, + "PageSize": page_size, + "StartingToken": starting_token, }, } - paginator = self.get_conn().get_paginator('get_query_results') + paginator = self.get_conn().get_paginator("get_query_results") return paginator.paginate(**result_params) - def poll_query_status(self, query_execution_id: str, max_tries: Optional[int] = None) -> Optional[str]: + def poll_query_status( + self, + query_execution_id: str, + max_tries: int | None = None, + max_polling_attempts: int | None = None, + ) -> str | None: """ Poll the status of submitted athena query until query state reaches final state. Returns one of the final states :param query_execution_id: Id of submitted athena query - :param max_tries: Number of times to poll for query state before function exits + :param max_tries: Deprecated - Use max_polling_attempts instead + :param max_polling_attempts: Number of times to poll for query state before function exits :return: str """ + if max_tries: + warnings.warn( + f"Passing 'max_tries' to {self.__class__.__name__}.poll_query_status is deprecated " + f"and will be removed in a future release. Please use 'max_polling_attempts' instead.", + DeprecationWarning, + stacklevel=2, + ) + if max_polling_attempts and max_polling_attempts != max_tries: + raise Exception("max_polling_attempts must be the same value as max_tries") + else: + max_polling_attempts = max_tries + try_number = 1 - final_query_state = None # Query state when query reaches final state or max_tries reached + final_query_state = None # Query state when query reaches final state or max_polling_attempts reached while True: query_state = self.check_query_status(query_execution_id) if query_state is None: - self.log.info('Trial %s: Invalid query state. Retrying again', try_number) + self.log.info("Trial %s: Invalid query state. Retrying again", try_number) elif query_state in self.TERMINAL_STATES: self.log.info( - 'Trial %s: Query execution completed. Final state is %s}', try_number, query_state + "Trial %s: Query execution completed. Final state is %s}", try_number, query_state ) final_query_state = query_state break else: - self.log.info('Trial %s: Query is still in non-terminal state - %s', try_number, query_state) - if max_tries and try_number >= max_tries: # Break loop if max_tries reached + self.log.info("Trial %s: Query is still in non-terminal state - %s", try_number, query_state) + if ( + max_polling_attempts and try_number >= max_polling_attempts + ): # Break loop if max_polling_attempts reached final_query_state = query_state break try_number += 1 @@ -233,7 +254,7 @@ def get_output_location(self, query_execution_id: str) -> str: if response: try: - output_location = response['QueryExecution']['ResultConfiguration']['OutputLocation'] + output_location = response["QueryExecution"]["ResultConfiguration"]["OutputLocation"] except KeyError: self.log.error("Error retrieving OutputLocation") raise @@ -244,7 +265,7 @@ def get_output_location(self, query_execution_id: str) -> str: return output_location - def stop_query(self, query_execution_id: str) -> Dict: + def stop_query(self, query_execution_id: str) -> dict: """ Cancel the submitted athena query @@ -252,18 +273,3 @@ def stop_query(self, query_execution_id: str) -> Dict: :return: dict """ return self.get_conn().stop_query_execution(QueryExecutionId=query_execution_id) - - -class AWSAthenaHook(AthenaHook): - """ - This hook is deprecated. - Please use :class:`airflow.providers.amazon.aws.hooks.athena.AthenaHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This hook is deprecated. Please use `airflow.providers.amazon.aws.hooks.athena.AthenaHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/hooks/aws_dynamodb.py b/airflow/providers/amazon/aws/hooks/aws_dynamodb.py deleted file mode 100644 index dedb80073e3e5..0000000000000 --- a/airflow/providers/amazon/aws/hooks/aws_dynamodb.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.dynamodb`.""" - -import warnings - -from airflow.providers.amazon.aws.hooks.dynamodb import AwsDynamoDBHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.dynamodb`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py index bbf0bfff83e91..b902d102cb346 100644 --- a/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/airflow/providers/amazon/aws/hooks/base_aws.py @@ -15,44 +15,46 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ This module contains Base AWS Hook. .. seealso:: For more information on how to use this hook, take a look at the guide: - :ref:`howto/connection:AWSHook` + :ref:`howto/connection:aws` """ +from __future__ import annotations -import configparser import datetime +import json import logging -import sys import warnings from functools import wraps -from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union import boto3 import botocore import botocore.session import requests import tenacity +from botocore.client import ClientMeta from botocore.config import Config from botocore.credentials import ReadOnlyCredentials -from slugify import slugify - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - from dateutil.tz import tzlocal +from slugify import slugify +from airflow.compat.functools import cached_property from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.hooks.base import BaseHook -from airflow.models.connection import Connection +from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.log.secrets_masker import mask_secret + +BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[boto3.client, boto3.resource]) + + +if TYPE_CHECKING: + from airflow.models.connection import Connection # Avoid circular imports. class BaseSessionFactory(LoggingMixin): @@ -67,73 +69,85 @@ class BaseSessionFactory(LoggingMixin): :ref:`howto/connection:aws:session-factory` """ - def __init__(self, conn: Connection, region_name: Optional[str], config: Config) -> None: + def __init__( + self, + conn: Connection | AwsConnectionWrapper | None, + region_name: str | None = None, + config: Config | None = None, + ) -> None: super().__init__() - self.conn = conn - self.region_name = region_name - self.config = config - self.extra_config = self.conn.extra_dejson + self._conn = conn + self._region_name = region_name + self._config = config - self.basic_session: Optional[boto3.session.Session] = None - self.role_arn: Optional[str] = None + @cached_property + def conn(self) -> AwsConnectionWrapper: + """Cached AWS Connection Wrapper.""" + return AwsConnectionWrapper( + conn=self._conn, + region_name=self._region_name, + botocore_config=self._config, + ) + + @cached_property + def basic_session(self) -> boto3.session.Session: + """Cached property with basic boto3.session.Session.""" + return self._create_basic_session(session_kwargs=self.conn.session_kwargs) + + @property + def extra_config(self) -> dict[str, Any]: + """AWS Connection extra_config.""" + return self.conn.extra_config + + @property + def region_name(self) -> str | None: + """AWS Region Name read-only property.""" + return self.conn.region_name + + @property + def config(self) -> Config | None: + """Configuration for botocore client read-only property.""" + return self.conn.botocore_config + + @property + def role_arn(self) -> str | None: + """Assume Role ARN from AWS Connection""" + return self.conn.role_arn def create_session(self) -> boto3.session.Session: - """Create AWS session.""" - session_kwargs = {} - if "session_kwargs" in self.extra_config: + """Create boto3 Session from connection config.""" + if not self.conn: self.log.info( - "Retrieving session_kwargs from Connection.extra_config['session_kwargs']: %s", - self.extra_config["session_kwargs"], + "No connection ID provided. Fallback on boto3 credential strategy (region_name=%r). " + "See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html", + self.region_name, ) - session_kwargs = self.extra_config["session_kwargs"] - self.basic_session = self._create_basic_session(session_kwargs=session_kwargs) - self.role_arn = self._read_role_arn_from_extra_config() - # If role_arn was specified then STS + assume_role - if self.role_arn is None: + return boto3.session.Session(region_name=self.region_name) + elif not self.role_arn: return self.basic_session - return self._create_session_with_assume_role(session_kwargs=session_kwargs) - - def _get_region_name(self) -> Optional[str]: - region_name = self.region_name - if self.region_name is None and 'region_name' in self.extra_config: - self.log.info("Retrieving region_name from Connection.extra_config['region_name']") - region_name = self.extra_config["region_name"] - return region_name - - def _create_basic_session(self, session_kwargs: Dict[str, Any]) -> boto3.session.Session: - aws_access_key_id, aws_secret_access_key = self._read_credentials_from_connection() - aws_session_token = self.extra_config.get("aws_session_token") - region_name = self._get_region_name() - self.log.debug( - "Creating session with aws_access_key_id=%s region_name=%s", - aws_access_key_id, - region_name, - ) - - return boto3.session.Session( - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - region_name=region_name, - aws_session_token=aws_session_token, - **session_kwargs, - ) - - def _create_session_with_assume_role(self, session_kwargs: Dict[str, Any]) -> boto3.session.Session: - assume_role_method = self.extra_config.get('assume_role_method', 'assume_role') - self.log.debug("assume_role_method=%s", assume_role_method) - supported_methods = ['assume_role', 'assume_role_with_saml', 'assume_role_with_web_identity'] - if assume_role_method not in supported_methods: - raise NotImplementedError( - f'assume_role_method={assume_role_method} in Connection {self.conn.conn_id} Extra.' - f'Currently {supported_methods} are supported.' - '(Exclude this setting will default to "assume_role").' - ) - if assume_role_method == 'assume_role_with_web_identity': + # Values stored in ``AwsConnectionWrapper.session_kwargs`` are intended to be used only + # to create the initial boto3 session. + # If the user wants to use the 'assume_role' mechanism then only the 'region_name' needs to be + # provided, otherwise other parameters might conflict with the base botocore session. + # Unfortunately it is not a part of public boto3 API, see source of boto3.session.Session: + # https://boto3.amazonaws.com/v1/documentation/api/latest/_modules/boto3/session.html#Session + # If we provide 'aws_access_key_id' or 'aws_secret_access_key' or 'aws_session_token' + # as part of session kwargs it will use them instead of assumed credentials. + assume_session_kwargs = {} + if self.conn.region_name: + assume_session_kwargs["region_name"] = self.conn.region_name + return self._create_session_with_assume_role(session_kwargs=assume_session_kwargs) + + def _create_basic_session(self, session_kwargs: dict[str, Any]) -> boto3.session.Session: + return boto3.session.Session(**session_kwargs) + + def _create_session_with_assume_role(self, session_kwargs: dict[str, Any]) -> boto3.session.Session: + if self.conn.assume_role_method == "assume_role_with_web_identity": # Deferred credentials have no initial credentials credential_fetcher = self._get_web_identity_credential_fetcher() credentials = botocore.credentials.DeferredRefreshableCredentials( - method='assume-role-with-web-identity', + method="assume-role-with-web-identity", refresh_using=credential_fetcher.fetch_credentials, time_fetcher=lambda: datetime.datetime.now(tz=tzlocal()), ) @@ -144,41 +158,34 @@ def _create_session_with_assume_role(self, session_kwargs: Dict[str, Any]) -> bo refresh_using=self._refresh_credentials, method="sts-assume-role", ) + session = botocore.session.get_session() session._credentials = credentials - - if self.basic_session is None: - raise RuntimeError("The basic session should be created here!") - region_name = self.basic_session.region_name session.set_config_variable("region", region_name) return boto3.session.Session(botocore_session=session, **session_kwargs) - def _refresh_credentials(self) -> Dict[str, Any]: - self.log.debug('Refreshing credentials') - assume_role_method = self.extra_config.get('assume_role_method', 'assume_role') - sts_session = self.basic_session - - if sts_session is None: - raise RuntimeError( - "Session should be initialized when refresh credentials with assume_role is used!" - ) + def _refresh_credentials(self) -> dict[str, Any]: + self.log.debug("Refreshing credentials") + assume_role_method = self.conn.assume_role_method + if assume_role_method not in ("assume_role", "assume_role_with_saml"): + raise NotImplementedError(f"assume_role_method={assume_role_method} not expected") - sts_client = sts_session.client("sts", config=self.config) + sts_client = self.basic_session.client("sts", config=self.config) - if assume_role_method == 'assume_role': + if assume_role_method == "assume_role": sts_response = self._assume_role(sts_client=sts_client) - elif assume_role_method == 'assume_role_with_saml': - sts_response = self._assume_role_with_saml(sts_client=sts_client) else: - raise NotImplementedError(f'assume_role_method={assume_role_method} not expected') - sts_response_http_status = sts_response['ResponseMetadata']['HTTPStatusCode'] - if not sts_response_http_status == 200: - raise RuntimeError(f'sts_response_http_status={sts_response_http_status}') - credentials = sts_response['Credentials'] - expiry_time = credentials.get('Expiration').isoformat() - self.log.debug('New credentials expiry_time: %s', expiry_time) + sts_response = self._assume_role_with_saml(sts_client=sts_client) + + sts_response_http_status = sts_response["ResponseMetadata"]["HTTPStatusCode"] + if sts_response_http_status != 200: + raise RuntimeError(f"sts_response_http_status={sts_response_http_status}") + + credentials = sts_response["Credentials"] + expiry_time = credentials.get("Expiration").isoformat() + self.log.debug("New credentials expiry_time: %s", expiry_time) credentials = { "access_key": credentials.get("AccessKeyId"), "secret_key": credentials.get("SecretAccessKey"), @@ -187,77 +194,37 @@ def _refresh_credentials(self) -> Dict[str, Any]: } return credentials - def _read_role_arn_from_extra_config(self) -> Optional[str]: - aws_account_id = self.extra_config.get("aws_account_id") - aws_iam_role = self.extra_config.get("aws_iam_role") - role_arn = self.extra_config.get("role_arn") - if role_arn is None and aws_account_id is not None and aws_iam_role is not None: - self.log.info("Constructing role_arn from aws_account_id and aws_iam_role") - role_arn = f"arn:aws:iam::{aws_account_id}:role/{aws_iam_role}" - self.log.debug("role_arn is %s", role_arn) - return role_arn - - def _read_credentials_from_connection(self) -> Tuple[Optional[str], Optional[str]]: - aws_access_key_id = None - aws_secret_access_key = None - if self.conn.login: - aws_access_key_id = self.conn.login - aws_secret_access_key = self.conn.password - self.log.info("Credentials retrieved from login") - elif "aws_access_key_id" in self.extra_config and "aws_secret_access_key" in self.extra_config: - aws_access_key_id = self.extra_config["aws_access_key_id"] - aws_secret_access_key = self.extra_config["aws_secret_access_key"] - self.log.info("Credentials retrieved from extra_config") - elif "s3_config_file" in self.extra_config: - aws_access_key_id, aws_secret_access_key = _parse_s3_config( - self.extra_config["s3_config_file"], - self.extra_config.get("s3_config_format"), - self.extra_config.get("profile"), - ) - self.log.info("Credentials retrieved from extra_config['s3_config_file']") - return aws_access_key_id, aws_secret_access_key - - def _strip_invalid_session_name_characters(self, role_session_name: str) -> str: - return slugify(role_session_name, regex_pattern=r'[^\w+=,.@-]+') - - def _assume_role(self, sts_client: boto3.client) -> Dict: - assume_role_kwargs = self.extra_config.get("assume_role_kwargs", {}) - if "external_id" in self.extra_config: # Backwards compatibility - assume_role_kwargs["ExternalId"] = self.extra_config.get("external_id") - role_session_name = self._strip_invalid_session_name_characters(f"Airflow_{self.conn.conn_id}") - self.log.debug( - "Doing sts_client.assume_role to role_arn=%s (role_session_name=%s)", - self.role_arn, - role_session_name, - ) - return sts_client.assume_role( - RoleArn=self.role_arn, RoleSessionName=role_session_name, **assume_role_kwargs - ) + def _assume_role(self, sts_client: boto3.client) -> dict: + kw = { + "RoleSessionName": self._strip_invalid_session_name_characters(f"Airflow_{self.conn.conn_id}"), + **self.conn.assume_role_kwargs, + "RoleArn": self.role_arn, + } + return sts_client.assume_role(**kw) - def _assume_role_with_saml(self, sts_client: boto3.client) -> Dict[str, Any]: - saml_config = self.extra_config['assume_role_with_saml'] - principal_arn = saml_config['principal_arn'] + def _assume_role_with_saml(self, sts_client: boto3.client) -> dict[str, Any]: + saml_config = self.extra_config["assume_role_with_saml"] + principal_arn = saml_config["principal_arn"] - idp_auth_method = saml_config['idp_auth_method'] - if idp_auth_method == 'http_spegno_auth': + idp_auth_method = saml_config["idp_auth_method"] + if idp_auth_method == "http_spegno_auth": saml_assertion = self._fetch_saml_assertion_using_http_spegno_auth(saml_config) else: raise NotImplementedError( - f'idp_auth_method={idp_auth_method} in Connection {self.conn.conn_id} Extra.' + f"idp_auth_method={idp_auth_method} in Connection {self.conn.conn_id} Extra." 'Currently only "http_spegno_auth" is supported, and must be specified.' ) self.log.debug("Doing sts_client.assume_role_with_saml to role_arn=%s", self.role_arn) - assume_role_kwargs = self.extra_config.get("assume_role_kwargs", {}) return sts_client.assume_role_with_saml( RoleArn=self.role_arn, PrincipalArn=principal_arn, SAMLAssertion=saml_assertion, - **assume_role_kwargs, + **self.conn.assume_role_kwargs, ) def _get_idp_response( - self, saml_config: Dict[str, Any], auth: requests.auth.AuthBase + self, saml_config: dict[str, Any], auth: requests.auth.AuthBase ) -> requests.models.Response: idp_url = saml_config["idp_url"] self.log.debug("idp_url= %s", idp_url) @@ -285,7 +252,7 @@ def _get_idp_response( return idp_response - def _fetch_saml_assertion_using_http_spegno_auth(self, saml_config: Dict[str, Any]) -> str: + def _fetch_saml_assertion_using_http_spegno_auth(self, saml_config: dict[str, Any]) -> str: # requests_gssapi will need paramiko > 2.6 since you'll need # 'gssapi' not 'python-gssapi' from PyPi. # https://github.com/paramiko/paramiko/pull/1311 @@ -293,32 +260,32 @@ def _fetch_saml_assertion_using_http_spegno_auth(self, saml_config: Dict[str, An from lxml import etree auth = requests_gssapi.HTTPSPNEGOAuth() - if 'mutual_authentication' in saml_config: - mutual_auth = saml_config['mutual_authentication'] - if mutual_auth == 'REQUIRED': + if "mutual_authentication" in saml_config: + mutual_auth = saml_config["mutual_authentication"] + if mutual_auth == "REQUIRED": auth = requests_gssapi.HTTPSPNEGOAuth(requests_gssapi.REQUIRED) - elif mutual_auth == 'OPTIONAL': + elif mutual_auth == "OPTIONAL": auth = requests_gssapi.HTTPSPNEGOAuth(requests_gssapi.OPTIONAL) - elif mutual_auth == 'DISABLED': + elif mutual_auth == "DISABLED": auth = requests_gssapi.HTTPSPNEGOAuth(requests_gssapi.DISABLED) else: raise NotImplementedError( - f'mutual_authentication={mutual_auth} in Connection {self.conn.conn_id} Extra.' + f"mutual_authentication={mutual_auth} in Connection {self.conn.conn_id} Extra." 'Currently "REQUIRED", "OPTIONAL" and "DISABLED" are supported.' - '(Exclude this setting will default to HTTPSPNEGOAuth() ).' + "(Exclude this setting will default to HTTPSPNEGOAuth() )." ) # Query the IDP idp_response = self._get_idp_response(saml_config, auth=auth) # Assist with debugging. Note: contains sensitive info! - xpath = saml_config['saml_response_xpath'] - log_idp_response = 'log_idp_response' in saml_config and saml_config['log_idp_response'] + xpath = saml_config["saml_response_xpath"] + log_idp_response = "log_idp_response" in saml_config and saml_config["log_idp_response"] if log_idp_response: self.log.warning( - 'The IDP response contains sensitive information, but log_idp_response is ON (%s).', + "The IDP response contains sensitive information, but log_idp_response is ON (%s).", log_idp_response, ) - self.log.debug('idp_response.content= %s', idp_response.content) - self.log.debug('xpath= %s', xpath) + self.log.debug("idp_response.content= %s", idp_response.content) + self.log.debug("xpath= %s", xpath) # Extract SAML Assertion from the returned HTML / XML xml = etree.fromstring(idp_response.content) saml_assertion = xml.xpath(xpath) @@ -326,29 +293,26 @@ def _fetch_saml_assertion_using_http_spegno_auth(self, saml_config: Dict[str, An if len(saml_assertion) == 1: saml_assertion = saml_assertion[0] if not saml_assertion: - raise ValueError('Invalid SAML Assertion') + raise ValueError("Invalid SAML Assertion") return saml_assertion def _get_web_identity_credential_fetcher( self, ) -> botocore.credentials.AssumeRoleWithWebIdentityCredentialFetcher: - if self.basic_session is None: - raise Exception("Session should be set where identity is fetched!") base_session = self.basic_session._session or botocore.session.get_session() client_creator = base_session.create_client - federation = self.extra_config.get('assume_role_with_web_identity_federation') - if federation == 'google': + federation = self.extra_config.get("assume_role_with_web_identity_federation") + if federation == "google": web_identity_token_loader = self._get_google_identity_token_loader() else: raise AirflowException( f'Unsupported federation: {federation}. Currently "google" only are supported.' ) - assume_role_kwargs = self.extra_config.get("assume_role_kwargs", {}) return botocore.credentials.AssumeRoleWithWebIdentityCredentialFetcher( client_creator=client_creator, web_identity_token_loader=web_identity_token_loader, role_arn=self.role_arn, - extra_args=assume_role_kwargs, + extra_args=self.conn.assume_role_kwargs, ) def _get_google_identity_token_loader(self): @@ -358,7 +322,7 @@ def _get_google_identity_token_loader(self): get_default_id_token_credentials, ) - audience = self.extra_config.get('assume_role_with_web_identity_federation_audience') + audience = self.extra_config.get("assume_role_with_web_identity_federation_audience") google_id_token_credentials = get_default_id_token_credentials(target_audience=audience) @@ -370,8 +334,39 @@ def web_identity_token_loader(): return web_identity_token_loader + def _strip_invalid_session_name_characters(self, role_session_name: str) -> str: + return slugify(role_session_name, regex_pattern=r"[^\w+=,.@-]+") + + def _get_region_name(self) -> str | None: + warnings.warn( + "`BaseSessionFactory._get_region_name` method deprecated and will be removed " + "in a future releases. Please use `BaseSessionFactory.region_name` property instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.region_name + + def _read_role_arn_from_extra_config(self) -> str | None: + warnings.warn( + "`BaseSessionFactory._read_role_arn_from_extra_config` method deprecated and will be removed " + "in a future releases. Please use `BaseSessionFactory.role_arn` property instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.role_arn + + def _read_credentials_from_connection(self) -> tuple[str | None, str | None]: + warnings.warn( + "`BaseSessionFactory._read_credentials_from_connection` method deprecated and will be removed " + "in a future releases. Please use `BaseSessionFactory.conn.aws_access_key_id` and " + "`BaseSessionFactory.aws_secret_access_key` properties instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.conn.aws_access_key_id, self.conn.aws_secret_access_key -class AwsBaseHook(BaseHook): + +class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]): """ Interact with AWS. This class is a thin wrapper around the boto3 python library. @@ -381,147 +376,152 @@ class AwsBaseHook(BaseHook): running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). - :param verify: Whether or not to verify SSL certificates. + :param verify: Whether or not to verify SSL certificates. See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. :param client_type: boto3.client client_type. Eg 's3', 'emr' etc :param resource_type: boto3.resource resource_type. Eg 'dynamodb' etc - :param config: Configuration for botocore client. - (https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html) + :param config: Configuration for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - conn_name_attr = 'aws_conn_id' - default_conn_name = 'aws_default' - conn_type = 'aws' - hook_name = 'Amazon Web Services' + conn_name_attr = "aws_conn_id" + default_conn_name = "aws_default" + conn_type = "aws" + hook_name = "Amazon Web Services" def __init__( self, - aws_conn_id: Optional[str] = default_conn_name, - verify: Union[bool, str, None] = None, - region_name: Optional[str] = None, - client_type: Optional[str] = None, - resource_type: Optional[str] = None, - config: Optional[Config] = None, + aws_conn_id: str | None = default_conn_name, + verify: bool | str | None = None, + region_name: str | None = None, + client_type: str | None = None, + resource_type: str | None = None, + config: Config | None = None, ) -> None: super().__init__() self.aws_conn_id = aws_conn_id - self.verify = verify self.client_type = client_type self.resource_type = resource_type - self.region_name = region_name - self.config = config - - if not (self.client_type or self.resource_type): - raise AirflowException('Either client_type or resource_type must be provided.') - - def _get_credentials(self, region_name: Optional[str]) -> Tuple[boto3.session.Session, Optional[str]]: - if not self.aws_conn_id: - session = boto3.session.Session(region_name=region_name) - return session, None + self._region_name = region_name + self._config = config + self._verify = verify - self.log.debug("Airflow Connection: aws_conn_id=%s", self.aws_conn_id) - - try: - # Fetch the Airflow connection object - connection_object = self.get_connection(self.aws_conn_id) - extra_config = connection_object.extra_dejson - endpoint_url = extra_config.get("host") - - # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html#botocore.config.Config - if "config_kwargs" in extra_config: - self.log.debug( - "Retrieving config_kwargs from Connection.extra_config['config_kwargs']: %s", - extra_config["config_kwargs"], + @cached_property + def conn_config(self) -> AwsConnectionWrapper: + """Get the Airflow Connection object and wrap it in helper (cached).""" + connection = None + if self.aws_conn_id: + try: + connection = self.get_connection(self.aws_conn_id) + except AirflowNotFoundException: + warnings.warn( + f"Unable to find AWS Connection ID '{self.aws_conn_id}', switching to empty. " + "This behaviour is deprecated and will be removed in a future releases. " + "Please provide existed AWS connection ID or if required boto3 credential strategy " + "explicit set AWS Connection ID to None.", + DeprecationWarning, + stacklevel=2, ) - self.config = Config(**extra_config["config_kwargs"]) - session = SessionFactory( - conn=connection_object, region_name=region_name, config=self.config - ).create_session() + return AwsConnectionWrapper( + conn=connection, region_name=self._region_name, botocore_config=self._config, verify=self._verify + ) - return session, endpoint_url + @property + def region_name(self) -> str | None: + """AWS Region Name read-only property.""" + return self.conn_config.region_name - except AirflowException: - self.log.warning("Unable to use Airflow Connection for credentials.") - self.log.debug("Fallback on boto3 credential strategy") - # http://boto3.readthedocs.io/en/latest/guide/configuration.html + @property + def config(self) -> Config | None: + """Configuration for botocore client read-only property.""" + return self.conn_config.botocore_config - self.log.debug( - "Creating session using boto3 credential strategy region_name=%s", - region_name, - ) - session = boto3.session.Session(region_name=region_name) - return session, None + @property + def verify(self) -> bool | str | None: + """Verify or not SSL certificates boto3 client/resource read-only property.""" + return self.conn_config.verify + + def get_session(self, region_name: str | None = None) -> boto3.session.Session: + """Get the underlying boto3.session.Session(region_name=region_name).""" + return SessionFactory( + conn=self.conn_config, region_name=region_name, config=self.config + ).create_session() def get_client_type( self, - client_type: Optional[str] = None, - region_name: Optional[str] = None, - config: Optional[Config] = None, + region_name: str | None = None, + config: Config | None = None, ) -> boto3.client: """Get the underlying boto3 client using boto3 session""" - session, endpoint_url = self._get_credentials(region_name=region_name) - - if client_type: - warnings.warn( - "client_type is deprecated. Set client_type from class attribute.", - DeprecationWarning, - stacklevel=2, - ) - else: - client_type = self.client_type + client_type = self.client_type # No AWS Operators use the config argument to this method. # Keep backward compatibility with other users who might use it if config is None: config = self.config - return session.client(client_type, endpoint_url=endpoint_url, config=config, verify=self.verify) + session = self.get_session(region_name=region_name) + return session.client( + client_type, endpoint_url=self.conn_config.endpoint_url, config=config, verify=self.verify + ) def get_resource_type( self, - resource_type: Optional[str] = None, - region_name: Optional[str] = None, - config: Optional[Config] = None, + region_name: str | None = None, + config: Config | None = None, ) -> boto3.resource: """Get the underlying boto3 resource using boto3 session""" - session, endpoint_url = self._get_credentials(region_name=region_name) - - if resource_type: - warnings.warn( - "resource_type is deprecated. Set resource_type from class attribute.", - DeprecationWarning, - stacklevel=2, - ) - else: - resource_type = self.resource_type + resource_type = self.resource_type # No AWS Operators use the config argument to this method. # Keep backward compatibility with other users who might use it if config is None: config = self.config - return session.resource(resource_type, endpoint_url=endpoint_url, config=config, verify=self.verify) + session = self.get_session(region_name=region_name) + return session.resource( + resource_type, endpoint_url=self.conn_config.endpoint_url, config=config, verify=self.verify + ) @cached_property - def conn(self) -> Union[boto3.client, boto3.resource]: + def conn(self) -> BaseAwsConnection: """ Get the underlying boto3 client/resource (cached) :return: boto3.client or boto3.resource - :rtype: Union[boto3.client, boto3.resource] """ - if self.client_type: + if not ((not self.client_type) ^ (not self.resource_type)): + raise ValueError( + f"Either client_type={self.client_type!r} or " + f"resource_type={self.resource_type!r} must be provided, not both." + ) + elif self.client_type: return self.get_client_type(region_name=self.region_name) - elif self.resource_type: - return self.get_resource_type(region_name=self.region_name) else: - # Rare possibility - subclasses have not specified a client_type or resource_type - raise NotImplementedError('Could not get boto3 connection!') + return self.get_resource_type(region_name=self.region_name) - def get_conn(self) -> Union[boto3.client, boto3.resource]: + @cached_property + def conn_client_meta(self) -> ClientMeta: + """Get botocore client metadata from Hook connection (cached).""" + conn = self.conn + if isinstance(conn, botocore.client.BaseClient): + return conn.meta + return conn.meta.client.meta + + @property + def conn_region_name(self) -> str: + """Get actual AWS Region Name from Hook connection (cached).""" + return self.conn_client_meta.region_name + + @property + def conn_partition(self) -> str: + """Get associated AWS Region Partition from Hook connection (cached).""" + return self.conn_client_meta.partition + + def get_conn(self) -> BaseAwsConnection: """ Get the underlying boto3 client/resource (cached) @@ -529,29 +529,27 @@ def get_conn(self) -> Union[boto3.client, boto3.resource]: with subclasses that rely on a super().get_conn() method. :return: boto3.client or boto3.resource - :rtype: Union[boto3.client, boto3.resource] """ # Compat shim return self.conn - def get_session(self, region_name: Optional[str] = None) -> boto3.session.Session: - """Get the underlying boto3.session.""" - session, _ = self._get_credentials(region_name=region_name) - return session - - def get_credentials(self, region_name: Optional[str] = None) -> ReadOnlyCredentials: + def get_credentials(self, region_name: str | None = None) -> ReadOnlyCredentials: """ Get the underlying `botocore.Credentials` object. This contains the following authentication attributes: access_key, secret_key and token. + By use this method also secret_key and token will mask in tasks logs. """ - session, _ = self._get_credentials(region_name=region_name) # Credentials are refreshable, so accessing your access key and # secret key separately can lead to a race condition. # See https://stackoverflow.com/a/36291428/8283373 - return session.get_credentials().get_frozen_credentials() + creds = self.get_session(region_name=region_name).get_credentials().get_frozen_credentials() + mask_secret(creds.secret_key) + if creds.token: + mask_secret(creds.token) + return creds - def expand_role(self, role: str, region_name: Optional[str] = None) -> str: + def expand_role(self, role: str, region_name: str | None = None) -> str: """ If the IAM role is a role name, get the Amazon Resource Name (ARN) for the role. If IAM role is already an IAM role ARN, no change is made. @@ -563,8 +561,10 @@ def expand_role(self, role: str, region_name: Optional[str] = None) -> str: if "/" in role: return role else: - session, endpoint_url = self._get_credentials(region_name=region_name) - _client = session.client('iam', endpoint_url=endpoint_url, config=self.config, verify=self.verify) + session = self.get_session(region_name=region_name) + _client = session.client( + "iam", endpoint_url=self.conn_config.endpoint_url, config=self.config, verify=self.verify + ) return _client.get_role(RoleName=role)["Role"]["Arn"] @staticmethod @@ -577,21 +577,21 @@ def retry(should_retry: Callable[[Exception], bool]): def retry_decorator(fun: Callable): @wraps(fun) def decorator_f(self, *args, **kwargs): - retry_args = getattr(self, 'retry_args', None) + retry_args = getattr(self, "retry_args", None) if retry_args is None: return fun(self, *args, **kwargs) - multiplier = retry_args.get('multiplier', 1) - min_limit = retry_args.get('min', 1) - max_limit = retry_args.get('max', 1) - stop_after_delay = retry_args.get('stop_after_delay', 10) + multiplier = retry_args.get("multiplier", 1) + min_limit = retry_args.get("min", 1) + max_limit = retry_args.get("max", 1) + stop_after_delay = retry_args.get("stop_after_delay", 10) tenacity_before_logger = tenacity.before_log(self.log, logging.INFO) if self.log else None tenacity_after_logger = tenacity.after_log(self.log, logging.INFO) if self.log else None default_kwargs = { - 'wait': tenacity.wait_exponential(multiplier=multiplier, max=max_limit, min=min_limit), - 'retry': tenacity.retry_if_exception(should_retry), - 'stop': tenacity.stop_after_delay(stop_after_delay), - 'before': tenacity_before_logger, - 'after': tenacity_after_logger, + "wait": tenacity.wait_exponential(multiplier=multiplier, max=max_limit, min=min_limit), + "retry": tenacity.retry_if_exception(should_retry), + "stop": tenacity.stop_after_delay(stop_after_delay), + "before": tenacity_before_logger, + "after": tenacity_after_logger, } return tenacity.retry(**default_kwargs)(fun)(self, *args, **kwargs) @@ -599,59 +599,81 @@ def decorator_f(self, *args, **kwargs): return retry_decorator + def _get_credentials(self, region_name: str | None) -> tuple[boto3.session.Session, str | None]: + warnings.warn( + "`AwsGenericHook._get_credentials` method deprecated and will be removed in a future releases. " + "Please use `AwsGenericHook.get_session` method and " + "`AwsGenericHook.conn_config.endpoint_url` property instead.", + DeprecationWarning, + stacklevel=2, + ) + + return self.get_session(region_name=region_name), self.conn_config.endpoint_url + + @staticmethod + def get_ui_field_behaviour() -> dict[str, Any]: + """Returns custom UI field behaviour for AWS Connection.""" + return { + "hidden_fields": ["host", "schema", "port"], + "relabeling": { + "login": "AWS Access Key ID", + "password": "AWS Secret Access Key", + }, + "placeholders": { + "login": "AKIAIOSFODNN7EXAMPLE", + "password": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "extra": json.dumps( + { + "region_name": "us-east-1", + "session_kwargs": {"profile_name": "default"}, + "config_kwargs": {"retries": {"mode": "standard", "max_attempts": 10}}, + "role_arn": "arn:aws:iam::123456789098:role/role-name", + "assume_role_method": "assume_role", + "assume_role_kwargs": {"RoleSessionName": "airflow"}, + "aws_session_token": "AQoDYXdzEJr...EXAMPLETOKEN", + "endpoint_url": "http://localhost:4566", + }, + indent=2, + ), + }, + } + + def test_connection(self): + """ + Tests the AWS connection by call AWS STS (Security Token Service) GetCallerIdentity API. -def _parse_s3_config( - config_file_name: str, config_format: Optional[str] = "boto", profile: Optional[str] = None -) -> Tuple[Optional[str], Optional[str]]: + .. seealso:: + https://docs.aws.amazon.com/STS/latest/APIReference/API_GetCallerIdentity.html + """ + try: + session = self.get_session() + conn_info = session.client("sts").get_caller_identity() + metadata = conn_info.pop("ResponseMetadata", {}) + if metadata.get("HTTPStatusCode") != 200: + try: + return False, json.dumps(metadata) + except TypeError: + return False, str(metadata) + conn_info["credentials_method"] = session.get_credentials().method + conn_info["region_name"] = session.region_name + return True, ", ".join(f"{k}={v!r}" for k, v in conn_info.items()) + + except Exception as e: + return False, str(f"{type(e).__name__!r} error occurred while testing connection: {e}") + + +class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]): """ - Parses a config file for s3 credentials. Can currently - parse boto, s3cmd.conf and AWS SDK config formats + Interact with AWS. + This class is a thin wrapper around the boto3 python library + with basic conn annotation. - :param config_file_name: path to the config file - :param config_format: config type. One of "boto", "s3cmd" or "aws". - Defaults to "boto" - :param profile: profile name in AWS type config file + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook` """ - config = configparser.ConfigParser() - if config.read(config_file_name): # pragma: no cover - sections = config.sections() - else: - raise AirflowException(f"Couldn't read {config_file_name}") - # Setting option names depending on file format - if config_format is None: - config_format = "boto" - conf_format = config_format.lower() - if conf_format == "boto": # pragma: no cover - if profile is not None and "profile " + profile in sections: - cred_section = "profile " + profile - else: - cred_section = "Credentials" - elif conf_format == "aws" and profile is not None: - cred_section = profile - else: - cred_section = "default" - # Option names - if conf_format in ("boto", "aws"): # pragma: no cover - key_id_option = "aws_access_key_id" - secret_key_option = "aws_secret_access_key" - # security_token_option = 'aws_security_token' - else: - key_id_option = "access_key" - secret_key_option = "secret_key" - # Actual Parsing - if cred_section not in sections: - raise AirflowException("This config file format is not recognized") - else: - try: - access_key = config.get(cred_section, key_id_option) - secret_key = config.get(cred_section, secret_key_option) - except Exception: - logging.warning("Option Error in parsing s3 config file") - raise - return access_key, secret_key -def resolve_session_factory() -> Type[BaseSessionFactory]: +def resolve_session_factory() -> type[BaseSessionFactory]: """Resolves custom SessionFactory class""" clazz = conf.getimport("aws", "session_factory", fallback=None) if not clazz: @@ -665,3 +687,14 @@ def resolve_session_factory() -> Type[BaseSessionFactory]: SessionFactory = resolve_session_factory() + + +def _parse_s3_config(config_file_name: str, config_format: str | None = "boto", profile: str | None = None): + """For compatibility with airflow.contrib.hooks.aws_hook""" + from airflow.providers.amazon.aws.utils.connection_wrapper import _parse_s3_config + + return _parse_s3_config( + config_file_name=config_file_name, + config_format=config_format, + profile=profile, + ) diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py b/airflow/providers/amazon/aws/hooks/batch_client.py index 3b10012b3943f..e9080189e76b2 100644 --- a/airflow/providers/amazon/aws/hooks/batch_client.py +++ b/airflow/providers/amazon/aws/hooks/batch_client.py @@ -20,14 +20,14 @@ .. seealso:: - - http://boto3.readthedocs.io/en/latest/guide/configuration.html - - http://boto3.readthedocs.io/en/latest/reference/services/batch.html + - https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html + - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html - https://docs.aws.amazon.com/batch/latest/APIReference/Welcome.html """ -import warnings +from __future__ import annotations + from random import uniform from time import sleep -from typing import Dict, List, Optional, Union import botocore.client import botocore.exceptions @@ -48,17 +48,16 @@ class BatchProtocol(Protocol): .. seealso:: - https://mypy.readthedocs.io/en/latest/protocols.html - - http://boto3.readthedocs.io/en/latest/reference/services/batch.html + - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html """ - def describe_jobs(self, jobs: List[str]) -> Dict: + def describe_jobs(self, jobs: list[str]) -> dict: """ Get job descriptions from AWS Batch :param jobs: a list of JobId to describe :return: an API response to describe jobs - :rtype: Dict """ ... @@ -71,7 +70,6 @@ def get_waiter(self, waiterName: str) -> botocore.waiter.Waiter: model file (typically this is CamelCasing). :return: a waiter object for the named AWS Batch service - :rtype: botocore.waiter.Waiter .. note:: AWS Batch might not have any waiters (until botocore PR-1307 is released). @@ -94,11 +92,11 @@ def submit_job( jobName: str, jobQueue: str, jobDefinition: str, - arrayProperties: Dict, - parameters: Dict, - containerOverrides: Dict, - tags: Dict, - ) -> Dict: + arrayProperties: dict, + parameters: dict, + containerOverrides: dict, + tags: dict, + ) -> dict: """ Submit a Batch job @@ -117,11 +115,10 @@ def submit_job( :param tags: the same parameter that boto3 will receive :return: an API response - :rtype: Dict """ ... - def terminate_job(self, jobId: str, reason: str) -> Dict: + def terminate_job(self, jobId: str, reason: str) -> dict: """ Terminate a Batch job @@ -130,7 +127,6 @@ def terminate_job(self, jobId: str, reason: str) -> Dict: :param reason: a reason to terminate job ID :return: an API response - :rtype: Dict """ ... @@ -181,36 +177,41 @@ class BatchClientHook(AwsBaseHook): DEFAULT_DELAY_MIN = 1 DEFAULT_DELAY_MAX = 10 - FAILURE_STATE = 'FAILED' - SUCCESS_STATE = 'SUCCEEDED' - RUNNING_STATE = 'RUNNING' + FAILURE_STATE = "FAILED" + SUCCESS_STATE = "SUCCEEDED" + RUNNING_STATE = "RUNNING" INTERMEDIATE_STATES = ( - 'SUBMITTED', - 'PENDING', - 'RUNNABLE', - 'STARTING', + "SUBMITTED", + "PENDING", + "RUNNABLE", + "STARTING", RUNNING_STATE, ) + COMPUTE_ENVIRONMENT_TERMINAL_STATUS = ("VALID", "DELETED") + COMPUTE_ENVIRONMENT_INTERMEDIATE_STATUS = ("CREATING", "UPDATING", "DELETING") + + JOB_QUEUE_TERMINAL_STATUS = ("VALID", "DELETED") + JOB_QUEUE_INTERMEDIATE_STATUS = ("CREATING", "UPDATING", "DELETING") + def __init__( - self, *args, max_retries: Optional[int] = None, status_retries: Optional[int] = None, **kwargs + self, *args, max_retries: int | None = None, status_retries: int | None = None, **kwargs ) -> None: # https://github.com/python/mypy/issues/6799 hence type: ignore - super().__init__(client_type='batch', *args, **kwargs) # type: ignore + super().__init__(client_type="batch", *args, **kwargs) # type: ignore self.max_retries = max_retries or self.MAX_RETRIES self.status_retries = status_retries or self.STATUS_RETRIES @property - def client(self) -> Union[BatchProtocol, botocore.client.BaseClient]: + def client(self) -> BatchProtocol | botocore.client.BaseClient: """ An AWS API client for Batch services. :return: a boto3 'batch' client for the ``.region_name`` - :rtype: Union[BatchProtocol, botocore.client.BaseClient] """ return self.conn - def terminate_job(self, job_id: str, reason: str) -> Dict: + def terminate_job(self, job_id: str, reason: str) -> dict: """ Terminate a Batch job @@ -219,7 +220,6 @@ def terminate_job(self, job_id: str, reason: str) -> Dict: :param reason: a reason to terminate job ID :return: an API response - :rtype: Dict """ response = self.get_conn().terminate_job(jobId=job_id, reason=reason) self.log.info(response) @@ -232,7 +232,6 @@ def check_job_success(self, job_id: str) -> bool: :param job_id: a Batch job ID - :rtype: bool :raises: AirflowException """ @@ -251,7 +250,7 @@ def check_job_success(self, job_id: str) -> bool: raise AirflowException(f"AWS Batch job ({job_id}) has unknown status: {job}") - def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None) -> None: + def wait_for_job(self, job_id: str, delay: int | float | None = None) -> None: """ Wait for Batch job to complete @@ -266,7 +265,7 @@ def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None) -> No self.poll_for_job_complete(job_id, delay) self.log.info("AWS Batch job (%s) has completed", job_id) - def poll_for_job_running(self, job_id: str, delay: Union[int, float, None] = None) -> None: + def poll_for_job_running(self, job_id: str, delay: int | float | None = None) -> None: """ Poll for job running. The status that indicates a job is running or already complete are: 'RUNNING'|'SUCCEEDED'|'FAILED'. @@ -288,7 +287,7 @@ def poll_for_job_running(self, job_id: str, delay: Union[int, float, None] = Non running_status = [self.RUNNING_STATE, self.SUCCESS_STATE, self.FAILURE_STATE] self.poll_job_status(job_id, running_status) - def poll_for_job_complete(self, job_id: str, delay: Union[int, float, None] = None) -> None: + def poll_for_job_complete(self, job_id: str, delay: int | float | None = None) -> None: """ Poll for job completion. The status that indicates job completion are: 'SUCCEEDED'|'FAILED'. @@ -306,7 +305,7 @@ def poll_for_job_complete(self, job_id: str, delay: Union[int, float, None] = No complete_status = [self.SUCCESS_STATE, self.FAILURE_STATE] self.poll_job_status(job_id, complete_status) - def poll_job_status(self, job_id: str, match_status: List[str]) -> bool: + def poll_job_status(self, job_id: str, match_status: list[str]) -> bool: """ Poll for job status using an exponential back-off strategy (with max_retries). @@ -315,7 +314,6 @@ def poll_job_status(self, job_id: str, match_status: List[str]) -> bool: :param match_status: a list of job status to match; the Batch job status are: 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED' - :rtype: bool :raises: AirflowException """ @@ -348,14 +346,13 @@ def poll_job_status(self, job_id: str, match_status: List[str]) -> bool: ) self.delay(pause) - def get_job_description(self, job_id: str) -> Dict: + def get_job_description(self, job_id: str) -> dict: """ Get job description (using status_retries). :param job_id: a Batch job ID :return: an API response for describe jobs - :rtype: Dict :raises: AirflowException """ @@ -390,7 +387,7 @@ def get_job_description(self, job_id: str) -> Dict: self.delay(pause) @staticmethod - def parse_job_description(job_id: str, response: Dict) -> Dict: + def parse_job_description(job_id: str, response: dict) -> dict: """ Parse job description to extract description for job_id @@ -399,7 +396,6 @@ def parse_job_description(job_id: str, response: Dict) -> Dict: :param response: an API response for describe jobs :return: an API response to describe job_id - :rtype: Dict :raises: AirflowException """ @@ -410,10 +406,47 @@ def parse_job_description(job_id: str, response: Dict) -> Dict: return matching_jobs[0] + def get_job_awslogs_info(self, job_id: str) -> dict[str, str] | None: + """ + Parse job description to extract AWS CloudWatch information. + + :param job_id: AWS Batch Job ID + """ + job_container_desc = self.get_job_description(job_id=job_id).get("container", {}) + log_configuration = job_container_desc.get("logConfiguration", {}) + + # In case if user select other "logDriver" rather than "awslogs" + # than CloudWatch logging should be disabled. + # If user not specify anything than expected that "awslogs" will use + # with default settings: + # awslogs-group = /aws/batch/job + # awslogs-region = `same as AWS Batch Job region` + log_driver = log_configuration.get("logDriver", "awslogs") + if log_driver != "awslogs": + self.log.warning( + "AWS Batch job (%s) uses logDriver (%s). AWS CloudWatch logging disabled.", job_id, log_driver + ) + return None + + awslogs_stream_name = job_container_desc.get("logStreamName") + if not awslogs_stream_name: + # In case of call this method on very early stage of running AWS Batch + # there is possibility than AWS CloudWatch Stream Name not exists yet. + # AWS CloudWatch Stream Name also not created in case of misconfiguration. + self.log.warning("AWS Batch job (%s) doesn't create AWS CloudWatch Stream.", job_id) + return None + + # Try to get user-defined log configuration options + log_options = log_configuration.get("options", {}) + + return { + "awslogs_stream_name": awslogs_stream_name, + "awslogs_group": log_options.get("awslogs-group", "/aws/batch/job"), + "awslogs_region": log_options.get("awslogs-region", self.conn_region_name), + } + @staticmethod - def add_jitter( - delay: Union[int, float], width: Union[int, float] = 1, minima: Union[int, float] = 0 - ) -> float: + def add_jitter(delay: int | float, width: int | float = 1, minima: int | float = 0) -> float: """ Use delay +/- width for random jitter @@ -432,7 +465,6 @@ def add_jitter( :return: uniform(delay - width, delay + width) jitter and it is a non-negative number - :rtype: float """ delay = abs(delay) width = abs(width) @@ -442,7 +474,7 @@ def add_jitter( return uniform(lower, upper) @staticmethod - def delay(delay: Union[int, float, None] = None) -> None: + def delay(delay: int | float | None = None) -> None: """ Pause execution for ``delay`` seconds. @@ -470,7 +502,6 @@ def exponential_delay(tries: int) -> float: :param tries: Number of tries - :rtype: float Examples of behavior: @@ -506,35 +537,3 @@ def exp(tries): delay = 1 + pow(tries * 0.6, 2) delay = min(max_interval, delay) return uniform(delay / 3, delay) - - -class AwsBatchProtocol(BatchProtocol, Protocol): - """ - This class is deprecated. - Please use :class:`airflow.providers.amazon.aws.hooks.batch.BatchProtocol`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This class is deprecated. " - "Please use :class:`airflow.providers.amazon.aws.hooks.batch.BatchProtocol`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class AwsBatchClientHook(BatchClientHook): - """ - This hook is deprecated. - Please use :class:`airflow.providers.amazon.aws.hooks.batch.BatchClientHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This hook is deprecated. " - "Please use :class:`airflow.providers.amazon.aws.hooks.batch.BatchClientHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/hooks/batch_waiters.py b/airflow/providers/amazon/aws/hooks/batch_waiters.py index 59ba0e431f25e..0bbb982e41738 100644 --- a/airflow/providers/amazon/aws/hooks/batch_waiters.py +++ b/airflow/providers/amazon/aws/hooks/batch_waiters.py @@ -15,8 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# - """ AWS Batch service waiters @@ -25,13 +23,12 @@ - https://boto3.amazonaws.com/v1/documentation/api/latest/guide/clients.html#waiters - https://github.com/boto/botocore/blob/develop/botocore/waiter.py """ +from __future__ import annotations import json import sys -import warnings from copy import deepcopy from pathlib import Path -from typing import Dict, List, Optional, Union import botocore.client import botocore.exceptions @@ -71,16 +68,12 @@ class BatchWaitersHook(BatchClientHook): # and the details of the config on that waiter can be further modified without any # accidental impact on the generation of new waiters from the defined waiter_model, e.g. waiters.get_waiter("JobExists").config.delay # -> 5 - waiter = waiters.get_waiter( - "JobExists" - ) # -> botocore.waiter.Batch.Waiter.JobExists object + waiter = waiters.get_waiter("JobExists") # -> botocore.waiter.Batch.Waiter.JobExists object waiter.config.delay = 10 waiters.get_waiter("JobExists").config.delay # -> 5 as defined by waiter_model # To use a specific waiter, update the config and call the `wait()` method for jobId, e.g. - waiter = waiters.get_waiter( - "JobExists" - ) # -> botocore.waiter.Batch.Waiter.JobExists object + waiter = waiters.get_waiter("JobExists") # -> botocore.waiter.Batch.Waiter.JobExists object waiter.config.delay = random.uniform(1, 10) # seconds waiter.config.max_attempts = 10 waiter.wait(jobs=[jobId]) @@ -97,27 +90,26 @@ class BatchWaitersHook(BatchClientHook): :param aws_conn_id: connection id of AWS credentials / region name. If None, credential boto3 strategy will be used - (http://boto3.readthedocs.io/en/latest/guide/configuration.html). + (https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html). :param region_name: region name to use in AWS client. Override the AWS region in connection (if provided) """ - def __init__(self, *args, waiter_config: Optional[Dict] = None, **kwargs) -> None: + def __init__(self, *args, waiter_config: dict | None = None, **kwargs) -> None: super().__init__(*args, **kwargs) - self._default_config = None # type: Optional[Dict] + self._default_config: dict | None = None self._waiter_config = waiter_config or self.default_config self._waiter_model = botocore.waiter.WaiterModel(self._waiter_config) @property - def default_config(self) -> Dict: + def default_config(self) -> dict: """ An immutable default waiter configuration :return: a waiter configuration for AWS Batch services - :rtype: Dict """ if self._default_config is None: config_path = Path(__file__).with_name("batch_waiters.json").resolve() @@ -126,7 +118,7 @@ def default_config(self) -> Dict: return deepcopy(self._default_config) # avoid accidental mutation @property - def waiter_config(self) -> Dict: + def waiter_config(self) -> dict: """ An immutable waiter configuration for this instance; a ``deepcopy`` is returned by this property. During the init for BatchWaiters, the waiter_config is used to build a @@ -134,7 +126,6 @@ def waiter_config(self) -> Dict: mutations of waiter_config leaking into the waiter_model. :return: a waiter configuration for AWS Batch services - :rtype: Dict """ return deepcopy(self._waiter_config) # avoid accidental mutation @@ -144,7 +135,6 @@ def waiter_model(self) -> botocore.waiter.WaiterModel: A configured waiter model used to generate waiters on AWS Batch services. :return: a waiter model for AWS Batch services - :rtype: botocore.waiter.WaiterModel """ return self._waiter_model @@ -179,20 +169,18 @@ def get_waiter(self, waiter_name: str) -> botocore.waiter.Waiter: model file (typically this is CamelCasing); see ``.list_waiters``. :return: a waiter object for the named AWS Batch service - :rtype: botocore.waiter.Waiter """ return botocore.waiter.create_waiter_with_client(waiter_name, self.waiter_model, self.client) - def list_waiters(self) -> List[str]: + def list_waiters(self) -> list[str]: """ List the waiters in a waiter configuration for AWS Batch services. :return: waiter names for AWS Batch services - :rtype: List[str] """ return self.waiter_model.waiter_names - def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None) -> None: + def wait_for_job(self, job_id: str, delay: int | float | None = None) -> None: """ Wait for Batch job to complete. This assumes that the ``.waiter_model`` is configured using some variation of the ``.default_config`` so that it can generate waiters with the @@ -231,19 +219,3 @@ def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None) -> No except (botocore.exceptions.ClientError, botocore.exceptions.WaiterError) as err: raise AirflowException(err) - - -class AwsBatchWaitersHook(BatchWaitersHook): - """ - This hook is deprecated. - Please use :class:`airflow.providers.amazon.aws.hooks.batch.BatchWaitersHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This hook is deprecated. " - "Please use :class:`airflow.providers.amazon.aws.hooks.batch.BatchWaitersHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/hooks/cloud_formation.py b/airflow/providers/amazon/aws/hooks/cloud_formation.py index e96f397628ed9..ac196ee602d1e 100644 --- a/airflow/providers/amazon/aws/hooks/cloud_formation.py +++ b/airflow/providers/amazon/aws/hooks/cloud_formation.py @@ -15,10 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains AWS CloudFormation Hook""" -import warnings -from typing import Optional, Union +from __future__ import annotations from boto3 import client, resource from botocore.exceptions import ClientError @@ -38,17 +36,17 @@ class CloudFormationHook(AwsBaseHook): """ def __init__(self, *args, **kwargs): - super().__init__(client_type='cloudformation', *args, **kwargs) + super().__init__(client_type="cloudformation", *args, **kwargs) - def get_stack_status(self, stack_name: Union[client, resource]) -> Optional[dict]: + def get_stack_status(self, stack_name: client | resource) -> dict | None: """Get stack status from CloudFormation.""" - self.log.info('Poking for stack %s', stack_name) + self.log.info("Poking for stack %s", stack_name) try: - stacks = self.get_conn().describe_stacks(StackName=stack_name)['Stacks'] - return stacks[0]['StackStatus'] + stacks = self.get_conn().describe_stacks(StackName=stack_name)["Stacks"] + return stacks[0]["StackStatus"] except ClientError as e: - if 'does not exist' in str(e): + if "does not exist" in str(e): return None else: raise e @@ -60,11 +58,11 @@ def create_stack(self, stack_name: str, cloudformation_parameters: dict) -> None :param stack_name: stack_name. :param cloudformation_parameters: parameters to be passed to CloudFormation. """ - if 'StackName' not in cloudformation_parameters: - cloudformation_parameters['StackName'] = stack_name + if "StackName" not in cloudformation_parameters: + cloudformation_parameters["StackName"] = stack_name self.get_conn().create_stack(**cloudformation_parameters) - def delete_stack(self, stack_name: str, cloudformation_parameters: Optional[dict] = None) -> None: + def delete_stack(self, stack_name: str, cloudformation_parameters: dict | None = None) -> None: """ Delete stack in CloudFormation. @@ -72,22 +70,6 @@ def delete_stack(self, stack_name: str, cloudformation_parameters: Optional[dict :param cloudformation_parameters: parameters to be passed to CloudFormation (optional). """ cloudformation_parameters = cloudformation_parameters or {} - if 'StackName' not in cloudformation_parameters: - cloudformation_parameters['StackName'] = stack_name + if "StackName" not in cloudformation_parameters: + cloudformation_parameters["StackName"] = stack_name self.get_conn().delete_stack(**cloudformation_parameters) - - -class AWSCloudFormationHook(CloudFormationHook): - """ - This hook is deprecated. - Please use :class:`airflow.providers.amazon.aws.hooks.cloud_formation.CloudFormationHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This hook is deprecated. " - "Please use :class:`airflow.providers.amazon.aws.hooks.cloud_formation.CloudFormationHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/hooks/datasync.py b/airflow/providers/amazon/aws/hooks/datasync.py index b75123d6fc6ec..89cca307502a9 100644 --- a/airflow/providers/amazon/aws/hooks/datasync.py +++ b/airflow/providers/amazon/aws/hooks/datasync.py @@ -14,12 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Interact with AWS DataSync, using the AWS ``boto3`` library.""" +from __future__ import annotations import time -import warnings -from typing import List, Optional from airflow.exceptions import AirflowBadRequest, AirflowException, AirflowTaskTimeout from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -53,7 +51,7 @@ class DataSyncHook(AwsBaseHook): TASK_EXECUTION_SUCCESS_STATES = ("SUCCESS",) def __init__(self, wait_interval_seconds: int = 30, *args, **kwargs) -> None: - super().__init__(client_type='datasync', *args, **kwargs) # type: ignore[misc] + super().__init__(client_type="datasync", *args, **kwargs) # type: ignore[misc] self.locations: list = [] self.tasks: list = [] # wait_interval_seconds = 0 is used during unit tests @@ -86,7 +84,7 @@ def create_location(self, location_uri: str, **create_location_kwargs) -> str: def get_location_arns( self, location_uri: str, case_sensitive: bool = False, ignore_trailing_slash: bool = True - ) -> List[str]: + ) -> list[str]: """ Return all LocationArns which match a LocationUri. @@ -94,7 +92,6 @@ def get_location_arns( :param bool case_sensitive: Do a case sensitive search for location URI. :param bool ignore_trailing_slash: Ignore / at the end of URI when matching. :return: List of LocationArns. - :rtype: list(str) :raises AirflowBadRequest: if ``location_uri`` is empty """ if not location_uri: @@ -141,7 +138,6 @@ def create_task( :param str destination_location_arn: Destination LocationArn. Must exist already. :param create_task_kwargs: Passed to ``boto.create_task()``. See AWS boto3 datasync documentation. :return: TaskArn of the created Task - :rtype: str """ task = self.get_conn().create_task( SourceLocationArn=source_location_arn, @@ -192,7 +188,6 @@ def get_task_arns_for_location_arns( :param list source_location_arns: List of source LocationArns. :param list destination_location_arns: List of destination LocationArns. :return: list - :rtype: list(TaskArns) :raises AirflowBadRequest: if ``source_location_arns`` or ``destination_location_arns`` are empty. """ if not source_location_arns: @@ -219,7 +214,6 @@ def start_task_execution(self, task_arn: str, **kwargs) -> str: :param str task_arn: TaskArn :return: TaskExecutionArn :param kwargs: kwargs sent to ``boto3.start_task_execution()`` - :rtype: str :raises ClientError: If a TaskExecution is already busy running for this ``task_arn``. :raises AirflowBadRequest: If ``task_arn`` is empty. """ @@ -245,7 +239,6 @@ def get_task_description(self, task_arn: str) -> dict: :param str task_arn: TaskArn :return: AWS metadata about a task. - :rtype: dict :raises AirflowBadRequest: If ``task_arn`` is empty. """ if not task_arn: @@ -258,18 +251,16 @@ def describe_task_execution(self, task_execution_arn: str) -> dict: :param str task_execution_arn: TaskExecutionArn :return: AWS metadata about a task execution. - :rtype: dict :raises AirflowBadRequest: If ``task_execution_arn`` is empty. """ return self.get_conn().describe_task_execution(TaskExecutionArn=task_execution_arn) - def get_current_task_execution_arn(self, task_arn: str) -> Optional[str]: + def get_current_task_execution_arn(self, task_arn: str) -> str | None: """ Get current TaskExecutionArn (if one exists) for the specified ``task_arn``. :param str task_arn: TaskArn :return: CurrentTaskExecutionArn for this ``task_arn`` or None. - :rtype: str :raises AirflowBadRequest: if ``task_arn`` is empty. """ if not task_arn: @@ -287,7 +278,6 @@ def wait_for_task_execution(self, task_execution_arn: str, max_iterations: int = :param str task_execution_arn: TaskExecutionArn :param int max_iterations: Maximum number of iterations before timing out. :return: Result of task execution. - :rtype: bool :raises AirflowTaskTimeout: If maximum iterations is exceeded. :raises AirflowBadRequest: If ``task_execution_arn`` is empty. """ @@ -316,18 +306,3 @@ def wait_for_task_execution(self, task_execution_arn: str, max_iterations: int = if iterations <= 0: raise AirflowTaskTimeout("Max iterations exceeded!") raise AirflowException(f"Unknown status: {status}") # Should never happen - - -class AWSDataSyncHook(DataSyncHook): - """ - This hook is deprecated. - Please use :class:`airflow.providers.amazon.aws.hooks.datasync.DataSyncHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This hook is deprecated. Please use `airflow.providers.amazon.aws.hooks.datasync.DataSyncHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/hooks/dms.py b/airflow/providers/amazon/aws/hooks/dms.py index a1bd19daf3129..7ca541dd41f54 100644 --- a/airflow/providers/amazon/aws/hooks/dms.py +++ b/airflow/providers/amazon/aws/hooks/dms.py @@ -15,9 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import json from enum import Enum -from typing import Optional from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -25,10 +26,10 @@ class DmsTaskWaiterStatus(str, Enum): """Available AWS DMS Task Waiter statuses.""" - DELETED = 'deleted' - READY = 'ready' - RUNNING = 'running' - STOPPED = 'stopped' + DELETED = "deleted" + READY = "ready" + RUNNING = "running" + STOPPED = "stopped" class DmsHook(AwsBaseHook): @@ -39,24 +40,21 @@ def __init__( *args, **kwargs, ): - kwargs['client_type'] = 'dms' + kwargs["client_type"] = "dms" super().__init__(*args, **kwargs) - def describe_replication_tasks(self, **kwargs): + def describe_replication_tasks(self, **kwargs) -> tuple[str | None, list]: """ Describe replication tasks :return: Marker and list of replication tasks - :rtype: (Optional[str], list) """ dms_client = self.get_conn() response = dms_client.describe_replication_tasks(**kwargs) - return response.get('Marker'), response.get('ReplicationTasks', []) + return response.get("Marker"), response.get("ReplicationTasks", []) - def find_replication_tasks_by_arn( - self, replication_task_arn: str, without_settings: Optional[bool] = False - ): + def find_replication_tasks_by_arn(self, replication_task_arn: str, without_settings: bool | None = False): """ Find and describe replication tasks by task ARN :param replication_task_arn: Replication task arn @@ -67,8 +65,8 @@ def find_replication_tasks_by_arn( _, tasks = self.describe_replication_tasks( Filters=[ { - 'Name': 'replication-task-arn', - 'Values': [replication_task_arn], + "Name": "replication-task-arn", + "Values": [replication_task_arn], } ], WithoutSettings=without_settings, @@ -76,7 +74,7 @@ def find_replication_tasks_by_arn( return tasks - def get_task_status(self, replication_task_arn: str) -> Optional[str]: + def get_task_status(self, replication_task_arn: str) -> str | None: """ Retrieve task status. @@ -89,11 +87,11 @@ def get_task_status(self, replication_task_arn: str) -> Optional[str]: ) if len(replication_tasks) == 1: - status = replication_tasks[0]['Status'] + status = replication_tasks[0]["Status"] self.log.info('Replication task with ARN(%s) has status "%s".', replication_task_arn, status) return status else: - self.log.info('Replication task with ARN(%s) is not found.', replication_task_arn) + self.log.info("Replication task with ARN(%s) is not found.", replication_task_arn) return None def create_replication_task( @@ -128,7 +126,7 @@ def create_replication_task( **kwargs, ) - replication_task_arn = create_task_response['ReplicationTask']['ReplicationTaskArn'] + replication_task_arn = create_task_response["ReplicationTask"]["ReplicationTaskArn"] self.wait_for_task_status(replication_task_arn, DmsTaskWaiterStatus.READY) return replication_task_arn @@ -182,15 +180,15 @@ def wait_for_task_status(self, replication_task_arn: str, status: DmsTaskWaiterS :param replication_task_arn: Replication task ARN """ if not isinstance(status, DmsTaskWaiterStatus): - raise TypeError('Status must be an instance of DmsTaskWaiterStatus') + raise TypeError("Status must be an instance of DmsTaskWaiterStatus") dms_client = self.get_conn() - waiter = dms_client.get_waiter(f'replication_task_{status}') + waiter = dms_client.get_waiter(f"replication_task_{status}") waiter.wait( Filters=[ { - 'Name': 'replication-task-arn', - 'Values': [ + "Name": "replication-task-arn", + "Values": [ replication_task_arn, ], }, diff --git a/airflow/providers/amazon/aws/hooks/dynamodb.py b/airflow/providers/amazon/aws/hooks/dynamodb.py index 7b298ee15ca62..52c96e9b02c23 100644 --- a/airflow/providers/amazon/aws/hooks/dynamodb.py +++ b/airflow/providers/amazon/aws/hooks/dynamodb.py @@ -15,11 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - - """This module contains the AWS DynamoDB hook""" -import warnings -from typing import Iterable, List, Optional +from __future__ import annotations + +from typing import Iterable from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -40,7 +39,7 @@ class DynamoDBHook(AwsBaseHook): """ def __init__( - self, *args, table_keys: Optional[List] = None, table_name: Optional[str] = None, **kwargs + self, *args, table_keys: list | None = None, table_name: str | None = None, **kwargs ) -> None: self.table_keys = table_keys self.table_name = table_name @@ -58,19 +57,3 @@ def write_batch_data(self, items: Iterable) -> bool: return True except Exception as general_error: raise AirflowException(f"Failed to insert items in dynamodb, error: {str(general_error)}") - - -class AwsDynamoDBHook(DynamoDBHook): - """ - This class is deprecated. - Please use :class:`airflow.providers.amazon.aws.hooks.dynamodb.DynamoDBHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This class is deprecated. " - "Please use :class:`airflow.providers.amazon.aws.hooks.dynamodb.DynamoDBHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/hooks/ec2.py b/airflow/providers/amazon/aws/hooks/ec2.py index 96dbaf541051d..bb715c4bf2ca1 100644 --- a/airflow/providers/amazon/aws/hooks/ec2.py +++ b/airflow/providers/amazon/aws/hooks/ec2.py @@ -15,11 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations import functools import time -from typing import List, Optional from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -31,16 +30,17 @@ def checker(self, *args, **kwargs): if self._api_type == "client_type": return func(self, *args, **kwargs) + ec2_doc_link = "https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ec2.html" raise AirflowException( - """ + f""" This method is only callable when using client_type API for interacting with EC2. Create the EC2Hook object as follows to use this method ec2 = EC2Hook(api_type="client_type") Read following for details on client_type and resource_type APIs: - 1. https://boto3.amazonaws.com/v1/documentation/api/1.9.42/reference/services/ec2.html#client - 2. https://boto3.amazonaws.com/v1/documentation/api/1.9.42/reference/services/ec2.html#service-resource # noqa + 1. {ec2_doc_link}#client + 2. {ec2_doc_link}#service-resource """ ) @@ -70,14 +70,13 @@ def __init__(self, api_type="resource_type", *args, **kwargs) -> None: super().__init__(*args, **kwargs) - def get_instance(self, instance_id: str, filters: Optional[List] = None): + def get_instance(self, instance_id: str, filters: list | None = None): """ Get EC2 instance by id and return it. :param instance_id: id of the AWS EC2 instance :param filters: List of filters to specify instances to get :return: Instance object - :rtype: ec2.Instance """ if self._api_type == "client_type": return self.get_instances(filters=filters, instance_ids=[instance_id]) @@ -121,7 +120,7 @@ def terminate_instances(self, instance_ids: list) -> dict: return self.conn.terminate_instances(InstanceIds=instance_ids) @only_client_type - def describe_instances(self, filters: Optional[List] = None, instance_ids: Optional[List] = None): + def describe_instances(self, filters: list | None = None, instance_ids: list | None = None): """ Describe EC2 instances, optionally applying filters and selective instance ids @@ -138,7 +137,7 @@ def describe_instances(self, filters: Optional[List] = None, instance_ids: Optio return self.conn.describe_instances(Filters=filters, InstanceIds=instance_ids) @only_client_type - def get_instances(self, filters: Optional[List] = None, instance_ids: Optional[List] = None) -> list: + def get_instances(self, filters: list | None = None, instance_ids: list | None = None) -> list: """ Get list of instance details, optionally applying filters and selective instance ids @@ -153,7 +152,7 @@ def get_instances(self, filters: Optional[List] = None, instance_ids: Optional[L ] @only_client_type - def get_instance_ids(self, filters: Optional[List] = None) -> list: + def get_instance_ids(self, filters: list | None = None) -> list: """ Get list of instance ids, optionally applying filters to fetch selective instances @@ -168,7 +167,6 @@ def get_instance_state(self, instance_id: str) -> str: :param instance_id: id of the AWS EC2 instance :return: current state of the instance - :rtype: str """ if self._api_type == "client_type": return self.get_instances(instance_ids=[instance_id])[0]["State"]["Name"] @@ -184,7 +182,6 @@ def wait_for_state(self, instance_id: str, target_state: str, check_interval: fl :param check_interval: time in seconds that the job should wait in between each instance state checks until operation is completed :return: None - :rtype: None """ instance_state = self.get_instance_state(instance_id=instance_id) diff --git a/airflow/providers/amazon/aws/hooks/ecs.py b/airflow/providers/amazon/aws/hooks/ecs.py new file mode 100644 index 0000000000000..f5a9945d92fa3 --- /dev/null +++ b/airflow/providers/amazon/aws/hooks/ecs.py @@ -0,0 +1,225 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import time +from collections import deque +from datetime import datetime, timedelta +from enum import Enum +from logging import Logger +from threading import Event, Thread +from typing import Generator + +from botocore.exceptions import ClientError, ConnectionClosedError +from botocore.waiter import Waiter + +from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart +from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook +from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook +from airflow.typing_compat import Protocol, runtime_checkable + + +def should_retry(exception: Exception): + """Check if exception is related to ECS resource quota (CPU, MEM).""" + if isinstance(exception, EcsOperatorError): + return any( + quota_reason in failure["reason"] + for quota_reason in ["RESOURCE:MEMORY", "RESOURCE:CPU"] + for failure in exception.failures + ) + return False + + +def should_retry_eni(exception: Exception): + """Check if exception is related to ENI (Elastic Network Interfaces).""" + if isinstance(exception, EcsTaskFailToStart): + return any( + eni_reason in exception.message + for eni_reason in ["network interface provisioning", "ResourceInitializationError"] + ) + return False + + +class EcsClusterStates(Enum): + """Contains the possible State values of an ECS Cluster.""" + + ACTIVE = "ACTIVE" + PROVISIONING = "PROVISIONING" + DEPROVISIONING = "DEPROVISIONING" + FAILED = "FAILED" + INACTIVE = "INACTIVE" + + +class EcsTaskDefinitionStates(Enum): + """Contains the possible State values of an ECS Task Definition.""" + + ACTIVE = "ACTIVE" + INACTIVE = "INACTIVE" + + +class EcsTaskStates(Enum): + """Contains the possible State values of an ECS Task.""" + + PROVISIONING = "PROVISIONING" + PENDING = "PENDING" + ACTIVATING = "ACTIVATING" + RUNNING = "RUNNING" + DEACTIVATING = "DEACTIVATING" + STOPPING = "STOPPING" + DEPROVISIONING = "DEPROVISIONING" + STOPPED = "STOPPED" + NONE = "NONE" + + +class EcsHook(AwsGenericHook): + """ + Interact with AWS Elastic Container Service, using the boto3 library + Hook attribute `conn` has all methods that listed in documentation + + .. seealso:: + - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html + - https://docs.aws.amazon.com/AmazonECS/latest/APIReference/Welcome.html + + Additional arguments (such as ``aws_conn_id`` or ``region_name``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook` + + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__(self, *args, **kwargs) -> None: + kwargs["client_type"] = "ecs" + super().__init__(*args, **kwargs) + + def get_cluster_state(self, cluster_name: str) -> str: + return self.conn.describe_clusters(clusters=[cluster_name])["clusters"][0]["status"] + + def get_task_definition_state(self, task_definition: str) -> str: + return self.conn.describe_task_definition(taskDefinition=task_definition)["taskDefinition"]["status"] + + def get_task_state(self, cluster, task) -> str: + return self.conn.describe_tasks(cluster=cluster, tasks=[task])["tasks"][0]["lastStatus"] + + +class EcsTaskLogFetcher(Thread): + """ + Fetches Cloudwatch log events with specific interval as a thread + and sends the log events to the info channel of the provided logger. + """ + + def __init__( + self, + *, + log_group: str, + log_stream_name: str, + fetch_interval: timedelta, + logger: Logger, + aws_conn_id: str | None = "aws_default", + region_name: str | None = None, + ): + super().__init__() + self._event = Event() + + self.fetch_interval = fetch_interval + + self.logger = logger + self.log_group = log_group + self.log_stream_name = log_stream_name + + self.hook = AwsLogsHook(aws_conn_id=aws_conn_id, region_name=region_name) + + def run(self) -> None: + logs_to_skip = 0 + while not self.is_stopped(): + time.sleep(self.fetch_interval.total_seconds()) + log_events = self._get_log_events(logs_to_skip) + for log_event in log_events: + self.logger.info(self._event_to_str(log_event)) + logs_to_skip += 1 + + def _get_log_events(self, skip: int = 0) -> Generator: + try: + yield from self.hook.get_log_events(self.log_group, self.log_stream_name, skip=skip) + except ClientError as error: + if error.response["Error"]["Code"] != "ResourceNotFoundException": + self.logger.warning("Error on retrieving Cloudwatch log events", error) + + yield from () + except ConnectionClosedError as error: + self.logger.warning("ConnectionClosedError on retrieving Cloudwatch log events", error) + yield from () + + def _event_to_str(self, event: dict) -> str: + event_dt = datetime.utcfromtimestamp(event["timestamp"] / 1000.0) + formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3] + message = event["message"] + return f"[{formatted_event_dt}] {message}" + + def get_last_log_messages(self, number_messages) -> list: + return [log["message"] for log in deque(self._get_log_events(), maxlen=number_messages)] + + def get_last_log_message(self) -> str | None: + try: + return self.get_last_log_messages(1)[0] + except IndexError: + return None + + def is_stopped(self) -> bool: + return self._event.is_set() + + def stop(self): + self._event.set() + + +@runtime_checkable +class EcsProtocol(Protocol): + """ + A structured Protocol for ``boto3.client('ecs')``. This is used for type hints on + :py:meth:`.EcsOperator.client`. + + .. seealso:: + + - https://mypy.readthedocs.io/en/latest/protocols.html + - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html + """ + + def run_task(self, **kwargs) -> dict: + """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.run_task""" # noqa: E501 + ... + + def get_waiter(self, x: str) -> Waiter: + """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.get_waiter""" # noqa: E501 + ... + + def describe_tasks(self, cluster: str, tasks) -> dict: + """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.describe_tasks""" # noqa: E501 + ... + + def stop_task(self, cluster, task, reason: str) -> dict: + """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.stop_task""" # noqa: E501 + ... + + def describe_task_definition(self, taskDefinition: str) -> dict: + """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.describe_task_definition""" # noqa: E501 + ... + + def list_tasks(self, cluster: str, launchType: str, desiredStatus: str, family: str) -> dict: + """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.list_tasks""" # noqa: E501 + ... diff --git a/airflow/providers/amazon/aws/hooks/eks.py b/airflow/providers/amazon/aws/hooks/eks.py index d2a795e498dfd..03488dea96771 100644 --- a/airflow/providers/amazon/aws/hooks/eks.py +++ b/airflow/providers/amazon/aws/hooks/eks.py @@ -14,30 +14,30 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Interact with Amazon EKS, using the boto3 library.""" +from __future__ import annotations + import base64 import json import sys import tempfile -import warnings from contextlib import contextmanager from enum import Enum from functools import partial -from typing import Callable, Dict, Generator, List, Optional +from typing import Callable, Generator -import yaml from botocore.exceptions import ClientError from botocore.signers import RequestSigner from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.utils import yaml from airflow.utils.json import AirflowJsonEncoder -DEFAULT_PAGINATION_TOKEN = '' +DEFAULT_PAGINATION_TOKEN = "" STS_TOKEN_EXPIRES_IN = 60 AUTHENTICATION_API_VERSION = "client.authentication.k8s.io/v1alpha1" -_POD_USERNAME = 'aws' -_CONTEXT_NAME = 'aws' +_POD_USERNAME = "aws" +_CONTEXT_NAME = "aws" class ClusterStates(Enum): @@ -86,7 +86,7 @@ class EksHook(AwsBaseHook): :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` """ - client_type = 'eks' + client_type = "eks" def __init__(self, *args, **kwargs) -> None: kwargs["client_type"] = self.client_type @@ -96,9 +96,9 @@ def create_cluster( self, name: str, roleArn: str, - resourcesVpcConfig: Dict, + resourcesVpcConfig: dict, **kwargs, - ) -> Dict: + ) -> dict: """ Creates an Amazon EKS control plane. @@ -108,7 +108,6 @@ def create_cluster( :param resourcesVpcConfig: The VPC configuration used by the cluster control plane. :return: Returns descriptive information about the created EKS Cluster. - :rtype: Dict """ eks_client = self.conn @@ -116,19 +115,19 @@ def create_cluster( name=name, roleArn=roleArn, resourcesVpcConfig=resourcesVpcConfig, **kwargs ) - self.log.info("Created Amazon EKS cluster with the name %s.", response.get('cluster').get('name')) + self.log.info("Created Amazon EKS cluster with the name %s.", response.get("cluster").get("name")) return response def create_nodegroup( self, clusterName: str, nodegroupName: str, - subnets: List[str], - nodeRole: Optional[str], + subnets: list[str], + nodeRole: str | None, *, - tags: Optional[Dict] = None, + tags: dict | None = None, **kwargs, - ) -> Dict: + ) -> dict: """ Creates an Amazon EKS managed node group for an Amazon EKS Cluster. @@ -139,17 +138,16 @@ def create_nodegroup( :param tags: Optional tags to apply to your nodegroup. :return: Returns descriptive information about the created EKS Managed Nodegroup. - :rtype: Dict """ eks_client = self.conn # The below tag is mandatory and must have a value of either 'owned' or 'shared' # A value of 'owned' denotes that the subnets are exclusive to the nodegroup. # The 'shared' value allows more than one resource to use the subnet. - cluster_tag_key = f'kubernetes.io/cluster/{clusterName}' + cluster_tag_key = f"kubernetes.io/cluster/{clusterName}" resolved_tags = tags or {} if cluster_tag_key not in resolved_tags: - resolved_tags[cluster_tag_key] = 'owned' + resolved_tags[cluster_tag_key] = "owned" response = eks_client.create_nodegroup( clusterName=clusterName, @@ -162,19 +160,19 @@ def create_nodegroup( self.log.info( "Created an Amazon EKS managed node group named %s in Amazon EKS cluster %s", - response.get('nodegroup').get('nodegroupName'), - response.get('nodegroup').get('clusterName'), + response.get("nodegroup").get("nodegroupName"), + response.get("nodegroup").get("clusterName"), ) return response def create_fargate_profile( self, clusterName: str, - fargateProfileName: Optional[str], - podExecutionRoleArn: Optional[str], - selectors: List, + fargateProfileName: str | None, + podExecutionRoleArn: str | None, + selectors: list, **kwargs, - ) -> Dict: + ) -> dict: """ Creates an AWS Fargate profile for an Amazon EKS cluster. @@ -185,7 +183,6 @@ def create_fargate_profile( :param selectors: The selectors to match for pods to use this Fargate profile. :return: Returns descriptive information about the created Fargate profile. - :rtype: Dict """ eks_client = self.conn @@ -199,28 +196,27 @@ def create_fargate_profile( self.log.info( "Created AWS Fargate profile with the name %s for Amazon EKS cluster %s.", - response.get('fargateProfile').get('fargateProfileName'), - response.get('fargateProfile').get('clusterName'), + response.get("fargateProfile").get("fargateProfileName"), + response.get("fargateProfile").get("clusterName"), ) return response - def delete_cluster(self, name: str) -> Dict: + def delete_cluster(self, name: str) -> dict: """ Deletes the Amazon EKS Cluster control plane. :param name: The name of the cluster to delete. :return: Returns descriptive information about the deleted EKS Cluster. - :rtype: Dict """ eks_client = self.conn response = eks_client.delete_cluster(name=name) - self.log.info("Deleted Amazon EKS cluster with the name %s.", response.get('cluster').get('name')) + self.log.info("Deleted Amazon EKS cluster with the name %s.", response.get("cluster").get("name")) return response - def delete_nodegroup(self, clusterName: str, nodegroupName: str) -> Dict: + def delete_nodegroup(self, clusterName: str, nodegroupName: str) -> dict: """ Deletes an Amazon EKS managed node group from a specified cluster. @@ -228,7 +224,6 @@ def delete_nodegroup(self, clusterName: str, nodegroupName: str) -> Dict: :param nodegroupName: The name of the nodegroup to delete. :return: Returns descriptive information about the deleted EKS Managed Nodegroup. - :rtype: Dict """ eks_client = self.conn @@ -236,12 +231,12 @@ def delete_nodegroup(self, clusterName: str, nodegroupName: str) -> Dict: self.log.info( "Deleted Amazon EKS managed node group named %s from Amazon EKS cluster %s.", - response.get('nodegroup').get('nodegroupName'), - response.get('nodegroup').get('clusterName'), + response.get("nodegroup").get("nodegroupName"), + response.get("nodegroup").get("clusterName"), ) return response - def delete_fargate_profile(self, clusterName: str, fargateProfileName: str) -> Dict: + def delete_fargate_profile(self, clusterName: str, fargateProfileName: str) -> dict: """ Deletes an AWS Fargate profile from a specified Amazon EKS cluster. @@ -249,7 +244,6 @@ def delete_fargate_profile(self, clusterName: str, fargateProfileName: str) -> D :param fargateProfileName: The name of the Fargate profile to delete. :return: Returns descriptive information about the deleted Fargate profile. - :rtype: Dict """ eks_client = self.conn @@ -259,12 +253,12 @@ def delete_fargate_profile(self, clusterName: str, fargateProfileName: str) -> D self.log.info( "Deleted AWS Fargate profile with the name %s from Amazon EKS cluster %s.", - response.get('fargateProfile').get('fargateProfileName'), - response.get('fargateProfile').get('clusterName'), + response.get("fargateProfile").get("fargateProfileName"), + response.get("fargateProfile").get("clusterName"), ) return response - def describe_cluster(self, name: str, verbose: bool = False) -> Dict: + def describe_cluster(self, name: str, verbose: bool = False) -> dict: """ Returns descriptive information about an Amazon EKS Cluster. @@ -272,21 +266,20 @@ def describe_cluster(self, name: str, verbose: bool = False) -> Dict: :param verbose: Provides additional logging if set to True. Defaults to False. :return: Returns descriptive information about a specific EKS Cluster. - :rtype: Dict """ eks_client = self.conn response = eks_client.describe_cluster(name=name) self.log.info( - "Retrieved details for Amazon EKS cluster named %s.", response.get('cluster').get('name') + "Retrieved details for Amazon EKS cluster named %s.", response.get("cluster").get("name") ) if verbose: - cluster_data = response.get('cluster') + cluster_data = response.get("cluster") self.log.info("Amazon EKS cluster details: %s", json.dumps(cluster_data, cls=AirflowJsonEncoder)) return response - def describe_nodegroup(self, clusterName: str, nodegroupName: str, verbose: bool = False) -> Dict: + def describe_nodegroup(self, clusterName: str, nodegroupName: str, verbose: bool = False) -> dict: """ Returns descriptive information about an Amazon EKS managed node group. @@ -295,7 +288,6 @@ def describe_nodegroup(self, clusterName: str, nodegroupName: str, verbose: bool :param verbose: Provides additional logging if set to True. Defaults to False. :return: Returns descriptive information about a specific EKS Nodegroup. - :rtype: Dict """ eks_client = self.conn @@ -303,11 +295,11 @@ def describe_nodegroup(self, clusterName: str, nodegroupName: str, verbose: bool self.log.info( "Retrieved details for Amazon EKS managed node group named %s in Amazon EKS cluster %s.", - response.get('nodegroup').get('nodegroupName'), - response.get('nodegroup').get('clusterName'), + response.get("nodegroup").get("nodegroupName"), + response.get("nodegroup").get("clusterName"), ) if verbose: - nodegroup_data = response.get('nodegroup') + nodegroup_data = response.get("nodegroup") self.log.info( "Amazon EKS managed node group details: %s", json.dumps(nodegroup_data, cls=AirflowJsonEncoder), @@ -316,7 +308,7 @@ def describe_nodegroup(self, clusterName: str, nodegroupName: str, verbose: bool def describe_fargate_profile( self, clusterName: str, fargateProfileName: str, verbose: bool = False - ) -> Dict: + ) -> dict: """ Returns descriptive information about an AWS Fargate profile. @@ -325,7 +317,6 @@ def describe_fargate_profile( :param verbose: Provides additional logging if set to True. Defaults to False. :return: Returns descriptive information about an AWS Fargate profile. - :rtype: Dict """ eks_client = self.conn @@ -335,11 +326,11 @@ def describe_fargate_profile( self.log.info( "Retrieved details for AWS Fargate profile named %s in Amazon EKS cluster %s.", - response.get('fargateProfile').get('fargateProfileName'), - response.get('fargateProfile').get('clusterName'), + response.get("fargateProfile").get("fargateProfileName"), + response.get("fargateProfile").get("clusterName"), ) if verbose: - fargate_profile_data = response.get('fargateProfile') + fargate_profile_data = response.get("fargateProfile") self.log.info( "AWS Fargate profile details: %s", json.dumps(fargate_profile_data, cls=AirflowJsonEncoder) ) @@ -352,12 +343,11 @@ def get_cluster_state(self, clusterName: str) -> ClusterStates: :param clusterName: The name of the cluster to check. :return: Returns the current status of a given Amazon EKS Cluster. - :rtype: ClusterStates """ eks_client = self.conn try: - return ClusterStates(eks_client.describe_cluster(name=clusterName).get('cluster').get('status')) + return ClusterStates(eks_client.describe_cluster(name=clusterName).get("cluster").get("status")) except ClientError as ex: if ex.response.get("Error").get("Code") == "ResourceNotFoundException": return ClusterStates.NONEXISTENT @@ -371,7 +361,6 @@ def get_fargate_profile_state(self, clusterName: str, fargateProfileName: str) - :param fargateProfileName: The name of the Fargate profile to check. :return: Returns the current status of a given AWS Fargate profile. - :rtype: AWS FargateProfileStates """ eks_client = self.conn @@ -380,8 +369,8 @@ def get_fargate_profile_state(self, clusterName: str, fargateProfileName: str) - eks_client.describe_fargate_profile( clusterName=clusterName, fargateProfileName=fargateProfileName ) - .get('fargateProfile') - .get('status') + .get("fargateProfile") + .get("status") ) except ClientError as ex: if ex.response.get("Error").get("Code") == "ResourceNotFoundException": @@ -396,15 +385,14 @@ def get_nodegroup_state(self, clusterName: str, nodegroupName: str) -> Nodegroup :param nodegroupName: The name of the nodegroup to check. :return: Returns the current status of a given Amazon EKS Nodegroup. - :rtype: NodegroupStates """ eks_client = self.conn try: return NodegroupStates( eks_client.describe_nodegroup(clusterName=clusterName, nodegroupName=nodegroupName) - .get('nodegroup') - .get('status') + .get("nodegroup") + .get("status") ) except ClientError as ex: if ex.response.get("Error").get("Code") == "ResourceNotFoundException": @@ -414,14 +402,13 @@ def get_nodegroup_state(self, clusterName: str, nodegroupName: str) -> Nodegroup def list_clusters( self, verbose: bool = False, - ) -> List: + ) -> list: """ Lists all Amazon EKS Clusters in your AWS account. :param verbose: Provides additional logging if set to True. Defaults to False. :return: A List containing the cluster names. - :rtype: List """ eks_client = self.conn list_cluster_call = partial(eks_client.list_clusters) @@ -432,7 +419,7 @@ def list_nodegroups( self, clusterName: str, verbose: bool = False, - ) -> List: + ) -> list: """ Lists all Amazon EKS managed node groups associated with the specified cluster. @@ -440,7 +427,6 @@ def list_nodegroups( :param verbose: Provides additional logging if set to True. Defaults to False. :return: A List of nodegroup names within the given cluster. - :rtype: List """ eks_client = self.conn list_nodegroups_call = partial(eks_client.list_nodegroups, clusterName=clusterName) @@ -451,7 +437,7 @@ def list_fargate_profiles( self, clusterName: str, verbose: bool = False, - ) -> List: + ) -> list: """ Lists all AWS Fargate profiles associated with the specified cluster. @@ -459,7 +445,6 @@ def list_fargate_profiles( :param verbose: Provides additional logging if set to True. Defaults to False. :return: A list of Fargate profile names within a given cluster. - :rtype: List """ eks_client = self.conn list_fargate_profiles_call = partial(eks_client.list_fargate_profiles, clusterName=clusterName) @@ -468,7 +453,7 @@ def list_fargate_profiles( api_call=list_fargate_profiles_call, response_key="fargateProfileNames", verbose=verbose ) - def _list_all(self, api_call: Callable, response_key: str, verbose: bool) -> List: + def _list_all(self, api_call: Callable, response_key: str, verbose: bool) -> list: """ Repeatedly calls a provided boto3 API Callable and collates the responses into a List. @@ -477,9 +462,8 @@ def _list_all(self, api_call: Callable, response_key: str, verbose: bool) -> Lis :param verbose: Provides additional logging if set to True. Defaults to False. :return: A List of the combined results of the provided API call. - :rtype: List """ - name_collection: List = [] + name_collection: list = [] token = DEFAULT_PAGINATION_TOKEN while token is not None: @@ -498,9 +482,7 @@ def _list_all(self, api_call: Callable, response_key: str, verbose: bool) -> Lis def generate_config_file( self, eks_cluster_name: str, - pod_namespace: Optional[str], - pod_username: Optional[str] = None, - pod_context: Optional[str] = None, + pod_namespace: str | None, ) -> Generator[str, None, None]: """ Writes the kubeconfig file given an EKS Cluster. @@ -508,20 +490,6 @@ def generate_config_file( :param eks_cluster_name: The name of the cluster to generate kubeconfig file for. :param pod_namespace: The namespace to run within kubernetes. """ - if pod_username: - warnings.warn( - "This pod_username parameter is deprecated, because changing the value does not make any " - "visible changes to the user.", - DeprecationWarning, - stacklevel=2, - ) - if pod_context: - warnings.warn( - "This pod_context parameter is deprecated, because changing the value does not make any " - "visible changes to the user.", - DeprecationWarning, - stacklevel=2, - ) # Set up the client eks_client = self.conn @@ -588,7 +556,7 @@ def generate_config_file( } config_text = yaml.dump(cluster_config, default_flow_style=False) - with tempfile.NamedTemporaryFile(mode='w') as config_file: + with tempfile.NamedTemporaryFile(mode="w") as config_file: config_file.write(config_text) config_file.flush() yield config_file.name @@ -597,49 +565,34 @@ def fetch_access_token_for_cluster(self, eks_cluster_name: str) -> str: session = self.get_session() service_id = self.conn.meta.service_model.service_id sts_url = ( - f'https://sts.{session.region_name}.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15' + f"https://sts.{session.region_name}.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15" ) signer = RequestSigner( service_id=service_id, region_name=session.region_name, - signing_name='sts', - signature_version='v4', + signing_name="sts", + signature_version="v4", credentials=session.get_credentials(), event_emitter=session.events, ) request_params = { - 'method': 'GET', - 'url': sts_url, - 'body': {}, - 'headers': {'x-k8s-aws-id': eks_cluster_name}, - 'context': {}, + "method": "GET", + "url": sts_url, + "body": {}, + "headers": {"x-k8s-aws-id": eks_cluster_name}, + "context": {}, } signed_url = signer.generate_presigned_url( request_dict=request_params, region_name=session.region_name, expires_in=STS_TOKEN_EXPIRES_IN, - operation_name='', + operation_name="", ) - base64_url = base64.urlsafe_b64encode(signed_url.encode('utf-8')).decode('utf-8') + base64_url = base64.urlsafe_b64encode(signed_url.encode("utf-8")).decode("utf-8") # remove any base64 encoding padding: - return 'k8s-aws-v1.' + base64_url.rstrip("=") - - -class EKSHook(EksHook): - """ - This hook is deprecated. - Please use :class:`airflow.providers.amazon.aws.hooks.eks.EksHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This hook is deprecated. Please use `airflow.providers.amazon.aws.hooks.eks.EksHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) + return "k8s-aws-v1." + base64_url.rstrip("=") diff --git a/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py b/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py index 47af28845d1aa..f859bef5e861f 100644 --- a/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py +++ b/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from time import sleep -from typing import Optional from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -57,7 +58,6 @@ def create_replication_group(self, config: dict) -> dict: :param config: Configuration for creating the replication group :return: Response from ElastiCache create replication group API - :rtype: dict """ return self.conn.create_replication_group(**config) @@ -67,7 +67,6 @@ def delete_replication_group(self, replication_group_id: str) -> dict: :param replication_group_id: ID of replication group to delete :return: Response from ElastiCache delete replication group API - :rtype: dict """ return self.conn.delete_replication_group(ReplicationGroupId=replication_group_id) @@ -77,7 +76,6 @@ def describe_replication_group(self, replication_group_id: str) -> dict: :param replication_group_id: ID of replication group to describe :return: Response from ElastiCache describe replication group API - :rtype: dict """ return self.conn.describe_replication_groups(ReplicationGroupId=replication_group_id) @@ -87,9 +85,8 @@ def get_replication_group_status(self, replication_group_id: str) -> str: :param replication_group_id: ID of replication group to check for status :return: Current status of replication group - :rtype: str """ - return self.describe_replication_group(replication_group_id)['ReplicationGroups'][0]['Status'] + return self.describe_replication_group(replication_group_id)["ReplicationGroups"][0]["Status"] def is_replication_group_available(self, replication_group_id: str) -> bool: """ @@ -97,17 +94,16 @@ def is_replication_group_available(self, replication_group_id: str) -> bool: :param replication_group_id: ID of replication group to check for availability :return: True if available else False - :rtype: bool """ - return self.get_replication_group_status(replication_group_id) == 'available' + return self.get_replication_group_status(replication_group_id) == "available" def wait_for_availability( self, replication_group_id: str, - initial_sleep_time: Optional[float] = None, - exponential_back_off_factor: Optional[float] = None, - max_retries: Optional[int] = None, - ): + initial_sleep_time: float | None = None, + exponential_back_off_factor: float | None = None, + max_retries: int | None = None, + ) -> bool: """ Check if replication group is available or not by performing a describe over it @@ -119,13 +115,12 @@ def wait_for_availability( :param max_retries: Max retries for checking availability of replication group If this is not supplied then this is defaulted to class level value :return: True if replication is available else False - :rtype: bool """ sleep_time = initial_sleep_time or self.initial_poke_interval exponential_back_off_factor = exponential_back_off_factor or self.exponential_back_off_factor max_retries = max_retries or self.max_retries num_tries = 0 - status = 'not-found' + status = "not-found" stop_poking = False while not stop_poking and num_tries <= max_retries: @@ -133,7 +128,7 @@ def wait_for_availability( stop_poking = status in self.TERMINAL_STATES self.log.info( - 'Current status of replication group with ID %s is %s', replication_group_id, status + "Current status of replication group with ID %s is %s", replication_group_id, status ) if not stop_poking: @@ -143,13 +138,13 @@ def wait_for_availability( if num_tries > max_retries: break - self.log.info('Poke retry %s. Sleep time %s seconds. Sleeping...', num_tries, sleep_time) + self.log.info("Poke retry %s. Sleep time %s seconds. Sleeping...", num_tries, sleep_time) sleep(sleep_time) sleep_time *= exponential_back_off_factor - if status != 'available': + if status != "available": self.log.warning('Replication group is not available. Current status is "%s"', status) return False @@ -159,9 +154,9 @@ def wait_for_availability( def wait_for_deletion( self, replication_group_id: str, - initial_sleep_time: Optional[float] = None, - exponential_back_off_factor: Optional[float] = None, - max_retries: Optional[int] = None, + initial_sleep_time: float | None = None, + exponential_back_off_factor: float | None = None, + max_retries: int | None = None, ): """ Helper for deleting a replication group ensuring it is either deleted or can't be deleted @@ -174,7 +169,6 @@ def wait_for_deletion( :param max_retries: Max retries for checking availability of replication group If this is not supplied then this is defaulted to class level value :return: Response from ElastiCache delete replication group API and flag to identify if deleted or not - :rtype: (dict, bool) """ deleted = False sleep_time = initial_sleep_time or self.initial_poke_interval @@ -188,12 +182,12 @@ def wait_for_deletion( status = self.get_replication_group_status(replication_group_id=replication_group_id) self.log.info( - 'Current status of replication group with ID %s is %s', replication_group_id, status + "Current status of replication group with ID %s is %s", replication_group_id, status ) # Can only delete if status is `available` # Status becomes `deleting` after this call so this will only run once - if status == 'available': + if status == "available": self.log.info("Initiating delete and then wait for it to finish") response = self.delete_replication_group(replication_group_id=replication_group_id) @@ -213,9 +207,9 @@ def wait_for_deletion( # modifying - Replication group has status deleting which is not valid # for deletion. - message = exp.response['Error']['Message'] + message = exp.response["Error"]["Message"] - self.log.warning('Received error message from AWS ElastiCache API : %s', message) + self.log.warning("Received error message from AWS ElastiCache API : %s", message) if not deleted: num_tries += 1 @@ -224,7 +218,7 @@ def wait_for_deletion( if num_tries > max_retries: break - self.log.info('Poke retry %s. Sleep time %s seconds. Sleeping...', num_tries, sleep_time) + self.log.info("Poke retry %s. Sleep time %s seconds. Sleeping...", num_tries, sleep_time) sleep(sleep_time) @@ -235,10 +229,10 @@ def wait_for_deletion( def ensure_delete_replication_group( self, replication_group_id: str, - initial_sleep_time: Optional[float] = None, - exponential_back_off_factor: Optional[float] = None, - max_retries: Optional[int] = None, - ): + initial_sleep_time: float | None = None, + exponential_back_off_factor: float | None = None, + max_retries: int | None = None, + ) -> dict: """ Delete a replication group ensuring it is either deleted or can't be deleted @@ -250,10 +244,9 @@ def ensure_delete_replication_group( :param max_retries: Max retries for checking availability of replication group If this is not supplied then this is defaulted to class level value :return: Response from ElastiCache delete replication group API - :rtype: dict :raises AirflowException: If replication group is not deleted """ - self.log.info('Deleting replication group with ID %s', replication_group_id) + self.log.info("Deleting replication group with ID %s", replication_group_id) response, deleted = self.wait_for_deletion( replication_group_id=replication_group_id, diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py index 143bdcdcc8913..5423dd1af83a5 100644 --- a/airflow/providers/amazon/aws/hooks/emr.py +++ b/airflow/providers/amazon/aws/hooks/emr.py @@ -15,19 +15,27 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + +import json +import warnings from time import sleep -from typing import Any, Dict, List, Optional +from typing import Any, Callable from botocore.exceptions import ClientError -from airflow.exceptions import AirflowException +from airflow.compat.functools import cached_property +from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook class EmrHook(AwsBaseHook): """ - Interact with AWS EMR. emr_conn_id is only necessary for using the - create_job_flow method. + Interact with Amazon Elastic MapReduce Service. + + :param emr_conn_id: :ref:`Amazon Elastic MapReduce Connection `. + This attribute is only necessary when using + the :meth:`~airflow.providers.amazon.aws.hooks.emr.EmrHook.create_job_flow` method. Additional arguments (such as ``aws_conn_id``) may be specified and are passed down to the underlying AwsBaseHook. @@ -36,17 +44,17 @@ class EmrHook(AwsBaseHook): :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` """ - conn_name_attr = 'emr_conn_id' - default_conn_name = 'emr_default' - conn_type = 'emr' - hook_name = 'Amazon Elastic MapReduce' + conn_name_attr = "emr_conn_id" + default_conn_name = "emr_default" + conn_type = "emr" + hook_name = "Amazon Elastic MapReduce" - def __init__(self, emr_conn_id: Optional[str] = default_conn_name, *args, **kwargs) -> None: + def __init__(self, emr_conn_id: str | None = default_conn_name, *args, **kwargs) -> None: self.emr_conn_id = emr_conn_id kwargs["client_type"] = "emr" super().__init__(*args, **kwargs) - def get_cluster_id_by_name(self, emr_cluster_name: str, cluster_states: List[str]) -> Optional[str]: + def get_cluster_id_by_name(self, emr_cluster_name: str, cluster_states: list[str]) -> str | None: """ Fetch id of EMR cluster with given name and (optional) states. Will return only if single id is found. @@ -58,38 +66,222 @@ def get_cluster_id_by_name(self, emr_cluster_name: str, cluster_states: List[str response = self.get_conn().list_clusters(ClusterStates=cluster_states) matching_clusters = list( - filter(lambda cluster: cluster['Name'] == emr_cluster_name, response['Clusters']) + filter(lambda cluster: cluster["Name"] == emr_cluster_name, response["Clusters"]) ) if len(matching_clusters) == 1: - cluster_id = matching_clusters[0]['Id'] - self.log.info('Found cluster name = %s id = %s', emr_cluster_name, cluster_id) + cluster_id = matching_clusters[0]["Id"] + self.log.info("Found cluster name = %s id = %s", emr_cluster_name, cluster_id) return cluster_id elif len(matching_clusters) > 1: - raise AirflowException(f'More than one cluster found for name {emr_cluster_name}') + raise AirflowException(f"More than one cluster found for name {emr_cluster_name}") else: - self.log.info('No cluster found for name %s', emr_cluster_name) + self.log.info("No cluster found for name %s", emr_cluster_name) return None - def create_job_flow(self, job_flow_overrides: Dict[str, Any]) -> Dict[str, Any]: - """ - Creates a job flow using the config from the EMR connection. - Keys of the json extra hash may have the arguments of the boto3 - run_job_flow method. - Overrides for this config may be passed as the job_flow_overrides. + def create_job_flow(self, job_flow_overrides: dict[str, Any]) -> dict[str, Any]: """ - if not self.emr_conn_id: - raise AirflowException('emr_conn_id must be present to use create_job_flow') + Create and start running a new cluster (job flow). - emr_conn = self.get_connection(self.emr_conn_id) + This method uses ``EmrHook.emr_conn_id`` to receive the initial Amazon EMR cluster configuration. + If ``EmrHook.emr_conn_id`` is empty or the connection does not exist, then an empty initial + configuration is used. - config = emr_conn.extra_dejson.copy() + :param job_flow_overrides: Is used to overwrite the parameters in the initial Amazon EMR configuration + cluster. The resulting configuration will be used in the boto3 emr client run_job_flow method. + + .. seealso:: + - :ref:`Amazon Elastic MapReduce Connection ` + - `API RunJobFlow `_ + - `boto3 emr client run_job_flow method `_. + """ + config = {} + if self.emr_conn_id: + try: + emr_conn = self.get_connection(self.emr_conn_id) + except AirflowNotFoundException: + warnings.warn( + f"Unable to find {self.hook_name} Connection ID {self.emr_conn_id!r}, " + "using an empty initial configuration. If you want to get rid of this warning " + "message please provide a valid `emr_conn_id` or set it to None.", + UserWarning, + stacklevel=2, + ) + else: + if emr_conn.conn_type and emr_conn.conn_type != self.conn_type: + warnings.warn( + f"{self.hook_name} Connection expected connection type {self.conn_type!r}, " + f"Connection {self.emr_conn_id!r} has conn_type={emr_conn.conn_type!r}. " + f"This connection might not work correctly.", + UserWarning, + stacklevel=2, + ) + config = emr_conn.extra_dejson.copy() config.update(job_flow_overrides) response = self.get_conn().run_job_flow(**config) return response + def add_job_flow_steps( + self, job_flow_id: str, steps: list[dict] | str | None = None, wait_for_completion: bool = False + ) -> list[str]: + """ + Add new steps to a running cluster. + + :param job_flow_id: The id of the job flow to which the steps are being added + :param steps: A list of the steps to be executed by the job flow + :param wait_for_completion: If True, wait for the steps to be completed. Default is False + """ + response = self.get_conn().add_job_flow_steps(JobFlowId=job_flow_id, Steps=steps) + + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException(f"Adding steps failed: {response}") + + self.log.info("Steps %s added to JobFlow", response["StepIds"]) + if wait_for_completion: + waiter = self.get_conn().get_waiter("step_complete") + for step_id in response["StepIds"]: + waiter.wait( + ClusterId=job_flow_id, + StepId=step_id, + WaiterConfig={ + "Delay": 5, + "MaxAttempts": 100, + }, + ) + return response["StepIds"] + + def test_connection(self): + """ + Return failed state for test Amazon Elastic MapReduce Connection (untestable). + + We need to overwrite this method because this hook is based on + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook`, + otherwise it will try to test connection to AWS STS by using the default boto3 credential strategy. + """ + msg = ( + f"{self.hook_name!r} Airflow Connection cannot be tested, by design it stores " + f"only key/value pairs and does not make a connection to an external resource." + ) + return False, msg + + @staticmethod + def get_ui_field_behaviour() -> dict[str, Any]: + """Returns custom UI field behaviour for Amazon Elastic MapReduce Connection.""" + return { + "hidden_fields": ["host", "schema", "port", "login", "password"], + "relabeling": { + "extra": "Run Job Flow Configuration", + }, + "placeholders": { + "extra": json.dumps( + { + "Name": "MyClusterName", + "ReleaseLabel": "emr-5.36.0", + "Applications": [{"Name": "Spark"}], + "Instances": { + "InstanceGroups": [ + { + "Name": "Primary node", + "Market": "SPOT", + "InstanceRole": "MASTER", + "InstanceType": "m5.large", + "InstanceCount": 1, + }, + ], + "KeepJobFlowAliveWhenNoSteps": False, + "TerminationProtected": False, + }, + "StepConcurrencyLevel": 2, + }, + indent=2, + ), + }, + } + + +class EmrServerlessHook(AwsBaseHook): + """ + Interact with EMR Serverless API. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + JOB_INTERMEDIATE_STATES = {"PENDING", "RUNNING", "SCHEDULED", "SUBMITTED"} + JOB_FAILURE_STATES = {"FAILED", "CANCELLING", "CANCELLED"} + JOB_SUCCESS_STATES = {"SUCCESS"} + JOB_TERMINAL_STATES = JOB_SUCCESS_STATES.union(JOB_FAILURE_STATES) + + APPLICATION_INTERMEDIATE_STATES = {"CREATING", "STARTING", "STOPPING"} + APPLICATION_FAILURE_STATES = {"STOPPED", "TERMINATED"} + APPLICATION_SUCCESS_STATES = {"CREATED", "STARTED"} + + def __init__(self, *args: Any, **kwargs: Any) -> None: + kwargs["client_type"] = "emr-serverless" + super().__init__(*args, **kwargs) + + @cached_property + def conn(self): + """Get the underlying boto3 EmrServerlessAPIService client (cached)""" + return super().conn + + # This method should be replaced with boto waiters which would implement timeouts and backoff nicely. + def waiter( + self, + get_state_callable: Callable, + get_state_args: dict, + parse_response: list, + desired_state: set, + failure_states: set, + object_type: str, + action: str, + countdown: int = 25 * 60, + check_interval_seconds: int = 60, + ) -> None: + """ + Will run the sensor until it turns True. + + :param get_state_callable: A callable to run until it returns True + :param get_state_args: Arguments to pass to get_state_callable + :param parse_response: Dictionary keys to extract state from response of get_state_callable + :param desired_state: Wait until the getter returns this value + :param failure_states: A set of states which indicate failure and should throw an + exception if any are reached before the desired_state + :param object_type: Used for the reporting string. What are you waiting for? (application, job, etc) + :param action: Used for the reporting string. What action are you waiting for? (created, deleted, etc) + :param countdown: Total amount of time the waiter should wait for the desired state + before timing out (in seconds). Defaults to 25 * 60 seconds. + :param check_interval_seconds: Number of seconds waiter should wait before attempting + to retry get_state_callable. Defaults to 60 seconds. + """ + response = get_state_callable(**get_state_args) + state: str = self.get_state(response, parse_response) + while state not in desired_state: + if state in failure_states: + raise AirflowException(f"{object_type.title()} reached failure state {state}.") + if countdown >= check_interval_seconds: + countdown -= check_interval_seconds + self.log.info("Waiting for %s to be %s.", object_type.lower(), action.lower()) + sleep(check_interval_seconds) + state = self.get_state(get_state_callable(**get_state_args), parse_response) + else: + message = f"{object_type.title()} still not {action.lower()} after the allocated time limit." + self.log.error(message) + raise RuntimeError(message) + + def get_state(self, response, keys) -> str: + value = response + for key in keys: + if value is not None: + value = value.get(key, None) + return value + class EmrContainerHook(AwsBaseHook): """ @@ -121,19 +313,45 @@ class EmrContainerHook(AwsBaseHook): "CANCEL_PENDING", ) - def __init__(self, *args: Any, virtual_cluster_id: Optional[str] = None, **kwargs: Any) -> None: + def __init__(self, *args: Any, virtual_cluster_id: str | None = None, **kwargs: Any) -> None: super().__init__(client_type="emr-containers", *args, **kwargs) # type: ignore self.virtual_cluster_id = virtual_cluster_id + def create_emr_on_eks_cluster( + self, + virtual_cluster_name: str, + eks_cluster_name: str, + eks_namespace: str, + tags: dict | None = None, + ) -> str: + response = self.conn.create_virtual_cluster( + name=virtual_cluster_name, + containerProvider={ + "id": eks_cluster_name, + "type": "EKS", + "info": {"eksInfo": {"namespace": eks_namespace}}, + }, + tags=tags or {}, + ) + + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException(f"Create EMR EKS Cluster failed: {response}") + else: + self.log.info( + "Create EMR EKS Cluster success - virtual cluster id %s", + response["id"], + ) + return response["id"] + def submit_job( self, name: str, execution_role_arn: str, release_label: str, job_driver: dict, - configuration_overrides: Optional[dict] = None, - client_request_token: Optional[str] = None, - tags: Optional[dict] = None, + configuration_overrides: dict | None = None, + client_request_token: str | None = None, + tags: dict | None = None, ) -> str: """ Submit a job to the EMR Containers API and return the job ID. @@ -166,17 +384,17 @@ def submit_job( response = self.conn.start_job_run(**params) - if response['ResponseMetadata']['HTTPStatusCode'] != 200: - raise AirflowException(f'Start Job Run failed: {response}') + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException(f"Start Job Run failed: {response}") else: self.log.info( "Start Job Run success - Job Id %s and virtual cluster id %s", - response['id'], - response['virtualClusterId'], + response["id"], + response["virtualClusterId"], ) - return response['id'] + return response["id"] - def get_job_failure_reason(self, job_id: str) -> Optional[str]: + def get_job_failure_reason(self, job_id: str) -> str | None: """ Fetch the reason for a job failure (e.g. error message). Returns None or reason string. @@ -191,17 +409,17 @@ def get_job_failure_reason(self, job_id: str) -> Optional[str]: virtualClusterId=self.virtual_cluster_id, id=job_id, ) - failure_reason = response['jobRun']['failureReason'] + failure_reason = response["jobRun"]["failureReason"] state_details = response["jobRun"]["stateDetails"] reason = f"{failure_reason} - {state_details}" except KeyError: - self.log.error('Could not get status of the EMR on EKS job') + self.log.error("Could not get status of the EMR on EKS job") except ClientError as ex: - self.log.error('AWS request failed, check logs for more info: %s', ex) + self.log.error("AWS request failed, check logs for more info: %s", ex) return reason - def check_query_status(self, job_id: str) -> Optional[str]: + def check_query_status(self, job_id: str) -> str | None: """ Fetch the status of submitted job run. Returns None or one of valid query states. See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr-containers.html#EMRContainers.Client.describe_job_run # noqa: E501 @@ -216,26 +434,43 @@ def check_query_status(self, job_id: str) -> Optional[str]: return response["jobRun"]["state"] except self.conn.exceptions.ResourceNotFoundException: # If the job is not found, we raise an exception as something fatal has happened. - raise AirflowException(f'Job ID {job_id} not found on Virtual Cluster {self.virtual_cluster_id}') + raise AirflowException(f"Job ID {job_id} not found on Virtual Cluster {self.virtual_cluster_id}") except ClientError as ex: # If we receive a generic ClientError, we swallow the exception so that the - self.log.error('AWS request failed, check logs for more info: %s', ex) + self.log.error("AWS request failed, check logs for more info: %s", ex) return None def poll_query_status( - self, job_id: str, max_tries: Optional[int] = None, poll_interval: int = 30 - ) -> Optional[str]: + self, + job_id: str, + max_tries: int | None = None, + poll_interval: int = 30, + max_polling_attempts: int | None = None, + ) -> str | None: """ Poll the status of submitted job run until query state reaches final state. Returns one of the final states. :param job_id: Id of submitted job run - :param max_tries: Number of times to poll for query state before function exits + :param max_tries: Deprecated - Use max_polling_attempts instead :param poll_interval: Time (in seconds) to wait between calls to check query status on EMR + :param max_polling_attempts: Number of times to poll for query state before function exits :return: str """ + if max_tries: + warnings.warn( + f"Method `{self.__class__.__name__}.max_tries` is deprecated and will be removed " + "in a future release. Please use method `max_polling_attempts` instead.", + DeprecationWarning, + stacklevel=2, + ) + if max_polling_attempts and max_polling_attempts != max_tries: + raise Exception("max_polling_attempts must be the same value as max_tries") + else: + max_polling_attempts = max_tries + try_number = 1 - final_query_state = None # Query state when query reaches final state or max_tries reached + final_query_state = None # Query state when query reaches final state or max_polling_attempts reached while True: query_state = self.check_query_status(job_id) @@ -247,14 +482,16 @@ def poll_query_status( break else: self.log.info("Try %s: Query is still in non-terminal state - %s", try_number, query_state) - if max_tries and try_number >= max_tries: # Break loop if max_tries reached + if ( + max_polling_attempts and try_number >= max_polling_attempts + ): # Break loop if max_polling_attempts reached final_query_state = query_state break try_number += 1 sleep(poll_interval) return final_query_state - def stop_query(self, job_id: str) -> Dict: + def stop_query(self, job_id: str) -> dict: """ Cancel the submitted job_run diff --git a/airflow/providers/amazon/aws/hooks/emr_containers.py b/airflow/providers/amazon/aws/hooks/emr_containers.py deleted file mode 100644 index 1e3b7a0ea4aea..0000000000000 --- a/airflow/providers/amazon/aws/hooks/emr_containers.py +++ /dev/null @@ -1,44 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.emr`.""" - -import warnings - -from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.emr`.", - DeprecationWarning, - stacklevel=2, -) - - -class EMRContainerHook(EmrContainerHook): - """ - This class is deprecated. - Please use :class:`airflow.providers.amazon.aws.hooks.emr.EmrContainerHook`. - """ - - def __init__(self, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.amazon.aws.hooks.emr.EmrContainerHook`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(**kwargs) diff --git a/airflow/providers/amazon/aws/hooks/glacier.py b/airflow/providers/amazon/aws/hooks/glacier.py index 00c4b884ae262..926dc05843509 100644 --- a/airflow/providers/amazon/aws/hooks/glacier.py +++ b/airflow/providers/amazon/aws/hooks/glacier.py @@ -15,9 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations - -from typing import Any, Dict +from typing import Any from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -29,20 +29,20 @@ def __init__(self, aws_conn_id: str = "aws_default") -> None: super().__init__(client_type="glacier") self.aws_conn_id = aws_conn_id - def retrieve_inventory(self, vault_name: str) -> Dict[str, Any]: + def retrieve_inventory(self, vault_name: str) -> dict[str, Any]: """ Initiate an Amazon Glacier inventory-retrieval job :param vault_name: the Glacier vault on which job is executed """ - job_params = {'Type': 'inventory-retrieval'} + job_params = {"Type": "inventory-retrieval"} self.log.info("Retrieving inventory for vault: %s", vault_name) response = self.get_conn().initiate_job(vaultName=vault_name, jobParameters=job_params) self.log.info("Initiated inventory-retrieval job for: %s", vault_name) self.log.info("Retrieval Job ID: %s", response["jobId"]) return response - def retrieve_inventory_results(self, vault_name: str, job_id: str) -> Dict[str, Any]: + def retrieve_inventory_results(self, vault_name: str, job_id: str) -> dict[str, Any]: """ Retrieve the results of an Amazon Glacier inventory-retrieval job @@ -53,7 +53,7 @@ def retrieve_inventory_results(self, vault_name: str, job_id: str) -> Dict[str, response = self.get_conn().get_job_output(vaultName=vault_name, jobId=job_id) return response - def describe_job(self, vault_name: str, job_id: str) -> Dict[str, Any]: + def describe_job(self, vault_name: str, job_id: str) -> dict[str, Any]: """ Retrieve the status of an Amazon S3 Glacier job, such as an inventory-retrieval job @@ -63,5 +63,5 @@ def describe_job(self, vault_name: str, job_id: str) -> Dict[str, Any]: """ self.log.info("Retrieving status for vault: %s and job %s", vault_name, job_id) response = self.get_conn().describe_job(vaultName=vault_name, jobId=job_id) - self.log.info("Job status: %s, code status: %s", response['Action'], response['StatusCode']) + self.log.info("Job status: %s, code status: %s", response["Action"], response["StatusCode"]) return response diff --git a/airflow/providers/amazon/aws/hooks/glue.py b/airflow/providers/amazon/aws/hooks/glue.py index dcd6d7c4661cd..f318417412858 100644 --- a/airflow/providers/amazon/aws/hooks/glue.py +++ b/airflow/providers/amazon/aws/hooks/glue.py @@ -15,14 +15,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import time -import warnings -from typing import Dict, List, Optional + +import boto3 from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +DEFAULT_LOG_SUFFIX = "output" +FAILURE_LOG_SUFFIX = "error" +# A filter value of ' ' translates to "match all". +# see: https://docs.aws.amazon.com/AmazonCloudWatch/latest/logs/FilterAndPatternSyntax.html +DEFAULT_LOG_FILTER = " " +FAILURE_LOG_FILTER = "?ERROR ?Exception" + class GlueJobHook(AwsBaseHook): """ @@ -44,15 +52,15 @@ class GlueJobHook(AwsBaseHook): def __init__( self, - s3_bucket: Optional[str] = None, - job_name: Optional[str] = None, - desc: Optional[str] = None, + s3_bucket: str | None = None, + job_name: str | None = None, + desc: str | None = None, concurrent_run_limit: int = 1, - script_location: Optional[str] = None, + script_location: str | None = None, retry_limit: int = 0, - num_of_dpus: Optional[int] = None, - iam_role_name: Optional[str] = None, - create_job_kwargs: Optional[dict] = None, + num_of_dpus: int | None = None, + iam_role_name: str | None = None, + create_job_kwargs: dict | None = None, *args, **kwargs, ): @@ -63,7 +71,7 @@ def __init__( self.retry_limit = retry_limit self.s3_bucket = s3_bucket self.role_name = iam_role_name - self.s3_glue_logs = 'logs/glue-logs/' + self.s3_glue_logs = "logs/glue-logs/" self.create_job_kwargs = create_job_kwargs or {} worker_type_exists = "WorkerType" in self.create_job_kwargs @@ -81,20 +89,20 @@ def __init__( else: self.num_of_dpus = num_of_dpus - kwargs['client_type'] = 'glue' + kwargs["client_type"] = "glue" super().__init__(*args, **kwargs) - def list_jobs(self) -> List: + def list_jobs(self) -> list: """:return: Lists of Jobs""" conn = self.get_conn() return conn.get_jobs() - def get_iam_execution_role(self) -> Dict: + def get_iam_execution_role(self) -> dict: """:return: iam role for job execution""" - session, endpoint_url = self._get_credentials(region_name=self.region_name) - iam_client = session.client('iam', endpoint_url=endpoint_url, config=self.config, verify=self.verify) - try: + iam_client = self.get_session(region_name=self.region_name).client( + "iam", endpoint_url=self.conn_config.endpoint_url, config=self.config, verify=self.verify + ) glue_execution_role = iam_client.get_role(RoleName=self.role_name) self.log.info("Iam Role Name: %s", self.role_name) return glue_execution_role @@ -104,9 +112,9 @@ def get_iam_execution_role(self) -> Dict: def initialize_job( self, - script_arguments: Optional[dict] = None, - run_kwargs: Optional[dict] = None, - ) -> Dict[str, str]: + script_arguments: dict | None = None, + run_kwargs: dict | None = None, + ) -> dict[str, str]: """ Initializes connection with AWS Glue to run job @@ -134,34 +142,94 @@ def get_job_state(self, job_name: str, run_id: str) -> str: """ glue_client = self.get_conn() job_run = glue_client.get_job_run(JobName=job_name, RunId=run_id, PredecessorsIncluded=True) - return job_run['JobRun']['JobRunState'] + return job_run["JobRun"]["JobRunState"] + + def print_job_logs( + self, + job_name: str, + run_id: str, + job_failed: bool = False, + next_token: str | None = None, + ) -> str | None: + """Prints the batch of logs to the Airflow task log and returns nextToken.""" + log_client = boto3.client("logs") + response = {} + + filter_pattern = FAILURE_LOG_FILTER if job_failed else DEFAULT_LOG_FILTER + log_group_prefix = self.conn.get_job_run(JobName=job_name, RunId=run_id)["JobRun"]["LogGroupName"] + log_group_suffix = FAILURE_LOG_SUFFIX if job_failed else DEFAULT_LOG_SUFFIX + log_group_name = f"{log_group_prefix}/{log_group_suffix}" + + try: + if next_token: + response = log_client.filter_log_events( + logGroupName=log_group_name, + logStreamNames=[run_id], + filterPattern=filter_pattern, + nextToken=next_token, + ) + else: + response = log_client.filter_log_events( + logGroupName=log_group_name, + logStreamNames=[run_id], + filterPattern=filter_pattern, + ) + if len(response["events"]): + messages = "\t".join([event["message"] for event in response["events"]]) + self.log.info("Glue Job Run Logs:\n\t%s", messages) + + except log_client.exceptions.ResourceNotFoundException: + self.log.warning( + "No new Glue driver logs found. This might be because there are no new logs, " + "or might be an error.\nIf the error persists, check the CloudWatch dashboard " + f"at: https://{self.conn_region_name}.console.aws.amazon.com/cloudwatch/home" + ) + + # If no new log events are available, filter_log_events will return None. + # In that case, check the same token again next pass. + return response.get("nextToken") or next_token - def job_completion(self, job_name: str, run_id: str) -> Dict[str, str]: + def job_completion(self, job_name: str, run_id: str, verbose: bool = False) -> dict[str, str]: """ Waits until Glue job with job_name completes or fails and return final state if finished. Raises AirflowException when the job failed :param job_name: unique job name per AWS account :param run_id: The job-run ID of the predecessor job run + :param verbose: If True, more Glue Job Run logs show in the Airflow Task Logs. (default: False) :return: Dict of JobRunState and JobRunId """ - failed_states = ['FAILED', 'TIMEOUT'] - finished_states = ['SUCCEEDED', 'STOPPED'] + failed_states = ["FAILED", "TIMEOUT"] + finished_states = ["SUCCEEDED", "STOPPED"] + next_log_token = None + job_failed = False while True: - job_run_state = self.get_job_state(job_name, run_id) - if job_run_state in finished_states: - self.log.info("Exiting Job %s Run State: %s", run_id, job_run_state) - return {'JobRunState': job_run_state, 'JobRunId': run_id} - if job_run_state in failed_states: - job_error_message = f"Exiting Job {run_id} Run State: {job_run_state}" - self.log.info(job_error_message) - raise AirflowException(job_error_message) - else: - self.log.info( - "Polling for AWS Glue Job %s current run state with status %s", job_name, job_run_state - ) - time.sleep(self.JOB_POLL_INTERVAL) + try: + job_run_state = self.get_job_state(job_name, run_id) + if job_run_state in finished_states: + self.log.info("Exiting Job %s Run State: %s", run_id, job_run_state) + return {"JobRunState": job_run_state, "JobRunId": run_id} + if job_run_state in failed_states: + job_failed = True + job_error_message = f"Exiting Job {run_id} Run State: {job_run_state}" + self.log.info(job_error_message) + raise AirflowException(job_error_message) + else: + self.log.info( + "Polling for AWS Glue Job %s current run state with status %s", + job_name, + job_run_state, + ) + time.sleep(self.JOB_POLL_INTERVAL) + finally: + if verbose: + next_log_token = self.print_job_logs( + job_name=job_name, + run_id=run_id, + job_failed=job_failed, + next_token=next_log_token, + ) def get_or_create_glue_job(self) -> str: """ @@ -172,23 +240,29 @@ def get_or_create_glue_job(self) -> str: try: get_job_response = glue_client.get_job(JobName=self.job_name) self.log.info("Job Already exist. Returning Name of the job") - return get_job_response['Job']['Name'] + return get_job_response["Job"]["Name"] except glue_client.exceptions.EntityNotFoundException: self.log.info("Job doesn't exist. Now creating and running AWS Glue Job") if self.s3_bucket is None: - raise AirflowException('Could not initialize glue job, error: Specify Parameter `s3_bucket`') - s3_log_path = f's3://{self.s3_bucket}/{self.s3_glue_logs}{self.job_name}' + raise AirflowException("Could not initialize glue job, error: Specify Parameter `s3_bucket`") + s3_log_path = f"s3://{self.s3_bucket}/{self.s3_glue_logs}{self.job_name}" execution_role = self.get_iam_execution_role() try: + default_command = { + "Name": "glueetl", + "ScriptLocation": self.script_location, + } + command = self.create_job_kwargs.pop("Command", default_command) + if "WorkerType" in self.create_job_kwargs and "NumberOfWorkers" in self.create_job_kwargs: create_job_response = glue_client.create_job( Name=self.job_name, Description=self.desc, LogUri=s3_log_path, - Role=execution_role['Role']['Arn'], + Role=execution_role["Role"]["Arn"], ExecutionProperty={"MaxConcurrentRuns": self.concurrent_run_limit}, - Command={"Name": "glueetl", "ScriptLocation": self.script_location}, + Command=command, MaxRetries=self.retry_limit, **self.create_job_kwargs, ) @@ -197,30 +271,14 @@ def get_or_create_glue_job(self) -> str: Name=self.job_name, Description=self.desc, LogUri=s3_log_path, - Role=execution_role['Role']['Arn'], + Role=execution_role["Role"]["Arn"], ExecutionProperty={"MaxConcurrentRuns": self.concurrent_run_limit}, - Command={"Name": "glueetl", "ScriptLocation": self.script_location}, + Command=command, MaxRetries=self.retry_limit, MaxCapacity=self.num_of_dpus, **self.create_job_kwargs, ) - return create_job_response['Name'] + return create_job_response["Name"] except Exception as general_error: self.log.error("Failed to create aws glue job, error: %s", general_error) raise - - -class AwsGlueJobHook(GlueJobHook): - """ - This hook is deprecated. - Please use :class:`airflow.providers.amazon.aws.hooks.glue.GlueJobHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This hook is deprecated. " - "Please use :class:`airflow.providers.amazon.aws.hooks.glue.GlueJobHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/hooks/glue_catalog.py b/airflow/providers/amazon/aws/hooks/glue_catalog.py index e77916d09e3b9..5387e19dbf893 100644 --- a/airflow/providers/amazon/aws/hooks/glue_catalog.py +++ b/airflow/providers/amazon/aws/hooks/glue_catalog.py @@ -15,10 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains AWS Glue Catalog Hook""" -import warnings -from typing import Dict, List, Optional, Set +from __future__ import annotations from botocore.exceptions import ClientError @@ -38,16 +36,16 @@ class GlueCatalogHook(AwsBaseHook): """ def __init__(self, *args, **kwargs): - super().__init__(client_type='glue', *args, **kwargs) + super().__init__(client_type="glue", *args, **kwargs) def get_partitions( self, database_name: str, table_name: str, - expression: str = '', - page_size: Optional[int] = None, - max_items: Optional[int] = None, - ) -> Set[tuple]: + expression: str = "", + page_size: int | None = None, + max_items: int | None = None, + ) -> set[tuple]: """ Retrieves the partition values for a table. @@ -63,19 +61,19 @@ def get_partitions( ``{('2018-01-01','1'), ('2018-01-01','2')}`` """ config = { - 'PageSize': page_size, - 'MaxItems': max_items, + "PageSize": page_size, + "MaxItems": max_items, } - paginator = self.get_conn().get_paginator('get_partitions') + paginator = self.get_conn().get_paginator("get_partitions") response = paginator.paginate( DatabaseName=database_name, TableName=table_name, Expression=expression, PaginationConfig=config ) partitions = set() for page in response: - for partition in page['Partitions']: - partitions.add(tuple(partition['Values'])) + for partition in page["Partitions"]: + partitions.add(tuple(partition["Values"])) return partitions @@ -87,7 +85,6 @@ def check_for_partition(self, database_name: str, table_name: str, expression: s :param table_name: Name of hive table @partition belongs to :expression: Expression that matches the partitions to check for (eg `a = 'b' AND c = 'd'`) - :rtype: bool >>> hook = GlueCatalogHook() >>> t = 'static_babynames_partitioned' @@ -104,7 +101,6 @@ def get_table(self, database_name: str, table_name: str) -> dict: :param database_name: Name of hive database (schema) @table belongs to :param table_name: Name of hive table - :rtype: dict >>> hook = GlueCatalogHook() >>> r = hook.get_table('db', 'table_foo') @@ -112,7 +108,7 @@ def get_table(self, database_name: str, table_name: str) -> dict: """ result = self.get_conn().get_table(DatabaseName=database_name, Name=table_name) - return result['Table'] + return result["Table"] def get_table_location(self, database_name: str, table_name: str) -> str: """ @@ -124,9 +120,9 @@ def get_table_location(self, database_name: str, table_name: str) -> str: """ table = self.get_table(database_name, table_name) - return table['StorageDescriptor']['Location'] + return table["StorageDescriptor"]["Location"] - def get_partition(self, database_name: str, table_name: str, partition_values: List[str]) -> Dict: + def get_partition(self, database_name: str, table_name: str, partition_values: list[str]) -> dict: """ Gets a Partition @@ -136,7 +132,6 @@ def get_partition(self, database_name: str, table_name: str, partition_values: L Please see official AWS documentation for further information. https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-GetPartition - :rtype: dict :raises: AirflowException @@ -153,7 +148,7 @@ def get_partition(self, database_name: str, table_name: str, partition_values: L self.log.error("Client error: %s", e) raise AirflowException("AWS request failed, check logs for more info") - def create_partition(self, database_name: str, table_name: str, partition_input: Dict) -> Dict: + def create_partition(self, database_name: str, table_name: str, partition_input: dict) -> dict: """ Creates a new Partition @@ -163,7 +158,6 @@ def create_partition(self, database_name: str, table_name: str, partition_input: Please see official AWS documentation for further information. https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-CreatePartition - :rtype: dict :raises: AirflowException @@ -178,19 +172,3 @@ def create_partition(self, database_name: str, table_name: str, partition_input: except ClientError as e: self.log.error("Client error: %s", e) raise AirflowException("AWS request failed, check logs for more info") - - -class AwsGlueCatalogHook(GlueCatalogHook): - """ - This hook is deprecated. - Please use :class:`airflow.providers.amazon.aws.hooks.glue_catalog.GlueCatalogHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This hook is deprecated. " - "Please use :class:`airflow.providers.amazon.aws.hooks.glue_catalog.GlueCatalogHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/hooks/glue_crawler.py b/airflow/providers/amazon/aws/hooks/glue_crawler.py index 65f7df8d28566..917b96b2c6549 100644 --- a/airflow/providers/amazon/aws/hooks/glue_crawler.py +++ b/airflow/providers/amazon/aws/hooks/glue_crawler.py @@ -15,15 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys -import warnings -from time import sleep +from __future__ import annotations -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property +from time import sleep +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -40,7 +36,7 @@ class GlueCrawlerHook(AwsBaseHook): """ def __init__(self, *args, **kwargs): - kwargs['client_type'] = 'glue' + kwargs["client_type"] = "glue" super().__init__(*args, **kwargs) @cached_property @@ -70,7 +66,7 @@ def get_crawler(self, crawler_name: str) -> dict: :param crawler_name: unique crawler name per AWS account :return: Nested dictionary of crawler configurations """ - return self.glue_client.get_crawler(Name=crawler_name)['Crawler'] + return self.glue_client.get_crawler(Name=crawler_name)["Crawler"] def update_crawler(self, **crawler_kwargs) -> bool: """ @@ -79,7 +75,7 @@ def update_crawler(self, **crawler_kwargs) -> bool: :param crawler_kwargs: Keyword args that define the configurations used for the crawler :return: True if crawler was updated and false otherwise """ - crawler_name = crawler_kwargs['Name'] + crawler_name = crawler_kwargs["Name"] current_crawler = self.get_crawler(crawler_name) update_config = { @@ -100,7 +96,7 @@ def create_crawler(self, **crawler_kwargs) -> str: :param crawler_kwargs: Keyword args that define the configurations used to create the crawler :return: Name of the crawler """ - crawler_name = crawler_kwargs['Name'] + crawler_name = crawler_kwargs["Name"] self.log.info("Creating crawler: %s", crawler_name) return self.glue_client.create_crawler(**crawler_kwargs) @@ -124,26 +120,26 @@ def wait_for_crawler_completion(self, crawler_name: str, poll_interval: int = 5) :param poll_interval: Time (in seconds) to wait between two consecutive calls to check crawler status :return: Crawler's status """ - failed_status = ['FAILED', 'CANCELLED'] + failed_status = ["FAILED", "CANCELLED"] while True: crawler = self.get_crawler(crawler_name) - crawler_state = crawler['State'] - if crawler_state == 'READY': + crawler_state = crawler["State"] + if crawler_state == "READY": self.log.info("State: %s", crawler_state) self.log.info("crawler_config: %s", crawler) - crawler_status = crawler['LastCrawl']['Status'] + crawler_status = crawler["LastCrawl"]["Status"] if crawler_status in failed_status: raise AirflowException(f"Status: {crawler_status}") metrics = self.glue_client.get_crawler_metrics(CrawlerNameList=[crawler_name])[ - 'CrawlerMetricsList' + "CrawlerMetricsList" ][0] self.log.info("Status: %s", crawler_status) - self.log.info("Last Runtime Duration (seconds): %s", metrics['LastRuntimeSeconds']) - self.log.info("Median Runtime Duration (seconds): %s", metrics['MedianRuntimeSeconds']) - self.log.info("Tables Created: %s", metrics['TablesCreated']) - self.log.info("Tables Updated: %s", metrics['TablesUpdated']) - self.log.info("Tables Deleted: %s", metrics['TablesDeleted']) + self.log.info("Last Runtime Duration (seconds): %s", metrics["LastRuntimeSeconds"]) + self.log.info("Median Runtime Duration (seconds): %s", metrics["MedianRuntimeSeconds"]) + self.log.info("Tables Created: %s", metrics["TablesCreated"]) + self.log.info("Tables Updated: %s", metrics["TablesUpdated"]) + self.log.info("Tables Deleted: %s", metrics["TablesDeleted"]) return crawler_status @@ -152,9 +148,9 @@ def wait_for_crawler_completion(self, crawler_name: str, poll_interval: int = 5) self.log.info("State: %s", crawler_state) metrics = self.glue_client.get_crawler_metrics(CrawlerNameList=[crawler_name])[ - 'CrawlerMetricsList' + "CrawlerMetricsList" ][0] - time_left = int(metrics['TimeLeftSeconds']) + time_left = int(metrics["TimeLeftSeconds"]) if time_left > 0: self.log.info("Estimated Time Left (seconds): %s", time_left) @@ -162,19 +158,3 @@ def wait_for_crawler_completion(self, crawler_name: str, poll_interval: int = 5) self.log.info("Crawler should finish soon") sleep(poll_interval) - - -class AwsGlueCrawlerHook(GlueCrawlerHook): - """ - This hook is deprecated. - Please use :class:`airflow.providers.amazon.aws.hooks.glue_crawler.GlueCrawlerHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This hook is deprecated. " - "Please use :class:`airflow.providers.amazon.aws.hooks.glue_crawler.GlueCrawlerHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/hooks/kinesis.py b/airflow/providers/amazon/aws/hooks/kinesis.py index f15457c6c28d7..71fe675466005 100644 --- a/airflow/providers/amazon/aws/hooks/kinesis.py +++ b/airflow/providers/amazon/aws/hooks/kinesis.py @@ -15,9 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains AWS Firehose hook""" -import warnings +from __future__ import annotations + from typing import Iterable from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -44,19 +44,3 @@ def __init__(self, delivery_stream: str, *args, **kwargs) -> None: def put_records(self, records: Iterable): """Write batch records to Kinesis Firehose""" return self.get_conn().put_record_batch(DeliveryStreamName=self.delivery_stream, Records=records) - - -class AwsFirehoseHook(FirehoseHook): - """ - This hook is deprecated. - Please use :class:`airflow.providers.amazon.aws.hooks.kinesis.FirehoseHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This hook is deprecated. " - "Please use :class:`airflow.providers.amazon.aws.hooks.kinesis.FirehoseHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/hooks/lambda_function.py b/airflow/providers/amazon/aws/hooks/lambda_function.py index b6819d9dba263..2919d37701aed 100644 --- a/airflow/providers/amazon/aws/hooks/lambda_function.py +++ b/airflow/providers/amazon/aws/hooks/lambda_function.py @@ -15,10 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains AWS Lambda hook""" -import warnings -from typing import Any, List, Optional +from __future__ import annotations + +from typing import Any from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -51,11 +51,11 @@ def invoke_lambda( self, *, function_name: str, - invocation_type: Optional[str] = None, - log_type: Optional[str] = None, - client_context: Optional[str] = None, - payload: Optional[str] = None, - qualifier: Optional[str] = None, + invocation_type: str | None = None, + log_type: str | None = None, + client_context: str | None = None, + payload: str | None = None, + qualifier: str | None = None, ): """Invoke Lambda Function. Refer to the boto3 documentation for more info.""" invoke_args = { @@ -76,22 +76,22 @@ def create_lambda( role: str, handler: str, code: dict, - description: Optional[str] = None, - timeout: Optional[int] = None, - memory_size: Optional[int] = None, - publish: Optional[bool] = None, - vpc_config: Optional[Any] = None, - package_type: Optional[str] = None, - dead_letter_config: Optional[Any] = None, - environment: Optional[Any] = None, - kms_key_arn: Optional[str] = None, - tracing_config: Optional[Any] = None, - tags: Optional[Any] = None, - layers: Optional[list] = None, - file_system_configs: Optional[List[Any]] = None, - image_config: Optional[Any] = None, - code_signing_config_arn: Optional[str] = None, - architectures: Optional[List[str]] = None, + description: str | None = None, + timeout: int | None = None, + memory_size: int | None = None, + publish: bool | None = None, + vpc_config: Any | None = None, + package_type: str | None = None, + dead_letter_config: Any | None = None, + environment: Any | None = None, + kms_key_arn: str | None = None, + tracing_config: Any | None = None, + tags: Any | None = None, + layers: list | None = None, + file_system_configs: list[Any] | None = None, + image_config: Any | None = None, + code_signing_config_arn: str | None = None, + architectures: list[str] | None = None, ) -> dict: """Create a Lambda Function""" create_function_args = { @@ -120,19 +120,3 @@ def create_lambda( return self.conn.create_function( **{k: v for k, v in create_function_args.items() if v is not None}, ) - - -class AwsLambdaHook(LambdaHook): - """ - This hook is deprecated. - Please use :class:`airflow.providers.amazon.aws.hooks.lambda_function.LambdaHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This hook is deprecated. " - "Please use :class:`airflow.providers.amazon.aws.hooks.lambda_function.LambdaHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/hooks/logs.py b/airflow/providers/amazon/aws/hooks/logs.py index 1c5e5e62fb709..e19ff275efc8c 100644 --- a/airflow/providers/amazon/aws/hooks/logs.py +++ b/airflow/providers/amazon/aws/hooks/logs.py @@ -15,12 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ This module contains a hook (AwsLogsHook) with some very basic functionality for interacting with AWS CloudWatch. """ -from typing import Dict, Generator, Optional +from __future__ import annotations + +from typing import Generator from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -59,7 +60,6 @@ def get_log_events( This is for when there are multiple entries at the same timestamp. :param start_from_head: whether to start from the beginning (True) of the log or at the end of the log (False). - :rtype: dict :return: | A CloudWatch log event with the following key-value pairs: | 'timestamp' (int): The time in milliseconds of the event. | 'message' (str): The log event data. @@ -68,7 +68,7 @@ def get_log_events( next_token = None while True: if next_token is not None: - token_arg: Optional[Dict[str, str]] = {'nextToken': next_token} + token_arg: dict[str, str] | None = {"nextToken": next_token} else: token_arg = {} @@ -80,7 +80,7 @@ def get_log_events( **token_arg, ) - events = response['events'] + events = response["events"] event_count = len(events) if event_count > skip: @@ -92,7 +92,7 @@ def get_log_events( yield from events - if next_token != response['nextForwardToken']: - next_token = response['nextForwardToken'] + if next_token != response["nextForwardToken"]: + next_token = response["nextForwardToken"] else: return diff --git a/airflow/providers/amazon/aws/hooks/quicksight.py b/airflow/providers/amazon/aws/hooks/quicksight.py index a7e90c36cf92a..3a03f3a3be2da 100644 --- a/airflow/providers/amazon/aws/hooks/quicksight.py +++ b/airflow/providers/amazon/aws/hooks/quicksight.py @@ -15,21 +15,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -import sys import time from botocore.exceptions import ClientError from airflow import AirflowException +from airflow.compat.functools import cached_property from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.hooks.sts import StsHook -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - class QuickSightHook(AwsBaseHook): """ @@ -58,7 +54,7 @@ def create_ingestion( ingestion_type: str, wait_for_completion: bool = True, check_interval: int = 30, - ): + ) -> dict: """ Creates and starts a new SPICE ingestion for a dataset. Refreshes the SPICE datasets @@ -70,9 +66,7 @@ def create_ingestion( will check the status of QuickSight Ingestion :return: Returns descriptive information about the created data ingestion having Ingestion ARN, HTTP status, ingestion ID and ingestion status. - :rtype: Dict """ - self.log.info("Creating QuickSight Ingestion for data set id %s.", data_set_id) quicksight_client = self.get_conn() try: @@ -97,7 +91,7 @@ def create_ingestion( self.log.error("Failed to run Amazon QuickSight create_ingestion API, error: %s", general_error) raise - def get_status(self, aws_account_id: str, data_set_id: str, ingestion_id: str): + def get_status(self, aws_account_id: str, data_set_id: str, ingestion_id: str) -> str: """ Get the current status of QuickSight Create Ingestion API. @@ -105,7 +99,6 @@ def get_status(self, aws_account_id: str, data_set_id: str, ingestion_id: str): :param data_set_id: QuickSight Data Set ID :param ingestion_id: QuickSight Ingestion ID :return: An QuickSight Ingestion Status - :rtype: str """ try: describe_ingestion_response = self.get_conn().describe_ingestion( @@ -136,7 +129,6 @@ def wait_for_state( will check the status of QuickSight Ingestion :return: response of describe_ingestion call after Ingestion is is done """ - sec = 0 status = self.get_status(aws_account_id, data_set_id, ingestion_id) while status in self.NON_TERMINAL_STATES and status != target_state: diff --git a/airflow/providers/amazon/aws/hooks/rds.py b/airflow/providers/amazon/aws/hooks/rds.py index 3539e951cade6..15790e3add96b 100644 --- a/airflow/providers/amazon/aws/hooks/rds.py +++ b/airflow/providers/amazon/aws/hooks/rds.py @@ -16,16 +16,19 @@ # specific language governing permissions and limitations # under the License. """Interact with AWS RDS.""" +from __future__ import annotations -from typing import TYPE_CHECKING +import time +from typing import TYPE_CHECKING, Callable -from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.exceptions import AirflowException, AirflowNotFoundException +from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook if TYPE_CHECKING: - from mypy_boto3_rds import RDSClient + from mypy_boto3_rds import RDSClient # noqa -class RdsHook(AwsBaseHook): +class RdsHook(AwsGenericHook["RDSClient"]): """ Interact with AWS RDS using proper client from the boto3 library. @@ -39,7 +42,7 @@ class RdsHook(AwsBaseHook): are passed down to the underlying AwsBaseHook. .. seealso:: - :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook` :param aws_conn_id: The Airflow connection used for AWS credentials. """ @@ -48,12 +51,299 @@ def __init__(self, *args, **kwargs) -> None: kwargs["client_type"] = "rds" super().__init__(*args, **kwargs) - @property - def conn(self) -> 'RDSClient': + def get_db_snapshot_state(self, snapshot_id: str) -> str: """ - Get the underlying boto3 RDS client (cached) + Get the current state of a DB instance snapshot. - :return: boto3 RDS client - :rtype: botocore.client.RDS + :param snapshot_id: The ID of the target DB instance snapshot + :return: Returns the status of the DB snapshot as a string (eg. "available") + :rtype: str + :raises AirflowNotFoundException: If the DB instance snapshot does not exist. """ - return super().conn + try: + response = self.conn.describe_db_snapshots(DBSnapshotIdentifier=snapshot_id) + except self.conn.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "DBSnapshotNotFound": + raise AirflowNotFoundException(e) + raise e + return response["DBSnapshots"][0]["Status"].lower() + + def wait_for_db_snapshot_state( + self, snapshot_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40 + ) -> None: + """ + Polls :py:meth:`RDS.Client.describe_db_snapshots` until the target state is reached. + An error is raised after a max number of attempts. + + :param snapshot_id: The ID of the target DB instance snapshot + :param target_state: Wait until this state is reached + :param check_interval: The amount of time in seconds to wait between attempts + :param max_attempts: The maximum number of attempts to be made + """ + + def poke(): + return self.get_db_snapshot_state(snapshot_id) + + target_state = target_state.lower() + if target_state in ("available", "deleted", "completed"): + waiter = self.conn.get_waiter(f"db_snapshot_{target_state}") # type: ignore + waiter.wait( + DBSnapshotIdentifier=snapshot_id, + WaiterConfig={"Delay": check_interval, "MaxAttempts": max_attempts}, + ) + else: + self._wait_for_state(poke, target_state, check_interval, max_attempts) + self.log.info("DB snapshot '%s' reached the '%s' state", snapshot_id, target_state) + + def get_db_cluster_snapshot_state(self, snapshot_id: str) -> str: + """ + Get the current state of a DB cluster snapshot. + + :param snapshot_id: The ID of the target DB cluster. + :return: Returns the status of the DB cluster snapshot as a string (eg. "available") + :rtype: str + :raises AirflowNotFoundException: If the DB cluster snapshot does not exist. + """ + try: + response = self.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=snapshot_id) + except self.conn.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "DBClusterSnapshotNotFoundFault": + raise AirflowNotFoundException(e) + raise e + return response["DBClusterSnapshots"][0]["Status"].lower() + + def wait_for_db_cluster_snapshot_state( + self, snapshot_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40 + ) -> None: + """ + Polls :py:meth:`RDS.Client.describe_db_cluster_snapshots` until the target state is reached. + An error is raised after a max number of attempts. + + :param snapshot_id: The ID of the target DB cluster snapshot + :param target_state: Wait until this state is reached + :param check_interval: The amount of time in seconds to wait between attempts + :param max_attempts: The maximum number of attempts to be made + + .. seealso:: + A list of possible values for target_state: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.describe_db_cluster_snapshots + """ + + def poke(): + return self.get_db_cluster_snapshot_state(snapshot_id) + + target_state = target_state.lower() + if target_state in ("available", "deleted"): + waiter = self.conn.get_waiter(f"db_cluster_snapshot_{target_state}") # type: ignore + waiter.wait( + DBClusterSnapshotIdentifier=snapshot_id, + WaiterConfig={"Delay": check_interval, "MaxAttempts": max_attempts}, + ) + else: + self._wait_for_state(poke, target_state, check_interval, max_attempts) + self.log.info("DB cluster snapshot '%s' reached the '%s' state", snapshot_id, target_state) + + def get_export_task_state(self, export_task_id: str) -> str: + """ + Gets the current state of an RDS snapshot export to Amazon S3. + + :param export_task_id: The identifier of the target snapshot export task. + :return: Returns the status of the snapshot export task as a string (eg. "canceled") + :rtype: str + :raises AirflowNotFoundException: If the export task does not exist. + """ + try: + response = self.conn.describe_export_tasks(ExportTaskIdentifier=export_task_id) + except self.conn.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "ExportTaskNotFoundFault": + raise AirflowNotFoundException(e) + raise e + return response["ExportTasks"][0]["Status"].lower() + + def wait_for_export_task_state( + self, export_task_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40 + ) -> None: + """ + Polls :py:meth:`RDS.Client.describe_export_tasks` until the target state is reached. + An error is raised after a max number of attempts. + + :param export_task_id: The identifier of the target snapshot export task. + :param target_state: Wait until this state is reached + :param check_interval: The amount of time in seconds to wait between attempts + :param max_attempts: The maximum number of attempts to be made + + .. seealso:: + A list of possible values for target_state: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.describe_export_tasks + """ + + def poke(): + return self.get_export_task_state(export_task_id) + + target_state = target_state.lower() + self._wait_for_state(poke, target_state, check_interval, max_attempts) + self.log.info("export task '%s' reached the '%s' state", export_task_id, target_state) + + def get_event_subscription_state(self, subscription_name: str) -> str: + """ + Gets the current state of an RDS snapshot export to Amazon S3. + + :param subscription_name: The name of the target RDS event notification subscription. + :return: Returns the status of the event subscription as a string (eg. "active") + :rtype: str + :raises AirflowNotFoundException: If the event subscription does not exist. + """ + try: + response = self.conn.describe_event_subscriptions(SubscriptionName=subscription_name) + except self.conn.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "SubscriptionNotFoundFault": + raise AirflowNotFoundException(e) + raise e + return response["EventSubscriptionsList"][0]["Status"].lower() + + def wait_for_event_subscription_state( + self, subscription_name: str, target_state: str, check_interval: int = 30, max_attempts: int = 40 + ) -> None: + """ + Polls :py:meth:`RDS.Client.describe_event_subscriptions` until the target state is reached. + An error is raised after a max number of attempts. + + :param subscription_name: The name of the target RDS event notification subscription. + :param target_state: Wait until this state is reached + :param check_interval: The amount of time in seconds to wait between attempts + :param max_attempts: The maximum number of attempts to be made + + .. seealso:: + A list of possible values for target_state: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.describe_event_subscriptions + """ + + def poke(): + return self.get_event_subscription_state(subscription_name) + + target_state = target_state.lower() + self._wait_for_state(poke, target_state, check_interval, max_attempts) + self.log.info("event subscription '%s' reached the '%s' state", subscription_name, target_state) + + def get_db_instance_state(self, db_instance_id: str) -> str: + """ + Get the current state of a DB instance. + + :param snapshot_id: The ID of the target DB instance. + :return: Returns the status of the DB instance as a string (eg. "available") + :rtype: str + :raises AirflowNotFoundException: If the DB instance does not exist. + """ + try: + response = self.conn.describe_db_instances(DBInstanceIdentifier=db_instance_id) + except self.conn.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "DBInstanceNotFoundFault": + raise AirflowNotFoundException(e) + raise e + return response["DBInstances"][0]["DBInstanceStatus"].lower() + + def wait_for_db_instance_state( + self, db_instance_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40 + ) -> None: + """ + Polls :py:meth:`RDS.Client.describe_db_instances` until the target state is reached. + An error is raised after a max number of attempts. + + :param db_instance_id: The ID of the target DB instance. + :param target_state: Wait until this state is reached + :param check_interval: The amount of time in seconds to wait between attempts + :param max_attempts: The maximum number of attempts to be made + + .. seealso:: + For information about DB instance statuses, see Viewing DB instance status in the Amazon RDS + User Guide. + https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/accessing-monitoring.html#Overview.DBInstance.Status + """ + + def poke(): + return self.get_db_instance_state(db_instance_id) + + target_state = target_state.lower() + if target_state in ("available", "deleted"): + waiter = self.conn.get_waiter(f"db_instance_{target_state}") # type: ignore + waiter.wait( + DBInstanceIdentifier=db_instance_id, + WaiterConfig={"Delay": check_interval, "MaxAttempts": max_attempts}, + ) + else: + self._wait_for_state(poke, target_state, check_interval, max_attempts) + self.log.info("DB cluster snapshot '%s' reached the '%s' state", db_instance_id, target_state) + + def get_db_cluster_state(self, db_cluster_id: str) -> str: + """ + Get the current state of a DB cluster. + + :param snapshot_id: The ID of the target DB cluster. + :return: Returns the status of the DB cluster as a string (eg. "available") + :rtype: str + :raises AirflowNotFoundException: If the DB cluster does not exist. + """ + try: + response = self.conn.describe_db_clusters(DBClusterIdentifier=db_cluster_id) + except self.conn.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "DBClusterNotFoundFault": + raise AirflowNotFoundException(e) + raise e + return response["DBClusters"][0]["Status"].lower() + + def wait_for_db_cluster_state( + self, db_cluster_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40 + ) -> None: + """ + Polls :py:meth:`RDS.Client.describe_db_clusters` until the target state is reached. + An error is raised after a max number of attempts. + + :param db_cluster_id: The ID of the target DB cluster. + :param target_state: Wait until this state is reached + :param check_interval: The amount of time in seconds to wait between attempts + :param max_attempts: The maximum number of attempts to be made + + .. seealso:: + For information about DB instance statuses, see Viewing DB instance status in the Amazon RDS + User Guide. + https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/accessing-monitoring.html#Overview.DBInstance.Status + """ + + def poke(): + return self.get_db_cluster_state(db_cluster_id) + + target_state = target_state.lower() + if target_state in ("available", "deleted"): + waiter = self.conn.get_waiter(f"db_cluster_{target_state}") # type: ignore + waiter.wait( + DBClusterIdentifier=db_cluster_id, + WaiterConfig={"Delay": check_interval, "MaxAttempts": max_attempts}, + ) + else: + self._wait_for_state(poke, target_state, check_interval, max_attempts) + self.log.info("DB cluster snapshot '%s' reached the '%s' state", db_cluster_id, target_state) + + def _wait_for_state( + self, + poke: Callable[..., str], + target_state: str, + check_interval: int, + max_attempts: int, + ) -> None: + """ + Polls the poke function for the current state until it reaches the target_state. + + :param poke: A function that returns the current state of the target resource as a string. + :param target_state: Wait until this state is reached + :param check_interval: The amount of time in seconds to wait between attempts + :param max_attempts: The maximum number of attempts to be made + """ + state = poke() + tries = 1 + while state != target_state: + self.log.info("Current state is %s", state) + if tries >= max_attempts: + raise AirflowException("Max attempts exceeded") + time.sleep(check_interval) + state = poke() + tries += 1 diff --git a/airflow/providers/amazon/aws/hooks/redshift.py b/airflow/providers/amazon/aws/hooks/redshift.py deleted file mode 100644 index 1766564d273b5..0000000000000 --- a/airflow/providers/amazon/aws/hooks/redshift.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Interact with AWS Redshift clusters.""" -import warnings - -from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook -from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.redshift_cluster` " - "or `airflow.providers.amazon.aws.hooks.redshift_sql` as appropriate.", - DeprecationWarning, - stacklevel=2, -) - -__all__ = ["RedshiftHook", "RedshiftSQLHook"] diff --git a/airflow/providers/amazon/aws/hooks/redshift_cluster.py b/airflow/providers/amazon/aws/hooks/redshift_cluster.py index c7b3c51722141..d85929d062024 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_cluster.py +++ b/airflow/providers/amazon/aws/hooks/redshift_cluster.py @@ -14,8 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import Any, Dict, List, Optional +import warnings +from typing import Any, Sequence from botocore.exceptions import ClientError @@ -35,6 +37,8 @@ class RedshiftHook(AwsBaseHook): :param aws_conn_id: The Airflow connection used for AWS credentials. """ + template_fields: Sequence[str] = ("cluster_identifier",) + def __init__(self, *args, **kwargs) -> None: kwargs["client_type"] = "redshift" super().__init__(*args, **kwargs) @@ -45,8 +49,8 @@ def create_cluster( node_type: str, master_username: str, master_user_password: str, - params: Dict[str, Any], - ) -> Dict[str, Any]: + params: dict[str, Any], + ) -> dict[str, Any]: """ Creates a new cluster with the specified parameters @@ -83,16 +87,16 @@ def cluster_status(self, cluster_identifier: str) -> str: :param final_cluster_snapshot_identifier: Optional[str] """ try: - response = self.get_conn().describe_clusters(ClusterIdentifier=cluster_identifier)['Clusters'] - return response[0]['ClusterStatus'] if response else None + response = self.get_conn().describe_clusters(ClusterIdentifier=cluster_identifier)["Clusters"] + return response[0]["ClusterStatus"] if response else None except self.get_conn().exceptions.ClusterNotFoundFault: - return 'cluster_not_found' + return "cluster_not_found" def delete_cluster( self, cluster_identifier: str, skip_final_cluster_snapshot: bool = True, - final_cluster_snapshot_identifier: Optional[str] = None, + final_cluster_snapshot_identifier: str | None = None, ): """ Delete a cluster and optionally create a snapshot @@ -101,27 +105,27 @@ def delete_cluster( :param skip_final_cluster_snapshot: determines cluster snapshot creation :param final_cluster_snapshot_identifier: name of final cluster snapshot """ - final_cluster_snapshot_identifier = final_cluster_snapshot_identifier or '' + final_cluster_snapshot_identifier = final_cluster_snapshot_identifier or "" response = self.get_conn().delete_cluster( ClusterIdentifier=cluster_identifier, SkipFinalClusterSnapshot=skip_final_cluster_snapshot, FinalClusterSnapshotIdentifier=final_cluster_snapshot_identifier, ) - return response['Cluster'] if response['Cluster'] else None + return response["Cluster"] if response["Cluster"] else None - def describe_cluster_snapshots(self, cluster_identifier: str) -> Optional[List[str]]: + def describe_cluster_snapshots(self, cluster_identifier: str) -> list[str] | None: """ Gets a list of snapshots for a cluster :param cluster_identifier: unique identifier of a cluster """ response = self.get_conn().describe_cluster_snapshots(ClusterIdentifier=cluster_identifier) - if 'Snapshots' not in response: + if "Snapshots" not in response: return None - snapshots = response['Snapshots'] + snapshots = response["Snapshots"] snapshots = [snapshot for snapshot in snapshots if snapshot["Status"]] - snapshots.sort(key=lambda x: x['SnapshotCreateTime'], reverse=True) + snapshots.sort(key=lambda x: x["SnapshotCreateTime"], reverse=True) return snapshots def restore_from_cluster_snapshot(self, cluster_identifier: str, snapshot_identifier: str) -> str: @@ -134,17 +138,48 @@ def restore_from_cluster_snapshot(self, cluster_identifier: str, snapshot_identi response = self.get_conn().restore_from_cluster_snapshot( ClusterIdentifier=cluster_identifier, SnapshotIdentifier=snapshot_identifier ) - return response['Cluster'] if response['Cluster'] else None + return response["Cluster"] if response["Cluster"] else None - def create_cluster_snapshot(self, snapshot_identifier: str, cluster_identifier: str) -> str: + def create_cluster_snapshot( + self, snapshot_identifier: str, cluster_identifier: str, retention_period: int = -1 + ) -> str: """ Creates a snapshot of a cluster :param snapshot_identifier: unique identifier for a snapshot of a cluster :param cluster_identifier: unique identifier of a cluster + :param retention_period: The number of days that a manual snapshot is retained. + If the value is -1, the manual snapshot is retained indefinitely. """ response = self.get_conn().create_cluster_snapshot( SnapshotIdentifier=snapshot_identifier, ClusterIdentifier=cluster_identifier, + ManualSnapshotRetentionPeriod=retention_period, ) - return response['Snapshot'] if response['Snapshot'] else None + return response["Snapshot"] if response["Snapshot"] else None + + def get_cluster_snapshot_status(self, snapshot_identifier: str, cluster_identifier: str | None = None): + """ + Return Redshift cluster snapshot status. If cluster snapshot not found return ``None`` + + :param snapshot_identifier: A unique identifier for the snapshot that you are requesting + :param cluster_identifier: (deprecated) The unique identifier of the cluster + the snapshot was created from + """ + if cluster_identifier: + warnings.warn( + "Parameter `cluster_identifier` is deprecated." + "This option will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + + try: + response = self.get_conn().describe_cluster_snapshots( + SnapshotIdentifier=snapshot_identifier, + ) + snapshot = response.get("Snapshots")[0] + snapshot_status: str = snapshot.get("Status") + return snapshot_status + except self.get_conn().exceptions.ClusterSnapshotNotFoundFault: + return None diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py b/airflow/providers/amazon/aws/hooks/redshift_data.py index 74459f58a5f8a..a2cb172979107 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_data.py +++ b/airflow/providers/amazon/aws/hooks/redshift_data.py @@ -15,16 +15,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from typing import TYPE_CHECKING -from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook if TYPE_CHECKING: - from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient + from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient # noqa -class RedshiftDataHook(AwsBaseHook): +class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]): """ Interact with AWS Redshift Data, using the boto3 library Hook attribute `conn` has all methods that listed in documentation @@ -37,7 +38,7 @@ class RedshiftDataHook(AwsBaseHook): are passed down to the underlying AwsBaseHook. .. seealso:: - :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook` :param aws_conn_id: The Airflow connection used for AWS credentials. """ @@ -45,8 +46,3 @@ class RedshiftDataHook(AwsBaseHook): def __init__(self, *args, **kwargs) -> None: kwargs["client_type"] = "redshift-data" super().__init__(*args, **kwargs) - - @property - def conn(self) -> 'RedshiftDataAPIServiceClient': - """Get the underlying boto3 RedshiftDataAPIService client (cached)""" - return super().conn diff --git a/airflow/providers/amazon/aws/hooks/redshift_sql.py b/airflow/providers/amazon/aws/hooks/redshift_sql.py index 03bb45f7ee128..120ce190ccb1b 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_sql.py +++ b/airflow/providers/amazon/aws/hooks/redshift_sql.py @@ -14,21 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -import sys -from typing import Dict, List, Optional, Union +from __future__ import annotations import redshift_connector from redshift_connector import Connection as RedshiftConnection from sqlalchemy import create_engine from sqlalchemy.engine.url import URL -from airflow.hooks.dbapi import DbApiHook - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property +from airflow.compat.functools import cached_property +from airflow.providers.common.sql.hooks.sql import DbApiHook class RedshiftSQLHook(DbApiHook): @@ -44,40 +38,40 @@ class RedshiftSQLHook(DbApiHook): get_sqlalchemy_engine() and get_uri() depend on sqlalchemy-amazon-redshift """ - conn_name_attr = 'redshift_conn_id' - default_conn_name = 'redshift_default' - conn_type = 'redshift' - hook_name = 'Amazon Redshift' + conn_name_attr = "redshift_conn_id" + default_conn_name = "redshift_default" + conn_type = "redshift" + hook_name = "Amazon Redshift" supports_autocommit = True @staticmethod - def get_ui_field_behavior() -> Dict: + def get_ui_field_behaviour() -> dict: """Returns custom field behavior""" return { "hidden_fields": [], - "relabeling": {'login': 'User', 'schema': 'Database'}, + "relabeling": {"login": "User", "schema": "Database"}, } @cached_property def conn(self): return self.get_connection(self.redshift_conn_id) # type: ignore[attr-defined] - def _get_conn_params(self) -> Dict[str, Union[str, int]]: + def _get_conn_params(self) -> dict[str, str | int]: """Helper method to retrieve connection args""" conn = self.conn - conn_params: Dict[str, Union[str, int]] = {} + conn_params: dict[str, str | int] = {} if conn.login: - conn_params['user'] = conn.login + conn_params["user"] = conn.login if conn.password: - conn_params['password'] = conn.password + conn_params["password"] = conn.password if conn.host: - conn_params['host'] = conn.host + conn_params["host"] = conn.host if conn.port: - conn_params['port'] = conn.port + conn_params["port"] = conn.port if conn.schema: - conn_params['database'] = conn.schema + conn_params["database"] = conn.schema return conn_params @@ -85,10 +79,13 @@ def get_uri(self) -> str: """Overrides DbApiHook get_uri to use redshift_connector sqlalchemy dialect as driver name""" conn_params = self._get_conn_params() - if 'user' in conn_params: - conn_params['username'] = conn_params.pop('user') + if "user" in conn_params: + conn_params["username"] = conn_params.pop("user") - return str(URL(drivername='redshift+redshift_connector', **conn_params)) + # Compatibility: The 'create' factory method was added in SQLAlchemy 1.4 + # to replace calling the default URL constructor directly. + create_url = getattr(URL, "create", URL) + return str(create_url(drivername="redshift+redshift_connector", **conn_params)) def get_sqlalchemy_engine(self, engine_kwargs=None): """Overrides DbApiHook get_sqlalchemy_engine to pass redshift_connector specific kwargs""" @@ -103,13 +100,12 @@ def get_sqlalchemy_engine(self, engine_kwargs=None): return create_engine(self.get_uri(), **engine_kwargs) - def get_table_primary_key(self, table: str, schema: Optional[str] = "public") -> Optional[List[str]]: + def get_table_primary_key(self, table: str, schema: str | None = "public") -> list[str] | None: """ Helper method that returns the table primary key :param table: Name of the target table - :param table: Name of the target schema, public by default + :param schema: Name of the target schema, public by default :return: Primary key columns list - :rtype: List[str] """ sql = """ select kcu.column_name @@ -129,5 +125,5 @@ def get_conn(self) -> RedshiftConnection: """Returns a redshift_connector.Connection object""" conn_params = self._get_conn_params() conn_kwargs_dejson = self.conn.extra_dejson - conn_kwargs: Dict = {**conn_params, **conn_kwargs_dejson} + conn_kwargs: dict = {**conn_params, **conn_kwargs_dejson} return redshift_connector.connect(**conn_kwargs) diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index e7e9f2de508da..e5b1bad5804f2 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -15,22 +15,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - - """Interact with AWS S3, using the boto3 library.""" +from __future__ import annotations + import fnmatch import gzip as gz import io import re import shutil +from copy import deepcopy from datetime import datetime from functools import wraps from inspect import signature from io import BytesIO from pathlib import Path -from tempfile import NamedTemporaryFile -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union, cast -from urllib.parse import urlparse +from tempfile import NamedTemporaryFile, gettempdir +from typing import Any, Callable, TypeVar, cast +from urllib.parse import urlsplit +from uuid import uuid4 from boto3.s3.transfer import S3Transfer, TransferConfig from botocore.exceptions import ClientError @@ -53,12 +55,12 @@ def provide_bucket_name(func: T) -> T: def wrapper(*args, **kwargs) -> T: bound_args = function_signature.bind(*args, **kwargs) - if 'bucket_name' not in bound_args.arguments: + if "bucket_name" not in bound_args.arguments: self = args[0] if self.aws_conn_id: connection = self.get_connection(self.aws_conn_id) if connection.schema: - bound_args.arguments['bucket_name'] = connection.schema + bound_args.arguments["bucket_name"] = connection.schema return func(*bound_args.args, **bound_args.kwargs) @@ -76,15 +78,15 @@ def unify_bucket_name_and_key(func: T) -> T: def wrapper(*args, **kwargs) -> T: bound_args = function_signature.bind(*args, **kwargs) - if 'wildcard_key' in bound_args.arguments: - key_name = 'wildcard_key' - elif 'key' in bound_args.arguments: - key_name = 'key' + if "wildcard_key" in bound_args.arguments: + key_name = "wildcard_key" + elif "key" in bound_args.arguments: + key_name = "key" else: - raise ValueError('Missing key parameter!') + raise ValueError("Missing key parameter!") - if 'bucket_name' not in bound_args.arguments: - bound_args.arguments['bucket_name'], bound_args.arguments[key_name] = S3Hook.parse_s3_url( + if "bucket_name" not in bound_args.arguments: + bound_args.arguments["bucket_name"], bound_args.arguments[key_name] = S3Hook.parse_s3_url( bound_args.arguments[key_name] ) @@ -97,6 +99,15 @@ class S3Hook(AwsBaseHook): """ Interact with AWS S3, using the boto3 library. + :param transfer_config_args: Configuration object for managed S3 transfers. + :param extra_args: Extra arguments that may be passed to the download/upload operations. + + .. seealso:: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#s3-transfers + + - For allowed upload extra arguments see ``boto3.s3.transfer.S3Transfer.ALLOWED_UPLOAD_ARGS``. + - For allowed download extra arguments see ``boto3.s3.transfer.S3Transfer.ALLOWED_DOWNLOAD_ARGS``. + Additional arguments (such as ``aws_conn_id``) may be specified and are passed down to the underlying AwsBaseHook. @@ -104,52 +115,67 @@ class S3Hook(AwsBaseHook): :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` """ - conn_type = 's3' - hook_name = 'Amazon S3' - - def __init__(self, *args, **kwargs) -> None: - kwargs['client_type'] = 's3' + def __init__( + self, + aws_conn_id: str | None = AwsBaseHook.default_conn_name, + transfer_config_args: dict | None = None, + extra_args: dict | None = None, + *args, + **kwargs, + ) -> None: + kwargs["client_type"] = "s3" + kwargs["aws_conn_id"] = aws_conn_id - self.extra_args = {} - if 'extra_args' in kwargs: - self.extra_args = kwargs['extra_args'] - if not isinstance(self.extra_args, dict): - raise ValueError(f"extra_args '{self.extra_args!r}' must be of type {dict}") - del kwargs['extra_args'] + if transfer_config_args and not isinstance(transfer_config_args, dict): + raise TypeError(f"transfer_config_args expected dict, got {type(transfer_config_args).__name__}.") + self.transfer_config = TransferConfig(**transfer_config_args or {}) - self.transfer_config = TransferConfig() - if 'transfer_config_args' in kwargs: - transport_config_args = kwargs['transfer_config_args'] - if not isinstance(transport_config_args, dict): - raise ValueError(f"transfer_config_args '{transport_config_args!r} must be of type {dict}") - self.transfer_config = TransferConfig(**transport_config_args) - del kwargs['transfer_config_args'] + if extra_args and not isinstance(extra_args, dict): + raise TypeError(f"extra_args expected dict, got {type(extra_args).__name__}.") + self._extra_args = extra_args or {} super().__init__(*args, **kwargs) + @property + def extra_args(self): + """Return hook's extra arguments (immutable).""" + return deepcopy(self._extra_args) + @staticmethod - def parse_s3_url(s3url: str) -> Tuple[str, str]: + def parse_s3_url(s3url: str) -> tuple[str, str]: """ Parses the S3 Url into a bucket name and key. + See https://docs.aws.amazon.com/AmazonS3/latest/userguide/access-bucket-intro.html + for valid url formats :param s3url: The S3 Url to parse. :return: the parsed bucket name and key - :rtype: tuple of str """ - parsed_url = urlparse(s3url) - - if not parsed_url.netloc: - raise AirflowException(f'Please provide a bucket_name instead of "{s3url}"') - - bucket_name = parsed_url.netloc - key = parsed_url.path.lstrip('/') - + format = s3url.split("//") + if format[0].lower() == "s3:": + parsed_url = urlsplit(s3url) + if not parsed_url.netloc: + raise AirflowException(f'Please provide a bucket name using a valid format: "{s3url}"') + + bucket_name = parsed_url.netloc + key = parsed_url.path.lstrip("/") + elif format[0] == "https:": + temp_split = format[1].split(".") + if temp_split[0] == "s3": + split_url = format[1].split("/") + bucket_name = split_url[1] + key = "/".join(split_url[2:]) + elif temp_split[1] == "s3": + bucket_name = temp_split[0] + key = "/".join(format[1].split("/")[1:]) + else: + raise AirflowException(f'Please provide a bucket name using a valid format: "{s3url}"') return bucket_name, key @staticmethod def get_s3_bucket_key( - bucket: Optional[str], key: str, bucket_param_name: str, key_param_name: str - ) -> Tuple[str, str]: + bucket: str | None, key: str, bucket_param_name: str, key_param_name: str + ) -> tuple[str, str]: """ Get the S3 bucket name and key from either: - bucket name and key. Return the info as it is after checking `key` is a relative path @@ -160,59 +186,64 @@ def get_s3_bucket_key( :param bucket_param_name: The parameter name containing the bucket name :param key_param_name: The parameter name containing the key name :return: the parsed bucket name and key - :rtype: tuple of str """ - if bucket is None: return S3Hook.parse_s3_url(key) - parsed_url = urlparse(key) - if parsed_url.scheme != '' or parsed_url.netloc != '': + parsed_url = urlsplit(key) + if parsed_url.scheme != "" or parsed_url.netloc != "": raise TypeError( - f'If `{bucket_param_name}` is provided, {key_param_name} should be a relative path ' - 'from root level, rather than a full s3:// url' + f"If `{bucket_param_name}` is provided, {key_param_name} should be a relative path " + "from root level, rather than a full s3:// url" ) return bucket, key @provide_bucket_name - def check_for_bucket(self, bucket_name: Optional[str] = None) -> bool: + def check_for_bucket(self, bucket_name: str | None = None) -> bool: """ Check if bucket_name exists. :param bucket_name: the name of the bucket :return: True if it exists and False if not. - :rtype: bool """ try: self.get_conn().head_bucket(Bucket=bucket_name) return True except ClientError as e: - self.log.error(e.response["Error"]["Message"]) + # The head_bucket api is odd in that it cannot return proper + # exception objects, so error codes must be used. Only 200, 404 and 403 + # are ever returned. See the following links for more details: + # https://github.com/boto/boto3/issues/2499 + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.head_bucket + return_code = int(e.response["Error"]["Code"]) + if return_code == 404: + self.log.info('Bucket "%s" does not exist', bucket_name) + elif return_code == 403: + self.log.error( + 'Access to bucket "%s" is forbidden or there was an error with the request', bucket_name + ) + self.log.error(e) return False @provide_bucket_name - def get_bucket(self, bucket_name: Optional[str] = None) -> object: + def get_bucket(self, bucket_name: str | None = None) -> object: """ Returns a boto3.S3.Bucket object :param bucket_name: the name of the bucket :return: the bucket object to the bucket name. - :rtype: boto3.S3.Bucket """ - # Buckets have no regions, and we cannot remove the region name from _get_credentials as we would - # break compatibility, so we set it explicitly to None. - session, endpoint_url = self._get_credentials(region_name=None) - s3_resource = session.resource( + s3_resource = self.get_session().resource( "s3", - endpoint_url=endpoint_url, + endpoint_url=self.conn_config.endpoint_url, config=self.config, verify=self.verify, ) return s3_resource.Bucket(bucket_name) @provide_bucket_name - def create_bucket(self, bucket_name: Optional[str] = None, region_name: Optional[str] = None) -> None: + def create_bucket(self, bucket_name: str | None = None, region_name: str | None = None) -> None: """ Creates an Amazon S3 bucket. @@ -220,16 +251,22 @@ def create_bucket(self, bucket_name: Optional[str] = None, region_name: Optional :param region_name: The name of the aws region in which to create the bucket. """ if not region_name: - region_name = self.get_conn().meta.region_name - if region_name == 'us-east-1': + if self.conn_region_name == "aws-global": + raise AirflowException( + "Unable to create bucket if `region_name` not set " + "and boto3 configured to use s3 regional endpoints." + ) + region_name = self.conn_region_name + + if region_name == "us-east-1": self.get_conn().create_bucket(Bucket=bucket_name) else: self.get_conn().create_bucket( - Bucket=bucket_name, CreateBucketConfiguration={'LocationConstraint': region_name} + Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region_name} ) @provide_bucket_name - def check_for_prefix(self, prefix: str, delimiter: str, bucket_name: Optional[str] = None) -> bool: + def check_for_prefix(self, prefix: str, delimiter: str, bucket_name: str | None = None) -> bool: """ Checks that a prefix exists in a bucket @@ -237,10 +274,9 @@ def check_for_prefix(self, prefix: str, delimiter: str, bucket_name: Optional[st :param prefix: a key prefix :param delimiter: the delimiter marks key hierarchy. :return: False if the prefix does not exist in the bucket and True if it does. - :rtype: bool """ prefix = prefix + delimiter if prefix[-1] != delimiter else prefix - prefix_split = re.split(fr'(\w+[{delimiter}])$', prefix, 1) + prefix_split = re.split(rf"(\w+[{delimiter}])$", prefix, 1) previous_level = prefix_split[0] plist = self.list_prefixes(bucket_name, previous_level, delimiter) return prefix in plist @@ -248,11 +284,11 @@ def check_for_prefix(self, prefix: str, delimiter: str, bucket_name: Optional[st @provide_bucket_name def list_prefixes( self, - bucket_name: Optional[str] = None, - prefix: Optional[str] = None, - delimiter: Optional[str] = None, - page_size: Optional[int] = None, - max_items: Optional[int] = None, + bucket_name: str | None = None, + prefix: str | None = None, + delimiter: str | None = None, + page_size: int | None = None, + max_items: int | None = None, ) -> list: """ Lists prefixes in a bucket under prefix @@ -263,29 +299,28 @@ def list_prefixes( :param page_size: pagination size :param max_items: maximum items to return :return: a list of matched prefixes - :rtype: list """ - prefix = prefix or '' - delimiter = delimiter or '' + prefix = prefix or "" + delimiter = delimiter or "" config = { - 'PageSize': page_size, - 'MaxItems': max_items, + "PageSize": page_size, + "MaxItems": max_items, } - paginator = self.get_conn().get_paginator('list_objects_v2') + paginator = self.get_conn().get_paginator("list_objects_v2") response = paginator.paginate( Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config ) - prefixes = [] # type: List[str] + prefixes: list[str] = [] for page in response: - if 'CommonPrefixes' in page: - prefixes.extend(common_prefix['Prefix'] for common_prefix in page['CommonPrefixes']) + if "CommonPrefixes" in page: + prefixes.extend(common_prefix["Prefix"] for common_prefix in page["CommonPrefixes"]) return prefixes def _list_key_object_filter( - self, keys: list, from_datetime: Optional[datetime] = None, to_datetime: Optional[datetime] = None + self, keys: list, from_datetime: datetime | None = None, to_datetime: datetime | None = None ) -> list: def _is_in_period(input_date: datetime) -> bool: if from_datetime is not None and input_date <= from_datetime: @@ -294,20 +329,20 @@ def _is_in_period(input_date: datetime) -> bool: return False return True - return [k['Key'] for k in keys if _is_in_period(k['LastModified'])] + return [k["Key"] for k in keys if _is_in_period(k["LastModified"])] @provide_bucket_name def list_keys( self, - bucket_name: Optional[str] = None, - prefix: Optional[str] = None, - delimiter: Optional[str] = None, - page_size: Optional[int] = None, - max_items: Optional[int] = None, - start_after_key: Optional[str] = None, - from_datetime: Optional[datetime] = None, - to_datetime: Optional[datetime] = None, - object_filter: Optional[Callable[..., list]] = None, + bucket_name: str | None = None, + prefix: str | None = None, + delimiter: str | None = None, + page_size: int | None = None, + max_items: int | None = None, + start_after_key: str | None = None, + from_datetime: datetime | None = None, + to_datetime: datetime | None = None, + object_filter: Callable[..., list] | None = None, ) -> list: """ Lists keys in a bucket under prefix and not containing delimiter @@ -331,8 +366,8 @@ def list_keys( def object_filter( keys: list, - from_datetime: Optional[datetime] = None, - to_datetime: Optional[datetime] = None, + from_datetime: datetime | None = None, + to_datetime: datetime | None = None, ) -> list: def _is_in_period(input_date: datetime) -> bool: if from_datetime is not None and input_date < from_datetime: @@ -345,18 +380,17 @@ def _is_in_period(input_date: datetime) -> bool: return [k["Key"] for k in keys if _is_in_period(k["LastModified"])] :return: a list of matched keys - :rtype: list """ - prefix = prefix or '' - delimiter = delimiter or '' - start_after_key = start_after_key or '' + prefix = prefix or "" + delimiter = delimiter or "" + start_after_key = start_after_key or "" self.object_filter_usr = object_filter config = { - 'PageSize': page_size, - 'MaxItems': max_items, + "PageSize": page_size, + "MaxItems": max_items, } - paginator = self.get_conn().get_paginator('list_objects_v2') + paginator = self.get_conn().get_paginator("list_objects_v2") response = paginator.paginate( Bucket=bucket_name, Prefix=prefix, @@ -365,10 +399,10 @@ def _is_in_period(input_date: datetime) -> bool: StartAfter=start_after_key, ) - keys = [] # type: List[str] + keys: list[str] = [] for page in response: - if 'Contents' in page: - keys.extend(iter(page['Contents'])) + if "Contents" in page: + keys.extend(iter(page["Contents"])) if self.object_filter_usr is not None: return self.object_filter_usr(keys, from_datetime, to_datetime) @@ -378,10 +412,10 @@ def _is_in_period(input_date: datetime) -> bool: def get_file_metadata( self, prefix: str, - bucket_name: Optional[str] = None, - page_size: Optional[int] = None, - max_items: Optional[int] = None, - ) -> List: + bucket_name: str | None = None, + page_size: int | None = None, + max_items: int | None = None, + ) -> list: """ Lists metadata objects in a bucket under prefix @@ -390,32 +424,30 @@ def get_file_metadata( :param page_size: pagination size :param max_items: maximum items to return :return: a list of metadata of objects - :rtype: list """ config = { - 'PageSize': page_size, - 'MaxItems': max_items, + "PageSize": page_size, + "MaxItems": max_items, } - paginator = self.get_conn().get_paginator('list_objects_v2') + paginator = self.get_conn().get_paginator("list_objects_v2") response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, PaginationConfig=config) files = [] for page in response: - if 'Contents' in page: - files += page['Contents'] + if "Contents" in page: + files += page["Contents"] return files @provide_bucket_name @unify_bucket_name_and_key - def head_object(self, key: str, bucket_name: Optional[str] = None) -> Optional[dict]: + def head_object(self, key: str, bucket_name: str | None = None) -> dict | None: """ Retrieves metadata of an object :param key: S3 key that will point to the file :param bucket_name: Name of the bucket in which the file is stored :return: metadata of an object - :rtype: dict """ try: return self.get_conn().head_object(Bucket=bucket_name, Key=key) @@ -427,35 +459,30 @@ def head_object(self, key: str, bucket_name: Optional[str] = None) -> Optional[d @provide_bucket_name @unify_bucket_name_and_key - def check_for_key(self, key: str, bucket_name: Optional[str] = None) -> bool: + def check_for_key(self, key: str, bucket_name: str | None = None) -> bool: """ Checks if a key exists in a bucket :param key: S3 key that will point to the file :param bucket_name: Name of the bucket in which the file is stored :return: True if the key exists and False if not. - :rtype: bool """ obj = self.head_object(key, bucket_name) return obj is not None @provide_bucket_name @unify_bucket_name_and_key - def get_key(self, key: str, bucket_name: Optional[str] = None) -> S3Transfer: + def get_key(self, key: str, bucket_name: str | None = None) -> S3Transfer: """ Returns a boto3.s3.Object :param key: the path to the key :param bucket_name: the name of the bucket :return: the key object from the bucket - :rtype: boto3.s3.Object """ - # Buckets have no regions, and we cannot remove the region name from _get_credentials as we would - # break compatibility, so we set it explicitly to None. - session, endpoint_url = self._get_credentials(region_name=None) - s3_resource = session.resource( + s3_resource = self.get_session().resource( "s3", - endpoint_url=endpoint_url, + endpoint_url=self.conn_config.endpoint_url, config=self.config, verify=self.verify, ) @@ -465,28 +492,27 @@ def get_key(self, key: str, bucket_name: Optional[str] = None) -> S3Transfer: @provide_bucket_name @unify_bucket_name_and_key - def read_key(self, key: str, bucket_name: Optional[str] = None) -> str: + def read_key(self, key: str, bucket_name: str | None = None) -> str: """ Reads a key from S3 :param key: S3 key that will point to the file :param bucket_name: Name of the bucket in which the file is stored :return: the content of the key - :rtype: str """ obj = self.get_key(key, bucket_name) - return obj.get()['Body'].read().decode('utf-8') + return obj.get()["Body"].read().decode("utf-8") @provide_bucket_name @unify_bucket_name_and_key def select_key( self, key: str, - bucket_name: Optional[str] = None, - expression: Optional[str] = None, - expression_type: Optional[str] = None, - input_serialization: Optional[Dict[str, Any]] = None, - output_serialization: Optional[Dict[str, Any]] = None, + bucket_name: str | None = None, + expression: str | None = None, + expression_type: str | None = None, + input_serialization: dict[str, Any] | None = None, + output_serialization: dict[str, Any] | None = None, ) -> str: """ Reads a key with S3 Select. @@ -498,19 +524,18 @@ def select_key( :param input_serialization: S3 Select input data serialization format :param output_serialization: S3 Select output data serialization format :return: retrieved subset of original data by S3 Select - :rtype: str .. seealso:: For more details about S3 Select parameters: - http://boto3.readthedocs.io/en/latest/reference/services/s3.html#S3.Client.select_object_content + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.select_object_content """ - expression = expression or 'SELECT * FROM S3Object' - expression_type = expression_type or 'SQL' + expression = expression or "SELECT * FROM S3Object" + expression_type = expression_type or "SQL" if input_serialization is None: - input_serialization = {'CSV': {}} + input_serialization = {"CSV": {}} if output_serialization is None: - output_serialization = {'CSV': {}} + output_serialization = {"CSV": {}} response = self.get_conn().select_object_content( Bucket=bucket_name, @@ -521,14 +546,14 @@ def select_key( OutputSerialization=output_serialization, ) - return b''.join( - event['Records']['Payload'] for event in response['Payload'] if 'Records' in event - ).decode('utf-8') + return b"".join( + event["Records"]["Payload"] for event in response["Payload"] if "Records" in event + ).decode("utf-8") @provide_bucket_name @unify_bucket_name_and_key def check_for_wildcard_key( - self, wildcard_key: str, bucket_name: Optional[str] = None, delimiter: str = '' + self, wildcard_key: str, bucket_name: str | None = None, delimiter: str = "" ) -> bool: """ Checks that a key matching a wildcard expression exists in a bucket @@ -537,7 +562,6 @@ def check_for_wildcard_key( :param bucket_name: the name of the bucket :param delimiter: the delimiter marks key hierarchy :return: True if a key exists and False if not. - :rtype: bool """ return ( self.get_wildcard_key(wildcard_key=wildcard_key, bucket_name=bucket_name, delimiter=delimiter) @@ -547,7 +571,7 @@ def check_for_wildcard_key( @provide_bucket_name @unify_bucket_name_and_key def get_wildcard_key( - self, wildcard_key: str, bucket_name: Optional[str] = None, delimiter: str = '' + self, wildcard_key: str, bucket_name: str | None = None, delimiter: str = "" ) -> S3Transfer: """ Returns a boto3.s3.Object object matching the wildcard expression @@ -556,9 +580,8 @@ def get_wildcard_key( :param bucket_name: the name of the bucket :param delimiter: the delimiter marks key hierarchy :return: the key object from the bucket or None if none has been found. - :rtype: boto3.s3.Object """ - prefix = re.split(r'[\[\*\?]', wildcard_key, 1)[0] + prefix = re.split(r"[\[\*\?]", wildcard_key, 1)[0] key_list = self.list_keys(bucket_name, prefix=prefix, delimiter=delimiter) key_matches = [k for k in key_list if fnmatch.fnmatch(k, wildcard_key)] if key_matches: @@ -569,13 +592,13 @@ def get_wildcard_key( @unify_bucket_name_and_key def load_file( self, - filename: Union[Path, str], + filename: Path | str, key: str, - bucket_name: Optional[str] = None, + bucket_name: str | None = None, replace: bool = False, encrypt: bool = False, gzip: bool = False, - acl_policy: Optional[str] = None, + acl_policy: str | None = None, ) -> None: """ Loads a local file to S3 @@ -598,15 +621,15 @@ def load_file( extra_args = self.extra_args if encrypt: - extra_args['ServerSideEncryption'] = "AES256" + extra_args["ServerSideEncryption"] = "AES256" if gzip: - with open(filename, 'rb') as f_in: - filename_gz = f'{f_in.name}.gz' - with gz.open(filename_gz, 'wb') as f_out: + with open(filename, "rb") as f_in: + filename_gz = f"{f_in.name}.gz" + with gz.open(filename_gz, "wb") as f_out: shutil.copyfileobj(f_in, f_out) filename = filename_gz if acl_policy: - extra_args['ACL'] = acl_policy + extra_args["ACL"] = acl_policy client = self.get_conn() client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args, Config=self.transfer_config) @@ -617,12 +640,12 @@ def load_string( self, string_data: str, key: str, - bucket_name: Optional[str] = None, + bucket_name: str | None = None, replace: bool = False, encrypt: bool = False, - encoding: Optional[str] = None, - acl_policy: Optional[str] = None, - compression: Optional[str] = None, + encoding: str | None = None, + acl_policy: str | None = None, + compression: str | None = None, ) -> None: """ Loads a string to S3 @@ -642,18 +665,18 @@ def load_string( object to be uploaded :param compression: Type of compression to use, currently only gzip is supported. """ - encoding = encoding or 'utf-8' + encoding = encoding or "utf-8" bytes_data = string_data.encode(encoding) # Compress string - available_compressions = ['gzip'] + available_compressions = ["gzip"] if compression is not None and compression not in available_compressions: raise NotImplementedError( f"Received {compression} compression type. " f"String can currently be compressed in {available_compressions} only." ) - if compression == 'gzip': + if compression == "gzip": bytes_data = gz.compress(bytes_data) file_obj = io.BytesIO(bytes_data) @@ -667,10 +690,10 @@ def load_bytes( self, bytes_data: bytes, key: str, - bucket_name: Optional[str] = None, + bucket_name: str | None = None, replace: bool = False, encrypt: bool = False, - acl_policy: Optional[str] = None, + acl_policy: str | None = None, ) -> None: """ Loads bytes to S3 @@ -698,10 +721,10 @@ def load_file_obj( self, file_obj: BytesIO, key: str, - bucket_name: Optional[str] = None, + bucket_name: str | None = None, replace: bool = False, encrypt: bool = False, - acl_policy: Optional[str] = None, + acl_policy: str | None = None, ) -> None: """ Loads a file object to S3 @@ -722,19 +745,19 @@ def _upload_file_obj( self, file_obj: BytesIO, key: str, - bucket_name: Optional[str] = None, + bucket_name: str | None = None, replace: bool = False, encrypt: bool = False, - acl_policy: Optional[str] = None, + acl_policy: str | None = None, ) -> None: if not replace and self.check_for_key(key, bucket_name): raise ValueError(f"The key {key} already exists.") extra_args = self.extra_args if encrypt: - extra_args['ServerSideEncryption'] = "AES256" + extra_args["ServerSideEncryption"] = "AES256" if acl_policy: - extra_args['ACL'] = acl_policy + extra_args["ACL"] = acl_policy client = self.get_conn() client.upload_fileobj( @@ -749,10 +772,10 @@ def copy_object( self, source_bucket_key: str, dest_bucket_key: str, - source_bucket_name: Optional[str] = None, - dest_bucket_name: Optional[str] = None, - source_version_id: Optional[str] = None, - acl_policy: Optional[str] = None, + source_bucket_name: str | None = None, + dest_bucket_name: str | None = None, + source_version_id: str | None = None, + acl_policy: str | None = None, ) -> None: """ Creates a copy of an object that is already stored in S3. @@ -779,17 +802,17 @@ def copy_object( :param acl_policy: The string to specify the canned ACL policy for the object to be copied which is private by default. """ - acl_policy = acl_policy or 'private' + acl_policy = acl_policy or "private" dest_bucket_name, dest_bucket_key = self.get_s3_bucket_key( - dest_bucket_name, dest_bucket_key, 'dest_bucket_name', 'dest_bucket_key' + dest_bucket_name, dest_bucket_key, "dest_bucket_name", "dest_bucket_key" ) source_bucket_name, source_bucket_key = self.get_s3_bucket_key( - source_bucket_name, source_bucket_key, 'source_bucket_name', 'source_bucket_key' + source_bucket_name, source_bucket_key, "source_bucket_name", "source_bucket_key" ) - copy_source = {'Bucket': source_bucket_name, 'Key': source_bucket_key, 'VersionId': source_version_id} + copy_source = {"Bucket": source_bucket_name, "Key": source_bucket_key, "VersionId": source_version_id} response = self.get_conn().copy_object( Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source, ACL=acl_policy ) @@ -803,7 +826,6 @@ def delete_bucket(self, bucket_name: str, force_delete: bool = False) -> None: :param bucket_name: Bucket name :param force_delete: Enable this to delete bucket even if not empty :return: None - :rtype: None """ if force_delete: bucket_keys = self.list_keys(bucket_name=bucket_name) @@ -811,7 +833,7 @@ def delete_bucket(self, bucket_name: str, force_delete: bool = False) -> None: self.delete_objects(bucket=bucket_name, keys=bucket_keys) self.conn.delete_bucket(Bucket=bucket_name) - def delete_objects(self, bucket: str, keys: Union[str, list]) -> None: + def delete_objects(self, bucket: str, keys: str | list) -> None: """ Delete keys from the bucket. @@ -834,16 +856,21 @@ def delete_objects(self, bucket: str, keys: Union[str, list]) -> None: # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.delete_objects for chunk in chunks(keys, chunk_size=1000): response = s3.delete_objects(Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]}) - deleted_keys = [x['Key'] for x in response.get("Deleted", [])] + deleted_keys = [x["Key"] for x in response.get("Deleted", [])] self.log.info("Deleted: %s", deleted_keys) if "Errors" in response: - errors_keys = [x['Key'] for x in response.get("Errors", [])] + errors_keys = [x["Key"] for x in response.get("Errors", [])] raise AirflowException(f"Errors when deleting: {errors_keys}") @provide_bucket_name @unify_bucket_name_and_key def download_file( - self, key: str, bucket_name: Optional[str] = None, local_path: Optional[str] = None + self, + key: str, + bucket_name: str | None = None, + local_path: str | None = None, + preserve_file_name: bool = False, + use_autogenerated_subdir: bool = True, ) -> str: """ Downloads a file from the S3 location to the local file system. @@ -852,33 +879,66 @@ def download_file( :param bucket_name: The specific bucket to use. :param local_path: The local path to the downloaded file. If no path is provided it will use the system's temporary directory. + :param preserve_file_name: If you want the downloaded file name to be the same name as it is in S3, + set this parameter to True. When set to False, a random filename will be generated. + Default: False. + :param use_autogenerated_subdir: Pairs with 'preserve_file_name = True' to download the file into a + random generated folder inside the 'local_path', useful to avoid collisions between various tasks + that might download the same file name. Set it to 'False' if you don't want it, and you want a + predictable path. + Default: True. :return: the file name. - :rtype: str """ - self.log.info('Downloading source S3 file from Bucket %s with path %s', bucket_name, key) + self.log.info( + "This function shadows the 'download_file' method of S3 API, but it is not the same. If you " + "want to use the original method from S3 API, please call " + "'S3Hook.get_conn().download_file()'" + ) + + self.log.info("Downloading source S3 file from Bucket %s with path %s", bucket_name, key) try: s3_obj = self.get_key(key, bucket_name) except ClientError as e: - if e.response.get('Error', {}).get('Code') == 404: + if e.response.get("Error", {}).get("Code") == 404: raise AirflowException( - f'The source file in Bucket {bucket_name} with path {key} does not exist' + f"The source file in Bucket {bucket_name} with path {key} does not exist" ) else: raise e - with NamedTemporaryFile(dir=local_path, prefix='airflow_tmp_', delete=False) as local_tmp_file: - s3_obj.download_fileobj(local_tmp_file) + if preserve_file_name: + local_dir = local_path if local_path else gettempdir() + subdir = f"airflow_tmp_dir_{uuid4().hex[0:8]}" if use_autogenerated_subdir else "" + filename_in_s3 = s3_obj.key.rsplit("/", 1)[-1] + file_path = Path(local_dir, subdir, filename_in_s3) + + if file_path.is_file(): + self.log.error("file '%s' already exists. Failing the task and not overwriting it", file_path) + raise FileExistsError + + file_path.parent.mkdir(exist_ok=True, parents=True) + + file = open(file_path, "wb") + else: + file = NamedTemporaryFile(dir=local_path, prefix="airflow_tmp_", delete=False) # type: ignore + + with file: + s3_obj.download_fileobj( + file, + ExtraArgs=self.extra_args, + Config=self.transfer_config, + ) - return local_tmp_file.name + return file.name def generate_presigned_url( self, client_method: str, - params: Optional[dict] = None, + params: dict | None = None, expires_in: int = 3600, - http_method: Optional[str] = None, - ) -> Optional[str]: + http_method: str | None = None, + ) -> str | None: """ Generate a presigned url given a client, its method, and arguments @@ -889,7 +949,6 @@ def generate_presigned_url( :param http_method: The http method to use on the generated url. By default, the http method is whatever is used in the method's model. :return: The presigned url. - :rtype: str """ s3_client = self.get_conn() try: @@ -902,17 +961,16 @@ def generate_presigned_url( return None @provide_bucket_name - def get_bucket_tagging(self, bucket_name: Optional[str] = None) -> Optional[List[Dict[str, str]]]: + def get_bucket_tagging(self, bucket_name: str | None = None) -> list[dict[str, str]] | None: """ Gets a List of tags from a bucket. :param bucket_name: The name of the bucket. :return: A List containing the key/value pairs for the tags - :rtype: Optional[List[Dict[str, str]]] """ try: s3_client = self.get_conn() - result = s3_client.get_bucket_tagging(Bucket=bucket_name)['TagSet'] + result = s3_client.get_bucket_tagging(Bucket=bucket_name)["TagSet"] self.log.info("S3 Bucket Tag Info: %s", result) return result except ClientError as e: @@ -922,10 +980,10 @@ def get_bucket_tagging(self, bucket_name: Optional[str] = None) -> Optional[List @provide_bucket_name def put_bucket_tagging( self, - tag_set: Optional[List[Dict[str, str]]] = None, - key: Optional[str] = None, - value: Optional[str] = None, - bucket_name: Optional[str] = None, + tag_set: list[dict[str, str]] | None = None, + key: str | None = None, + value: str | None = None, + bucket_name: str | None = None, ) -> None: """ Overwrites the existing TagSet with provided tags. Must provide either a TagSet or a key/value pair. @@ -935,33 +993,31 @@ def put_bucket_tagging( :param value: The Value for the new TagSet entry. :param bucket_name: The name of the bucket. :return: None - :rtype: None """ self.log.info("S3 Bucket Tag Info:\tKey: %s\tValue: %s\tSet: %s", key, value, tag_set) if not tag_set: tag_set = [] if key and value: - tag_set.append({'Key': key, 'Value': value}) + tag_set.append({"Key": key, "Value": value}) elif not tag_set or (key or value): - message = 'put_bucket_tagging() requires either a predefined TagSet or a key/value pair.' + message = "put_bucket_tagging() requires either a predefined TagSet or a key/value pair." self.log.error(message) raise ValueError(message) try: s3_client = self.get_conn() - s3_client.put_bucket_tagging(Bucket=bucket_name, Tagging={'TagSet': tag_set}) + s3_client.put_bucket_tagging(Bucket=bucket_name, Tagging={"TagSet": tag_set}) except ClientError as e: self.log.error(e) raise e @provide_bucket_name - def delete_bucket_tagging(self, bucket_name: Optional[str] = None) -> None: + def delete_bucket_tagging(self, bucket_name: str | None = None) -> None: """ Deletes all tags from a bucket. :param bucket_name: The name of the bucket. :return: None - :rtype: None """ s3_client = self.get_conn() s3_client.delete_bucket_tagging(Bucket=bucket_name) diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py b/airflow/providers/amazon/aws/hooks/sagemaker.py index 2c8c28a738ec3..ae57e832d1b90 100644 --- a/airflow/providers/amazon/aws/hooks/sagemaker.py +++ b/airflow/providers/amazon/aws/hooks/sagemaker.py @@ -15,15 +15,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import collections import os +import re import tarfile import tempfile import time import warnings from datetime import datetime from functools import partial -from typing import Any, Callable, Dict, Generator, List, Optional, Set, cast +from typing import Any, Callable, Generator, cast from botocore.exceptions import ClientError @@ -49,10 +52,10 @@ class LogState: # Position is a tuple that includes the last read timestamp and the number of items that were read # at that time. This is used to figure out which event to start with on the next read. -Position = collections.namedtuple('Position', ['timestamp', 'skip']) +Position = collections.namedtuple("Position", ["timestamp", "skip"]) -def argmin(arr, f: Callable) -> Optional[int]: +def argmin(arr, f: Callable) -> int | None: """Return the index, i, in arr that minimizes f(arr[i])""" min_value = None min_idx = None @@ -73,28 +76,28 @@ def secondary_training_status_changed(current_job_description: dict, prev_job_de :return: Whether the secondary status message of a training job changed or not. """ - current_secondary_status_transitions = current_job_description.get('SecondaryStatusTransitions') + current_secondary_status_transitions = current_job_description.get("SecondaryStatusTransitions") if current_secondary_status_transitions is None or len(current_secondary_status_transitions) == 0: return False prev_job_secondary_status_transitions = ( - prev_job_description.get('SecondaryStatusTransitions') if prev_job_description is not None else None + prev_job_description.get("SecondaryStatusTransitions") if prev_job_description is not None else None ) last_message = ( - prev_job_secondary_status_transitions[-1]['StatusMessage'] + prev_job_secondary_status_transitions[-1]["StatusMessage"] if prev_job_secondary_status_transitions is not None and len(prev_job_secondary_status_transitions) > 0 - else '' + else "" ) - message = current_job_description['SecondaryStatusTransitions'][-1]['StatusMessage'] + message = current_job_description["SecondaryStatusTransitions"][-1]["StatusMessage"] return message != last_message def secondary_training_status_message( - job_description: Dict[str, List[Any]], prev_description: Optional[dict] + job_description: dict[str, list[Any]], prev_description: dict | None ) -> str: """ Returns a string contains start time and the secondary training job status message. @@ -104,14 +107,14 @@ def secondary_training_status_message( :return: Job status string to be printed. """ - current_transitions = job_description.get('SecondaryStatusTransitions') + current_transitions = job_description.get("SecondaryStatusTransitions") if current_transitions is None or len(current_transitions) == 0: - return '' + return "" prev_transitions_num = 0 if prev_description is not None: - if prev_description.get('SecondaryStatusTransitions') is not None: - prev_transitions_num = len(prev_description['SecondaryStatusTransitions']) + if prev_description.get("SecondaryStatusTransitions") is not None: + prev_transitions_num = len(prev_description["SecondaryStatusTransitions"]) transitions_to_print = ( current_transitions[-1:] @@ -121,13 +124,13 @@ def secondary_training_status_message( status_strs = [] for transition in transitions_to_print: - message = transition['StatusMessage'] - time_str = timezone.convert_to_utc(cast(datetime, job_description['LastModifiedTime'])).strftime( - '%Y-%m-%d %H:%M:%S' + message = transition["StatusMessage"] + time_str = timezone.convert_to_utc(cast(datetime, job_description["LastModifiedTime"])).strftime( + "%Y-%m-%d %H:%M:%S" ) status_strs.append(f"{time_str} {transition['Status']} - {message}") - return '\n'.join(status_strs) + return "\n".join(status_strs) class SageMakerHook(AwsBaseHook): @@ -141,12 +144,12 @@ class SageMakerHook(AwsBaseHook): :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` """ - non_terminal_states = {'InProgress', 'Stopping'} - endpoint_non_terminal_states = {'Creating', 'Updating', 'SystemUpdating', 'RollingBack', 'Deleting'} - failed_states = {'Failed'} + non_terminal_states = {"InProgress", "Stopping"} + endpoint_non_terminal_states = {"Creating", "Updating", "SystemUpdating", "RollingBack", "Deleting"} + failed_states = {"Failed"} def __init__(self, *args, **kwargs): - super().__init__(client_type='sagemaker', *args, **kwargs) + super().__init__(client_type="sagemaker", *args, **kwargs) self.s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) self.logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id) @@ -164,7 +167,7 @@ def tar_and_s3_upload(self, path: str, key: str, bucket: str) -> None: files = [os.path.join(path, name) for name in os.listdir(path)] else: files = [path] - with tarfile.open(mode='w:gz', fileobj=temp_file) as tar_file: + with tarfile.open(mode="w:gz", fileobj=temp_file) as tar_file: for f in files: tar_file.add(f, arcname=os.path.basename(f)) temp_file.seek(0) @@ -175,27 +178,25 @@ def configure_s3_resources(self, config: dict) -> None: Extract the S3 operations from the configuration and execute them. :param config: config of SageMaker operation - :rtype: dict """ - s3_operations = config.pop('S3Operations', None) + s3_operations = config.pop("S3Operations", None) if s3_operations is not None: - create_bucket_ops = s3_operations.get('S3CreateBucket', []) - upload_ops = s3_operations.get('S3Upload', []) + create_bucket_ops = s3_operations.get("S3CreateBucket", []) + upload_ops = s3_operations.get("S3Upload", []) for op in create_bucket_ops: - self.s3_hook.create_bucket(bucket_name=op['Bucket']) + self.s3_hook.create_bucket(bucket_name=op["Bucket"]) for op in upload_ops: - if op['Tar']: - self.tar_and_s3_upload(op['Path'], op['Key'], op['Bucket']) + if op["Tar"]: + self.tar_and_s3_upload(op["Path"], op["Key"], op["Bucket"]) else: - self.s3_hook.load_file(op['Path'], op['Key'], op['Bucket']) + self.s3_hook.load_file(op["Path"], op["Key"], op["Bucket"]) def check_s3_url(self, s3url: str) -> bool: """ Check if an S3 URL exists :param s3url: S3 url - :rtype: bool """ bucket, key = S3Hook.parse_s3_url(s3url) if not self.s3_hook.check_for_bucket(bucket_name=bucket): @@ -203,7 +204,7 @@ def check_s3_url(self, s3url: str) -> bool: if ( key and not self.s3_hook.check_for_key(key=key, bucket_name=bucket) - and not self.s3_hook.check_for_prefix(prefix=key, bucket_name=bucket, delimiter='/') + and not self.s3_hook.check_for_prefix(prefix=key, bucket_name=bucket, delimiter="/") ): # check if s3 key exists in the case user provides a single file # or if s3 prefix exists in the case user provides multiple files in @@ -221,9 +222,9 @@ def check_training_config(self, training_config: dict) -> None: :return: None """ if "InputDataConfig" in training_config: - for channel in training_config['InputDataConfig']: - if "S3DataSource" in channel['DataSource']: - self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri']) + for channel in training_config["InputDataConfig"]: + if "S3DataSource" in channel["DataSource"]: + self.check_s3_url(channel["DataSource"]["S3DataSource"]["S3Uri"]) def check_tuning_config(self, tuning_config: dict) -> None: """ @@ -232,39 +233,9 @@ def check_tuning_config(self, tuning_config: dict) -> None: :param tuning_config: tuning_config :return: None """ - for channel in tuning_config['TrainingJobDefinition']['InputDataConfig']: - if "S3DataSource" in channel['DataSource']: - self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri']) - - def get_log_conn(self): - """ - This method is deprecated. - Please use :py:meth:`airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_conn` instead. - """ - warnings.warn( - "Method `get_log_conn` has been deprecated. " - "Please use `airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_conn` instead.", - category=DeprecationWarning, - stacklevel=2, - ) - - return self.logs_hook.get_conn() - - def log_stream(self, log_group, stream_name, start_time=0, skip=0): - """ - This method is deprecated. - Please use - :py:meth:`airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events` instead. - """ - warnings.warn( - "Method `log_stream` has been deprecated. " - "Please use " - "`airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events` instead.", - category=DeprecationWarning, - stacklevel=2, - ) - - return self.logs_hook.get_log_events(log_group, stream_name, start_time, skip) + for channel in tuning_config["TrainingJobDefinition"]["InputDataConfig"]: + if "S3DataSource" in channel["DataSource"]: + self.check_s3_url(channel["DataSource"]["S3DataSource"]["S3Uri"]) def multi_stream_iter(self, log_group: str, streams: list, positions=None) -> Generator: """ @@ -283,7 +254,7 @@ def multi_stream_iter(self, log_group: str, streams: list, positions=None) -> Ge self.logs_hook.get_log_events(log_group, s, positions[s].timestamp, positions[s].skip) for s in streams ] - events: List[Optional[Any]] = [] + events: list[Any | None] = [] for event_stream in event_iters: if not event_stream: events.append(None) @@ -294,7 +265,7 @@ def multi_stream_iter(self, log_group: str, streams: list, positions=None) -> Ge events.append(None) while any(events): - i = argmin(events, lambda x: x['timestamp'] if x else 9999999999) or 0 + i = argmin(events, lambda x: x["timestamp"] if x else 9999999999) or 0 yield i, events[i] try: events[i] = next(event_iters[i]) @@ -307,7 +278,7 @@ def create_training_job( wait_for_completion: bool = True, print_log: bool = True, check_interval: int = 30, - max_ingestion_time: Optional[int] = None, + max_ingestion_time: int | None = None, ): """ Starts a model training job. After training completes, Amazon SageMaker saves @@ -327,7 +298,7 @@ def create_training_job( response = self.get_conn().create_training_job(**config) if print_log: self.check_training_status_with_log( - config['TrainingJobName'], + config["TrainingJobName"], self.non_terminal_states, self.failed_states, wait_for_completion, @@ -336,17 +307,17 @@ def create_training_job( ) elif wait_for_completion: describe_response = self.check_status( - config['TrainingJobName'], - 'TrainingJobStatus', + config["TrainingJobName"], + "TrainingJobStatus", self.describe_training_job, check_interval, max_ingestion_time, ) billable_time = ( - describe_response['TrainingEndTime'] - describe_response['TrainingStartTime'] - ) * describe_response['ResourceConfig']['InstanceCount'] - self.log.info('Billable seconds: %d', int(billable_time.total_seconds()) + 1) + describe_response["TrainingEndTime"] - describe_response["TrainingStartTime"] + ) * describe_response["ResourceConfig"]["InstanceCount"] + self.log.info("Billable seconds: %d", int(billable_time.total_seconds()) + 1) return response @@ -355,7 +326,7 @@ def create_tuning_job( config: dict, wait_for_completion: bool = True, check_interval: int = 30, - max_ingestion_time: Optional[int] = None, + max_ingestion_time: int | None = None, ): """ Starts a hyperparameter tuning job. A hyperparameter tuning job finds the @@ -378,8 +349,8 @@ def create_tuning_job( response = self.get_conn().create_hyper_parameter_tuning_job(**config) if wait_for_completion: self.check_status( - config['HyperParameterTuningJobName'], - 'HyperParameterTuningJobStatus', + config["HyperParameterTuningJobName"], + "HyperParameterTuningJobStatus", self.describe_tuning_job, check_interval, max_ingestion_time, @@ -391,7 +362,7 @@ def create_transform_job( config: dict, wait_for_completion: bool = True, check_interval: int = 30, - max_ingestion_time: Optional[int] = None, + max_ingestion_time: int | None = None, ): """ Starts a transform job. A transform job uses a trained model to get inferences @@ -406,14 +377,14 @@ def create_transform_job( None implies no timeout for any SageMaker job. :return: A response to transform job creation """ - if "S3DataSource" in config['TransformInput']['DataSource']: - self.check_s3_url(config['TransformInput']['DataSource']['S3DataSource']['S3Uri']) + if "S3DataSource" in config["TransformInput"]["DataSource"]: + self.check_s3_url(config["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"]) response = self.get_conn().create_transform_job(**config) if wait_for_completion: self.check_status( - config['TransformJobName'], - 'TransformJobStatus', + config["TransformJobName"], + "TransformJobStatus", self.describe_transform_job, check_interval, max_ingestion_time, @@ -425,7 +396,7 @@ def create_processing_job( config: dict, wait_for_completion: bool = True, check_interval: int = 30, - max_ingestion_time: Optional[int] = None, + max_ingestion_time: int | None = None, ): """ Use Amazon SageMaker Processing to analyze data and evaluate machine learning @@ -445,8 +416,8 @@ def create_processing_job( response = self.get_conn().create_processing_job(**config) if wait_for_completion: self.check_status( - config['ProcessingJobName'], - 'ProcessingJobStatus', + config["ProcessingJobName"], + "ProcessingJobStatus", self.describe_processing_job, check_interval, max_ingestion_time, @@ -486,7 +457,7 @@ def create_endpoint( config: dict, wait_for_completion: bool = True, check_interval: int = 30, - max_ingestion_time: Optional[int] = None, + max_ingestion_time: int | None = None, ): """ When you create a serverless endpoint, SageMaker provisions and manages @@ -511,8 +482,8 @@ def create_endpoint( response = self.get_conn().create_endpoint(**config) if wait_for_completion: self.check_status( - config['EndpointName'], - 'EndpointStatus', + config["EndpointName"], + "EndpointStatus", self.describe_endpoint, check_interval, max_ingestion_time, @@ -525,7 +496,7 @@ def update_endpoint( config: dict, wait_for_completion: bool = True, check_interval: int = 30, - max_ingestion_time: Optional[int] = None, + max_ingestion_time: int | None = None, ): """ Deploys the new EndpointConfig specified in the request, switches to using @@ -544,8 +515,8 @@ def update_endpoint( response = self.get_conn().update_endpoint(**config) if wait_for_completion: self.check_status( - config['EndpointName'], - 'EndpointStatus', + config["EndpointName"], + "EndpointStatus", self.describe_endpoint, check_interval, max_ingestion_time, @@ -573,7 +544,7 @@ def describe_training_job_with_log( last_describe_job_call: float, ): """Return the training job info associated with job_name and print CloudWatch logs""" - log_group = '/aws/sagemaker/TrainingJobs' + log_group = "/aws/sagemaker/TrainingJobs" if len(stream_names) < instance_count: # Log streams are created whenever a container starts writing to stdout/err, so this list @@ -582,11 +553,11 @@ def describe_training_job_with_log( try: streams = logs_conn.describe_log_streams( logGroupName=log_group, - logStreamNamePrefix=job_name + '/', - orderBy='LogStreamName', + logStreamNamePrefix=job_name + "/", + orderBy="LogStreamName", limit=instance_count, ) - stream_names = [s['logStreamName'] for s in streams['logStreams']] + stream_names = [s["logStreamName"] for s in streams["logStreams"]] positions.update( [(s, Position(timestamp=0, skip=0)) for s in stream_names if s not in positions] ) @@ -597,12 +568,12 @@ def describe_training_job_with_log( if len(stream_names) > 0: for idx, event in self.multi_stream_iter(log_group, stream_names, positions): - self.log.info(event['message']) + self.log.info(event["message"]) ts, count = positions[stream_names[idx]] - if event['timestamp'] == ts: + if event["timestamp"] == ts: positions[stream_names[idx]] = Position(timestamp=ts, skip=count + 1) else: - positions[stream_names[idx]] = Position(timestamp=event['timestamp'], skip=1) + positions[stream_names[idx]] = Position(timestamp=event["timestamp"], skip=1) if state == LogState.COMPLETE: return state, last_description, last_describe_job_call @@ -617,7 +588,7 @@ def describe_training_job_with_log( self.log.info(secondary_training_status_message(description, last_description)) last_description = description - status = description['TrainingJobStatus'] + status = description["TrainingJobStatus"] if status not in self.non_terminal_states: state = LogState.JOB_COMPLETE @@ -681,8 +652,8 @@ def check_status( key: str, describe_function: Callable, check_interval: int, - max_ingestion_time: Optional[int] = None, - non_terminal_states: Optional[Set] = None, + max_ingestion_time: int | None = None, + non_terminal_states: set | None = None, ): """ Check status of a SageMaker job @@ -704,34 +675,30 @@ def check_status( non_terminal_states = self.non_terminal_states sec = 0 - running = True - while running: + while True: time.sleep(check_interval) sec += check_interval try: response = describe_function(job_name) status = response[key] - self.log.info('Job still running for %s seconds... current status is %s', sec, status) + self.log.info("Job still running for %s seconds... current status is %s", sec, status) except KeyError: - raise AirflowException('Could not get status of the SageMaker job') + raise AirflowException("Could not get status of the SageMaker job") except ClientError: - raise AirflowException('AWS request failed, check logs for more info') + raise AirflowException("AWS request failed, check logs for more info") - if status in non_terminal_states: - running = True - elif status in self.failed_states: + if status in self.failed_states: raise AirflowException(f"SageMaker job failed because {response['FailureReason']}") - else: - running = False + elif status not in non_terminal_states: + break if max_ingestion_time and sec > max_ingestion_time: # ensure that the job gets killed if the max ingestion time is exceeded - raise AirflowException(f'SageMaker job took more than {max_ingestion_time} seconds') + raise AirflowException(f"SageMaker job took more than {max_ingestion_time} seconds") - self.log.info('SageMaker Job completed') - response = describe_function(job_name) + self.log.info("SageMaker Job completed") return response def check_training_status_with_log( @@ -741,7 +708,7 @@ def check_training_status_with_log( failed_states: set, wait_for_completion: bool, check_interval: int, - max_ingestion_time: Optional[int] = None, + max_ingestion_time: int | None = None, ): """ Display the logs for a given training job, optionally tailing them until the @@ -761,8 +728,8 @@ def check_training_status_with_log( sec = 0 description = self.describe_training_job(job_name) self.log.info(secondary_training_status_message(description, None)) - instance_count = description['ResourceConfig']['InstanceCount'] - status = description['TrainingJobStatus'] + instance_count = description["ResourceConfig"]["InstanceCount"] + status = description["TrainingJobStatus"] stream_names: list = [] # The list of log streams positions: dict = {} # The current position in each stream, map of stream name -> position @@ -812,21 +779,21 @@ def check_training_status_with_log( if max_ingestion_time and sec > max_ingestion_time: # ensure that the job gets killed if the max ingestion time is exceeded - raise AirflowException(f'SageMaker job took more than {max_ingestion_time} seconds') + raise AirflowException(f"SageMaker job took more than {max_ingestion_time} seconds") if wait_for_completion: - status = last_description['TrainingJobStatus'] + status = last_description["TrainingJobStatus"] if status in failed_states: - reason = last_description.get('FailureReason', '(No reason provided)') - raise AirflowException(f'Error training {job_name}: {status} Reason: {reason}') + reason = last_description.get("FailureReason", "(No reason provided)") + raise AirflowException(f"Error training {job_name}: {status} Reason: {reason}") billable_time = ( - last_description['TrainingEndTime'] - last_description['TrainingStartTime'] + last_description["TrainingEndTime"] - last_description["TrainingStartTime"] ) * instance_count - self.log.info('Billable seconds: %d', int(billable_time.total_seconds()) + 1) + self.log.info("Billable seconds: %d", int(billable_time.total_seconds()) + 1) def list_training_jobs( - self, name_contains: Optional[str] = None, max_results: Optional[int] = None, **kwargs - ) -> List[Dict]: + self, name_contains: str | None = None, max_results: int | None = None, **kwargs + ) -> list[dict]: """ This method wraps boto3's `list_training_jobs`. The training job name and max results are configurable via arguments. Other arguments are not, and should be provided via kwargs. Note boto3 expects these in @@ -844,28 +811,42 @@ def list_training_jobs( :param kwargs: (optional) kwargs to boto3's list_training_jobs method :return: results of the list_training_jobs request """ - config = {} + config, max_results = self._preprocess_list_request_args(name_contains, max_results, **kwargs) + list_training_jobs_request = partial(self.get_conn().list_training_jobs, **config) + results = self._list_request( + list_training_jobs_request, "TrainingJobSummaries", max_results=max_results + ) + return results - if name_contains: - if "NameContains" in kwargs: - raise AirflowException("Either name_contains or NameContains can be provided, not both.") - config["NameContains"] = name_contains + def list_transform_jobs( + self, name_contains: str | None = None, max_results: int | None = None, **kwargs + ) -> list[dict]: + """ + This method wraps boto3's `list_transform_jobs`. + The transform job name and max results are configurable via arguments. + Other arguments are not, and should be provided via kwargs. Note boto3 expects these in + CamelCase format, for example: - if "MaxResults" in kwargs and kwargs["MaxResults"] is not None: - if max_results: - raise AirflowException("Either max_results or MaxResults can be provided, not both.") - # Unset MaxResults, we'll use the SageMakerHook's internal method for iteratively fetching results - max_results = kwargs["MaxResults"] - del kwargs["MaxResults"] + .. code-block:: python - config.update(kwargs) - list_training_jobs_request = partial(self.get_conn().list_training_jobs, **config) + list_transform_jobs(name_contains="myjob", StatusEquals="Failed") + + .. seealso:: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_transform_jobs + + :param name_contains: (optional) partial name to match + :param max_results: (optional) maximum number of results to return. None returns infinite results + :param kwargs: (optional) kwargs to boto3's list_transform_jobs method + :return: results of the list_transform_jobs request + """ + config, max_results = self._preprocess_list_request_args(name_contains, max_results, **kwargs) + list_transform_jobs_request = partial(self.get_conn().list_transform_jobs, **config) results = self._list_request( - list_training_jobs_request, "TrainingJobSummaries", max_results=max_results + list_transform_jobs_request, "TransformJobSummaries", max_results=max_results ) return results - def list_processing_jobs(self, **kwargs) -> List[Dict]: + def list_processing_jobs(self, **kwargs) -> list[dict]: """ This method wraps boto3's `list_processing_jobs`. All arguments should be provided via kwargs. Note boto3 expects these in CamelCase format, for example: @@ -886,9 +867,40 @@ def list_processing_jobs(self, **kwargs) -> List[Dict]: ) return results + def _preprocess_list_request_args( + self, name_contains: str | None = None, max_results: int | None = None, **kwargs + ) -> tuple[dict[str, Any], int | None]: + """ + This method preprocesses the arguments to the boto3's list_* methods. + It will turn arguments name_contains and max_results as boto3 compliant CamelCase format. + This method also makes sure that these two arguments are only set once. + + :param name_contains: boto3 function with arguments + :param max_results: the result key to iterate over + :param kwargs: (optional) kwargs to boto3's list_* method + :return: Tuple with config dict to be passed to boto3's list_* method and max_results parameter + """ + config = {} + + if name_contains: + if "NameContains" in kwargs: + raise AirflowException("Either name_contains or NameContains can be provided, not both.") + config["NameContains"] = name_contains + + if "MaxResults" in kwargs and kwargs["MaxResults"] is not None: + if max_results: + raise AirflowException("Either max_results or MaxResults can be provided, not both.") + # Unset MaxResults, we'll use the SageMakerHook's internal method for iteratively fetching results + max_results = kwargs["MaxResults"] + del kwargs["MaxResults"] + + config.update(kwargs) + + return config, max_results + def _list_request( - self, partial_func: Callable, result_key: str, max_results: Optional[int] = None - ) -> List[Dict]: + self, partial_func: Callable, result_key: str, max_results: int | None = None + ) -> list[dict]: """ All AWS boto3 list_* requests return results in batches (if the key "NextToken" is contained in the result, there are more results to fetch). The default AWS batch size is 10, and configurable up to @@ -905,7 +917,7 @@ def _list_request( """ sagemaker_max_results = 100 # Fixed number set by AWS - results: List[Dict] = [] + results: list[dict] = [] next_token = None while True: @@ -929,13 +941,63 @@ def _list_request( next_token = response["NextToken"] def find_processing_job_by_name(self, processing_job_name: str) -> bool: - """Query processing job by name""" + """ + Query processing job by name + + This method is deprecated. + Please use `airflow.providers.amazon.aws.hooks.sagemaker.count_processing_jobs_by_name`. + """ + warnings.warn( + "This method is deprecated. " + "Please use `airflow.providers.amazon.aws.hooks.sagemaker.count_processing_jobs_by_name`.", + DeprecationWarning, + stacklevel=2, + ) + return bool(self.count_processing_jobs_by_name(processing_job_name)) + + @staticmethod + def _name_matches_pattern( + processing_job_name: str, + found_name: str, + job_name_suffix: str | None = None, + ) -> bool: + pattern = re.compile(f"^{processing_job_name}({job_name_suffix})?$") + return pattern.fullmatch(found_name) is not None + + def count_processing_jobs_by_name( + self, + processing_job_name: str, + job_name_suffix: str | None = None, + throttle_retry_delay: int = 2, + retries: int = 3, + ) -> int: + """ + Returns the number of processing jobs found with the provided name prefix. + :param processing_job_name: The prefix to look for. + :param job_name_suffix: The optional suffix which may be appended to deduplicate an existing job name. + :param throttle_retry_delay: Seconds to wait if a ThrottlingException is hit. + :param retries: The max number of times to retry. + :returns: The number of processing jobs that start with the provided prefix. + """ try: - self.get_conn().describe_processing_job(ProcessingJobName=processing_job_name) - return True + jobs = self.get_conn().list_processing_jobs(NameContains=processing_job_name) + # We want to make sure the job name starts with the provided name, not just contains it. + matching_jobs = [ + job["ProcessingJobName"] + for job in jobs["ProcessingJobSummaries"] + if self._name_matches_pattern(processing_job_name, job["ProcessingJobName"], job_name_suffix) + ] + return len(matching_jobs) except ClientError as e: - if e.response['Error']['Code'] in ['ValidationException', 'ResourceNotFound']: - return False + if e.response["Error"]["Code"] == "ResourceNotFound": + # No jobs found with that name. This is good, return 0. + return 0 + if e.response["Error"]["Code"] == "ThrottlingException" and retries: + # If we hit a ThrottlingException, back off a little and try again. + time.sleep(throttle_retry_delay) + return self.count_processing_jobs_by_name( + processing_job_name, job_name_suffix, throttle_retry_delay * 2, retries - 1 + ) raise def delete_model(self, model_name: str): diff --git a/airflow/providers/amazon/aws/hooks/secrets_manager.py b/airflow/providers/amazon/aws/hooks/secrets_manager.py index f1596d4961f4e..f20c833473236 100644 --- a/airflow/providers/amazon/aws/hooks/secrets_manager.py +++ b/airflow/providers/amazon/aws/hooks/secrets_manager.py @@ -15,11 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations import base64 import json -from typing import Union from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -36,24 +35,23 @@ class SecretsManagerHook(AwsBaseHook): """ def __init__(self, *args, **kwargs): - super().__init__(client_type='secretsmanager', *args, **kwargs) + super().__init__(client_type="secretsmanager", *args, **kwargs) - def get_secret(self, secret_name: str) -> Union[str, bytes]: + def get_secret(self, secret_name: str) -> str | bytes: """ Retrieve secret value from AWS Secrets Manager as a str or bytes reflecting format it stored in the AWS Secrets Manager :param secret_name: name of the secrets. :return: Union[str, bytes] with the information about the secrets - :rtype: Union[str, bytes] """ # Depending on whether the secret is a string or binary, one of # these fields will be populated. get_secret_value_response = self.get_conn().get_secret_value(SecretId=secret_name) - if 'SecretString' in get_secret_value_response: - secret = get_secret_value_response['SecretString'] + if "SecretString" in get_secret_value_response: + secret = get_secret_value_response["SecretString"] else: - secret = base64.b64decode(get_secret_value_response['SecretBinary']) + secret = base64.b64decode(get_secret_value_response["SecretBinary"]) return secret def get_secret_as_dict(self, secret_name: str) -> dict: @@ -62,6 +60,5 @@ def get_secret_as_dict(self, secret_name: str) -> dict: :param secret_name: name of the secrets. :return: dict with the information about the secrets - :rtype: dict """ return json.loads(self.get_secret(secret_name)) diff --git a/airflow/providers/amazon/aws/hooks/ses.py b/airflow/providers/amazon/aws/hooks/ses.py index 92dcce7ecb617..b58024cfa439b 100644 --- a/airflow/providers/amazon/aws/hooks/ses.py +++ b/airflow/providers/amazon/aws/hooks/ses.py @@ -15,8 +15,9 @@ # specific language governing permissions and limitations # under the License. """This module contains AWS SES Hook""" -import warnings -from typing import Any, Dict, Iterable, List, Optional, Union +from __future__ import annotations + +from typing import Any, Iterable from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.utils.email import build_mime_message @@ -34,23 +35,23 @@ class SesHook(AwsBaseHook): """ def __init__(self, *args, **kwargs) -> None: - kwargs['client_type'] = 'ses' + kwargs["client_type"] = "ses" super().__init__(*args, **kwargs) def send_email( self, mail_from: str, - to: Union[str, Iterable[str]], + to: str | Iterable[str], subject: str, html_content: str, - files: Optional[List[str]] = None, - cc: Optional[Union[str, Iterable[str]]] = None, - bcc: Optional[Union[str, Iterable[str]]] = None, - mime_subtype: str = 'mixed', - mime_charset: str = 'utf-8', - reply_to: Optional[str] = None, - return_path: Optional[str] = None, - custom_headers: Optional[Dict[str, Any]] = None, + files: list[str] | None = None, + cc: str | Iterable[str] | None = None, + bcc: str | Iterable[str] | None = None, + mime_subtype: str = "mixed", + mime_charset: str = "utf-8", + reply_to: str | None = None, + return_path: str | None = None, + custom_headers: dict[str, Any] | None = None, ) -> dict: """ Send email using Amazon Simple Email Service @@ -76,9 +77,9 @@ def send_email( custom_headers = custom_headers or {} if reply_to: - custom_headers['Reply-To'] = reply_to + custom_headers["Reply-To"] = reply_to if return_path: - custom_headers['Return-Path'] = return_path + custom_headers["Return-Path"] = return_path message, recipients = build_mime_message( mail_from=mail_from, @@ -94,20 +95,5 @@ def send_email( ) return ses_client.send_raw_email( - Source=mail_from, Destinations=recipients, RawMessage={'Data': message.as_string()} - ) - - -class SESHook(SesHook): - """ - This hook is deprecated. - Please use :class:`airflow.providers.amazon.aws.hooks.ses.SesHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This hook is deprecated. Please use :class:`airflow.providers.amazon.aws.hooks.ses.SesHook`.", - DeprecationWarning, - stacklevel=2, + Source=mail_from, Destinations=recipients, RawMessage={"Data": message.as_string()} ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/hooks/sns.py b/airflow/providers/amazon/aws/hooks/sns.py index fc009d9f9bdae..7b9fcea3dae9c 100644 --- a/airflow/providers/amazon/aws/hooks/sns.py +++ b/airflow/providers/amazon/aws/hooks/sns.py @@ -15,26 +15,25 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains AWS SNS hook""" +from __future__ import annotations + import json -import warnings -from typing import Dict, Optional, Union from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook def _get_message_attribute(o): if isinstance(o, bytes): - return {'DataType': 'Binary', 'BinaryValue': o} + return {"DataType": "Binary", "BinaryValue": o} if isinstance(o, str): - return {'DataType': 'String', 'StringValue': o} + return {"DataType": "String", "StringValue": o} if isinstance(o, (int, float)): - return {'DataType': 'Number', 'StringValue': str(o)} - if hasattr(o, '__iter__'): - return {'DataType': 'String.Array', 'StringValue': json.dumps(o)} + return {"DataType": "Number", "StringValue": str(o)} + if hasattr(o, "__iter__"): + return {"DataType": "String.Array", "StringValue": json.dumps(o)} raise TypeError( - f'Values in MessageAttributes must be one of bytes, str, int, float, or iterable; got {type(o)}' + f"Values in MessageAttributes must be one of bytes, str, int, float, or iterable; got {type(o)}" ) @@ -50,14 +49,14 @@ class SnsHook(AwsBaseHook): """ def __init__(self, *args, **kwargs): - super().__init__(client_type='sns', *args, **kwargs) + super().__init__(client_type="sns", *args, **kwargs) def publish_to_target( self, target_arn: str, message: str, - subject: Optional[str] = None, - message_attributes: Optional[dict] = None, + subject: str | None = None, + message_attributes: dict | None = None, ): """ Publish a message to a topic or an endpoint. @@ -75,33 +74,18 @@ def publish_to_target( - iterable = String.Array """ - publish_kwargs: Dict[str, Union[str, dict]] = { - 'TargetArn': target_arn, - 'MessageStructure': 'json', - 'Message': json.dumps({'default': message}), + publish_kwargs: dict[str, str | dict] = { + "TargetArn": target_arn, + "MessageStructure": "json", + "Message": json.dumps({"default": message}), } # Construct args this way because boto3 distinguishes from missing args and those set to None if subject: - publish_kwargs['Subject'] = subject + publish_kwargs["Subject"] = subject if message_attributes: - publish_kwargs['MessageAttributes'] = { + publish_kwargs["MessageAttributes"] = { key: _get_message_attribute(val) for key, val in message_attributes.items() } return self.get_conn().publish(**publish_kwargs) - - -class AwsSnsHook(SnsHook): - """ - This hook is deprecated. - Please use :class:`airflow.providers.amazon.aws.hooks.sns.SnsHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This hook is deprecated. Please use :class:`airflow.providers.amazon.aws.hooks.sns.SnsHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/hooks/sqs.py b/airflow/providers/amazon/aws/hooks/sqs.py index b94756f63aa26..2d3fd9de20e0f 100644 --- a/airflow/providers/amazon/aws/hooks/sqs.py +++ b/airflow/providers/amazon/aws/hooks/sqs.py @@ -15,10 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains AWS SQS hook""" -import warnings -from typing import Dict, Optional +from __future__ import annotations from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -38,7 +36,7 @@ def __init__(self, *args, **kwargs) -> None: kwargs["client_type"] = "sqs" super().__init__(*args, **kwargs) - def create_queue(self, queue_name: str, attributes: Optional[Dict] = None) -> Dict: + def create_queue(self, queue_name: str, attributes: dict | None = None) -> dict: """ Create queue using connection object @@ -48,7 +46,6 @@ def create_queue(self, queue_name: str, attributes: Optional[Dict] = None) -> Di :return: dict with the information about the queue For details of the returned value see :py:meth:`SQS.create_queue` - :rtype: dict """ return self.get_conn().create_queue(QueueName=queue_name, Attributes=attributes or {}) @@ -57,8 +54,9 @@ def send_message( queue_url: str, message_body: str, delay_seconds: int = 0, - message_attributes: Optional[Dict] = None, - ) -> Dict: + message_attributes: dict | None = None, + message_group_id: str | None = None, + ) -> dict: """ Send message to the queue @@ -67,29 +65,19 @@ def send_message( :param delay_seconds: seconds to delay the message :param message_attributes: additional attributes for the message (default: None) For details of the attributes parameter see :py:meth:`botocore.client.SQS.send_message` + :param message_group_id: This applies only to FIFO (first-in-first-out) queues. (default: None) + For details of the attributes parameter see :py:meth:`botocore.client.SQS.send_message` :return: dict with the information about the message sent For details of the returned value see :py:meth:`botocore.client.SQS.send_message` - :rtype: dict """ - return self.get_conn().send_message( - QueueUrl=queue_url, - MessageBody=message_body, - DelaySeconds=delay_seconds, - MessageAttributes=message_attributes or {}, - ) - - -class SQSHook(SqsHook): - """ - This hook is deprecated. - Please use :class:`airflow.providers.amazon.aws.hooks.sqs.SqsHook`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This hook is deprecated. Please use :class:`airflow.providers.amazon.aws.hooks.sqs.SqsHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) + params = { + "QueueUrl": queue_url, + "MessageBody": message_body, + "DelaySeconds": delay_seconds, + "MessageAttributes": message_attributes or {}, + } + if message_group_id: + params["MessageGroupId"] = message_group_id + + return self.get_conn().send_message(**params) diff --git a/airflow/providers/amazon/aws/hooks/step_function.py b/airflow/providers/amazon/aws/hooks/step_function.py index 97ffb10c04147..1819be12b7602 100644 --- a/airflow/providers/amazon/aws/hooks/step_function.py +++ b/airflow/providers/amazon/aws/hooks/step_function.py @@ -14,9 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import json -from typing import Optional, Union from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -32,15 +32,15 @@ class StepFunctionHook(AwsBaseHook): :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` """ - def __init__(self, region_name: Optional[str] = None, *args, **kwargs) -> None: + def __init__(self, *args, **kwargs) -> None: kwargs["client_type"] = "stepfunctions" super().__init__(*args, **kwargs) def start_execution( self, state_machine_arn: str, - name: Optional[str] = None, - state_machine_input: Union[dict, str, None] = None, + name: str | None = None, + state_machine_input: dict | str | None = None, ) -> str: """ Start Execution of the State Machine. @@ -50,21 +50,20 @@ def start_execution( :param name: The name of the execution. :param state_machine_input: JSON data input to pass to the State Machine :return: Execution ARN - :rtype: str """ - execution_args = {'stateMachineArn': state_machine_arn} + execution_args = {"stateMachineArn": state_machine_arn} if name is not None: - execution_args['name'] = name + execution_args["name"] = name if state_machine_input is not None: if isinstance(state_machine_input, str): - execution_args['input'] = state_machine_input + execution_args["input"] = state_machine_input elif isinstance(state_machine_input, dict): - execution_args['input'] = json.dumps(state_machine_input) + execution_args["input"] = json.dumps(state_machine_input) - self.log.info('Executing Step Function State Machine: %s', state_machine_arn) + self.log.info("Executing Step Function State Machine: %s", state_machine_arn) response = self.conn.start_execution(**execution_args) - return response.get('executionArn') + return response.get("executionArn") def describe_execution(self, execution_arn: str) -> dict: """ @@ -73,6 +72,5 @@ def describe_execution(self, execution_arn: str) -> dict: :param execution_arn: ARN of the State Machine Execution :return: Dict with Execution details - :rtype: dict """ return self.get_conn().describe_execution(executionArn=execution_arn) diff --git a/airflow/providers/amazon/aws/hooks/sts.py b/airflow/providers/amazon/aws/hooks/sts.py index aff787ee5d70e..8323ed2163f63 100644 --- a/airflow/providers/amazon/aws/hooks/sts.py +++ b/airflow/providers/amazon/aws/hooks/sts.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -33,9 +34,8 @@ def __init__(self, *args, **kwargs): def get_account_number(self) -> str: """Get the account Number""" - try: - return self.get_conn().get_caller_identity()['Account'] + return self.get_conn().get_caller_identity()["Account"] except Exception as general_error: self.log.error("Failed to get the AWS Account Number, error: %s", general_error) raise diff --git a/airflow/providers/apache/cassandra/example_dags/__init__.py b/airflow/providers/amazon/aws/links/__init__.py similarity index 100% rename from airflow/providers/apache/cassandra/example_dags/__init__.py rename to airflow/providers/amazon/aws/links/__init__.py diff --git a/airflow/providers/amazon/aws/links/base_aws.py b/airflow/providers/amazon/aws/links/base_aws.py new file mode 100644 index 0000000000000..fba2f17e96f39 --- /dev/null +++ b/airflow/providers/amazon/aws/links/base_aws.py @@ -0,0 +1,95 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar + +from airflow.models import BaseOperatorLink, XCom + +if TYPE_CHECKING: + from airflow.models import BaseOperator + from airflow.models.taskinstance import TaskInstanceKey + from airflow.utils.context import Context + + +BASE_AWS_CONSOLE_LINK = "https://console.{aws_domain}" + + +class BaseAwsLink(BaseOperatorLink): + """Base Helper class for constructing AWS Console Link""" + + name: ClassVar[str] + key: ClassVar[str] + format_str: ClassVar[str] + + @staticmethod + def get_aws_domain(aws_partition) -> str | None: + if aws_partition == "aws": + return "aws.amazon.com" + elif aws_partition == "aws-cn": + return "amazonaws.cn" + elif aws_partition == "aws-us-gov": + return "amazonaws-us-gov.com" + + return None + + def format_link(self, **kwargs) -> str: + """ + Format AWS Service Link + + Some AWS Service Link should require additional escaping + in this case this method should be overridden. + """ + try: + return self.format_str.format(**kwargs) + except KeyError: + return "" + + def get_link( + self, + operator: BaseOperator, + *, + ti_key: TaskInstanceKey, + ) -> str: + """ + Link to Amazon Web Services Console. + + :param operator: airflow operator + :param ti_key: TaskInstance ID to return link for + :return: link to external system + """ + conf = XCom.get_value(key=self.key, ti_key=ti_key) + return self.format_link(**conf) if conf else "" + + @classmethod + def persist( + cls, context: Context, operator: BaseOperator, region_name: str, aws_partition: str, **kwargs + ) -> None: + """Store link information into XCom""" + if not operator.do_xcom_push: + return + + operator.xcom_push( + context, + key=cls.key, + value={ + "region_name": region_name, + "aws_domain": cls.get_aws_domain(aws_partition), + **kwargs, + }, + ) diff --git a/airflow/providers/amazon/aws/links/batch.py b/airflow/providers/amazon/aws/links/batch.py new file mode 100644 index 0000000000000..432d129a7c328 --- /dev/null +++ b/airflow/providers/amazon/aws/links/batch.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink + + +class BatchJobDefinitionLink(BaseAwsLink): + """Helper class for constructing AWS Batch Job Definition Link""" + + name = "Batch Job Definition" + key = "batch_job_definition" + format_str = ( + BASE_AWS_CONSOLE_LINK + "/batch/home?region={region_name}#job-definition/detail/{job_definition_arn}" + ) + + +class BatchJobDetailsLink(BaseAwsLink): + """Helper class for constructing AWS Batch Job Details Link""" + + name = "Batch Job Details" + key = "batch_job_details" + format_str = BASE_AWS_CONSOLE_LINK + "/batch/home?region={region_name}#jobs/detail/{job_id}" + + +class BatchJobQueueLink(BaseAwsLink): + """Helper class for constructing AWS Batch Job Queue Link""" + + name = "Batch Job Queue" + key = "batch_job_queue" + format_str = BASE_AWS_CONSOLE_LINK + "/batch/home?region={region_name}#queues/detail/{job_queue_arn}" diff --git a/airflow/providers/amazon/aws/links/emr.py b/airflow/providers/amazon/aws/links/emr.py new file mode 100644 index 0000000000000..aa739567fb919 --- /dev/null +++ b/airflow/providers/amazon/aws/links/emr.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink + + +class EmrClusterLink(BaseAwsLink): + """Helper class for constructing AWS EMR Cluster Link""" + + name = "EMR Cluster" + key = "emr_cluster" + format_str = ( + BASE_AWS_CONSOLE_LINK + "/elasticmapreduce/home?region={region_name}#cluster-details:{job_flow_id}" + ) diff --git a/airflow/providers/amazon/aws/links/logs.py b/airflow/providers/amazon/aws/links/logs.py new file mode 100644 index 0000000000000..7998191d9226d --- /dev/null +++ b/airflow/providers/amazon/aws/links/logs.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from urllib.parse import quote_plus + +from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink + + +class CloudWatchEventsLink(BaseAwsLink): + """Helper class for constructing AWS CloudWatch Events Link""" + + name = "CloudWatch Events" + key = "cloudwatch_events" + format_str = ( + BASE_AWS_CONSOLE_LINK + + "/cloudwatch/home?region={awslogs_region}#logsV2:log-groups/log-group/{awslogs_group}" + + "/log-events/{awslogs_stream_name}" + ) + + def format_link(self, **kwargs) -> str: + for field in ("awslogs_stream_name", "awslogs_group"): + if field in kwargs: + kwargs[field] = quote_plus(kwargs[field]) + else: + return "" + + return super().format_link(**kwargs) diff --git a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py index c975a2cb83fc6..e50a6d4d74e2e 100644 --- a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +++ b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py @@ -15,17 +15,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys +from __future__ import annotations + from datetime import datetime import watchtower -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - +from airflow.compat.functools import cached_property from airflow.configuration import conf +from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.utils.log.file_task_handler import FileTaskHandler from airflow.utils.log.logging_mixin import LoggingMixin @@ -42,9 +40,9 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin): :param filename_template: template for file name (local storage) or log stream name (remote) """ - def __init__(self, base_log_folder: str, log_group_arn: str, filename_template: str): + def __init__(self, base_log_folder: str, log_group_arn: str, filename_template: str | None = None): super().__init__(base_log_folder, filename_template) - split_arn = log_group_arn.split(':') + split_arn = log_group_arn.split(":") self.handler = None self.log_group = split_arn[6] @@ -54,30 +52,19 @@ def __init__(self, base_log_folder: str, log_group_arn: str, filename_template: @cached_property def hook(self): """Returns AwsLogsHook.""" - remote_conn_id = conf.get('logging', 'REMOTE_LOG_CONN_ID') - try: - from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook - - return AwsLogsHook(aws_conn_id=remote_conn_id, region_name=self.region_name) - except Exception as e: - self.log.error( - 'Could not create an AwsLogsHook with connection id "%s". ' - 'Please make sure that apache-airflow[aws] is installed and ' - 'the Cloudwatch logs connection exists. Exception: "%s"', - remote_conn_id, - e, - ) - return None + return AwsLogsHook( + aws_conn_id=conf.get("logging", "REMOTE_LOG_CONN_ID"), region_name=self.region_name + ) def _render_filename(self, ti, try_number): # Replace unsupported log group name characters - return super()._render_filename(ti, try_number).replace(':', '_') + return super()._render_filename(ti, try_number).replace(":", "_") def set_context(self, ti): super().set_context(ti) self.handler = watchtower.CloudWatchLogHandler( - log_group=self.log_group, - stream_name=self._render_filename(ti, ti.try_number), + log_group_name=self.log_group, + log_stream_name=self._render_filename(ti, ti.try_number), boto3_client=self.hook.get_conn(), ) @@ -97,11 +84,21 @@ def close(self): def _read(self, task_instance, try_number, metadata=None): stream_name = self._render_filename(task_instance, try_number) - return ( - f'*** Reading remote log from Cloudwatch log_group: {self.log_group} ' - f'log_stream: {stream_name}.\n{self.get_cloudwatch_logs(stream_name=stream_name)}\n', - {'end_of_log': True}, - ) + try: + return ( + f"*** Reading remote log from Cloudwatch log_group: {self.log_group} " + f"log_stream: {stream_name}.\n{self.get_cloudwatch_logs(stream_name=stream_name)}\n", + {"end_of_log": True}, + ) + except Exception as e: + log = ( + f"*** Unable to read remote logs from Cloudwatch (log_group: {self.log_group}, log_stream: " + f"{stream_name})\n*** {str(e)}\n\n" + ) + self.log.error(log) + local_log, metadata = super()._read(task_instance, try_number, metadata) + log += local_log + return log, metadata def get_cloudwatch_logs(self, stream_name: str) -> str: """ @@ -110,21 +107,15 @@ def get_cloudwatch_logs(self, stream_name: str) -> str: :param stream_name: name of the Cloudwatch log stream to get all logs from :return: string of all logs from the given log stream """ - try: - events = list( - self.hook.get_log_events( - log_group=self.log_group, log_stream_name=stream_name, start_from_head=True - ) - ) - - return '\n'.join(self._event_to_str(event) for event in events) - except Exception: - msg = f'Could not read remote logs from log_group: {self.log_group} log_stream: {stream_name}.' - self.log.exception(msg) - return msg + events = self.hook.get_log_events( + log_group=self.log_group, + log_stream_name=stream_name, + start_from_head=True, + ) + return "\n".join(self._event_to_str(event) for event in events) def _event_to_str(self, event: dict) -> str: - event_dt = datetime.utcfromtimestamp(event['timestamp'] / 1000.0) - formatted_event_dt = event_dt.strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] - message = event['message'] - return f'[{formatted_event_dt}] {message}' + event_dt = datetime.utcfromtimestamp(event["timestamp"] / 1000.0) + formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3] + message = event["message"] + return f"[{formatted_event_dt}] {message}" diff --git a/airflow/providers/amazon/aws/log/s3_task_handler.py b/airflow/providers/amazon/aws/log/s3_task_handler.py index ce3da88f16916..8535277b13e70 100644 --- a/airflow/providers/amazon/aws/log/s3_task_handler.py +++ b/airflow/providers/amazon/aws/log/s3_task_handler.py @@ -15,16 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import os import pathlib -import sys - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property +from airflow.compat.functools import cached_property from airflow.configuration import conf +from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.utils.log.file_task_handler import FileTaskHandler from airflow.utils.log.logging_mixin import LoggingMixin @@ -36,10 +34,10 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin): uploads to and reads from S3 remote storage. """ - def __init__(self, base_log_folder: str, s3_log_folder: str, filename_template: str): + def __init__(self, base_log_folder: str, s3_log_folder: str, filename_template: str | None = None): super().__init__(base_log_folder, filename_template) self.remote_base = s3_log_folder - self.log_relative_path = '' + self.log_relative_path = "" self._hook = None self.closed = False self.upload_on_close = True @@ -47,20 +45,9 @@ def __init__(self, base_log_folder: str, s3_log_folder: str, filename_template: @cached_property def hook(self): """Returns S3Hook.""" - remote_conn_id = conf.get('logging', 'REMOTE_LOG_CONN_ID') - try: - from airflow.providers.amazon.aws.hooks.s3 import S3Hook - - return S3Hook(remote_conn_id, transfer_config_args={"use_threads": False}) - except Exception as e: - self.log.exception( - 'Could not create an S3Hook with connection id "%s". ' - 'Please make sure that apache-airflow[aws] is installed and ' - 'the S3 connection exists. Exception : "%s"', - remote_conn_id, - e, - ) - return None + return S3Hook( + aws_conn_id=conf.get("logging", "REMOTE_LOG_CONN_ID"), transfer_config_args={"use_threads": False} + ) def set_context(self, ti): super().set_context(ti) @@ -72,7 +59,7 @@ def set_context(self, ti): # Clear the file first so that duplicate data is not uploaded # when re-using the same path (e.g. with rescheduled sensors) if self.upload_on_close: - with open(self.handler.baseFilename, 'w'): + with open(self.handler.baseFilename, "w"): pass def close(self): @@ -122,18 +109,18 @@ def _read(self, ti, try_number, metadata=None): log_exists = self.s3_log_exists(remote_loc) except Exception as error: self.log.exception("Failed to verify remote log exists %s.", remote_loc) - log = f'*** Failed to verify remote log exists {remote_loc}.\n{error}\n' + log = f"*** Failed to verify remote log exists {remote_loc}.\n{error}\n" if log_exists: # If S3 remote file exists, we do not fetch logs from task instance # local machine even if there are errors reading remote logs, as # returned remote_log will contain error messages. remote_log = self.s3_read(remote_loc, return_error=True) - log = f'*** Reading remote log from {remote_loc}.\n{remote_log}\n' - return log, {'end_of_log': True} + log = f"*** Reading remote log from {remote_loc}.\n{remote_log}\n" + return log, {"end_of_log": True} else: - log += '*** Falling back to local log\n' - local_log, metadata = super()._read(ti, try_number) + log += "*** Falling back to local log\n" + local_log, metadata = super()._read(ti, try_number, metadata) return log + local_log, metadata def s3_log_exists(self, remote_log_location: str) -> bool: @@ -158,12 +145,12 @@ def s3_read(self, remote_log_location: str, return_error: bool = False) -> str: try: return self.hook.read_key(remote_log_location) except Exception as error: - msg = f'Could not read logs from {remote_log_location} with error: {error}' + msg = f"Could not read logs from {remote_log_location} with error: {error}" self.log.exception(msg) # return error if needed if return_error: return msg - return '' + return "" def s3_write(self, log: str, remote_log_location: str, append: bool = True, max_retry: int = 1): """ @@ -179,9 +166,9 @@ def s3_write(self, log: str, remote_log_location: str, append: bool = True, max_ try: if append and self.s3_log_exists(remote_log_location): old_log = self.s3_read(remote_log_location) - log = '\n'.join([old_log, log]) if old_log else log + log = "\n".join([old_log, log]) if old_log else log except Exception: - self.log.exception('Could not verify previous log to append') + self.log.exception("Could not verify previous log to append") # Default to a single retry attempt because s3 upload failures are # rare but occasionally occur. Multiple retry attempts are unlikely @@ -192,11 +179,11 @@ def s3_write(self, log: str, remote_log_location: str, append: bool = True, max_ log, key=remote_log_location, replace=True, - encrypt=conf.getboolean('logging', 'ENCRYPT_S3_LOGS'), + encrypt=conf.getboolean("logging", "ENCRYPT_S3_LOGS"), ) break except Exception: if try_num < max_retry: - self.log.warning('Failed attempt to write logs to %s, will retry', remote_log_location) + self.log.warning("Failed attempt to write logs to %s, will retry", remote_log_location) else: - self.log.exception('Could not write logs to %s', remote_log_location) + self.log.exception("Could not write logs to %s", remote_log_location) diff --git a/airflow/providers/amazon/aws/operators/appflow.py b/airflow/providers/amazon/aws/operators/appflow.py new file mode 100644 index 0000000000000..b077b04bb59dd --- /dev/null +++ b/airflow/providers/amazon/aws/operators/appflow.py @@ -0,0 +1,480 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, cast + +from airflow.compat.functools import cached_property +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.operators.python import ShortCircuitOperator +from airflow.providers.amazon.aws.hooks.appflow import AppflowHook +from airflow.providers.amazon.aws.utils import datetime_to_epoch_ms + +if TYPE_CHECKING: + from mypy_boto3_appflow.type_defs import ( + DescribeFlowExecutionRecordsResponseTypeDef, + ExecutionRecordTypeDef, + TaskTypeDef, + ) + + from airflow.utils.context import Context + + +SUPPORTED_SOURCES = {"salesforce", "zendesk"} +MANDATORY_FILTER_DATE_MSG = "The filter_date argument is mandatory for {entity}!" +NOT_SUPPORTED_SOURCE_MSG = "Source {source} is not supported for {entity}!" + + +class AppflowBaseOperator(BaseOperator): + """ + Amazon Appflow Base Operator class (not supposed to be used directly in DAGs). + + :param source: The source name (Supported: salesforce, zendesk) + :param flow_name: The flow name + :param flow_update: A boolean to enable/disable a flow update before the run + :param source_field: The field name to apply filters + :param filter_date: The date value (or template) to be used in filters. + :param poll_interval: how often in seconds to check the query status + :param aws_conn_id: aws connection to use + :param region: aws region to use + """ + + ui_color = "#2bccbd" + + def __init__( + self, + source: str, + flow_name: str, + flow_update: bool, + source_field: str | None = None, + filter_date: str | None = None, + poll_interval: int = 20, + aws_conn_id: str = "aws_default", + region: str | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if source not in SUPPORTED_SOURCES: + raise ValueError(f"{source} is not a supported source (options: {SUPPORTED_SOURCES})!") + self.filter_date = filter_date + self.flow_name = flow_name + self.source = source + self.source_field = source_field + self.poll_interval = poll_interval + self.aws_conn_id = aws_conn_id + self.region = region + self.flow_update = flow_update + + @cached_property + def hook(self) -> AppflowHook: + """Create and return an AppflowHook.""" + return AppflowHook(aws_conn_id=self.aws_conn_id, region_name=self.region) + + def execute(self, context: Context) -> None: + self.filter_date_parsed: datetime | None = ( + datetime.fromisoformat(self.filter_date) if self.filter_date else None + ) + self.connector_type = self._get_connector_type() + if self.flow_update: + self._update_flow() + self._run_flow(context) + + def _get_connector_type(self) -> str: + response = self.hook.conn.describe_flow(flowName=self.flow_name) + connector_type = response["sourceFlowConfig"]["connectorType"] + if self.source != connector_type.lower(): + raise ValueError(f"Incompatible source ({self.source} and connector type ({connector_type})!") + return connector_type + + def _update_flow(self) -> None: + self.hook.update_flow_filter(flow_name=self.flow_name, filter_tasks=[], set_trigger_ondemand=True) + + def _run_flow(self, context) -> str: + execution_id = self.hook.run_flow(flow_name=self.flow_name, poll_interval=self.poll_interval) + task_instance = context["task_instance"] + task_instance.xcom_push("execution_id", execution_id) + return execution_id + + +class AppflowRunOperator(AppflowBaseOperator): + """ + Execute a Appflow run with filters as is. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AppflowRunOperator` + + :param source: The source name (Supported: salesforce, zendesk) + :param flow_name: The flow name + :param poll_interval: how often in seconds to check the query status + :param aws_conn_id: aws connection to use + :param region: aws region to use + """ + + def __init__( + self, + source: str, + flow_name: str, + poll_interval: int = 20, + aws_conn_id: str = "aws_default", + region: str | None = None, + **kwargs, + ) -> None: + if source not in {"salesforce", "zendesk"}: + raise ValueError(NOT_SUPPORTED_SOURCE_MSG.format(source=source, entity="AppflowRunOperator")) + super().__init__( + source=source, + flow_name=flow_name, + flow_update=False, + source_field=None, + filter_date=None, + poll_interval=poll_interval, + aws_conn_id=aws_conn_id, + region=region, + **kwargs, + ) + + +class AppflowRunFullOperator(AppflowBaseOperator): + """ + Execute a Appflow full run removing any filter. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AppflowRunFullOperator` + + :param source: The source name (Supported: salesforce, zendesk) + :param flow_name: The flow name + :param poll_interval: how often in seconds to check the query status + :param aws_conn_id: aws connection to use + :param region: aws region to use + """ + + def __init__( + self, + source: str, + flow_name: str, + poll_interval: int = 20, + aws_conn_id: str = "aws_default", + region: str | None = None, + **kwargs, + ) -> None: + if source not in {"salesforce", "zendesk"}: + raise ValueError(NOT_SUPPORTED_SOURCE_MSG.format(source=source, entity="AppflowRunFullOperator")) + super().__init__( + source=source, + flow_name=flow_name, + flow_update=True, + source_field=None, + filter_date=None, + poll_interval=poll_interval, + aws_conn_id=aws_conn_id, + region=region, + **kwargs, + ) + + +class AppflowRunBeforeOperator(AppflowBaseOperator): + """ + Execute a Appflow run after updating the filters to select only previous data. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AppflowRunBeforeOperator` + + :param source: The source name (Supported: salesforce) + :param flow_name: The flow name + :param source_field: The field name to apply filters + :param filter_date: The date value (or template) to be used in filters. + :param poll_interval: how often in seconds to check the query status + :param aws_conn_id: aws connection to use + :param region: aws region to use + """ + + template_fields = ("filter_date",) + + def __init__( + self, + source: str, + flow_name: str, + source_field: str, + filter_date: str, + poll_interval: int = 20, + aws_conn_id: str = "aws_default", + region: str | None = None, + **kwargs, + ) -> None: + if not filter_date: + raise ValueError(MANDATORY_FILTER_DATE_MSG.format(entity="AppflowRunBeforeOperator")) + if source != "salesforce": + raise ValueError( + NOT_SUPPORTED_SOURCE_MSG.format(source=source, entity="AppflowRunBeforeOperator") + ) + super().__init__( + source=source, + flow_name=flow_name, + flow_update=True, + source_field=source_field, + filter_date=filter_date, + poll_interval=poll_interval, + aws_conn_id=aws_conn_id, + region=region, + **kwargs, + ) + + def _update_flow(self) -> None: + if not self.filter_date_parsed: + raise ValueError(f"Invalid filter_date argument parser value: {self.filter_date_parsed}") + if not self.source_field: + raise ValueError(f"Invalid source_field argument value: {self.source_field}") + filter_task: TaskTypeDef = { + "taskType": "Filter", + "connectorOperator": {self.connector_type: "LESS_THAN"}, # type: ignore + "sourceFields": [self.source_field], + "taskProperties": { + "DATA_TYPE": "datetime", + "VALUE": str(datetime_to_epoch_ms(self.filter_date_parsed)), + }, # NOT inclusive + } + self.hook.update_flow_filter( + flow_name=self.flow_name, filter_tasks=[filter_task], set_trigger_ondemand=True + ) + + +class AppflowRunAfterOperator(AppflowBaseOperator): + """ + Execute a Appflow run after updating the filters to select only future data. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AppflowRunAfterOperator` + + :param source: The source name (Supported: salesforce, zendesk) + :param flow_name: The flow name + :param source_field: The field name to apply filters + :param filter_date: The date value (or template) to be used in filters. + :param poll_interval: how often in seconds to check the query status + :param aws_conn_id: aws connection to use + :param region: aws region to use + """ + + template_fields = ("filter_date",) + + def __init__( + self, + source: str, + flow_name: str, + source_field: str, + filter_date: str, + poll_interval: int = 20, + aws_conn_id: str = "aws_default", + region: str | None = None, + **kwargs, + ) -> None: + if not filter_date: + raise ValueError(MANDATORY_FILTER_DATE_MSG.format(entity="AppflowRunAfterOperator")) + if source not in {"salesforce", "zendesk"}: + raise ValueError(NOT_SUPPORTED_SOURCE_MSG.format(source=source, entity="AppflowRunAfterOperator")) + super().__init__( + source=source, + flow_name=flow_name, + flow_update=True, + source_field=source_field, + filter_date=filter_date, + poll_interval=poll_interval, + aws_conn_id=aws_conn_id, + region=region, + **kwargs, + ) + + def _update_flow(self) -> None: + if not self.filter_date_parsed: + raise ValueError(f"Invalid filter_date argument parser value: {self.filter_date_parsed}") + if not self.source_field: + raise ValueError(f"Invalid source_field argument value: {self.source_field}") + filter_task: TaskTypeDef = { + "taskType": "Filter", + "connectorOperator": {self.connector_type: "GREATER_THAN"}, # type: ignore + "sourceFields": [self.source_field], + "taskProperties": { + "DATA_TYPE": "datetime", + "VALUE": str(datetime_to_epoch_ms(self.filter_date_parsed)), + }, # NOT inclusive + } + self.hook.update_flow_filter( + flow_name=self.flow_name, filter_tasks=[filter_task], set_trigger_ondemand=True + ) + + +class AppflowRunDailyOperator(AppflowBaseOperator): + """ + Execute a Appflow run after updating the filters to select only a single day. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AppflowRunDailyOperator` + + :param source: The source name (Supported: salesforce) + :param flow_name: The flow name + :param source_field: The field name to apply filters + :param filter_date: The date value (or template) to be used in filters. + :param poll_interval: how often in seconds to check the query status + :param aws_conn_id: aws connection to use + :param region: aws region to use + """ + + template_fields = ("filter_date",) + + def __init__( + self, + source: str, + flow_name: str, + source_field: str, + filter_date: str, + poll_interval: int = 20, + aws_conn_id: str = "aws_default", + region: str | None = None, + **kwargs, + ) -> None: + if not filter_date: + raise ValueError(MANDATORY_FILTER_DATE_MSG.format(entity="AppflowRunDailyOperator")) + if source != "salesforce": + raise ValueError(NOT_SUPPORTED_SOURCE_MSG.format(source=source, entity="AppflowRunDailyOperator")) + super().__init__( + source=source, + flow_name=flow_name, + flow_update=True, + source_field=source_field, + filter_date=filter_date, + poll_interval=poll_interval, + aws_conn_id=aws_conn_id, + region=region, + **kwargs, + ) + + def _update_flow(self) -> None: + if not self.filter_date_parsed: + raise ValueError(f"Invalid filter_date argument parser value: {self.filter_date_parsed}") + if not self.source_field: + raise ValueError(f"Invalid source_field argument value: {self.source_field}") + start_filter_date = self.filter_date_parsed - timedelta(milliseconds=1) + end_filter_date = self.filter_date_parsed + timedelta(days=1) + filter_task: TaskTypeDef = { + "taskType": "Filter", + "connectorOperator": {self.connector_type: "BETWEEN"}, # type: ignore + "sourceFields": [self.source_field], + "taskProperties": { + "DATA_TYPE": "datetime", + "LOWER_BOUND": str(datetime_to_epoch_ms(start_filter_date)), # NOT inclusive + "UPPER_BOUND": str(datetime_to_epoch_ms(end_filter_date)), # NOT inclusive + }, + } + self.hook.update_flow_filter( + flow_name=self.flow_name, filter_tasks=[filter_task], set_trigger_ondemand=True + ) + + +class AppflowRecordsShortCircuitOperator(ShortCircuitOperator): + """ + Short-circuit in case of a empty Appflow's run. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AppflowRecordsShortCircuitOperator` + + :param flow_name: The flow name + :param appflow_run_task_id: Run task ID from where this operator should extract the execution ID + :param ignore_downstream_trigger_rules: Ignore downstream trigger rules + :param aws_conn_id: aws connection to use + :param region: aws region to use + """ + + ui_color = "#33ffec" # Light blue + + def __init__( + self, + *, + flow_name: str, + appflow_run_task_id: str, + ignore_downstream_trigger_rules: bool = True, + aws_conn_id: str = "aws_default", + region: str | None = None, + **kwargs, + ) -> None: + super().__init__( + python_callable=self._has_new_records_func, + op_kwargs={ + "flow_name": flow_name, + "appflow_run_task_id": appflow_run_task_id, + }, + ignore_downstream_trigger_rules=ignore_downstream_trigger_rules, + **kwargs, + ) + self.aws_conn_id = aws_conn_id + self.region = region + + @staticmethod + def _get_target_execution_id( + records: list[ExecutionRecordTypeDef], execution_id: str + ) -> ExecutionRecordTypeDef | None: + for record in records: + if record.get("executionId") == execution_id: + return record + return None + + @cached_property + def hook(self) -> AppflowHook: + """Create and return an AppflowHook.""" + return AppflowHook(aws_conn_id=self.aws_conn_id, region_name=self.region) + + def _has_new_records_func(self, **kwargs) -> bool: + appflow_task_id = kwargs["appflow_run_task_id"] + self.log.info("appflow_task_id: %s", appflow_task_id) + flow_name = kwargs["flow_name"] + self.log.info("flow_name: %s", flow_name) + af_client = self.hook.conn + task_instance = kwargs["task_instance"] + execution_id = task_instance.xcom_pull(task_ids=appflow_task_id, key="execution_id") # type: ignore + if not execution_id: + raise AirflowException(f"No execution_id found from task_id {appflow_task_id}!") + self.log.info("execution_id: %s", execution_id) + args = {"flowName": flow_name, "maxResults": 100} + response: DescribeFlowExecutionRecordsResponseTypeDef = cast( + "DescribeFlowExecutionRecordsResponseTypeDef", {} + ) + record = None + + while not record: + if "nextToken" in response: + response = af_client.describe_flow_execution_records(nextToken=response["nextToken"], **args) + else: + response = af_client.describe_flow_execution_records(**args) + record = AppflowRecordsShortCircuitOperator._get_target_execution_id( + response["flowExecutions"], execution_id + ) + if not record and "nextToken" not in response: + raise AirflowException(f"Flow ({execution_id}) without recordsProcessed info.") + + execution = record.get("executionResult", {}) + if "recordsProcessed" not in execution: + raise AirflowException(f"Flow ({execution_id}) without recordsProcessed info!") + records_processed = execution["recordsProcessed"] + self.log.info("records_processed: %d", records_processed) + task_instance.xcom_push("records_processed", records_processed) # type: ignore + return records_processed > 0 diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py index 6febe2a917532..61c897a4817d3 100644 --- a/airflow/providers/amazon/aws/operators/athena.py +++ b/airflow/providers/amazon/aws/operators/athena.py @@ -15,16 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -import sys -import warnings -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence +from __future__ import annotations -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property +import warnings +from typing import TYPE_CHECKING, Any, Sequence +from airflow.compat.functools import cached_property from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.athena import AthenaHook @@ -49,12 +45,14 @@ class AthenaOperator(BaseOperator): :param query_execution_context: Context in which query need to be run :param result_configuration: Dict with path to store results in and config related to encryption :param sleep_time: Time (in seconds) to wait between two consecutive calls to check query status on Athena - :param max_tries: Number of times to poll for query state before function exits + :param max_tries: Deprecated - use max_polling_attempts instead. + :param max_polling_attempts: Number of times to poll for query state before function exits + To limit task execution time, use execution_timeout. """ - ui_color = '#44b5e2' - template_fields: Sequence[str] = ('query', 'database', 'output_location') - template_ext: Sequence[str] = ('.sql',) + ui_color = "#44b5e2" + template_fields: Sequence[str] = ("query", "database", "output_location") + template_ext: Sequence[str] = (".sql",) template_fields_renderers = {"query": "sql"} def __init__( @@ -64,12 +62,13 @@ def __init__( database: str, output_location: str, aws_conn_id: str = "aws_default", - client_request_token: Optional[str] = None, + client_request_token: str | None = None, workgroup: str = "primary", - query_execution_context: Optional[Dict[str, str]] = None, - result_configuration: Optional[Dict[str, Any]] = None, + query_execution_context: dict[str, str] | None = None, + result_configuration: dict[str, Any] | None = None, sleep_time: int = 30, - max_tries: Optional[int] = None, + max_tries: int | None = None, + max_polling_attempts: int | None = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -82,18 +81,30 @@ def __init__( self.query_execution_context = query_execution_context or {} self.result_configuration = result_configuration or {} self.sleep_time = sleep_time - self.max_tries = max_tries - self.query_execution_id = None # type: Optional[str] + self.max_polling_attempts = max_polling_attempts + self.query_execution_id: str | None = None + + if max_tries: + warnings.warn( + f"Parameter `{self.__class__.__name__}.max_tries` is deprecated and will be removed " + "in a future release. Please use method `max_polling_attempts` instead.", + DeprecationWarning, + stacklevel=2, + ) + if max_polling_attempts and max_polling_attempts != max_tries: + raise Exception("max_polling_attempts must be the same value as max_tries") + else: + self.max_polling_attempts = max_tries @cached_property def hook(self) -> AthenaHook: """Create and return an AthenaHook.""" return AthenaHook(self.aws_conn_id, sleep_time=self.sleep_time) - def execute(self, context: 'Context') -> Optional[str]: + def execute(self, context: Context) -> str | None: """Run Presto Query on Athena""" - self.query_execution_context['Database'] = self.database - self.result_configuration['OutputLocation'] = self.output_location + self.query_execution_context["Database"] = self.database + self.result_configuration["OutputLocation"] = self.output_location self.query_execution_id = self.hook.run_query( self.query, self.query_execution_context, @@ -101,18 +112,21 @@ def execute(self, context: 'Context') -> Optional[str]: self.client_request_token, self.workgroup, ) - query_status = self.hook.poll_query_status(self.query_execution_id, self.max_tries) + query_status = self.hook.poll_query_status( + self.query_execution_id, + max_polling_attempts=self.max_polling_attempts, + ) if query_status in AthenaHook.FAILURE_STATES: error_message = self.hook.get_state_change_reason(self.query_execution_id) raise Exception( - f'Final state of Athena job is {query_status}, query_execution_id is ' - f'{self.query_execution_id}. Error: {error_message}' + f"Final state of Athena job is {query_status}, query_execution_id is " + f"{self.query_execution_id}. Error: {error_message}" ) elif not query_status or query_status in AthenaHook.INTERMEDIATE_STATES: raise Exception( - f'Final state of Athena job is {query_status}. Max tries of poll status exceeded, ' - f'query_execution_id is {self.query_execution_id}.' + f"Final state of Athena job is {query_status}. Max tries of poll status exceeded, " + f"query_execution_id is {self.query_execution_id}." ) return self.query_execution_id @@ -120,35 +134,19 @@ def execute(self, context: 'Context') -> Optional[str]: def on_kill(self) -> None: """Cancel the submitted athena query""" if self.query_execution_id: - self.log.info('Received a kill signal.') - self.log.info('Stopping Query with executionId - %s', self.query_execution_id) + self.log.info("Received a kill signal.") + self.log.info("Stopping Query with executionId - %s", self.query_execution_id) response = self.hook.stop_query(self.query_execution_id) http_status_code = None try: - http_status_code = response['ResponseMetadata']['HTTPStatusCode'] + http_status_code = response["ResponseMetadata"]["HTTPStatusCode"] except Exception as ex: - self.log.error('Exception while cancelling query: %s', ex) + self.log.error("Exception while cancelling query: %s", ex) finally: if http_status_code is None or http_status_code != 200: - self.log.error('Unable to request query cancel on athena. Exiting') + self.log.error("Unable to request query cancel on athena. Exiting") else: self.log.info( - 'Polling Athena for query with id %s to reach final state', self.query_execution_id + "Polling Athena for query with id %s to reach final state", self.query_execution_id ) self.hook.poll_query_status(self.query_execution_id) - - -class AWSAthenaOperator(AthenaOperator): - """ - This operator is deprecated. - Please use :class:`airflow.providers.amazon.aws.operators.athena.AthenaOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This operator is deprecated. Please use " - "`airflow.providers.amazon.aws.operators.athena.AthenaOperator`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/operators/aws_lambda.py b/airflow/providers/amazon/aws/operators/aws_lambda.py index c2d9d022fbcbc..70d96c094d2a8 100644 --- a/airflow/providers/amazon/aws/operators/aws_lambda.py +++ b/airflow/providers/amazon/aws/operators/aws_lambda.py @@ -15,88 +15,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.lambda_function`.""" +from __future__ import annotations -import json -from typing import TYPE_CHECKING, Optional, Sequence +import warnings -from airflow.models import BaseOperator -from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook +from airflow.providers.amazon.aws.operators.lambda_function import AwsLambdaInvokeFunctionOperator # noqa -if TYPE_CHECKING: - from airflow.utils.context import Context - - -class AwsLambdaInvokeFunctionOperator(BaseOperator): - """ - Invokes an AWS Lambda function. - You can invoke a function synchronously (and wait for the response), - or asynchronously. - To invoke a function asynchronously, - set `invocation_type` to `Event`. For more details, - review the boto3 Lambda invoke docs. - - :param function_name: The name of the AWS Lambda function, version, or alias. - :param payload: The JSON string that you want to provide to your Lambda function as input. - :param log_type: Set to Tail to include the execution log in the response. Otherwise, set to "None". - :param qualifier: Specify a version or alias to invoke a published version of the function. - :param aws_conn_id: The AWS connection ID to use - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:AwsLambdaInvokeFunctionOperator` - - """ - - template_fields: Sequence[str] = ('function_name', 'payload', 'qualifier', 'invocation_type') - ui_color = '#ff7300' - - def __init__( - self, - *, - function_name: str, - log_type: Optional[str] = None, - qualifier: Optional[str] = None, - invocation_type: Optional[str] = None, - client_context: Optional[str] = None, - payload: Optional[str] = None, - aws_conn_id: str = 'aws_default', - **kwargs, - ): - super().__init__(**kwargs) - self.function_name = function_name - self.payload = payload - self.log_type = log_type - self.qualifier = qualifier - self.invocation_type = invocation_type - self.client_context = client_context - self.aws_conn_id = aws_conn_id - - def execute(self, context: 'Context'): - """ - Invokes the target AWS Lambda function from Airflow. - - :return: The response payload from the function, or an error object. - """ - hook = LambdaHook(aws_conn_id=self.aws_conn_id) - success_status_codes = [200, 202, 204] - self.log.info("Invoking AWS Lambda function: %s with payload: %s", self.function_name, self.payload) - response = hook.invoke_lambda( - function_name=self.function_name, - invocation_type=self.invocation_type, - log_type=self.log_type, - client_context=self.client_context, - payload=self.payload, - qualifier=self.qualifier, - ) - self.log.info("Lambda response metadata: %r", response.get("ResponseMetadata")) - if response.get("StatusCode") not in success_status_codes: - raise ValueError('Lambda function did not execute', json.dumps(response.get("ResponseMetadata"))) - payload_stream = response.get("Payload") - payload = payload_stream.read().decode() - if "FunctionError" in response: - raise ValueError( - 'Lambda function execution resulted in error', - {"ResponseMetadata": response.get("ResponseMetadata"), "Payload": payload}, - ) - self.log.info('Lambda function invocation succeeded: %r', response.get("ResponseMetadata")) - return payload +warnings.warn( + "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.lambda_function`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index fcd925f9ae32e..9e85afaf4c0ce 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -14,23 +14,36 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# - """ An Airflow operator for AWS Batch services .. seealso:: - - http://boto3.readthedocs.io/en/latest/guide/configuration.html - - http://boto3.readthedocs.io/en/latest/reference/services/batch.html + - https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html + - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html - https://docs.aws.amazon.com/batch/latest/APIReference/Welcome.html """ -import warnings -from typing import TYPE_CHECKING, Any, Optional, Sequence +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING, Any, Sequence + +from airflow.providers.amazon.aws.utils import trim_none_values + +if sys.version_info >= (3, 8): + from functools import cached_property +else: + from cached_property import cached_property from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook +from airflow.providers.amazon.aws.links.batch import ( + BatchJobDefinitionLink, + BatchJobDetailsLink, + BatchJobQueueLink, +) +from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink if TYPE_CHECKING: from airflow.utils.context import Context @@ -87,14 +100,32 @@ class BatchOperator(BaseOperator): """ ui_color = "#c3dae0" - arn = None # type: Optional[str] + arn: str | None = None template_fields: Sequence[str] = ( + "job_id", "job_name", + "job_definition", + "job_queue", "overrides", + "array_properties", "parameters", + "waiters", + "tags", + "wait_for_completion", ) template_fields_renderers = {"overrides": "json", "parameters": "json"} + @property + def operator_extra_links(self): + op_extra_links = [BatchJobDetailsLink()] + if self.wait_for_completion: + op_extra_links.extend([BatchJobDefinitionLink(), BatchJobQueueLink()]) + if not self.array_properties: + # There is no CloudWatch Link to the parent Batch Job available. + op_extra_links.append(CloudWatchEventsLink()) + + return tuple(op_extra_links) + def __init__( self, *, @@ -102,15 +133,16 @@ def __init__( job_definition: str, job_queue: str, overrides: dict, - array_properties: Optional[dict] = None, - parameters: Optional[dict] = None, - job_id: Optional[str] = None, - waiters: Optional[Any] = None, - max_retries: Optional[int] = None, - status_retries: Optional[int] = None, - aws_conn_id: Optional[str] = None, - region_name: Optional[str] = None, - tags: Optional[dict] = None, + array_properties: dict | None = None, + parameters: dict | None = None, + job_id: str | None = None, + waiters: Any | None = None, + max_retries: int | None = None, + status_retries: int | None = None, + aws_conn_id: str | None = None, + region_name: str | None = None, + tags: dict | None = None, + wait_for_completion: bool = True, **kwargs, ): @@ -124,6 +156,7 @@ def __init__( self.parameters = parameters or {} self.waiters = waiters self.tags = tags or {} + self.wait_for_completion = wait_for_completion self.hook = BatchClientHook( max_retries=max_retries, status_retries=status_retries, @@ -131,20 +164,24 @@ def __init__( region_name=region_name, ) - def execute(self, context: 'Context'): + def execute(self, context: Context): """ Submit and monitor an AWS Batch job :raises: AirflowException """ self.submit_job(context) - self.monitor_job(context) + + if self.wait_for_completion: + self.monitor_job(context) + + return self.job_id def on_kill(self): response = self.hook.client.terminate_job(jobId=self.job_id, reason="Task killed by the user") self.log.info("AWS Batch job (%s) terminated: %s", self.job_id, response) - def submit_job(self, context: 'Context'): + def submit_job(self, context: Context): """ Submit an AWS Batch job @@ -167,14 +204,25 @@ def submit_job(self, context: 'Context'): containerOverrides=self.overrides, tags=self.tags, ) - self.job_id = response["jobId"] - - self.log.info("AWS Batch job (%s) started: %s", self.job_id, response) except Exception as e: - self.log.error("AWS Batch job (%s) failed submission", self.job_id) + self.log.error( + "AWS Batch job failed submission - job definition: %s - on queue %s", + self.job_definition, + self.job_queue, + ) raise AirflowException(e) - def monitor_job(self, context: 'Context'): + self.job_id = response["jobId"] + self.log.info("AWS Batch job (%s) started: %s", self.job_id, response) + BatchJobDetailsLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + job_id=self.job_id, + ) + + def monitor_job(self, context: Context): """ Monitor an AWS Batch job monitor_job can raise an exception or an AirflowTaskTimeout can be raised if execution_timeout @@ -184,28 +232,152 @@ def monitor_job(self, context: 'Context'): :raises: AirflowException """ if not self.job_id: - raise AirflowException('AWS Batch job - job_id was not found') + raise AirflowException("AWS Batch job - job_id was not found") + + try: + job_desc = self.hook.get_job_description(self.job_id) + job_definition_arn = job_desc["jobDefinition"] + job_queue_arn = job_desc["jobQueue"] + self.log.info( + "AWS Batch job (%s) Job Definition ARN: %r, Job Queue ARN: %r", + self.job_id, + job_definition_arn, + job_queue_arn, + ) + except KeyError: + self.log.warning("AWS Batch job (%s) can't get Job Definition ARN and Job Queue ARN", self.job_id) + else: + BatchJobDefinitionLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + job_definition_arn=job_definition_arn, + ) + BatchJobQueueLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + job_queue_arn=job_queue_arn, + ) if self.waiters: self.waiters.wait_for_job(self.job_id) else: self.hook.wait_for_job(self.job_id) + awslogs = self.hook.get_job_awslogs_info(self.job_id) + if awslogs: + self.log.info("AWS Batch job (%s) CloudWatch Events details found: %s", self.job_id, awslogs) + CloudWatchEventsLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + **awslogs, + ) + self.hook.check_job_success(self.job_id) self.log.info("AWS Batch job (%s) succeeded", self.job_id) -class AwsBatchOperator(BatchOperator): +class BatchCreateComputeEnvironmentOperator(BaseOperator): """ - This operator is deprecated. - Please use :class:`airflow.providers.amazon.aws.operators.batch.BatchOperator`. + Create an AWS Batch compute environment + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BatchCreateComputeEnvironmentOperator` + + :param compute_environment_name: the name of the AWS batch compute environment (templated) + + :param environment_type: the type of the compute-environment + + :param state: the state of the compute-environment + + :param compute_resources: details about the resources managed by the compute-environment (templated). + See more details here + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html#Batch.Client.create_compute_environment + + :param unmanaged_v_cpus: the maximum number of vCPU for an unmanaged compute environment. + This parameter is only supported when the ``type`` parameter is set to ``UNMANAGED``. + + :param service_role: the IAM role that allows Batch to make calls to other AWS services on your behalf + (templated) + + :param tags: the tags that you apply to the compute-environment to help you categorize and organize your + resources + + :param max_retries: exponential back-off retries, 4200 = 48 hours; + polling is only used when waiters is None + + :param status_retries: number of HTTP retries to get job status, 10; + polling is only used when waiters is None + + :param aws_conn_id: connection id of AWS credentials / region name. If None, + credential boto3 strategy will be used. + + :param region_name: region name to use in AWS Hook. + Override the region_name in connection (if provided) """ - def __init__(self, *args, **kwargs): - warnings.warn( - "This operator is deprecated. " - "Please use :class:`airflow.providers.amazon.aws.operators.batch.BatchOperator`.", - DeprecationWarning, - stacklevel=2, + template_fields: Sequence[str] = ( + "compute_environment_name", + "compute_resources", + "service_role", + ) + template_fields_renderers = {"compute_resources": "json"} + + def __init__( + self, + compute_environment_name: str, + environment_type: str, + state: str, + compute_resources: dict, + unmanaged_v_cpus: int | None = None, + service_role: str | None = None, + tags: dict | None = None, + max_retries: int | None = None, + status_retries: int | None = None, + aws_conn_id: str | None = None, + region_name: str | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.compute_environment_name = compute_environment_name + self.environment_type = environment_type + self.state = state + self.unmanaged_v_cpus = unmanaged_v_cpus + self.compute_resources = compute_resources + self.service_role = service_role + self.tags = tags or {} + self.max_retries = max_retries + self.status_retries = status_retries + self.aws_conn_id = aws_conn_id + self.region_name = region_name + + @cached_property + def hook(self): + """Create and return a BatchClientHook""" + return BatchClientHook( + max_retries=self.max_retries, + status_retries=self.status_retries, + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, ) - super().__init__(*args, **kwargs) + + def execute(self, context: Context): + """Create an AWS batch compute environment""" + kwargs: dict[str, Any] = { + "computeEnvironmentName": self.compute_environment_name, + "type": self.environment_type, + "state": self.state, + "unmanagedvCpus": self.unmanaged_v_cpus, + "computeResources": self.compute_resources, + "serviceRole": self.service_role, + "tags": self.tags, + } + self.hook.client.create_compute_environment(**trim_none_values(kwargs)) + + self.log.info("AWS Batch compute environment created successfully") diff --git a/airflow/providers/amazon/aws/operators/cloud_formation.py b/airflow/providers/amazon/aws/operators/cloud_formation.py index 2e7bb8ae121c6..423ec36e598e3 100644 --- a/airflow/providers/amazon/aws/operators/cloud_formation.py +++ b/airflow/providers/amazon/aws/operators/cloud_formation.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains CloudFormation create/delete stack operators.""" -from typing import TYPE_CHECKING, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.cloud_formation import CloudFormationHook @@ -39,20 +41,20 @@ class CloudFormationCreateStackOperator(BaseOperator): :param aws_conn_id: aws connection to uses """ - template_fields: Sequence[str] = ('stack_name',) + template_fields: Sequence[str] = ("stack_name", "cloudformation_parameters") template_ext: Sequence[str] = () - ui_color = '#6b9659' + ui_color = "#6b9659" def __init__( - self, *, stack_name: str, cloudformation_parameters: dict, aws_conn_id: str = 'aws_default', **kwargs + self, *, stack_name: str, cloudformation_parameters: dict, aws_conn_id: str = "aws_default", **kwargs ): super().__init__(**kwargs) self.stack_name = stack_name self.cloudformation_parameters = cloudformation_parameters self.aws_conn_id = aws_conn_id - def execute(self, context: 'Context'): - self.log.info('CloudFormation parameters: %s', self.cloudformation_parameters) + def execute(self, context: Context): + self.log.info("CloudFormation parameters: %s", self.cloudformation_parameters) cloudformation_hook = CloudFormationHook(aws_conn_id=self.aws_conn_id) cloudformation_hook.create_stack(self.stack_name, self.cloudformation_parameters) @@ -72,17 +74,17 @@ class CloudFormationDeleteStackOperator(BaseOperator): :param aws_conn_id: aws connection to uses """ - template_fields: Sequence[str] = ('stack_name',) + template_fields: Sequence[str] = ("stack_name",) template_ext: Sequence[str] = () - ui_color = '#1d472b' - ui_fgcolor = '#FFF' + ui_color = "#1d472b" + ui_fgcolor = "#FFF" def __init__( self, *, stack_name: str, - cloudformation_parameters: Optional[dict] = None, - aws_conn_id: str = 'aws_default', + cloudformation_parameters: dict | None = None, + aws_conn_id: str = "aws_default", **kwargs, ): super().__init__(**kwargs) @@ -90,8 +92,8 @@ def __init__( self.stack_name = stack_name self.aws_conn_id = aws_conn_id - def execute(self, context: 'Context'): - self.log.info('CloudFormation Parameters: %s', self.cloudformation_parameters) + def execute(self, context: Context): + self.log.info("CloudFormation Parameters: %s", self.cloudformation_parameters) cloudformation_hook = CloudFormationHook(aws_conn_id=self.aws_conn_id) cloudformation_hook.delete_stack(self.stack_name, self.cloudformation_parameters) diff --git a/airflow/providers/amazon/aws/operators/datasync.py b/airflow/providers/amazon/aws/operators/datasync.py index 5a0c86071139f..508d87e163658 100644 --- a/airflow/providers/amazon/aws/operators/datasync.py +++ b/airflow/providers/amazon/aws/operators/datasync.py @@ -14,13 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Create, get, update, execute and delete an AWS DataSync Task.""" +from __future__ import annotations import logging import random -import warnings -from typing import TYPE_CHECKING, List, Optional, Sequence +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException, AirflowTaskTimeout from airflow.models import BaseOperator @@ -49,6 +48,7 @@ class DataSyncOperator(BaseOperator): consecutive calls to check TaskExecution status. :param max_iterations: Maximum number of consecutive calls to check TaskExecution status. + :param wait_for_completion: If True, wait for the task execution to reach a final state :param task_arn: AWS DataSync TaskArn to use. If None, then this operator will attempt to either search for an existing Task or attempt to create a new Task. :param source_location_uri: Source location URI to search for. All DataSync @@ -122,16 +122,17 @@ def __init__( aws_conn_id: str = "aws_default", wait_interval_seconds: int = 30, max_iterations: int = 60, - task_arn: Optional[str] = None, - source_location_uri: Optional[str] = None, - destination_location_uri: Optional[str] = None, + wait_for_completion: bool = True, + task_arn: str | None = None, + source_location_uri: str | None = None, + destination_location_uri: str | None = None, allow_random_task_choice: bool = False, allow_random_location_choice: bool = False, - create_task_kwargs: Optional[dict] = None, - create_source_location_kwargs: Optional[dict] = None, - create_destination_location_kwargs: Optional[dict] = None, - update_task_kwargs: Optional[dict] = None, - task_execution_kwargs: Optional[dict] = None, + create_task_kwargs: dict | None = None, + create_source_location_kwargs: dict | None = None, + create_destination_location_kwargs: dict | None = None, + update_task_kwargs: dict | None = None, + task_execution_kwargs: dict | None = None, delete_task_after_execution: bool = False, **kwargs, ): @@ -141,6 +142,7 @@ def __init__( self.aws_conn_id = aws_conn_id self.wait_interval_seconds = wait_interval_seconds self.max_iterations = max_iterations + self.wait_for_completion = wait_for_completion self.task_arn = task_arn @@ -175,16 +177,16 @@ def __init__( ) # Others - self.hook: Optional[DataSyncHook] = None + self.hook: DataSyncHook | None = None # Candidates - these are found in AWS as possible things # for us to use - self.candidate_source_location_arns: Optional[List[str]] = None - self.candidate_destination_location_arns: Optional[List[str]] = None - self.candidate_task_arns: Optional[List[str]] = None + self.candidate_source_location_arns: list[str] | None = None + self.candidate_destination_location_arns: list[str] | None = None + self.candidate_task_arns: list[str] | None = None # Actuals - self.source_location_arn: Optional[str] = None - self.destination_location_arn: Optional[str] = None - self.task_execution_arn: Optional[str] = None + self.source_location_arn: str | None = None + self.destination_location_arn: str | None = None + self.task_execution_arn: str | None = None def get_hook(self) -> DataSyncHook: """Create and return DataSyncHook. @@ -200,7 +202,7 @@ def get_hook(self) -> DataSyncHook: ) return self.hook - def execute(self, context: 'Context'): + def execute(self, context: Context): # If task_arn was not specified then try to # find 0, 1 or many candidate DataSync Tasks to run if not self.task_arn: @@ -258,7 +260,7 @@ def _get_tasks_and_locations(self) -> None: ) self.log.info("Found candidate DataSync TaskArns %s", self.candidate_task_arns) - def choose_task(self, task_arn_list: list) -> Optional[str]: + def choose_task(self, task_arn_list: list) -> str | None: """Select 1 DataSync TaskArn from a list""" if not task_arn_list: return None @@ -272,7 +274,7 @@ def choose_task(self, task_arn_list: list) -> Optional[str]: return random.choice(task_arn_list) raise AirflowException(f"Unable to choose a Task from {task_arn_list}") - def choose_location(self, location_arn_list: Optional[List[str]]) -> Optional[str]: + def choose_location(self, location_arn_list: list[str] | None) -> str | None: """Select 1 DataSync LocationArn from a list""" if not location_arn_list: return None @@ -292,7 +294,7 @@ def _create_datasync_task(self) -> None: self.source_location_arn = self.choose_location(self.candidate_source_location_arns) if not self.source_location_arn and self.source_location_uri and self.create_source_location_kwargs: - self.log.info('Attempting to create source Location') + self.log.info("Attempting to create source Location") self.source_location_arn = hook.create_location( self.source_location_uri, **self.create_source_location_kwargs ) @@ -307,7 +309,7 @@ def _create_datasync_task(self) -> None: and self.destination_location_uri and self.create_destination_location_kwargs ): - self.log.info('Attempting to create destination Location') + self.log.info("Attempting to create destination Location") self.destination_location_arn = hook.create_location( self.destination_location_uri, **self.create_destination_location_kwargs ) @@ -346,12 +348,15 @@ def _execute_datasync_task(self) -> None: self.task_execution_arn = hook.start_task_execution(self.task_arn, **self.task_execution_kwargs) self.log.info("Started TaskExecutionArn %s", self.task_execution_arn) + if not self.wait_for_completion: + return + # Wait for task execution to complete self.log.info("Waiting for TaskExecutionArn %s", self.task_execution_arn) try: result = hook.wait_for_task_execution(self.task_execution_arn, max_iterations=self.max_iterations) except (AirflowTaskTimeout, AirflowException) as e: - self.log.error('Cancelling TaskExecution after Exception: %s', e) + self.log.error("Cancelling TaskExecution after Exception: %s", e) self._cancel_datasync_task_execution() raise self.log.info("Completed TaskExecutionArn %s", self.task_execution_arn) @@ -361,11 +366,11 @@ def _execute_datasync_task(self) -> None: # Log some meaningful statuses level = logging.ERROR if not result else logging.INFO - self.log.log(level, 'Status=%s', task_execution_description['Status']) - if 'Result' in task_execution_description: - for k, v in task_execution_description['Result'].items(): - if 'Status' in k or 'Error' in k: - self.log.log(level, '%s=%s', k, v) + self.log.log(level, "Status=%s", task_execution_description["Status"]) + if "Result" in task_execution_description: + for k, v in task_execution_description["Result"].items(): + if "Status" in k or "Error" in k: + self.log.log(level, "%s=%s", k, v) if not result: raise AirflowException(f"Failed TaskExecutionArn {self.task_execution_arn}") @@ -379,7 +384,7 @@ def _cancel_datasync_task_execution(self): self.log.info("Cancelled TaskExecutionArn %s", self.task_execution_arn) def on_kill(self): - self.log.error('Cancelling TaskExecution after task was killed') + self.log.error("Cancelling TaskExecution after task was killed") self._cancel_datasync_task_execution() def _delete_datasync_task(self) -> None: @@ -393,23 +398,7 @@ def _delete_datasync_task(self) -> None: hook.delete_task(self.task_arn) self.log.info("Task Deleted") - def _get_location_arns(self, location_uri) -> List[str]: + def _get_location_arns(self, location_uri) -> list[str]: location_arns = self.get_hook().get_location_arns(location_uri) self.log.info("Found LocationArns %s for LocationUri %s", location_arns, location_uri) return location_arns - - -class AWSDataSyncOperator(DataSyncOperator): - """ - This operator is deprecated. - Please use :class:`airflow.providers.amazon.aws.operators.datasync.DataSyncOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This operator is deprecated. Please use " - "`airflow.providers.amazon.aws.operators.datasync.DataSyncHook`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/operators/dms.py b/airflow/providers/amazon/aws/operators/dms.py index aca515cfed3a9..6303afcbaf318 100644 --- a/airflow/providers/amazon/aws/operators/dms.py +++ b/airflow/providers/amazon/aws/operators/dms.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Dict, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.dms import DmsHook @@ -47,13 +49,13 @@ class DmsCreateTaskOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'replication_task_id', - 'source_endpoint_arn', - 'target_endpoint_arn', - 'replication_instance_arn', - 'table_mappings', - 'migration_type', - 'create_task_kwargs', + "replication_task_id", + "source_endpoint_arn", + "target_endpoint_arn", + "replication_instance_arn", + "table_mappings", + "migration_type", + "create_task_kwargs", ) template_ext: Sequence[str] = () template_fields_renderers = { @@ -69,9 +71,9 @@ def __init__( target_endpoint_arn: str, replication_instance_arn: str, table_mappings: dict, - migration_type: str = 'full-load', - create_task_kwargs: Optional[dict] = None, - aws_conn_id: str = 'aws_default', + migration_type: str = "full-load", + create_task_kwargs: dict | None = None, + aws_conn_id: str = "aws_default", **kwargs, ): super().__init__(**kwargs) @@ -84,7 +86,7 @@ def __init__( self.create_task_kwargs = create_task_kwargs or {} self.aws_conn_id = aws_conn_id - def execute(self, context: 'Context'): + def execute(self, context: Context): """ Creates AWS DMS replication task from Airflow @@ -122,22 +124,22 @@ class DmsDeleteTaskOperator(BaseOperator): maintained on each worker node). """ - template_fields: Sequence[str] = ('replication_task_arn',) + template_fields: Sequence[str] = ("replication_task_arn",) template_ext: Sequence[str] = () - template_fields_renderers: Dict[str, str] = {} + template_fields_renderers: dict[str, str] = {} def __init__( self, *, - replication_task_arn: Optional[str] = None, - aws_conn_id: str = 'aws_default', + replication_task_arn: str | None = None, + aws_conn_id: str = "aws_default", **kwargs, ): super().__init__(**kwargs) self.replication_task_arn = replication_task_arn self.aws_conn_id = aws_conn_id - def execute(self, context: 'Context'): + def execute(self, context: Context): """ Deletes AWS DMS replication task from Airflow @@ -164,27 +166,26 @@ class DmsDescribeTasksOperator(BaseOperator): maintained on each worker node). """ - template_fields: Sequence[str] = ('describe_tasks_kwargs',) + template_fields: Sequence[str] = ("describe_tasks_kwargs",) template_ext: Sequence[str] = () - template_fields_renderers: Dict[str, str] = {'describe_tasks_kwargs': 'json'} + template_fields_renderers: dict[str, str] = {"describe_tasks_kwargs": "json"} def __init__( self, *, - describe_tasks_kwargs: Optional[dict] = None, - aws_conn_id: str = 'aws_default', + describe_tasks_kwargs: dict | None = None, + aws_conn_id: str = "aws_default", **kwargs, ): super().__init__(**kwargs) self.describe_tasks_kwargs = describe_tasks_kwargs or {} self.aws_conn_id = aws_conn_id - def execute(self, context: 'Context'): + def execute(self, context: Context) -> tuple[str | None, list]: """ Describes AWS DMS replication tasks from Airflow :return: Marker and list of replication tasks - :rtype: (Optional[str], list) """ dms_hook = DmsHook(aws_conn_id=self.aws_conn_id) return dms_hook.describe_replication_tasks(**self.describe_tasks_kwargs) @@ -210,20 +211,20 @@ class DmsStartTaskOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'replication_task_arn', - 'start_replication_task_type', - 'start_task_kwargs', + "replication_task_arn", + "start_replication_task_type", + "start_task_kwargs", ) template_ext: Sequence[str] = () - template_fields_renderers = {'start_task_kwargs': 'json'} + template_fields_renderers = {"start_task_kwargs": "json"} def __init__( self, *, replication_task_arn: str, - start_replication_task_type: str = 'start-replication', - start_task_kwargs: Optional[dict] = None, - aws_conn_id: str = 'aws_default', + start_replication_task_type: str = "start-replication", + start_task_kwargs: dict | None = None, + aws_conn_id: str = "aws_default", **kwargs, ): super().__init__(**kwargs) @@ -232,7 +233,7 @@ def __init__( self.start_task_kwargs = start_task_kwargs or {} self.aws_conn_id = aws_conn_id - def execute(self, context: 'Context'): + def execute(self, context: Context): """ Starts AWS DMS replication task from Airflow @@ -264,22 +265,22 @@ class DmsStopTaskOperator(BaseOperator): maintained on each worker node). """ - template_fields: Sequence[str] = ('replication_task_arn',) + template_fields: Sequence[str] = ("replication_task_arn",) template_ext: Sequence[str] = () - template_fields_renderers: Dict[str, str] = {} + template_fields_renderers: dict[str, str] = {} def __init__( self, *, - replication_task_arn: Optional[str] = None, - aws_conn_id: str = 'aws_default', + replication_task_arn: str | None = None, + aws_conn_id: str = "aws_default", **kwargs, ): super().__init__(**kwargs) self.replication_task_arn = replication_task_arn self.aws_conn_id = aws_conn_id - def execute(self, context: 'Context'): + def execute(self, context: Context): """ Stops AWS DMS replication task from Airflow diff --git a/airflow/providers/amazon/aws/operators/dms_create_task.py b/airflow/providers/amazon/aws/operators/dms_create_task.py deleted file mode 100644 index 0443772dc6aee..0000000000000 --- a/airflow/providers/amazon/aws/operators/dms_create_task.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.dms`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.dms import DmsCreateTaskOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.dms`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/dms_delete_task.py b/airflow/providers/amazon/aws/operators/dms_delete_task.py deleted file mode 100644 index ddb929efce1ce..0000000000000 --- a/airflow/providers/amazon/aws/operators/dms_delete_task.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.dms`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.dms import DmsDeleteTaskOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.dms`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/dms_describe_tasks.py b/airflow/providers/amazon/aws/operators/dms_describe_tasks.py deleted file mode 100644 index 022c028c89b96..0000000000000 --- a/airflow/providers/amazon/aws/operators/dms_describe_tasks.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.dms`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.dms import DmsDescribeTasksOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.dms`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/dms_start_task.py b/airflow/providers/amazon/aws/operators/dms_start_task.py deleted file mode 100644 index ba78ea60bed01..0000000000000 --- a/airflow/providers/amazon/aws/operators/dms_start_task.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.dms`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.dms import DmsStartTaskOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.dms`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/dms_stop_task.py b/airflow/providers/amazon/aws/operators/dms_stop_task.py deleted file mode 100644 index c4ad5f3a4a7cf..0000000000000 --- a/airflow/providers/amazon/aws/operators/dms_stop_task.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.dms`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.dms import DmsStopTaskOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.dms`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/ec2.py b/airflow/providers/amazon/aws/operators/ec2.py index 133596929e423..60cb43a32dd09 100644 --- a/airflow/providers/amazon/aws/operators/ec2.py +++ b/airflow/providers/amazon/aws/operators/ec2.py @@ -15,9 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook @@ -50,7 +50,7 @@ def __init__( *, instance_id: str, aws_conn_id: str = "aws_default", - region_name: Optional[str] = None, + region_name: str | None = None, check_interval: float = 15, **kwargs, ): @@ -60,7 +60,7 @@ def __init__( self.region_name = region_name self.check_interval = check_interval - def execute(self, context: 'Context'): + def execute(self, context: Context): ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) self.log.info("Starting EC2 instance %s", self.instance_id) instance = ec2_hook.get_instance(instance_id=self.instance_id) @@ -96,7 +96,7 @@ def __init__( *, instance_id: str, aws_conn_id: str = "aws_default", - region_name: Optional[str] = None, + region_name: str | None = None, check_interval: float = 15, **kwargs, ): @@ -106,7 +106,7 @@ def __init__( self.region_name = region_name self.check_interval = check_interval - def execute(self, context: 'Context'): + def execute(self, context: Context): ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) self.log.info("Stopping EC2 instance %s", self.instance_id) instance = ec2_hook.get_instance(instance_id=self.instance_id) diff --git a/airflow/providers/amazon/aws/operators/ec2_start_instance.py b/airflow/providers/amazon/aws/operators/ec2_start_instance.py deleted file mode 100644 index c2c25e5708b0a..0000000000000 --- a/airflow/providers/amazon/aws/operators/ec2_start_instance.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.ec2`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.ec2 import EC2StartInstanceOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.ec2`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/ec2_stop_instance.py b/airflow/providers/amazon/aws/operators/ec2_stop_instance.py deleted file mode 100644 index ddafa21c5bd5e..0000000000000 --- a/airflow/providers/amazon/aws/operators/ec2_stop_instance.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.ec2`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.ec2 import EC2StopInstanceOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.ec2`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index d1112edf445ee..0652589ea55bf 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -15,163 +15,261 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import re import sys -import time import warnings -from collections import deque -from datetime import datetime, timedelta -from logging import Logger -from threading import Event, Thread -from typing import Dict, Generator, Optional, Sequence +from datetime import timedelta +from typing import TYPE_CHECKING, Sequence -from botocore.exceptions import ClientError, ConnectionClosedError -from botocore.waiter import Waiter +import boto3 +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.models import BaseOperator, XCom from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart + +# TODO: Remove the following import when EcsProtocol and EcsTaskLogFetcher deprecations are removed. +from airflow.providers.amazon.aws.hooks import ecs from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook -from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook -from airflow.typing_compat import Protocol, runtime_checkable +from airflow.providers.amazon.aws.hooks.ecs import ( + EcsClusterStates, + EcsHook, + EcsTaskDefinitionStates, + should_retry_eni, +) +from airflow.providers.amazon.aws.sensors.ecs import EcsClusterStateSensor, EcsTaskDefinitionStateSensor from airflow.utils.session import provide_session +if TYPE_CHECKING: + from airflow.utils.context import Context -def should_retry(exception: Exception): - """Check if exception is related to ECS resource quota (CPU, MEM).""" - if isinstance(exception, EcsOperatorError): - return any( - quota_reason in failure['reason'] - for quota_reason in ['RESOURCE:MEMORY', 'RESOURCE:CPU'] - for failure in exception.failures - ) - return False +DEFAULT_CONN_ID = "aws_default" -def should_retry_eni(exception: Exception): - """Check if exception is related to ENI (Elastic Network Interfaces).""" - if isinstance(exception, EcsTaskFailToStart): - return any( - eni_reason in exception.message - for eni_reason in ['network interface provisioning', 'ResourceInitializationError'] - ) - return False +class EcsBaseOperator(BaseOperator): + """This is the base operator for all Elastic Container Service operators.""" + + def __init__(self, *, aws_conn_id: str | None = DEFAULT_CONN_ID, region: str | None = None, **kwargs): + self.aws_conn_id = aws_conn_id + self.region = region + super().__init__(**kwargs) + + @cached_property + def hook(self) -> EcsHook: + """Create and return an EcsHook.""" + return EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region) + @cached_property + def client(self) -> boto3.client: + """Create and return the EcsHook's client.""" + return self.hook.conn -@runtime_checkable -class EcsProtocol(Protocol): + def execute(self, context: Context): + """Must overwrite in child classes.""" + raise NotImplementedError("Please implement execute() in subclass") + + +class EcsCreateClusterOperator(EcsBaseOperator): """ - A structured Protocol for ``boto3.client('ecs')``. This is used for type hints on - :py:meth:`.EcsOperator.client`. + Creates an AWS ECS cluster. .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:EcsCreateClusterOperator` - - https://mypy.readthedocs.io/en/latest/protocols.html - - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html + :param cluster_name: The name of your cluster. If you don't specify a name for your + cluster, you create a cluster that's named default. + :param create_cluster_kwargs: Extra arguments for Cluster Creation. + :param wait_for_completion: If True, waits for creation of the cluster to complete. (default: True) """ - def run_task(self, **kwargs) -> Dict: - """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.run_task""" # noqa: E501 - ... - - def get_waiter(self, x: str) -> Waiter: - """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.get_waiter""" # noqa: E501 - ... + template_fields: Sequence[str] = ("cluster_name", "create_cluster_kwargs", "wait_for_completion") - def describe_tasks(self, cluster: str, tasks) -> Dict: - """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.describe_tasks""" # noqa: E501 - ... + def __init__( + self, + *, + cluster_name: str, + create_cluster_kwargs: dict | None = None, + wait_for_completion: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.cluster_name = cluster_name + self.create_cluster_kwargs = create_cluster_kwargs or {} + self.wait_for_completion = wait_for_completion - def stop_task(self, cluster, task, reason: str) -> Dict: - """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.stop_task""" # noqa: E501 - ... + def execute(self, context: Context): + self.log.info( + "Creating cluster %s using the following values: %s", + self.cluster_name, + self.create_cluster_kwargs, + ) + result = self.client.create_cluster(clusterName=self.cluster_name, **self.create_cluster_kwargs) - def describe_task_definition(self, taskDefinition: str) -> Dict: - """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.describe_task_definition""" # noqa: E501 - ... + if self.wait_for_completion: + while not EcsClusterStateSensor( + task_id="await_cluster", + cluster_name=self.cluster_name, + ).poke(context): + # The sensor has a built-in delay and will try again until + # the cluster is ready or has reached a failed state. + pass - def list_tasks(self, cluster: str, launchType: str, desiredStatus: str, family: str) -> Dict: - """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.list_tasks""" # noqa: E501 - ... + return result["cluster"] -class EcsTaskLogFetcher(Thread): +class EcsDeleteClusterOperator(EcsBaseOperator): """ - Fetches Cloudwatch log events with specific interval as a thread - and sends the log events to the info channel of the provided logger. + Deletes an AWS ECS cluster. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:EcsDeleteClusterOperator` + + :param cluster_name: The short name or full Amazon Resource Name (ARN) of the cluster to delete. + :param wait_for_completion: If True, waits for creation of the cluster to complete. (default: True) """ + template_fields: Sequence[str] = ("cluster_name", "wait_for_completion") + def __init__( self, *, - aws_conn_id: Optional[str] = 'aws_default', - region_name: Optional[str] = None, - log_group: str, - log_stream_name: str, - fetch_interval: timedelta, - logger: Logger, - ): - super().__init__() - self._event = Event() + cluster_name: str, + wait_for_completion: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.cluster_name = cluster_name + self.wait_for_completion = wait_for_completion + + def execute(self, context: Context): + self.log.info("Deleting cluster %s.", self.cluster_name) + result = self.client.delete_cluster(cluster=self.cluster_name) + + if self.wait_for_completion: + while not EcsClusterStateSensor( + task_id="await_cluster_delete", + cluster_name=self.cluster_name, + target_state=EcsClusterStates.INACTIVE, + failure_states={EcsClusterStates.FAILED}, + ).poke(context): + # The sensor has a built-in delay and will try again until + # the cluster is deleted or reaches a failed state. + pass + + return result["cluster"] + + +class EcsDeregisterTaskDefinitionOperator(EcsBaseOperator): + """ + Deregister a task definition on AWS ECS. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:EcsDeregisterTaskDefinitionOperator` + + :param task_definition: The family and revision (family:revision) or full Amazon Resource Name (ARN) + of the task definition to deregister. If you use a family name, you must specify a revision. + :param wait_for_completion: If True, waits for creation of the cluster to complete. (default: True) + """ + + template_fields: Sequence[str] = ("task_definition", "wait_for_completion") + + def __init__(self, *, task_definition: str, wait_for_completion: bool = True, **kwargs): + super().__init__(**kwargs) + self.task_definition = task_definition + self.wait_for_completion = wait_for_completion - self.fetch_interval = fetch_interval + def execute(self, context: Context): + self.log.info("Deregistering task definition %s.", self.task_definition) + result = self.client.deregister_task_definition(taskDefinition=self.task_definition) - self.logger = logger - self.log_group = log_group - self.log_stream_name = log_stream_name + if self.wait_for_completion: + while not EcsTaskDefinitionStateSensor( + task_id="await_deregister_task_definition", + task_definition=self.task_definition, + target_state=EcsTaskDefinitionStates.INACTIVE, + ).poke(context): + # The sensor has a built-in delay and will try again until the + # task definition is deregistered or reaches a failed state. + pass - self.hook = AwsLogsHook(aws_conn_id=aws_conn_id, region_name=region_name) + return result["taskDefinition"]["taskDefinitionArn"] - def run(self) -> None: - logs_to_skip = 0 - while not self.is_stopped(): - time.sleep(self.fetch_interval.total_seconds()) - log_events = self._get_log_events(logs_to_skip) - for log_event in log_events: - self.logger.info(self._event_to_str(log_event)) - logs_to_skip += 1 - def _get_log_events(self, skip: int = 0) -> Generator: - try: - yield from self.hook.get_log_events(self.log_group, self.log_stream_name, skip=skip) - except ClientError as error: - if error.response['Error']['Code'] != 'ResourceNotFoundException': - self.logger.warning('Error on retrieving Cloudwatch log events', error) +class EcsRegisterTaskDefinitionOperator(EcsBaseOperator): + """ + Register a task definition on AWS ECS. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:EcsRegisterTaskDefinitionOperator` - yield from () - except ConnectionClosedError as error: - self.logger.warning('ConnectionClosedError on retrieving Cloudwatch log events', error) - yield from () + :param family: The family name of a task definition to create. + :param container_definitions: A list of container definitions in JSON format that describe + the different containers that make up your task. + :param register_task_kwargs: Extra arguments for Register Task Definition. + :param wait_for_completion: If True, waits for creation of the cluster to complete. (default: True) + """ - def _event_to_str(self, event: dict) -> str: - event_dt = datetime.utcfromtimestamp(event['timestamp'] / 1000.0) - formatted_event_dt = event_dt.strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] - message = event['message'] - return f'[{formatted_event_dt}] {message}' + template_fields: Sequence[str] = ( + "family", + "container_definitions", + "register_task_kwargs", + "wait_for_completion", + ) - def get_last_log_messages(self, number_messages) -> list: - return [log['message'] for log in deque(self._get_log_events(), maxlen=number_messages)] + def __init__( + self, + *, + family: str, + container_definitions: list[dict], + register_task_kwargs: dict | None = None, + wait_for_completion: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.family = family + self.container_definitions = container_definitions + self.register_task_kwargs = register_task_kwargs or {} + self.wait_for_completion = wait_for_completion - def get_last_log_message(self) -> Optional[str]: - try: - return self.get_last_log_messages(1)[0] - except IndexError: - return None + def execute(self, context: Context): + self.log.info( + "Registering task definition %s using the following values: %s", + self.family, + self.register_task_kwargs, + ) + self.log.info("Using container definition %s", self.container_definitions) + response = self.client.register_task_definition( + family=self.family, + containerDefinitions=self.container_definitions, + **self.register_task_kwargs, + ) + task_arn = response["taskDefinition"]["taskDefinitionArn"] - def is_stopped(self) -> bool: - return self._event.is_set() + if self.wait_for_completion: + while not EcsTaskDefinitionStateSensor( + task_id="await_register_task_definition", task_definition=task_arn + ).poke(context): + # The sensor has a built-in delay and will try again until + # the task definition is registered or reaches a failed state. + pass - def stop(self): - self._event.set() + context["ti"].xcom_push(key="task_definition_arn", value=task_arn) + return task_arn -class EcsOperator(BaseOperator): +class EcsRunTaskOperator(EcsBaseOperator): """ Execute a task on AWS ECS (Elastic Container Service) .. seealso:: For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:EcsOperator` + :ref:`howto/operator:EcsRunTaskOperator` :param task_definition: the task definition name on Elastic Container Service :param cluster: the cluster name on Elastic Container Service @@ -179,7 +277,7 @@ class EcsOperator(BaseOperator): https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.run_task :param aws_conn_id: connection id of AWS credentials / region name. If None, credential boto3 strategy will be used - (http://boto3.readthedocs.io/en/latest/guide/configuration.html). + (https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html). :param region_name: region name to use in AWS Hook. Override the region_name in connection (if provided) :param launch_type: the launch type on which to run your task ('EC2', 'EXTERNAL', or 'FARGATE') @@ -216,15 +314,35 @@ class EcsOperator(BaseOperator): :param number_logs_exception: Number of lines from the last Cloudwatch logs to return in the AirflowException if an ECS task is stopped (to receive Airflow alerts with the logs of what failed in the code running in ECS). + :param wait_for_completion: If True, waits for creation of the cluster to complete. (default: True) """ - ui_color = '#f0ede4' - template_fields: Sequence[str] = ('overrides',) + ui_color = "#f0ede4" + template_fields: Sequence[str] = ( + "task_definition", + "cluster", + "overrides", + "launch_type", + "capacity_provider_strategy", + "group", + "placement_constraints", + "placement_strategy", + "platform_version", + "network_configuration", + "tags", + "awslogs_group", + "awslogs_region", + "awslogs_stream_prefix", + "awslogs_fetch_interval", + "propagate_tags", + "reattach", + "number_logs_exception", + "wait_for_completion", + ) template_fields_renderers = { "overrides": "json", "network_configuration": "json", "tags": "json", - "quota_retry": "json", } REATTACH_XCOM_KEY = "ecs_task_arn" REATTACH_XCOM_TASK_ID_TEMPLATE = "{task_id}_task_arn" @@ -235,30 +353,27 @@ def __init__( task_definition: str, cluster: str, overrides: dict, - aws_conn_id: Optional[str] = None, - region_name: Optional[str] = None, - launch_type: str = 'EC2', - capacity_provider_strategy: Optional[list] = None, - group: Optional[str] = None, - placement_constraints: Optional[list] = None, - placement_strategy: Optional[list] = None, - platform_version: Optional[str] = None, - network_configuration: Optional[dict] = None, - tags: Optional[dict] = None, - awslogs_group: Optional[str] = None, - awslogs_region: Optional[str] = None, - awslogs_stream_prefix: Optional[str] = None, + launch_type: str = "EC2", + capacity_provider_strategy: list | None = None, + group: str | None = None, + placement_constraints: list | None = None, + placement_strategy: list | None = None, + platform_version: str | None = None, + network_configuration: dict | None = None, + tags: dict | None = None, + awslogs_group: str | None = None, + awslogs_region: str | None = None, + awslogs_stream_prefix: str | None = None, awslogs_fetch_interval: timedelta = timedelta(seconds=30), - propagate_tags: Optional[str] = None, - quota_retry: Optional[dict] = None, + propagate_tags: str | None = None, + quota_retry: dict | None = None, reattach: bool = False, number_logs_exception: int = 10, + wait_for_completion: bool = True, **kwargs, ): super().__init__(**kwargs) - self.aws_conn_id = aws_conn_id - self.region_name = region_name self.task_definition = task_definition self.cluster = cluster self.overrides = overrides @@ -280,29 +395,26 @@ def __init__( self.number_logs_exception = number_logs_exception if self.awslogs_region is None: - self.awslogs_region = region_name + self.awslogs_region = self.region - self.hook: Optional[AwsBaseHook] = None - self.client: Optional[EcsProtocol] = None - self.arn: Optional[str] = None + self.arn: str | None = None self.retry_args = quota_retry - self.task_log_fetcher: Optional[EcsTaskLogFetcher] = None + self.task_log_fetcher: EcsTaskLogFetcher | None = None + self.wait_for_completion = wait_for_completion @provide_session def execute(self, context, session=None): self.log.info( - 'Running ECS Task - Task definition: %s - on cluster %s', self.task_definition, self.cluster + "Running ECS Task - Task definition: %s - on cluster %s", self.task_definition, self.cluster ) - self.log.info('EcsOperator overrides: %s', self.overrides) - - self.client = self.get_hook().get_conn() + self.log.info("EcsOperator overrides: %s", self.overrides) if self.reattach: self._try_reattach_task(context) self._start_wait_check_task(context) - self.log.info('ECS Task has been successfully executed') + self.log.info("ECS Task has been successfully executed") if self.reattach: # Clear the XCom value storing the ECS task ARN if the task has completed @@ -321,18 +433,20 @@ def _start_wait_check_task(self, context): self._start_task(context) if self._aws_logs_enabled(): - self.log.info('Starting ECS Task Log Fetcher') + self.log.info("Starting ECS Task Log Fetcher") self.task_log_fetcher = self._get_task_log_fetcher() self.task_log_fetcher.start() try: - self._wait_for_task_ended() + if self.wait_for_completion: + self._wait_for_task_ended() finally: self.task_log_fetcher.stop() self.task_log_fetcher.join() else: - self._wait_for_task_ended() + if self.wait_for_completion: + self._wait_for_task_ended() self._check_success_task() @@ -341,68 +455,54 @@ def _xcom_del(self, session, task_id): def _start_task(self, context): run_opts = { - 'cluster': self.cluster, - 'taskDefinition': self.task_definition, - 'overrides': self.overrides, - 'startedBy': self.owner, + "cluster": self.cluster, + "taskDefinition": self.task_definition, + "overrides": self.overrides, + "startedBy": self.owner, } if self.capacity_provider_strategy: - run_opts['capacityProviderStrategy'] = self.capacity_provider_strategy + run_opts["capacityProviderStrategy"] = self.capacity_provider_strategy elif self.launch_type: - run_opts['launchType'] = self.launch_type + run_opts["launchType"] = self.launch_type if self.platform_version is not None: - run_opts['platformVersion'] = self.platform_version + run_opts["platformVersion"] = self.platform_version if self.group is not None: - run_opts['group'] = self.group + run_opts["group"] = self.group if self.placement_constraints is not None: - run_opts['placementConstraints'] = self.placement_constraints + run_opts["placementConstraints"] = self.placement_constraints if self.placement_strategy is not None: - run_opts['placementStrategy'] = self.placement_strategy + run_opts["placementStrategy"] = self.placement_strategy if self.network_configuration is not None: - run_opts['networkConfiguration'] = self.network_configuration + run_opts["networkConfiguration"] = self.network_configuration if self.tags is not None: - run_opts['tags'] = [{'key': k, 'value': v} for (k, v) in self.tags.items()] + run_opts["tags"] = [{"key": k, "value": v} for (k, v) in self.tags.items()] if self.propagate_tags is not None: - run_opts['propagateTags'] = self.propagate_tags + run_opts["propagateTags"] = self.propagate_tags response = self.client.run_task(**run_opts) - failures = response['failures'] + failures = response["failures"] if len(failures) > 0: raise EcsOperatorError(failures, response) - self.log.info('ECS Task started: %s', response) + self.log.info("ECS Task started: %s", response) - self.arn = response['tasks'][0]['taskArn'] + self.arn = response["tasks"][0]["taskArn"] self.ecs_task_id = self.arn.split("/")[-1] self.log.info("ECS task ID is: %s", self.ecs_task_id) if self.reattach: # Save the task ARN in XCom to be able to reattach it if needed - self._xcom_set( - context, - key=self.REATTACH_XCOM_KEY, - value=self.arn, - task_id=self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id), - ) - - def _xcom_set(self, context, key, value, task_id): - XCom.set( - key=key, - value=value, - task_id=task_id, - dag_id=self.dag_id, - run_id=context["run_id"], - ) + self.xcom_push(context, key=self.REATTACH_XCOM_KEY, value=self.arn) def _try_reattach_task(self, context): task_def_resp = self.client.describe_task_definition(taskDefinition=self.task_definition) - ecs_task_family = task_def_resp['taskDefinition']['family'] + ecs_task_family = task_def_resp["taskDefinition"]["family"] list_tasks_resp = self.client.list_tasks( - cluster=self.cluster, desiredStatus='RUNNING', family=ecs_task_family + cluster=self.cluster, desiredStatus="RUNNING", family=ecs_task_family ) - running_tasks = list_tasks_resp['taskArns'] + running_tasks = list_tasks_resp["taskArns"] # Check if the ECS task previously launched is already running previous_task_arn = self.xcom_pull( @@ -421,7 +521,7 @@ def _wait_for_task_ended(self) -> None: if not self.client or not self.arn: return - waiter = self.client.get_waiter('tasks_stopped') + waiter = self.client.get_waiter("tasks_stopped") waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow waiter.wait(cluster=self.cluster, tasks=[self.arn]) @@ -430,12 +530,13 @@ def _wait_for_task_ended(self) -> None: def _aws_logs_enabled(self): return self.awslogs_group and self.awslogs_stream_prefix - def _get_task_log_fetcher(self) -> EcsTaskLogFetcher: + # TODO: When the deprecation wrapper below is removed, please fix the following return type hint. + def _get_task_log_fetcher(self) -> ecs.EcsTaskLogFetcher: if not self.awslogs_group: raise ValueError("must specify awslogs_group to fetch task logs") log_stream_name = f"{self.awslogs_stream_prefix}/{self.ecs_task_id}" - return EcsTaskLogFetcher( + return ecs.EcsTaskLogFetcher( aws_conn_id=self.aws_conn_id, region_name=self.awslogs_region, log_group=self.awslogs_group, @@ -449,14 +550,14 @@ def _check_success_task(self) -> None: return response = self.client.describe_tasks(cluster=self.cluster, tasks=[self.arn]) - self.log.info('ECS Task stopped, check status: %s', response) + self.log.info("ECS Task stopped, check status: %s", response) - if len(response.get('failures', [])) > 0: + if len(response.get("failures", [])) > 0: raise AirflowException(response) - for task in response['tasks']: + for task in response["tasks"]: - if task.get('stopCode', '') == 'TaskFailedToStart': + if task.get("stopCode", "") == "TaskFailedToStart": # Reset task arn here otherwise the retry run will not start # a new task but keep polling the old dead one # I'm not resetting it for other exceptions here because @@ -468,14 +569,14 @@ def _check_success_task(self) -> None: # successfully finished, but there is no other indication of failure # in the response. # https://docs.aws.amazon.com/AmazonECS/latest/developerguide/stopped-task-errors.html - if re.match(r'Host EC2 \(instance .+?\) (stopped|terminated)\.', task.get('stoppedReason', '')): + if re.match(r"Host EC2 \(instance .+?\) (stopped|terminated)\.", task.get("stoppedReason", "")): raise AirflowException( f"The task was stopped because the host instance terminated:" f" {task.get('stoppedReason', '')}" ) - containers = task['containers'] + containers = task["containers"] for container in containers: - if container.get('lastStatus') == 'STOPPED' and container.get('exitCode', 1) != 0: + if container.get("lastStatus") == "STOPPED" and container.get("exitCode", 1) != 0: if self.task_log_fetcher: last_logs = "\n".join( self.task_log_fetcher.get_last_log_messages(self.number_logs_exception) @@ -485,23 +586,15 @@ def _check_success_task(self) -> None: f"logs from Cloudwatch:\n{last_logs}" ) else: - raise AirflowException(f'This task is not in success state {task}') - elif container.get('lastStatus') == 'PENDING': - raise AirflowException(f'This task is still pending {task}') - elif 'error' in container.get('reason', '').lower(): + raise AirflowException(f"This task is not in success state {task}") + elif container.get("lastStatus") == "PENDING": + raise AirflowException(f"This task is still pending {task}") + elif "error" in container.get("reason", "").lower(): raise AirflowException( f"This containers encounter an error during launching: " f"{container.get('reason', '').lower()}" ) - def get_hook(self) -> AwsBaseHook: - """Create and return an AwsHook.""" - if self.hook: - return self.hook - - self.hook = AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type='ecs', region_name=self.region_name) - return self.hook - def on_kill(self) -> None: if not self.client or not self.arn: return @@ -510,53 +603,56 @@ def on_kill(self) -> None: self.task_log_fetcher.stop() response = self.client.stop_task( - cluster=self.cluster, task=self.arn, reason='Task killed by the user' + cluster=self.cluster, task=self.arn, reason="Task killed by the user" ) self.log.info(response) -class ECSOperator(EcsOperator): +class EcsOperator(EcsRunTaskOperator): """ This operator is deprecated. - Please use :class:`airflow.providers.amazon.aws.operators.ecs.EcsOperator`. + Please use :class:`airflow.providers.amazon.aws.operators.ecs.EcsRunTaskOperator`. """ def __init__(self, *args, **kwargs): warnings.warn( "This operator is deprecated. " - "Please use `airflow.providers.amazon.aws.operators.ecs.EcsOperator`.", + "Please use `airflow.providers.amazon.aws.operators.ecs.EcsRunTaskOperator`.", DeprecationWarning, stacklevel=2, ) super().__init__(*args, **kwargs) -class ECSTaskLogFetcher(EcsTaskLogFetcher): +class EcsTaskLogFetcher(ecs.EcsTaskLogFetcher): """ This class is deprecated. - Please use :class:`airflow.providers.amazon.aws.operators.ecs.EcsTaskLogFetcher`. + Please use :class:`airflow.providers.amazon.aws.hooks.ecs.EcsTaskLogFetcher`. """ + # TODO: Note to deprecator, Be sure to fix the use of `ecs.EcsTaskLogFetcher` + # in the Operators above when you remove this wrapper class. def __init__(self, *args, **kwargs): warnings.warn( "This class is deprecated. " - "Please use `airflow.providers.amazon.aws.operators.ecs.EcsTaskLogFetcher`.", + "Please use `airflow.providers.amazon.aws.hooks.ecs.EcsTaskLogFetcher`.", DeprecationWarning, stacklevel=2, ) super().__init__(*args, **kwargs) -class ECSProtocol(EcsProtocol): +class EcsProtocol(ecs.EcsProtocol): """ This class is deprecated. - Please use :class:`airflow.providers.amazon.aws.operators.ecs.EcsProtocol`. + Please use :class:`airflow.providers.amazon.aws.hooks.ecs.EcsProtocol`. """ + # TODO: Note to deprecator, Be sure to fix the use of `ecs.EcsProtocol` + # in the Operators above when you remove this wrapper class. def __init__(self): warnings.warn( - "This class is deprecated. " - "Please use `airflow.providers.amazon.aws.operators.ecs.EcsProtocol`.", + "This class is deprecated. Please use `airflow.providers.amazon.aws.hooks.ecs.EcsProtocol`.", DeprecationWarning, stacklevel=2, ) diff --git a/airflow/providers/amazon/aws/operators/eks.py b/airflow/providers/amazon/aws/operators/eks.py index 8e97015d86182..70d74f4cda849 100644 --- a/airflow/providers/amazon/aws/operators/eks.py +++ b/airflow/providers/amazon/aws/operators/eks.py @@ -14,12 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains Amazon EKS operators.""" +from __future__ import annotations + import warnings from ast import literal_eval from time import sleep -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast +from typing import TYPE_CHECKING, Any, List, Sequence, cast from airflow import AirflowException from airflow.models import BaseOperator @@ -32,21 +33,20 @@ CHECK_INTERVAL_SECONDS = 15 TIMEOUT_SECONDS = 25 * 60 -DEFAULT_COMPUTE_TYPE = 'nodegroup' -DEFAULT_CONN_ID = 'aws_default' -DEFAULT_FARGATE_PROFILE_NAME = 'profile' -DEFAULT_NAMESPACE_NAME = 'default' -DEFAULT_NODEGROUP_NAME = 'nodegroup' -DEFAULT_POD_NAME = 'pod' +DEFAULT_COMPUTE_TYPE = "nodegroup" +DEFAULT_CONN_ID = "aws_default" +DEFAULT_FARGATE_PROFILE_NAME = "profile" +DEFAULT_NAMESPACE_NAME = "default" +DEFAULT_NODEGROUP_NAME = "nodegroup" ABORT_MSG = "{compute} are still active after the allocated time limit. Aborting." CAN_NOT_DELETE_MSG = "A cluster can not be deleted with attached {compute}. Deleting {count} {compute}." MISSING_ARN_MSG = "Creating an {compute} requires {requirement} to be passed in." SUCCESS_MSG = "No {compute} remain, deleting cluster." -SUPPORTED_COMPUTE_VALUES = frozenset({'nodegroup', 'fargate'}) -NODEGROUP_FULL_NAME = 'Amazon EKS managed node groups' -FARGATE_FULL_NAME = 'AWS Fargate profiles' +SUPPORTED_COMPUTE_VALUES = frozenset({"nodegroup", "fargate"}) +NODEGROUP_FULL_NAME = "Amazon EKS managed node groups" +FARGATE_FULL_NAME = "AWS Fargate profiles" class EksCreateClusterOperator(BaseOperator): @@ -125,18 +125,18 @@ def __init__( self, cluster_name: str, cluster_role_arn: str, - resources_vpc_config: Dict[str, Any], - compute: Optional[str] = DEFAULT_COMPUTE_TYPE, - create_cluster_kwargs: Optional[Dict] = None, + resources_vpc_config: dict[str, Any], + compute: str | None = DEFAULT_COMPUTE_TYPE, + create_cluster_kwargs: dict | None = None, nodegroup_name: str = DEFAULT_NODEGROUP_NAME, - nodegroup_role_arn: Optional[str] = None, - create_nodegroup_kwargs: Optional[Dict] = None, + nodegroup_role_arn: str | None = None, + create_nodegroup_kwargs: dict | None = None, fargate_profile_name: str = DEFAULT_FARGATE_PROFILE_NAME, - fargate_pod_execution_role_arn: Optional[str] = None, - fargate_selectors: Optional[List] = None, - create_fargate_profile_kwargs: Optional[Dict] = None, + fargate_pod_execution_role_arn: str | None = None, + fargate_selectors: list | None = None, + create_fargate_profile_kwargs: dict | None = None, aws_conn_id: str = DEFAULT_CONN_ID, - region: Optional[str] = None, + region: str | None = None, **kwargs, ) -> None: self.compute = compute @@ -155,18 +155,18 @@ def __init__( self.region = region super().__init__(**kwargs) - def execute(self, context: 'Context'): + def execute(self, context: Context): if self.compute: if self.compute not in SUPPORTED_COMPUTE_VALUES: raise ValueError("Provided compute type is not supported.") - elif (self.compute == 'nodegroup') and not self.nodegroup_role_arn: + elif (self.compute == "nodegroup") and not self.nodegroup_role_arn: raise ValueError( - MISSING_ARN_MSG.format(compute=NODEGROUP_FULL_NAME, requirement='nodegroup_role_arn') + MISSING_ARN_MSG.format(compute=NODEGROUP_FULL_NAME, requirement="nodegroup_role_arn") ) - elif (self.compute == 'fargate') and not self.fargate_pod_execution_role_arn: + elif (self.compute == "fargate") and not self.fargate_pod_execution_role_arn: raise ValueError( MISSING_ARN_MSG.format( - compute=FARGATE_FULL_NAME, requirement='fargate_pod_execution_role_arn' + compute=FARGATE_FULL_NAME, requirement="fargate_pod_execution_role_arn" ) ) @@ -205,15 +205,15 @@ def execute(self, context: 'Context'): eks_hook.delete_cluster(name=self.cluster_name) raise RuntimeError(message) - if self.compute == 'nodegroup': + if self.compute == "nodegroup": eks_hook.create_nodegroup( clusterName=self.cluster_name, nodegroupName=self.nodegroup_name, - subnets=cast(List[str], self.resources_vpc_config.get('subnetIds')), + subnets=cast(List[str], self.resources_vpc_config.get("subnetIds")), nodeRole=self.nodegroup_role_arn, **self.create_nodegroup_kwargs, ) - elif self.compute == 'fargate': + elif self.compute == "fargate": eks_hook.create_fargate_profile( clusterName=self.cluster_name, fargateProfileName=self.fargate_profile_name, @@ -261,12 +261,12 @@ class EksCreateNodegroupOperator(BaseOperator): def __init__( self, cluster_name: str, - nodegroup_subnets: Union[List[str], str], + nodegroup_subnets: list[str] | str, nodegroup_role_arn: str, nodegroup_name: str = DEFAULT_NODEGROUP_NAME, - create_nodegroup_kwargs: Optional[Dict] = None, + create_nodegroup_kwargs: dict | None = None, aws_conn_id: str = DEFAULT_CONN_ID, - region: Optional[str] = None, + region: str | None = None, **kwargs, ) -> None: self.cluster_name = cluster_name @@ -278,9 +278,9 @@ def __init__( self.nodegroup_subnets = nodegroup_subnets super().__init__(**kwargs) - def execute(self, context: 'Context'): + def execute(self, context: Context): if isinstance(self.nodegroup_subnets, str): - nodegroup_subnets_list: List[str] = [] + nodegroup_subnets_list: list[str] = [] if self.nodegroup_subnets != "": try: nodegroup_subnets_list = cast(List, literal_eval(self.nodegroup_subnets)) @@ -344,11 +344,11 @@ def __init__( self, cluster_name: str, pod_execution_role_arn: str, - selectors: List, - fargate_profile_name: Optional[str] = DEFAULT_FARGATE_PROFILE_NAME, - create_fargate_profile_kwargs: Optional[Dict] = None, + selectors: list, + fargate_profile_name: str | None = DEFAULT_FARGATE_PROFILE_NAME, + create_fargate_profile_kwargs: dict | None = None, aws_conn_id: str = DEFAULT_CONN_ID, - region: Optional[str] = None, + region: str | None = None, **kwargs, ) -> None: self.cluster_name = cluster_name @@ -360,7 +360,7 @@ def __init__( self.region = region super().__init__(**kwargs) - def execute(self, context: 'Context'): + def execute(self, context: Context): eks_hook = EksHook( aws_conn_id=self.aws_conn_id, region_name=self.region, @@ -408,7 +408,7 @@ def __init__( cluster_name: str, force_delete_compute: bool = False, aws_conn_id: str = DEFAULT_CONN_ID, - region: Optional[str] = None, + region: str | None = None, **kwargs, ) -> None: self.cluster_name = cluster_name @@ -417,7 +417,7 @@ def __init__( self.region = region super().__init__(**kwargs) - def execute(self, context: 'Context'): + def execute(self, context: Context): eks_hook = EksHook( aws_conn_id=self.aws_conn_id, region_name=self.region, @@ -528,7 +528,7 @@ def __init__( cluster_name: str, nodegroup_name: str, aws_conn_id: str = DEFAULT_CONN_ID, - region: Optional[str] = None, + region: str | None = None, **kwargs, ) -> None: self.cluster_name = cluster_name @@ -537,7 +537,7 @@ def __init__( self.region = region super().__init__(**kwargs) - def execute(self, context: 'Context'): + def execute(self, context: Context): eks_hook = EksHook( aws_conn_id=self.aws_conn_id, region_name=self.region, @@ -577,7 +577,7 @@ def __init__( cluster_name: str, fargate_profile_name: str, aws_conn_id: str = DEFAULT_CONN_ID, - region: Optional[str] = None, + region: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -586,7 +586,7 @@ def __init__( self.aws_conn_id = aws_conn_id self.region = region - def execute(self, context: 'Context'): + def execute(self, context: Context): eks_hook = EksHook( aws_conn_id=self.aws_conn_id, region_name=self.region, @@ -646,23 +646,14 @@ def __init__( # file is stored locally in the worker and not in the cluster. in_cluster: bool = False, namespace: str = DEFAULT_NAMESPACE_NAME, - pod_context: Optional[str] = None, - pod_name: Optional[str] = None, - pod_username: Optional[str] = None, + pod_context: str | None = None, + pod_name: str | None = None, + pod_username: str | None = None, aws_conn_id: str = DEFAULT_CONN_ID, - region: Optional[str] = None, - is_delete_operator_pod: Optional[bool] = None, + region: str | None = None, + is_delete_operator_pod: bool | None = None, **kwargs, ) -> None: - if pod_name is None: - warnings.warn( - "Default value of pod name is deprecated. " - "We recommend that you pass pod name explicitly. ", - DeprecationWarning, - stacklevel=2, - ) - pod_name = DEFAULT_POD_NAME - if is_delete_operator_pod is None: warnings.warn( f"You have not set parameter `is_delete_operator_pod` in class {self.__class__.__name__}. " @@ -687,26 +678,12 @@ def __init__( is_delete_operator_pod=is_delete_operator_pod, **kwargs, ) - if pod_username: - warnings.warn( - "This pod_username parameter is deprecated, because changing the value does not make any " - "visible changes to the user.", - DeprecationWarning, - stacklevel=2, - ) - if pod_context: - warnings.warn( - "This pod_context parameter is deprecated, because changing the value does not make any " - "visible changes to the user.", - DeprecationWarning, - stacklevel=2, - ) # There is no need to manage the kube_config file, as it will be generated automatically. # All Kubernetes parameters (except config_file) are also valid for the EksPodOperator. if self.config_file: raise AirflowException("The config_file is not an allowed parameter for the EksPodOperator.") - def execute(self, context: 'Context'): + def execute(self, context: Context): eks_hook = EksHook( aws_conn_id=self.aws_conn_id, region_name=self.region, @@ -715,115 +692,3 @@ def execute(self, context: 'Context'): eks_cluster_name=self.cluster_name, pod_namespace=self.namespace ) as self.config_file: return super().execute(context) - - -class EKSCreateClusterOperator(EksCreateClusterOperator): - """ - This operator is deprecated. - Please use :class:`airflow.providers.amazon.aws.operators.eks.EksCreateClusterOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This operator is deprecated. " - "Please use `airflow.providers.amazon.aws.operators.eks.EksCreateClusterOperator`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class EKSCreateNodegroupOperator(EksCreateNodegroupOperator): - """ - This operator is deprecated. - Please use :class:`airflow.providers.amazon.aws.operators.eks.EksCreateNodegroupOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This operator is deprecated. " - "Please use `airflow.providers.amazon.aws.operators.eks.EksCreateNodegroupOperator`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class EKSCreateFargateProfileOperator(EksCreateFargateProfileOperator): - """ - This operator is deprecated. - Please use :class:`airflow.providers.amazon.aws.operators.eks.EksCreateFargateProfileOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This operator is deprecated. " - "Please use `airflow.providers.amazon.aws.operators.eks.EksCreateFargateProfileOperator`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class EKSDeleteClusterOperator(EksDeleteClusterOperator): - """ - This operator is deprecated. - Please use :class:`airflow.providers.amazon.aws.operators.eks.EksDeleteClusterOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This operator is deprecated. " - "Please use `airflow.providers.amazon.aws.operators.eks.EksDeleteClusterOperator`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class EKSDeleteNodegroupOperator(EksDeleteNodegroupOperator): - """ - This operator is deprecated. - Please use :class:`airflow.providers.amazon.aws.operators.eks.EksDeleteNodegroupOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This operator is deprecated. " - "Please use `airflow.providers.amazon.aws.operators.eks.EksDeleteNodegroupOperator`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class EKSDeleteFargateProfileOperator(EksDeleteFargateProfileOperator): - """ - This operator is deprecated. - Please use :class:`airflow.providers.amazon.aws.operators.eks.EksDeleteFargateProfileOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This operator is deprecated. " - "Please use `airflow.providers.amazon.aws.operators.eks.EksDeleteFargateProfileOperator`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class EKSPodOperator(EksPodOperator): - """ - This operator is deprecated. - Please use :class:`airflow.providers.amazon.aws.operators.eks.EksPodOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This operator is deprecated. " - "Please use `airflow.providers.amazon.aws.operators.eks.EksPodOperator`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index a1f3fa753d817..63659e8188b9d 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -15,27 +15,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import ast -import sys -from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union +import warnings +from typing import TYPE_CHECKING, Any, Sequence from uuid import uuid4 from airflow.exceptions import AirflowException -from airflow.models import BaseOperator, BaseOperatorLink, XCom -from airflow.providers.amazon.aws.hooks.emr import EmrHook +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook +from airflow.providers.amazon.aws.links.emr import EmrClusterLink if TYPE_CHECKING: - from airflow.models.taskinstance import TaskInstanceKey from airflow.utils.context import Context - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - -from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook +from airflow.compat.functools import cached_property class EmrAddStepsOperator(BaseOperator): @@ -55,28 +50,29 @@ class EmrAddStepsOperator(BaseOperator): :param aws_conn_id: aws connection to uses :param steps: boto3 style steps or reference to a steps file (must be '.json') to be added to the jobflow. (templated) + :param wait_for_completion: If True, the operator will wait for all the steps to be completed. :param do_xcom_push: if True, job_flow_id is pushed to XCom with key job_flow_id. """ - template_fields: Sequence[str] = ('job_flow_id', 'job_flow_name', 'cluster_states', 'steps') - template_ext: Sequence[str] = ('.json',) + template_fields: Sequence[str] = ("job_flow_id", "job_flow_name", "cluster_states", "steps") + template_ext: Sequence[str] = (".json",) template_fields_renderers = {"steps": "json"} - ui_color = '#f9c915' + ui_color = "#f9c915" + operator_extra_links = (EmrClusterLink(),) def __init__( self, *, - job_flow_id: Optional[str] = None, - job_flow_name: Optional[str] = None, - cluster_states: Optional[List[str]] = None, - aws_conn_id: str = 'aws_default', - steps: Optional[Union[List[dict], str]] = None, + job_flow_id: str | None = None, + job_flow_name: str | None = None, + cluster_states: list[str] | None = None, + aws_conn_id: str = "aws_default", + steps: list[dict] | str | None = None, + wait_for_completion: bool = False, **kwargs, ): - if kwargs.get('xcom_push') is not None: - raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead") if not (job_flow_id is None) ^ (job_flow_name is None): - raise AirflowException('Exactly one of job_flow_id or job_flow_name must be specified.') + raise AirflowException("Exactly one of job_flow_id or job_flow_name must be specified.") super().__init__(**kwargs) cluster_states = cluster_states or [] steps = steps or [] @@ -85,23 +81,30 @@ def __init__( self.job_flow_name = job_flow_name self.cluster_states = cluster_states self.steps = steps + self.wait_for_completion = wait_for_completion - def execute(self, context: 'Context') -> List[str]: + def execute(self, context: Context) -> list[str]: emr_hook = EmrHook(aws_conn_id=self.aws_conn_id) - emr = emr_hook.get_conn() - job_flow_id = self.job_flow_id or emr_hook.get_cluster_id_by_name( str(self.job_flow_name), self.cluster_states ) if not job_flow_id: - raise AirflowException(f'No cluster found for name: {self.job_flow_name}') + raise AirflowException(f"No cluster found for name: {self.job_flow_name}") if self.do_xcom_push: - context['ti'].xcom_push(key='job_flow_id', value=job_flow_id) + context["ti"].xcom_push(key="job_flow_id", value=job_flow_id) + + EmrClusterLink.persist( + context=context, + operator=self, + region_name=emr_hook.conn_region_name, + aws_partition=emr_hook.conn_partition, + job_flow_id=job_flow_id, + ) - self.log.info('Adding steps to %s', job_flow_id) + self.log.info("Adding steps to %s", job_flow_id) # steps may arrive as a string representing a list # e.g. if we used XCom or a file then: steps="[{ step1 }, { step2 }]" @@ -109,19 +112,73 @@ def execute(self, context: 'Context') -> List[str]: if isinstance(steps, str): steps = ast.literal_eval(steps) - response = emr.add_job_flow_steps(JobFlowId=job_flow_id, Steps=steps) + return emr_hook.add_job_flow_steps(job_flow_id=job_flow_id, steps=steps, wait_for_completion=True) - if not response['ResponseMetadata']['HTTPStatusCode'] == 200: - raise AirflowException(f'Adding steps failed: {response}') - else: - self.log.info('Steps %s added to JobFlow', response['StepIds']) - return response['StepIds'] + +class EmrEksCreateClusterOperator(BaseOperator): + """ + An operator that creates EMR on EKS virtual clusters. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:EmrEksCreateClusterOperator` + + :param virtual_cluster_name: The name of the EMR EKS virtual cluster to create. + :param eks_cluster_name: The EKS cluster used by the EMR virtual cluster. + :param eks_namespace: namespace used by the EKS cluster. + :param virtual_cluster_id: The EMR on EKS virtual cluster id. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param tags: The tags assigned to created cluster. + Defaults to None + """ + + template_fields: Sequence[str] = ( + "virtual_cluster_name", + "eks_cluster_name", + "eks_namespace", + ) + ui_color = "#f9c915" + + def __init__( + self, + *, + virtual_cluster_name: str, + eks_cluster_name: str, + eks_namespace: str, + virtual_cluster_id: str = "", + aws_conn_id: str = "aws_default", + tags: dict | None = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.virtual_cluster_name = virtual_cluster_name + self.eks_cluster_name = eks_cluster_name + self.eks_namespace = eks_namespace + self.virtual_cluster_id = virtual_cluster_id + self.aws_conn_id = aws_conn_id + self.tags = tags + + @cached_property + def hook(self) -> EmrContainerHook: + """Create and return an EmrContainerHook.""" + return EmrContainerHook(self.aws_conn_id) + + def execute(self, context: Context) -> str | None: + """Create EMR on EKS virtual Cluster""" + self.virtual_cluster_id = self.hook.create_emr_on_eks_cluster( + self.virtual_cluster_name, self.eks_cluster_name, self.eks_namespace, self.tags + ) + return self.virtual_cluster_id class EmrContainerOperator(BaseOperator): """ An operator that submits jobs to EMR on EKS virtual clusters. + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:EmrContainerOperator` + :param name: The name of the job run. :param virtual_cluster_id: The EMR on EKS virtual cluster ID :param execution_role_arn: The IAM role ARN associated with the job run. @@ -133,8 +190,10 @@ class EmrContainerOperator(BaseOperator): Use this if you want to specify a unique ID to prevent two jobs from getting started. If no token is provided, a UUIDv4 token will be generated for you. :param aws_conn_id: The Airflow connection used for AWS credentials. + :param wait_for_completion: Whether or not to wait in the operator for the job to complete. :param poll_interval: Time (in seconds) to wait between two consecutive calls to check query status on EMR - :param max_tries: Maximum number of times to wait for the job run to finish. + :param max_tries: Deprecated - use max_polling_attempts instead. + :param max_polling_attempts: Maximum number of times to wait for the job run to finish. Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state. :param tags: The tags assigned to job runs. Defaults to None @@ -157,12 +216,14 @@ def __init__( execution_role_arn: str, release_label: str, job_driver: dict, - configuration_overrides: Optional[dict] = None, - client_request_token: Optional[str] = None, + configuration_overrides: dict | None = None, + client_request_token: str | None = None, aws_conn_id: str = "aws_default", + wait_for_completion: bool = True, poll_interval: int = 30, - max_tries: Optional[int] = None, - tags: Optional[dict] = None, + max_tries: int | None = None, + tags: dict | None = None, + max_polling_attempts: int | None = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -174,10 +235,23 @@ def __init__( self.configuration_overrides = configuration_overrides or {} self.aws_conn_id = aws_conn_id self.client_request_token = client_request_token or str(uuid4()) + self.wait_for_completion = wait_for_completion self.poll_interval = poll_interval - self.max_tries = max_tries + self.max_polling_attempts = max_polling_attempts self.tags = tags - self.job_id: Optional[str] = None + self.job_id: str | None = None + + if max_tries: + warnings.warn( + f"Parameter `{self.__class__.__name__}.max_tries` is deprecated and will be removed " + "in a future release. Please use method `max_polling_attempts` instead.", + DeprecationWarning, + stacklevel=2, + ) + if max_polling_attempts and max_polling_attempts != max_tries: + raise Exception("max_polling_attempts must be the same value as max_tries") + else: + self.max_polling_attempts = max_tries @cached_property def hook(self) -> EmrContainerHook: @@ -187,7 +261,7 @@ def hook(self) -> EmrContainerHook: virtual_cluster_id=self.virtual_cluster_id, ) - def execute(self, context: 'Context') -> Optional[str]: + def execute(self, context: Context) -> str | None: """Run job on EMR Containers""" self.job_id = self.hook.submit_job( self.name, @@ -198,20 +272,25 @@ def execute(self, context: 'Context') -> Optional[str]: self.client_request_token, self.tags, ) - query_status = self.hook.poll_query_status(self.job_id, self.max_tries, self.poll_interval) - - if query_status in EmrContainerHook.FAILURE_STATES: - error_message = self.hook.get_job_failure_reason(self.job_id) - raise AirflowException( - f"EMR Containers job failed. Final state is {query_status}. " - f"query_execution_id is {self.job_id}. Error: {error_message}" - ) - elif not query_status or query_status in EmrContainerHook.INTERMEDIATE_STATES: - raise AirflowException( - f"Final state of EMR Containers job is {query_status}. " - f"Max tries of poll status exceeded, query_execution_id is {self.job_id}." + if self.wait_for_completion: + query_status = self.hook.poll_query_status( + self.job_id, + max_polling_attempts=self.max_polling_attempts, + poll_interval=self.poll_interval, ) + if query_status in EmrContainerHook.FAILURE_STATES: + error_message = self.hook.get_job_failure_reason(self.job_id) + raise AirflowException( + f"EMR Containers job failed. Final state is {query_status}. " + f"query_execution_id is {self.job_id}. Error: {error_message}" + ) + elif not query_status or query_status in EmrContainerHook.INTERMEDIATE_STATES: + raise AirflowException( + f"Final state of EMR Containers job is {query_status}. " + f"Max tries of poll status exceeded, query_execution_id is {self.job_id}." + ) + return self.job_id def on_kill(self) -> None: @@ -235,38 +314,6 @@ def on_kill(self) -> None: self.hook.poll_query_status(self.job_id) -class EmrClusterLink(BaseOperatorLink): - """Operator link for EmrCreateJobFlowOperator. It allows users to access the EMR Cluster""" - - name = 'EMR Cluster' - - def get_link( - self, - operator, - dttm: Optional[datetime] = None, - ti_key: Optional["TaskInstanceKey"] = None, - ) -> str: - """ - Get link to EMR cluster. - - :param operator: operator - :param dttm: datetime - :return: url link - """ - if ti_key is not None: - flow_id = XCom.get_value(key="return_value", ti_key=ti_key) - else: - assert dttm - flow_id = XCom.get_one( - key="return_value", dag_id=operator.dag_id, task_id=operator.task_id, execution_date=dttm - ) - return ( - f'https://console.aws.amazon.com/elasticmapreduce/home#cluster-details:{flow_id}' - if flow_id - else '' - ) - - class EmrCreateJobFlowOperator(BaseOperator): """ Creates an EMR JobFlow, reading the config from the EMR connection. @@ -277,57 +324,70 @@ class EmrCreateJobFlowOperator(BaseOperator): For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:EmrCreateJobFlowOperator` - :param aws_conn_id: aws connection to uses - :param emr_conn_id: emr connection to use + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is None or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node) + :param emr_conn_id: :ref:`Amazon Elastic MapReduce Connection `. + Use to receive an initial Amazon EMR cluster configuration: + ``boto3.client('emr').run_job_flow`` request body. + If this is None or empty or the connection does not exist, + then an empty initial configuration is used. :param job_flow_overrides: boto3 style arguments or reference to an arguments file - (must be '.json') to override emr_connection extra. (templated) + (must be '.json') to override specific ``emr_conn_id`` extra parameters. (templated) :param region_name: Region named passed to EmrHook """ - template_fields: Sequence[str] = ('job_flow_overrides',) - template_ext: Sequence[str] = ('.json',) + template_fields: Sequence[str] = ("job_flow_overrides",) + template_ext: Sequence[str] = (".json",) template_fields_renderers = {"job_flow_overrides": "json"} - ui_color = '#f9c915' + ui_color = "#f9c915" operator_extra_links = (EmrClusterLink(),) def __init__( self, *, - aws_conn_id: str = 'aws_default', - emr_conn_id: str = 'emr_default', - job_flow_overrides: Optional[Union[str, Dict[str, Any]]] = None, - region_name: Optional[str] = None, + aws_conn_id: str = "aws_default", + emr_conn_id: str | None = "emr_default", + job_flow_overrides: str | dict[str, Any] | None = None, + region_name: str | None = None, **kwargs, ): super().__init__(**kwargs) self.aws_conn_id = aws_conn_id self.emr_conn_id = emr_conn_id - if job_flow_overrides is None: - job_flow_overrides = {} - self.job_flow_overrides = job_flow_overrides + self.job_flow_overrides = job_flow_overrides or {} self.region_name = region_name - def execute(self, context: 'Context') -> str: + def execute(self, context: Context) -> str: emr = EmrHook( aws_conn_id=self.aws_conn_id, emr_conn_id=self.emr_conn_id, region_name=self.region_name ) self.log.info( - 'Creating JobFlow using aws-conn-id: %s, emr-conn-id: %s', self.aws_conn_id, self.emr_conn_id + "Creating JobFlow using aws-conn-id: %s, emr-conn-id: %s", self.aws_conn_id, self.emr_conn_id ) - if isinstance(self.job_flow_overrides, str): - job_flow_overrides: Dict[str, Any] = ast.literal_eval(self.job_flow_overrides) + job_flow_overrides: dict[str, Any] = ast.literal_eval(self.job_flow_overrides) self.job_flow_overrides = job_flow_overrides else: job_flow_overrides = self.job_flow_overrides response = emr.create_job_flow(job_flow_overrides) - if not response['ResponseMetadata']['HTTPStatusCode'] == 200: - raise AirflowException(f'JobFlow creation failed: {response}') + if not response["ResponseMetadata"]["HTTPStatusCode"] == 200: + raise AirflowException(f"JobFlow creation failed: {response}") else: - self.log.info('JobFlow with id %s created', response['JobFlowId']) - return response['JobFlowId'] + job_flow_id = response["JobFlowId"] + self.log.info("JobFlow with id %s created", job_flow_id) + EmrClusterLink.persist( + context=context, + operator=self, + region_name=emr.conn_region_name, + aws_partition=emr.conn_partition, + job_flow_id=job_flow_id, + ) + return job_flow_id class EmrModifyClusterOperator(BaseOperator): @@ -344,38 +404,44 @@ class EmrModifyClusterOperator(BaseOperator): :param do_xcom_push: if True, cluster_id is pushed to XCom with key cluster_id. """ - template_fields: Sequence[str] = ('cluster_id', 'step_concurrency_level') + template_fields: Sequence[str] = ("cluster_id", "step_concurrency_level") template_ext: Sequence[str] = () - ui_color = '#f9c915' + ui_color = "#f9c915" + operator_extra_links = (EmrClusterLink(),) def __init__( - self, *, cluster_id: str, step_concurrency_level: int, aws_conn_id: str = 'aws_default', **kwargs + self, *, cluster_id: str, step_concurrency_level: int, aws_conn_id: str = "aws_default", **kwargs ): - if kwargs.get('xcom_push') is not None: - raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead") super().__init__(**kwargs) self.aws_conn_id = aws_conn_id self.cluster_id = cluster_id self.step_concurrency_level = step_concurrency_level - def execute(self, context: 'Context') -> int: + def execute(self, context: Context) -> int: emr_hook = EmrHook(aws_conn_id=self.aws_conn_id) - emr = emr_hook.get_conn() if self.do_xcom_push: - context['ti'].xcom_push(key='cluster_id', value=self.cluster_id) + context["ti"].xcom_push(key="cluster_id", value=self.cluster_id) + + EmrClusterLink.persist( + context=context, + operator=self, + region_name=emr_hook.conn_region_name, + aws_partition=emr_hook.conn_partition, + job_flow_id=self.cluster_id, + ) - self.log.info('Modifying cluster %s', self.cluster_id) + self.log.info("Modifying cluster %s", self.cluster_id) response = emr.modify_cluster( ClusterId=self.cluster_id, StepConcurrencyLevel=self.step_concurrency_level ) - if response['ResponseMetadata']['HTTPStatusCode'] != 200: - raise AirflowException(f'Modify cluster failed: {response}') + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException(f"Modify cluster failed: {response}") else: - self.log.info('Steps concurrency level %d', response['StepConcurrencyLevel']) - return response['StepConcurrencyLevel'] + self.log.info("Steps concurrency level %d", response["StepConcurrencyLevel"]) + return response["StepConcurrencyLevel"] class EmrTerminateJobFlowOperator(BaseOperator): @@ -390,22 +456,292 @@ class EmrTerminateJobFlowOperator(BaseOperator): :param aws_conn_id: aws connection to uses """ - template_fields: Sequence[str] = ('job_flow_id',) + template_fields: Sequence[str] = ("job_flow_id",) template_ext: Sequence[str] = () - ui_color = '#f9c915' + ui_color = "#f9c915" + operator_extra_links = (EmrClusterLink(),) - def __init__(self, *, job_flow_id: str, aws_conn_id: str = 'aws_default', **kwargs): + def __init__(self, *, job_flow_id: str, aws_conn_id: str = "aws_default", **kwargs): super().__init__(**kwargs) self.job_flow_id = job_flow_id self.aws_conn_id = aws_conn_id - def execute(self, context: 'Context') -> None: - emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn() + def execute(self, context: Context) -> None: + emr_hook = EmrHook(aws_conn_id=self.aws_conn_id) + emr = emr_hook.get_conn() + + EmrClusterLink.persist( + context=context, + operator=self, + region_name=emr_hook.conn_region_name, + aws_partition=emr_hook.conn_partition, + job_flow_id=self.job_flow_id, + ) - self.log.info('Terminating JobFlow %s', self.job_flow_id) + self.log.info("Terminating JobFlow %s", self.job_flow_id) response = emr.terminate_job_flows(JobFlowIds=[self.job_flow_id]) - if not response['ResponseMetadata']['HTTPStatusCode'] == 200: - raise AirflowException(f'JobFlow termination failed: {response}') + if not response["ResponseMetadata"]["HTTPStatusCode"] == 200: + raise AirflowException(f"JobFlow termination failed: {response}") else: - self.log.info('JobFlow with id %s terminated', self.job_flow_id) + self.log.info("JobFlow with id %s terminated", self.job_flow_id) + + +class EmrServerlessCreateApplicationOperator(BaseOperator): + """ + Operator to create Serverless EMR Application + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:EmrServerlessCreateApplicationOperator` + + :param release_label: The EMR release version associated with the application. + :param job_type: The type of application you want to start, such as Spark or Hive. + :param wait_for_completion: If true, wait for the Application to start before returning. Default to True + :param client_request_token: The client idempotency token of the application to create. + Its value must be unique for each request. + :param config: Optional dictionary for arbitrary parameters to the boto API create_application call. + :param aws_conn_id: AWS connection to use + """ + + def __init__( + self, + release_label: str, + job_type: str, + client_request_token: str = "", + config: dict | None = None, + wait_for_completion: bool = True, + aws_conn_id: str = "aws_default", + **kwargs, + ): + self.aws_conn_id = aws_conn_id + self.release_label = release_label + self.job_type = job_type + self.wait_for_completion = wait_for_completion + self.kwargs = kwargs + self.config = config or {} + super().__init__(**kwargs) + + self.client_request_token = client_request_token or str(uuid4()) + + @cached_property + def hook(self) -> EmrServerlessHook: + """Create and return an EmrServerlessHook.""" + return EmrServerlessHook(aws_conn_id=self.aws_conn_id) + + def execute(self, context: Context): + response = self.hook.conn.create_application( + clientToken=self.client_request_token, + releaseLabel=self.release_label, + type=self.job_type, + **self.config, + ) + application_id = response["applicationId"] + + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException(f"Application Creation failed: {response}") + + self.log.info("EMR serverless application created: %s", application_id) + + # This should be replaced with a boto waiter when available. + self.hook.waiter( + get_state_callable=self.hook.conn.get_application, + get_state_args={"applicationId": application_id}, + parse_response=["application", "state"], + desired_state={"CREATED"}, + failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES, + object_type="application", + action="created", + ) + + self.log.info("Starting application %s", application_id) + self.hook.conn.start_application(applicationId=application_id) + + if self.wait_for_completion: + # This should be replaced with a boto waiter when available. + self.hook.waiter( + get_state_callable=self.hook.conn.get_application, + get_state_args={"applicationId": application_id}, + parse_response=["application", "state"], + desired_state={"STARTED"}, + failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES, + object_type="application", + action="started", + ) + + return application_id + + +class EmrServerlessStartJobOperator(BaseOperator): + """ + Operator to start EMR Serverless job. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:EmrServerlessStartJobOperator` + + :param application_id: ID of the EMR Serverless application to start. + :param execution_role_arn: ARN of role to perform action. + :param job_driver: Driver that the job runs on. + :param configuration_overrides: Configuration specifications to override existing configurations. + :param client_request_token: The client idempotency token of the application to create. + Its value must be unique for each request. + :param config: Optional dictionary for arbitrary parameters to the boto API start_job_run call. + :param wait_for_completion: If true, waits for the job to start before returning. Defaults to True. + :param aws_conn_id: AWS connection to use. + :param name: Name for the EMR Serverless job. If not provided, a default name will be assigned. + """ + + template_fields: Sequence[str] = ( + "application_id", + "execution_role_arn", + "job_driver", + "configuration_overrides", + ) + + def __init__( + self, + application_id: str, + execution_role_arn: str, + job_driver: dict, + configuration_overrides: dict | None, + client_request_token: str = "", + config: dict | None = None, + wait_for_completion: bool = True, + aws_conn_id: str = "aws_default", + name: str | None = None, + **kwargs, + ): + self.aws_conn_id = aws_conn_id + self.application_id = application_id + self.execution_role_arn = execution_role_arn + self.job_driver = job_driver + self.configuration_overrides = configuration_overrides + self.wait_for_completion = wait_for_completion + self.config = config or {} + self.name = name or self.config.pop("name", f"emr_serverless_job_airflow_{uuid4()}") + super().__init__(**kwargs) + + self.client_request_token = client_request_token or str(uuid4()) + + @cached_property + def hook(self) -> EmrServerlessHook: + """Create and return an EmrServerlessHook.""" + return EmrServerlessHook(aws_conn_id=self.aws_conn_id) + + def execute(self, context: Context) -> dict: + self.log.info("Starting job on Application: %s", self.application_id) + + app_state = self.hook.conn.get_application(applicationId=self.application_id)["application"]["state"] + if app_state not in EmrServerlessHook.APPLICATION_SUCCESS_STATES: + self.hook.conn.start_application(applicationId=self.application_id) + + self.hook.waiter( + get_state_callable=self.hook.conn.get_application, + get_state_args={"applicationId": self.application_id}, + parse_response=["application", "state"], + desired_state={"STARTED"}, + failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES, + object_type="application", + action="started", + ) + + response = self.hook.conn.start_job_run( + clientToken=self.client_request_token, + applicationId=self.application_id, + executionRoleArn=self.execution_role_arn, + jobDriver=self.job_driver, + configurationOverrides=self.configuration_overrides, + name=self.name, + **self.config, + ) + + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException(f"EMR serverless job failed to start: {response}") + + self.log.info("EMR serverless job started: %s", response["jobRunId"]) + if self.wait_for_completion: + # This should be replaced with a boto waiter when available. + self.hook.waiter( + get_state_callable=self.hook.conn.get_job_run, + get_state_args={ + "applicationId": self.application_id, + "jobRunId": response["jobRunId"], + }, + parse_response=["jobRun", "state"], + desired_state=EmrServerlessHook.JOB_SUCCESS_STATES, + failure_states=EmrServerlessHook.JOB_FAILURE_STATES, + object_type="job", + action="run", + ) + return response["jobRunId"] + + +class EmrServerlessDeleteApplicationOperator(BaseOperator): + """ + Operator to delete EMR Serverless application + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:EmrServerlessDeleteApplicationOperator` + + :param application_id: ID of the EMR Serverless application to delete. + :param wait_for_completion: If true, wait for the Application to start before returning. Default to True + :param aws_conn_id: AWS connection to use + """ + + template_fields: Sequence[str] = ("application_id",) + + def __init__( + self, + application_id: str, + wait_for_completion: bool = True, + aws_conn_id: str = "aws_default", + **kwargs, + ): + self.aws_conn_id = aws_conn_id + self.application_id = application_id + self.wait_for_completion = wait_for_completion + super().__init__(**kwargs) + + @cached_property + def hook(self) -> EmrServerlessHook: + """Create and return an EmrServerlessHook.""" + return EmrServerlessHook(aws_conn_id=self.aws_conn_id) + + def execute(self, context: Context) -> None: + self.log.info("Stopping application: %s", self.application_id) + self.hook.conn.stop_application(applicationId=self.application_id) + + # This should be replaced with a boto waiter when available. + self.hook.waiter( + get_state_callable=self.hook.conn.get_application, + get_state_args={ + "applicationId": self.application_id, + }, + parse_response=["application", "state"], + desired_state=EmrServerlessHook.APPLICATION_FAILURE_STATES, + failure_states=set(), + object_type="application", + action="stopped", + ) + + self.log.info("Deleting application: %s", self.application_id) + response = self.hook.conn.delete_application(applicationId=self.application_id) + + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException(f"Application deletion failed: {response}") + + if self.wait_for_completion: + # This should be replaced with a boto waiter when available. + self.hook.waiter( + get_state_callable=self.hook.conn.get_application, + get_state_args={"applicationId": self.application_id}, + parse_response=["application", "state"], + desired_state={"TERMINATED"}, + failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES, + object_type="application", + action="deleted", + ) + + self.log.info("EMR serverless application deleted") diff --git a/airflow/providers/amazon/aws/operators/emr_add_steps.py b/airflow/providers/amazon/aws/operators/emr_add_steps.py deleted file mode 100644 index c99be43e8cdf5..0000000000000 --- a/airflow/providers/amazon/aws/operators/emr_add_steps.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.emr`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.emr import EmrAddStepsOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/emr_containers.py b/airflow/providers/amazon/aws/operators/emr_containers.py deleted file mode 100644 index 6e81b0bddea0c..0000000000000 --- a/airflow/providers/amazon/aws/operators/emr_containers.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.emr`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr`.", - DeprecationWarning, - stacklevel=2, -) - - -class EMRContainerOperator(EmrContainerOperator): - """ - This class is deprecated. - Please use :class:`airflow.providers.amazon.aws.operators.emr.EmrContainerOperator`. - """ - - def __init__(self, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.amazon.aws.operators.emr.EmrContainerOperator`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(**kwargs) diff --git a/airflow/providers/amazon/aws/operators/emr_create_job_flow.py b/airflow/providers/amazon/aws/operators/emr_create_job_flow.py deleted file mode 100644 index 18c052fa2c835..0000000000000 --- a/airflow/providers/amazon/aws/operators/emr_create_job_flow.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.emr`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.emr import EmrClusterLink, EmrCreateJobFlowOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr`.", - DeprecationWarning, - stacklevel=2, -) - -__all__ = ["EmrClusterLink", "EmrCreateJobFlowOperator"] diff --git a/airflow/providers/amazon/aws/operators/emr_modify_cluster.py b/airflow/providers/amazon/aws/operators/emr_modify_cluster.py deleted file mode 100644 index 71b44d5364e42..0000000000000 --- a/airflow/providers/amazon/aws/operators/emr_modify_cluster.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.emr`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.emr import EmrClusterLink, EmrModifyClusterOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/emr_terminate_job_flow.py b/airflow/providers/amazon/aws/operators/emr_terminate_job_flow.py deleted file mode 100644 index f924393a443ad..0000000000000 --- a/airflow/providers/amazon/aws/operators/emr_terminate_job_flow.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.emr`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.emr import EmrTerminateJobFlowOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/glacier.py b/airflow/providers/amazon/aws/operators/glacier.py index 337492a4523e1..4e7c8b5e17421 100644 --- a/airflow/providers/amazon/aws/operators/glacier.py +++ b/airflow/providers/amazon/aws/operators/glacier.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator @@ -49,6 +51,56 @@ def __init__( self.aws_conn_id = aws_conn_id self.vault_name = vault_name - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = GlacierHook(aws_conn_id=self.aws_conn_id) return hook.retrieve_inventory(vault_name=self.vault_name) + + +class GlacierUploadArchiveOperator(BaseOperator): + """ + This operator add an archive to an Amazon S3 Glacier vault + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GlacierUploadArchiveOperator` + + :param vault_name: The name of the vault + :param body: A bytes or seekable file-like object. The data to upload. + :param checksum: The SHA256 tree hash of the data being uploaded. + This parameter is automatically populated if it is not provided + :param archive_description: The description of the archive you are uploading + :param account_id: (Optional) AWS account ID of the account that owns the vault. + Defaults to the credentials used to sign the request + :param aws_conn_id: The reference to the AWS connection details + """ + + template_fields: Sequence[str] = ("vault_name",) + + def __init__( + self, + *, + vault_name: str, + body: object, + checksum: str | None = None, + archive_description: str | None = None, + account_id: str | None = None, + aws_conn_id="aws_default", + **kwargs, + ): + super().__init__(**kwargs) + self.aws_conn_id = aws_conn_id + self.account_id = account_id + self.vault_name = vault_name + self.body = body + self.checksum = checksum + self.archive_description = archive_description + + def execute(self, context: Context): + hook = GlacierHook(aws_conn_id=self.aws_conn_id) + return hook.get_conn().upload_archive( + accountId=self.account_id, + vaultName=self.vault_name, + archiveDescription=self.archive_description, + body=self.body, + checksum=self.checksum, + ) diff --git a/airflow/providers/amazon/aws/operators/glue.py b/airflow/providers/amazon/aws/operators/glue.py index cb76e784300e0..37a08a216a0b4 100644 --- a/airflow/providers/amazon/aws/operators/glue.py +++ b/airflow/providers/amazon/aws/operators/glue.py @@ -15,10 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import os.path -import warnings -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.glue import GlueJobHook @@ -51,33 +51,41 @@ class GlueJobOperator(BaseOperator): :param create_job_kwargs: Extra arguments for Glue Job Creation :param run_job_kwargs: Extra arguments for Glue Job Run :param wait_for_completion: Whether or not wait for job run completion. (default: True) + :param verbose: If True, Glue Job Run logs show in the Airflow Task Logs. (default: False) """ - template_fields: Sequence[str] = ('script_args',) + template_fields: Sequence[str] = ( + "job_name", + "script_location", + "script_args", + "s3_bucket", + "iam_role_name", + ) template_ext: Sequence[str] = () template_fields_renderers = { "script_args": "json", "create_job_kwargs": "json", } - ui_color = '#ededed' + ui_color = "#ededed" def __init__( self, *, - job_name: str = 'aws_glue_default_job', - job_desc: str = 'AWS Glue Job with Airflow', - script_location: str, - concurrent_run_limit: Optional[int] = None, - script_args: Optional[dict] = None, + job_name: str = "aws_glue_default_job", + job_desc: str = "AWS Glue Job with Airflow", + script_location: str | None = None, + concurrent_run_limit: int | None = None, + script_args: dict | None = None, retry_limit: int = 0, - num_of_dpus: Optional[int] = None, - aws_conn_id: str = 'aws_default', - region_name: Optional[str] = None, - s3_bucket: Optional[str] = None, - iam_role_name: Optional[str] = None, - create_job_kwargs: Optional[dict] = None, - run_job_kwargs: Optional[dict] = None, + num_of_dpus: int | None = None, + aws_conn_id: str = "aws_default", + region_name: str | None = None, + s3_bucket: str | None = None, + iam_role_name: str | None = None, + create_job_kwargs: dict | None = None, + run_job_kwargs: dict | None = None, wait_for_completion: bool = True, + verbose: bool = False, **kwargs, ): super().__init__(**kwargs) @@ -93,12 +101,13 @@ def __init__( self.s3_bucket = s3_bucket self.iam_role_name = iam_role_name self.s3_protocol = "s3://" - self.s3_artifacts_prefix = 'artifacts/glue-scripts/' + self.s3_artifacts_prefix = "artifacts/glue-scripts/" self.create_job_kwargs = create_job_kwargs self.run_job_kwargs = run_job_kwargs or {} self.wait_for_completion = wait_for_completion + self.verbose = verbose - def execute(self, context: 'Context'): + def execute(self, context: Context): """ Executes AWS Glue Job from Airflow @@ -135,29 +144,13 @@ def execute(self, context: 'Context'): ) glue_job_run = glue_job.initialize_job(self.script_args, self.run_job_kwargs) if self.wait_for_completion: - glue_job_run = glue_job.job_completion(self.job_name, glue_job_run['JobRunId']) + glue_job_run = glue_job.job_completion(self.job_name, glue_job_run["JobRunId"], self.verbose) self.log.info( "AWS Glue Job: %s status: %s. Run Id: %s", self.job_name, - glue_job_run['JobRunState'], - glue_job_run['JobRunId'], + glue_job_run["JobRunState"], + glue_job_run["JobRunId"], ) else: - self.log.info("AWS Glue Job: %s. Run Id: %s", self.job_name, glue_job_run['JobRunId']) - return glue_job_run['JobRunId'] - - -class AwsGlueJobOperator(GlueJobOperator): - """ - This operator is deprecated. - Please use :class:`airflow.providers.amazon.aws.operators.glue.GlueJobOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This operator is deprecated. " - "Please use :class:`airflow.providers.amazon.aws.operators.glue.GlueJobOperator`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) + self.log.info("AWS Glue Job: %s. Run Id: %s", self.job_name, glue_job_run["JobRunId"]) + return glue_job_run["JobRunId"] diff --git a/airflow/providers/amazon/aws/operators/glue_crawler.py b/airflow/providers/amazon/aws/operators/glue_crawler.py index a584301f926d1..1e30be1b262bd 100644 --- a/airflow/providers/amazon/aws/operators/glue_crawler.py +++ b/airflow/providers/amazon/aws/operators/glue_crawler.py @@ -15,19 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys -import warnings -from typing import TYPE_CHECKING +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence if TYPE_CHECKING: from airflow.utils.context import Context - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - +from airflow.compat.functools import cached_property from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook @@ -48,12 +43,13 @@ class GlueCrawlerOperator(BaseOperator): :param wait_for_completion: Whether or not wait for crawl execution completion. (default: True) """ - ui_color = '#ededed' + template_fields: Sequence[str] = ("config",) + ui_color = "#ededed" def __init__( self, config, - aws_conn_id='aws_default', + aws_conn_id="aws_default", poll_interval: int = 5, wait_for_completion: bool = True, **kwargs, @@ -69,13 +65,13 @@ def hook(self) -> GlueCrawlerHook: """Create and return an GlueCrawlerHook.""" return GlueCrawlerHook(self.aws_conn_id) - def execute(self, context: 'Context'): + def execute(self, context: Context): """ Executes AWS Glue Crawler from Airflow :return: the name of the current glue crawler. """ - crawler_name = self.config['Name'] + crawler_name = self.config["Name"] if self.hook.has_crawler(crawler_name): self.hook.update_crawler(**self.config) else: @@ -88,19 +84,3 @@ def execute(self, context: 'Context'): self.hook.wait_for_crawler_completion(crawler_name=crawler_name, poll_interval=self.poll_interval) return crawler_name - - -class AwsGlueCrawlerOperator(GlueCrawlerOperator): - """ - This operator is deprecated. - Please use :class:`airflow.providers.amazon.aws.operators.glue_crawler.GlueCrawlerOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This operator is deprecated. " - "Please use :class:`airflow.providers.amazon.aws.operators.glue_crawler.GlueCrawlerOperator`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/operators/lambda_function.py b/airflow/providers/amazon/aws/operators/lambda_function.py new file mode 100644 index 0000000000000..59e0012b0530f --- /dev/null +++ b/airflow/providers/amazon/aws/operators/lambda_function.py @@ -0,0 +1,103 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Sequence + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class AwsLambdaInvokeFunctionOperator(BaseOperator): + """ + Invokes an AWS Lambda function. + You can invoke a function synchronously (and wait for the response), + or asynchronously. + To invoke a function asynchronously, + set `invocation_type` to `Event`. For more details, + review the boto3 Lambda invoke docs. + + :param function_name: The name of the AWS Lambda function, version, or alias. + :param payload: The JSON string that you want to provide to your Lambda function as input. + :param log_type: Set to Tail to include the execution log in the response. Otherwise, set to "None". + :param qualifier: Specify a version or alias to invoke a published version of the function. + :param aws_conn_id: The AWS connection ID to use + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AwsLambdaInvokeFunctionOperator` + + """ + + template_fields: Sequence[str] = ("function_name", "payload", "qualifier", "invocation_type") + ui_color = "#ff7300" + + def __init__( + self, + *, + function_name: str, + log_type: str | None = None, + qualifier: str | None = None, + invocation_type: str | None = None, + client_context: str | None = None, + payload: str | None = None, + aws_conn_id: str = "aws_default", + **kwargs, + ): + super().__init__(**kwargs) + self.function_name = function_name + self.payload = payload + self.log_type = log_type + self.qualifier = qualifier + self.invocation_type = invocation_type + self.client_context = client_context + self.aws_conn_id = aws_conn_id + + def execute(self, context: Context): + """ + Invokes the target AWS Lambda function from Airflow. + + :return: The response payload from the function, or an error object. + """ + hook = LambdaHook(aws_conn_id=self.aws_conn_id) + success_status_codes = [200, 202, 204] + self.log.info("Invoking AWS Lambda function: %s with payload: %s", self.function_name, self.payload) + response = hook.invoke_lambda( + function_name=self.function_name, + invocation_type=self.invocation_type, + log_type=self.log_type, + client_context=self.client_context, + payload=self.payload, + qualifier=self.qualifier, + ) + self.log.info("Lambda response metadata: %r", response.get("ResponseMetadata")) + if response.get("StatusCode") not in success_status_codes: + raise ValueError("Lambda function did not execute", json.dumps(response.get("ResponseMetadata"))) + payload_stream = response.get("Payload") + payload = payload_stream.read().decode() + if "FunctionError" in response: + raise ValueError( + "Lambda function execution resulted in error", + {"ResponseMetadata": response.get("ResponseMetadata"), "Payload": payload}, + ) + self.log.info("Lambda function invocation succeeded: %r", response.get("ResponseMetadata")) + return payload diff --git a/airflow/providers/amazon/aws/operators/quicksight.py b/airflow/providers/amazon/aws/operators/quicksight.py index a9da61bdfc3bf..85514af806c2f 100644 --- a/airflow/providers/amazon/aws/operators/quicksight.py +++ b/airflow/providers/amazon/aws/operators/quicksight.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook @@ -71,7 +72,7 @@ def __init__( wait_for_completion: bool = True, check_interval: int = 30, aws_conn_id: str = DEFAULT_CONN_ID, - region: Optional[str] = None, + region: str | None = None, **kwargs, ): self.data_set_id = data_set_id @@ -83,12 +84,12 @@ def __init__( self.region = region super().__init__(**kwargs) - def execute(self, context: "Context"): + def execute(self, context: Context): hook = QuickSightHook( aws_conn_id=self.aws_conn_id, region_name=self.region, ) - self.log.info("Running the Amazon QuickSight SPICE Ingestion on Dataset ID: %s)", self.data_set_id) + self.log.info("Running the Amazon QuickSight SPICE Ingestion on Dataset ID: %s", self.data_set_id) return hook.create_ingestion( data_set_id=self.data_set_id, ingestion_id=self.ingestion_id, diff --git a/airflow/providers/amazon/aws/operators/rds.py b/airflow/providers/amazon/aws/operators/rds.py index a527107e80a46..c10e969c8bc2d 100644 --- a/airflow/providers/amazon/aws/operators/rds.py +++ b/airflow/providers/amazon/aws/operators/rds.py @@ -15,14 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import json -import time -from typing import TYPE_CHECKING, List, Optional, Sequence +from typing import TYPE_CHECKING, Sequence from mypy_boto3_rds.type_defs import TagTypeDef -from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.rds import RdsHook from airflow.providers.amazon.aws.utils.rds import RdsDbType @@ -37,64 +36,14 @@ class RdsBaseOperator(BaseOperator): ui_color = "#eeaa88" ui_fgcolor = "#ffffff" - def __init__(self, *args, aws_conn_id: str = "aws_conn_id", hook_params: Optional[dict] = None, **kwargs): + def __init__(self, *args, aws_conn_id: str = "aws_conn_id", hook_params: dict | None = None, **kwargs): hook_params = hook_params or {} self.hook = RdsHook(aws_conn_id=aws_conn_id, **hook_params) super().__init__(*args, **kwargs) self._await_interval = 60 # seconds - def _describe_item(self, item_type: str, item_name: str) -> list: - - if item_type == 'instance_snapshot': - db_snaps = self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=item_name) - return db_snaps['DBSnapshots'] - elif item_type == 'cluster_snapshot': - cl_snaps = self.hook.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=item_name) - return cl_snaps['DBClusterSnapshots'] - elif item_type == 'export_task': - exports = self.hook.conn.describe_export_tasks(ExportTaskIdentifier=item_name) - return exports['ExportTasks'] - elif item_type == 'event_subscription': - subscriptions = self.hook.conn.describe_event_subscriptions(SubscriptionName=item_name) - return subscriptions['EventSubscriptionsList'] - else: - raise AirflowException(f"Method for {item_type} is not implemented") - - def _await_status( - self, - item_type: str, - item_name: str, - wait_statuses: Optional[List[str]] = None, - ok_statuses: Optional[List[str]] = None, - error_statuses: Optional[List[str]] = None, - ) -> None: - """ - Continuously gets item description from `_describe_item()` and waits until: - - status is in `wait_statuses` - - status not in `ok_statuses` and `error_statuses` - """ - while True: - items = self._describe_item(item_type, item_name) - - if len(items) == 0: - raise AirflowException(f"There is no {item_type} with identifier {item_name}") - if len(items) > 1: - raise AirflowException(f"There are {len(items)} {item_type} with identifier {item_name}") - - if wait_statuses and items[0]['Status'].lower() in wait_statuses: - time.sleep(self._await_interval) - continue - elif ok_statuses and items[0]['Status'].lower() in ok_statuses: - break - elif error_statuses and items[0]['Status'].lower() in error_statuses: - raise AirflowException(f"Item has error status ({error_statuses}): {items[0]}") - else: - raise AirflowException(f"Item has uncertain status: {items[0]}") - - return None - - def execute(self, context: 'Context') -> str: + def execute(self, context: Context) -> str: """Different implementations for snapshots, tasks and events""" raise NotImplementedError @@ -117,6 +66,7 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator): :param db_snapshot_identifier: The identifier for the DB snapshot :param tags: A list of tags in format `[{"Key": "something", "Value": "something"},] `USER Tagging `__ + :param wait_for_completion: If True, waits for creation of the DB snapshot to complete. (default: True) """ template_fields = ("db_snapshot_identifier", "db_identifier", "tags") @@ -127,7 +77,8 @@ def __init__( db_type: str, db_identifier: str, db_snapshot_identifier: str, - tags: Optional[Sequence[TagTypeDef]] = None, + tags: Sequence[TagTypeDef] | None = None, + wait_for_completion: bool = True, aws_conn_id: str = "aws_conn_id", **kwargs, ): @@ -136,8 +87,9 @@ def __init__( self.db_identifier = db_identifier self.db_snapshot_identifier = db_snapshot_identifier self.tags = tags or [] + self.wait_for_completion = wait_for_completion - def execute(self, context: 'Context') -> str: + def execute(self, context: Context) -> str: self.log.info( "Starting to create snapshot of RDS %s '%s': %s", self.db_type, @@ -152,12 +104,8 @@ def execute(self, context: 'Context') -> str: Tags=self.tags, ) create_response = json.dumps(create_instance_snap, default=str) - self._await_status( - 'instance_snapshot', - self.db_snapshot_identifier, - wait_statuses=['creating'], - ok_statuses=['available'], - ) + if self.wait_for_completion: + self.hook.wait_for_db_snapshot_state(self.db_snapshot_identifier, target_state="available") else: create_cluster_snap = self.hook.conn.create_db_cluster_snapshot( DBClusterIdentifier=self.db_identifier, @@ -165,13 +113,10 @@ def execute(self, context: 'Context') -> str: Tags=self.tags, ) create_response = json.dumps(create_cluster_snap, default=str) - self._await_status( - 'cluster_snapshot', - self.db_snapshot_identifier, - wait_statuses=['creating'], - ok_statuses=['available'], - ) - + if self.wait_for_completion: + self.hook.wait_for_db_cluster_snapshot_state( + self.db_snapshot_identifier, target_state="available" + ) return create_response @@ -196,6 +141,7 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator): :param target_custom_availability_zone: The external custom Availability Zone identifier for the target Only when db_type='instance' :param source_region: The ID of the region that contains the snapshot to be copied + :param wait_for_completion: If True, waits for snapshot copy to complete. (default: True) """ template_fields = ( @@ -213,12 +159,13 @@ def __init__( source_db_snapshot_identifier: str, target_db_snapshot_identifier: str, kms_key_id: str = "", - tags: Optional[Sequence[TagTypeDef]] = None, + tags: Sequence[TagTypeDef] | None = None, copy_tags: bool = False, pre_signed_url: str = "", option_group_name: str = "", target_custom_availability_zone: str = "", source_region: str = "", + wait_for_completion: bool = True, aws_conn_id: str = "aws_default", **kwargs, ): @@ -234,8 +181,9 @@ def __init__( self.option_group_name = option_group_name self.target_custom_availability_zone = target_custom_availability_zone self.source_region = source_region + self.wait_for_completion = wait_for_completion - def execute(self, context: 'Context') -> str: + def execute(self, context: Context) -> str: self.log.info( "Starting to copy snapshot '%s' as '%s'", self.source_db_snapshot_identifier, @@ -255,12 +203,10 @@ def execute(self, context: 'Context') -> str: SourceRegion=self.source_region, ) copy_response = json.dumps(copy_instance_snap, default=str) - self._await_status( - 'instance_snapshot', - self.target_db_snapshot_identifier, - wait_statuses=['creating'], - ok_statuses=['available'], - ) + if self.wait_for_completion: + self.hook.wait_for_db_snapshot_state( + self.target_db_snapshot_identifier, target_state="available" + ) else: copy_cluster_snap = self.hook.conn.copy_db_cluster_snapshot( SourceDBClusterSnapshotIdentifier=self.source_db_snapshot_identifier, @@ -272,13 +218,10 @@ def execute(self, context: 'Context') -> str: SourceRegion=self.source_region, ) copy_response = json.dumps(copy_cluster_snap, default=str) - self._await_status( - 'cluster_snapshot', - self.target_db_snapshot_identifier, - wait_statuses=['copying'], - ok_statuses=['available'], - ) - + if self.wait_for_completion: + self.hook.wait_for_db_cluster_snapshot_state( + self.target_db_snapshot_identifier, target_state="available" + ) return copy_response @@ -301,6 +244,7 @@ def __init__( *, db_type: str, db_snapshot_identifier: str, + wait_for_completion: bool = True, aws_conn_id: str = "aws_default", **kwargs, ): @@ -308,8 +252,9 @@ def __init__( self.db_type = RdsDbType(db_type) self.db_snapshot_identifier = db_snapshot_identifier + self.wait_for_completion = wait_for_completion - def execute(self, context: 'Context') -> str: + def execute(self, context: Context) -> str: self.log.info("Starting to delete snapshot '%s'", self.db_snapshot_identifier) if self.db_type.value == "instance": @@ -317,11 +262,17 @@ def execute(self, context: 'Context') -> str: DBSnapshotIdentifier=self.db_snapshot_identifier, ) delete_response = json.dumps(delete_instance_snap, default=str) + if self.wait_for_completion: + self.hook.wait_for_db_snapshot_state(self.db_snapshot_identifier, target_state="deleted") else: delete_cluster_snap = self.hook.conn.delete_db_cluster_snapshot( DBClusterSnapshotIdentifier=self.db_snapshot_identifier, ) delete_response = json.dumps(delete_cluster_snap, default=str) + if self.wait_for_completion: + self.hook.wait_for_db_cluster_snapshot_state( + self.db_snapshot_identifier, target_state="deleted" + ) return delete_response @@ -341,6 +292,7 @@ class RdsStartExportTaskOperator(RdsBaseOperator): :param kms_key_id: The ID of the Amazon Web Services KMS key to use to encrypt the snapshot. :param s3_prefix: The Amazon S3 bucket prefix to use as the file name and path of the exported snapshot. :param export_only: The data to be exported from the snapshot. + :param wait_for_completion: If True, waits for the DB snapshot export to complete. (default: True) """ template_fields = ( @@ -361,8 +313,9 @@ def __init__( s3_bucket_name: str, iam_role_arn: str, kms_key_id: str, - s3_prefix: str = '', - export_only: Optional[List[str]] = None, + s3_prefix: str = "", + export_only: list[str] | None = None, + wait_for_completion: bool = True, aws_conn_id: str = "aws_default", **kwargs, ): @@ -375,8 +328,9 @@ def __init__( self.kms_key_id = kms_key_id self.s3_prefix = s3_prefix self.export_only = export_only or [] + self.wait_for_completion = wait_for_completion - def execute(self, context: 'Context') -> str: + def execute(self, context: Context) -> str: self.log.info("Starting export task %s for snapshot %s", self.export_task_identifier, self.source_arn) start_export = self.hook.conn.start_export_task( @@ -389,14 +343,8 @@ def execute(self, context: 'Context') -> str: ExportOnly=self.export_only, ) - self._await_status( - 'export_task', - self.export_task_identifier, - wait_statuses=['starting', 'in_progress'], - ok_statuses=['complete'], - error_statuses=['canceling', 'canceled'], - ) - + if self.wait_for_completion: + self.hook.wait_for_export_task_state(self.export_task_identifier, target_state="complete") return json.dumps(start_export, default=str) @@ -409,6 +357,7 @@ class RdsCancelExportTaskOperator(RdsBaseOperator): :ref:`howto/operator:RdsCancelExportTaskOperator` :param export_task_identifier: The identifier of the snapshot export task to cancel + :param wait_for_completion: If True, waits for DB snapshot export to cancel. (default: True) """ template_fields = ("export_task_identifier",) @@ -417,26 +366,24 @@ def __init__( self, *, export_task_identifier: str, + wait_for_completion: bool = True, aws_conn_id: str = "aws_default", **kwargs, ): super().__init__(aws_conn_id=aws_conn_id, **kwargs) self.export_task_identifier = export_task_identifier + self.wait_for_completion = wait_for_completion - def execute(self, context: 'Context') -> str: + def execute(self, context: Context) -> str: self.log.info("Canceling export task %s", self.export_task_identifier) cancel_export = self.hook.conn.cancel_export_task( ExportTaskIdentifier=self.export_task_identifier, ) - self._await_status( - 'export_task', - self.export_task_identifier, - wait_statuses=['canceling'], - ok_statuses=['canceled'], - ) + if self.wait_for_completion: + self.hook.wait_for_export_task_state(self.export_task_identifier, target_state="canceled") return json.dumps(cancel_export, default=str) @@ -458,6 +405,7 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator): :param enabled: A value that indicates whether to activate the subscription (default True)l :param tags: A list of tags in format `[{"Key": "something", "Value": "something"},] `USER Tagging `__ + :param wait_for_completion: If True, waits for creation of the subscription to complete. (default: True) """ template_fields = ( @@ -475,10 +423,11 @@ def __init__( subscription_name: str, sns_topic_arn: str, source_type: str = "", - event_categories: Optional[Sequence[str]] = None, - source_ids: Optional[Sequence[str]] = None, + event_categories: Sequence[str] | None = None, + source_ids: Sequence[str] | None = None, enabled: bool = True, - tags: Optional[Sequence[TagTypeDef]] = None, + tags: Sequence[TagTypeDef] | None = None, + wait_for_completion: bool = True, aws_conn_id: str = "aws_default", **kwargs, ): @@ -491,8 +440,9 @@ def __init__( self.source_ids = source_ids or [] self.enabled = enabled self.tags = tags or [] + self.wait_for_completion = wait_for_completion - def execute(self, context: 'Context') -> str: + def execute(self, context: Context) -> str: self.log.info("Creating event subscription '%s' to '%s'", self.subscription_name, self.sns_topic_arn) create_subscription = self.hook.conn.create_event_subscription( @@ -504,13 +454,9 @@ def execute(self, context: 'Context') -> str: Enabled=self.enabled, Tags=self.tags, ) - self._await_status( - 'event_subscription', - self.subscription_name, - wait_statuses=['creating'], - ok_statuses=['active'], - ) + if self.wait_for_completion: + self.hook.wait_for_event_subscription_state(self.subscription_name, target_state="active") return json.dumps(create_subscription, default=str) @@ -538,7 +484,7 @@ def __init__( self.subscription_name = subscription_name - def execute(self, context: 'Context') -> str: + def execute(self, context: Context) -> str: self.log.info( "Deleting event subscription %s", self.subscription_name, @@ -551,6 +497,225 @@ def execute(self, context: 'Context') -> str: return json.dumps(delete_subscription, default=str) +class RdsCreateDbInstanceOperator(RdsBaseOperator): + """ + Creates an RDS DB instance + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:RdsCreateDbInstanceOperator` + + :param db_instance_identifier: The DB instance identifier, must start with a letter and + contain from 1 to 63 letters, numbers, or hyphens + :param db_instance_class: The compute and memory capacity of the DB instance, for example db.m5.large + :param engine: The name of the database engine to be used for this instance + :param rds_kwargs: Named arguments to pass to boto3 RDS client function ``create_db_instance`` + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.create_db_instance + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param wait_for_completion: If True, waits for creation of the DB instance to complete. (default: True) + """ + + template_fields = ("db_instance_identifier", "db_instance_class", "engine", "rds_kwargs") + + def __init__( + self, + *, + db_instance_identifier: str, + db_instance_class: str, + engine: str, + rds_kwargs: dict | None = None, + aws_conn_id: str = "aws_default", + wait_for_completion: bool = True, + **kwargs, + ): + super().__init__(aws_conn_id=aws_conn_id, **kwargs) + + self.db_instance_identifier = db_instance_identifier + self.db_instance_class = db_instance_class + self.engine = engine + self.rds_kwargs = rds_kwargs or {} + self.wait_for_completion = wait_for_completion + + def execute(self, context: Context) -> str: + self.log.info("Creating new DB instance %s", self.db_instance_identifier) + + create_db_instance = self.hook.conn.create_db_instance( + DBInstanceIdentifier=self.db_instance_identifier, + DBInstanceClass=self.db_instance_class, + Engine=self.engine, + **self.rds_kwargs, + ) + + if self.wait_for_completion: + self.hook.wait_for_db_instance_state(self.db_instance_identifier, target_state="available") + return json.dumps(create_db_instance, default=str) + + +class RdsDeleteDbInstanceOperator(RdsBaseOperator): + """ + Deletes an RDS DB Instance + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:RdsDeleteDbInstanceOperator` + + :param db_instance_identifier: The DB instance identifier for the DB instance to be deleted + :param rds_kwargs: Named arguments to pass to boto3 RDS client function ``delete_db_instance`` + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.delete_db_instance + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param wait_for_completion: If True, waits for deletion of the DB instance to complete. (default: True) + """ + + template_fields = ("db_instance_identifier", "rds_kwargs") + + def __init__( + self, + *, + db_instance_identifier: str, + rds_kwargs: dict | None = None, + aws_conn_id: str = "aws_default", + wait_for_completion: bool = True, + **kwargs, + ): + super().__init__(aws_conn_id=aws_conn_id, **kwargs) + self.db_instance_identifier = db_instance_identifier + self.rds_kwargs = rds_kwargs or {} + self.wait_for_completion = wait_for_completion + + def execute(self, context: Context) -> str: + self.log.info("Deleting DB instance %s", self.db_instance_identifier) + + delete_db_instance = self.hook.conn.delete_db_instance( + DBInstanceIdentifier=self.db_instance_identifier, + **self.rds_kwargs, + ) + + if self.wait_for_completion: + self.hook.wait_for_db_instance_state(self.db_instance_identifier, target_state="deleted") + return json.dumps(delete_db_instance, default=str) + + +class RdsStartDbOperator(RdsBaseOperator): + """ + Starts an RDS DB instance / cluster + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:RdsStartDbOperator` + + :param db_identifier: The AWS identifier of the DB to start + :param db_type: Type of the DB - either "instance" or "cluster" (default: "instance") + :param aws_conn_id: The Airflow connection used for AWS credentials. (default: "aws_default") + :param wait_for_completion: If True, waits for DB to start. (default: True) + """ + + template_fields = ("db_identifier", "db_type") + + def __init__( + self, + *, + db_identifier: str, + db_type: RdsDbType | str = RdsDbType.INSTANCE, + aws_conn_id: str = "aws_default", + wait_for_completion: bool = True, + **kwargs, + ): + super().__init__(aws_conn_id=aws_conn_id, **kwargs) + self.db_identifier = db_identifier + self.db_type = db_type + self.wait_for_completion = wait_for_completion + + def execute(self, context: Context) -> str: + self.db_type = RdsDbType(self.db_type) + start_db_response = self._start_db() + if self.wait_for_completion: + self._wait_until_db_available() + return json.dumps(start_db_response, default=str) + + def _start_db(self): + self.log.info("Starting DB %s '%s'", self.db_type.value, self.db_identifier) + if self.db_type == RdsDbType.INSTANCE: + response = self.hook.conn.start_db_instance(DBInstanceIdentifier=self.db_identifier) + else: + response = self.hook.conn.start_db_cluster(DBClusterIdentifier=self.db_identifier) + return response + + def _wait_until_db_available(self): + self.log.info("Waiting for DB %s to reach 'available' state", self.db_type.value) + if self.db_type == RdsDbType.INSTANCE: + self.hook.wait_for_db_instance_state(self.db_identifier, target_state="available") + else: + self.hook.wait_for_db_cluster_state(self.db_identifier, target_state="available") + + +class RdsStopDbOperator(RdsBaseOperator): + """ + Stops an RDS DB instance / cluster + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:RdsStopDbOperator` + + :param db_identifier: The AWS identifier of the DB to stop + :param db_type: Type of the DB - either "instance" or "cluster" (default: "instance") + :param db_snapshot_identifier: The instance identifier of the DB Snapshot to create before + stopping the DB instance. The default value (None) skips snapshot creation. This + parameter is ignored when ``db_type`` is "cluster" + :param aws_conn_id: The Airflow connection used for AWS credentials. (default: "aws_default") + :param wait_for_completion: If True, waits for DB to stop. (default: True) + """ + + template_fields = ("db_identifier", "db_snapshot_identifier", "db_type") + + def __init__( + self, + *, + db_identifier: str, + db_type: RdsDbType | str = RdsDbType.INSTANCE, + db_snapshot_identifier: str | None = None, + aws_conn_id: str = "aws_default", + wait_for_completion: bool = True, + **kwargs, + ): + super().__init__(aws_conn_id=aws_conn_id, **kwargs) + self.db_identifier = db_identifier + self.db_type = db_type + self.db_snapshot_identifier = db_snapshot_identifier + self.wait_for_completion = wait_for_completion + + def execute(self, context: Context) -> str: + self.db_type = RdsDbType(self.db_type) + stop_db_response = self._stop_db() + if self.wait_for_completion: + self._wait_until_db_stopped() + return json.dumps(stop_db_response, default=str) + + def _stop_db(self): + self.log.info("Stopping DB %s '%s'", self.db_type.value, self.db_identifier) + if self.db_type == RdsDbType.INSTANCE: + conn_params = {"DBInstanceIdentifier": self.db_identifier} + # The db snapshot parameter is optional, but the AWS SDK raises an exception + # if passed a null value. Only set snapshot id if value is present. + if self.db_snapshot_identifier: + conn_params["DBSnapshotIdentifier"] = self.db_snapshot_identifier + response = self.hook.conn.stop_db_instance(**conn_params) + else: + if self.db_snapshot_identifier: + self.log.warning( + "'db_snapshot_identifier' does not apply to db clusters. " + "Remove it to silence this warning." + ) + response = self.hook.conn.stop_db_cluster(DBClusterIdentifier=self.db_identifier) + return response + + def _wait_until_db_stopped(self): + self.log.info("Waiting for DB %s to reach 'stopped' state", self.db_type.value) + if self.db_type == RdsDbType.INSTANCE: + self.hook.wait_for_db_instance_state(self.db_identifier, target_state="stopped") + else: + self.hook.wait_for_db_cluster_state(self.db_identifier, target_state="stopped") + + __all__ = [ "RdsCreateDbSnapshotOperator", "RdsCopyDbSnapshotOperator", @@ -559,4 +724,8 @@ def execute(self, context: 'Context') -> str: "RdsDeleteEventSubscriptionOperator", "RdsStartExportTaskOperator", "RdsCancelExportTaskOperator", + "RdsCreateDbInstanceOperator", + "RdsDeleteDbInstanceOperator", + "RdsStartDbOperator", + "RdsStopDbOperator", ] diff --git a/airflow/providers/amazon/aws/operators/redshift.py b/airflow/providers/amazon/aws/operators/redshift.py deleted file mode 100644 index efcf44a797dea..0000000000000 --- a/airflow/providers/amazon/aws/operators/redshift.py +++ /dev/null @@ -1,33 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import warnings - -from airflow.providers.amazon.aws.operators.redshift_cluster import ( - RedshiftPauseClusterOperator, - RedshiftResumeClusterOperator, -) -from airflow.providers.amazon.aws.operators.redshift_sql import RedshiftSQLOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.redshift_sql` " - "or `airflow.providers.amazon.aws.operators.redshift_cluster` as appropriate.", - DeprecationWarning, - stacklevel=2, -) - -__all__ = ["RedshiftSQLOperator", "RedshiftPauseClusterOperator", "RedshiftResumeClusterOperator"] diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index 340e9577efef4..2da0fbf23a73d 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -14,9 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import time -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Sequence +from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook @@ -88,6 +91,7 @@ class RedshiftCreateClusterOperator(BaseOperator): "cluster_type", "node_type", "number_of_nodes", + "vpc_security_group_ids", ) ui_color = "#eeaa11" ui_fgcolor = "#ffffff" @@ -102,32 +106,32 @@ def __init__( cluster_type: str = "multi-node", db_name: str = "dev", number_of_nodes: int = 1, - cluster_security_groups: Optional[List[str]] = None, - vpc_security_group_ids: Optional[List[str]] = None, - cluster_subnet_group_name: Optional[str] = None, - availability_zone: Optional[str] = None, - preferred_maintenance_window: Optional[str] = None, - cluster_parameter_group_name: Optional[str] = None, + cluster_security_groups: list[str] | None = None, + vpc_security_group_ids: list[str] | None = None, + cluster_subnet_group_name: str | None = None, + availability_zone: str | None = None, + preferred_maintenance_window: str | None = None, + cluster_parameter_group_name: str | None = None, automated_snapshot_retention_period: int = 1, - manual_snapshot_retention_period: Optional[int] = None, + manual_snapshot_retention_period: int | None = None, port: int = 5439, cluster_version: str = "1.0", allow_version_upgrade: bool = True, publicly_accessible: bool = True, encrypted: bool = False, - hsm_client_certificate_identifier: Optional[str] = None, - hsm_configuration_identifier: Optional[str] = None, - elastic_ip: Optional[str] = None, - tags: Optional[List[Any]] = None, - kms_key_id: Optional[str] = None, + hsm_client_certificate_identifier: str | None = None, + hsm_configuration_identifier: str | None = None, + elastic_ip: str | None = None, + tags: list[Any] | None = None, + kms_key_id: str | None = None, enhanced_vpc_routing: bool = False, - additional_info: Optional[str] = None, - iam_roles: Optional[List[str]] = None, - maintenance_track_name: Optional[str] = None, - snapshot_schedule_identifier: Optional[str] = None, - availability_zone_relocation: Optional[bool] = None, - aqua_configuration_status: Optional[str] = None, - default_iam_role_arn: Optional[str] = None, + additional_info: str | None = None, + iam_roles: list[str] | None = None, + maintenance_track_name: str | None = None, + snapshot_schedule_identifier: str | None = None, + availability_zone_relocation: bool | None = None, + aqua_configuration_status: str | None = None, + default_iam_role_arn: str | None = None, aws_conn_id: str = "aws_default", **kwargs, ): @@ -168,10 +172,10 @@ def __init__( self.aws_conn_id = aws_conn_id self.kwargs = kwargs - def execute(self, context: 'Context'): + def execute(self, context: Context): redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id) self.log.info("Creating Redshift cluster %s", self.cluster_identifier) - params: Dict[str, Any] = {} + params: dict[str, Any] = {} if self.db_name: params["DBName"] = self.db_name if self.cluster_type: @@ -242,6 +246,130 @@ def execute(self, context: 'Context'): self.log.info(cluster) +class RedshiftCreateClusterSnapshotOperator(BaseOperator): + """ + Creates a manual snapshot of the specified cluster. The cluster must be in the available state + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:RedshiftCreateClusterSnapshotOperator` + + :param snapshot_identifier: A unique identifier for the snapshot that you are requesting + :param cluster_identifier: The cluster identifier for which you want a snapshot + :param retention_period: The number of days that a manual snapshot is retained. + If the value is -1, the manual snapshot is retained indefinitely. + :param wait_for_completion: Whether wait for the cluster snapshot to be in ``available`` state + :param poll_interval: Time (in seconds) to wait between two consecutive calls to check state + :param max_attempt: The maximum number of attempts to be made to check the state + :param aws_conn_id: The Airflow connection used for AWS credentials. + The default connection id is ``aws_default`` + """ + + template_fields: Sequence[str] = ( + "cluster_identifier", + "snapshot_identifier", + ) + + def __init__( + self, + *, + snapshot_identifier: str, + cluster_identifier: str, + retention_period: int = -1, + wait_for_completion: bool = False, + poll_interval: int = 15, + max_attempt: int = 20, + aws_conn_id: str = "aws_default", + **kwargs, + ): + super().__init__(**kwargs) + self.snapshot_identifier = snapshot_identifier + self.cluster_identifier = cluster_identifier + self.retention_period = retention_period + self.wait_for_completion = wait_for_completion + self.poll_interval = poll_interval + self.max_attempt = max_attempt + self.redshift_hook = RedshiftHook(aws_conn_id=aws_conn_id) + + def execute(self, context: Context) -> Any: + cluster_state = self.redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier) + if cluster_state != "available": + raise AirflowException( + "Redshift cluster must be in available state. " + f"Redshift cluster current state is {cluster_state}" + ) + + self.redshift_hook.create_cluster_snapshot( + cluster_identifier=self.cluster_identifier, + snapshot_identifier=self.snapshot_identifier, + retention_period=self.retention_period, + ) + + if self.wait_for_completion: + self.redshift_hook.get_conn().get_waiter("snapshot_available").wait( + ClusterIdentifier=self.cluster_identifier, + WaiterConfig={ + "Delay": self.poll_interval, + "MaxAttempts": self.max_attempt, + }, + ) + + +class RedshiftDeleteClusterSnapshotOperator(BaseOperator): + """ + Deletes the specified manual snapshot + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:RedshiftDeleteClusterSnapshotOperator` + + :param snapshot_identifier: A unique identifier for the snapshot that you are requesting + :param cluster_identifier: The unique identifier of the cluster the snapshot was created from + :param wait_for_completion: Whether wait for cluster deletion or not + The default value is ``True`` + :param aws_conn_id: The Airflow connection used for AWS credentials. + The default connection id is ``aws_default`` + :param poll_interval: Time (in seconds) to wait between two consecutive calls to check snapshot state + """ + + template_fields: Sequence[str] = ( + "cluster_identifier", + "snapshot_identifier", + ) + + def __init__( + self, + *, + snapshot_identifier: str, + cluster_identifier: str, + wait_for_completion: bool = True, + aws_conn_id: str = "aws_default", + poll_interval: int = 10, + **kwargs, + ): + super().__init__(**kwargs) + self.snapshot_identifier = snapshot_identifier + self.cluster_identifier = cluster_identifier + self.wait_for_completion = wait_for_completion + self.poll_interval = poll_interval + self.redshift_hook = RedshiftHook(aws_conn_id=aws_conn_id) + + def execute(self, context: Context) -> Any: + self.redshift_hook.get_conn().delete_cluster_snapshot( + SnapshotClusterIdentifier=self.cluster_identifier, + SnapshotIdentifier=self.snapshot_identifier, + ) + + if self.wait_for_completion: + while self.get_status() is not None: + time.sleep(self.poll_interval) + + def get_status(self) -> str: + return self.redshift_hook.get_cluster_snapshot_status( + snapshot_identifier=self.snapshot_identifier, + ) + + class RedshiftResumeClusterOperator(BaseOperator): """ Resume a paused AWS Redshift Cluster @@ -268,17 +396,27 @@ def __init__( super().__init__(**kwargs) self.cluster_identifier = cluster_identifier self.aws_conn_id = aws_conn_id + # These parameters are added to address an issue with the boto3 API where the API + # prematurely reports the cluster as available to receive requests. This causes the cluster + # to reject initial attempts to resume the cluster despite reporting the correct state. + self._attempts = 10 + self._attempt_interval = 15 - def execute(self, context: 'Context'): + def execute(self, context: Context): redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id) - cluster_state = redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier) - if cluster_state == 'paused': - self.log.info("Starting Redshift cluster %s", self.cluster_identifier) - redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier) - else: - self.log.warning( - "Unable to resume cluster since cluster is currently in status: %s", cluster_state - ) + + while self._attempts >= 1: + try: + redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier) + return + except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error: + self._attempts = self._attempts - 1 + + if self._attempts > 0: + self.log.error("Unable to resume cluster. %d attempts remaining.", self._attempts) + time.sleep(self._attempt_interval) + else: + raise error class RedshiftPauseClusterOperator(BaseOperator): @@ -307,17 +445,27 @@ def __init__( super().__init__(**kwargs) self.cluster_identifier = cluster_identifier self.aws_conn_id = aws_conn_id + # These parameters are added to address an issue with the boto3 API where the API + # prematurely reports the cluster as available to receive requests. This causes the cluster + # to reject initial attempts to pause the cluster despite reporting the correct state. + self._attempts = 10 + self._attempt_interval = 15 - def execute(self, context: 'Context'): + def execute(self, context: Context): redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id) - cluster_state = redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier) - if cluster_state == 'available': - self.log.info("Pausing Redshift cluster %s", self.cluster_identifier) - redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier) - else: - self.log.warning( - "Unable to pause cluster since cluster is currently in status: %s", cluster_state - ) + + while self._attempts >= 1: + try: + redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier) + return + except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error: + self._attempts = self._attempts - 1 + + if self._attempts > 0: + self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts) + time.sleep(self._attempt_interval) + else: + raise error class RedshiftDeleteClusterOperator(BaseOperator): @@ -346,7 +494,7 @@ def __init__( *, cluster_identifier: str, skip_final_cluster_snapshot: bool = True, - final_cluster_snapshot_identifier: Optional[str] = None, + final_cluster_snapshot_identifier: str | None = None, wait_for_completion: bool = True, aws_conn_id: str = "aws_default", poll_interval: float = 30.0, @@ -360,24 +508,16 @@ def __init__( self.redshift_hook = RedshiftHook(aws_conn_id=aws_conn_id) self.poll_interval = poll_interval - def execute(self, context: 'Context'): - self.delete_cluster() - - if self.wait_for_completion: - cluster_status: str = self.check_status() - while cluster_status != "cluster_not_found": - self.log.info( - "cluster status is %s. Sleeping for %s seconds.", cluster_status, self.poll_interval - ) - time.sleep(self.poll_interval) - cluster_status = self.check_status() - - def delete_cluster(self) -> None: + def execute(self, context: Context): self.redshift_hook.delete_cluster( cluster_identifier=self.cluster_identifier, skip_final_cluster_snapshot=self.skip_final_cluster_snapshot, final_cluster_snapshot_identifier=self.final_cluster_snapshot_identifier, ) - def check_status(self) -> str: - return self.redshift_hook.cluster_status(self.cluster_identifier) + if self.wait_for_completion: + waiter = self.redshift_hook.get_conn().get_waiter("cluster_deleted") + waiter.wait( + ClusterIdentifier=self.cluster_identifier, + WaiterConfig={"Delay": self.poll_interval, "MaxAttempts": 30}, + ) diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py b/airflow/providers/amazon/aws/operators/redshift_data.py index f2d47da655835..416ce8146ef6d 100644 --- a/airflow/providers/amazon/aws/operators/redshift_data.py +++ b/airflow/providers/amazon/aws/operators/redshift_data.py @@ -15,17 +15,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys -from time import sleep -from typing import TYPE_CHECKING, Any, Dict, Optional +from __future__ import annotations -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property +from time import sleep +from typing import TYPE_CHECKING, Any +from airflow.compat.functools import cached_property from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook +from airflow.providers.amazon.aws.utils import trim_none_values if TYPE_CHECKING: from airflow.utils.context import Context @@ -40,7 +38,7 @@ class RedshiftDataOperator(BaseOperator): :ref:`howto/operator:RedshiftDataOperator` :param database: the name of the database - :param sql: the SQL statement text to run + :param sql: the SQL statement or list of SQL statement to run :param cluster_identifier: unique identifier of a cluster :param db_user: the database username :param parameters: the parameters for the SQL statement @@ -54,32 +52,32 @@ class RedshiftDataOperator(BaseOperator): """ template_fields = ( - 'cluster_identifier', - 'database', - 'sql', - 'db_user', - 'parameters', - 'statement_name', - 'aws_conn_id', - 'region', + "cluster_identifier", + "database", + "sql", + "db_user", + "parameters", + "statement_name", + "aws_conn_id", + "region", ) - template_ext = ('.sql',) - template_fields_renderers = {'sql': 'sql'} + template_ext = (".sql",) + template_fields_renderers = {"sql": "sql"} def __init__( self, database: str, - sql: str, - cluster_identifier: Optional[str] = None, - db_user: Optional[str] = None, - parameters: Optional[list] = None, - secret_arn: Optional[str] = None, - statement_name: Optional[str] = None, + sql: str | list, + cluster_identifier: str | None = None, + db_user: str | None = None, + parameters: list | None = None, + secret_arn: str | None = None, + statement_name: str | None = None, with_event: bool = False, await_result: bool = True, poll_interval: int = 10, - aws_conn_id: str = 'aws_default', - region: Optional[str] = None, + aws_conn_id: str = "aws_default", + region: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -109,7 +107,7 @@ def hook(self) -> RedshiftDataHook: return RedshiftDataHook(aws_conn_id=self.aws_conn_id, region_name=self.region) def execute_query(self): - kwargs: Dict[str, Any] = { + kwargs: dict[str, Any] = { "ClusterIdentifier": self.cluster_identifier, "Database": self.database, "Sql": self.sql, @@ -120,9 +118,22 @@ def execute_query(self): "StatementName": self.statement_name, } - filter_values = {key: val for key, val in kwargs.items() if val is not None} - resp = self.hook.conn.execute_statement(**filter_values) - return resp['Id'] + resp = self.hook.conn.execute_statement(**trim_none_values(kwargs)) + return resp["Id"] + + def execute_batch_query(self): + kwargs: dict[str, Any] = { + "ClusterIdentifier": self.cluster_identifier, + "Database": self.database, + "Sqls": self.sql, + "DbUser": self.db_user, + "Parameters": self.parameters, + "WithEvent": self.with_event, + "SecretArn": self.secret_arn, + "StatementName": self.statement_name, + } + resp = self.hook.conn.batch_execute_statement(**trim_none_values(kwargs)) + return resp["Id"] def wait_for_results(self, statement_id): while True: @@ -130,20 +141,22 @@ def wait_for_results(self, statement_id): resp = self.hook.conn.describe_statement( Id=statement_id, ) - status = resp['Status'] - if status == 'FINISHED': + status = resp["Status"] + if status == "FINISHED": return status - elif status == 'FAILED' or status == 'ABORTED': + elif status == "FAILED" or status == "ABORTED": raise ValueError(f"Statement {statement_id!r} terminated with status {status}.") else: self.log.info("Query %s", status) sleep(self.poll_interval) - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: """Execute a statement against Amazon Redshift""" self.log.info("Executing statement: %s", self.sql) - - self.statement_id = self.execute_query() + if isinstance(self.sql, list): + self.statement_id = self.execute_batch_query() + else: + self.statement_id = self.execute_query() if self.await_result: self.wait_for_results(self.statement_id) @@ -153,10 +166,10 @@ def execute(self, context: 'Context') -> None: def on_kill(self) -> None: """Cancel the submitted redshift query""" if self.statement_id: - self.log.info('Received a kill signal.') - self.log.info('Stopping Query with statementId - %s', self.statement_id) + self.log.info("Received a kill signal.") + self.log.info("Stopping Query with statementId - %s", self.statement_id) try: self.hook.conn.cancel_statement(Id=self.statement_id) except Exception as ex: - self.log.error('Unable to cancel query. Exiting. %s', ex) + self.log.error("Unable to cancel query. Exiting. %s", ex) diff --git a/airflow/providers/amazon/aws/operators/redshift_sql.py b/airflow/providers/amazon/aws/operators/redshift_sql.py index c7ad77acb5341..b425ac262d17f 100644 --- a/airflow/providers/amazon/aws/operators/redshift_sql.py +++ b/airflow/providers/amazon/aws/operators/redshift_sql.py @@ -14,18 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union +import warnings +from typing import Sequence -from airflow.models import BaseOperator -from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook -from airflow.www import utils as wwwutils +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator -if TYPE_CHECKING: - from airflow.utils.context import Context - -class RedshiftSQLOperator(BaseOperator): +class RedshiftSQLOperator(SQLExecuteQueryOperator): """ Executes SQL Statements against an Amazon Redshift cluster @@ -43,36 +40,18 @@ class RedshiftSQLOperator(BaseOperator): (default value: False) """ - template_fields: Sequence[str] = ('sql',) - template_ext: Sequence[str] = ('.sql',) - # TODO: Remove renderer check when the provider has an Airflow 2.3+ requirement. - template_fields_renderers = { - "sql": "postgresql" if "postgresql" in wwwutils.get_attr_renderer() else "sql" - } - - def __init__( - self, - *, - sql: Union[str, Iterable[str]], - redshift_conn_id: str = 'redshift_default', - parameters: Optional[dict] = None, - autocommit: bool = True, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.redshift_conn_id = redshift_conn_id - self.sql = sql - self.autocommit = autocommit - self.parameters = parameters - - def get_hook(self) -> RedshiftSQLHook: - """Create and return RedshiftSQLHook. - :return RedshiftSQLHook: A RedshiftSQLHook instance. - """ - return RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id) - - def execute(self, context: 'Context') -> None: - """Execute a statement against Amazon Redshift""" - self.log.info("Executing statement: %s", self.sql) - hook = self.get_hook() - hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) + template_fields: Sequence[str] = ( + "sql", + "conn_id", + ) + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "postgresql"} + + def __init__(self, *, redshift_conn_id: str = "redshift_default", **kwargs) -> None: + super().__init__(conn_id=redshift_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""", + DeprecationWarning, + stacklevel=2, + ) diff --git a/airflow/providers/amazon/aws/operators/s3.py b/airflow/providers/amazon/aws/operators/s3.py index 7fbef6629c7f7..91732084017d1 100644 --- a/airflow/providers/amazon/aws/operators/s3.py +++ b/airflow/providers/amazon/aws/operators/s3.py @@ -15,13 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# - """This module contains AWS S3 operators.""" +from __future__ import annotations + import subprocess import sys from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -57,17 +57,16 @@ def __init__( self, *, bucket_name: str, - aws_conn_id: Optional[str] = "aws_default", - region_name: Optional[str] = None, + aws_conn_id: str | None = "aws_default", + region_name: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) self.bucket_name = bucket_name self.region_name = region_name self.aws_conn_id = aws_conn_id - self.region_name = region_name - def execute(self, context: 'Context'): + def execute(self, context: Context): s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) if not s3_hook.check_for_bucket(self.bucket_name): s3_hook.create_bucket(bucket_name=self.bucket_name, region_name=self.region_name) @@ -99,7 +98,7 @@ def __init__( self, bucket_name: str, force_delete: bool = False, - aws_conn_id: Optional[str] = "aws_default", + aws_conn_id: str | None = "aws_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -107,7 +106,7 @@ def __init__( self.force_delete = force_delete self.aws_conn_id = aws_conn_id - def execute(self, context: 'Context'): + def execute(self, context: Context): s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) if s3_hook.check_for_bucket(self.bucket_name): s3_hook.delete_bucket(bucket_name=self.bucket_name, force_delete=self.force_delete) @@ -134,12 +133,12 @@ class S3GetBucketTaggingOperator(BaseOperator): template_fields: Sequence[str] = ("bucket_name",) - def __init__(self, bucket_name: str, aws_conn_id: Optional[str] = "aws_default", **kwargs) -> None: + def __init__(self, bucket_name: str, aws_conn_id: str | None = "aws_default", **kwargs) -> None: super().__init__(**kwargs) self.bucket_name = bucket_name self.aws_conn_id = aws_conn_id - def execute(self, context: 'Context'): + def execute(self, context: Context): s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) if s3_hook.check_for_bucket(self.bucket_name): @@ -177,10 +176,10 @@ class S3PutBucketTaggingOperator(BaseOperator): def __init__( self, bucket_name: str, - key: Optional[str] = None, - value: Optional[str] = None, - tag_set: Optional[List[Dict[str, str]]] = None, - aws_conn_id: Optional[str] = "aws_default", + key: str | None = None, + value: str | None = None, + tag_set: list[dict[str, str]] | None = None, + aws_conn_id: str | None = "aws_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -190,7 +189,7 @@ def __init__( self.bucket_name = bucket_name self.aws_conn_id = aws_conn_id - def execute(self, context: 'Context'): + def execute(self, context: Context): s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) if s3_hook.check_for_bucket(self.bucket_name): @@ -221,12 +220,12 @@ class S3DeleteBucketTaggingOperator(BaseOperator): template_fields: Sequence[str] = ("bucket_name",) - def __init__(self, bucket_name: str, aws_conn_id: Optional[str] = "aws_default", **kwargs) -> None: + def __init__(self, bucket_name: str, aws_conn_id: str | None = "aws_default", **kwargs) -> None: super().__init__(**kwargs) self.bucket_name = bucket_name self.aws_conn_id = aws_conn_id - def execute(self, context: 'Context'): + def execute(self, context: Context): s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) if s3_hook.check_for_bucket(self.bucket_name): @@ -280,10 +279,10 @@ class S3CopyObjectOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'source_bucket_key', - 'dest_bucket_key', - 'source_bucket_name', - 'dest_bucket_name', + "source_bucket_key", + "dest_bucket_key", + "source_bucket_name", + "dest_bucket_name", ) def __init__( @@ -291,12 +290,12 @@ def __init__( *, source_bucket_key: str, dest_bucket_key: str, - source_bucket_name: Optional[str] = None, - dest_bucket_name: Optional[str] = None, - source_version_id: Optional[str] = None, - aws_conn_id: str = 'aws_default', - verify: Optional[Union[str, bool]] = None, - acl_policy: Optional[str] = None, + source_bucket_name: str | None = None, + dest_bucket_name: str | None = None, + source_version_id: str | None = None, + aws_conn_id: str = "aws_default", + verify: str | bool | None = None, + acl_policy: str | None = None, **kwargs, ): super().__init__(**kwargs) @@ -310,7 +309,7 @@ def __init__( self.verify = verify self.acl_policy = acl_policy - def execute(self, context: 'Context'): + def execute(self, context: Context): s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) s3_hook.copy_object( self.source_bucket_key, @@ -360,21 +359,21 @@ class S3CreateObjectOperator(BaseOperator): """ - template_fields: Sequence[str] = ('s3_bucket', 's3_key', 'data') + template_fields: Sequence[str] = ("s3_bucket", "s3_key", "data") def __init__( self, *, - s3_bucket: Optional[str] = None, + s3_bucket: str | None = None, s3_key: str, - data: Union[str, bytes], + data: str | bytes, replace: bool = False, encrypt: bool = False, - acl_policy: Optional[str] = None, - encoding: Optional[str] = None, - compression: Optional[str] = None, - aws_conn_id: str = 'aws_default', - verify: Optional[Union[str, bool]] = None, + acl_policy: str | None = None, + encoding: str | None = None, + compression: str | None = None, + aws_conn_id: str = "aws_default", + verify: str | bool | None = None, **kwargs, ): super().__init__(**kwargs) @@ -390,10 +389,10 @@ def __init__( self.aws_conn_id = aws_conn_id self.verify = verify - def execute(self, context: 'Context'): + def execute(self, context: Context): s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) - s3_bucket, s3_key = s3_hook.get_s3_bucket_key(self.s3_bucket, self.s3_key, 'dest_bucket', 'dest_key') + s3_bucket, s3_key = s3_hook.get_s3_bucket_key(self.s3_bucket, self.s3_key, "dest_bucket", "dest_key") if isinstance(self.data, str): s3_hook.load_string( @@ -444,16 +443,16 @@ class S3DeleteObjectsOperator(BaseOperator): CA cert bundle than the one used by botocore. """ - template_fields: Sequence[str] = ('keys', 'bucket', 'prefix') + template_fields: Sequence[str] = ("keys", "bucket", "prefix") def __init__( self, *, bucket: str, - keys: Optional[Union[str, list]] = None, - prefix: Optional[str] = None, - aws_conn_id: str = 'aws_default', - verify: Optional[Union[str, bool]] = None, + keys: str | list | None = None, + prefix: str | None = None, + aws_conn_id: str = "aws_default", + verify: str | bool | None = None, **kwargs, ): @@ -467,7 +466,7 @@ def __init__( if not bool(keys is None) ^ bool(prefix is None): raise AirflowException("Either keys or prefix should be set.") - def execute(self, context: 'Context'): + def execute(self, context: Context): if not bool(self.keys is None) ^ bool(self.prefix is None): raise AirflowException("Either keys or prefix should be set.") @@ -525,22 +524,22 @@ class S3FileTransformOperator(BaseOperator): :param replace: Replace dest S3 key if it already exists """ - template_fields: Sequence[str] = ('source_s3_key', 'dest_s3_key', 'script_args') + template_fields: Sequence[str] = ("source_s3_key", "dest_s3_key", "script_args") template_ext: Sequence[str] = () - ui_color = '#f9c915' + ui_color = "#f9c915" def __init__( self, *, source_s3_key: str, dest_s3_key: str, - transform_script: Optional[str] = None, + transform_script: str | None = None, select_expression=None, - script_args: Optional[Sequence[str]] = None, - source_aws_conn_id: str = 'aws_default', - source_verify: Optional[Union[bool, str]] = None, - dest_aws_conn_id: str = 'aws_default', - dest_verify: Optional[Union[bool, str]] = None, + script_args: Sequence[str] | None = None, + source_aws_conn_id: str = "aws_default", + source_verify: bool | str | None = None, + dest_aws_conn_id: str = "aws_default", + dest_verify: bool | str | None = None, replace: bool = False, **kwargs, ) -> None: @@ -558,7 +557,7 @@ def __init__( self.script_args = script_args or [] self.output_encoding = sys.getdefaultencoding() - def execute(self, context: 'Context'): + def execute(self, context: Context): if self.transform_script is None and self.select_expression is None: raise AirflowException("Either transform_script or select_expression must be specified") @@ -589,7 +588,7 @@ def execute(self, context: 'Context'): ) as process: self.log.info("Output:") if process.stdout is not None: - for line in iter(process.stdout.readline, b''): + for line in iter(process.stdout.readline, b""): self.log.info(line.decode(self.output_encoding).rstrip()) process.wait() @@ -653,17 +652,17 @@ class S3ListOperator(BaseOperator): ) """ - template_fields: Sequence[str] = ('bucket', 'prefix', 'delimiter') - ui_color = '#ffd700' + template_fields: Sequence[str] = ("bucket", "prefix", "delimiter") + ui_color = "#ffd700" def __init__( self, *, bucket: str, - prefix: str = '', - delimiter: str = '', - aws_conn_id: str = 'aws_default', - verify: Optional[Union[str, bool]] = None, + prefix: str = "", + delimiter: str = "", + aws_conn_id: str = "aws_default", + verify: str | bool | None = None, **kwargs, ): super().__init__(**kwargs) @@ -673,11 +672,11 @@ def __init__( self.aws_conn_id = aws_conn_id self.verify = verify - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) self.log.info( - 'Getting the list of files from bucket: %s in prefix: %s (Delimiter %s)', + "Getting the list of files from bucket: %s in prefix: %s (Delimiter %s)", self.bucket, self.prefix, self.delimiter, @@ -727,8 +726,8 @@ class S3ListPrefixesOperator(BaseOperator): ) """ - template_fields: Sequence[str] = ('bucket', 'prefix', 'delimiter') - ui_color = '#ffd700' + template_fields: Sequence[str] = ("bucket", "prefix", "delimiter") + ui_color = "#ffd700" def __init__( self, @@ -736,8 +735,8 @@ def __init__( bucket: str, prefix: str, delimiter: str, - aws_conn_id: str = 'aws_default', - verify: Optional[Union[str, bool]] = None, + aws_conn_id: str = "aws_default", + verify: str | bool | None = None, **kwargs, ): super().__init__(**kwargs) @@ -747,11 +746,11 @@ def __init__( self.aws_conn_id = aws_conn_id self.verify = verify - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) self.log.info( - 'Getting the list of subfolders from bucket: %s in prefix: %s (Delimiter %s)', + "Getting the list of subfolders from bucket: %s in prefix: %s (Delimiter %s)", self.bucket, self.prefix, self.delimiter, diff --git a/airflow/providers/amazon/aws/operators/s3_bucket.py b/airflow/providers/amazon/aws/operators/s3_bucket.py deleted file mode 100644 index e5806fa780848..0000000000000 --- a/airflow/providers/amazon/aws/operators/s3_bucket.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.s3`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.s3 import S3CreateBucketOperator, S3DeleteBucketOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/s3_bucket_tagging.py b/airflow/providers/amazon/aws/operators/s3_bucket_tagging.py deleted file mode 100644 index fb214412643ae..0000000000000 --- a/airflow/providers/amazon/aws/operators/s3_bucket_tagging.py +++ /dev/null @@ -1,32 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.s3`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.s3 import ( # noqa - S3DeleteBucketTaggingOperator, - S3GetBucketTaggingOperator, - S3PutBucketTaggingOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/s3_copy_object.py b/airflow/providers/amazon/aws/operators/s3_copy_object.py deleted file mode 100644 index 298f8d21724cb..0000000000000 --- a/airflow/providers/amazon/aws/operators/s3_copy_object.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.s3`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.s3 import S3CopyObjectOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/s3_delete_objects.py b/airflow/providers/amazon/aws/operators/s3_delete_objects.py deleted file mode 100644 index 35d86893eb377..0000000000000 --- a/airflow/providers/amazon/aws/operators/s3_delete_objects.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.s3`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.s3 import S3DeleteObjectsOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/s3_file_transform.py b/airflow/providers/amazon/aws/operators/s3_file_transform.py deleted file mode 100644 index 4400b202bbda0..0000000000000 --- a/airflow/providers/amazon/aws/operators/s3_file_transform.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.s3`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.s3 import S3FileTransformOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/s3_list.py b/airflow/providers/amazon/aws/operators/s3_list.py deleted file mode 100644 index c114a0b814886..0000000000000 --- a/airflow/providers/amazon/aws/operators/s3_list.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.s3`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.s3 import S3ListOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/s3_list_prefixes.py b/airflow/providers/amazon/aws/operators/s3_list_prefixes.py deleted file mode 100644 index 5e94f0c17f471..0000000000000 --- a/airflow/providers/amazon/aws/operators/s3_list_prefixes.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.s3`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.s3 import S3ListPrefixesOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index 084e3e857e9be..c5f08db049206 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -14,28 +14,29 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import json -import sys -from typing import TYPE_CHECKING, Any, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Sequence from botocore.exceptions import ClientError +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property +from airflow.utils.json import AirflowJsonEncoder if TYPE_CHECKING: from airflow.utils.context import Context -DEFAULT_CONN_ID = 'aws_default' -CHECK_INTERVAL_SECOND = 30 +DEFAULT_CONN_ID: str = "aws_default" +CHECK_INTERVAL_SECOND: int = 30 + + +def serialize(result: dict) -> str: + return json.loads(json.dumps(result, cls=AirflowJsonEncoder)) class SageMakerBaseOperator(BaseOperator): @@ -44,17 +45,18 @@ class SageMakerBaseOperator(BaseOperator): :param config: The configuration necessary to start a training job (templated) """ - template_fields: Sequence[str] = ('config',) + template_fields: Sequence[str] = ("config",) template_ext: Sequence[str] = () - template_fields_renderers = {'config': 'json'} - ui_color = '#ededed' - integer_fields: List[List[Any]] = [] + template_fields_renderers: dict = {"config": "json"} + ui_color: str = "#ededed" + integer_fields: list[list[Any]] = [] - def __init__(self, *, config: dict, **kwargs): + def __init__(self, *, config: dict, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs): super().__init__(**kwargs) self.config = config + self.aws_conn_id = aws_conn_id - def parse_integer(self, config, field): + def parse_integer(self, config: dict, field: list[str] | str) -> None: """Recursive method for parsing string fields holding integer values to integers.""" if len(field) == 1: if isinstance(config, list): @@ -74,34 +76,39 @@ def parse_integer(self, config, field): self.parse_integer(config[head], tail) return - def parse_config_integers(self): - """ - Parse the integer fields of training config to integers in case the config is rendered by Jinja and - all fields are str - """ + def parse_config_integers(self) -> None: + """Parse the integer fields to ints in case the config is rendered by Jinja and all fields are str.""" for field in self.integer_fields: self.parse_integer(self.config, field) - def expand_role(self): + def expand_role(self) -> None: """Placeholder for calling boto3's `expand_role`, which expands an IAM role name into an ARN.""" - def preprocess_config(self): + def preprocess_config(self) -> None: """Process the config into a usable form.""" - self.log.info('Preprocessing the config and doing required s3_operations') + self._create_integer_fields() + self.log.info("Preprocessing the config and doing required s3_operations") self.hook.configure_s3_resources(self.config) self.parse_config_integers() self.expand_role() self.log.info( - 'After preprocessing the config is:\n %s', - json.dumps(self.config, sort_keys=True, indent=4, separators=(',', ': ')), + "After preprocessing the config is:\n %s", + json.dumps(self.config, sort_keys=True, indent=4, separators=(",", ": ")), ) - def execute(self, context: 'Context'): - raise NotImplementedError('Please implement execute() in sub class!') + def _create_integer_fields(self) -> None: + """ + Set fields which should be cast to integers. + Child classes should override this method if they need integer fields parsed. + """ + self.integer_fields = [] + + def execute(self, context: Context) -> None | dict: + raise NotImplementedError("Please implement execute() in sub class!") @cached_property def hook(self): - """Return SageMakerHook""" + """Return SageMakerHook.""" return SageMakerHook(aws_conn_id=self.aws_conn_id) @@ -140,55 +147,64 @@ def __init__( wait_for_completion: bool = True, print_log: bool = True, check_interval: int = CHECK_INTERVAL_SECOND, - max_ingestion_time: Optional[int] = None, - action_if_job_exists: str = 'increment', + max_ingestion_time: int | None = None, + action_if_job_exists: str = "increment", **kwargs, ): - super().__init__(config=config, **kwargs) - if action_if_job_exists not in ('increment', 'fail'): + super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) + if action_if_job_exists not in ("increment", "fail"): raise AirflowException( f"Argument action_if_job_exists accepts only 'increment' and 'fail'. \ Provided value: '{action_if_job_exists}'." ) self.action_if_job_exists = action_if_job_exists - self.aws_conn_id = aws_conn_id self.wait_for_completion = wait_for_completion self.print_log = print_log self.check_interval = check_interval self.max_ingestion_time = max_ingestion_time - self._create_integer_fields() def _create_integer_fields(self) -> None: - """Set fields which should be casted to integers.""" - self.integer_fields = [ - ['ProcessingResources', 'ClusterConfig', 'InstanceCount'], - ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'], + """Set fields which should be cast to integers.""" + self.integer_fields: list[list[str] | list[list[str]]] = [ + ["ProcessingResources", "ClusterConfig", "InstanceCount"], + ["ProcessingResources", "ClusterConfig", "VolumeSizeInGB"], ] - if 'StoppingCondition' in self.config: - self.integer_fields += [['StoppingCondition', 'MaxRuntimeInSeconds']] + if "StoppingCondition" in self.config: + self.integer_fields.append(["StoppingCondition", "MaxRuntimeInSeconds"]) def expand_role(self) -> None: - if 'RoleArn' in self.config: - hook = AwsBaseHook(self.aws_conn_id, client_type='iam') - self.config['RoleArn'] = hook.expand_role(self.config['RoleArn']) + """Expands an IAM role name into an ARN.""" + if "RoleArn" in self.config: + hook = AwsBaseHook(self.aws_conn_id, client_type="iam") + self.config["RoleArn"] = hook.expand_role(self.config["RoleArn"]) - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: self.preprocess_config() - processing_job_name = self.config['ProcessingJobName'] - if self.hook.find_processing_job_by_name(processing_job_name): - raise AirflowException( - f'A SageMaker processing job with name {processing_job_name} already exists.' - ) - self.log.info('Creating SageMaker processing job %s.', self.config['ProcessingJobName']) + processing_job_name = self.config["ProcessingJobName"] + processing_job_dedupe_pattern = "-[0-9]+$" + existing_jobs_found = self.hook.count_processing_jobs_by_name( + processing_job_name, processing_job_dedupe_pattern + ) + if existing_jobs_found: + if self.action_if_job_exists == "fail": + raise AirflowException( + f"A SageMaker processing job with name {processing_job_name} already exists." + ) + elif self.action_if_job_exists == "increment": + self.log.info("Found existing processing job with name '%s'.", processing_job_name) + new_processing_job_name = f"{processing_job_name}-{existing_jobs_found + 1}" + self.config["ProcessingJobName"] = new_processing_job_name + self.log.info("Incremented processing job name to '%s'.", new_processing_job_name) + response = self.hook.create_processing_job( self.config, wait_for_completion=self.wait_for_completion, check_interval=self.check_interval, max_ingestion_time=self.max_ingestion_time, ) - if response['ResponseMetadata']['HTTPStatusCode'] != 200: - raise AirflowException(f'Sagemaker Processing Job creation failed: {response}') - return {'Processing': self.hook.describe_processing_job(self.config['ProcessingJobName'])} + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException(f"Sagemaker Processing Job creation failed: {response}") + return {"Processing": serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))} class SageMakerEndpointConfigOperator(SageMakerBaseOperator): @@ -209,8 +225,6 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator): :return Dict: Returns The ARN of the endpoint config created in Amazon SageMaker. """ - integer_fields = [['ProductionVariants', 'InitialInstanceCount']] - def __init__( self, *, @@ -218,18 +232,24 @@ def __init__( aws_conn_id: str = DEFAULT_CONN_ID, **kwargs, ): - super().__init__(config=config, **kwargs) - self.config = config - self.aws_conn_id = aws_conn_id + super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) + + def _create_integer_fields(self) -> None: + """Set fields which should be cast to integers.""" + self.integer_fields: list[list[str]] = [["ProductionVariants", "InitialInstanceCount"]] - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: self.preprocess_config() - self.log.info('Creating SageMaker Endpoint Config %s.', self.config['EndpointConfigName']) + self.log.info("Creating SageMaker Endpoint Config %s.", self.config["EndpointConfigName"]) response = self.hook.create_endpoint_config(self.config) - if response['ResponseMetadata']['HTTPStatusCode'] != 200: - raise AirflowException(f'Sagemaker endpoint config creation failed: {response}') + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException(f"Sagemaker endpoint config creation failed: {response}") else: - return {'EndpointConfig': self.hook.describe_endpoint_config(self.config['EndpointConfigName'])} + return { + "EndpointConfig": serialize( + self.hook.describe_endpoint_config(self.config["EndpointConfigName"]) + ) + } class SageMakerEndpointOperator(SageMakerBaseOperator): @@ -287,54 +307,54 @@ def __init__( aws_conn_id: str = DEFAULT_CONN_ID, wait_for_completion: bool = True, check_interval: int = CHECK_INTERVAL_SECOND, - max_ingestion_time: Optional[int] = None, - operation: str = 'create', + max_ingestion_time: int | None = None, + operation: str = "create", **kwargs, ): - super().__init__(config=config, **kwargs) - self.config = config - self.aws_conn_id = aws_conn_id + super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) self.wait_for_completion = wait_for_completion self.check_interval = check_interval self.max_ingestion_time = max_ingestion_time self.operation = operation.lower() - if self.operation not in ['create', 'update']: + if self.operation not in ["create", "update"]: raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"') - self.create_integer_fields() - def create_integer_fields(self) -> None: - """Set fields which should be casted to integers.""" - if 'EndpointConfig' in self.config: - self.integer_fields = [['EndpointConfig', 'ProductionVariants', 'InitialInstanceCount']] + def _create_integer_fields(self) -> None: + """Set fields which should be cast to integers.""" + if "EndpointConfig" in self.config: + self.integer_fields: list[list[str]] = [ + ["EndpointConfig", "ProductionVariants", "InitialInstanceCount"] + ] def expand_role(self) -> None: - if 'Model' not in self.config: + """Expands an IAM role name into an ARN.""" + if "Model" not in self.config: return - hook = AwsBaseHook(self.aws_conn_id, client_type='iam') - config = self.config['Model'] - if 'ExecutionRoleArn' in config: - config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn']) + hook = AwsBaseHook(self.aws_conn_id, client_type="iam") + config = self.config["Model"] + if "ExecutionRoleArn" in config: + config["ExecutionRoleArn"] = hook.expand_role(config["ExecutionRoleArn"]) - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: self.preprocess_config() - model_info = self.config.get('Model') - endpoint_config_info = self.config.get('EndpointConfig') - endpoint_info = self.config.get('Endpoint', self.config) + model_info = self.config.get("Model") + endpoint_config_info = self.config.get("EndpointConfig") + endpoint_info = self.config.get("Endpoint", self.config) if model_info: - self.log.info('Creating SageMaker model %s.', model_info['ModelName']) + self.log.info("Creating SageMaker model %s.", model_info["ModelName"]) self.hook.create_model(model_info) if endpoint_config_info: - self.log.info('Creating endpoint config %s.', endpoint_config_info['EndpointConfigName']) + self.log.info("Creating endpoint config %s.", endpoint_config_info["EndpointConfigName"]) self.hook.create_endpoint_config(endpoint_config_info) - if self.operation == 'create': + if self.operation == "create": sagemaker_operation = self.hook.create_endpoint - log_str = 'Creating' - elif self.operation == 'update': + log_str = "Creating" + elif self.operation == "update": sagemaker_operation = self.hook.update_endpoint - log_str = 'Updating' + log_str = "Updating" else: raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"') - self.log.info('%s SageMaker endpoint %s.', log_str, endpoint_info['EndpointName']) + self.log.info("%s SageMaker endpoint %s.", log_str, endpoint_info["EndpointName"]) try: response = sagemaker_operation( endpoint_info, @@ -343,21 +363,23 @@ def execute(self, context: 'Context') -> dict: max_ingestion_time=self.max_ingestion_time, ) except ClientError: - self.operation = 'update' + self.operation = "update" sagemaker_operation = self.hook.update_endpoint - log_str = 'Updating' + log_str = "Updating" response = sagemaker_operation( endpoint_info, wait_for_completion=self.wait_for_completion, check_interval=self.check_interval, max_ingestion_time=self.max_ingestion_time, ) - if response['ResponseMetadata']['HTTPStatusCode'] != 200: - raise AirflowException(f'Sagemaker endpoint creation failed: {response}') + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException(f"Sagemaker endpoint creation failed: {response}") else: return { - 'EndpointConfig': self.hook.describe_endpoint_config(endpoint_info['EndpointConfigName']), - 'Endpoint': self.hook.describe_endpoint(endpoint_info['EndpointName']), + "EndpointConfig": serialize( + self.hook.describe_endpoint_config(endpoint_info["EndpointConfigName"]) + ), + "Endpoint": serialize(self.hook.describe_endpoint(endpoint_info["EndpointName"])), } @@ -396,6 +418,11 @@ class SageMakerTransformOperator(SageMakerBaseOperator): :param max_ingestion_time: If wait is set to True, the operation fails if the transform job doesn't finish within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout. + :param check_if_job_exists: If set to true, then the operator will check whether a transform job + already exists for the name in the config. + :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment" + (default) and "fail". + This is only relevant if check_if_job_exists is True. :return Dict: Returns The ARN of the model created in Amazon SageMaker. """ @@ -406,58 +433,85 @@ def __init__( aws_conn_id: str = DEFAULT_CONN_ID, wait_for_completion: bool = True, check_interval: int = CHECK_INTERVAL_SECOND, - max_ingestion_time: Optional[int] = None, + max_ingestion_time: int | None = None, + check_if_job_exists: bool = True, + action_if_job_exists: str = "increment", **kwargs, ): - super().__init__(config=config, **kwargs) - self.config = config - self.aws_conn_id = aws_conn_id + super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) self.wait_for_completion = wait_for_completion self.check_interval = check_interval self.max_ingestion_time = max_ingestion_time - self.create_integer_fields() - - def create_integer_fields(self) -> None: - """Set fields which should be casted to integers.""" - self.integer_fields: List[List[str]] = [ - ['Transform', 'TransformResources', 'InstanceCount'], - ['Transform', 'MaxConcurrentTransforms'], - ['Transform', 'MaxPayloadInMB'], + self.check_if_job_exists = check_if_job_exists + if action_if_job_exists in ("increment", "fail"): + self.action_if_job_exists = action_if_job_exists + else: + raise AirflowException( + f"Argument action_if_job_exists accepts only 'increment' and 'fail'. \ + Provided value: '{action_if_job_exists}'." + ) + + def _create_integer_fields(self) -> None: + """Set fields which should be cast to integers.""" + self.integer_fields: list[list[str]] = [ + ["Transform", "TransformResources", "InstanceCount"], + ["Transform", "MaxConcurrentTransforms"], + ["Transform", "MaxPayloadInMB"], ] - if 'Transform' not in self.config: + if "Transform" not in self.config: for field in self.integer_fields: field.pop(0) def expand_role(self) -> None: - if 'Model' not in self.config: + """Expands an IAM role name into an ARN.""" + if "Model" not in self.config: return - config = self.config['Model'] - if 'ExecutionRoleArn' in config: - hook = AwsBaseHook(self.aws_conn_id, client_type='iam') - config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn']) + config = self.config["Model"] + if "ExecutionRoleArn" in config: + hook = AwsBaseHook(self.aws_conn_id, client_type="iam") + config["ExecutionRoleArn"] = hook.expand_role(config["ExecutionRoleArn"]) - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: self.preprocess_config() - model_config = self.config.get('Model') - transform_config = self.config.get('Transform', self.config) + model_config = self.config.get("Model") + transform_config = self.config.get("Transform", self.config) + if self.check_if_job_exists: + self._check_if_transform_job_exists() if model_config: - self.log.info('Creating SageMaker Model %s for transform job', model_config['ModelName']) + self.log.info("Creating SageMaker Model %s for transform job", model_config["ModelName"]) self.hook.create_model(model_config) - self.log.info('Creating SageMaker transform Job %s.', transform_config['TransformJobName']) + self.log.info("Creating SageMaker transform Job %s.", transform_config["TransformJobName"]) response = self.hook.create_transform_job( transform_config, wait_for_completion=self.wait_for_completion, check_interval=self.check_interval, max_ingestion_time=self.max_ingestion_time, ) - if response['ResponseMetadata']['HTTPStatusCode'] != 200: - raise AirflowException(f'Sagemaker transform Job creation failed: {response}') + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException(f"Sagemaker transform Job creation failed: {response}") else: return { - 'Model': self.hook.describe_model(transform_config['ModelName']), - 'Transform': self.hook.describe_transform_job(transform_config['TransformJobName']), + "Model": serialize(self.hook.describe_model(transform_config["ModelName"])), + "Transform": serialize( + self.hook.describe_transform_job(transform_config["TransformJobName"]) + ), } + def _check_if_transform_job_exists(self) -> None: + transform_config = self.config.get("Transform", self.config) + transform_job_name = transform_config["TransformJobName"] + transform_jobs = self.hook.list_transform_jobs(name_contains=transform_job_name) + if transform_job_name in [tj["TransformJobName"] for tj in transform_jobs]: + if self.action_if_job_exists == "increment": + self.log.info("Found existing transform job with name '%s'.", transform_job_name) + new_transform_job_name = f"{transform_job_name}-{(len(transform_jobs) + 1)}" + transform_config["TransformJobName"] = new_transform_job_name + self.log.info("Incremented transform job name to '%s'.", new_transform_job_name) + elif self.action_if_job_exists == "fail": + raise AirflowException( + f"A SageMaker transform job with name {transform_job_name} already exists." + ) + class SageMakerTuningOperator(SageMakerBaseOperator): """ @@ -485,14 +539,6 @@ class SageMakerTuningOperator(SageMakerBaseOperator): :return Dict: Returns The ARN of the tuning job created in Amazon SageMaker. """ - integer_fields = [ - ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxNumberOfTrainingJobs'], - ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxParallelTrainingJobs'], - ['TrainingJobDefinition', 'ResourceConfig', 'InstanceCount'], - ['TrainingJobDefinition', 'ResourceConfig', 'VolumeSizeInGB'], - ['TrainingJobDefinition', 'StoppingCondition', 'MaxRuntimeInSeconds'], - ] - def __init__( self, *, @@ -500,27 +546,36 @@ def __init__( aws_conn_id: str = DEFAULT_CONN_ID, wait_for_completion: bool = True, check_interval: int = CHECK_INTERVAL_SECOND, - max_ingestion_time: Optional[int] = None, + max_ingestion_time: int | None = None, **kwargs, ): - super().__init__(config=config, **kwargs) - self.config = config - self.aws_conn_id = aws_conn_id + super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) self.wait_for_completion = wait_for_completion self.check_interval = check_interval self.max_ingestion_time = max_ingestion_time def expand_role(self) -> None: - if 'TrainingJobDefinition' in self.config: - config = self.config['TrainingJobDefinition'] - if 'RoleArn' in config: - hook = AwsBaseHook(self.aws_conn_id, client_type='iam') - config['RoleArn'] = hook.expand_role(config['RoleArn']) + """Expands an IAM role name into an ARN.""" + if "TrainingJobDefinition" in self.config: + config = self.config["TrainingJobDefinition"] + if "RoleArn" in config: + hook = AwsBaseHook(self.aws_conn_id, client_type="iam") + config["RoleArn"] = hook.expand_role(config["RoleArn"]) + + def _create_integer_fields(self) -> None: + """Set fields which should be cast to integers.""" + self.integer_fields: list[list[str]] = [ + ["HyperParameterTuningJobConfig", "ResourceLimits", "MaxNumberOfTrainingJobs"], + ["HyperParameterTuningJobConfig", "ResourceLimits", "MaxParallelTrainingJobs"], + ["TrainingJobDefinition", "ResourceConfig", "InstanceCount"], + ["TrainingJobDefinition", "ResourceConfig", "VolumeSizeInGB"], + ["TrainingJobDefinition", "StoppingCondition", "MaxRuntimeInSeconds"], + ] - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: self.preprocess_config() self.log.info( - 'Creating SageMaker Hyper-Parameter Tuning Job %s', self.config['HyperParameterTuningJobName'] + "Creating SageMaker Hyper-Parameter Tuning Job %s", self.config["HyperParameterTuningJobName"] ) response = self.hook.create_tuning_job( self.config, @@ -528,10 +583,12 @@ def execute(self, context: 'Context') -> dict: check_interval=self.check_interval, max_ingestion_time=self.max_ingestion_time, ) - if response['ResponseMetadata']['HTTPStatusCode'] != 200: - raise AirflowException(f'Sagemaker Tuning Job creation failed: {response}') + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException(f"Sagemaker Tuning Job creation failed: {response}") else: - return {'Tuning': self.hook.describe_tuning_job(self.config['HyperParameterTuningJobName'])} + return { + "Tuning": serialize(self.hook.describe_tuning_job(self.config["HyperParameterTuningJobName"])) + } class SageMakerModelOperator(SageMakerBaseOperator): @@ -552,24 +609,23 @@ class SageMakerModelOperator(SageMakerBaseOperator): :return Dict: Returns The ARN of the model created in Amazon SageMaker. """ - def __init__(self, *, config, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs): - super().__init__(config=config, **kwargs) - self.config = config - self.aws_conn_id = aws_conn_id + def __init__(self, *, config: dict, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs): + super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) def expand_role(self) -> None: - if 'ExecutionRoleArn' in self.config: - hook = AwsBaseHook(self.aws_conn_id, client_type='iam') - self.config['ExecutionRoleArn'] = hook.expand_role(self.config['ExecutionRoleArn']) + """Expands an IAM role name into an ARN.""" + if "ExecutionRoleArn" in self.config: + hook = AwsBaseHook(self.aws_conn_id, client_type="iam") + self.config["ExecutionRoleArn"] = hook.expand_role(self.config["ExecutionRoleArn"]) - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: self.preprocess_config() - self.log.info('Creating SageMaker Model %s.', self.config['ModelName']) + self.log.info("Creating SageMaker Model %s.", self.config["ModelName"]) response = self.hook.create_model(self.config) - if response['ResponseMetadata']['HTTPStatusCode'] != 200: - raise AirflowException(f'Sagemaker model creation failed: {response}') + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException(f"Sagemaker model creation failed: {response}") else: - return {'Model': self.hook.describe_model(self.config['ModelName'])} + return {"Model": serialize(self.hook.describe_model(self.config["ModelName"]))} class SageMakerTrainingOperator(SageMakerBaseOperator): @@ -597,16 +653,10 @@ class SageMakerTrainingOperator(SageMakerBaseOperator): already exists for the name in the config. :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment" (default) and "fail". - This is only relevant if check_if + This is only relevant if check_if_job_exists is True. :return Dict: Returns The ARN of the training job created in Amazon SageMaker. """ - integer_fields = [ - ['ResourceConfig', 'InstanceCount'], - ['ResourceConfig', 'VolumeSizeInGB'], - ['StoppingCondition', 'MaxRuntimeInSeconds'], - ] - def __init__( self, *, @@ -615,19 +665,18 @@ def __init__( wait_for_completion: bool = True, print_log: bool = True, check_interval: int = CHECK_INTERVAL_SECOND, - max_ingestion_time: Optional[int] = None, + max_ingestion_time: int | None = None, check_if_job_exists: bool = True, - action_if_job_exists: str = 'increment', + action_if_job_exists: str = "increment", **kwargs, ): - super().__init__(config=config, **kwargs) - self.aws_conn_id = aws_conn_id + super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) self.wait_for_completion = wait_for_completion self.print_log = print_log self.check_interval = check_interval self.max_ingestion_time = max_ingestion_time self.check_if_job_exists = check_if_job_exists - if action_if_job_exists in ('increment', 'fail'): + if action_if_job_exists in ("increment", "fail"): self.action_if_job_exists = action_if_job_exists else: raise AirflowException( @@ -636,15 +685,24 @@ def __init__( ) def expand_role(self) -> None: - if 'RoleArn' in self.config: - hook = AwsBaseHook(self.aws_conn_id, client_type='iam') - self.config['RoleArn'] = hook.expand_role(self.config['RoleArn']) + """Expands an IAM role name into an ARN.""" + if "RoleArn" in self.config: + hook = AwsBaseHook(self.aws_conn_id, client_type="iam") + self.config["RoleArn"] = hook.expand_role(self.config["RoleArn"]) + + def _create_integer_fields(self) -> None: + """Set fields which should be cast to integers.""" + self.integer_fields: list[list[str]] = [ + ["ResourceConfig", "InstanceCount"], + ["ResourceConfig", "VolumeSizeInGB"], + ["StoppingCondition", "MaxRuntimeInSeconds"], + ] - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: self.preprocess_config() if self.check_if_job_exists: self._check_if_job_exists() - self.log.info('Creating SageMaker training job %s.', self.config['TrainingJobName']) + self.log.info("Creating SageMaker training job %s.", self.config["TrainingJobName"]) response = self.hook.create_training_job( self.config, wait_for_completion=self.wait_for_completion, @@ -652,23 +710,23 @@ def execute(self, context: 'Context') -> dict: check_interval=self.check_interval, max_ingestion_time=self.max_ingestion_time, ) - if response['ResponseMetadata']['HTTPStatusCode'] != 200: - raise AirflowException(f'Sagemaker Training Job creation failed: {response}') + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException(f"Sagemaker Training Job creation failed: {response}") else: - return {'Training': self.hook.describe_training_job(self.config['TrainingJobName'])} + return {"Training": serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))} def _check_if_job_exists(self) -> None: - training_job_name = self.config['TrainingJobName'] + training_job_name = self.config["TrainingJobName"] training_jobs = self.hook.list_training_jobs(name_contains=training_job_name) - if training_job_name in [tj['TrainingJobName'] for tj in training_jobs]: - if self.action_if_job_exists == 'increment': + if training_job_name in [tj["TrainingJobName"] for tj in training_jobs]: + if self.action_if_job_exists == "increment": self.log.info("Found existing training job with name '%s'.", training_job_name) - new_training_job_name = f'{training_job_name}-{(len(training_jobs) + 1)}' - self.config['TrainingJobName'] = new_training_job_name + new_training_job_name = f"{training_job_name}-{(len(training_jobs) + 1)}" + self.config["TrainingJobName"] = new_training_job_name self.log.info("Incremented training job name to '%s'.", new_training_job_name) - elif self.action_if_job_exists == 'fail': + elif self.action_if_job_exists == "fail": raise AirflowException( - f'A SageMaker training job with name {training_job_name} already exists.' + f"A SageMaker training job with name {training_job_name} already exists." ) @@ -685,12 +743,10 @@ class SageMakerDeleteModelOperator(SageMakerBaseOperator): :param aws_conn_id: The AWS connection ID to use. """ - def __init__(self, *, config, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs): - super().__init__(config=config, **kwargs) - self.config = config - self.aws_conn_id = aws_conn_id + def __init__(self, *, config: dict, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs): + super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) - def execute(self, context: 'Context') -> Any: + def execute(self, context: Context) -> Any: sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id) - sagemaker_hook.delete_model(model_name=self.config['ModelName']) - self.log.info("Model %s deleted successfully.", self.config['ModelName']) + sagemaker_hook.delete_model(model_name=self.config["ModelName"]) + self.log.info("Model %s deleted successfully.", self.config["ModelName"]) diff --git a/airflow/providers/amazon/aws/operators/sagemaker_base.py b/airflow/providers/amazon/aws/operators/sagemaker_base.py deleted file mode 100644 index 22e44d71c8ca0..0000000000000 --- a/airflow/providers/amazon/aws/operators/sagemaker_base.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.sagemaker import SageMakerBaseOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py b/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py deleted file mode 100644 index 5351431f00f68..0000000000000 --- a/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py b/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py deleted file mode 100644 index 737e5b6c7d2c8..0000000000000 --- a/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointConfigOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/sagemaker_model.py b/airflow/providers/amazon/aws/operators/sagemaker_model.py deleted file mode 100644 index fffe8d96e4b69..0000000000000 --- a/airflow/providers/amazon/aws/operators/sagemaker_model.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.sagemaker import SageMakerModelOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/sagemaker_processing.py b/airflow/providers/amazon/aws/operators/sagemaker_processing.py deleted file mode 100644 index b3a4be8fa2aa3..0000000000000 --- a/airflow/providers/amazon/aws/operators/sagemaker_processing.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.sagemaker import SageMakerProcessingOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/sagemaker_training.py b/airflow/providers/amazon/aws/operators/sagemaker_training.py deleted file mode 100644 index 40f13b45e932f..0000000000000 --- a/airflow/providers/amazon/aws/operators/sagemaker_training.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTrainingOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/sagemaker_transform.py b/airflow/providers/amazon/aws/operators/sagemaker_transform.py deleted file mode 100644 index 1e833edebcb16..0000000000000 --- a/airflow/providers/amazon/aws/operators/sagemaker_transform.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTransformOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/sagemaker_tuning.py b/airflow/providers/amazon/aws/operators/sagemaker_tuning.py deleted file mode 100644 index 18a8263a4f102..0000000000000 --- a/airflow/providers/amazon/aws/operators/sagemaker_tuning.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTuningOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/sns.py b/airflow/providers/amazon/aws/operators/sns.py index e916798d03386..99525c4cb6bc2 100644 --- a/airflow/providers/amazon/aws/operators/sns.py +++ b/airflow/providers/amazon/aws/operators/sns.py @@ -15,9 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations """Publish message to SNS queue""" -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.sns import SnsHook @@ -42,7 +43,7 @@ class SnsPublishOperator(BaseOperator): determined automatically) """ - template_fields: Sequence[str] = ('message', 'subject', 'message_attributes') + template_fields: Sequence[str] = ("target_arn", "message", "subject", "message_attributes", "aws_conn_id") template_ext: Sequence[str] = () template_fields_renderers = {"message_attributes": "json"} @@ -51,9 +52,9 @@ def __init__( *, target_arn: str, message: str, - aws_conn_id: str = 'aws_default', - subject: Optional[str] = None, - message_attributes: Optional[dict] = None, + subject: str | None = None, + message_attributes: dict | None = None, + aws_conn_id: str = "aws_default", **kwargs, ): super().__init__(**kwargs) @@ -63,11 +64,11 @@ def __init__( self.message_attributes = message_attributes self.aws_conn_id = aws_conn_id - def execute(self, context: 'Context'): + def execute(self, context: Context): sns = SnsHook(aws_conn_id=self.aws_conn_id) self.log.info( - 'Sending SNS notification to %s using %s:\nsubject=%s\nattributes=%s\nmessage=%s', + "Sending SNS notification to %s using %s:\nsubject=%s\nattributes=%s\nmessage=%s", self.target_arn, self.aws_conn_id, self.subject, diff --git a/airflow/providers/amazon/aws/operators/sqs.py b/airflow/providers/amazon/aws/operators/sqs.py index 6eff54134a75d..0b0cfc4f16297 100644 --- a/airflow/providers/amazon/aws/operators/sqs.py +++ b/airflow/providers/amazon/aws/operators/sqs.py @@ -14,10 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Publish message to SQS queue""" -import warnings -from typing import TYPE_CHECKING, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.sqs import SqsHook @@ -39,21 +39,30 @@ class SqsPublishOperator(BaseOperator): :param message_attributes: additional attributes for the message (default: None) For details of the attributes parameter see :py:meth:`botocore.client.SQS.send_message` :param delay_seconds: message delay (templated) (default: 1 second) + :param message_group_id: This parameter applies only to FIFO (first-in-first-out) queues. (default: None) + For details of the attributes parameter see :py:meth:`botocore.client.SQS.send_message` :param aws_conn_id: AWS connection id (default: aws_default) """ - template_fields: Sequence[str] = ('sqs_queue', 'message_content', 'delay_seconds', 'message_attributes') - template_fields_renderers = {'message_attributes': 'json'} - ui_color = '#6ad3fa' + template_fields: Sequence[str] = ( + "sqs_queue", + "message_content", + "delay_seconds", + "message_attributes", + "message_group_id", + ) + template_fields_renderers = {"message_attributes": "json"} + ui_color = "#6ad3fa" def __init__( self, *, sqs_queue: str, message_content: str, - message_attributes: Optional[dict] = None, + message_attributes: dict | None = None, delay_seconds: int = 0, - aws_conn_id: str = 'aws_default', + message_group_id: str | None = None, + aws_conn_id: str = "aws_default", **kwargs, ): super().__init__(**kwargs) @@ -62,15 +71,15 @@ def __init__( self.message_content = message_content self.delay_seconds = delay_seconds self.message_attributes = message_attributes or {} + self.message_group_id = message_group_id - def execute(self, context: 'Context'): + def execute(self, context: Context) -> dict: """ Publish the message to the Amazon SQS queue :param context: the context object :return: dict with information about the message sent For details of the returned dict see :py:meth:`botocore.client.SQS.send_message` - :rtype: dict """ hook = SqsHook(aws_conn_id=self.aws_conn_id) @@ -79,24 +88,9 @@ def execute(self, context: 'Context'): message_body=self.message_content, delay_seconds=self.delay_seconds, message_attributes=self.message_attributes, + message_group_id=self.message_group_id, ) - self.log.info('send_message result: %s', result) + self.log.info("send_message result: %s", result) return result - - -class SQSPublishOperator(SqsPublishOperator): - """ - This operator is deprecated. - Please use :class:`airflow.providers.amazon.aws.operators.sqs.SqsPublishOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This operator is deprecated. " - "Please use `airflow.providers.amazon.aws.operators.sqs.SqsPublishOperator`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/operators/step_function.py b/airflow/providers/amazon/aws/operators/step_function.py index 7c32b33890117..c131dabc742e1 100644 --- a/airflow/providers/amazon/aws/operators/step_function.py +++ b/airflow/providers/amazon/aws/operators/step_function.py @@ -14,10 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +from __future__ import annotations import json -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -44,18 +44,18 @@ class StepFunctionStartExecutionOperator(BaseOperator): :param do_xcom_push: if True, execution_arn is pushed to XCom with key execution_arn. """ - template_fields: Sequence[str] = ('state_machine_arn', 'name', 'input') + template_fields: Sequence[str] = ("state_machine_arn", "name", "input") template_ext: Sequence[str] = () - ui_color = '#f9c915' + ui_color = "#f9c915" def __init__( self, *, state_machine_arn: str, - name: Optional[str] = None, - state_machine_input: Union[dict, str, None] = None, - aws_conn_id: str = 'aws_default', - region_name: Optional[str] = None, + name: str | None = None, + state_machine_input: dict | str | None = None, + aws_conn_id: str = "aws_default", + region_name: str | None = None, **kwargs, ): super().__init__(**kwargs) @@ -65,15 +65,15 @@ def __init__( self.aws_conn_id = aws_conn_id self.region_name = region_name - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) execution_arn = hook.start_execution(self.state_machine_arn, self.name, self.input) if execution_arn is None: - raise AirflowException(f'Failed to start State Machine execution for: {self.state_machine_arn}') + raise AirflowException(f"Failed to start State Machine execution for: {self.state_machine_arn}") - self.log.info('Started State Machine execution for %s: %s', self.state_machine_arn, execution_arn) + self.log.info("Started State Machine execution for %s: %s", self.state_machine_arn, execution_arn) return execution_arn @@ -92,16 +92,16 @@ class StepFunctionGetExecutionOutputOperator(BaseOperator): :param aws_conn_id: aws connection to use, defaults to 'aws_default' """ - template_fields: Sequence[str] = ('execution_arn',) + template_fields: Sequence[str] = ("execution_arn",) template_ext: Sequence[str] = () - ui_color = '#f9c915' + ui_color = "#f9c915" def __init__( self, *, execution_arn: str, - aws_conn_id: str = 'aws_default', - region_name: Optional[str] = None, + aws_conn_id: str = "aws_default", + region_name: str | None = None, **kwargs, ): super().__init__(**kwargs) @@ -109,12 +109,12 @@ def __init__( self.aws_conn_id = aws_conn_id self.region_name = region_name - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) execution_status = hook.describe_execution(self.execution_arn) - execution_output = json.loads(execution_status['output']) if 'output' in execution_status else None + execution_output = json.loads(execution_status["output"]) if "output" in execution_status else None - self.log.info('Got State Machine Execution output for %s', self.execution_arn) + self.log.info("Got State Machine Execution output for %s", self.execution_arn) return execution_output diff --git a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py deleted file mode 100644 index 2b047241ae861..0000000000000 --- a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py +++ /dev/null @@ -1,30 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.step_function`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.step_function import ( # noqa - StepFunctionGetExecutionOutputOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.step_function`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/operators/step_function_start_execution.py b/airflow/providers/amazon/aws/operators/step_function_start_execution.py deleted file mode 100644 index 10a847ffc90bd..0000000000000 --- a/airflow/providers/amazon/aws/operators/step_function_start_execution.py +++ /dev/null @@ -1,28 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.step_function`.""" - -import warnings - -from airflow.providers.amazon.aws.operators.step_function import StepFunctionStartExecutionOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.step_function`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/secrets/secrets_manager.py b/airflow/providers/amazon/aws/secrets/secrets_manager.py index 8b72f955ac6d5..684a8d4afbee9 100644 --- a/airflow/providers/amazon/aws/secrets/secrets_manager.py +++ b/airflow/providers/amazon/aws/secrets/secrets_manager.py @@ -16,31 +16,22 @@ # specific language governing permissions and limitations # under the License. """Objects relating to sourcing secrets from AWS Secrets Manager""" +from __future__ import annotations import ast import json -import re -import sys import warnings -from typing import Optional -from urllib.parse import urlencode - -import boto3 - -from airflow.version import version as airflow_version - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property +from typing import TYPE_CHECKING, Any +from urllib.parse import unquote, urlencode +from airflow.compat.functools import cached_property +from airflow.providers.amazon.aws.utils import get_airflow_version, trim_none_values from airflow.secrets import BaseSecretsBackend from airflow.utils.log.logging_mixin import LoggingMixin - -def _parse_version(val): - val = re.sub(r'(\d+\.\d+\.\d+).*', lambda x: x.group(1), val) - return tuple(int(x) for x in val.split('.')) +if TYPE_CHECKING: + # Avoid circular import problems when instantiating the backend during configuration. + from airflow.models.connection import Connection class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin): @@ -63,8 +54,17 @@ class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin): if you provide ``{"config_prefix": "airflow/config"}`` and request config key ``sql_alchemy_conn``. - You can also pass additional keyword arguments like ``aws_secret_access_key``, ``aws_access_key_id`` - or ``region_name`` to this class and they would be passed on to Boto3 client. + You can also pass additional keyword arguments listed in AWS Connection Extra config + to this class, and they would be used for establishing a connection and passed on to Boto3 client. + + .. code-block:: ini + + [secrets] + backend = airflow.providers.amazon.aws.secrets.secrets_manager.SecretsManagerBackend + backend_kwargs = {"connections_prefix": "airflow/connections", "region_name": "eu-west-1"} + + .. seealso:: + :ref:`howto/connection:aws:configuring-the-connection` There are two ways of storing secrets in Secret Manager for using them with this operator: storing them as a conn URI in one field, or taking advantage of native approach of Secrets Manager @@ -74,7 +74,7 @@ class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin): .. code-block:: python possible_words_for_conn_fields = { - "user": ["user", "username", "login", "user_name"], + "login": ["user", "username", "login", "user_name"], "password": ["password", "pass", "key"], "host": ["host", "remote_host", "server"], "port": ["port"], @@ -95,11 +95,13 @@ class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin): :param config_prefix: Specifies the prefix of the secret to read to get Configurations. If set to None (null value in the configuration), requests for configurations will not be sent to AWS Secrets Manager. If you don't want a config_prefix, set it as an empty string - :param profile_name: The name of a profile to use. If not given, then the default profile is used. :param sep: separator used to concatenate secret_prefix and secret_id. Default: "/" :param full_url_mode: if True, the secrets must be stored as one conn URI in just one field per secret. If False (set it as false in backend_kwargs), you can store the secret using different fields (password, user...). + :param are_secret_values_urlencoded: If True, and full_url_mode is False, then the values are assumed to + be URL-encoded and will be decoded before being passed into a Connection object. This option is + ignored when full_url_mode is True. :param extra_conn_words: for using just when you set full_url_mode as false and store the secrets in different fields of secrets manager. You can add more words for each connection part beyond the default ones. The extra words to be searched should be passed as a dict of lists, @@ -109,13 +111,13 @@ class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin): def __init__( self, - connections_prefix: str = 'airflow/connections', - variables_prefix: str = 'airflow/variables', - config_prefix: str = 'airflow/config', - profile_name: Optional[str] = None, + connections_prefix: str = "airflow/connections", + variables_prefix: str = "airflow/variables", + config_prefix: str = "airflow/config", sep: str = "/", full_url_mode: bool = True, - extra_conn_words: Optional[dict] = None, + are_secret_values_urlencoded: bool | None = None, + extra_conn_words: dict[str, list[str]] | None = None, **kwargs, ): super().__init__() @@ -131,23 +133,62 @@ def __init__( self.config_prefix = config_prefix.rstrip(sep) else: self.config_prefix = config_prefix - self.profile_name = profile_name self.sep = sep self.full_url_mode = full_url_mode + + if are_secret_values_urlencoded is None: + self.are_secret_values_urlencoded = True + else: + warnings.warn( + "The `secret_values_are_urlencoded` kwarg only exists to assist in migrating away from" + " URL-encoding secret values when `full_url_mode` is False. It will be considered deprecated" + " when values are not required to be URL-encoded by default.", + DeprecationWarning, + stacklevel=2, + ) + if full_url_mode and not are_secret_values_urlencoded: + warnings.warn( + "The `secret_values_are_urlencoded` kwarg for the SecretsManagerBackend is only used" + " when `full_url_mode` is False. When `full_url_mode` is True, the secret needs to be" + " URL-encoded.", + UserWarning, + stacklevel=2, + ) + self.are_secret_values_urlencoded = are_secret_values_urlencoded self.extra_conn_words = extra_conn_words or {} + + self.profile_name = kwargs.get("profile_name", None) + # Remove client specific arguments from kwargs + self.api_version = kwargs.pop("api_version", None) + self.use_ssl = kwargs.pop("use_ssl", None) + self.kwargs = kwargs @cached_property def client(self): """Create a Secrets Manager client""" - session = boto3.session.Session(profile_name=self.profile_name) - - return session.client(service_name="secretsmanager", **self.kwargs) + from airflow.providers.amazon.aws.hooks.base_aws import SessionFactory + from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper + + conn_id = f"{self.__class__.__name__}__connection" + conn_config = AwsConnectionWrapper.from_connection_metadata(conn_id=conn_id, extra=self.kwargs) + client_kwargs = trim_none_values( + { + "region_name": conn_config.region_name, + "verify": conn_config.verify, + "endpoint_url": conn_config.endpoint_url, + "api_version": self.api_version, + "use_ssl": self.use_ssl, + } + ) + + session = SessionFactory(conn=conn_config).create_session() + return session.client(service_name="secretsmanager", **client_kwargs) @staticmethod - def _format_uri_with_extra(secret, conn_string): + def _format_uri_with_extra(secret, conn_string: str) -> str: try: - extra_dict = secret['extra'] + extra_dict = secret["extra"] except KeyError: return conn_string @@ -156,31 +197,161 @@ def _format_uri_with_extra(secret, conn_string): return conn_string - def get_uri_from_secret(self, secret): + def get_connection(self, conn_id: str) -> Connection | None: + if not self.full_url_mode: + # Avoid circular import problems when instantiating the backend during configuration. + from airflow.models.connection import Connection + + secret_string = self._get_secret(self.connections_prefix, conn_id) + secret_dict = self._deserialize_json_string(secret_string) + + if not secret_dict: + return None + + if "extra" in secret_dict and isinstance(secret_dict["extra"], str): + secret_dict["extra"] = self._deserialize_json_string(secret_dict["extra"]) + + data = self._standardize_secret_keys(secret_dict) + + if self.are_secret_values_urlencoded: + data = self._remove_escaping_in_secret_dict(secret=data, conn_id=conn_id) + + port: int | None = None + + if data["port"] is not None: + port = int(data["port"]) + + return Connection( + conn_id=conn_id, + login=data["user"], + password=data["password"], + host=data["host"], + port=port, + schema=data["schema"], + conn_type=data["conn_type"], + extra=data["extra"], + ) + + return super().get_connection(conn_id=conn_id) + + def _standardize_secret_keys(self, secret: dict[str, Any]) -> dict[str, Any]: + """Standardize the names of the keys in the dict. These keys align with""" possible_words_for_conn_fields = { - 'user': ['user', 'username', 'login', 'user_name'], - 'password': ['password', 'pass', 'key'], - 'host': ['host', 'remote_host', 'server'], - 'port': ['port'], - 'schema': ['database', 'schema'], - 'conn_type': ['conn_type', 'conn_id', 'connection_type', 'engine'], + "user": ["user", "username", "login", "user_name"], + "password": ["password", "pass", "key"], + "host": ["host", "remote_host", "server"], + "port": ["port"], + "schema": ["database", "schema"], + "conn_type": ["conn_type", "conn_id", "connection_type", "engine"], + "extra": ["extra"], } for conn_field, extra_words in self.extra_conn_words.items(): possible_words_for_conn_fields[conn_field].extend(extra_words) - conn_d = {} + conn_d: dict[str, Any] = {} for conn_field, possible_words in possible_words_for_conn_fields.items(): try: conn_d[conn_field] = [v for k, v in secret.items() if k in possible_words][0] except IndexError: - conn_d[conn_field] = '' + conn_d[conn_field] = None - conn_string = "{conn_type}://{user}:{password}@{host}:{port}/{schema}".format(**conn_d) + return conn_d + def get_uri_from_secret(self, secret: dict[str, str]) -> str: + conn_d: dict[str, str] = {k: v if v else "" for k, v in self._standardize_secret_keys(secret).items()} + conn_string = "{conn_type}://{user}:{password}@{host}:{port}/{schema}".format(**conn_d) return self._format_uri_with_extra(secret, conn_string) - def get_conn_value(self, conn_id: str): + def _deserialize_json_string(self, value: str | None) -> dict[Any, Any] | None: + if not value: + return None + try: + # Use ast.literal_eval for backwards compatibility. + # Previous version of this code had a comment saying that using json.loads caused errors. + # This likely means people were using dict reprs instead of valid JSONs. + res: dict[str, Any] = json.loads(value) + except json.JSONDecodeError: + try: + res = ast.literal_eval(value) if value else None + warnings.warn( + f"In future versions, `{type(self).__name__}` will only support valid JSONs, not dict" + " reprs. Please make sure your secret is a valid JSON." + ) + except ValueError: # 'malformed node or string: ' error, for empty conns + return None + + return res + + def _remove_escaping_in_secret_dict(self, secret: dict[str, Any], conn_id: str) -> dict[str, Any]: + # When ``unquote(v) == v``, then removing unquote won't affect the user, regardless of + # whether or not ``v`` is URL-encoded. For example, "foo bar" is not URL-encoded. But + # because decoding it doesn't affect the value, then it will migrate safely when + # ``unquote`` gets removed. + # + # When parameters are URL-encoded, but decoding is idempotent, we need to warn the user + # to un-escape their secrets. For example, if "foo%20bar" is a URL-encoded string, then + # decoding is idempotent because ``unquote(unquote("foo%20bar")) == unquote("foo%20bar")``. + # + # In the rare situation that value is URL-encoded but the decoding is _not_ idempotent, + # this causes a major issue. For example, if "foo%2520bar" is URL-encoded, then decoding is + # _not_ idempotent because ``unquote(unquote("foo%2520bar")) != unquote("foo%2520bar")`` + # + # This causes a problem for migration because if the user decodes their value, we cannot + # infer that is the case by looking at the decoded value (from our vantage point, it will + # look to be URL-encoded.) + # + # So when this uncommon situation occurs, the user _must_ adjust the configuration and set + # ``parameters_are_urlencoded`` to False to migrate safely. In all other cases, we do not + # need the user to adjust this object to migrate; they can transition their secrets with + # the default configuration. + + warn_user = False + idempotent = True + + for k, v in secret.copy().items(): + + if k == "extra" and isinstance(v, dict): + # The old behavior was that extras were _not_ urlencoded inside the secret. + # If they were urlencoded (e.g. "foo%20bar"), then they would be re-urlencoded + # (e.g. "foo%20bar" becomes "foo%2520bar") and then unquoted once when parsed. + # So we should just allow the extra dict to remain as-is. + continue + + elif v is not None: + v_unquoted = unquote(v) + if v != v_unquoted: + secret[k] = unquote(v) + warn_user = True + + # Check to see if decoding is idempotent. + if v_unquoted == unquote(v_unquoted): + idempotent = False + + if warn_user: + msg = ( + "When full_url_mode=False, URL-encoding secret values is deprecated. In future versions, " + f"this value will not be un-escaped. For the conn_id {conn_id!r}, please remove the " + "URL-encoding.\n\n" + "This warning was raised because the SecretsManagerBackend detected that this " + "connection was URL-encoded." + ) + if idempotent: + msg = f" Once the values for conn_id {conn_id!r} are decoded, this warning will go away." + if not idempotent: + msg += ( + " In addition to decoding the values for your connection, you must also set" + " secret_values_are_urlencoded=False for your config variable" + " secrets.backend_kwargs because this connection's URL encoding is not idempotent." + " For more information, see:" + " https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/secrets-backends" + "/aws-secrets-manager.html#url-encoding-of-secrets-when-full-url-mode-is-false" + ) + warnings.warn(msg, DeprecationWarning, stacklevel=2) + + return secret + + def get_conn_value(self, conn_id: str) -> str | None: """ Get serialized representation of Connection @@ -191,13 +362,30 @@ def get_conn_value(self, conn_id: str): if self.full_url_mode: return self._get_secret(self.connections_prefix, conn_id) - try: - secret_string = self._get_secret(self.connections_prefix, conn_id) - # json.loads gives error - secret = ast.literal_eval(secret_string) if secret_string else None - except ValueError: # 'malformed node or string: ' error, for empty conns - connection = None - secret = None + else: + warnings.warn( + f"In future versions, `{type(self).__name__}.get_conn_value` will return a JSON string when" + " full_url_mode is False, not a URI.", + DeprecationWarning, + ) + + # It is very rare for user code to get to this point, since: + # + # - When full_url_mode is True, the previous statement returns. + # - When full_url_mode is False, get_connection() does not call + # `get_conn_value`. Additionally, full_url_mode defaults to True. + # + # So the code would have to be calling `get_conn_value` directly, and + # the user would be using a non-default setting. + # + # As of Airflow 2.3.0, get_conn_value() is allowed to return a JSON + # string in the base implementation. This is a way to deprecate this + # behavior gracefully. + + secret_string = self._get_secret(self.connections_prefix, conn_id) + + secret = self._deserialize_json_string(secret_string) + connection = None # These lines will check if we have with some denomination stored an username, password and host if secret: @@ -205,7 +393,7 @@ def get_conn_value(self, conn_id: str): return connection - def get_conn_uri(self, conn_id: str) -> Optional[str]: + def get_conn_uri(self, conn_id: str) -> str | None: """ Return URI representation of Connection conn_id. @@ -214,7 +402,7 @@ def get_conn_uri(self, conn_id: str) -> Optional[str]: :param conn_id: the connection id :return: deserialized Connection """ - if _parse_version(airflow_version) >= (2, 3): + if get_airflow_version() >= (2, 3): warnings.warn( f"Method `{self.__class__.__name__}.get_conn_uri` is deprecated and will be removed " "in a future release. Please use method `get_conn_value` instead.", @@ -223,7 +411,7 @@ def get_conn_uri(self, conn_id: str) -> Optional[str]: ) return self.get_conn_value(conn_id) - def get_variable(self, key: str) -> Optional[str]: + def get_variable(self, key: str) -> str | None: """ Get Airflow Variable from Environment Variable :param key: Variable Key @@ -234,7 +422,7 @@ def get_variable(self, key: str) -> Optional[str]: return self._get_secret(self.variables_prefix, key) - def get_config(self, key: str) -> Optional[str]: + def get_config(self, key: str) -> str | None: """ Get Airflow Configuration :param key: Configuration Option Key @@ -245,12 +433,13 @@ def get_config(self, key: str) -> Optional[str]: return self._get_secret(self.config_prefix, key) - def _get_secret(self, path_prefix, secret_id: str) -> Optional[str]: + def _get_secret(self, path_prefix, secret_id: str) -> str | None: """ Get secret value from Secrets Manager :param path_prefix: Prefix for the Path to get Secret :param secret_id: Secret Key """ + error_msg = "An error occurred when calling the get_secret_value operation" if path_prefix: secrets_path = self.build_path(path_prefix, secret_id, self.sep) else: @@ -260,18 +449,39 @@ def _get_secret(self, path_prefix, secret_id: str) -> Optional[str]: response = self.client.get_secret_value( SecretId=secrets_path, ) - return response.get('SecretString') + return response.get("SecretString") except self.client.exceptions.ResourceNotFoundException: self.log.debug( - "An error occurred (ResourceNotFoundException) when calling the " - "get_secret_value operation: " - "Secret %s not found.", + "ResourceNotFoundException: %s. Secret %s not found.", + error_msg, secret_id, ) return None - except self.client.exceptions.AccessDeniedException: + except self.client.exceptions.InvalidParameterException: + self.log.debug( + "InvalidParameterException: %s", + error_msg, + exc_info=True, + ) + return None + except self.client.exceptions.InvalidRequestException: + self.log.debug( + "InvalidRequestException: %s", + error_msg, + exc_info=True, + ) + return None + except self.client.exceptions.DecryptionFailure: + self.log.debug( + "DecryptionFailure: %s", + error_msg, + exc_info=True, + ) + return None + except self.client.exceptions.InternalServiceError: self.log.debug( - "An error occurred (AccessDeniedException) when calling the get_secret_value operation", + "InternalServiceError: %s", + error_msg, exc_info=True, ) return None diff --git a/airflow/providers/amazon/aws/secrets/systems_manager.py b/airflow/providers/amazon/aws/secrets/systems_manager.py index e45a5500ab003..af45cb5e9813b 100644 --- a/airflow/providers/amazon/aws/secrets/systems_manager.py +++ b/airflow/providers/amazon/aws/secrets/systems_manager.py @@ -16,29 +16,16 @@ # specific language governing permissions and limitations # under the License. """Objects relating to sourcing connections from AWS SSM Parameter Store""" -import re -import sys -import warnings -from typing import Optional - -import boto3 - -from airflow.version import version as airflow_version +from __future__ import annotations -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property +import warnings +from airflow.compat.functools import cached_property +from airflow.providers.amazon.aws.utils import trim_none_values from airflow.secrets import BaseSecretsBackend from airflow.utils.log.logging_mixin import LoggingMixin -def _parse_version(val): - val = re.sub(r'(\d+\.\d+\.\d+).*', lambda x: x.group(1), val) - return tuple(int(x) for x in val.split('.')) - - class SystemsManagerParameterStoreBackend(BaseSecretsBackend, LoggingMixin): """ Retrieves Connection or Variables from AWS SSM Parameter Store @@ -62,15 +49,26 @@ class SystemsManagerParameterStoreBackend(BaseSecretsBackend, LoggingMixin): If set to None (null), requests for variables will not be sent to AWS SSM Parameter Store. :param config_prefix: Specifies the prefix of the secret to read to get Variables. If set to None (null), requests for configurations will not be sent to AWS SSM Parameter Store. - :param profile_name: The name of a profile to use. If not given, then the default profile is used. + + You can also pass additional keyword arguments listed in AWS Connection Extra config + to this class, and they would be used for establish connection and passed on to Boto3 client. + + .. code-block:: ini + + [secrets] + backend = airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend + backend_kwargs = {"connections_prefix": "airflow/connections", "region_name": "eu-west-1"} + + .. seealso:: + :ref:`howto/connection:aws:configuring-the-connection` + """ def __init__( self, - connections_prefix: str = '/airflow/connections', - variables_prefix: str = '/airflow/variables', - config_prefix: str = '/airflow/config', - profile_name: Optional[str] = None, + connections_prefix: str = "/airflow/connections", + variables_prefix: str = "/airflow/variables", + config_prefix: str = "/airflow/config", **kwargs, ): super().__init__() @@ -79,23 +77,43 @@ def __init__( else: self.connections_prefix = connections_prefix if variables_prefix is not None: - self.variables_prefix = variables_prefix.rstrip('/') + self.variables_prefix = variables_prefix.rstrip("/") else: self.variables_prefix = variables_prefix if config_prefix is not None: - self.config_prefix = config_prefix.rstrip('/') + self.config_prefix = config_prefix.rstrip("/") else: self.config_prefix = config_prefix - self.profile_name = profile_name + + self.profile_name = kwargs.get("profile_name", None) + # Remove client specific arguments from kwargs + self.api_version = kwargs.pop("api_version", None) + self.use_ssl = kwargs.pop("use_ssl", None) + self.kwargs = kwargs @cached_property def client(self): """Create a SSM client""" - session = boto3.Session(profile_name=self.profile_name) - return session.client("ssm", **self.kwargs) - - def get_conn_value(self, conn_id: str) -> Optional[str]: + from airflow.providers.amazon.aws.hooks.base_aws import SessionFactory + from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper + + conn_id = f"{self.__class__.__name__}__connection" + conn_config = AwsConnectionWrapper.from_connection_metadata(conn_id=conn_id, extra=self.kwargs) + client_kwargs = trim_none_values( + { + "region_name": conn_config.region_name, + "verify": conn_config.verify, + "endpoint_url": conn_config.endpoint_url, + "api_version": self.api_version, + "use_ssl": self.use_ssl, + } + ) + + session = SessionFactory(conn=conn_config).create_session() + return session.client(service_name="ssm", **client_kwargs) + + def get_conn_value(self, conn_id: str) -> str | None: """ Get param value @@ -106,25 +124,27 @@ def get_conn_value(self, conn_id: str) -> Optional[str]: return self._get_secret(self.connections_prefix, conn_id) - def get_conn_uri(self, conn_id: str) -> Optional[str]: + def get_conn_uri(self, conn_id: str) -> str | None: """ Return URI representation of Connection conn_id. As of Airflow version 2.3.0 this method is deprecated. :param conn_id: the connection id - :return: deserialized Connection """ - if _parse_version(airflow_version) >= (2, 3): - warnings.warn( - f"Method `{self.__class__.__name__}.get_conn_uri` is deprecated and will be removed " - "in a future release. Please use method `get_conn_value` instead.", - DeprecationWarning, - stacklevel=2, - ) - return self.get_conn_value(conn_id) - - def get_variable(self, key: str) -> Optional[str]: + warnings.warn( + f"Method `{self.__class__.__name__}.get_conn_uri` is deprecated and will be removed " + "in a future release. Please use method `get_conn_value` instead.", + DeprecationWarning, + stacklevel=2, + ) + value = self.get_conn_value(conn_id) + if value is None: + return None + + return self.deserialize_connection(conn_id, value).get_uri() + + def get_variable(self, key: str) -> str | None: """ Get Airflow Variable from Environment Variable @@ -136,7 +156,7 @@ def get_variable(self, key: str) -> Optional[str]: return self._get_secret(self.variables_prefix, key) - def get_config(self, key: str) -> Optional[str]: + def get_config(self, key: str) -> str | None: """ Get Airflow Configuration @@ -148,7 +168,7 @@ def get_config(self, key: str) -> Optional[str]: return self._get_secret(self.config_prefix, key) - def _get_secret(self, path_prefix: str, secret_id: str) -> Optional[str]: + def _get_secret(self, path_prefix: str, secret_id: str) -> str | None: """ Get secret value from Parameter Store. diff --git a/airflow/providers/amazon/aws/sensors/athena.py b/airflow/providers/amazon/aws/sensors/athena.py index 927f512143362..1954d15a8e004 100644 --- a/airflow/providers/amazon/aws/sensors/athena.py +++ b/airflow/providers/amazon/aws/sensors/athena.py @@ -15,17 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys -from typing import TYPE_CHECKING, Any, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence if TYPE_CHECKING: from airflow.utils.context import Context -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.athena import AthenaHook from airflow.sensors.base import BaseSensorOperator @@ -37,7 +34,7 @@ class AthenaSensor(BaseSensorOperator): If the query fails, the task will fail. .. seealso:: - For more information on how to use this operator, take a look at the guide: + For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:AthenaSensor` @@ -50,25 +47,25 @@ class AthenaSensor(BaseSensorOperator): """ INTERMEDIATE_STATES = ( - 'QUEUED', - 'RUNNING', + "QUEUED", + "RUNNING", ) FAILURE_STATES = ( - 'FAILED', - 'CANCELLED', + "FAILED", + "CANCELLED", ) - SUCCESS_STATES = ('SUCCEEDED',) + SUCCESS_STATES = ("SUCCEEDED",) - template_fields: Sequence[str] = ('query_execution_id',) + template_fields: Sequence[str] = ("query_execution_id",) template_ext: Sequence[str] = () - ui_color = '#66c3ff' + ui_color = "#66c3ff" def __init__( self, *, query_execution_id: str, - max_retries: Optional[int] = None, - aws_conn_id: str = 'aws_default', + max_retries: int | None = None, + aws_conn_id: str = "aws_default", sleep_time: int = 10, **kwargs: Any, ) -> None: @@ -78,11 +75,11 @@ def __init__( self.sleep_time = sleep_time self.max_retries = max_retries - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: state = self.hook.poll_query_status(self.query_execution_id, self.max_retries) if state in self.FAILURE_STATES: - raise AirflowException('Athena sensor failed') + raise AirflowException("Athena sensor failed") if state in self.INTERMEDIATE_STATES: return False diff --git a/airflow/providers/amazon/aws/sensors/batch.py b/airflow/providers/amazon/aws/sensors/batch.py index faab424e314fa..facd01a00e4cf 100644 --- a/airflow/providers/amazon/aws/sensors/batch.py +++ b/airflow/providers/amazon/aws/sensors/batch.py @@ -14,8 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Sequence +import sys +from typing import TYPE_CHECKING, Sequence + +if sys.version_info >= (3, 8): + from functools import cached_property +else: + from cached_property import cached_property from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook @@ -36,29 +43,30 @@ class BatchSensor(BaseSensorOperator): :param job_id: Batch job_id to check the state for :param aws_conn_id: aws connection to use, defaults to 'aws_default' + :param region_name: aws region name associated with the client """ - template_fields: Sequence[str] = ('job_id',) + template_fields: Sequence[str] = ("job_id",) template_ext: Sequence[str] = () - ui_color = '#66c3ff' + ui_color = "#66c3ff" def __init__( self, *, job_id: str, - aws_conn_id: str = 'aws_default', - region_name: Optional[str] = None, + aws_conn_id: str = "aws_default", + region_name: str | None = None, **kwargs, ): super().__init__(**kwargs) self.job_id = job_id self.aws_conn_id = aws_conn_id self.region_name = region_name - self.hook: Optional[BatchClientHook] = None + self.hook: BatchClientHook | None = None - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: job_description = self.get_hook().get_job_description(self.job_id) - state = job_description['status'] + state = job_description["status"] if state == BatchClientHook.SUCCESS_STATE: return True @@ -67,9 +75,9 @@ def poke(self, context: 'Context') -> bool: return False if state == BatchClientHook.FAILURE_STATE: - raise AirflowException(f'Batch sensor failed. AWS Batch job status: {state}') + raise AirflowException(f"Batch sensor failed. AWS Batch job status: {state}") - raise AirflowException(f'Batch sensor failed. Unknown AWS Batch job status: {state}') + raise AirflowException(f"Batch sensor failed. Unknown AWS Batch job status: {state}") def get_hook(self) -> BatchClientHook: """Create and return a BatchClientHook""" @@ -81,3 +89,129 @@ def get_hook(self) -> BatchClientHook: region_name=self.region_name, ) return self.hook + + +class BatchComputeEnvironmentSensor(BaseSensorOperator): + """ + Asks for the state of the Batch compute environment until it reaches a failure state or success state. + If the environment fails, the task will fail. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:BatchComputeEnvironmentSensor` + + :param compute_environment: Batch compute environment name + + :param aws_conn_id: aws connection to use, defaults to 'aws_default' + + :param region_name: aws region name associated with the client + """ + + template_fields: Sequence[str] = ("compute_environment",) + template_ext: Sequence[str] = () + ui_color = "#66c3ff" + + def __init__( + self, + compute_environment: str, + aws_conn_id: str = "aws_default", + region_name: str | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.compute_environment = compute_environment + self.aws_conn_id = aws_conn_id + self.region_name = region_name + + @cached_property + def hook(self) -> BatchClientHook: + """Create and return a BatchClientHook""" + return BatchClientHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + ) + + def poke(self, context: Context) -> bool: + response = self.hook.client.describe_compute_environments( + computeEnvironments=[self.compute_environment] + ) + + if len(response["computeEnvironments"]) == 0: + raise AirflowException(f"AWS Batch compute environment {self.compute_environment} not found") + + status = response["computeEnvironments"][0]["status"] + + if status in BatchClientHook.COMPUTE_ENVIRONMENT_TERMINAL_STATUS: + return True + + if status in BatchClientHook.COMPUTE_ENVIRONMENT_INTERMEDIATE_STATUS: + return False + + raise AirflowException( + f"AWS Batch compute environment failed. AWS Batch compute environment status: {status}" + ) + + +class BatchJobQueueSensor(BaseSensorOperator): + """ + Asks for the state of the Batch job queue until it reaches a failure state or success state. + If the queue fails, the task will fail. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:BatchJobQueueSensor` + + :param job_queue: Batch job queue name + + :param treat_non_existing_as_deleted: If True, a non-existing Batch job queue is considered as a deleted + queue and as such a valid case. + + :param aws_conn_id: aws connection to use, defaults to 'aws_default' + + :param region_name: aws region name associated with the client + """ + + template_fields: Sequence[str] = ("job_queue",) + template_ext: Sequence[str] = () + ui_color = "#66c3ff" + + def __init__( + self, + job_queue: str, + treat_non_existing_as_deleted: bool = False, + aws_conn_id: str = "aws_default", + region_name: str | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.job_queue = job_queue + self.treat_non_existing_as_deleted = treat_non_existing_as_deleted + self.aws_conn_id = aws_conn_id + self.region_name = region_name + + @cached_property + def hook(self) -> BatchClientHook: + """Create and return a BatchClientHook""" + return BatchClientHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + ) + + def poke(self, context: Context) -> bool: + response = self.hook.client.describe_job_queues(jobQueues=[self.job_queue]) + + if len(response["jobQueues"]) == 0: + if self.treat_non_existing_as_deleted: + return True + else: + raise AirflowException(f"AWS Batch job queue {self.job_queue} not found") + + status = response["jobQueues"][0]["status"] + + if status in BatchClientHook.JOB_QUEUE_TERMINAL_STATUS: + return True + + if status in BatchClientHook.JOB_QUEUE_INTERMEDIATE_STATUS: + return False + + raise AirflowException(f"AWS Batch job queue failed. AWS Batch job queue status: {status}") diff --git a/airflow/providers/amazon/aws/sensors/cloud_formation.py b/airflow/providers/amazon/aws/sensors/cloud_formation.py index fb01bdc7f6827..d2bd45592654f 100644 --- a/airflow/providers/amazon/aws/sensors/cloud_formation.py +++ b/airflow/providers/amazon/aws/sensors/cloud_formation.py @@ -16,17 +16,14 @@ # specific language governing permissions and limitations # under the License. """This module contains sensors for AWS CloudFormation.""" -import sys -from typing import TYPE_CHECKING, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence if TYPE_CHECKING: from airflow.utils.context import Context -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - +from airflow.compat.functools import cached_property from airflow.providers.amazon.aws.hooks.cloud_formation import CloudFormationHook from airflow.sensors.base import BaseSensorOperator @@ -36,36 +33,35 @@ class CloudFormationCreateStackSensor(BaseSensorOperator): Waits for a stack to be created successfully on AWS CloudFormation. .. seealso:: - For more information on how to use this operator, take a look at the guide: + For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:CloudFormationCreateStackSensor` - :param stack_name: The name of the stack to wait for (templated) :param aws_conn_id: ID of the Airflow connection where credentials and extra configuration are stored :param poke_interval: Time in seconds that the job should wait between each try """ - template_fields: Sequence[str] = ('stack_name',) - ui_color = '#C5CAE9' + template_fields: Sequence[str] = ("stack_name",) + ui_color = "#C5CAE9" - def __init__(self, *, stack_name, aws_conn_id='aws_default', region_name=None, **kwargs): + def __init__(self, *, stack_name, aws_conn_id="aws_default", region_name=None, **kwargs): super().__init__(**kwargs) self.stack_name = stack_name self.aws_conn_id = aws_conn_id self.region_name = region_name - def poke(self, context: 'Context'): + def poke(self, context: Context): stack_status = self.hook.get_stack_status(self.stack_name) - if stack_status == 'CREATE_COMPLETE': + if stack_status == "CREATE_COMPLETE": return True - if stack_status in ('CREATE_IN_PROGRESS', None): + if stack_status in ("CREATE_IN_PROGRESS", None): return False - raise ValueError(f'Stack {self.stack_name} in bad state: {stack_status}') + raise ValueError(f"Stack {self.stack_name} in bad state: {stack_status}") @cached_property def hook(self) -> CloudFormationHook: - """Create and return an CloudFormationHook""" + """Create and return a CloudFormationHook""" return CloudFormationHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) @@ -74,7 +70,7 @@ class CloudFormationDeleteStackSensor(BaseSensorOperator): Waits for a stack to be deleted successfully on AWS CloudFormation. .. seealso:: - For more information on how to use this operator, take a look at the guide: + For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:CloudFormationDeleteStackSensor` :param stack_name: The name of the stack to wait for (templated) @@ -83,15 +79,15 @@ class CloudFormationDeleteStackSensor(BaseSensorOperator): :param poke_interval: Time in seconds that the job should wait between each try """ - template_fields: Sequence[str] = ('stack_name',) - ui_color = '#C5CAE9' + template_fields: Sequence[str] = ("stack_name",) + ui_color = "#C5CAE9" def __init__( self, *, stack_name: str, - aws_conn_id: str = 'aws_default', - region_name: Optional[str] = None, + aws_conn_id: str = "aws_default", + region_name: str | None = None, **kwargs, ): super().__init__(**kwargs) @@ -99,15 +95,15 @@ def __init__( self.region_name = region_name self.stack_name = stack_name - def poke(self, context: 'Context'): + def poke(self, context: Context): stack_status = self.hook.get_stack_status(self.stack_name) - if stack_status in ('DELETE_COMPLETE', None): + if stack_status in ("DELETE_COMPLETE", None): return True - if stack_status == 'DELETE_IN_PROGRESS': + if stack_status == "DELETE_IN_PROGRESS": return False - raise ValueError(f'Stack {self.stack_name} in bad state: {stack_status}') + raise ValueError(f"Stack {self.stack_name} in bad state: {stack_status}") @cached_property def hook(self) -> CloudFormationHook: - """Create and return an CloudFormationHook""" + """Create and return a CloudFormationHook""" return CloudFormationHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) diff --git a/airflow/providers/amazon/aws/sensors/dms.py b/airflow/providers/amazon/aws/sensors/dms.py index 26e6b7148fdcd..9a05d77a21253 100644 --- a/airflow/providers/amazon/aws/sensors/dms.py +++ b/airflow/providers/amazon/aws/sensors/dms.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Iterable, Optional, Sequence +from typing import TYPE_CHECKING, Iterable, Sequence from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.dms import DmsHook @@ -40,15 +41,15 @@ class DmsTaskBaseSensor(BaseSensorOperator): the task reaches any of these states """ - template_fields: Sequence[str] = ('replication_task_arn',) + template_fields: Sequence[str] = ("replication_task_arn",) template_ext: Sequence[str] = () def __init__( self, replication_task_arn: str, - aws_conn_id='aws_default', - target_statuses: Optional[Iterable[str]] = None, - termination_statuses: Optional[Iterable[str]] = None, + aws_conn_id="aws_default", + target_statuses: Iterable[str] | None = None, + termination_statuses: Iterable[str] | None = None, *args, **kwargs, ): @@ -57,7 +58,7 @@ def __init__( self.replication_task_arn = replication_task_arn self.target_statuses: Iterable[str] = target_statuses or [] self.termination_statuses: Iterable[str] = termination_statuses or [] - self.hook: Optional[DmsHook] = None + self.hook: DmsHook | None = None def get_hook(self) -> DmsHook: """Get DmsHook""" @@ -67,21 +68,21 @@ def get_hook(self) -> DmsHook: self.hook = DmsHook(self.aws_conn_id) return self.hook - def poke(self, context: 'Context'): - status: Optional[str] = self.get_hook().get_task_status(self.replication_task_arn) + def poke(self, context: Context): + status: str | None = self.get_hook().get_task_status(self.replication_task_arn) if not status: raise AirflowException( - f'Failed to read task status, task with ARN {self.replication_task_arn} not found' + f"Failed to read task status, task with ARN {self.replication_task_arn} not found" ) - self.log.info('DMS Replication task (%s) has status: %s', self.replication_task_arn, status) + self.log.info("DMS Replication task (%s) has status: %s", self.replication_task_arn, status) if status in self.target_statuses: return True if status in self.termination_statuses: - raise AirflowException(f'Unexpected status: {status}') + raise AirflowException(f"Unexpected status: {status}") return False @@ -91,25 +92,25 @@ class DmsTaskCompletedSensor(DmsTaskBaseSensor): Pokes DMS task until it is completed. .. seealso:: - For more information on how to use this operator, take a look at the guide: + For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:DmsTaskCompletedSensor` :param replication_task_arn: AWS DMS replication task ARN """ - template_fields: Sequence[str] = ('replication_task_arn',) + template_fields: Sequence[str] = ("replication_task_arn",) template_ext: Sequence[str] = () def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.target_statuses = ['stopped'] + self.target_statuses = ["stopped"] self.termination_statuses = [ - 'creating', - 'deleting', - 'failed', - 'failed-move', - 'modifying', - 'moving', - 'ready', - 'testing', + "creating", + "deleting", + "failed", + "failed-move", + "modifying", + "moving", + "ready", + "testing", ] diff --git a/airflow/providers/amazon/aws/sensors/dms_task.py b/airflow/providers/amazon/aws/sensors/dms_task.py deleted file mode 100644 index 26a06c5812556..0000000000000 --- a/airflow/providers/amazon/aws/sensors/dms_task.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.dms`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.dms import DmsTaskBaseSensor, DmsTaskCompletedSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.dms`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/sensors/ec2.py b/airflow/providers/amazon/aws/sensors/ec2.py index 7d4a640593fec..bfdb8fcd411c0 100644 --- a/airflow/providers/amazon/aws/sensors/ec2.py +++ b/airflow/providers/amazon/aws/sensors/ec2.py @@ -15,9 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Sequence from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook from airflow.sensors.base import BaseSensorOperator @@ -51,7 +51,7 @@ def __init__( target_state: str, instance_id: str, aws_conn_id: str = "aws_default", - region_name: Optional[str] = None, + region_name: str | None = None, **kwargs, ): if target_state not in self.valid_states: @@ -62,7 +62,7 @@ def __init__( self.aws_conn_id = aws_conn_id self.region_name = region_name - def poke(self, context: 'Context'): + def poke(self, context: Context): ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) instance_state = ec2_hook.get_instance_state(instance_id=self.instance_id) self.log.info("instance state: %s", instance_state) diff --git a/airflow/providers/amazon/aws/sensors/ec2_instance_state.py b/airflow/providers/amazon/aws/sensors/ec2_instance_state.py deleted file mode 100644 index d166b69d20ffb..0000000000000 --- a/airflow/providers/amazon/aws/sensors/ec2_instance_state.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.ec2`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.ec2 import EC2InstanceStateSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.ec2`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/sensors/ecs.py b/airflow/providers/amazon/aws/sensors/ecs.py new file mode 100644 index 0000000000000..c1151b8f9a55c --- /dev/null +++ b/airflow/providers/amazon/aws/sensors/ecs.py @@ -0,0 +1,188 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence + +import boto3 + +from airflow import AirflowException +from airflow.compat.functools import cached_property +from airflow.providers.amazon.aws.hooks.ecs import ( + EcsClusterStates, + EcsHook, + EcsTaskDefinitionStates, + EcsTaskStates, +) +from airflow.sensors.base import BaseSensorOperator + +if TYPE_CHECKING: + from airflow.utils.context import Context + +DEFAULT_CONN_ID: str = "aws_default" + + +def _check_failed(current_state, target_state, failure_states): + if (current_state != target_state) and (current_state in failure_states): + raise AirflowException( + f"Terminal state reached. Current state: {current_state}, Expected state: {target_state}" + ) + + +class EcsBaseSensor(BaseSensorOperator): + """Contains general sensor behavior for Elastic Container Service.""" + + def __init__(self, *, aws_conn_id: str | None = DEFAULT_CONN_ID, region: str | None = None, **kwargs): + self.aws_conn_id = aws_conn_id + self.region = region + super().__init__(**kwargs) + + @cached_property + def hook(self) -> EcsHook: + """Create and return an EcsHook.""" + return EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region) + + @cached_property + def client(self) -> boto3.client: + """Create and return an EcsHook client.""" + return self.hook.conn + + +class EcsClusterStateSensor(EcsBaseSensor): + """ + Polls the cluster state until it reaches a terminal state. Raises an + AirflowException with the failure reason if a failed state is reached. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/sensor:EcsClusterStateSensor` + + :param cluster_name: The name of your cluster. + :param target_state: Success state to watch for. (Default: "ACTIVE") + :param failure_states: Fail if any of these states are reached before the + Success State. (Default: "FAILED" or "INACTIVE") + """ + + template_fields: Sequence[str] = ("cluster_name", "target_state", "failure_states") + + def __init__( + self, + *, + cluster_name: str, + target_state: EcsClusterStates | None = EcsClusterStates.ACTIVE, + failure_states: set[EcsClusterStates] | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.cluster_name = cluster_name + self.target_state = target_state + self.failure_states = failure_states or {EcsClusterStates.FAILED, EcsClusterStates.INACTIVE} + + def poke(self, context: Context): + cluster_state = EcsClusterStates(self.hook.get_cluster_state(cluster_name=self.cluster_name)) + + self.log.info("Cluster state: %s, waiting for: %s", cluster_state, self.target_state) + _check_failed(cluster_state, self.target_state, self.failure_states) + + return cluster_state == self.target_state + + +class EcsTaskDefinitionStateSensor(EcsBaseSensor): + """ + Polls the task definition state until it reaches a terminal state. Raises an + AirflowException with the failure reason if a failed state is reached. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/sensor:EcsTaskDefinitionStateSensor` + + :param task_definition: The family for the latest ACTIVE revision, family and + revision (family:revision ) for a specific revision in the family, or full + Amazon Resource Name (ARN) of the task definition. + :param target_state: Success state to watch for. (Default: "ACTIVE") + """ + + template_fields: Sequence[str] = ("task_definition", "target_state", "failure_states") + + def __init__( + self, + *, + task_definition: str, + target_state: EcsTaskDefinitionStates | None = EcsTaskDefinitionStates.ACTIVE, + **kwargs, + ): + super().__init__(**kwargs) + self.task_definition = task_definition + self.target_state = target_state + # There are only two possible states, so set failure_state to whatever is not the target_state + self.failure_states = { + ( + EcsTaskDefinitionStates.INACTIVE + if target_state == EcsTaskDefinitionStates.ACTIVE + else EcsTaskDefinitionStates.ACTIVE + ) + } + + def poke(self, context: Context): + task_definition_state = EcsTaskDefinitionStates( + self.hook.get_task_definition_state(task_definition=self.task_definition) + ) + + self.log.info("Task Definition state: %s, waiting for: %s", task_definition_state, self.target_state) + _check_failed(task_definition_state, self.target_state, [self.failure_states]) + return task_definition_state == self.target_state + + +class EcsTaskStateSensor(EcsBaseSensor): + """ + Polls the task state until it reaches a terminal state. Raises an + AirflowException with the failure reason if a failed state is reached. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/sensor:EcsTaskStateSensor` + + :param cluster: The short name or full Amazon Resource Name (ARN) of the cluster that hosts the task. + :param task: The task ID or full ARN of the task to poll. + :param target_state: Success state to watch for. (Default: "ACTIVE") + :param failure_states: Fail if any of these states are reached before + the Success State. (Default: "STOPPED") + """ + + template_fields: Sequence[str] = ("cluster", "task", "target_state", "failure_states") + + def __init__( + self, + *, + cluster: str, + task: str, + target_state: EcsTaskStates | None = EcsTaskStates.RUNNING, + failure_states: set[EcsTaskStates] | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.cluster = cluster + self.task = task + self.target_state = target_state + self.failure_states = failure_states or {EcsTaskStates.STOPPED} + + def poke(self, context: Context): + task_state = EcsTaskStates(self.hook.get_task_state(cluster=self.cluster, task=self.task)) + + self.log.info("Task state: %s, waiting for: %s", task_state, self.target_state) + _check_failed(task_state, self.target_state, self.failure_states) + return task_state == self.target_state diff --git a/airflow/providers/amazon/aws/sensors/eks.py b/airflow/providers/amazon/aws/sensors/eks.py index 92ed55da4d31e..f2ee37215173d 100644 --- a/airflow/providers/amazon/aws/sensors/eks.py +++ b/airflow/providers/amazon/aws/sensors/eks.py @@ -14,10 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """Tracking the state of Amazon EKS Clusters, Amazon EKS managed node groups, and AWS Fargate profiles.""" -import warnings -from typing import TYPE_CHECKING, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.eks import ( @@ -85,7 +85,7 @@ def __init__( cluster_name: str, target_state: ClusterStates = ClusterStates.ACTIVE, aws_conn_id: str = DEFAULT_CONN_ID, - region: Optional[str] = None, + region: str | None = None, **kwargs, ): self.cluster_name = cluster_name @@ -98,7 +98,7 @@ def __init__( self.region = region super().__init__(**kwargs) - def poke(self, context: 'Context'): + def poke(self, context: Context): eks_hook = EksHook( aws_conn_id=self.aws_conn_id, region_name=self.region, @@ -121,7 +121,7 @@ class EksFargateProfileStateSensor(BaseSensorOperator): Check the state of an AWS Fargate profile until it reaches the target state or another terminal state. .. seealso:: - For more information on how to use this operator, take a look at the guide: + For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:EksFargateProfileStateSensor` :param cluster_name: The name of the Cluster which the AWS Fargate profile is attached to. (templated) @@ -153,7 +153,7 @@ def __init__( fargate_profile_name: str, target_state: FargateProfileStates = FargateProfileStates.ACTIVE, aws_conn_id: str = DEFAULT_CONN_ID, - region: Optional[str] = None, + region: str | None = None, **kwargs, ): self.cluster_name = cluster_name @@ -167,7 +167,7 @@ def __init__( self.region = region super().__init__(**kwargs) - def poke(self, context: 'Context'): + def poke(self, context: Context): eks_hook = EksHook( aws_conn_id=self.aws_conn_id, region_name=self.region, @@ -224,7 +224,7 @@ def __init__( nodegroup_name: str, target_state: NodegroupStates = NodegroupStates.ACTIVE, aws_conn_id: str = DEFAULT_CONN_ID, - region: Optional[str] = None, + region: str | None = None, **kwargs, ): self.cluster_name = cluster_name @@ -238,7 +238,7 @@ def __init__( self.region = region super().__init__(**kwargs) - def poke(self, context: 'Context'): + def poke(self, context: Context): eks_hook = EksHook( aws_conn_id=self.aws_conn_id, region_name=self.region, @@ -256,51 +256,3 @@ def poke(self, context: 'Context'): ) ) return nodegroup_state == self.target_state - - -class EKSClusterStateSensor(EksClusterStateSensor): - """ - This sensor is deprecated. - Please use :class:`airflow.providers.amazon.aws.sensors.eks.EksClusterStateSensor`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This sensor is deprecated. " - "Please use `airflow.providers.amazon.aws.sensors.eks.EksClusterStateSensor`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class EKSFargateProfileStateSensor(EksFargateProfileStateSensor): - """ - This sensor is deprecated. - Please use :class:`airflow.providers.amazon.aws.sensors.eks.EksFargateProfileStateSensor`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This sensor is deprecated. " - "Please use `airflow.providers.amazon.aws.sensors.eks.EksFargateProfileStateSensor`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - -class EKSNodegroupStateSensor(EksNodegroupStateSensor): - """ - This sensor is deprecated. - Please use :class:`airflow.providers.amazon.aws.sensors.eks.EksNodegroupStateSensor`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This sensor is deprecated. " - "Please use `airflow.providers.amazon.aws.sensors.eks.EksNodegroupStateSensor`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py index a4c2b3a71142b..a3684fa249a1d 100644 --- a/airflow/providers/amazon/aws/sensors/emr.py +++ b/airflow/providers/amazon/aws/sensors/emr.py @@ -15,21 +15,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Sequence +from __future__ import annotations -if TYPE_CHECKING: - from airflow.utils.context import Context +from typing import TYPE_CHECKING, Any, Iterable, Sequence +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook +from airflow.sensors.base import BaseSensorOperator, poke_mode_only -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property +if TYPE_CHECKING: + from airflow.utils.context import Context -from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook -from airflow.sensors.base import BaseSensorOperator +from airflow.compat.functools import cached_property class EmrBaseSensor(BaseSensorOperator): @@ -43,17 +40,17 @@ class EmrBaseSensor(BaseSensorOperator): Subclasses should set ``target_states`` and ``failed_states`` fields. - :param aws_conn_id: aws connection to uses + :param aws_conn_id: aws connection to use """ - ui_color = '#66c3ff' + ui_color = "#66c3ff" - def __init__(self, *, aws_conn_id: str = 'aws_default', **kwargs): + def __init__(self, *, aws_conn_id: str = "aws_default", **kwargs): super().__init__(**kwargs) self.aws_conn_id = aws_conn_id self.target_states: Iterable[str] = [] # will be set in subclasses self.failed_states: Iterable[str] = [] # will be set in subclasses - self.hook: Optional[EmrHook] = None + self.hook: EmrHook | None = None def get_hook(self) -> EmrHook: """Get EmrHook""" @@ -63,58 +60,173 @@ def get_hook(self) -> EmrHook: self.hook = EmrHook(aws_conn_id=self.aws_conn_id) return self.hook - def poke(self, context: 'Context'): + def poke(self, context: Context): response = self.get_emr_response() - if response['ResponseMetadata']['HTTPStatusCode'] != 200: - self.log.info('Bad HTTP response: %s', response) + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + self.log.info("Bad HTTP response: %s", response) return False state = self.state_from_response(response) - self.log.info('Job flow currently %s', state) + self.log.info("Job flow currently %s", state) if state in self.target_states: return True if state in self.failed_states: - final_message = 'EMR job failed' + final_message = "EMR job failed" failure_message = self.failure_message_from_response(response) if failure_message: - final_message += ' ' + failure_message + final_message += " " + failure_message raise AirflowException(final_message) return False - def get_emr_response(self) -> Dict[str, Any]: + def get_emr_response(self) -> dict[str, Any]: """ Make an API call with boto3 and get response. :return: response - :rtype: dict[str, Any] """ - raise NotImplementedError('Please implement get_emr_response() in subclass') + raise NotImplementedError("Please implement get_emr_response() in subclass") @staticmethod - def state_from_response(response: Dict[str, Any]) -> str: + def state_from_response(response: dict[str, Any]) -> str: """ Get state from response dictionary. :param response: response from AWS API :return: state - :rtype: str """ - raise NotImplementedError('Please implement state_from_response() in subclass') + raise NotImplementedError("Please implement state_from_response() in subclass") + + @staticmethod + def failure_message_from_response(response: dict[str, Any]) -> str | None: + """ + Get failure message from response dictionary. + + :param response: response from AWS API + :return: failure message + """ + raise NotImplementedError("Please implement failure_message_from_response() in subclass") + + +class EmrServerlessJobSensor(BaseSensorOperator): + """ + Asks for the state of the job run until it reaches a failure state or success state. + If the job run fails, the task will fail. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:EmrServerlessJobSensor` + + :param application_id: application_id to check the state of + :param job_run_id: job_run_id to check the state of + :param target_states: a set of states to wait for, defaults to 'SUCCESS' + :param aws_conn_id: aws connection to use, defaults to 'aws_default' + """ + + template_fields: Sequence[str] = ( + "application_id", + "job_run_id", + ) + + def __init__( + self, + *, + application_id: str, + job_run_id: str, + target_states: set | frozenset = frozenset(EmrServerlessHook.JOB_SUCCESS_STATES), + aws_conn_id: str = "aws_default", + **kwargs: Any, + ) -> None: + self.aws_conn_id = aws_conn_id + self.target_states = target_states + self.application_id = application_id + self.job_run_id = job_run_id + super().__init__(**kwargs) + + def poke(self, context: Context) -> bool: + response = self.hook.conn.get_job_run(applicationId=self.application_id, jobRunId=self.job_run_id) + + state = response["jobRun"]["state"] + + if state in EmrServerlessHook.JOB_FAILURE_STATES: + failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}" + raise AirflowException(failure_message) + + return state in self.target_states + + @cached_property + def hook(self) -> EmrServerlessHook: + """Create and return an EmrServerlessHook""" + return EmrServerlessHook(aws_conn_id=self.aws_conn_id) + + @staticmethod + def failure_message_from_response(response: dict[str, Any]) -> str | None: + """ + Get failure message from response dictionary. + + :param response: response from AWS API + :return: failure message + """ + return response["jobRun"]["stateDetails"] + + +class EmrServerlessApplicationSensor(BaseSensorOperator): + """ + Asks for the state of the application until it reaches a failure state or success state. + If the application fails, the task will fail. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:EmrServerlessApplicationSensor` + + :param application_id: application_id to check the state of + :param target_states: a set of states to wait for, defaults to {'CREATED', 'STARTED'} + :param aws_conn_id: aws connection to use, defaults to 'aws_default' + """ + + template_fields: Sequence[str] = ("application_id",) + + def __init__( + self, + *, + application_id: str, + target_states: set | frozenset = frozenset(EmrServerlessHook.APPLICATION_SUCCESS_STATES), + aws_conn_id: str = "aws_default", + **kwargs: Any, + ) -> None: + self.aws_conn_id = aws_conn_id + self.target_states = target_states + self.application_id = application_id + super().__init__(**kwargs) + + def poke(self, context: Context) -> bool: + response = self.hook.conn.get_application(applicationId=self.application_id) + + state = response["application"]["state"] + + if state in EmrServerlessHook.APPLICATION_FAILURE_STATES: + failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}" + raise AirflowException(failure_message) + + return state in self.target_states + + @cached_property + def hook(self) -> EmrServerlessHook: + """Create and return an EmrServerlessHook""" + return EmrServerlessHook(aws_conn_id=self.aws_conn_id) @staticmethod - def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]: + def failure_message_from_response(response: dict[str, Any]) -> str | None: """ Get failure message from response dictionary. :param response: response from AWS API :return: failure message - :rtype: Optional[str] """ - raise NotImplementedError('Please implement failure_message_from_response() in subclass') + return response["application"]["stateDetails"] class EmrContainerSensor(BaseSensorOperator): @@ -146,17 +258,17 @@ class EmrContainerSensor(BaseSensorOperator): ) SUCCESS_STATES = ("COMPLETED",) - template_fields: Sequence[str] = ('virtual_cluster_id', 'job_id') + template_fields: Sequence[str] = ("virtual_cluster_id", "job_id") template_ext: Sequence[str] = () - ui_color = '#66c3ff' + ui_color = "#66c3ff" def __init__( self, *, virtual_cluster_id: str, job_id: str, - max_retries: Optional[int] = None, - aws_conn_id: str = 'aws_default', + max_retries: int | None = None, + aws_conn_id: str = "aws_default", poll_interval: int = 10, **kwargs: Any, ) -> None: @@ -167,11 +279,15 @@ def __init__( self.poll_interval = poll_interval self.max_retries = max_retries - def poke(self, context: 'Context') -> bool: - state = self.hook.poll_query_status(self.job_id, self.max_retries, self.poll_interval) + def poke(self, context: Context) -> bool: + state = self.hook.poll_query_status( + self.job_id, + max_polling_attempts=self.max_retries, + poll_interval=self.poll_interval, + ) if state in self.FAILURE_STATES: - raise AirflowException('EMR Containers sensor failed') + raise AirflowException("EMR Containers sensor failed") if state in self.INTERMEDIATE_STATES: return False @@ -204,23 +320,23 @@ class EmrJobFlowSensor(EmrBaseSensor): job flow reaches any of these states """ - template_fields: Sequence[str] = ('job_flow_id', 'target_states', 'failed_states') + template_fields: Sequence[str] = ("job_flow_id", "target_states", "failed_states") template_ext: Sequence[str] = () def __init__( self, *, job_flow_id: str, - target_states: Optional[Iterable[str]] = None, - failed_states: Optional[Iterable[str]] = None, + target_states: Iterable[str] | None = None, + failed_states: Iterable[str] | None = None, **kwargs, ): super().__init__(**kwargs) self.job_flow_id = job_flow_id - self.target_states = target_states or ['TERMINATED'] - self.failed_states = failed_states or ['TERMINATED_WITH_ERRORS'] + self.target_states = target_states or ["TERMINATED"] + self.failed_states = failed_states or ["TERMINATED_WITH_ERRORS"] - def get_emr_response(self) -> Dict[str, Any]: + def get_emr_response(self) -> dict[str, Any]: """ Make an API call with boto3 and get cluster-level details. @@ -228,35 +344,32 @@ def get_emr_response(self) -> Dict[str, Any]: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr.html#EMR.Client.describe_cluster :return: response - :rtype: dict[str, Any] """ emr_client = self.get_hook().get_conn() - self.log.info('Poking cluster %s', self.job_flow_id) + self.log.info("Poking cluster %s", self.job_flow_id) return emr_client.describe_cluster(ClusterId=self.job_flow_id) @staticmethod - def state_from_response(response: Dict[str, Any]) -> str: + def state_from_response(response: dict[str, Any]) -> str: """ Get state from response dictionary. :param response: response from AWS API :return: current state of the cluster - :rtype: str """ - return response['Cluster']['Status']['State'] + return response["Cluster"]["Status"]["State"] @staticmethod - def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]: + def failure_message_from_response(response: dict[str, Any]) -> str | None: """ Get failure message from response dictionary. :param response: response from AWS API :return: failure message - :rtype: Optional[str] """ - cluster_status = response['Cluster']['Status'] - state_change_reason = cluster_status.get('StateChangeReason') + cluster_status = response["Cluster"]["Status"] + state_change_reason = cluster_status.get("StateChangeReason") if state_change_reason: return ( f"for code: {state_change_reason.get('Code', 'No code')} " @@ -265,6 +378,7 @@ def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]: return None +@poke_mode_only class EmrStepSensor(EmrBaseSensor): """ Asks for the state of the step until it reaches any of the target states. @@ -284,7 +398,7 @@ class EmrStepSensor(EmrBaseSensor): step reaches any of these states """ - template_fields: Sequence[str] = ('job_flow_id', 'step_id', 'target_states', 'failed_states') + template_fields: Sequence[str] = ("job_flow_id", "step_id", "target_states", "failed_states") template_ext: Sequence[str] = () def __init__( @@ -292,17 +406,17 @@ def __init__( *, job_flow_id: str, step_id: str, - target_states: Optional[Iterable[str]] = None, - failed_states: Optional[Iterable[str]] = None, + target_states: Iterable[str] | None = None, + failed_states: Iterable[str] | None = None, **kwargs, ): super().__init__(**kwargs) self.job_flow_id = job_flow_id self.step_id = step_id - self.target_states = target_states or ['COMPLETED'] - self.failed_states = failed_states or ['CANCELLED', 'FAILED', 'INTERRUPTED'] + self.target_states = target_states or ["COMPLETED"] + self.failed_states = failed_states or ["CANCELLED", "FAILED", "INTERRUPTED"] - def get_emr_response(self) -> Dict[str, Any]: + def get_emr_response(self) -> dict[str, Any]: """ Make an API call with boto3 and get details about the cluster step. @@ -310,34 +424,31 @@ def get_emr_response(self) -> Dict[str, Any]: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr.html#EMR.Client.describe_step :return: response - :rtype: dict[str, Any] """ emr_client = self.get_hook().get_conn() - self.log.info('Poking step %s on cluster %s', self.step_id, self.job_flow_id) + self.log.info("Poking step %s on cluster %s", self.step_id, self.job_flow_id) return emr_client.describe_step(ClusterId=self.job_flow_id, StepId=self.step_id) @staticmethod - def state_from_response(response: Dict[str, Any]) -> str: + def state_from_response(response: dict[str, Any]) -> str: """ Get state from response dictionary. :param response: response from AWS API :return: execution state of the cluster step - :rtype: str """ - return response['Step']['Status']['State'] + return response["Step"]["Status"]["State"] @staticmethod - def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]: + def failure_message_from_response(response: dict[str, Any]) -> str | None: """ Get failure message from response dictionary. :param response: response from AWS API :return: failure message - :rtype: Optional[str] """ - fail_details = response['Step']['Status'].get('FailureDetails') + fail_details = response["Step"]["Status"].get("FailureDetails") if fail_details: return ( f"for reason {fail_details.get('Reason')} " diff --git a/airflow/providers/amazon/aws/sensors/emr_base.py b/airflow/providers/amazon/aws/sensors/emr_base.py deleted file mode 100644 index 89991d7fa04aa..0000000000000 --- a/airflow/providers/amazon/aws/sensors/emr_base.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.emr`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.emr import EmrBaseSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/sensors/emr_containers.py b/airflow/providers/amazon/aws/sensors/emr_containers.py deleted file mode 100644 index 6cfa9adb7d76d..0000000000000 --- a/airflow/providers/amazon/aws/sensors/emr_containers.py +++ /dev/null @@ -1,45 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.emr`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.emr import EmrBaseSensor, EmrContainerSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr`.", - DeprecationWarning, - stacklevel=2, -) - - -class EMRContainerSensor(EmrContainerSensor): - """ - This class is deprecated. - Please use :class:`airflow.providers.amazon.aws.sensors.emr.EmrContainerSensor`. - """ - - def __init__(self, **kwargs): - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.amazon.aws.sensors.emr.EmrContainerSensor`.""", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(**kwargs) diff --git a/airflow/providers/amazon/aws/sensors/emr_job_flow.py b/airflow/providers/amazon/aws/sensors/emr_job_flow.py deleted file mode 100644 index 31b6dabb4a379..0000000000000 --- a/airflow/providers/amazon/aws/sensors/emr_job_flow.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.emr`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.emr import EmrJobFlowSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/sensors/emr_step.py b/airflow/providers/amazon/aws/sensors/emr_step.py deleted file mode 100644 index aca71619089e7..0000000000000 --- a/airflow/providers/amazon/aws/sensors/emr_step.py +++ /dev/null @@ -1,30 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.emr`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.emr import EmrJobFlowSensor, EmrStepSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/sensors/glacier.py b/airflow/providers/amazon/aws/sensors/glacier.py index e92f5a4326b3a..857e578327a9e 100644 --- a/airflow/providers/amazon/aws/sensors/glacier.py +++ b/airflow/providers/amazon/aws/sensors/glacier.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from enum import Enum from typing import TYPE_CHECKING, Any, Sequence @@ -38,7 +40,7 @@ class GlacierJobOperationSensor(BaseSensorOperator): Glacier sensor for checking job state. This operator runs only in reschedule mode. .. seealso:: - For more information on how to use this operator, take a look at the guide: + For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:GlacierJobOperationSensor` :param aws_conn_id: The reference to the AWS connection details @@ -65,7 +67,7 @@ class GlacierJobOperationSensor(BaseSensorOperator): def __init__( self, *, - aws_conn_id: str = 'aws_default', + aws_conn_id: str = "aws_default", vault_name: str, job_id: str, poke_interval: int = 60 * 20, @@ -79,7 +81,7 @@ def __init__( self.poke_interval = poke_interval self.mode = mode - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: hook = GlacierHook(aws_conn_id=self.aws_conn_id) response = hook.describe_job(vault_name=self.vault_name, job_id=self.job_id) diff --git a/airflow/providers/amazon/aws/sensors/glue.py b/airflow/providers/amazon/aws/sensors/glue.py index 525e7b8ee6234..87e8f2c249d63 100644 --- a/airflow/providers/amazon/aws/sensors/glue.py +++ b/airflow/providers/amazon/aws/sensors/glue.py @@ -15,7 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import warnings +from __future__ import annotations + from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException @@ -37,43 +38,51 @@ class GlueJobSensor(BaseSensorOperator): :param job_name: The AWS Glue Job unique name :param run_id: The AWS Glue current running job identifier + :param verbose: If True, more Glue Job Run logs show in the Airflow Task Logs. (default: False) """ - template_fields: Sequence[str] = ('job_name', 'run_id') + template_fields: Sequence[str] = ("job_name", "run_id") - def __init__(self, *, job_name: str, run_id: str, aws_conn_id: str = 'aws_default', **kwargs): + def __init__( + self, + *, + job_name: str, + run_id: str, + verbose: bool = False, + aws_conn_id: str = "aws_default", + **kwargs, + ): super().__init__(**kwargs) self.job_name = job_name self.run_id = run_id + self.verbose = verbose self.aws_conn_id = aws_conn_id - self.success_states = ['SUCCEEDED'] - self.errored_states = ['FAILED', 'STOPPED', 'TIMEOUT'] + self.success_states: list[str] = ["SUCCEEDED"] + self.errored_states: list[str] = ["FAILED", "STOPPED", "TIMEOUT"] + self.next_log_token: str | None = None - def poke(self, context: 'Context'): + def poke(self, context: Context): hook = GlueJobHook(aws_conn_id=self.aws_conn_id) self.log.info("Poking for job run status :for Glue Job %s and ID %s", self.job_name, self.run_id) job_state = hook.get_job_state(job_name=self.job_name, run_id=self.run_id) - if job_state in self.success_states: - self.log.info("Exiting Job %s Run State: %s", self.run_id, job_state) - return True - elif job_state in self.errored_states: - job_error_message = f"Exiting Job {self.run_id} Run State: {job_state}" - raise AirflowException(job_error_message) - else: - return False - - -class AwsGlueJobSensor(GlueJobSensor): - """ - This sensor is deprecated. - Please use :class:`airflow.providers.amazon.aws.sensors.glue.GlueJobSensor`. - """ + job_failed = False - def __init__(self, *args, **kwargs): - warnings.warn( - "This sensor is deprecated. " - "Please use :class:`airflow.providers.amazon.aws.sensors.glue.GlueJobSensor`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) + try: + if job_state in self.success_states: + self.log.info("Exiting Job %s Run State: %s", self.run_id, job_state) + return True + elif job_state in self.errored_states: + job_failed = True + job_error_message = "Exiting Job %s Run State: %s", self.run_id, job_state + self.log.info(job_error_message) + raise AirflowException(job_error_message) + else: + return False + finally: + if self.verbose: + self.next_log_token = hook.print_job_logs( + job_name=self.job_name, + run_id=self.run_id, + job_failed=job_failed, + next_token=self.next_log_token, + ) diff --git a/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py b/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py index c49277f34342f..21bf8cb772447 100644 --- a/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +++ b/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import warnings -from typing import TYPE_CHECKING, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook from airflow.sensors.base import BaseSensorOperator @@ -47,20 +48,20 @@ class GlueCatalogPartitionSensor(BaseSensorOperator): """ template_fields: Sequence[str] = ( - 'database_name', - 'table_name', - 'expression', + "database_name", + "table_name", + "expression", ) - ui_color = '#C5CAE9' + ui_color = "#C5CAE9" def __init__( self, *, table_name: str, expression: str = "ds='{{ ds }}'", - aws_conn_id: str = 'aws_default', - region_name: Optional[str] = None, - database_name: str = 'default', + aws_conn_id: str = "aws_default", + region_name: str | None = None, + database_name: str = "default", poke_interval: int = 60 * 3, **kwargs, ): @@ -70,14 +71,14 @@ def __init__( self.table_name = table_name self.expression = expression self.database_name = database_name - self.hook: Optional[GlueCatalogHook] = None + self.hook: GlueCatalogHook | None = None - def poke(self, context: 'Context'): + def poke(self, context: Context): """Checks for existence of the partition in the AWS Glue Catalog table""" - if '.' in self.table_name: - self.database_name, self.table_name = self.table_name.split('.') + if "." in self.table_name: + self.database_name, self.table_name = self.table_name.split(".") self.log.info( - 'Poking for table %s. %s, expression %s', self.database_name, self.table_name, self.expression + "Poking for table %s. %s, expression %s", self.database_name, self.table_name, self.expression ) return self.get_hook().check_for_partition(self.database_name, self.table_name, self.expression) @@ -89,19 +90,3 @@ def get_hook(self) -> GlueCatalogHook: self.hook = GlueCatalogHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) return self.hook - - -class AwsGlueCatalogPartitionSensor(GlueCatalogPartitionSensor): - """ - This sensor is deprecated. Please use - :class:`airflow.providers.amazon.aws.sensors.glue_catalog_partition.GlueCatalogPartitionSensor`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This sensor is deprecated. " - "Please use :class:`airflow.providers.amazon.aws.sensors.glue_catalog_partition.GlueCatalogPartitionSensor`.", # noqa: 501 - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/sensors/glue_crawler.py b/airflow/providers/amazon/aws/sensors/glue_crawler.py index 52944ce2eb30e..3032c603c625b 100644 --- a/airflow/providers/amazon/aws/sensors/glue_crawler.py +++ b/airflow/providers/amazon/aws/sensors/glue_crawler.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import warnings -from typing import TYPE_CHECKING, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook @@ -39,21 +40,23 @@ class GlueCrawlerSensor(BaseSensorOperator): :param aws_conn_id: aws connection to use, defaults to 'aws_default' """ - def __init__(self, *, crawler_name: str, aws_conn_id: str = 'aws_default', **kwargs) -> None: + template_fields: Sequence[str] = ("crawler_name",) + + def __init__(self, *, crawler_name: str, aws_conn_id: str = "aws_default", **kwargs) -> None: super().__init__(**kwargs) self.crawler_name = crawler_name self.aws_conn_id = aws_conn_id - self.success_statuses = 'SUCCEEDED' - self.errored_statuses = ('FAILED', 'CANCELLED') - self.hook: Optional[GlueCrawlerHook] = None + self.success_statuses = "SUCCEEDED" + self.errored_statuses = ("FAILED", "CANCELLED") + self.hook: GlueCrawlerHook | None = None - def poke(self, context: 'Context'): + def poke(self, context: Context): hook = self.get_hook() self.log.info("Poking for AWS Glue crawler: %s", self.crawler_name) - crawler_state = hook.get_crawler(self.crawler_name)['State'] - if crawler_state == 'READY': + crawler_state = hook.get_crawler(self.crawler_name)["State"] + if crawler_state == "READY": self.log.info("State: %s", crawler_state) - crawler_status = hook.get_crawler(self.crawler_name)['LastCrawl']['Status'] + crawler_status = hook.get_crawler(self.crawler_name)["LastCrawl"]["Status"] if crawler_status == self.success_statuses: self.log.info("Status: %s", crawler_status) return True @@ -69,19 +72,3 @@ def get_hook(self) -> GlueCrawlerHook: self.hook = GlueCrawlerHook(aws_conn_id=self.aws_conn_id) return self.hook - - -class AwsGlueCrawlerSensor(GlueCrawlerSensor): - """ - This sensor is deprecated. Please use - :class:`airflow.providers.amazon.aws.sensors.glue_crawler.GlueCrawlerSensor`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This sensor is deprecated. " - "Please use :class:`airflow.providers.amazon.aws.sensors.glue_crawler.GlueCrawlerSensor`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/sensors/quicksight.py b/airflow/providers/amazon/aws/sensors/quicksight.py index da94980e80650..09cc92cf96dd5 100644 --- a/airflow/providers/amazon/aws/sensors/quicksight.py +++ b/airflow/providers/amazon/aws/sensors/quicksight.py @@ -15,10 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -import sys -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Sequence +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook from airflow.providers.amazon.aws.hooks.sts import StsHook @@ -27,11 +28,6 @@ if TYPE_CHECKING: from airflow.utils.context import Context -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - class QuickSightSensor(BaseSensorOperator): """ @@ -50,6 +46,8 @@ class QuickSightSensor(BaseSensorOperator): maintained on each worker node). """ + template_fields: Sequence[str] = ("data_set_id", "ingestion_id", "aws_conn_id") + def __init__( self, *, @@ -64,16 +62,15 @@ def __init__( self.aws_conn_id = aws_conn_id self.success_status = "COMPLETED" self.errored_statuses = ("FAILED", "CANCELLED") - self.quicksight_hook: Optional[QuickSightHook] = None - self.sts_hook: Optional[StsHook] = None + self.quicksight_hook: QuickSightHook | None = None + self.sts_hook: StsHook | None = None - def poke(self, context: "Context"): + def poke(self, context: Context) -> bool: """ Pokes until the QuickSight Ingestion has successfully finished. :param context: The task context during execution. :return: True if it COMPLETED and False if not. - :rtype: bool """ quicksight_hook = self.get_quicksight_hook sts_hook = self.get_sts_hook diff --git a/airflow/providers/amazon/aws/sensors/rds.py b/airflow/providers/amazon/aws/sensors/rds.py index 3c24c82fbf940..731c8b5def6fb 100644 --- a/airflow/providers/amazon/aws/sensors/rds.py +++ b/airflow/providers/amazon/aws/sensors/rds.py @@ -14,12 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional, Sequence +from typing import TYPE_CHECKING, Sequence -from botocore.exceptions import ClientError - -from airflow import AirflowException +from airflow.exceptions import AirflowNotFoundException from airflow.providers.amazon.aws.hooks.rds import RdsHook from airflow.providers.amazon.aws.utils.rds import RdsDbType from airflow.sensors.base import BaseSensorOperator @@ -34,43 +33,19 @@ class RdsBaseSensor(BaseSensorOperator): ui_color = "#ddbb77" ui_fgcolor = "#ffffff" - def __init__(self, *args, aws_conn_id: str = "aws_conn_id", hook_params: Optional[dict] = None, **kwargs): + def __init__(self, *args, aws_conn_id: str = "aws_conn_id", hook_params: dict | None = None, **kwargs): hook_params = hook_params or {} self.hook = RdsHook(aws_conn_id=aws_conn_id, **hook_params) - self.target_statuses: List[str] = [] + self.target_statuses: list[str] = [] super().__init__(*args, **kwargs) - def _describe_item(self, item_type: str, item_name: str) -> list: - - if item_type == 'instance_snapshot': - db_snaps = self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=item_name) - return db_snaps['DBSnapshots'] - elif item_type == 'cluster_snapshot': - cl_snaps = self.hook.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=item_name) - return cl_snaps['DBClusterSnapshots'] - elif item_type == 'export_task': - exports = self.hook.conn.describe_export_tasks(ExportTaskIdentifier=item_name) - return exports['ExportTasks'] - else: - raise AirflowException(f"Method for {item_type} is not implemented") - - def _check_item(self, item_type: str, item_name: str) -> bool: - """Get certain item from `_describe_item()` and check its status""" - - try: - items = self._describe_item(item_type, item_name) - except ClientError: - return False - else: - return bool(items) and any(map(lambda s: items[0]['Status'].lower() == s, self.target_statuses)) - class RdsSnapshotExistenceSensor(RdsBaseSensor): """ Waits for RDS snapshot with a specific status. .. seealso:: - For more information on how to use this operator, take a look at the guide: + For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:RdsSnapshotExistenceSensor` :param db_type: Type of the DB - either "instance" or "cluster" @@ -79,8 +54,8 @@ class RdsSnapshotExistenceSensor(RdsBaseSensor): """ template_fields: Sequence[str] = ( - 'db_snapshot_identifier', - 'target_statuses', + "db_snapshot_identifier", + "target_statuses", ) def __init__( @@ -88,23 +63,27 @@ def __init__( *, db_type: str, db_snapshot_identifier: str, - target_statuses: Optional[List[str]] = None, + target_statuses: list[str] | None = None, aws_conn_id: str = "aws_conn_id", **kwargs, ): super().__init__(aws_conn_id=aws_conn_id, **kwargs) self.db_type = RdsDbType(db_type) self.db_snapshot_identifier = db_snapshot_identifier - self.target_statuses = target_statuses or ['available'] + self.target_statuses = target_statuses or ["available"] - def poke(self, context: 'Context'): + def poke(self, context: Context): self.log.info( - 'Poking for statuses : %s\nfor snapshot %s', self.target_statuses, self.db_snapshot_identifier + "Poking for statuses : %s\nfor snapshot %s", self.target_statuses, self.db_snapshot_identifier ) - if self.db_type.value == "instance": - return self._check_item(item_type='instance_snapshot', item_name=self.db_snapshot_identifier) - else: - return self._check_item(item_type='cluster_snapshot', item_name=self.db_snapshot_identifier) + try: + if self.db_type.value == "instance": + state = self.hook.get_db_snapshot_state(self.db_snapshot_identifier) + else: + state = self.hook.get_db_cluster_snapshot_state(self.db_snapshot_identifier) + except AirflowNotFoundException: + return False + return state in self.target_statuses class RdsExportTaskExistenceSensor(RdsBaseSensor): @@ -112,7 +91,7 @@ class RdsExportTaskExistenceSensor(RdsBaseSensor): Waits for RDS export task with a specific status. .. seealso:: - For more information on how to use this operator, take a look at the guide: + For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:RdsExportTaskExistenceSensor` :param export_task_identifier: A unique identifier for the snapshot export task. @@ -120,15 +99,15 @@ class RdsExportTaskExistenceSensor(RdsBaseSensor): """ template_fields: Sequence[str] = ( - 'export_task_identifier', - 'target_statuses', + "export_task_identifier", + "target_statuses", ) def __init__( self, *, export_task_identifier: str, - target_statuses: Optional[List[str]] = None, + target_statuses: list[str] | None = None, aws_conn_id: str = "aws_default", **kwargs, ): @@ -136,21 +115,74 @@ def __init__( self.export_task_identifier = export_task_identifier self.target_statuses = target_statuses or [ - 'starting', - 'in_progress', - 'complete', - 'canceling', - 'canceled', + "starting", + "in_progress", + "complete", + "canceling", + "canceled", ] - def poke(self, context: 'Context'): + def poke(self, context: Context): + self.log.info( + "Poking for statuses : %s\nfor export task %s", self.target_statuses, self.export_task_identifier + ) + try: + state = self.hook.get_export_task_state(self.export_task_identifier) + except AirflowNotFoundException: + return False + return state in self.target_statuses + + +class RdsDbSensor(RdsBaseSensor): + """ + Waits for an RDS instance or cluster to enter one of a number of states + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:RdsDbSensor` + + :param db_type: Type of the DB - either "instance" or "cluster" (default: 'instance') + :param db_identifier: The AWS identifier for the DB + :param target_statuses: Target status of DB + """ + + template_fields: Sequence[str] = ( + "db_identifier", + "db_type", + "target_statuses", + ) + + def __init__( + self, + *, + db_identifier: str, + db_type: RdsDbType | str = RdsDbType.INSTANCE, + target_statuses: list[str] | None = None, + aws_conn_id: str = "aws_default", + **kwargs, + ): + super().__init__(aws_conn_id=aws_conn_id, **kwargs) + self.db_identifier = db_identifier + self.target_statuses = target_statuses or ["available"] + self.db_type = db_type + + def poke(self, context: Context): + db_type = RdsDbType(self.db_type) self.log.info( - 'Poking for statuses : %s\nfor export task %s', self.target_statuses, self.export_task_identifier + "Poking for statuses : %s\nfor db instance %s", self.target_statuses, self.db_identifier ) - return self._check_item(item_type='export_task', item_name=self.export_task_identifier) + try: + if db_type == RdsDbType.INSTANCE: + state = self.hook.get_db_instance_state(self.db_identifier) + else: + state = self.hook.get_db_cluster_state(self.db_identifier) + except AirflowNotFoundException: + return False + return state in self.target_statuses __all__ = [ "RdsExportTaskExistenceSensor", + "RdsDbSensor", "RdsSnapshotExistenceSensor", ] diff --git a/airflow/providers/amazon/aws/sensors/redshift.py b/airflow/providers/amazon/aws/sensors/redshift.py deleted file mode 100644 index 6a73e7ddba962..0000000000000 --- a/airflow/providers/amazon/aws/sensors/redshift.py +++ /dev/null @@ -1,30 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import warnings - -from airflow.providers.amazon.aws.sensors.redshift_cluster import RedshiftClusterSensor - -AwsRedshiftClusterSensor = RedshiftClusterSensor - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.redshift_cluster`.", - DeprecationWarning, - stacklevel=2, -) - -__all__ = ["AwsRedshiftClusterSensor", "RedshiftClusterSensor"] diff --git a/airflow/providers/amazon/aws/sensors/redshift_cluster.py b/airflow/providers/amazon/aws/sensors/redshift_cluster.py index ae772e95ffbf0..76f4f90111755 100644 --- a/airflow/providers/amazon/aws/sensors/redshift_cluster.py +++ b/airflow/providers/amazon/aws/sensors/redshift_cluster.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook from airflow.sensors.base import BaseSensorOperator @@ -35,24 +37,24 @@ class RedshiftClusterSensor(BaseSensorOperator): :param target_status: The cluster status desired. """ - template_fields: Sequence[str] = ('cluster_identifier', 'target_status') + template_fields: Sequence[str] = ("cluster_identifier", "target_status") def __init__( self, *, cluster_identifier: str, - target_status: str = 'available', - aws_conn_id: str = 'aws_default', + target_status: str = "available", + aws_conn_id: str = "aws_default", **kwargs, ): super().__init__(**kwargs) self.cluster_identifier = cluster_identifier self.target_status = target_status self.aws_conn_id = aws_conn_id - self.hook: Optional[RedshiftHook] = None + self.hook: RedshiftHook | None = None - def poke(self, context: 'Context'): - self.log.info('Poking for status : %s\nfor cluster %s', self.target_status, self.cluster_identifier) + def poke(self, context: Context): + self.log.info("Poking for status : %s\nfor cluster %s", self.target_status, self.cluster_identifier) return self.get_hook().cluster_status(self.cluster_identifier) == self.target_status def get_hook(self) -> RedshiftHook: diff --git a/airflow/providers/amazon/aws/sensors/s3.py b/airflow/providers/amazon/aws/sensors/s3.py index 182b05864cf1d..c60e46841ea98 100644 --- a/airflow/providers/amazon/aws/sensors/s3.py +++ b/airflow/providers/amazon/aws/sensors/s3.py @@ -15,24 +15,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations +import fnmatch import os import re -import sys -import warnings from datetime import datetime -from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Set, Union +from typing import TYPE_CHECKING, Callable, Sequence if TYPE_CHECKING: from airflow.utils.context import Context - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.sensors.base import BaseSensorOperator, poke_mode_only @@ -77,17 +71,17 @@ def check_fn(files: List) -> bool: CA cert bundle than the one used by botocore. """ - template_fields: Sequence[str] = ('bucket_key', 'bucket_name') + template_fields: Sequence[str] = ("bucket_key", "bucket_name") def __init__( self, *, - bucket_key: Union[str, List[str]], - bucket_name: Optional[str] = None, + bucket_key: str | list[str], + bucket_name: str | None = None, wildcard_match: bool = False, - check_fn: Optional[Callable[..., bool]] = None, - aws_conn_id: str = 'aws_default', - verify: Optional[Union[str, bool]] = None, + check_fn: Callable[..., bool] | None = None, + aws_conn_id: str = "aws_default", + verify: str | bool | None = None, **kwargs, ): super().__init__(**kwargs) @@ -97,11 +91,11 @@ def __init__( self.check_fn = check_fn self.aws_conn_id = aws_conn_id self.verify = verify - self.hook: Optional[S3Hook] = None + self.hook: S3Hook | None = None def _check_key(self, key): - bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, 'bucket_name', 'bucket_key') - self.log.info('Poking for key : s3://%s/%s', bucket_name, key) + bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key") + self.log.info("Poking for key : s3://%s/%s", bucket_name, key) """ Set variable `files` which contains a list of dict which contains only the size @@ -111,25 +105,26 @@ def _check_key(self, key): }] """ if self.wildcard_match: - prefix = re.split(r'[\[\*\?]', key, 1)[0] - files = self.get_hook().get_file_metadata(prefix, bucket_name) - if len(files) == 0: + prefix = re.split(r"[\[\*\?]", key, 1)[0] + keys = self.get_hook().get_file_metadata(prefix, bucket_name) + key_matches = [k for k in keys if fnmatch.fnmatch(k["Key"], key)] + if len(key_matches) == 0: return False # Reduce the set of metadata to size only - files = list(map(lambda f: {'Size': f['Size']}, files)) + files = list(map(lambda f: {"Size": f["Size"]}, key_matches)) else: obj = self.get_hook().head_object(key, bucket_name) if obj is None: return False - files = [{'Size': obj['ContentLength']}] + files = [{"Size": obj["ContentLength"]}] if self.check_fn is not None: return self.check_fn(files) return True - def poke(self, context: 'Context'): + def poke(self, context: Context): return all(self._check_key(key) for key in self.bucket_key) def get_hook(self) -> S3Hook: @@ -141,40 +136,6 @@ def get_hook(self) -> S3Hook: return self.hook -class S3KeySizeSensor(S3KeySensor): - """ - This class is deprecated. - Please use :class:`~airflow.providers.amazon.aws.sensors.s3.S3KeySensor`. - """ - - def __init__( - self, - *, - check_fn: Optional[Callable[..., bool]] = None, - **kwargs, - ): - warnings.warn( - """ - S3KeySizeSensor is deprecated. - Please use `airflow.providers.amazon.aws.sensors.s3.S3KeySensor`. - """, - DeprecationWarning, - stacklevel=2, - ) - - super().__init__( - check_fn=check_fn if check_fn is not None else S3KeySizeSensor.default_check_fn, **kwargs - ) - - @staticmethod - def default_check_fn(data: List) -> bool: - """Default function for checking that S3 Objects have size more than 0 - - :param data: List of the objects in S3 bucket. - """ - return all(f.get('Size', 0) > 0 for f in data) - - @poke_mode_only class S3KeysUnchangedSensor(BaseSensorOperator): """ @@ -213,18 +174,18 @@ class S3KeysUnchangedSensor(BaseSensorOperator): when this happens. If false an error will be raised. """ - template_fields: Sequence[str] = ('bucket_name', 'prefix') + template_fields: Sequence[str] = ("bucket_name", "prefix") def __init__( self, *, bucket_name: str, prefix: str, - aws_conn_id: str = 'aws_default', - verify: Optional[Union[bool, str]] = None, + aws_conn_id: str = "aws_default", + verify: bool | str | None = None, inactivity_period: float = 60 * 60, min_objects: int = 1, - previous_objects: Optional[Set[str]] = None, + previous_objects: set[str] | None = None, allow_delete: bool = True, **kwargs, ) -> None: @@ -242,14 +203,14 @@ def __init__( self.allow_delete = allow_delete self.aws_conn_id = aws_conn_id self.verify = verify - self.last_activity_time: Optional[datetime] = None + self.last_activity_time: datetime | None = None @cached_property def hook(self): """Returns S3Hook.""" return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) - def is_keys_unchanged(self, current_objects: Set[str]) -> bool: + def is_keys_unchanged(self, current_objects: set[str]) -> bool: """ Checks whether new objects have been uploaded and the inactivity_period has passed and updates the state of the sensor accordingly. @@ -313,36 +274,5 @@ def is_keys_unchanged(self, current_objects: Set[str]) -> bool: return False return False - def poke(self, context: 'Context'): + def poke(self, context: Context): return self.is_keys_unchanged(set(self.hook.list_keys(self.bucket_name, prefix=self.prefix))) - - -class S3PrefixSensor(S3KeySensor): - """ - This class is deprecated. - Please use :class:`~airflow.providers.amazon.aws.sensors.s3.S3KeySensor`. - """ - - template_fields: Sequence[str] = ('prefix', 'bucket_name') - - def __init__( - self, - *, - prefix: Union[str, List[str]], - delimiter: str = '/', - **kwargs, - ): - warnings.warn( - """ - S3PrefixSensor is deprecated. - Please use `airflow.providers.amazon.aws.sensors.s3.S3KeySensor`. - """, - DeprecationWarning, - stacklevel=2, - ) - - self.prefix = prefix - prefixes = [self.prefix] if isinstance(self.prefix, str) else self.prefix - keys = [pref if pref.endswith(delimiter) else pref + delimiter for pref in prefixes] - - super().__init__(bucket_key=keys, **kwargs) diff --git a/airflow/providers/amazon/aws/sensors/s3_key.py b/airflow/providers/amazon/aws/sensors/s3_key.py deleted file mode 100644 index deff11d00d1f2..0000000000000 --- a/airflow/providers/amazon/aws/sensors/s3_key.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.s3`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor, S3KeySizeSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/sensors/s3_keys_unchanged.py b/airflow/providers/amazon/aws/sensors/s3_keys_unchanged.py deleted file mode 100644 index 792d29c46cfcd..0000000000000 --- a/airflow/providers/amazon/aws/sensors/s3_keys_unchanged.py +++ /dev/null @@ -1,28 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.s3`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.s3 import S3KeysUnchangedSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/sensors/s3_prefix.py b/airflow/providers/amazon/aws/sensors/s3_prefix.py deleted file mode 100644 index 3990e8774a675..0000000000000 --- a/airflow/providers/amazon/aws/sensors/s3_prefix.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.s3`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.s3 import S3PrefixSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/sensors/sagemaker.py b/airflow/providers/amazon/aws/sensors/sagemaker.py index 925ddaed17fd9..2d9c9aacf0f06 100644 --- a/airflow/providers/amazon/aws/sensors/sagemaker.py +++ b/airflow/providers/amazon/aws/sensors/sagemaker.py @@ -14,9 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import time -from typing import TYPE_CHECKING, Optional, Sequence, Set +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook @@ -34,12 +35,12 @@ class SageMakerBaseSensor(BaseSensorOperator): Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods. """ - ui_color = '#ededed' + ui_color = "#ededed" - def __init__(self, *, aws_conn_id: str = 'aws_default', **kwargs): + def __init__(self, *, aws_conn_id: str = "aws_default", **kwargs): super().__init__(**kwargs) self.aws_conn_id = aws_conn_id - self.hook: Optional[SageMakerHook] = None + self.hook: SageMakerHook | None = None def get_hook(self) -> SageMakerHook: """Get SageMakerHook.""" @@ -48,39 +49,39 @@ def get_hook(self) -> SageMakerHook: self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id) return self.hook - def poke(self, context: 'Context'): + def poke(self, context: Context): response = self.get_sagemaker_response() - if response['ResponseMetadata']['HTTPStatusCode'] != 200: - self.log.info('Bad HTTP response: %s', response) + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + self.log.info("Bad HTTP response: %s", response) return False state = self.state_from_response(response) - self.log.info('Job currently %s', state) + self.log.info("Job currently %s", state) if state in self.non_terminal_states(): return False if state in self.failed_states(): failed_reason = self.get_failed_reason_from_response(response) - raise AirflowException(f'Sagemaker job failed for the following reason: {failed_reason}') + raise AirflowException(f"Sagemaker job failed for the following reason: {failed_reason}") return True - def non_terminal_states(self) -> Set[str]: + def non_terminal_states(self) -> set[str]: """Placeholder for returning states with should not terminate.""" - raise NotImplementedError('Please implement non_terminal_states() in subclass') + raise NotImplementedError("Please implement non_terminal_states() in subclass") - def failed_states(self) -> Set[str]: + def failed_states(self) -> set[str]: """Placeholder for returning states with are considered failed.""" - raise NotImplementedError('Please implement failed_states() in subclass') + raise NotImplementedError("Please implement failed_states() in subclass") def get_sagemaker_response(self) -> dict: """Placeholder for checking status of a SageMaker task.""" - raise NotImplementedError('Please implement get_sagemaker_response() in subclass') + raise NotImplementedError("Please implement get_sagemaker_response() in subclass") def get_failed_reason_from_response(self, response: dict) -> str: """Placeholder for extracting the reason for failure from an AWS response.""" - return 'Unknown' + return "Unknown" def state_from_response(self, response: dict) -> str: """Placeholder for extracting the state from an AWS response.""" - raise NotImplementedError('Please implement state_from_response() in subclass') + raise NotImplementedError("Please implement state_from_response() in subclass") class SageMakerEndpointSensor(SageMakerBaseSensor): @@ -95,7 +96,7 @@ class SageMakerEndpointSensor(SageMakerBaseSensor): :param endpoint_name: Name of the endpoint instance to watch. """ - template_fields: Sequence[str] = ('endpoint_name',) + template_fields: Sequence[str] = ("endpoint_name",) template_ext: Sequence[str] = () def __init__(self, *, endpoint_name, **kwargs): @@ -109,14 +110,14 @@ def failed_states(self): return SageMakerHook.failed_states def get_sagemaker_response(self): - self.log.info('Poking Sagemaker Endpoint %s', self.endpoint_name) + self.log.info("Poking Sagemaker Endpoint %s", self.endpoint_name) return self.get_hook().describe_endpoint(self.endpoint_name) def get_failed_reason_from_response(self, response): - return response['FailureReason'] + return response["FailureReason"] def state_from_response(self, response): - return response['EndpointStatus'] + return response["EndpointStatus"] class SageMakerTransformSensor(SageMakerBaseSensor): @@ -131,7 +132,7 @@ class SageMakerTransformSensor(SageMakerBaseSensor): :param job_name: Name of the transform job to watch. """ - template_fields: Sequence[str] = ('job_name',) + template_fields: Sequence[str] = ("job_name",) template_ext: Sequence[str] = () def __init__(self, *, job_name: str, **kwargs): @@ -145,14 +146,14 @@ def failed_states(self): return SageMakerHook.failed_states def get_sagemaker_response(self): - self.log.info('Poking Sagemaker Transform Job %s', self.job_name) + self.log.info("Poking Sagemaker Transform Job %s", self.job_name) return self.get_hook().describe_transform_job(self.job_name) def get_failed_reason_from_response(self, response): - return response['FailureReason'] + return response["FailureReason"] def state_from_response(self, response): - return response['TransformJobStatus'] + return response["TransformJobStatus"] class SageMakerTuningSensor(SageMakerBaseSensor): @@ -167,7 +168,7 @@ class SageMakerTuningSensor(SageMakerBaseSensor): :param job_name: Name of the tuning instance to watch. """ - template_fields: Sequence[str] = ('job_name',) + template_fields: Sequence[str] = ("job_name",) template_ext: Sequence[str] = () def __init__(self, *, job_name: str, **kwargs): @@ -181,14 +182,14 @@ def failed_states(self): return SageMakerHook.failed_states def get_sagemaker_response(self): - self.log.info('Poking Sagemaker Tuning Job %s', self.job_name) + self.log.info("Poking Sagemaker Tuning Job %s", self.job_name) return self.get_hook().describe_tuning_job(self.job_name) def get_failed_reason_from_response(self, response): - return response['FailureReason'] + return response["FailureReason"] def state_from_response(self, response): - return response['HyperParameterTuningJobStatus'] + return response["HyperParameterTuningJobStatus"] class SageMakerTrainingSensor(SageMakerBaseSensor): @@ -204,7 +205,7 @@ class SageMakerTrainingSensor(SageMakerBaseSensor): :param print_log: Prints the cloudwatch log if True; Defaults to True. """ - template_fields: Sequence[str] = ('job_name',) + template_fields: Sequence[str] = ("job_name",) template_ext: Sequence[str] = () def __init__(self, *, job_name, print_log=True, **kwargs): @@ -213,8 +214,8 @@ def __init__(self, *, job_name, print_log=True, **kwargs): self.print_log = print_log self.positions = {} self.stream_names = [] - self.instance_count: Optional[int] = None - self.state: Optional[int] = None + self.instance_count: int | None = None + self.state: int | None = None self.last_description = None self.last_describe_job_call = None self.log_resource_inited = False @@ -222,8 +223,8 @@ def __init__(self, *, job_name, print_log=True, **kwargs): def init_log_resource(self, hook: SageMakerHook) -> None: """Set tailing LogState for associated training job.""" description = hook.describe_training_job(self.job_name) - self.instance_count = description['ResourceConfig']['InstanceCount'] - status = description['TrainingJobStatus'] + self.instance_count = description["ResourceConfig"]["InstanceCount"] + status = description["TrainingJobStatus"] job_already_completed = status not in self.non_terminal_states() self.state = LogState.COMPLETE if job_already_completed else LogState.TAILING self.last_description = description @@ -258,13 +259,13 @@ def get_sagemaker_response(self): status = self.state_from_response(self.last_description) if (status not in self.non_terminal_states()) and (status not in self.failed_states()): billable_time = ( - self.last_description['TrainingEndTime'] - self.last_description['TrainingStartTime'] - ) * self.last_description['ResourceConfig']['InstanceCount'] - self.log.info('Billable seconds: %s', (int(billable_time.total_seconds()) + 1)) + self.last_description["TrainingEndTime"] - self.last_description["TrainingStartTime"] + ) * self.last_description["ResourceConfig"]["InstanceCount"] + self.log.info("Billable seconds: %s", (int(billable_time.total_seconds()) + 1)) return self.last_description def get_failed_reason_from_response(self, response): - return response['FailureReason'] + return response["FailureReason"] def state_from_response(self, response): - return response['TrainingJobStatus'] + return response["TrainingJobStatus"] diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_base.py b/airflow/providers/amazon/aws/sensors/sagemaker_base.py deleted file mode 100644 index 102c410d279a4..0000000000000 --- a/airflow/providers/amazon/aws/sensors/sagemaker_base.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerBaseSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_endpoint.py b/airflow/providers/amazon/aws/sensors/sagemaker_endpoint.py deleted file mode 100644 index 00ed8442386b4..0000000000000 --- a/airflow/providers/amazon/aws/sensors/sagemaker_endpoint.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerEndpointSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_training.py b/airflow/providers/amazon/aws/sensors/sagemaker_training.py deleted file mode 100644 index d1949964956e8..0000000000000 --- a/airflow/providers/amazon/aws/sensors/sagemaker_training.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerTrainingSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_transform.py b/airflow/providers/amazon/aws/sensors/sagemaker_transform.py deleted file mode 100644 index 7a48f3e7de006..0000000000000 --- a/airflow/providers/amazon/aws/sensors/sagemaker_transform.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerTransformSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py b/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py deleted file mode 100644 index d5f0d90555ffd..0000000000000 --- a/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerTuningSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/sensors/sqs.py b/airflow/providers/amazon/aws/sensors/sqs.py index cc026ec5e8306..8b09203d2e319 100644 --- a/airflow/providers/amazon/aws/sensors/sqs.py +++ b/airflow/providers/amazon/aws/sensors/sqs.py @@ -16,14 +16,16 @@ # specific language governing permissions and limitations # under the License. """Reads and then deletes the message from SQS queue""" +from __future__ import annotations + import json -import warnings -from typing import TYPE_CHECKING, Any, Optional, Sequence +from typing import TYPE_CHECKING, Any, Collection, Sequence from jsonpath_ng import parse from typing_extensions import Literal from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection from airflow.providers.amazon.aws.hooks.sqs import SqsHook from airflow.sensors.base import BaseSensorOperator @@ -34,9 +36,13 @@ class SqsSensor(BaseSensorOperator): """ Get messages from an Amazon SQS queue and then delete the messages from the queue. - If deletion of messages fails an AirflowException is thrown. Otherwise, the messages + If deletion of messages fails, an AirflowException is thrown. Otherwise, the messages are pushed through XCom with the key ``messages``. + By default,the sensor performs one and only one SQS call per poke, which limits the result to + a maximum of 10 messages. However, the total number of SQS API calls per poke can be controlled + by num_batches param. + .. seealso:: For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:SqsSensor` @@ -44,6 +50,7 @@ class SqsSensor(BaseSensorOperator): :param aws_conn_id: AWS connection id :param sqs_queue: The SQS queue url (templated) :param max_messages: The maximum number of messages to retrieve for each poke (templated) + :param num_batches: The number of times the sensor will call the SQS API to receive messages (default: 1) :param wait_time_seconds: The time in seconds to wait for receiving messages (default: 1 second) :param visibility_timeout: Visibility timeout, a period of time during which Amazon SQS prevents other consumers from receiving and processing the message. @@ -64,17 +71,18 @@ class SqsSensor(BaseSensorOperator): """ - template_fields: Sequence[str] = ('sqs_queue', 'max_messages', 'message_filtering_config') + template_fields: Sequence[str] = ("sqs_queue", "max_messages", "message_filtering_config") def __init__( self, *, sqs_queue, - aws_conn_id: str = 'aws_default', + aws_conn_id: str = "aws_default", max_messages: int = 5, + num_batches: int = 1, wait_time_seconds: int = 1, - visibility_timeout: Optional[int] = None, - message_filtering: Optional[Literal["literal", "jsonpath"]] = None, + visibility_timeout: int | None = None, + message_filtering: Literal["literal", "jsonpath"] | None = None, message_filtering_match_values: Any = None, message_filtering_config: Any = None, delete_message_on_reception: bool = True, @@ -84,6 +92,7 @@ def __init__( self.sqs_queue = sqs_queue self.aws_conn_id = aws_conn_id self.max_messages = max_messages + self.num_batches = num_batches self.wait_time_seconds = wait_time_seconds self.visibility_timeout = visibility_timeout @@ -96,39 +105,37 @@ def __init__( message_filtering_match_values = set(message_filtering_match_values) self.message_filtering_match_values = message_filtering_match_values - if self.message_filtering == 'literal': + if self.message_filtering == "literal": if self.message_filtering_match_values is None: - raise TypeError('message_filtering_match_values must be specified for literal matching') + raise TypeError("message_filtering_match_values must be specified for literal matching") self.message_filtering_config = message_filtering_config - self.hook: Optional[SqsHook] = None + self.hook: SqsHook | None = None - def poke(self, context: 'Context'): + def poll_sqs(self, sqs_conn: BaseAwsConnection) -> Collection: """ - Check for message on subscribed queue and write to xcom the message with key ``messages`` + Poll SQS queue to retrieve messages. - :param context: the context object - :return: ``True`` if message is available or ``False`` + :param sqs_conn: SQS connection + :return: A list of messages retrieved from SQS """ - sqs_conn = self.get_hook().get_conn() - - self.log.info('SqsSensor checking for message on queue: %s', self.sqs_queue) + self.log.info("SqsSensor checking for message on queue: %s", self.sqs_queue) receive_message_kwargs = { - 'QueueUrl': self.sqs_queue, - 'MaxNumberOfMessages': self.max_messages, - 'WaitTimeSeconds': self.wait_time_seconds, + "QueueUrl": self.sqs_queue, + "MaxNumberOfMessages": self.max_messages, + "WaitTimeSeconds": self.wait_time_seconds, } if self.visibility_timeout is not None: - receive_message_kwargs['VisibilityTimeout'] = self.visibility_timeout + receive_message_kwargs["VisibilityTimeout"] = self.visibility_timeout response = sqs_conn.receive_message(**receive_message_kwargs) if "Messages" not in response: - return False + return [] - messages = response['Messages'] + messages = response["Messages"] num_messages = len(messages) self.log.info("Received %d messages", num_messages) @@ -136,28 +143,47 @@ def poke(self, context: 'Context'): messages = self.filter_messages(messages) num_messages = len(messages) self.log.info("There are %d messages left after filtering", num_messages) + return messages - if not num_messages: - return False + def poke(self, context: Context): + """ + Check subscribed queue for messages and write them to xcom with the ``messages`` key. - if not self.delete_message_on_reception: - context['ti'].xcom_push(key='messages', value=messages) - return True + :param context: the context object + :return: ``True`` if message is available or ``False`` + """ + sqs_conn = self.get_hook().get_conn() - self.log.info("Deleting %d messages", num_messages) + message_batch: list[Any] = [] - entries = [ - {'Id': message['MessageId'], 'ReceiptHandle': message['ReceiptHandle']} for message in messages - ] - response = sqs_conn.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries) + # perform multiple SQS call to retrieve messages in series + for _ in range(self.num_batches): + messages = self.poll_sqs(sqs_conn=sqs_conn) - if 'Successful' in response: - context['ti'].xcom_push(key='messages', value=messages) - return True - else: - raise AirflowException( - 'Delete SQS Messages failed ' + str(response) + ' for messages ' + str(messages) - ) + if not len(messages): + continue + + message_batch.extend(messages) + + if self.delete_message_on_reception: + + self.log.info("Deleting %d messages", len(messages)) + + entries = [ + {"Id": message["MessageId"], "ReceiptHandle": message["ReceiptHandle"]} + for message in messages + ] + response = sqs_conn.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries) + + if "Successful" not in response: + raise AirflowException( + "Delete SQS Messages failed " + str(response) + " for messages " + str(messages) + ) + if not len(message_batch): + return False + + context["ti"].xcom_push(key="messages", value=message_batch) + return True def get_hook(self) -> SqsHook: """Create and return an SqsHook""" @@ -168,17 +194,17 @@ def get_hook(self) -> SqsHook: return self.hook def filter_messages(self, messages): - if self.message_filtering == 'literal': + if self.message_filtering == "literal": return self.filter_messages_literal(messages) - if self.message_filtering == 'jsonpath': + if self.message_filtering == "jsonpath": return self.filter_messages_jsonpath(messages) else: - raise NotImplementedError('Override this method to define custom filters') + raise NotImplementedError("Override this method to define custom filters") def filter_messages_literal(self, messages): filtered_messages = [] for message in messages: - if message['Body'] in self.message_filtering_match_values: + if message["Body"] in self.message_filtering_match_values: filtered_messages.append(message) return filtered_messages @@ -186,7 +212,7 @@ def filter_messages_jsonpath(self, messages): jsonpath_expr = parse(self.message_filtering_config) filtered_messages = [] for message in messages: - body = message['Body'] + body = message["Body"] # Body is a string, deserialize to an object and then parse body = json.loads(body) results = jsonpath_expr.find(body) @@ -200,18 +226,3 @@ def filter_messages_jsonpath(self, messages): filtered_messages.append(message) break return filtered_messages - - -class SQSSensor(SqsSensor): - """ - This sensor is deprecated. - Please use :class:`airflow.providers.amazon.aws.sensors.sqs.SqsSensor`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This class is deprecated. Please use `airflow.providers.amazon.aws.sensors.sqs.SqsSensor`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/sensors/step_function.py b/airflow/providers/amazon/aws/sensors/step_function.py index 6f82c0bc99d8c..fda6f932d8bbd 100644 --- a/airflow/providers/amazon/aws/sensors/step_function.py +++ b/airflow/providers/amazon/aws/sensors/step_function.py @@ -14,9 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import json -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook @@ -43,45 +44,45 @@ class StepFunctionExecutionSensor(BaseSensorOperator): :param aws_conn_id: aws connection to use, defaults to 'aws_default' """ - INTERMEDIATE_STATES = ('RUNNING',) + INTERMEDIATE_STATES = ("RUNNING",) FAILURE_STATES = ( - 'FAILED', - 'TIMED_OUT', - 'ABORTED', + "FAILED", + "TIMED_OUT", + "ABORTED", ) - SUCCESS_STATES = ('SUCCEEDED',) + SUCCESS_STATES = ("SUCCEEDED",) - template_fields: Sequence[str] = ('execution_arn',) + template_fields: Sequence[str] = ("execution_arn",) template_ext: Sequence[str] = () - ui_color = '#66c3ff' + ui_color = "#66c3ff" def __init__( self, *, execution_arn: str, - aws_conn_id: str = 'aws_default', - region_name: Optional[str] = None, + aws_conn_id: str = "aws_default", + region_name: str | None = None, **kwargs, ): super().__init__(**kwargs) self.execution_arn = execution_arn self.aws_conn_id = aws_conn_id self.region_name = region_name - self.hook: Optional[StepFunctionHook] = None + self.hook: StepFunctionHook | None = None - def poke(self, context: 'Context'): + def poke(self, context: Context): execution_status = self.get_hook().describe_execution(self.execution_arn) - state = execution_status['status'] - output = json.loads(execution_status['output']) if 'output' in execution_status else None + state = execution_status["status"] + output = json.loads(execution_status["output"]) if "output" in execution_status else None if state in self.FAILURE_STATES: - raise AirflowException(f'Step Function sensor failed. State Machine Output: {output}') + raise AirflowException(f"Step Function sensor failed. State Machine Output: {output}") if state in self.INTERMEDIATE_STATES: return False - self.log.info('Doing xcom_push of output') - self.xcom_push(context, 'output', output) + self.log.info("Doing xcom_push of output") + self.xcom_push(context, "output", output) return True def get_hook(self) -> StepFunctionHook: diff --git a/airflow/providers/amazon/aws/sensors/step_function_execution.py b/airflow/providers/amazon/aws/sensors/step_function_execution.py deleted file mode 100644 index 267343cd1df9e..0000000000000 --- a/airflow/providers/amazon/aws/sensors/step_function_execution.py +++ /dev/null @@ -1,28 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.step_function`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.step_function import StepFunctionExecutionSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.step_function`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py index a6f5f8da21aae..155f5439a6f0e 100644 --- a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py @@ -15,17 +15,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# - """ This module contains operators to replicate records from DynamoDB table to S3. """ +from __future__ import annotations + import json from copy import copy from os.path import getsize from tempfile import NamedTemporaryFile -from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence +from typing import IO, TYPE_CHECKING, Any, Callable, Sequence from uuid import uuid4 from airflow.models import BaseOperator @@ -36,12 +36,12 @@ from airflow.utils.context import Context -def _convert_item_to_json_bytes(item: Dict[str, Any]) -> bytes: - return (json.dumps(item) + '\n').encode('utf-8') +def _convert_item_to_json_bytes(item: dict[str, Any]) -> bytes: + return (json.dumps(item) + "\n").encode("utf-8") def _upload_file_to_s3( - file_obj: IO, bucket_name: str, s3_key_prefix: str, aws_conn_id: str = 'aws_default' + file_obj: IO, bucket_name: str, s3_key_prefix: str, aws_conn_id: str = "aws_default" ) -> None: s3_client = S3Hook(aws_conn_id=aws_conn_id).get_conn() file_obj.seek(0) @@ -80,8 +80,9 @@ class DynamoDBToS3Operator(BaseOperator): """ template_fields: Sequence[str] = ( - 's3_bucket_name', - 'dynamodb_table_name', + "s3_bucket_name", + "s3_key_prefix", + "dynamodb_table_name", ) template_fields_renderers = { "dynamodb_scan_kwargs": "json", @@ -93,10 +94,10 @@ def __init__( dynamodb_table_name: str, s3_bucket_name: str, file_size: int, - dynamodb_scan_kwargs: Optional[Dict[str, Any]] = None, - s3_key_prefix: str = '', - process_func: Callable[[Dict[str, Any]], bytes] = _convert_item_to_json_bytes, - aws_conn_id: str = 'aws_default', + dynamodb_scan_kwargs: dict[str, Any] | None = None, + s3_key_prefix: str = "", + process_func: Callable[[dict[str, Any]], bytes] = _convert_item_to_json_bytes, + aws_conn_id: str = "aws_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -108,12 +109,13 @@ def __init__( self.s3_key_prefix = s3_key_prefix self.aws_conn_id = aws_conn_id - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = DynamoDBHook(aws_conn_id=self.aws_conn_id) table = hook.get_conn().Table(self.dynamodb_table_name) scan_kwargs = copy(self.dynamodb_scan_kwargs) if self.dynamodb_scan_kwargs else {} err = None + f: IO[Any] with NamedTemporaryFile() as f: try: f = self._scan_dynamodb_and_upload_to_s3(f, scan_kwargs, table) @@ -127,16 +129,16 @@ def execute(self, context: 'Context') -> None: def _scan_dynamodb_and_upload_to_s3(self, temp_file: IO, scan_kwargs: dict, table: Any) -> IO: while True: response = table.scan(**scan_kwargs) - items = response['Items'] + items = response["Items"] for item in items: temp_file.write(self.process_func(item)) - if 'LastEvaluatedKey' not in response: + if "LastEvaluatedKey" not in response: # no more items to scan break - last_evaluated_key = response['LastEvaluatedKey'] - scan_kwargs['ExclusiveStartKey'] = last_evaluated_key + last_evaluated_key = response["LastEvaluatedKey"] + scan_kwargs["ExclusiveStartKey"] = last_evaluated_key # Upload the file to S3 if reach file size limit if getsize(temp_file.name) >= self.file_size: diff --git a/airflow/providers/amazon/aws/transfers/exasol_to_s3.py b/airflow/providers/amazon/aws/transfers/exasol_to_s3.py index 0f2fb7a99ea10..d5e78adf481b1 100644 --- a/airflow/providers/amazon/aws/transfers/exasol_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/exasol_to_s3.py @@ -16,9 +16,10 @@ # specific language governing permissions and limitations # under the License. """Transfers data from Exasol database into a S3 Bucket.""" +from __future__ import annotations from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Dict, Optional, Sequence +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -49,25 +50,25 @@ class ExasolToS3Operator(BaseOperator): method of :class:`~pyexasol.connection.ExaConnection`. """ - template_fields: Sequence[str] = ('query_or_table', 'key', 'bucket_name', 'query_params', 'export_params') + template_fields: Sequence[str] = ("query_or_table", "key", "bucket_name", "query_params", "export_params") template_fields_renderers = {"query_or_table": "sql", "query_params": "json", "export_params": "json"} - template_ext: Sequence[str] = ('.sql',) - ui_color = '#ededed' + template_ext: Sequence[str] = (".sql",) + ui_color = "#ededed" def __init__( self, *, query_or_table: str, key: str, - bucket_name: Optional[str] = None, + bucket_name: str | None = None, replace: bool = False, encrypt: bool = False, gzip: bool = False, - acl_policy: Optional[str] = None, - query_params: Optional[Dict] = None, - export_params: Optional[Dict] = None, - exasol_conn_id: str = 'exasol_default', - aws_conn_id: str = 'aws_default', + acl_policy: str | None = None, + query_params: dict | None = None, + export_params: dict | None = None, + exasol_conn_id: str = "exasol_default", + aws_conn_id: str = "aws_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -83,7 +84,7 @@ def __init__( self.exasol_conn_id = exasol_conn_id self.aws_conn_id = aws_conn_id - def execute(self, context: 'Context'): + def execute(self, context: Context): exasol_hook = ExasolHook(exasol_conn_id=self.exasol_conn_id) s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) diff --git a/airflow/providers/amazon/aws/transfers/ftp_to_s3.py b/airflow/providers/amazon/aws/transfers/ftp_to_s3.py index 1426599bc4763..45811a82e4823 100644 --- a/airflow/providers/amazon/aws/transfers/ftp_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/ftp_to_s3.py @@ -15,8 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -60,7 +62,7 @@ class FTPToS3Operator(BaseOperator): uploaded to the S3 bucket. """ - template_fields: Sequence[str] = ('ftp_path', 's3_bucket', 's3_key', 'ftp_filenames', 's3_filenames') + template_fields: Sequence[str] = ("ftp_path", "s3_bucket", "s3_key", "ftp_filenames", "s3_filenames") def __init__( self, @@ -68,14 +70,14 @@ def __init__( ftp_path: str, s3_bucket: str, s3_key: str, - ftp_filenames: Optional[Union[str, List[str]]] = None, - s3_filenames: Optional[Union[str, List[str]]] = None, - ftp_conn_id: str = 'ftp_default', - aws_conn_id: str = 'aws_default', + ftp_filenames: str | list[str] | None = None, + s3_filenames: str | list[str] | None = None, + ftp_conn_id: str = "ftp_default", + aws_conn_id: str = "aws_default", replace: bool = False, encrypt: bool = False, gzip: bool = False, - acl_policy: Optional[str] = None, + acl_policy: str | None = None, **kwargs, ): super().__init__(**kwargs) @@ -90,8 +92,8 @@ def __init__( self.encrypt = encrypt self.gzip = gzip self.acl_policy = acl_policy - self.s3_hook: Optional[S3Hook] = None - self.ftp_hook: Optional[FTPHook] = None + self.s3_hook: S3Hook | None = None + self.ftp_hook: FTPHook | None = None def __upload_to_s3_from_ftp(self, remote_filename, s3_file_key): with NamedTemporaryFile() as local_tmp_file: @@ -108,35 +110,35 @@ def __upload_to_s3_from_ftp(self, remote_filename, s3_file_key): gzip=self.gzip, acl_policy=self.acl_policy, ) - self.log.info('File upload to %s', s3_file_key) + self.log.info("File upload to %s", s3_file_key) - def execute(self, context: 'Context'): + def execute(self, context: Context): self.ftp_hook = FTPHook(ftp_conn_id=self.ftp_conn_id) self.s3_hook = S3Hook(self.aws_conn_id) if self.ftp_filenames: if isinstance(self.ftp_filenames, str): - self.log.info('Getting files in %s', self.ftp_path) + self.log.info("Getting files in %s", self.ftp_path) list_dir = self.ftp_hook.list_directory( path=self.ftp_path, ) - if self.ftp_filenames == '*': + if self.ftp_filenames == "*": files = list_dir else: ftp_filename: str = self.ftp_filenames files = list(filter(lambda f: ftp_filename in f, list_dir)) for file in files: - self.log.info('Moving file %s', file) + self.log.info("Moving file %s", file) if self.s3_filenames and isinstance(self.s3_filenames, str): filename = file.replace(self.ftp_filenames, self.s3_filenames) else: filename = file - s3_file_key = f'{self.s3_key}{filename}' + s3_file_key = f"{self.s3_key}{filename}" self.__upload_to_s3_from_ftp(file, s3_file_key) else: diff --git a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py index b521ce5360741..6a4a58103689b 100644 --- a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py @@ -16,9 +16,11 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Cloud Storage to S3 operator.""" +from __future__ import annotations + import os import warnings -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -83,30 +85,30 @@ class GCSToS3Operator(BaseOperator): """ template_fields: Sequence[str] = ( - 'bucket', - 'prefix', - 'delimiter', - 'dest_s3_key', - 'google_impersonation_chain', + "bucket", + "prefix", + "delimiter", + "dest_s3_key", + "google_impersonation_chain", ) - ui_color = '#f0eee4' + ui_color = "#f0eee4" def __init__( self, *, bucket: str, - prefix: Optional[str] = None, - delimiter: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - google_cloud_storage_conn_id: Optional[str] = None, - delegate_to: Optional[str] = None, - dest_aws_conn_id: str = 'aws_default', + prefix: str | None = None, + delimiter: str | None = None, + gcp_conn_id: str = "google_cloud_default", + google_cloud_storage_conn_id: str | None = None, + delegate_to: str | None = None, + dest_aws_conn_id: str = "aws_default", dest_s3_key: str, - dest_verify: Optional[Union[str, bool]] = None, + dest_verify: str | bool | None = None, replace: bool = False, - google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - dest_s3_extra_args: Optional[Dict] = None, - s3_acl_policy: Optional[str] = None, + google_impersonation_chain: str | Sequence[str] | None = None, + dest_s3_extra_args: dict | None = None, + s3_acl_policy: str | None = None, keep_directory_structure: bool = True, **kwargs, ) -> None: @@ -135,7 +137,7 @@ def __init__( self.s3_acl_policy = s3_acl_policy self.keep_directory_structure = keep_directory_structure - def execute(self, context: 'Context') -> List[str]: + def execute(self, context: Context) -> list[str]: # list all files in an Google Cloud Storage bucket hook = GCSHook( gcp_conn_id=self.gcp_conn_id, @@ -144,7 +146,7 @@ def execute(self, context: 'Context') -> List[str]: ) self.log.info( - 'Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s', + "Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s", self.bucket, self.delimiter, self.prefix, @@ -170,7 +172,7 @@ def execute(self, context: 'Context') -> List[str]: # in case that no files exists, return an empty array to avoid errors existing_files = existing_files if existing_files is not None else [] # remove the prefix for the existing files to allow the match - existing_files = [file.replace(prefix, '', 1) for file in existing_files] + existing_files = [file.replace(prefix, "", 1) for file in existing_files] files = list(set(files) - set(existing_files)) if files: diff --git a/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py b/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py index 07d3410ee7334..4a189230a4bb1 100644 --- a/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py +++ b/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py @@ -15,8 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import tempfile -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.glacier import GlacierHook @@ -71,8 +73,8 @@ def __init__( object_name: str, gzip: bool, chunk_size: int = 1024, - delegate_to: Optional[str] = None, - google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + google_impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -86,7 +88,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = google_impersonation_chain - def execute(self, context: 'Context') -> str: + def execute(self, context: Context) -> str: glacier_hook = GlacierHook(aws_conn_id=self.aws_conn_id) gcs_hook = GCSHook( gcp_conn_id=self.gcp_conn_id, diff --git a/airflow/providers/amazon/aws/transfers/google_api_to_s3.py b/airflow/providers/amazon/aws/transfers/google_api_to_s3.py index f3e10b62b2aa0..223b13472948c 100644 --- a/airflow/providers/amazon/aws/transfers/google_api_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/google_api_to_s3.py @@ -15,11 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module allows you to transfer data from any Google API endpoint into a S3 Bucket.""" +from __future__ import annotations + import json import sys -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator, TaskInstance from airflow.models.xcom import MAX_XCOM_SIZE, XCOM_RETURN_KEY @@ -57,6 +58,10 @@ class GoogleApiToS3Operator(BaseOperator): :param google_api_endpoint_params: The params to control the corresponding endpoint result. :param s3_destination_key: The url where to put the data retrieved from the endpoint in S3. + + .. note See https://docs.aws.amazon.com/AmazonS3/latest/userguide/access-bucket-intro.html + for valid url formats. + :param google_api_response_via_xcom: Can be set to expose the google api response to xcom. :param google_api_endpoint_params_via_xcom: If set to a value this value will be used as a key for pulling from xcom and updating the google api endpoint params. @@ -85,12 +90,13 @@ class GoogleApiToS3Operator(BaseOperator): """ template_fields: Sequence[str] = ( - 'google_api_endpoint_params', - 's3_destination_key', - 'google_impersonation_chain', + "google_api_endpoint_params", + "s3_destination_key", + "google_impersonation_chain", + "gcp_conn_id", ) template_ext: Sequence[str] = () - ui_color = '#cc181e' + ui_color = "#cc181e" def __init__( self, @@ -100,16 +106,16 @@ def __init__( google_api_endpoint_path: str, google_api_endpoint_params: dict, s3_destination_key: str, - google_api_response_via_xcom: Optional[str] = None, - google_api_endpoint_params_via_xcom: Optional[str] = None, - google_api_endpoint_params_via_xcom_task_ids: Optional[str] = None, + google_api_response_via_xcom: str | None = None, + google_api_endpoint_params_via_xcom: str | None = None, + google_api_endpoint_params_via_xcom_task_ids: str | None = None, google_api_pagination: bool = False, google_api_num_retries: int = 0, s3_overwrite: bool = False, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - aws_conn_id: str = 'aws_default', - google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + aws_conn_id: str = "aws_default", + google_impersonation_chain: str | Sequence[str] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -129,23 +135,23 @@ def __init__( self.aws_conn_id = aws_conn_id self.google_impersonation_chain = google_impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: """ Transfers Google APIs json data to S3. :param context: The context that is being provided when executing. """ - self.log.info('Transferring data from %s to s3', self.google_api_service_name) + self.log.info("Transferring data from %s to s3", self.google_api_service_name) if self.google_api_endpoint_params_via_xcom: - self._update_google_api_endpoint_params_via_xcom(context['task_instance']) + self._update_google_api_endpoint_params_via_xcom(context["task_instance"]) data = self._retrieve_data_from_google_api() self._load_data_to_s3(data) if self.google_api_response_via_xcom: - self._expose_google_api_response_via_xcom(context['task_instance'], data) + self._expose_google_api_response_via_xcom(context["task_instance"], data) def _retrieve_data_from_google_api(self) -> dict: google_discovery_api_hook = GoogleDiscoveryApiHook( @@ -165,7 +171,10 @@ def _retrieve_data_from_google_api(self) -> dict: def _load_data_to_s3(self, data: dict) -> None: s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) s3_hook.load_string( - string_data=json.dumps(data), key=self.s3_destination_key, replace=self.s3_overwrite + string_data=json.dumps(data), + bucket_name=S3Hook.parse_s3_url(self.s3_destination_key)[0], + key=S3Hook.parse_s3_url(self.s3_destination_key)[1], + replace=self.s3_overwrite, ) def _update_google_api_endpoint_params_via_xcom(self, task_instance: TaskInstance) -> None: @@ -181,4 +190,4 @@ def _expose_google_api_response_via_xcom(self, task_instance: TaskInstance, data if sys.getsizeof(data) < MAX_XCOM_SIZE: task_instance.xcom_push(key=self.google_api_response_via_xcom or XCOM_RETURN_KEY, value=data) else: - raise RuntimeError('The size of the downloaded data is too large to push to XCom!') + raise RuntimeError("The size of the downloaded data is too large to push to XCom!") diff --git a/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py b/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py index 652eca22b86df..167665d627b82 100644 --- a/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py +++ b/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py @@ -15,11 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains operator to move data from Hive to DynamoDB.""" +from __future__ import annotations import json -from typing import TYPE_CHECKING, Callable, Optional, Sequence +from typing import TYPE_CHECKING, Callable, Sequence from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook @@ -52,10 +52,10 @@ class HiveToDynamoDBOperator(BaseOperator): :param aws_conn_id: aws connection """ - template_fields: Sequence[str] = ('sql',) - template_ext: Sequence[str] = ('.sql',) + template_fields: Sequence[str] = ("sql",) + template_ext: Sequence[str] = (".sql",) template_fields_renderers = {"sql": "hql"} - ui_color = '#a0e08c' + ui_color = "#a0e08c" def __init__( self, @@ -63,13 +63,13 @@ def __init__( sql: str, table_name: str, table_keys: list, - pre_process: Optional[Callable] = None, - pre_process_args: Optional[list] = None, - pre_process_kwargs: Optional[list] = None, - region_name: Optional[str] = None, - schema: str = 'default', - hiveserver2_conn_id: str = 'hiveserver2_default', - aws_conn_id: str = 'aws_default', + pre_process: Callable | None = None, + pre_process_args: list | None = None, + pre_process_kwargs: list | None = None, + region_name: str | None = None, + schema: str = "default", + hiveserver2_conn_id: str = "hiveserver2_default", + aws_conn_id: str = "aws_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -84,10 +84,10 @@ def __init__( self.hiveserver2_conn_id = hiveserver2_conn_id self.aws_conn_id = aws_conn_id - def execute(self, context: 'Context'): + def execute(self, context: Context): hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id) - self.log.info('Extracting data from Hive') + self.log.info("Extracting data from Hive") self.log.info(self.sql) data = hive.get_pandas_df(self.sql, schema=self.schema) @@ -98,13 +98,13 @@ def execute(self, context: 'Context'): region_name=self.region_name, ) - self.log.info('Inserting rows into dynamodb') + self.log.info("Inserting rows into dynamodb") if self.pre_process is None: - dynamodb.write_batch_data(json.loads(data.to_json(orient='records'))) + dynamodb.write_batch_data(json.loads(data.to_json(orient="records"))) else: dynamodb.write_batch_data( self.pre_process(data=data, args=self.pre_process_args, kwargs=self.pre_process_kwargs) ) - self.log.info('Done.') + self.log.info("Done.") diff --git a/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py b/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py index e79276dbc7190..b4394d6389254 100644 --- a/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py @@ -16,8 +16,10 @@ # specific language governing permissions and limitations # under the License. """This module allows you to transfer mail attachments from a mail server into s3 bucket.""" +from __future__ import annotations + import warnings -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -51,7 +53,7 @@ class ImapAttachmentToS3Operator(BaseOperator): :param aws_conn_id: AWS connection to use. """ - template_fields: Sequence[str] = ('imap_attachment_name', 's3_key', 'imap_mail_filter') + template_fields: Sequence[str] = ("imap_attachment_name", "s3_key", "imap_mail_filter") def __init__( self, @@ -60,12 +62,12 @@ def __init__( s3_bucket: str, s3_key: str, imap_check_regex: bool = False, - imap_mail_folder: str = 'INBOX', - imap_mail_filter: str = 'All', + imap_mail_folder: str = "INBOX", + imap_mail_filter: str = "All", s3_overwrite: bool = False, - imap_conn_id: str = 'imap_default', - s3_conn_id: Optional[str] = None, - aws_conn_id: str = 'aws_default', + imap_conn_id: str = "imap_default", + s3_conn_id: str | None = None, + aws_conn_id: str = "aws_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -83,14 +85,14 @@ def __init__( self.imap_conn_id = imap_conn_id self.aws_conn_id = aws_conn_id - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: """ This function executes the transfer from the email server (via imap) into s3. :param context: The context while executing. """ self.log.info( - 'Transferring mail attachment %s from mail server via imap to s3 key %s...', + "Transferring mail attachment %s from mail server via imap to s3 key %s...", self.imap_attachment_name, self.s3_key, ) diff --git a/airflow/providers/amazon/aws/transfers/local_to_s3.py b/airflow/providers/amazon/aws/transfers/local_to_s3.py index d3e76dcb1d275..d431935738b09 100644 --- a/airflow/providers/amazon/aws/transfers/local_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/local_to_s3.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -62,20 +64,20 @@ class LocalFilesystemToS3Operator(BaseOperator): uploaded to the S3 bucket. """ - template_fields: Sequence[str] = ('filename', 'dest_key', 'dest_bucket') + template_fields: Sequence[str] = ("filename", "dest_key", "dest_bucket") def __init__( self, *, filename: str, dest_key: str, - dest_bucket: Optional[str] = None, - aws_conn_id: str = 'aws_default', - verify: Optional[Union[str, bool]] = None, + dest_bucket: str | None = None, + aws_conn_id: str = "aws_default", + verify: str | bool | None = None, replace: bool = False, encrypt: bool = False, gzip: bool = False, - acl_policy: Optional[str] = None, + acl_policy: str | None = None, **kwargs, ): super().__init__(**kwargs) @@ -90,10 +92,10 @@ def __init__( self.gzip = gzip self.acl_policy = acl_policy - def execute(self, context: 'Context'): + def execute(self, context: Context): s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) s3_bucket, s3_key = s3_hook.get_s3_bucket_key( - self.dest_bucket, self.dest_key, 'dest_bucket', 'dest_key' + self.dest_bucket, self.dest_key, "dest_bucket", "dest_key" ) s3_hook.load_file( self.filename, diff --git a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py index 44aae36378a29..eaa41f114d7da 100644 --- a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py @@ -15,9 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import json import warnings -from typing import TYPE_CHECKING, Any, Iterable, Optional, Sequence, Union, cast +from typing import TYPE_CHECKING, Any, Iterable, Sequence, cast from bson import json_util @@ -57,25 +59,25 @@ class MongoToS3Operator(BaseOperator): :param compression: type of compression to use for output file in S3. Currently only gzip is supported. """ - template_fields: Sequence[str] = ('s3_bucket', 's3_key', 'mongo_query', 'mongo_collection') - ui_color = '#589636' + template_fields: Sequence[str] = ("s3_bucket", "s3_key", "mongo_query", "mongo_collection") + ui_color = "#589636" template_fields_renderers = {"mongo_query": "json"} def __init__( self, *, - s3_conn_id: Optional[str] = None, - mongo_conn_id: str = 'mongo_default', - aws_conn_id: str = 'aws_default', + s3_conn_id: str | None = None, + mongo_conn_id: str = "mongo_default", + aws_conn_id: str = "aws_default", mongo_collection: str, - mongo_query: Union[list, dict], + mongo_query: list | dict, s3_bucket: str, s3_key: str, - mongo_db: Optional[str] = None, - mongo_projection: Optional[Union[list, dict]] = None, + mongo_db: str | None = None, + mongo_projection: list | dict | None = None, replace: bool = False, allow_disk_use: bool = False, - compression: Optional[str] = None, + compression: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -99,7 +101,7 @@ def __init__( self.allow_disk_use = allow_disk_use self.compression = compression - def execute(self, context: 'Context'): + def execute(self, context: Context): """Is written to depend on transform method""" s3_conn = S3Hook(self.aws_conn_id) @@ -132,7 +134,7 @@ def execute(self, context: 'Context'): ) @staticmethod - def _stringify(iterable: Iterable, joinable: str = '\n') -> str: + def _stringify(iterable: Iterable, joinable: str = "\n") -> str: """ Takes an iterable (pymongo Cursor or Array) containing dictionaries and returns a stringified version using python join diff --git a/airflow/providers/amazon/aws/transfers/mysql_to_s3.py b/airflow/providers/amazon/aws/transfers/mysql_to_s3.py deleted file mode 100644 index dc3d84ecb3658..0000000000000 --- a/airflow/providers/amazon/aws/transfers/mysql_to_s3.py +++ /dev/null @@ -1,72 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import warnings -from typing import Optional - -from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.transfers.sql_to_s3 import SqlToS3Operator - -warnings.warn( - "This module is deprecated. Please use airflow.providers.amazon.aws.transfers.sql_to_s3`.", - DeprecationWarning, - stacklevel=2, -) - - -class MySQLToS3Operator(SqlToS3Operator): - """ - This class is deprecated. - Please use `airflow.providers.amazon.aws.transfers.sql_to_s3.SqlToS3Operator`. - """ - - template_fields_renderers = { - "pd_csv_kwargs": "json", - } - - def __init__( - self, - *, - mysql_conn_id: str = 'mysql_default', - pd_csv_kwargs: Optional[dict] = None, - index: bool = False, - header: bool = False, - **kwargs, - ) -> None: - warnings.warn( - """ - MySQLToS3Operator is deprecated. - Please use `airflow.providers.amazon.aws.transfers.sql_to_s3.SqlToS3Operator`. - """, - DeprecationWarning, - stacklevel=2, - ) - - pd_kwargs = kwargs.get('pd_kwargs', {}) - if kwargs.get('file_format', "csv") == "csv": - if "path_or_buf" in pd_kwargs: - raise AirflowException('The argument path_or_buf is not allowed, please remove it') - if "index" not in pd_kwargs: - pd_kwargs["index"] = index - if "header" not in pd_kwargs: - pd_kwargs["header"] = header - kwargs["pd_kwargs"] = {**kwargs.get('pd_kwargs', {}), **pd_kwargs} - elif pd_csv_kwargs is not None: - raise TypeError("pd_csv_kwargs may not be specified when file_format='parquet'") - - super().__init__(sql_conn_id=mysql_conn_id, **kwargs) diff --git a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py index 0bfdda44e78d6..ed57df0984054 100644 --- a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """Transfers data from AWS Redshift into a S3 Bucket.""" -from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable, Mapping, Sequence from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook @@ -29,7 +31,7 @@ class RedshiftToS3Operator(BaseOperator): """ - Executes an UNLOAD command to s3 as a CSV with headers + Execute an UNLOAD command to s3 as a CSV with headers. .. seealso:: For more information on how to use this operator, take a look at the guide: @@ -68,44 +70,45 @@ class RedshiftToS3Operator(BaseOperator): """ template_fields: Sequence[str] = ( - 's3_bucket', - 's3_key', - 'schema', - 'table', - 'unload_options', - 'select_query', + "s3_bucket", + "s3_key", + "schema", + "table", + "unload_options", + "select_query", + "redshift_conn_id", ) - template_ext: Sequence[str] = ('.sql',) - template_fields_renderers = {'select_query': 'sql'} - ui_color = '#ededed' + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"select_query": "sql"} + ui_color = "#ededed" def __init__( self, *, s3_bucket: str, s3_key: str, - schema: Optional[str] = None, - table: Optional[str] = None, - select_query: Optional[str] = None, - redshift_conn_id: str = 'redshift_default', - aws_conn_id: str = 'aws_default', - verify: Optional[Union[bool, str]] = None, - unload_options: Optional[List] = None, + schema: str | None = None, + table: str | None = None, + select_query: str | None = None, + redshift_conn_id: str = "redshift_default", + aws_conn_id: str = "aws_default", + verify: bool | str | None = None, + unload_options: list | None = None, autocommit: bool = False, include_header: bool = False, - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Iterable | Mapping | None = None, table_as_file_name: bool = True, # Set to True by default for not breaking current workflows **kwargs, ) -> None: super().__init__(**kwargs) self.s3_bucket = s3_bucket - self.s3_key = f'{s3_key}/{table}_' if (table and table_as_file_name) else s3_key + self.s3_key = f"{s3_key}/{table}_" if (table and table_as_file_name) else s3_key self.schema = schema self.table = table self.redshift_conn_id = redshift_conn_id self.aws_conn_id = aws_conn_id self.verify = verify - self.unload_options = unload_options or [] # type: List + self.unload_options: list = unload_options or [] self.autocommit = autocommit self.include_header = include_header self.parameters = parameters @@ -117,12 +120,12 @@ def __init__( self.select_query = f"SELECT * FROM {self.schema}.{self.table}" else: raise ValueError( - 'Please provide both `schema` and `table` params or `select_query` to fetch the data.' + "Please provide both `schema` and `table` params or `select_query` to fetch the data." ) - if self.include_header and 'HEADER' not in [uo.upper().strip() for uo in self.unload_options]: + if self.include_header and "HEADER" not in [uo.upper().strip() for uo in self.unload_options]: self.unload_options = list(self.unload_options) + [ - 'HEADER', + "HEADER", ] def _build_unload_query( @@ -136,22 +139,22 @@ def _build_unload_query( {unload_options}; """ - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: redshift_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id) conn = S3Hook.get_connection(conn_id=self.aws_conn_id) - if conn.extra_dejson.get('role_arn', False): + if conn.extra_dejson.get("role_arn", False): credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}" else: s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) credentials = s3_hook.get_credentials() credentials_block = build_credentials_block(credentials) - unload_options = '\n\t\t\t'.join(self.unload_options) + unload_options = "\n\t\t\t".join(self.unload_options) unload_query = self._build_unload_query( credentials_block, self.select_query, self.s3_key, unload_options ) - self.log.info('Executing UNLOAD command...') + self.log.info("Executing UNLOAD command...") redshift_hook.run(unload_query, self.autocommit, parameters=self.parameters) self.log.info("UNLOAD command complete...") diff --git a/airflow/providers/amazon/aws/transfers/s3_to_ftp.py b/airflow/providers/amazon/aws/transfers/s3_to_ftp.py index 2e07a9575fd26..9a15772310747 100644 --- a/airflow/providers/amazon/aws/transfers/s3_to_ftp.py +++ b/airflow/providers/amazon/aws/transfers/s3_to_ftp.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Sequence @@ -46,7 +47,7 @@ class S3ToFTPOperator(BaseOperator): establishing a connection to the FTP server. """ - template_fields: Sequence[str] = ('s3_bucket', 's3_key', 'ftp_path') + template_fields: Sequence[str] = ("s3_bucket", "s3_key", "ftp_path") def __init__( self, @@ -54,8 +55,8 @@ def __init__( s3_bucket, s3_key, ftp_path, - aws_conn_id='aws_default', - ftp_conn_id='ftp_default', + aws_conn_id="aws_default", + ftp_conn_id="ftp_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -65,15 +66,15 @@ def __init__( self.aws_conn_id = aws_conn_id self.ftp_conn_id = ftp_conn_id - def execute(self, context: 'Context'): + def execute(self, context: Context): s3_hook = S3Hook(self.aws_conn_id) ftp_hook = FTPHook(ftp_conn_id=self.ftp_conn_id) s3_obj = s3_hook.get_key(self.s3_key, self.s3_bucket) with NamedTemporaryFile() as local_tmp_file: - self.log.info('Downloading file from %s', self.s3_key) + self.log.info("Downloading file from %s", self.s3_key) s3_obj.download_fileobj(local_tmp_file) local_tmp_file.seek(0) ftp_hook.store_file(self.ftp_path, local_tmp_file.name) - self.log.info('File stored in %s', {self.ftp_path}) + self.log.info("File stored in %s", {self.ftp_path}) diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py index 014e23ec070f2..bbb83310b7bb7 100644 --- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py +++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py @@ -14,9 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -import warnings -from typing import TYPE_CHECKING, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Iterable, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -28,7 +28,7 @@ from airflow.utils.context import Context -AVAILABLE_METHODS = ['APPEND', 'REPLACE', 'UPSERT'] +AVAILABLE_METHODS = ["APPEND", "REPLACE", "UPSERT"] class S3ToRedshiftOperator(BaseOperator): @@ -64,9 +64,17 @@ class S3ToRedshiftOperator(BaseOperator): :param upsert_keys: List of fields to use as key on upsert action """ - template_fields: Sequence[str] = ('s3_bucket', 's3_key', 'schema', 'table', 'column_list', 'copy_options') + template_fields: Sequence[str] = ( + "s3_bucket", + "s3_key", + "schema", + "table", + "column_list", + "copy_options", + "redshift_conn_id", + ) template_ext: Sequence[str] = () - ui_color = '#99e699' + ui_color = "#99e699" def __init__( self, @@ -75,27 +83,16 @@ def __init__( table: str, s3_bucket: str, s3_key: str, - redshift_conn_id: str = 'redshift_default', - aws_conn_id: str = 'aws_default', - verify: Optional[Union[bool, str]] = None, - column_list: Optional[List[str]] = None, - copy_options: Optional[List] = None, + redshift_conn_id: str = "redshift_default", + aws_conn_id: str = "aws_default", + verify: bool | str | None = None, + column_list: list[str] | None = None, + copy_options: list | None = None, autocommit: bool = False, - method: str = 'APPEND', - upsert_keys: Optional[List[str]] = None, + method: str = "APPEND", + upsert_keys: list[str] | None = None, **kwargs, ) -> None: - - if 'truncate_table' in kwargs: - warnings.warn( - """`truncate_table` is deprecated. Please use `REPLACE` method.""", - DeprecationWarning, - stacklevel=2, - ) - if kwargs['truncate_table']: - method = 'REPLACE' - kwargs.pop('truncate_table', None) - super().__init__(**kwargs) self.schema = schema self.table = table @@ -111,10 +108,10 @@ def __init__( self.upsert_keys = upsert_keys if self.method not in AVAILABLE_METHODS: - raise AirflowException(f'Method not found! Available methods: {AVAILABLE_METHODS}') + raise AirflowException(f"Method not found! Available methods: {AVAILABLE_METHODS}") def _build_copy_query(self, copy_destination: str, credentials_block: str, copy_options: str) -> str: - column_names = "(" + ", ".join(self.column_list) + ")" if self.column_list else '' + column_names = "(" + ", ".join(self.column_list) + ")" if self.column_list else "" return f""" COPY {copy_destination} {column_names} FROM 's3://{self.s3_bucket}/{self.s3_key}' @@ -123,34 +120,34 @@ def _build_copy_query(self, copy_destination: str, credentials_block: str, copy_ {copy_options}; """ - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: redshift_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id) conn = S3Hook.get_connection(conn_id=self.aws_conn_id) - if conn.extra_dejson.get('role_arn', False): + if conn.extra_dejson.get("role_arn", False): credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}" else: s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) credentials = s3_hook.get_credentials() credentials_block = build_credentials_block(credentials) - copy_options = '\n\t\t\t'.join(self.copy_options) - destination = f'{self.schema}.{self.table}' - copy_destination = f'#{self.table}' if self.method == 'UPSERT' else destination + copy_options = "\n\t\t\t".join(self.copy_options) + destination = f"{self.schema}.{self.table}" + copy_destination = f"#{self.table}" if self.method == "UPSERT" else destination copy_statement = self._build_copy_query(copy_destination, credentials_block, copy_options) - sql: Union[list, str] + sql: str | Iterable[str] - if self.method == 'REPLACE': + if self.method == "REPLACE": sql = ["BEGIN;", f"DELETE FROM {destination};", copy_statement, "COMMIT"] - elif self.method == 'UPSERT': + elif self.method == "UPSERT": keys = self.upsert_keys or redshift_hook.get_table_primary_key(self.table, self.schema) if not keys: raise AirflowException( f"No primary key on {self.schema}.{self.table}. Please provide keys on 'upsert_keys'" ) - where_statement = ' AND '.join([f'{self.table}.{k} = {copy_destination}.{k}' for k in keys]) + where_statement = " AND ".join([f"{self.table}.{k} = {copy_destination}.{k}" for k in keys]) sql = [ f"CREATE TABLE {copy_destination} (LIKE {destination});", @@ -164,6 +161,6 @@ def execute(self, context: 'Context') -> None: else: sql = copy_statement - self.log.info('Executing COPY command...') + self.log.info("Executing COPY command...") redshift_hook.run(sql, autocommit=self.autocommit) self.log.info("COPY command complete...") diff --git a/airflow/providers/amazon/aws/transfers/s3_to_sftp.py b/airflow/providers/amazon/aws/transfers/s3_to_sftp.py index 7c003cfb72608..3038c17cb0a6c 100644 --- a/airflow/providers/amazon/aws/transfers/s3_to_sftp.py +++ b/airflow/providers/amazon/aws/transfers/s3_to_sftp.py @@ -15,10 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import warnings from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Optional, Sequence -from urllib.parse import urlparse +from typing import TYPE_CHECKING, Sequence +from urllib.parse import urlsplit from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -51,7 +53,7 @@ class S3ToSFTPOperator(BaseOperator): downloading the file from S3. """ - template_fields: Sequence[str] = ('s3_key', 'sftp_path') + template_fields: Sequence[str] = ("s3_key", "sftp_path") def __init__( self, @@ -59,9 +61,9 @@ def __init__( s3_bucket: str, s3_key: str, sftp_path: str, - sftp_conn_id: str = 'ssh_default', - s3_conn_id: Optional[str] = None, - aws_conn_id: str = 'aws_default', + sftp_conn_id: str = "ssh_default", + s3_conn_id: str | None = None, + aws_conn_id: str = "aws_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -78,10 +80,10 @@ def __init__( @staticmethod def get_s3_key(s3_key: str) -> str: """This parses the correct format for S3 keys regardless of how the S3 url is passed.""" - parsed_s3_key = urlparse(s3_key) - return parsed_s3_key.path.lstrip('/') + parsed_s3_key = urlsplit(s3_key) + return parsed_s3_key.path.lstrip("/") - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: self.s3_key = self.get_s3_key(self.s3_key) ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id) s3_hook = S3Hook(self.aws_conn_id) diff --git a/airflow/providers/amazon/aws/transfers/salesforce_to_s3.py b/airflow/providers/amazon/aws/transfers/salesforce_to_s3.py index a953693f1f193..f0d9c82f3fe63 100644 --- a/airflow/providers/amazon/aws/transfers/salesforce_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/salesforce_to_s3.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import os import tempfile -from typing import TYPE_CHECKING, Dict, Optional, Sequence +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -70,7 +71,7 @@ def __init__( s3_key: str, salesforce_conn_id: str, export_format: str = "csv", - query_params: Optional[Dict] = None, + query_params: dict | None = None, include_deleted: bool = False, coerce_to_timestamp: bool = False, record_time_added: bool = False, @@ -78,7 +79,7 @@ def __init__( replace: bool = False, encrypt: bool = False, gzip: bool = False, - acl_policy: Optional[str] = None, + acl_policy: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -97,7 +98,7 @@ def __init__( self.gzip = gzip self.acl_policy = acl_policy - def execute(self, context: 'Context') -> str: + def execute(self, context: Context) -> str: salesforce_hook = SalesforceHook(salesforce_conn_id=self.salesforce_conn_id) response = salesforce_hook.make_query( query=self.salesforce_query, diff --git a/airflow/providers/amazon/aws/transfers/sftp_to_s3.py b/airflow/providers/amazon/aws/transfers/sftp_to_s3.py index 71376e3179c0a..546c4710267b4 100644 --- a/airflow/providers/amazon/aws/transfers/sftp_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/sftp_to_s3.py @@ -15,9 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Sequence -from urllib.parse import urlparse +from urllib.parse import urlsplit from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -50,7 +52,7 @@ class SFTPToS3Operator(BaseOperator): if False streams file from SFTP to S3. """ - template_fields: Sequence[str] = ('s3_key', 'sftp_path') + template_fields: Sequence[str] = ("s3_key", "sftp_path") def __init__( self, @@ -58,8 +60,8 @@ def __init__( s3_bucket: str, s3_key: str, sftp_path: str, - sftp_conn_id: str = 'ssh_default', - s3_conn_id: str = 'aws_default', + sftp_conn_id: str = "ssh_default", + s3_conn_id: str = "aws_default", use_temp_file: bool = True, **kwargs, ) -> None: @@ -74,10 +76,10 @@ def __init__( @staticmethod def get_s3_key(s3_key: str) -> str: """This parses the correct format for S3 keys regardless of how the S3 url is passed.""" - parsed_s3_key = urlparse(s3_key) - return parsed_s3_key.path.lstrip('/') + parsed_s3_key = urlsplit(s3_key) + return parsed_s3_key.path.lstrip("/") - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: self.s3_key = self.get_s3_key(self.s3_key) ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id) s3_hook = S3Hook(self.s3_conn_id) @@ -90,5 +92,5 @@ def execute(self, context: 'Context') -> None: s3_hook.load_file(filename=f.name, key=self.s3_key, bucket_name=self.s3_bucket, replace=True) else: - with sftp_client.file(self.sftp_path, mode='rb') as data: + with sftp_client.file(self.sftp_path, mode="rb") as data: s3_hook.get_conn().upload_fileobj(data, self.s3_bucket, self.s3_key, Callback=self.log.info) diff --git a/airflow/providers/amazon/aws/transfers/sql_to_s3.py b/airflow/providers/amazon/aws/transfers/sql_to_s3.py index f399c271416e4..713fbc005901f 100644 --- a/airflow/providers/amazon/aws/transfers/sql_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/sql_to_s3.py @@ -15,11 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations +import enum from collections import namedtuple -from enum import Enum from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union +from typing import TYPE_CHECKING, Iterable, Mapping, Sequence import numpy as np import pandas as pd @@ -27,25 +28,28 @@ from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook -from airflow.hooks.dbapi import DbApiHook from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.common.sql.hooks.sql import DbApiHook if TYPE_CHECKING: from airflow.utils.context import Context -FILE_FORMAT = Enum( - "FILE_FORMAT", - "CSV, JSON, PARQUET", -) +class FILE_FORMAT(enum.Enum): + """Possible file formats.""" -FileOptions = namedtuple('FileOptions', ['mode', 'suffix', 'function']) + CSV = enum.auto() + JSON = enum.auto() + PARQUET = enum.auto() + + +FileOptions = namedtuple("FileOptions", ["mode", "suffix", "function"]) FILE_OPTIONS_MAP = { - FILE_FORMAT.CSV: FileOptions('r+', '.csv', 'to_csv'), - FILE_FORMAT.JSON: FileOptions('r+', '.json', 'to_json'), - FILE_FORMAT.PARQUET: FileOptions('rb+', '.parquet', 'to_parquet'), + FILE_FORMAT.CSV: FileOptions("r+", ".csv", "to_csv"), + FILE_FORMAT.JSON: FileOptions("r+", ".json", "to_json"), + FILE_FORMAT.PARQUET: FileOptions("rb+", ".parquet", "to_parquet"), } @@ -79,11 +83,12 @@ class SqlToS3Operator(BaseOperator): """ template_fields: Sequence[str] = ( - 's3_bucket', - 's3_key', - 'query', + "s3_bucket", + "s3_key", + "query", + "sql_conn_id", ) - template_ext: Sequence[str] = ('.sql',) + template_ext: Sequence[str] = (".sql",) template_fields_renderers = { "query": "sql", "pd_kwargs": "json", @@ -96,12 +101,12 @@ def __init__( s3_bucket: str, s3_key: str, sql_conn_id: str, - parameters: Union[None, Mapping, Iterable] = None, + parameters: None | Mapping | Iterable = None, replace: bool = False, - aws_conn_id: str = 'aws_default', - verify: Optional[Union[bool, str]] = None, - file_format: Literal['csv', 'json', 'parquet'] = 'csv', - pd_kwargs: Optional[dict] = None, + aws_conn_id: str = "aws_default", + verify: bool | str | None = None, + file_format: Literal["csv", "json", "parquet"] = "csv", + pd_kwargs: dict | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -116,17 +121,26 @@ def __init__( self.parameters = parameters if "path_or_buf" in self.pd_kwargs: - raise AirflowException('The argument path_or_buf is not allowed, please remove it') - - self.file_format = getattr(FILE_FORMAT, file_format.upper(), None) + raise AirflowException("The argument path_or_buf is not allowed, please remove it") - if self.file_format is None: + try: + self.file_format = FILE_FORMAT[file_format.upper()] + except KeyError: raise AirflowException(f"The argument file_format doesn't support {file_format} value.") @staticmethod - def _fix_int_dtypes(df: pd.DataFrame) -> None: - """Mutate DataFrame to set dtypes for int columns containing NaN values.""" + def _fix_dtypes(df: pd.DataFrame, file_format: FILE_FORMAT) -> None: + """ + Mutate DataFrame to set dtypes for float columns containing NaN values. + Set dtype of object to str to allow for downstream transformations. + """ for col in df: + + if df[col].dtype.name == "object" and file_format == "parquet": + # if the type wasn't identified or converted, change it to a string so if can still be + # processed. + df[col] = df[col].astype(str) + if "float" in df[col].dtype.name and df[col].hasnans: # inspect values to determine if dtype of non-null values is int or float notna_series = df[col].dropna().values @@ -139,13 +153,13 @@ def _fix_int_dtypes(df: pd.DataFrame) -> None: df[col] = np.where(df[col].isnull(), None, df[col]) df[col] = df[col].astype(pd.Float64Dtype()) - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: sql_hook = self._get_hook() s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) data_df = sql_hook.get_pandas_df(sql=self.query, parameters=self.parameters) self.log.info("Data from SQL obtained") - self._fix_int_dtypes(data_df) + self._fix_dtypes(data_df, self.file_format) file_options = FILE_OPTIONS_MAP[self.file_format] with NamedTemporaryFile(mode=file_options.mode, suffix=file_options.suffix) as tmp_file: @@ -162,7 +176,7 @@ def _get_hook(self) -> DbApiHook: self.log.debug("Get connection for %s", self.sql_conn_id) conn = BaseHook.get_connection(self.sql_conn_id) hook = conn.get_hook() - if not callable(getattr(hook, 'get_pandas_df', None)): + if not callable(getattr(hook, "get_pandas_df", None)): raise AirflowException( "This hook is not supported. The hook class must have get_pandas_df method." ) diff --git a/airflow/providers/amazon/aws/utils/__init__.py b/airflow/providers/amazon/aws/utils/__init__.py index 13a83393a9124..0aff1289dc5e5 100644 --- a/airflow/providers/amazon/aws/utils/__init__.py +++ b/airflow/providers/amazon/aws/utils/__init__.py @@ -14,3 +14,33 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + +import re +from datetime import datetime + +from airflow.version import version + + +def trim_none_values(obj: dict): + return {key: val for key, val in obj.items() if val is not None} + + +def datetime_to_epoch(date_time: datetime) -> int: + """Convert a datetime object to an epoch integer (seconds).""" + return int(date_time.timestamp()) + + +def datetime_to_epoch_ms(date_time: datetime) -> int: + """Convert a datetime object to an epoch integer (milliseconds).""" + return int(date_time.timestamp() * 1_000) + + +def datetime_to_epoch_us(date_time: datetime) -> int: + """Convert a datetime object to an epoch integer (microseconds).""" + return int(date_time.timestamp() * 1_000_000) + + +def get_airflow_version() -> tuple[int, ...]: + val = re.sub(r"(\d+\.\d+\.\d+).*", lambda x: x.group(1), version) + return tuple(int(x) for x in val.split(".")) diff --git a/airflow/providers/amazon/aws/utils/connection_wrapper.py b/airflow/providers/amazon/aws/utils/connection_wrapper.py new file mode 100644 index 0000000000000..f6ae3fdf01837 --- /dev/null +++ b/airflow/providers/amazon/aws/utils/connection_wrapper.py @@ -0,0 +1,475 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import warnings +from copy import deepcopy +from dataclasses import MISSING, InitVar, dataclass, field, fields +from typing import TYPE_CHECKING, Any + +from botocore.config import Config + +from airflow.compat.functools import cached_property +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.utils import trim_none_values +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.log.secrets_masker import mask_secret +from airflow.utils.types import NOTSET, ArgNotSet + +if TYPE_CHECKING: + from airflow.models.connection import Connection # Avoid circular imports. + + +@dataclass +class _ConnectionMetadata: + """Connection metadata data-class. + + This class implements main :ref:`~airflow.models.connection.Connection` attributes + and use in AwsConnectionWrapper for avoid circular imports. + + Only for internal usage, this class might change or removed in the future. + """ + + conn_id: str | None = None + conn_type: str | None = None + description: str | None = None + host: str | None = None + login: str | None = None + password: str | None = None + schema: str | None = None + port: int | None = None + extra: str | dict | None = None + + @property + def extra_dejson(self): + if not self.extra: + return {} + extra = deepcopy(self.extra) + if isinstance(extra, str): + try: + extra = json.loads(extra) + except json.JSONDecodeError as err: + raise AirflowException( + f"'extra' expected valid JSON-Object string. Original error:\n * {err}" + ) from None + if not isinstance(extra, dict): + raise TypeError(f"Expected JSON-Object or dict, got {type(extra).__name__}.") + return extra + + +@dataclass +class AwsConnectionWrapper(LoggingMixin): + """ + AWS Connection Wrapper class helper. + Use for validate and resolve AWS Connection parameters. + + ``conn`` reference to Airflow Connection object or AwsConnectionWrapper + if it set to ``None`` than default values would use. + + The precedence rules for ``region_name`` + 1. Explicit set (in Hook) ``region_name``. + 2. Airflow Connection Extra 'region_name'. + + The precedence rules for ``botocore_config`` + 1. Explicit set (in Hook) ``botocore_config``. + 2. Construct from Airflow Connection Extra 'botocore_kwargs'. + 3. The wrapper's default value + """ + + conn: InitVar[Connection | AwsConnectionWrapper | _ConnectionMetadata | None] + region_name: str | None = field(default=None) + # boto3 client/resource configs + botocore_config: Config | None = field(default=None) + verify: bool | str | None = field(default=None) + + # Reference to Airflow Connection attributes + # ``extra_config`` contains original Airflow Connection Extra. + conn_id: str | ArgNotSet | None = field(init=False, default=NOTSET) + conn_type: str | None = field(init=False, default=None) + login: str | None = field(init=False, repr=False, default=None) + password: str | None = field(init=False, repr=False, default=None) + extra_config: dict[str, Any] = field(init=False, repr=False, default_factory=dict) + + # AWS Credentials from connection. + aws_access_key_id: str | None = field(init=False, default=None) + aws_secret_access_key: str | None = field(init=False, default=None) + aws_session_token: str | None = field(init=False, default=None) + + # AWS Shared Credential profile_name + profile_name: str | None = field(init=False, default=None) + # Custom endpoint_url for boto3.client and boto3.resource + endpoint_url: str | None = field(init=False, default=None) + + # Assume Role Configurations + role_arn: str | None = field(init=False, default=None) + assume_role_method: str | None = field(init=False, default=None) + assume_role_kwargs: dict[str, Any] = field(init=False, default_factory=dict) + + @cached_property + def conn_repr(self): + return f"AWS Connection (conn_id={self.conn_id!r}, conn_type={self.conn_type!r})" + + def __post_init__(self, conn: Connection): + if isinstance(conn, type(self)): + # For every field with init=False we copy reference value from original wrapper + # For every field with init=True we use init values if it not equal default + # We can't use ``dataclasses.replace`` in classmethod because + # we limited by InitVar arguments since it not stored in object, + # and also we do not want to run __post_init__ method again which print all logs/warnings again. + for fl in fields(conn): + value = getattr(conn, fl.name) + if not fl.init: + setattr(self, fl.name, value) + else: + if fl.default is not MISSING: + default = fl.default + elif fl.default_factory is not MISSING: + default = fl.default_factory() # zero-argument callable + else: + continue # Value mandatory, skip + + orig_value = getattr(self, fl.name) + if orig_value == default: + # Only replace value if it not equal default value + setattr(self, fl.name, value) + return + elif not conn: + return + + # Assign attributes from AWS Connection + self.conn_id = conn.conn_id + self.conn_type = conn.conn_type or "aws" + self.login = conn.login + self.password = conn.password + self.extra_config = deepcopy(conn.extra_dejson) + + if self.conn_type.lower() == "s3": + warnings.warn( + f"{self.conn_repr} has connection type 's3', " + "which has been replaced by connection type 'aws'. " + "Please update your connection to have `conn_type='aws'`.", + DeprecationWarning, + stacklevel=2, + ) + elif self.conn_type != "aws": + warnings.warn( + f"{self.conn_repr} expected connection type 'aws', got {self.conn_type!r}. " + "This connection might not work correctly. " + "Please use Amazon Web Services Connection type.", + UserWarning, + stacklevel=2, + ) + + extra = deepcopy(conn.extra_dejson) + session_kwargs = extra.get("session_kwargs", {}) + if session_kwargs: + warnings.warn( + "'session_kwargs' in extra config is deprecated and will be removed in a future releases. " + f"Please specify arguments passed to boto3 Session directly in {self.conn_repr} extra.", + DeprecationWarning, + stacklevel=2, + ) + + # Retrieve initial connection credentials + init_credentials = self._get_credentials(**extra) + self.aws_access_key_id, self.aws_secret_access_key, self.aws_session_token = init_credentials + + if not self.region_name: + if "region_name" in extra: + self.region_name = extra["region_name"] + self.log.debug("Retrieving region_name=%s from %s extra.", self.region_name, self.conn_repr) + elif "region_name" in session_kwargs: + self.region_name = session_kwargs["region_name"] + self.log.debug( + "Retrieving region_name=%s from %s extra['session_kwargs'].", + self.region_name, + self.conn_repr, + ) + + if self.verify is None and "verify" in extra: + self.verify = extra["verify"] + self.log.debug("Retrieving verify=%s from %s extra.", self.verify, self.conn_repr) + + if "profile_name" in extra: + self.profile_name = extra["profile_name"] + self.log.debug("Retrieving profile_name=%s from %s extra.", self.profile_name, self.conn_repr) + elif "profile_name" in session_kwargs: + self.profile_name = session_kwargs["profile_name"] + self.log.debug( + "Retrieving profile_name=%s from %s extra['session_kwargs'].", + self.profile_name, + self.conn_repr, + ) + + # Warn the user that an invalid parameter is being used which actually not related to 'profile_name'. + # ToDo: Remove this check entirely as soon as drop support credentials from s3_config_file + if "profile" in extra and "s3_config_file" not in extra and not self.profile_name: + warnings.warn( + f"Found 'profile' without specifying 's3_config_file' in {self.conn_repr} extra. " + "If required profile from AWS Shared Credentials please " + f"set 'profile_name' in {self.conn_repr} extra.", + UserWarning, + stacklevel=2, + ) + + config_kwargs = extra.get("config_kwargs") + if not self.botocore_config and config_kwargs: + # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + self.log.debug("Retrieving botocore config=%s from %s extra.", config_kwargs, self.conn_repr) + self.botocore_config = Config(**config_kwargs) + + if conn.host: + warnings.warn( + f"Host {conn.host} specified in the connection is not used." + " Please, set it on extra['endpoint_url'] instead", + DeprecationWarning, + stacklevel=2, + ) + + self.endpoint_url = extra.get("host") + if self.endpoint_url: + warnings.warn( + "extra['host'] is deprecated and will be removed in a future release." + " Please set extra['endpoint_url'] instead", + DeprecationWarning, + stacklevel=2, + ) + else: + self.endpoint_url = extra.get("endpoint_url") + + # Retrieve Assume Role Configuration + assume_role_configs = self._get_assume_role_configs(**extra) + self.role_arn, self.assume_role_method, self.assume_role_kwargs = assume_role_configs + + @classmethod + def from_connection_metadata( + cls, + conn_id: str | None = None, + login: str | None = None, + password: str | None = None, + extra: dict[str, Any] | None = None, + ): + """ + Create config from connection metadata. + + :param conn_id: Custom connection ID. + :param login: AWS Access Key ID. + :param password: AWS Secret Access Key. + :param extra: Connection Extra metadata. + """ + conn_meta = _ConnectionMetadata( + conn_id=conn_id, conn_type="aws", login=login, password=password, extra=extra + ) + return cls(conn=conn_meta) + + @property + def extra_dejson(self): + """Compatibility with `airflow.models.Connection.extra_dejson` property.""" + return self.extra_config + + @property + def session_kwargs(self) -> dict[str, Any]: + """Additional kwargs passed to boto3.session.Session.""" + return trim_none_values( + { + "aws_access_key_id": self.aws_access_key_id, + "aws_secret_access_key": self.aws_secret_access_key, + "aws_session_token": self.aws_session_token, + "region_name": self.region_name, + "profile_name": self.profile_name, + } + ) + + def __bool__(self): + return self.conn_id is not NOTSET + + def _get_credentials( + self, + *, + aws_access_key_id: str | None = None, + aws_secret_access_key: str | None = None, + aws_session_token: str | None = None, + # Deprecated Values + s3_config_file: str | None = None, + s3_config_format: str | None = None, + profile: str | None = None, + session_kwargs: dict[str, Any] | None = None, + **kwargs, + ) -> tuple[str | None, str | None, str | None]: + """ + Get AWS credentials from connection login/password and extra. + + ``aws_access_key_id`` and ``aws_secret_access_key`` order + 1. From Connection login and password + 2. From Connection extra['aws_access_key_id'] and extra['aws_access_key_id'] + 3. (deprecated) Form Connection extra['session_kwargs'] + 4. (deprecated) From local credentials file + + Get ``aws_session_token`` from extra['aws_access_key_id'] + + """ + session_kwargs = session_kwargs or {} + session_aws_access_key_id = session_kwargs.get("aws_access_key_id") + session_aws_secret_access_key = session_kwargs.get("aws_secret_access_key") + session_aws_session_token = session_kwargs.get("aws_session_token") + + if self.login and self.password: + self.log.info("%s credentials retrieved from login and password.", self.conn_repr) + aws_access_key_id, aws_secret_access_key = self.login, self.password + elif aws_access_key_id and aws_secret_access_key: + self.log.info("%s credentials retrieved from extra.", self.conn_repr) + elif session_aws_access_key_id and session_aws_secret_access_key: + aws_access_key_id = session_aws_access_key_id + aws_secret_access_key = session_aws_secret_access_key + self.log.info("%s credentials retrieved from extra['session_kwargs'].", self.conn_repr) + elif s3_config_file: + aws_access_key_id, aws_secret_access_key = _parse_s3_config( + s3_config_file, + s3_config_format, + profile, + ) + self.log.info("%s credentials retrieved from extra['s3_config_file']", self.conn_repr) + + if aws_session_token: + self.log.info( + "%s session token retrieved from extra, please note you are responsible for renewing these.", + self.conn_repr, + ) + elif session_aws_session_token: + aws_session_token = session_aws_session_token + self.log.info( + "%s session token retrieved from extra['session_kwargs'], " + "please note you are responsible for renewing these.", + self.conn_repr, + ) + + return aws_access_key_id, aws_secret_access_key, aws_session_token + + def _get_assume_role_configs( + self, + *, + role_arn: str | None = None, + assume_role_method: str = "assume_role", + assume_role_kwargs: dict[str, Any] | None = None, + # Deprecated Values + aws_account_id: str | None = None, + aws_iam_role: str | None = None, + external_id: str | None = None, + **kwargs, + ) -> tuple[str | None, str | None, dict[Any, str]]: + """Get assume role configs from Connection extra.""" + if role_arn: + self.log.debug("Retrieving role_arn=%r from %s extra.", role_arn, self.conn_repr) + elif aws_account_id and aws_iam_role: + warnings.warn( + "Constructing 'role_arn' from extra['aws_account_id'] and extra['aws_iam_role'] is deprecated" + f" and will be removed in a future releases." + f" Please set 'role_arn' in {self.conn_repr} extra.", + DeprecationWarning, + stacklevel=3, + ) + role_arn = f"arn:aws:iam::{aws_account_id}:role/{aws_iam_role}" + self.log.debug( + "Constructions role_arn=%r from %s extra['aws_account_id'] and extra['aws_iam_role'].", + role_arn, + self.conn_repr, + ) + + if not role_arn: + # There is no reason obtain `assume_role_method` and `assume_role_kwargs` if `role_arn` not set. + return None, None, {} + + supported_methods = ["assume_role", "assume_role_with_saml", "assume_role_with_web_identity"] + if assume_role_method not in supported_methods: + raise NotImplementedError( + f"Found assume_role_method={assume_role_method!r} in {self.conn_repr} extra." + f" Currently {supported_methods} are supported." + ' (Exclude this setting will default to "assume_role").' + ) + self.log.debug("Retrieve assume_role_method=%r from %s.", assume_role_method, self.conn_repr) + + assume_role_kwargs = assume_role_kwargs or {} + if "ExternalId" not in assume_role_kwargs and external_id: + warnings.warn( + "'external_id' in extra config is deprecated and will be removed in a future releases. " + f"Please set 'ExternalId' in 'assume_role_kwargs' in {self.conn_repr} extra.", + DeprecationWarning, + stacklevel=3, + ) + assume_role_kwargs["ExternalId"] = external_id + + return role_arn, assume_role_method, assume_role_kwargs + + +def _parse_s3_config( + config_file_name: str, config_format: str | None = "boto", profile: str | None = None +) -> tuple[str | None, str | None]: + """ + Parses a config file for s3 credentials. Can currently + parse boto, s3cmd.conf and AWS SDK config formats + + :param config_file_name: path to the config file + :param config_format: config type. One of "boto", "s3cmd" or "aws". + Defaults to "boto" + :param profile: profile name in AWS type config file + """ + warnings.warn( + "Use local credentials file is never documented and well tested. " + "Obtain credentials by this way deprecated and will be removed in a future releases.", + DeprecationWarning, + stacklevel=4, + ) + + import configparser + + config = configparser.ConfigParser() + if config.read(config_file_name): # pragma: no cover + sections = config.sections() + else: + raise AirflowException(f"Couldn't read {config_file_name}") + # Setting option names depending on file format + if config_format is None: + config_format = "boto" + conf_format = config_format.lower() + if conf_format == "boto": # pragma: no cover + if profile is not None and "profile " + profile in sections: + cred_section = "profile " + profile + else: + cred_section = "Credentials" + elif conf_format == "aws" and profile is not None: + cred_section = profile + else: + cred_section = "default" + # Option names + if conf_format in ("boto", "aws"): # pragma: no cover + key_id_option = "aws_access_key_id" + secret_key_option = "aws_secret_access_key" + else: + key_id_option = "access_key" + secret_key_option = "secret_key" + # Actual Parsing + if cred_section not in sections: + raise AirflowException("This config file format is not recognized") + else: + try: + access_key = config.get(cred_section, key_id_option) + secret_key = config.get(cred_section, secret_key_option) + mask_secret(secret_key) + except Exception: + raise AirflowException("Option Error in parsing s3 config file") + return access_key, secret_key diff --git a/airflow/providers/amazon/aws/utils/eks_get_token.py b/airflow/providers/amazon/aws/utils/eks_get_token.py index d9422b35e301c..27bf51676f8b1 100644 --- a/airflow/providers/amazon/aws/utils/eks_get_token.py +++ b/airflow/providers/amazon/aws/utils/eks_get_token.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import argparse import json @@ -29,23 +30,23 @@ def get_expiration_time(): token_expiration = datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRATION_MINUTES) - return token_expiration.strftime('%Y-%m-%dT%H:%M:%SZ') + return token_expiration.strftime("%Y-%m-%dT%H:%M:%SZ") def get_parser(): - parser = argparse.ArgumentParser(description='Get a token for authentication with an Amazon EKS cluster.') + parser = argparse.ArgumentParser(description="Get a token for authentication with an Amazon EKS cluster.") parser.add_argument( - '--cluster-name', help='The name of the cluster to generate kubeconfig file for.', required=True + "--cluster-name", help="The name of the cluster to generate kubeconfig file for.", required=True ) parser.add_argument( - '--aws-conn-id', + "--aws-conn-id", help=( - 'The Airflow connection used for AWS credentials. ' - 'If not specified or empty then the default boto3 behaviour is used.' + "The Airflow connection used for AWS credentials. " + "If not specified or empty then the default boto3 behaviour is used." ), ) parser.add_argument( - '--region-name', help='AWS region_name. If not specified then the default boto3 behaviour is used.' + "--region-name", help="AWS region_name. If not specified then the default boto3 behaviour is used." ) return parser @@ -66,5 +67,5 @@ def main(): print(json.dumps(exec_credential_object)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/airflow/providers/amazon/aws/utils/emailer.py b/airflow/providers/amazon/aws/utils/emailer.py index 7f0356a3a8d59..3e00abc78a08c 100644 --- a/airflow/providers/amazon/aws/utils/emailer.py +++ b/airflow/providers/amazon/aws/utils/emailer.py @@ -16,23 +16,25 @@ # specific language governing permissions and limitations # under the License. """Airflow module for email backend using AWS SES""" -from typing import Any, Dict, List, Optional, Union +from __future__ import annotations + +from typing import Any from airflow.providers.amazon.aws.hooks.ses import SesHook def send_email( - to: Union[List[str], str], + to: list[str] | str, subject: str, html_content: str, - files: Optional[List] = None, - cc: Optional[Union[List[str], str]] = None, - bcc: Optional[Union[List[str], str]] = None, - mime_subtype: str = 'mixed', - mime_charset: str = 'utf-8', - conn_id: str = 'aws_default', - from_email: Optional[str] = None, - custom_headers: Optional[Dict[str, Any]] = None, + files: list | None = None, + cc: list[str] | str | None = None, + bcc: list[str] | str | None = None, + mime_subtype: str = "mixed", + mime_charset: str = "utf-8", + conn_id: str = "aws_default", + from_email: str | None = None, + custom_headers: dict[str, Any] | None = None, **kwargs, ) -> None: """Email backend for SES.""" diff --git a/airflow/providers/amazon/aws/utils/rds.py b/airflow/providers/amazon/aws/utils/rds.py index 154f65b5560c1..873f2cf83ecf0 100644 --- a/airflow/providers/amazon/aws/utils/rds.py +++ b/airflow/providers/amazon/aws/utils/rds.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from enum import Enum diff --git a/airflow/providers/amazon/aws/utils/redshift.py b/airflow/providers/amazon/aws/utils/redshift.py index bb64c9b46f463..d931cb047430d 100644 --- a/airflow/providers/amazon/aws/utils/redshift.py +++ b/airflow/providers/amazon/aws/utils/redshift.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import logging diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 413b6dcfed7ca..59277d469d37a 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -22,6 +22,12 @@ description: | Amazon integration (including `Amazon Web Services (AWS) `__). versions: + - 6.1.0 + - 6.0.0 + - 5.1.0 + - 5.0.0 + - 4.1.0 + - 4.0.0 - 3.4.0 - 3.3.0 - 3.2.0 @@ -40,8 +46,22 @@ versions: - 1.1.0 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-common-sql>=1.3.1 + - boto3>=1.24.0 + # watchtower 3 has been released end Jan and introduced breaking change across the board that might + # change logging behaviour: + # https://github.com/kislyuk/watchtower/blob/develop/Changes.rst#changes-for-v300-2022-01-26 + # TODO: update to watchtower >3 + - watchtower~=2.0.1 + - jsonpath_ng>=1.5.3 + - redshift_connector>=2.0.888 + - sqlalchemy_redshift>=0.8.6 + - pandas>=0.17.1 + - mypy-boto3-rds>=1.24.0 + - mypy-boto3-redshift-data>=1.24.0 + - mypy-boto3-appflow>=1.24.0 integrations: - integration-name: Amazon Athena @@ -101,6 +121,12 @@ integrations: - /docs/apache-airflow-providers-amazon/operators/emr_eks.rst logo: /integration-logos/aws/Amazon-EMR_light-bg@4x.png tags: [aws] + - integration-name: Amazon EMR Serverless + external-doc-url: https://docs.aws.amazon.com/emr/latest/EMR-Serverless-UserGuide/emr-serverless.html + how-to-guide: + - /docs/apache-airflow-providers-amazon/operators/emr_serverless.rst + logo: /integration-logos/aws/Amazon-EMR_light-bg@4x.png + tags: [aws] - integration-name: Amazon Glacier external-doc-url: https://aws.amazon.com/glacier/ logo: /integration-logos/aws/Amazon-S3-Glacier_light-bg@4x.png @@ -212,6 +238,12 @@ integrations: external-doc-url: https://docs.aws.amazon.com/STS/latest/APIReference/welcome.html logo: /integration-logos/aws/AWS-STS_light-bg@4x.png tags: [aws] + - integration-name: Amazon Appflow + external-doc-url: https://docs.aws.amazon.com/appflow/1.0/APIReference/Welcome.html + logo: /integration-logos/aws/Amazon_AppFlow_light.png + how-to-guide: + - /docs/apache-airflow-providers-amazon/operators/appflow.rst + tags: [aws] operators: - integration-name: Amazon Athena @@ -229,15 +261,8 @@ operators: - integration-name: AWS Database Migration Service python-modules: - airflow.providers.amazon.aws.operators.dms - - airflow.providers.amazon.aws.operators.dms_create_task - - airflow.providers.amazon.aws.operators.dms_delete_task - - airflow.providers.amazon.aws.operators.dms_describe_tasks - - airflow.providers.amazon.aws.operators.dms_start_task - - airflow.providers.amazon.aws.operators.dms_stop_task - integration-name: Amazon EC2 python-modules: - - airflow.providers.amazon.aws.operators.ec2_start_instance - - airflow.providers.amazon.aws.operators.ec2_stop_instance - airflow.providers.amazon.aws.operators.ec2 - integration-name: Amazon ECS python-modules: @@ -248,14 +273,9 @@ operators: - integration-name: Amazon EMR python-modules: - airflow.providers.amazon.aws.operators.emr - - airflow.providers.amazon.aws.operators.emr_add_steps - - airflow.providers.amazon.aws.operators.emr_create_job_flow - - airflow.providers.amazon.aws.operators.emr_modify_cluster - - airflow.providers.amazon.aws.operators.emr_terminate_job_flow - integration-name: Amazon EMR on EKS python-modules: - airflow.providers.amazon.aws.operators.emr - - airflow.providers.amazon.aws.operators.emr_containers - integration-name: Amazon Glacier python-modules: - airflow.providers.amazon.aws.operators.glacier @@ -266,27 +286,13 @@ operators: - integration-name: AWS Lambda python-modules: - airflow.providers.amazon.aws.operators.aws_lambda + - airflow.providers.amazon.aws.operators.lambda_function - integration-name: Amazon Simple Storage Service (S3) python-modules: - - airflow.providers.amazon.aws.operators.s3_bucket - - airflow.providers.amazon.aws.operators.s3_bucket_tagging - - airflow.providers.amazon.aws.operators.s3_copy_object - - airflow.providers.amazon.aws.operators.s3_delete_objects - - airflow.providers.amazon.aws.operators.s3_file_transform - - airflow.providers.amazon.aws.operators.s3_list - - airflow.providers.amazon.aws.operators.s3_list_prefixes - airflow.providers.amazon.aws.operators.s3 - integration-name: Amazon SageMaker python-modules: - airflow.providers.amazon.aws.operators.sagemaker - - airflow.providers.amazon.aws.operators.sagemaker_base - - airflow.providers.amazon.aws.operators.sagemaker_endpoint - - airflow.providers.amazon.aws.operators.sagemaker_endpoint_config - - airflow.providers.amazon.aws.operators.sagemaker_model - - airflow.providers.amazon.aws.operators.sagemaker_processing - - airflow.providers.amazon.aws.operators.sagemaker_training - - airflow.providers.amazon.aws.operators.sagemaker_transform - - airflow.providers.amazon.aws.operators.sagemaker_tuning - integration-name: Amazon Simple Notification Service (SNS) python-modules: - airflow.providers.amazon.aws.operators.sns @@ -295,21 +301,21 @@ operators: - airflow.providers.amazon.aws.operators.sqs - integration-name: AWS Step Functions python-modules: - - airflow.providers.amazon.aws.operators.step_function_get_execution_output - - airflow.providers.amazon.aws.operators.step_function_start_execution - airflow.providers.amazon.aws.operators.step_function - integration-name: Amazon RDS python-modules: - airflow.providers.amazon.aws.operators.rds - integration-name: Amazon Redshift python-modules: - - airflow.providers.amazon.aws.operators.redshift - airflow.providers.amazon.aws.operators.redshift_sql - airflow.providers.amazon.aws.operators.redshift_cluster - airflow.providers.amazon.aws.operators.redshift_data - integration-name: Amazon QuickSight python-modules: - airflow.providers.amazon.aws.operators.quicksight + - integration-name: Amazon Appflow + python-modules: + - airflow.providers.amazon.aws.operators.appflow sensors: - integration-name: Amazon Athena @@ -323,25 +329,22 @@ sensors: - airflow.providers.amazon.aws.sensors.cloud_formation - integration-name: AWS Database Migration Service python-modules: - - airflow.providers.amazon.aws.sensors.dms_task - airflow.providers.amazon.aws.sensors.dms - integration-name: Amazon EC2 python-modules: - - airflow.providers.amazon.aws.sensors.ec2_instance_state - airflow.providers.amazon.aws.sensors.ec2 + - integration-name: Amazon ECS + python-modules: + - airflow.providers.amazon.aws.sensors.ecs - integration-name: Amazon Elastic Kubernetes Service (EKS) python-modules: - airflow.providers.amazon.aws.sensors.eks - integration-name: Amazon EMR python-modules: - airflow.providers.amazon.aws.sensors.emr - - airflow.providers.amazon.aws.sensors.emr_base - - airflow.providers.amazon.aws.sensors.emr_job_flow - - airflow.providers.amazon.aws.sensors.emr_step - integration-name: Amazon EMR on EKS python-modules: - airflow.providers.amazon.aws.sensors.emr - - airflow.providers.amazon.aws.sensors.emr_containers - integration-name: Amazon Glacier python-modules: - airflow.providers.amazon.aws.sensors.glacier @@ -355,28 +358,18 @@ sensors: - airflow.providers.amazon.aws.sensors.rds - integration-name: Amazon Redshift python-modules: - - airflow.providers.amazon.aws.sensors.redshift - airflow.providers.amazon.aws.sensors.redshift_cluster - integration-name: Amazon Simple Storage Service (S3) python-modules: - - airflow.providers.amazon.aws.sensors.s3_key - - airflow.providers.amazon.aws.sensors.s3_keys_unchanged - - airflow.providers.amazon.aws.sensors.s3_prefix - airflow.providers.amazon.aws.sensors.s3 - integration-name: Amazon SageMaker python-modules: - airflow.providers.amazon.aws.sensors.sagemaker - - airflow.providers.amazon.aws.sensors.sagemaker_base - - airflow.providers.amazon.aws.sensors.sagemaker_endpoint - - airflow.providers.amazon.aws.sensors.sagemaker_training - - airflow.providers.amazon.aws.sensors.sagemaker_transform - - airflow.providers.amazon.aws.sensors.sagemaker_tuning - integration-name: Amazon Simple Queue Service (SQS) python-modules: - airflow.providers.amazon.aws.sensors.sqs - integration-name: AWS Step Functions python-modules: - - airflow.providers.amazon.aws.sensors.step_function_execution - airflow.providers.amazon.aws.sensors.step_function - integration-name: Amazon QuickSight python-modules: @@ -389,7 +382,6 @@ hooks: - integration-name: Amazon DynamoDB python-modules: - airflow.providers.amazon.aws.hooks.dynamodb - - airflow.providers.amazon.aws.hooks.aws_dynamodb - integration-name: Amazon Web Services python-modules: - airflow.providers.amazon.aws.hooks.base_aws @@ -409,6 +401,9 @@ hooks: - integration-name: Amazon EC2 python-modules: - airflow.providers.amazon.aws.hooks.ec2 + - integration-name: Amazon ECS + python-modules: + - airflow.providers.amazon.aws.hooks.ecs - integration-name: Amazon ElastiCache python-modules: - airflow.providers.amazon.aws.hooks.elasticache_replication_group @@ -420,7 +415,7 @@ hooks: - airflow.providers.amazon.aws.hooks.emr - integration-name: Amazon EMR on EKS python-modules: - - airflow.providers.amazon.aws.hooks.emr_containers + - airflow.providers.amazon.aws.hooks.emr - integration-name: Amazon Glacier python-modules: - airflow.providers.amazon.aws.hooks.glacier @@ -443,7 +438,6 @@ hooks: - airflow.providers.amazon.aws.hooks.rds - integration-name: Amazon Redshift python-modules: - - airflow.providers.amazon.aws.hooks.redshift - airflow.providers.amazon.aws.hooks.redshift_sql - airflow.providers.amazon.aws.hooks.redshift_cluster - airflow.providers.amazon.aws.hooks.redshift_data @@ -474,6 +468,9 @@ hooks: - integration-name: AWS Security Token Service (STS) python-modules: - airflow.providers.amazon.aws.hooks.sts + - integration-name: Amazon Appflow + python-modules: + - airflow.providers.amazon.aws.hooks.appflow transfers: - source-integration-name: Amazon DynamoDB @@ -504,10 +501,6 @@ transfers: target-integration-name: Amazon Simple Storage Service (S3) how-to-guide: /docs/apache-airflow-providers-amazon/operators/transfer/mongo_to_s3.rst python-module: airflow.providers.amazon.aws.transfers.mongo_to_s3 - - source-integration-name: MySQL - target-integration-name: Amazon Simple Storage Service (S3) - how-to-guide: /docs/apache-airflow-providers-amazon/operators/transfer/sql_to_s3.rst - python-module: airflow.providers.amazon.aws.transfers.mysql_to_s3 - source-integration-name: Amazon Redshift target-integration-name: Amazon Simple Storage Service (S3) how-to-guide: /docs/apache-airflow-providers-amazon/operators/transfer/redshift_to_s3.rst @@ -543,32 +536,26 @@ transfers: target-integration-name: Amazon Simple Storage Service (S3) how-to-guide: /docs/apache-airflow-providers-amazon/operators/transfer/local_to_s3.rst python-module: airflow.providers.amazon.aws.transfers.local_to_s3 - - source-integration-name: SQL + - source-integration-name: Common SQL target-integration-name: Amazon Simple Storage Service (S3) how-to-guide: /docs/apache-airflow-providers-amazon/operators/transfer/sql_to_s3.rst python-module: airflow.providers.amazon.aws.transfers.sql_to_s3 -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.amazon.aws.hooks.s3.S3Hook - - airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook - - airflow.providers.amazon.aws.hooks.emr.EmrHook - - airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook extra-links: - - airflow.providers.amazon.aws.operators.emr.EmrClusterLink - - airflow.providers.amazon.aws.operators.emr_create_job_flow.EmrClusterLink + - airflow.providers.amazon.aws.links.batch.BatchJobDefinitionLink + - airflow.providers.amazon.aws.links.batch.BatchJobDetailsLink + - airflow.providers.amazon.aws.links.batch.BatchJobQueueLink + - airflow.providers.amazon.aws.links.emr.EmrClusterLink + - airflow.providers.amazon.aws.links.logs.CloudWatchEventsLink connection-types: - - hook-class-name: airflow.providers.amazon.aws.hooks.s3.S3Hook - connection-type: s3 - - hook-class-name: airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook + - hook-class-name: airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook connection-type: aws - hook-class-name: airflow.providers.amazon.aws.hooks.emr.EmrHook connection-type: emr - hook-class-name: airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook connection-type: redshift - - hook-class-name: airflow.providers.amazon.aws.hooks.redshift.RedshiftDataHook - connection-type: aws secrets-backends: - airflow.providers.amazon.aws.secrets.secrets_manager.SecretsManagerBackend diff --git a/airflow/providers/apache/beam/.latest-doc-only-change.txt b/airflow/providers/apache/beam/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/apache/beam/.latest-doc-only-change.txt +++ b/airflow/providers/apache/beam/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/apache/beam/CHANGELOG.rst b/airflow/providers/apache/beam/CHANGELOG.rst index 8969c402ce18f..cfdad5d2fa201 100644 --- a/airflow/providers/apache/beam/CHANGELOG.rst +++ b/airflow/providers/apache/beam/CHANGELOG.rst @@ -16,9 +16,67 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +4.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Features +~~~~~~~~ + +* ``Add backward compatibility with old versions of Apache Beam (#27263)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + +4.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Features +~~~~~~~~ + +* ``Added missing project_id to the wait_for_job (#24020)`` +* ``Support impersonation service account parameter for Dataflow runner (#23961)`` + +Misc +~~~~ + +* ``chore: Refactoring and Cleaning Apache Providers (#24219)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``AIP-47 - Migrate beam DAGs to new design #22439 (#24211)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 3.4.0 ..... diff --git a/airflow/providers/apache/beam/example_dags/example_beam.py b/airflow/providers/apache/beam/example_dags/example_beam.py deleted file mode 100644 index ea52458303129..0000000000000 --- a/airflow/providers/apache/beam/example_dags/example_beam.py +++ /dev/null @@ -1,437 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG for Apache Beam operators -""" -import os -from datetime import datetime -from urllib.parse import urlparse - -from airflow import models -from airflow.providers.apache.beam.operators.beam import ( - BeamRunGoPipelineOperator, - BeamRunJavaPipelineOperator, - BeamRunPythonPipelineOperator, -) -from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus -from airflow.providers.google.cloud.operators.dataflow import DataflowConfiguration -from airflow.providers.google.cloud.sensors.dataflow import DataflowJobStatusSensor -from airflow.providers.google.cloud.transfers.gcs_to_local import GCSToLocalFilesystemOperator -from airflow.utils.trigger_rule import TriggerRule - -GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'example-project') -GCS_INPUT = os.environ.get('APACHE_BEAM_PYTHON', 'gs://INVALID BUCKET NAME/shakespeare/kinglear.txt') -GCS_TMP = os.environ.get('APACHE_BEAM_GCS_TMP', 'gs://INVALID BUCKET NAME/temp/') -GCS_STAGING = os.environ.get('APACHE_BEAM_GCS_STAGING', 'gs://INVALID BUCKET NAME/staging/') -GCS_OUTPUT = os.environ.get('APACHE_BEAM_GCS_OUTPUT', 'gs://INVALID BUCKET NAME/output') -GCS_PYTHON = os.environ.get('APACHE_BEAM_PYTHON', 'gs://INVALID BUCKET NAME/wordcount_debugging.py') -GCS_PYTHON_DATAFLOW_ASYNC = os.environ.get( - 'APACHE_BEAM_PYTHON_DATAFLOW_ASYNC', 'gs://INVALID BUCKET NAME/wordcount_debugging.py' -) -GCS_GO = os.environ.get('APACHE_BEAM_GO', 'gs://INVALID BUCKET NAME/wordcount_debugging.go') -GCS_GO_DATAFLOW_ASYNC = os.environ.get( - 'APACHE_BEAM_GO_DATAFLOW_ASYNC', 'gs://INVALID BUCKET NAME/wordcount_debugging.go' -) -GCS_JAR_DIRECT_RUNNER = os.environ.get( - 'APACHE_BEAM_DIRECT_RUNNER_JAR', - 'gs://INVALID BUCKET NAME/tests/dataflow-templates-bundled-java=11-beam-v2.25.0-DirectRunner.jar', -) -GCS_JAR_DATAFLOW_RUNNER = os.environ.get( - 'APACHE_BEAM_DATAFLOW_RUNNER_JAR', 'gs://INVALID BUCKET NAME/word-count-beam-bundled-0.1.jar' -) -GCS_JAR_SPARK_RUNNER = os.environ.get( - 'APACHE_BEAM_SPARK_RUNNER_JAR', - 'gs://INVALID BUCKET NAME/tests/dataflow-templates-bundled-java=11-beam-v2.25.0-SparkRunner.jar', -) -GCS_JAR_FLINK_RUNNER = os.environ.get( - 'APACHE_BEAM_FLINK_RUNNER_JAR', - 'gs://INVALID BUCKET NAME/tests/dataflow-templates-bundled-java=11-beam-v2.25.0-FlinkRunner.jar', -) - -GCS_JAR_DIRECT_RUNNER_PARTS = urlparse(GCS_JAR_DIRECT_RUNNER) -GCS_JAR_DIRECT_RUNNER_BUCKET_NAME = GCS_JAR_DIRECT_RUNNER_PARTS.netloc -GCS_JAR_DIRECT_RUNNER_OBJECT_NAME = GCS_JAR_DIRECT_RUNNER_PARTS.path[1:] -GCS_JAR_DATAFLOW_RUNNER_PARTS = urlparse(GCS_JAR_DATAFLOW_RUNNER) -GCS_JAR_DATAFLOW_RUNNER_BUCKET_NAME = GCS_JAR_DATAFLOW_RUNNER_PARTS.netloc -GCS_JAR_DATAFLOW_RUNNER_OBJECT_NAME = GCS_JAR_DATAFLOW_RUNNER_PARTS.path[1:] -GCS_JAR_SPARK_RUNNER_PARTS = urlparse(GCS_JAR_SPARK_RUNNER) -GCS_JAR_SPARK_RUNNER_BUCKET_NAME = GCS_JAR_SPARK_RUNNER_PARTS.netloc -GCS_JAR_SPARK_RUNNER_OBJECT_NAME = GCS_JAR_SPARK_RUNNER_PARTS.path[1:] -GCS_JAR_FLINK_RUNNER_PARTS = urlparse(GCS_JAR_FLINK_RUNNER) -GCS_JAR_FLINK_RUNNER_BUCKET_NAME = GCS_JAR_FLINK_RUNNER_PARTS.netloc -GCS_JAR_FLINK_RUNNER_OBJECT_NAME = GCS_JAR_FLINK_RUNNER_PARTS.path[1:] - - -DEFAULT_ARGS = { - 'default_pipeline_options': {'output': '/tmp/example_beam'}, - 'trigger_rule': TriggerRule.ALL_DONE, -} -START_DATE = datetime(2021, 1, 1) - - -with models.DAG( - "example_beam_native_java_direct_runner", - schedule_interval=None, # Override to match your needs - start_date=START_DATE, - catchup=False, - tags=['example'], -) as dag_native_java_direct_runner: - - # [START howto_operator_start_java_direct_runner_pipeline] - jar_to_local_direct_runner = GCSToLocalFilesystemOperator( - task_id="jar_to_local_direct_runner", - bucket=GCS_JAR_DIRECT_RUNNER_BUCKET_NAME, - object_name=GCS_JAR_DIRECT_RUNNER_OBJECT_NAME, - filename="/tmp/beam_wordcount_direct_runner_{{ ds_nodash }}.jar", - ) - - start_java_pipeline_direct_runner = BeamRunJavaPipelineOperator( - task_id="start_java_pipeline_direct_runner", - jar="/tmp/beam_wordcount_direct_runner_{{ ds_nodash }}.jar", - pipeline_options={ - 'output': '/tmp/start_java_pipeline_direct_runner', - 'inputFile': GCS_INPUT, - }, - job_class='org.apache.beam.examples.WordCount', - ) - - jar_to_local_direct_runner >> start_java_pipeline_direct_runner - # [END howto_operator_start_java_direct_runner_pipeline] - -with models.DAG( - "example_beam_native_java_dataflow_runner", - schedule_interval=None, # Override to match your needs - start_date=START_DATE, - catchup=False, - tags=['example'], -) as dag_native_java_dataflow_runner: - # [START howto_operator_start_java_dataflow_runner_pipeline] - jar_to_local_dataflow_runner = GCSToLocalFilesystemOperator( - task_id="jar_to_local_dataflow_runner", - bucket=GCS_JAR_DATAFLOW_RUNNER_BUCKET_NAME, - object_name=GCS_JAR_DATAFLOW_RUNNER_OBJECT_NAME, - filename="/tmp/beam_wordcount_dataflow_runner_{{ ds_nodash }}.jar", - ) - - start_java_pipeline_dataflow = BeamRunJavaPipelineOperator( - task_id="start_java_pipeline_dataflow", - runner="DataflowRunner", - jar="/tmp/beam_wordcount_dataflow_runner_{{ ds_nodash }}.jar", - pipeline_options={ - 'tempLocation': GCS_TMP, - 'stagingLocation': GCS_STAGING, - 'output': GCS_OUTPUT, - }, - job_class='org.apache.beam.examples.WordCount', - dataflow_config={"job_name": "{{task.task_id}}", "location": "us-central1"}, - ) - - jar_to_local_dataflow_runner >> start_java_pipeline_dataflow - # [END howto_operator_start_java_dataflow_runner_pipeline] - -with models.DAG( - "example_beam_native_java_spark_runner", - schedule_interval=None, # Override to match your needs - start_date=START_DATE, - catchup=False, - tags=['example'], -) as dag_native_java_spark_runner: - - jar_to_local_spark_runner = GCSToLocalFilesystemOperator( - task_id="jar_to_local_spark_runner", - bucket=GCS_JAR_SPARK_RUNNER_BUCKET_NAME, - object_name=GCS_JAR_SPARK_RUNNER_OBJECT_NAME, - filename="/tmp/beam_wordcount_spark_runner_{{ ds_nodash }}.jar", - ) - - start_java_pipeline_spark_runner = BeamRunJavaPipelineOperator( - task_id="start_java_pipeline_spark_runner", - runner="SparkRunner", - jar="/tmp/beam_wordcount_spark_runner_{{ ds_nodash }}.jar", - pipeline_options={ - 'output': '/tmp/start_java_pipeline_spark_runner', - 'inputFile': GCS_INPUT, - }, - job_class='org.apache.beam.examples.WordCount', - ) - - jar_to_local_spark_runner >> start_java_pipeline_spark_runner - -with models.DAG( - "example_beam_native_java_flink_runner", - schedule_interval=None, # Override to match your needs - start_date=START_DATE, - catchup=False, - tags=['example'], -) as dag_native_java_flink_runner: - - jar_to_local_flink_runner = GCSToLocalFilesystemOperator( - task_id="jar_to_local_flink_runner", - bucket=GCS_JAR_FLINK_RUNNER_BUCKET_NAME, - object_name=GCS_JAR_FLINK_RUNNER_OBJECT_NAME, - filename="/tmp/beam_wordcount_flink_runner_{{ ds_nodash }}.jar", - ) - - start_java_pipeline_flink_runner = BeamRunJavaPipelineOperator( - task_id="start_java_pipeline_flink_runner", - runner="FlinkRunner", - jar="/tmp/beam_wordcount_flink_runner_{{ ds_nodash }}.jar", - pipeline_options={ - 'output': '/tmp/start_java_pipeline_flink_runner', - 'inputFile': GCS_INPUT, - }, - job_class='org.apache.beam.examples.WordCount', - ) - - jar_to_local_flink_runner >> start_java_pipeline_flink_runner - - -with models.DAG( - "example_beam_native_python", - start_date=START_DATE, - schedule_interval=None, # Override to match your needs - catchup=False, - default_args=DEFAULT_ARGS, - tags=['example'], -) as dag_native_python: - - # [START howto_operator_start_python_direct_runner_pipeline_local_file] - start_python_pipeline_local_direct_runner = BeamRunPythonPipelineOperator( - task_id="start_python_pipeline_local_direct_runner", - py_file='apache_beam.examples.wordcount', - py_options=['-m'], - py_requirements=['apache-beam[gcp]==2.26.0'], - py_interpreter='python3', - py_system_site_packages=False, - ) - # [END howto_operator_start_python_direct_runner_pipeline_local_file] - - # [START howto_operator_start_python_direct_runner_pipeline_gcs_file] - start_python_pipeline_direct_runner = BeamRunPythonPipelineOperator( - task_id="start_python_pipeline_direct_runner", - py_file=GCS_PYTHON, - py_options=[], - pipeline_options={"output": GCS_OUTPUT}, - py_requirements=['apache-beam[gcp]==2.26.0'], - py_interpreter='python3', - py_system_site_packages=False, - ) - # [END howto_operator_start_python_direct_runner_pipeline_gcs_file] - - # [START howto_operator_start_python_dataflow_runner_pipeline_gcs_file] - start_python_pipeline_dataflow_runner = BeamRunPythonPipelineOperator( - task_id="start_python_pipeline_dataflow_runner", - runner="DataflowRunner", - py_file=GCS_PYTHON, - pipeline_options={ - 'tempLocation': GCS_TMP, - 'stagingLocation': GCS_STAGING, - 'output': GCS_OUTPUT, - }, - py_options=[], - py_requirements=['apache-beam[gcp]==2.26.0'], - py_interpreter='python3', - py_system_site_packages=False, - dataflow_config=DataflowConfiguration( - job_name='{{task.task_id}}', project_id=GCP_PROJECT_ID, location="us-central1" - ), - ) - # [END howto_operator_start_python_dataflow_runner_pipeline_gcs_file] - - start_python_pipeline_local_spark_runner = BeamRunPythonPipelineOperator( - task_id="start_python_pipeline_local_spark_runner", - py_file='apache_beam.examples.wordcount', - runner="SparkRunner", - py_options=['-m'], - py_requirements=['apache-beam[gcp]==2.26.0'], - py_interpreter='python3', - py_system_site_packages=False, - ) - - start_python_pipeline_local_flink_runner = BeamRunPythonPipelineOperator( - task_id="start_python_pipeline_local_flink_runner", - py_file='apache_beam.examples.wordcount', - runner="FlinkRunner", - py_options=['-m'], - pipeline_options={ - 'output': '/tmp/start_python_pipeline_local_flink_runner', - }, - py_requirements=['apache-beam[gcp]==2.26.0'], - py_interpreter='python3', - py_system_site_packages=False, - ) - - ( - [ - start_python_pipeline_local_direct_runner, - start_python_pipeline_direct_runner, - ] - >> start_python_pipeline_local_flink_runner - >> start_python_pipeline_local_spark_runner - ) - - -with models.DAG( - "example_beam_native_python_dataflow_async", - default_args=DEFAULT_ARGS, - start_date=START_DATE, - schedule_interval=None, # Override to match your needs - catchup=False, - tags=['example'], -) as dag_native_python_dataflow_async: - # [START howto_operator_start_python_dataflow_runner_pipeline_async_gcs_file] - start_python_job_dataflow_runner_async = BeamRunPythonPipelineOperator( - task_id="start_python_job_dataflow_runner_async", - runner="DataflowRunner", - py_file=GCS_PYTHON_DATAFLOW_ASYNC, - pipeline_options={ - 'tempLocation': GCS_TMP, - 'stagingLocation': GCS_STAGING, - 'output': GCS_OUTPUT, - }, - py_options=[], - py_requirements=['apache-beam[gcp]==2.26.0'], - py_interpreter='python3', - py_system_site_packages=False, - dataflow_config=DataflowConfiguration( - job_name='{{task.task_id}}', - project_id=GCP_PROJECT_ID, - location="us-central1", - wait_until_finished=False, - ), - ) - - wait_for_python_job_dataflow_runner_async_done = DataflowJobStatusSensor( - task_id="wait-for-python-job-async-done", - job_id="{{task_instance.xcom_pull('start_python_job_dataflow_runner_async')['dataflow_job_id']}}", - expected_statuses={DataflowJobStatus.JOB_STATE_DONE}, - project_id=GCP_PROJECT_ID, - location='us-central1', - ) - - start_python_job_dataflow_runner_async >> wait_for_python_job_dataflow_runner_async_done - # [END howto_operator_start_python_dataflow_runner_pipeline_async_gcs_file] - - -with models.DAG( - "example_beam_native_go", - start_date=START_DATE, - schedule_interval="@once", - catchup=False, - default_args=DEFAULT_ARGS, - tags=['example'], -) as dag_native_go: - - # [START howto_operator_start_go_direct_runner_pipeline_local_file] - start_go_pipeline_local_direct_runner = BeamRunGoPipelineOperator( - task_id="start_go_pipeline_local_direct_runner", - go_file='files/apache_beam/examples/wordcount.go', - ) - # [END howto_operator_start_go_direct_runner_pipeline_local_file] - - # [START howto_operator_start_go_direct_runner_pipeline_gcs_file] - start_go_pipeline_direct_runner = BeamRunGoPipelineOperator( - task_id="start_go_pipeline_direct_runner", - go_file=GCS_GO, - pipeline_options={"output": GCS_OUTPUT}, - ) - # [END howto_operator_start_go_direct_runner_pipeline_gcs_file] - - # [START howto_operator_start_go_dataflow_runner_pipeline_gcs_file] - start_go_pipeline_dataflow_runner = BeamRunGoPipelineOperator( - task_id="start_go_pipeline_dataflow_runner", - runner="DataflowRunner", - go_file=GCS_GO, - pipeline_options={ - 'tempLocation': GCS_TMP, - 'stagingLocation': GCS_STAGING, - 'output': GCS_OUTPUT, - 'WorkerHarnessContainerImage': "apache/beam_go_sdk:latest", - }, - dataflow_config=DataflowConfiguration( - job_name='{{task.task_id}}', project_id=GCP_PROJECT_ID, location="us-central1" - ), - ) - # [END howto_operator_start_go_dataflow_runner_pipeline_gcs_file] - - start_go_pipeline_local_spark_runner = BeamRunGoPipelineOperator( - task_id="start_go_pipeline_local_spark_runner", - go_file='/files/apache_beam/examples/wordcount.go', - runner="SparkRunner", - pipeline_options={ - 'endpoint': '/your/spark/endpoint', - }, - ) - - start_go_pipeline_local_flink_runner = BeamRunGoPipelineOperator( - task_id="start_go_pipeline_local_flink_runner", - go_file='/files/apache_beam/examples/wordcount.go', - runner="FlinkRunner", - pipeline_options={ - 'output': '/tmp/start_go_pipeline_local_flink_runner', - }, - ) - - ( - [ - start_go_pipeline_local_direct_runner, - start_go_pipeline_direct_runner, - ] - >> start_go_pipeline_local_flink_runner - >> start_go_pipeline_local_spark_runner - ) - - -with models.DAG( - "example_beam_native_go_dataflow_async", - default_args=DEFAULT_ARGS, - start_date=START_DATE, - schedule_interval="@once", - catchup=False, - tags=['example'], -) as dag_native_go_dataflow_async: - # [START howto_operator_start_go_dataflow_runner_pipeline_async_gcs_file] - start_go_job_dataflow_runner_async = BeamRunGoPipelineOperator( - task_id="start_go_job_dataflow_runner_async", - runner="DataflowRunner", - go_file=GCS_GO_DATAFLOW_ASYNC, - pipeline_options={ - 'tempLocation': GCS_TMP, - 'stagingLocation': GCS_STAGING, - 'output': GCS_OUTPUT, - 'WorkerHarnessContainerImage': "apache/beam_go_sdk:latest", - }, - dataflow_config=DataflowConfiguration( - job_name='{{task.task_id}}', - project_id=GCP_PROJECT_ID, - location="us-central1", - wait_until_finished=False, - ), - ) - - wait_for_go_job_dataflow_runner_async_done = DataflowJobStatusSensor( - task_id="wait-for-go-job-async-done", - job_id="{{task_instance.xcom_pull('start_go_job_dataflow_runner_async')['dataflow_job_id']}}", - expected_statuses={DataflowJobStatus.JOB_STATE_DONE}, - project_id=GCP_PROJECT_ID, - location='us-central1', - ) - - start_go_job_dataflow_runner_async >> wait_for_go_job_dataflow_runner_async_done - # [END howto_operator_start_go_dataflow_runner_pipeline_async_gcs_file] diff --git a/airflow/providers/apache/beam/hooks/beam.py b/airflow/providers/apache/beam/hooks/beam.py index 0644e02b625f0..28a5abc0c6e3e 100644 --- a/airflow/providers/apache/beam/hooks/beam.py +++ b/airflow/providers/apache/beam/hooks/beam.py @@ -16,6 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains a Apache Beam Hook.""" +from __future__ import annotations + +import contextlib import json import os import select @@ -24,7 +27,9 @@ import subprocess import textwrap from tempfile import TemporaryDirectory -from typing import Callable, List, Optional +from typing import Callable + +from packaging.version import Version from airflow.exceptions import AirflowConfigException, AirflowException from airflow.hooks.base import BaseHook @@ -50,7 +55,7 @@ class BeamRunnerType: Twister2Runner = "Twister2Runner" -def beam_options_to_args(options: dict) -> List[str]: +def beam_options_to_args(options: dict) -> list[str]: """ Returns a formatted pipeline options from a dictionary of arguments @@ -60,12 +65,11 @@ def beam_options_to_args(options: dict) -> List[str]: :param options: Dictionary with options :return: List of arguments - :rtype: List[str] """ if not options: return [] - args: List[str] = [] + args: list[str] = [] for attr, value in options.items(): if value is None or (isinstance(value, bool) and value): args.append(f"--{attr}") @@ -88,14 +92,14 @@ class BeamCommandRunner(LoggingMixin): def __init__( self, - cmd: List[str], - process_line_callback: Optional[Callable[[str], None]] = None, - working_directory: Optional[str] = None, + cmd: list[str], + process_line_callback: Callable[[str], None] | None = None, + working_directory: str | None = None, ) -> None: super().__init__() self.log.info("Running command: %s", " ".join(shlex.quote(c) for c in cmd)) self.process_line_callback = process_line_callback - self.job_id: Optional[str] = None + self.job_id: str | None = None self._proc = subprocess.Popen( cmd, @@ -173,9 +177,9 @@ def __init__( def _start_pipeline( self, variables: dict, - command_prefix: List[str], - process_line_callback: Optional[Callable[[str], None]] = None, - working_directory: Optional[str] = None, + command_prefix: list[str], + process_line_callback: Callable[[str], None] | None = None, + working_directory: str | None = None, ) -> None: cmd = command_prefix + [ f"--runner={self.runner}", @@ -193,11 +197,11 @@ def start_python_pipeline( self, variables: dict, py_file: str, - py_options: List[str], + py_options: list[str], py_interpreter: str = "python3", - py_requirements: Optional[List[str]] = None, + py_requirements: list[str] | None = None, py_system_site_packages: bool = False, - process_line_callback: Optional[Callable[[str], None]] = None, + process_line_callback: Callable[[str], None] | None = None, ): """ Starts Apache Beam python pipeline. @@ -225,37 +229,47 @@ def start_python_pipeline( if "labels" in variables: variables["labels"] = [f"{key}={value}" for key, value in variables["labels"].items()] - if py_requirements is not None: - if not py_requirements and not py_system_site_packages: - warning_invalid_environment = textwrap.dedent( - """\ - Invalid method invocation. You have disabled inclusion of system packages and empty list - required for installation, so it is not possible to create a valid virtual environment. - In the virtual environment, apache-beam package must be installed for your job to be \ - executed. To fix this problem: - * install apache-beam on the system, then set parameter py_system_site_packages to True, - * add apache-beam to the list of required packages in parameter py_requirements. - """ - ) - raise AirflowException(warning_invalid_environment) - - with TemporaryDirectory(prefix="apache-beam-venv") as tmp_dir: + with contextlib.ExitStack() as exit_stack: + if py_requirements is not None: + if not py_requirements and not py_system_site_packages: + warning_invalid_environment = textwrap.dedent( + """\ + Invalid method invocation. You have disabled inclusion of system packages and empty + list required for installation, so it is not possible to create a valid virtual + environment. In the virtual environment, apache-beam package must be installed for + your job to be executed. + + To fix this problem: + * install apache-beam on the system, then set parameter py_system_site_packages + to True, + * add apache-beam to the list of required packages in parameter py_requirements. + """ + ) + raise AirflowException(warning_invalid_environment) + tmp_dir = exit_stack.enter_context(TemporaryDirectory(prefix="apache-beam-venv")) py_interpreter = prepare_virtualenv( venv_directory=tmp_dir, python_bin=py_interpreter, system_site_packages=py_system_site_packages, requirements=py_requirements, ) - command_prefix = [py_interpreter] + py_options + [py_file] - self._start_pipeline( - variables=variables, - command_prefix=command_prefix, - process_line_callback=process_line_callback, - ) - else: command_prefix = [py_interpreter] + py_options + [py_file] + beam_version = ( + subprocess.check_output( + [py_interpreter, "-c", "import apache_beam; print(apache_beam.__version__)"] + ) + .decode() + .strip() + ) + self.log.info("Beam version: %s", beam_version) + impersonate_service_account = variables.get("impersonate_service_account") + if impersonate_service_account: + if Version(beam_version) < Version("2.39.0") or True: + raise AirflowException( + "The impersonateServiceAccount option requires Apache Beam 2.39.0 or newer." + ) self._start_pipeline( variables=variables, command_prefix=command_prefix, @@ -266,8 +280,8 @@ def start_java_pipeline( self, variables: dict, jar: str, - job_class: Optional[str] = None, - process_line_callback: Optional[Callable[[str], None]] = None, + job_class: str | None = None, + process_line_callback: Callable[[str], None] | None = None, ) -> None: """ Starts Apache Beam Java pipeline. @@ -292,7 +306,7 @@ def start_go_pipeline( self, variables: dict, go_file: str, - process_line_callback: Optional[Callable[[str], None]] = None, + process_line_callback: Callable[[str], None] | None = None, should_init_module: bool = False, ) -> None: """ diff --git a/airflow/providers/apache/beam/operators/beam.py b/airflow/providers/apache/beam/operators/beam.py index e7f7af5e236a5..efef187aaf28a 100644 --- a/airflow/providers/apache/beam/operators/beam.py +++ b/airflow/providers/apache/beam/operators/beam.py @@ -16,11 +16,13 @@ # specific language governing permissions and limitations # under the License. """This module contains Apache Beam operators.""" +from __future__ import annotations + import copy import tempfile from abc import ABC, ABCMeta from contextlib import ExitStack -from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Callable, Sequence from airflow import AirflowException from airflow.models import BaseOperator @@ -47,17 +49,17 @@ class BeamDataflowMixin(metaclass=ABCMeta): :class:`~airflow.providers.apache.beam.operators.beam.BeamRunGoPipelineOperator`. """ - dataflow_hook: Optional[DataflowHook] + dataflow_hook: DataflowHook | None dataflow_config: DataflowConfiguration gcp_conn_id: str - delegate_to: Optional[str] + delegate_to: str | None dataflow_support_impersonation: bool = True def _set_dataflow( self, pipeline_options: dict, - job_name_variable_key: Optional[str] = None, - ) -> Tuple[str, dict, Callable[[str], None]]: + job_name_variable_key: str | None = None, + ) -> tuple[str, dict, Callable[[str], None]]: self.dataflow_hook = self.__set_dataflow_hook() self.dataflow_config.project_id = self.dataflow_config.project_id or self.dataflow_hook.project_id dataflow_job_name = self.__get_dataflow_job_name() @@ -85,7 +87,7 @@ def __get_dataflow_job_name(self) -> str: ) def __get_dataflow_pipeline_options( - self, pipeline_options: dict, job_name: str, job_name_key: Optional[str] = None + self, pipeline_options: dict, job_name: str, job_name_key: str | None = None ) -> dict: pipeline_options = copy.deepcopy(pipeline_options) if job_name_key is not None: @@ -151,11 +153,11 @@ def __init__( self, *, runner: str = "DirectRunner", - default_pipeline_options: Optional[dict] = None, - pipeline_options: Optional[dict] = None, + default_pipeline_options: dict | None = None, + pipeline_options: dict | None = None, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - dataflow_config: Optional[Union[DataflowConfiguration, dict]] = None, + delegate_to: str | None = None, + dataflow_config: DataflowConfiguration | dict | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -168,9 +170,9 @@ def __init__( self.dataflow_config = DataflowConfiguration(**dataflow_config) else: self.dataflow_config = dataflow_config or DataflowConfiguration() - self.beam_hook: Optional[BeamHook] = None - self.dataflow_hook: Optional[DataflowHook] = None - self.dataflow_job_id: Optional[str] = None + self.beam_hook: BeamHook | None = None + self.dataflow_hook: DataflowHook | None = None + self.dataflow_job_id: str | None = None if self.dataflow_config and self.runner.lower() != BeamRunnerType.DataflowRunner.lower(): self.log.warning( @@ -180,13 +182,13 @@ def __init__( def _init_pipeline_options( self, format_pipeline_options: bool = False, - job_name_variable_key: Optional[str] = None, - ) -> Tuple[bool, Optional[str], dict, Optional[Callable[[str], None]]]: + job_name_variable_key: str | None = None, + ) -> tuple[bool, str | None, dict, Callable[[str], None] | None]: self.beam_hook = BeamHook(runner=self.runner) pipeline_options = self.default_pipeline_options.copy() - process_line_callback: Optional[Callable[[str], None]] = None + process_line_callback: Callable[[str], None] | None = None is_dataflow = self.runner.lower() == BeamRunnerType.DataflowRunner.lower() - dataflow_job_name: Optional[str] = None + dataflow_job_name: str | None = None if is_dataflow: dataflow_job_name, pipeline_options, process_line_callback = self._set_dataflow( pipeline_options=pipeline_options, @@ -247,7 +249,7 @@ class BeamRunPythonPipelineOperator(BeamBasePipelineOperator): "default_pipeline_options", "dataflow_config", ) - template_fields_renderers = {'dataflow_config': 'json', 'pipeline_options': 'json'} + template_fields_renderers = {"dataflow_config": "json", "pipeline_options": "json"} operator_extra_links = (DataflowJobLink(),) def __init__( @@ -255,15 +257,15 @@ def __init__( *, py_file: str, runner: str = "DirectRunner", - default_pipeline_options: Optional[dict] = None, - pipeline_options: Optional[dict] = None, + default_pipeline_options: dict | None = None, + pipeline_options: dict | None = None, py_interpreter: str = "python3", - py_options: Optional[List[str]] = None, - py_requirements: Optional[List[str]] = None, + py_options: list[str] | None = None, + py_requirements: list[str] | None = None, py_system_site_packages: bool = False, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - dataflow_config: Optional[Union[DataflowConfiguration, dict]] = None, + delegate_to: str | None = None, + dataflow_config: DataflowConfiguration | dict | None = None, **kwargs, ) -> None: super().__init__( @@ -285,7 +287,7 @@ def __init__( {"airflow-version": "v" + version.replace(".", "-").replace("+", "-")} ) - def execute(self, context: 'Context'): + def execute(self, context: Context): """Execute the Apache Beam Pipeline.""" ( is_dataflow, @@ -343,7 +345,7 @@ def execute(self, context: 'Context'): def on_kill(self) -> None: if self.dataflow_hook and self.dataflow_job_id: - self.log.info('Dataflow job with id: `%s` was requested to be cancelled.', self.dataflow_job_id) + self.log.info("Dataflow job with id: `%s` was requested to be cancelled.", self.dataflow_job_id) self.dataflow_hook.cancel_job( job_id=self.dataflow_job_id, project_id=self.dataflow_config.project_id, @@ -386,7 +388,7 @@ class BeamRunJavaPipelineOperator(BeamBasePipelineOperator): "default_pipeline_options", "dataflow_config", ) - template_fields_renderers = {'dataflow_config': 'json', 'pipeline_options': 'json'} + template_fields_renderers = {"dataflow_config": "json", "pipeline_options": "json"} ui_color = "#0273d4" operator_extra_links = (DataflowJobLink(),) @@ -396,12 +398,12 @@ def __init__( *, jar: str, runner: str = "DirectRunner", - job_class: Optional[str] = None, - default_pipeline_options: Optional[dict] = None, - pipeline_options: Optional[dict] = None, + job_class: str | None = None, + default_pipeline_options: dict | None = None, + pipeline_options: dict | None = None, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - dataflow_config: Optional[Union[DataflowConfiguration, dict]] = None, + delegate_to: str | None = None, + dataflow_config: DataflowConfiguration | dict | None = None, **kwargs, ) -> None: super().__init__( @@ -416,7 +418,7 @@ def __init__( self.jar = jar self.job_class = job_class - def execute(self, context: 'Context'): + def execute(self, context: Context): """Execute the Apache Beam Pipeline.""" ( is_dataflow, @@ -469,11 +471,7 @@ def execute(self, context: 'Context'): process_line_callback=process_line_callback, ) if dataflow_job_name and self.dataflow_config.location: - multiple_jobs = ( - self.dataflow_config.multiple_jobs - if self.dataflow_config.multiple_jobs - else False - ) + multiple_jobs = self.dataflow_config.multiple_jobs or False DataflowJobLink.persist( self, context, @@ -499,7 +497,7 @@ def execute(self, context: 'Context'): def on_kill(self) -> None: if self.dataflow_hook and self.dataflow_job_id: - self.log.info('Dataflow job with id: `%s` was requested to be cancelled.', self.dataflow_job_id) + self.log.info("Dataflow job with id: `%s` was requested to be cancelled.", self.dataflow_job_id) self.dataflow_hook.cancel_job( job_id=self.dataflow_job_id, project_id=self.dataflow_config.project_id, @@ -533,7 +531,7 @@ class BeamRunGoPipelineOperator(BeamBasePipelineOperator): "default_pipeline_options", "dataflow_config", ] - template_fields_renderers = {'dataflow_config': 'json', 'pipeline_options': 'json'} + template_fields_renderers = {"dataflow_config": "json", "pipeline_options": "json"} operator_extra_links = (DataflowJobLink(),) def __init__( @@ -541,11 +539,11 @@ def __init__( *, go_file: str, runner: str = "DirectRunner", - default_pipeline_options: Optional[dict] = None, - pipeline_options: Optional[dict] = None, + default_pipeline_options: dict | None = None, + pipeline_options: dict | None = None, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - dataflow_config: Optional[Union[DataflowConfiguration, dict]] = None, + delegate_to: str | None = None, + dataflow_config: DataflowConfiguration | dict | None = None, **kwargs, ) -> None: super().__init__( @@ -571,7 +569,7 @@ def __init__( {"airflow-version": "v" + version.replace(".", "-").replace("+", "-")} ) - def execute(self, context: 'Context'): + def execute(self, context: Context): """Execute the Apache Beam Pipeline.""" ( is_dataflow, @@ -629,7 +627,7 @@ def execute(self, context: 'Context'): def on_kill(self) -> None: if self.dataflow_hook and self.dataflow_job_id: - self.log.info('Dataflow job with id: `%s` was requested to be cancelled.', self.dataflow_job_id) + self.log.info("Dataflow job with id: `%s` was requested to be cancelled.", self.dataflow_job_id) self.dataflow_hook.cancel_job( job_id=self.dataflow_job_id, project_id=self.dataflow_config.project_id, diff --git a/airflow/providers/apache/beam/provider.yaml b/airflow/providers/apache/beam/provider.yaml index 4d6bbeab7a23c..5f83ce135203a 100644 --- a/airflow/providers/apache/beam/provider.yaml +++ b/airflow/providers/apache/beam/provider.yaml @@ -22,6 +22,8 @@ description: | `Apache Beam `__. versions: + - 4.1.0 + - 4.0.0 - 3.4.0 - 3.3.0 - 3.2.1 @@ -33,8 +35,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-beam>=2.33.0 integrations: - integration-name: Apache Beam @@ -54,4 +57,6 @@ hooks: - airflow.providers.apache.beam.hooks.beam additional-extras: - google: apache-beam[gcp] + - name: google + dependencies: + - apache-beam[gcp] diff --git a/airflow/providers/apache/cassandra/.latest-doc-only-change.txt b/airflow/providers/apache/cassandra/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/apache/cassandra/.latest-doc-only-change.txt +++ b/airflow/providers/apache/cassandra/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/apache/cassandra/CHANGELOG.rst b/airflow/providers/apache/cassandra/CHANGELOG.rst index 029740ddef0b2..51d14af93e4e9 100644 --- a/airflow/providers/apache/cassandra/CHANGELOG.rst +++ b/airflow/providers/apache/cassandra/CHANGELOG.rst @@ -16,9 +16,56 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Misc +~~~~ + * ``chore: Refactoring and Cleaning Apache Providers (#24219)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``AIP-47 - Migrate cassandra DAGs to new design #22439 (#24209)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.1.3 ..... diff --git a/airflow/providers/apache/cassandra/hooks/cassandra.py b/airflow/providers/apache/cassandra/hooks/cassandra.py index 3d250741d2fc8..e058556c21b44 100644 --- a/airflow/providers/apache/cassandra/hooks/cassandra.py +++ b/airflow/providers/apache/cassandra/hooks/cassandra.py @@ -15,10 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains hook to integrate with Apache Cassandra.""" +from __future__ import annotations -from typing import Any, Dict, Union +from typing import Any, Union from cassandra.auth import PlainTextAuthProvider from cassandra.cluster import Cluster, Session @@ -83,10 +83,10 @@ class CassandraHook(BaseHook, LoggingMixin): For details of the Cluster config, see cassandra.cluster. """ - conn_name_attr = 'cassandra_conn_id' - default_conn_name = 'cassandra_default' - conn_type = 'cassandra' - hook_name = 'Cassandra' + conn_name_attr = "cassandra_conn_id" + default_conn_name = "cassandra_default" + conn_type = "cassandra" + hook_name = "Cassandra" def __init__(self, cassandra_conn_id: str = default_conn_name): super().__init__() @@ -94,31 +94,31 @@ def __init__(self, cassandra_conn_id: str = default_conn_name): conn_config = {} if conn.host: - conn_config['contact_points'] = conn.host.split(',') + conn_config["contact_points"] = conn.host.split(",") if conn.port: - conn_config['port'] = int(conn.port) + conn_config["port"] = int(conn.port) if conn.login: - conn_config['auth_provider'] = PlainTextAuthProvider(username=conn.login, password=conn.password) + conn_config["auth_provider"] = PlainTextAuthProvider(username=conn.login, password=conn.password) - policy_name = conn.extra_dejson.get('load_balancing_policy', None) - policy_args = conn.extra_dejson.get('load_balancing_policy_args', {}) + policy_name = conn.extra_dejson.get("load_balancing_policy", None) + policy_args = conn.extra_dejson.get("load_balancing_policy_args", {}) lb_policy = self.get_lb_policy(policy_name, policy_args) if lb_policy: - conn_config['load_balancing_policy'] = lb_policy + conn_config["load_balancing_policy"] = lb_policy - cql_version = conn.extra_dejson.get('cql_version', None) + cql_version = conn.extra_dejson.get("cql_version", None) if cql_version: - conn_config['cql_version'] = cql_version + conn_config["cql_version"] = cql_version - ssl_options = conn.extra_dejson.get('ssl_options', None) + ssl_options = conn.extra_dejson.get("ssl_options", None) if ssl_options: - conn_config['ssl_options'] = ssl_options + conn_config["ssl_options"] = ssl_options - protocol_version = conn.extra_dejson.get('protocol_version', None) + protocol_version = conn.extra_dejson.get("protocol_version", None) if protocol_version: - conn_config['protocol_version'] = protocol_version + conn_config["protocol_version"] = protocol_version self.cluster = Cluster(**conn_config) self.keyspace = conn.schema @@ -141,37 +141,36 @@ def shutdown_cluster(self) -> None: self.cluster.shutdown() @staticmethod - def get_lb_policy(policy_name: str, policy_args: Dict[str, Any]) -> Policy: + def get_lb_policy(policy_name: str, policy_args: dict[str, Any]) -> Policy: """ Creates load balancing policy. :param policy_name: Name of the policy to use. :param policy_args: Parameters for the policy. """ - if policy_name == 'DCAwareRoundRobinPolicy': - local_dc = policy_args.get('local_dc', '') - used_hosts_per_remote_dc = int(policy_args.get('used_hosts_per_remote_dc', 0)) + if policy_name == "DCAwareRoundRobinPolicy": + local_dc = policy_args.get("local_dc", "") + used_hosts_per_remote_dc = int(policy_args.get("used_hosts_per_remote_dc", 0)) return DCAwareRoundRobinPolicy(local_dc, used_hosts_per_remote_dc) - if policy_name == 'WhiteListRoundRobinPolicy': - hosts = policy_args.get('hosts') + if policy_name == "WhiteListRoundRobinPolicy": + hosts = policy_args.get("hosts") if not hosts: - raise Exception('Hosts must be specified for WhiteListRoundRobinPolicy') + raise Exception("Hosts must be specified for WhiteListRoundRobinPolicy") return WhiteListRoundRobinPolicy(hosts) - if policy_name == 'TokenAwarePolicy': + if policy_name == "TokenAwarePolicy": allowed_child_policies = ( - 'RoundRobinPolicy', - 'DCAwareRoundRobinPolicy', - 'WhiteListRoundRobinPolicy', + "RoundRobinPolicy", + "DCAwareRoundRobinPolicy", + "WhiteListRoundRobinPolicy", ) - child_policy_name = policy_args.get('child_load_balancing_policy', 'RoundRobinPolicy') - child_policy_args = policy_args.get('child_load_balancing_policy_args', {}) + child_policy_name = policy_args.get("child_load_balancing_policy", "RoundRobinPolicy") + child_policy_args = policy_args.get("child_load_balancing_policy_args", {}) if child_policy_name not in allowed_child_policies: return TokenAwarePolicy(RoundRobinPolicy()) - else: - child_policy = CassandraHook.get_lb_policy(child_policy_name, child_policy_args) - return TokenAwarePolicy(child_policy) + child_policy = CassandraHook.get_lb_policy(child_policy_name, child_policy_args) + return TokenAwarePolicy(child_policy) # Fallback to default RoundRobinPolicy return RoundRobinPolicy() @@ -184,12 +183,12 @@ def table_exists(self, table: str) -> bool: Use dot notation to target a specific keyspace. """ keyspace = self.keyspace - if '.' in table: - keyspace, table = table.split('.', 1) + if "." in table: + keyspace, table = table.split(".", 1) cluster_metadata = self.get_conn().cluster.metadata return keyspace in cluster_metadata.keyspaces and table in cluster_metadata.keyspaces[keyspace].tables - def record_exists(self, table: str, keys: Dict[str, str]) -> bool: + def record_exists(self, table: str, keys: dict[str, str]) -> bool: """ Checks if a record exists in Cassandra @@ -198,9 +197,9 @@ def record_exists(self, table: str, keys: Dict[str, str]) -> bool: :param keys: The keys and their values to check the existence. """ keyspace = self.keyspace - if '.' in table: - keyspace, table = table.split('.', 1) - ks_str = " AND ".join(f"{key}=%({key})s" for key in keys.keys()) + if "." in table: + keyspace, table = table.split(".", 1) + ks_str = " AND ".join(f"{key}=%({key})s" for key in keys) query = f"SELECT * FROM {keyspace}.{table} WHERE {ks_str}" try: result = self.get_conn().execute(query, keys) diff --git a/airflow/providers/apache/cassandra/provider.yaml b/airflow/providers/apache/cassandra/provider.yaml index 311d12c3411e5..961bbd36798f7 100644 --- a/airflow/providers/apache/cassandra/provider.yaml +++ b/airflow/providers/apache/cassandra/provider.yaml @@ -22,6 +22,8 @@ description: | `Apache Cassandra `__. versions: + - 3.1.0 + - 3.0.0 - 2.1.3 - 2.1.2 - 2.1.1 @@ -31,8 +33,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - cassandra-driver>=3.13.0 integrations: - integration-name: Apache Cassandra @@ -53,9 +56,6 @@ hooks: python-modules: - airflow.providers.apache.cassandra.hooks.cassandra -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.apache.cassandra.hooks.cassandra.CassandraHook - connection-types: - hook-class-name: airflow.providers.apache.cassandra.hooks.cassandra.CassandraHook connection-type: cassandra diff --git a/airflow/providers/apache/cassandra/sensors/record.py b/airflow/providers/apache/cassandra/sensors/record.py index f0a407297adfd..1221c6fc37995 100644 --- a/airflow/providers/apache/cassandra/sensors/record.py +++ b/airflow/providers/apache/cassandra/sensors/record.py @@ -19,8 +19,9 @@ This module contains sensor that check the existence of a record in a Cassandra cluster. """ +from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Sequence +from typing import TYPE_CHECKING, Any, Sequence from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook from airflow.sensors.base import BaseSensorOperator @@ -53,12 +54,12 @@ class CassandraRecordSensor(BaseSensorOperator): when connecting to Cassandra cluster """ - template_fields: Sequence[str] = ('table', 'keys') + template_fields: Sequence[str] = ("table", "keys") def __init__( self, *, - keys: Dict[str, str], + keys: dict[str, str], table: str, cassandra_conn_id: str = CassandraHook.default_conn_name, **kwargs: Any, @@ -68,7 +69,7 @@ def __init__( self.table = table self.keys = keys - def poke(self, context: "Context") -> bool: - self.log.info('Sensor check existence of record: %s', self.keys) + def poke(self, context: Context) -> bool: + self.log.info("Sensor check existence of record: %s", self.keys) hook = CassandraHook(self.cassandra_conn_id) return hook.record_exists(self.table, self.keys) diff --git a/airflow/providers/apache/cassandra/sensors/table.py b/airflow/providers/apache/cassandra/sensors/table.py index 2f5e6681cb4b9..60e8594f9b9f2 100644 --- a/airflow/providers/apache/cassandra/sensors/table.py +++ b/airflow/providers/apache/cassandra/sensors/table.py @@ -15,11 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ This module contains sensor that check the existence of a table in a Cassandra cluster. """ +from __future__ import annotations from typing import TYPE_CHECKING, Any, Sequence @@ -52,7 +52,7 @@ class CassandraTableSensor(BaseSensorOperator): when connecting to Cassandra cluster """ - template_fields: Sequence[str] = ('table',) + template_fields: Sequence[str] = ("table",) def __init__( self, @@ -65,7 +65,7 @@ def __init__( self.cassandra_conn_id = cassandra_conn_id self.table = table - def poke(self, context: "Context") -> bool: - self.log.info('Sensor check existence of table: %s', self.table) + def poke(self, context: Context) -> bool: + self.log.info("Sensor check existence of table: %s", self.table) hook = CassandraHook(self.cassandra_conn_id) return hook.table_exists(self.table) diff --git a/airflow/providers/apache/drill/CHANGELOG.rst b/airflow/providers/apache/drill/CHANGELOG.rst index a2bc20f91d195..cfd26afb423aa 100644 --- a/airflow/providers/apache/drill/CHANGELOG.rst +++ b/airflow/providers/apache/drill/CHANGELOG.rst @@ -16,9 +16,87 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +2.3.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Features +~~~~~~~~ + +* ``Add SQLExecuteQueryOperator (#25717)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + +2.2.1 +..... + +Misc +~~~~ + +* ``Add common-sql lower bound for common-sql (#25789)`` + +.. Review and move the new changes to one of the sections above: + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +2.2.0 +..... + +Features +~~~~~~~~ + +* ``Unify DbApiHook.run() method with the methods which override it (#23971)`` + + +2.1.0 +..... + +Features +~~~~~~~~ + +* ``Move all SQL classes to common-sql provider (#24836)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +2.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``AIP-47 - Migrate drill DAGs to new design #22439 (#24206)`` + * ``Prepare provider documentation 2022.05.11 (#23631)`` + * ``Clean up in-line f-string concatenation (#23591)`` + * ``chore: Refactoring and Cleaning Apache Providers (#24219)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 1.0.4 ..... diff --git a/airflow/providers/apache/drill/hooks/drill.py b/airflow/providers/apache/drill/hooks/drill.py index a15658e9e38ce..1d847e27e68d8 100644 --- a/airflow/providers/apache/drill/hooks/drill.py +++ b/airflow/providers/apache/drill/hooks/drill.py @@ -15,13 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import Any, Iterable, Optional, Tuple +from typing import Any, Iterable from sqlalchemy import create_engine from sqlalchemy.engine import Connection -from airflow.hooks.dbapi import DbApiHook +from airflow.providers.common.sql.hooks.sql import DbApiHook class DrillHook(DbApiHook): @@ -38,24 +39,24 @@ class DrillHook(DbApiHook): connection using the extras field e.g. ``{"storage_plugin": "dfs"}``. """ - conn_name_attr = 'drill_conn_id' - default_conn_name = 'drill_default' - conn_type = 'drill' - hook_name = 'Drill' + conn_name_attr = "drill_conn_id" + default_conn_name = "drill_default" + conn_type = "drill" + hook_name = "Drill" supports_autocommit = False def get_conn(self) -> Connection: """Establish a connection to Drillbit.""" conn_md = self.get_connection(getattr(self, self.conn_name_attr)) - creds = f'{conn_md.login}:{conn_md.password}@' if conn_md.login else '' + creds = f"{conn_md.login}:{conn_md.password}@" if conn_md.login else "" engine = create_engine( f'{conn_md.extra_dejson.get("dialect_driver", "drill+sadrill")}://{creds}' - f'{conn_md.host}:{conn_md.port}/' + f"{conn_md.host}:{conn_md.port}/" f'{conn_md.extra_dejson.get("storage_plugin", "dfs")}' ) self.log.info( - 'Connected to the Drillbit at %s:%s as user %s', conn_md.host, conn_md.port, conn_md.login + "Connected to the Drillbit at %s:%s as user %s", conn_md.host, conn_md.port, conn_md.login ) return engine.raw_connection() @@ -68,11 +69,11 @@ def get_uri(self) -> str: conn_md = self.get_connection(getattr(self, self.conn_name_attr)) host = conn_md.host if conn_md.port is not None: - host += f':{conn_md.port}' - conn_type = 'drill' if not conn_md.conn_type else conn_md.conn_type - dialect_driver = conn_md.extra_dejson.get('dialect_driver', 'drill+sadrill') - storage_plugin = conn_md.extra_dejson.get('storage_plugin', 'dfs') - return f'{conn_type}://{host}/{storage_plugin}?dialect_driver={dialect_driver}' + host += f":{conn_md.port}" + conn_type = conn_md.conn_type or "drill" + dialect_driver = conn_md.extra_dejson.get("dialect_driver", "drill+sadrill") + storage_plugin = conn_md.extra_dejson.get("storage_plugin", "dfs") + return f"{conn_type}://{host}/{storage_plugin}?dialect_driver={dialect_driver}" def set_autocommit(self, conn: Connection, autocommit: bool) -> NotImplementedError: raise NotImplementedError("There are no transactions in Drill.") @@ -80,8 +81,8 @@ def set_autocommit(self, conn: Connection, autocommit: bool) -> NotImplementedEr def insert_rows( self, table: str, - rows: Iterable[Tuple[str]], - target_fields: Optional[Iterable[str]] = None, + rows: Iterable[tuple[str]], + target_fields: Iterable[str] | None = None, commit_every: int = 1000, replace: bool = False, **kwargs: Any, diff --git a/airflow/providers/apache/drill/operators/drill.py b/airflow/providers/apache/drill/operators/drill.py index 791ed546c34fe..1be54e8405648 100644 --- a/airflow/providers/apache/drill/operators/drill.py +++ b/airflow/providers/apache/drill/operators/drill.py @@ -15,18 +15,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union +from __future__ import annotations -import sqlparse +import warnings +from typing import Sequence -from airflow.models import BaseOperator -from airflow.providers.apache.drill.hooks.drill import DrillHook +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator -if TYPE_CHECKING: - from airflow.utils.context import Context - -class DrillOperator(BaseOperator): +class DrillOperator(SQLExecuteQueryOperator): """ Executes the provided SQL in the identified Drill environment. @@ -42,28 +39,16 @@ class DrillOperator(BaseOperator): :param parameters: (optional) the parameters to render the SQL query with. """ - template_fields: Sequence[str] = ('sql',) - template_fields_renderers = {'sql': 'sql'} - template_ext: Sequence[str] = ('.sql',) - ui_color = '#ededed' - - def __init__( - self, - *, - sql: str, - drill_conn_id: str = 'drill_default', - parameters: Optional[Union[Mapping, Iterable]] = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.sql = sql - self.drill_conn_id = drill_conn_id - self.parameters = parameters - self.hook: Optional[DrillHook] = None - - def execute(self, context: 'Context'): - self.log.info('Executing: %s on %s', self.sql, self.drill_conn_id) - self.hook = DrillHook(drill_conn_id=self.drill_conn_id) - sql = sqlparse.split(sqlparse.format(self.sql, strip_comments=True)) - no_term_sql = [s[:-1] for s in sql if s[-1] == ';'] - self.hook.run(no_term_sql, parameters=self.parameters) + template_fields: Sequence[str] = ("sql",) + template_fields_renderers = {"sql": "sql"} + template_ext: Sequence[str] = (".sql",) + ui_color = "#ededed" + + def __init__(self, *, drill_conn_id: str = "drill_default", **kwargs) -> None: + super().__init__(conn_id=drill_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""", + DeprecationWarning, + stacklevel=2, + ) diff --git a/airflow/providers/apache/drill/provider.yaml b/airflow/providers/apache/drill/provider.yaml index 6dfeade83e3c0..8cab31b8bce20 100644 --- a/airflow/providers/apache/drill/provider.yaml +++ b/airflow/providers/apache/drill/provider.yaml @@ -22,14 +22,21 @@ description: | `Apache Drill `__. versions: + - 2.3.0 + - 2.2.1 + - 2.2.0 + - 2.1.0 + - 2.0.0 - 1.0.4 - 1.0.3 - 1.0.2 - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-common-sql>=1.3.1 + - sqlalchemy-drill>=1.1.0 integrations: - integration-name: Apache Drill @@ -49,9 +56,6 @@ hooks: python-modules: - airflow.providers.apache.drill.hooks.drill -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.apache.drill.hooks.drill.DrillHook - connection-types: - hook-class-name: airflow.providers.apache.drill.hooks.drill.DrillHook connection-type: drill diff --git a/airflow/providers/apache/druid/CHANGELOG.rst b/airflow/providers/apache/druid/CHANGELOG.rst index ae22cead26039..957a2a2bfc4ea 100644 --- a/airflow/providers/apache/druid/CHANGELOG.rst +++ b/airflow/providers/apache/druid/CHANGELOG.rst @@ -16,9 +16,90 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.3.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Bug Fixes +~~~~~~~~~ + +* ``BugFix - Druid Airflow Exception to about content (#27174)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + +3.2.1 +..... + +Misc +~~~~ + +* ``Add common-sql lower bound for common-sql (#25789)`` + +.. Review and move the new changes to one of the sections above: + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +3.2.0 +..... + +Features +~~~~~~~~ + +* ``Move all "old" SQL operators to common.sql providers (#25350)`` + + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Move all SQL classes to common-sql provider (#24836)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Misc +~~~~ + +* ``chore: Refactoring and Cleaning Apache Providers (#24219)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``AIP-47 - Migrate druid DAGs to new design #22439 (#24207)`` + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.3.3 ..... diff --git a/airflow/providers/apache/druid/hooks/druid.py b/airflow/providers/apache/druid/hooks/druid.py index 671c914be604f..abfd86d68e3b1 100644 --- a/airflow/providers/apache/druid/hooks/druid.py +++ b/airflow/providers/apache/druid/hooks/druid.py @@ -15,16 +15,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import time -from typing import Any, Dict, Iterable, Optional, Tuple, Union +from typing import Any, Iterable import requests from pydruid.db import connect from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook -from airflow.hooks.dbapi import DbApiHook +from airflow.providers.common.sql.hooks.sql import DbApiHook class DruidHook(BaseHook): @@ -44,16 +45,16 @@ class DruidHook(BaseHook): def __init__( self, - druid_ingest_conn_id: str = 'druid_ingest_default', + druid_ingest_conn_id: str = "druid_ingest_default", timeout: int = 1, - max_ingestion_time: Optional[int] = None, + max_ingestion_time: int | None = None, ) -> None: super().__init__() self.druid_ingest_conn_id = druid_ingest_conn_id self.timeout = timeout self.max_ingestion_time = max_ingestion_time - self.header = {'content-type': 'application/json'} + self.header = {"content-type": "application/json"} if self.timeout < 1: raise ValueError("Druid timeout should be equal or greater than 1") @@ -63,11 +64,11 @@ def get_conn_url(self) -> str: conn = self.get_connection(self.druid_ingest_conn_id) host = conn.host port = conn.port - conn_type = 'http' if not conn.conn_type else conn.conn_type - endpoint = conn.extra_dejson.get('endpoint', '') + conn_type = conn.conn_type or "http" + endpoint = conn.extra_dejson.get("endpoint", "") return f"{conn_type}://{host}:{port}/{endpoint}" - def get_auth(self) -> Optional[requests.auth.HTTPBasicAuth]: + def get_auth(self) -> requests.auth.HTTPBasicAuth | None: """ Return username and password from connections tab as requests.auth.HTTPBasicAuth object. @@ -81,18 +82,21 @@ def get_auth(self) -> Optional[requests.auth.HTTPBasicAuth]: else: return None - def submit_indexing_job(self, json_index_spec: Union[Dict[str, Any], str]) -> None: + def submit_indexing_job(self, json_index_spec: dict[str, Any] | str) -> None: """Submit Druid ingestion job""" url = self.get_conn_url() self.log.info("Druid ingestion spec: %s", json_index_spec) req_index = requests.post(url, data=json_index_spec, headers=self.header, auth=self.get_auth()) - if req_index.status_code != 200: - raise AirflowException(f'Did not get 200 when submitting the Druid job to {url}') + + code = req_index.status_code + if code != 200: + self.log.error("Error submitting the Druid job to %s (%s) %s", url, code, req_index.content) + raise AirflowException(f"Did not get 200 when submitting the Druid job to {url}") req_json = req_index.json() # Wait until the job is completed - druid_task_id = req_json['task'] + druid_task_id = req_json["task"] self.log.info("Druid indexing task-id: %s", druid_task_id) running = True @@ -106,23 +110,23 @@ def submit_indexing_job(self, json_index_spec: Union[Dict[str, Any], str]) -> No if self.max_ingestion_time and sec > self.max_ingestion_time: # ensure that the job gets killed if the max ingestion time is exceeded requests.post(f"{url}/{druid_task_id}/shutdown", auth=self.get_auth()) - raise AirflowException(f'Druid ingestion took more than {self.max_ingestion_time} seconds') + raise AirflowException(f"Druid ingestion took more than {self.max_ingestion_time} seconds") time.sleep(self.timeout) sec += self.timeout - status = req_status.json()['status']['status'] - if status == 'RUNNING': + status = req_status.json()["status"]["status"] + if status == "RUNNING": running = True - elif status == 'SUCCESS': + elif status == "SUCCESS": running = False # Great success! - elif status == 'FAILED': - raise AirflowException('Druid indexing job failed, check console for more info') + elif status == "FAILED": + raise AirflowException("Druid indexing job failed, check console for more info") else: - raise AirflowException(f'Could not get status of the job, got {status}') + raise AirflowException(f"Could not get status of the job, got {status}") - self.log.info('Successful index') + self.log.info("Successful index") class DruidDbApiHook(DbApiHook): @@ -133,10 +137,10 @@ class DruidDbApiHook(DbApiHook): For ingestion, please use druidHook. """ - conn_name_attr = 'druid_broker_conn_id' - default_conn_name = 'druid_broker_default' - conn_type = 'druid' - hook_name = 'Druid' + conn_name_attr = "druid_broker_conn_id" + default_conn_name = "druid_broker_default" + conn_type = "druid" + hook_name = "Druid" supports_autocommit = False def get_conn(self) -> connect: @@ -145,12 +149,12 @@ def get_conn(self) -> connect: druid_broker_conn = connect( host=conn.host, port=conn.port, - path=conn.extra_dejson.get('endpoint', '/druid/v2/sql'), - scheme=conn.extra_dejson.get('schema', 'http'), + path=conn.extra_dejson.get("endpoint", "/druid/v2/sql"), + scheme=conn.extra_dejson.get("schema", "http"), user=conn.login, password=conn.password, ) - self.log.info('Get the connection to druid broker on %s using user %s', conn.host, conn.login) + self.log.info("Get the connection to druid broker on %s using user %s", conn.host, conn.login) return druid_broker_conn def get_uri(self) -> str: @@ -162,10 +166,10 @@ def get_uri(self) -> str: conn = self.get_connection(getattr(self, self.conn_name_attr)) host = conn.host if conn.port is not None: - host += f':{conn.port}' - conn_type = 'druid' if not conn.conn_type else conn.conn_type - endpoint = conn.extra_dejson.get('endpoint', 'druid/v2/sql') - return f'{conn_type}://{host}/{endpoint}' + host += f":{conn.port}" + conn_type = conn.conn_type or "druid" + endpoint = conn.extra_dejson.get("endpoint", "druid/v2/sql") + return f"{conn_type}://{host}/{endpoint}" def set_autocommit(self, conn: connect, autocommit: bool) -> NotImplementedError: raise NotImplementedError() @@ -173,8 +177,8 @@ def set_autocommit(self, conn: connect, autocommit: bool) -> NotImplementedError def insert_rows( self, table: str, - rows: Iterable[Tuple[str]], - target_fields: Optional[Iterable[str]] = None, + rows: Iterable[tuple[str]], + target_fields: Iterable[str] | None = None, commit_every: int = 1000, replace: bool = False, **kwargs: Any, diff --git a/airflow/providers/apache/druid/operators/druid.py b/airflow/providers/apache/druid/operators/druid.py index 9fd8b3595af5c..2ef39886dab0a 100644 --- a/airflow/providers/apache/druid/operators/druid.py +++ b/airflow/providers/apache/druid/operators/druid.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional, Sequence +from typing import TYPE_CHECKING, Any, Sequence from airflow.models import BaseOperator from airflow.providers.apache.druid.hooks.druid import DruidHook @@ -37,17 +38,17 @@ class DruidOperator(BaseOperator): :param max_ingestion_time: The maximum ingestion time before assuming the job failed """ - template_fields: Sequence[str] = ('json_index_file',) - template_ext: Sequence[str] = ('.json',) - template_fields_renderers = {'json_index_file': 'json'} + template_fields: Sequence[str] = ("json_index_file",) + template_ext: Sequence[str] = (".json",) + template_fields_renderers = {"json_index_file": "json"} def __init__( self, *, json_index_file: str, - druid_ingest_conn_id: str = 'druid_ingest_default', + druid_ingest_conn_id: str = "druid_ingest_default", timeout: int = 1, - max_ingestion_time: Optional[int] = None, + max_ingestion_time: int | None = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -56,7 +57,7 @@ def __init__( self.timeout = timeout self.max_ingestion_time = max_ingestion_time - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: hook = DruidHook( druid_ingest_conn_id=self.conn_id, timeout=self.timeout, diff --git a/airflow/providers/apache/druid/operators/druid_check.py b/airflow/providers/apache/druid/operators/druid_check.py index 33a4151350e55..84c00c4d61b0f 100644 --- a/airflow/providers/apache/druid/operators/druid_check.py +++ b/airflow/providers/apache/druid/operators/druid_check.py @@ -15,21 +15,23 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import warnings -from airflow.operators.sql import SQLCheckOperator +from airflow.providers.common.sql.operators.sql import SQLCheckOperator class DruidCheckOperator(SQLCheckOperator): """ This class is deprecated. - Please use `airflow.operators.sql.SQLCheckOperator`. + Please use `airflow.providers.common.sql.operators.sql.SQLCheckOperator`. """ - def __init__(self, druid_broker_conn_id: str = 'druid_broker_default', **kwargs): + def __init__(self, druid_broker_conn_id: str = "druid_broker_default", **kwargs): warnings.warn( """This class is deprecated. - Please use `airflow.operators.sql.SQLCheckOperator`.""", + Please use `airflow.providers.common.sql.operators.sql.SQLCheckOperator`.""", DeprecationWarning, stacklevel=2, ) diff --git a/airflow/providers/apache/druid/provider.yaml b/airflow/providers/apache/druid/provider.yaml index 214a1f15425ad..34e577e24b83d 100644 --- a/airflow/providers/apache/druid/provider.yaml +++ b/airflow/providers/apache/druid/provider.yaml @@ -22,6 +22,11 @@ description: | `Apache Druid `__. versions: + - 3.3.0 + - 3.2.1 + - 3.2.0 + - 3.1.0 + - 3.0.0 - 2.3.3 - 2.3.2 - 2.3.1 @@ -35,8 +40,10 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-common-sql>=1.3.1 + - pydruid>=0.4.1 integrations: - integration-name: Apache Druid @@ -57,8 +64,6 @@ hooks: python-modules: - airflow.providers.apache.druid.hooks.druid -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.apache.druid.hooks.druid.DruidDbApiHook connection-types: - hook-class-name: airflow.providers.apache.druid.hooks.druid.DruidDbApiHook diff --git a/airflow/providers/apache/druid/transfers/hive_to_druid.py b/airflow/providers/apache/druid/transfers/hive_to_druid.py index dc5f74109acf1..4c7523dd8d434 100644 --- a/airflow/providers/apache/druid/transfers/hive_to_druid.py +++ b/airflow/providers/apache/druid/transfers/hive_to_druid.py @@ -15,10 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains operator to move data from Hive to Druid.""" +from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Sequence from airflow.models import BaseOperator from airflow.providers.apache.druid.hooks.druid import DruidHook @@ -64,9 +64,9 @@ class HiveToDruidOperator(BaseOperator): :param job_properties: additional properties for job """ - template_fields: Sequence[str] = ('sql', 'intervals') - template_ext: Sequence[str] = ('.sql',) - template_fields_renderers = {'sql': 'hql'} + template_fields: Sequence[str] = ("sql", "intervals") + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "hql"} def __init__( self, @@ -74,25 +74,25 @@ def __init__( sql: str, druid_datasource: str, ts_dim: str, - metric_spec: Optional[List[Any]] = None, - hive_cli_conn_id: str = 'hive_cli_default', - druid_ingest_conn_id: str = 'druid_ingest_default', - metastore_conn_id: str = 'metastore_default', - hadoop_dependency_coordinates: Optional[List[str]] = None, - intervals: Optional[List[Any]] = None, + metric_spec: list[Any] | None = None, + hive_cli_conn_id: str = "hive_cli_default", + druid_ingest_conn_id: str = "druid_ingest_default", + metastore_conn_id: str = "metastore_default", + hadoop_dependency_coordinates: list[str] | None = None, + intervals: list[Any] | None = None, num_shards: float = -1, target_partition_size: int = -1, query_granularity: str = "NONE", segment_granularity: str = "DAY", - hive_tblproperties: Optional[Dict[Any, Any]] = None, - job_properties: Optional[Dict[Any, Any]] = None, + hive_tblproperties: dict[Any, Any] | None = None, + job_properties: dict[Any, Any] | None = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.sql = sql self.druid_datasource = druid_datasource self.ts_dim = ts_dim - self.intervals = intervals or ['{{ ds }}/{{ tomorrow_ds }}'] + self.intervals = intervals or ["{{ ds }}/{{ tomorrow_ds }}"] self.num_shards = num_shards self.target_partition_size = target_partition_size self.query_granularity = query_granularity @@ -105,12 +105,12 @@ def __init__( self.hive_tblproperties = hive_tblproperties or {} self.job_properties = job_properties - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) self.log.info("Extracting data from Hive") - hive_table = 'druid.' + context['task_instance_key_str'].replace('.', '_') - sql = self.sql.strip().strip(';') - tblproperties = ''.join(f", '{k}' = '{v}'" for k, v in self.hive_tblproperties.items()) + hive_table = "druid." + context["task_instance_key_str"].replace(".", "_") + sql = self.sql.strip().strip(";") + tblproperties = "".join(f", '{k}' = '{v}'" for k, v in self.hive_tblproperties.items()) hql = f"""\ SET mapred.output.compress=false; SET hive.exec.compress.output=false; @@ -152,7 +152,7 @@ def execute(self, context: "Context") -> None: hql = f"DROP TABLE IF EXISTS {hive_table}" hive.run_cli(hql) - def construct_ingest_query(self, static_path: str, columns: List[str]) -> Dict[str, Any]: + def construct_ingest_query(self, static_path: str, columns: list[str]) -> dict[str, Any]: """ Builds an ingest query for an HDFS TSV load. @@ -170,13 +170,13 @@ def construct_ingest_query(self, static_path: str, columns: List[str]) -> Dict[s else: num_shards = -1 - metric_names = [m['fieldName'] for m in self.metric_spec if m['type'] != 'count'] + metric_names = [m["fieldName"] for m in self.metric_spec if m["type"] != "count"] # Take all the columns, which are not the time dimension # or a metric, as the dimension columns dimensions = [c for c in columns if c not in metric_names and c != self.ts_dim] - ingest_query_dict: Dict[str, Any] = { + ingest_query_dict: dict[str, Any] = { "type": "index_hadoop", "spec": { "dataSchema": { @@ -220,9 +220,9 @@ def construct_ingest_query(self, static_path: str, columns: List[str]) -> Dict[s } if self.job_properties: - ingest_query_dict['spec']['tuningConfig']['jobProperties'].update(self.job_properties) + ingest_query_dict["spec"]["tuningConfig"]["jobProperties"].update(self.job_properties) if self.hadoop_dependency_coordinates: - ingest_query_dict['hadoopDependencyCoordinates'] = self.hadoop_dependency_coordinates + ingest_query_dict["hadoopDependencyCoordinates"] = self.hadoop_dependency_coordinates return ingest_query_dict diff --git a/airflow/providers/apache/hdfs/.latest-doc-only-change.txt b/airflow/providers/apache/hdfs/.latest-doc-only-change.txt index e7c3c940c9c77..ff7136e07d744 100644 --- a/airflow/providers/apache/hdfs/.latest-doc-only-change.txt +++ b/airflow/providers/apache/hdfs/.latest-doc-only-change.txt @@ -1 +1 @@ -602abe8394fafe7de54df7e73af56de848cdf617 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/apache/hdfs/CHANGELOG.rst b/airflow/providers/apache/hdfs/CHANGELOG.rst index e71546e3a0c7c..b7f11a87c1024 100644 --- a/airflow/providers/apache/hdfs/CHANGELOG.rst +++ b/airflow/providers/apache/hdfs/CHANGELOG.rst @@ -16,9 +16,74 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.2.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Adding Authentication to webhdfs sensor (#25110)`` + +3.0.1 +..... + +Bug Fixes +~~~~~~~~~ + +* ``'WebHDFSHook' Bugfix/optional port (#24550)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Misc +~~~~ + +* ``chore: Refactoring and Cleaning Apache Providers (#24219)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.2.3 ..... diff --git a/airflow/providers/apache/hdfs/hooks/hdfs.py b/airflow/providers/apache/hdfs/hooks/hdfs.py index 6d1ce3d010125..0e98320cb77d7 100644 --- a/airflow/providers/apache/hdfs/hooks/hdfs.py +++ b/airflow/providers/apache/hdfs/hooks/hdfs.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """Hook for HDFS operations""" -from typing import Any, Optional +from __future__ import annotations + +from typing import Any from airflow.configuration import conf from airflow.exceptions import AirflowException @@ -43,20 +45,20 @@ class HDFSHook(BaseHook): :param autoconfig: use snakebite's automatically configured client """ - conn_name_attr = 'hdfs_conn_id' - default_conn_name = 'hdfs_default' - conn_type = 'hdfs' - hook_name = 'HDFS' + conn_name_attr = "hdfs_conn_id" + default_conn_name = "hdfs_default" + conn_type = "hdfs" + hook_name = "HDFS" def __init__( - self, hdfs_conn_id: str = 'hdfs_default', proxy_user: Optional[str] = None, autoconfig: bool = False + self, hdfs_conn_id: str = "hdfs_default", proxy_user: str | None = None, autoconfig: bool = False ): super().__init__() if not snakebite_loaded: raise ImportError( - 'This HDFSHook implementation requires snakebite, but ' - 'snakebite is not compatible with Python 3 ' - '(as of August 2015). Please help by submitting a PR!' + "This HDFSHook implementation requires snakebite, but " + "snakebite is not compatible with Python 3 " + "(as of August 2015). Please help by submitting a PR!" ) self.hdfs_conn_id = hdfs_conn_id self.proxy_user = proxy_user @@ -68,7 +70,7 @@ def get_conn(self) -> Any: # take the first. effective_user = self.proxy_user autoconfig = self.autoconfig - use_sasl = conf.get('core', 'security') == 'kerberos' + use_sasl = conf.get("core", "security") == "kerberos" try: connections = self.get_connections(self.hdfs_conn_id) @@ -76,8 +78,8 @@ def get_conn(self) -> Any: if not effective_user: effective_user = connections[0].login if not autoconfig: - autoconfig = connections[0].extra_dejson.get('autoconfig', False) - hdfs_namenode_principal = connections[0].extra_dejson.get('hdfs_namenode_principal') + autoconfig = connections[0].extra_dejson.get("autoconfig", False) + hdfs_namenode_principal = connections[0].extra_dejson.get("hdfs_namenode_principal") except AirflowException: if not autoconfig: raise diff --git a/airflow/providers/apache/hdfs/hooks/webhdfs.py b/airflow/providers/apache/hdfs/hooks/webhdfs.py index a32206ba9bef8..67608c481cd2a 100644 --- a/airflow/providers/apache/hdfs/hooks/webhdfs.py +++ b/airflow/providers/apache/hdfs/hooks/webhdfs.py @@ -16,9 +16,11 @@ # specific language governing permissions and limitations # under the License. """Hook for Web HDFS""" +from __future__ import annotations + import logging import socket -from typing import Any, Optional +from typing import Any import requests from hdfs import HdfsError, InsecureClient @@ -50,7 +52,7 @@ class WebHDFSHook(BaseHook): :param proxy_user: The user used to authenticate. """ - def __init__(self, webhdfs_conn_id: str = 'webhdfs_default', proxy_user: Optional[str] = None): + def __init__(self, webhdfs_conn_id: str = "webhdfs_default", proxy_user: str | None = None): super().__init__() self.webhdfs_conn_id = webhdfs_conn_id self.proxy_user = proxy_user @@ -59,7 +61,6 @@ def get_conn(self) -> Any: """ Establishes a connection depending on the security mode set via config or environment variable. :return: a hdfscli InsecureClient or KerberosClient object. - :rtype: hdfs.InsecureClient or hdfs.ext.kerberos.KerberosClient """ connection = self._find_valid_server() if connection is None: @@ -68,42 +69,54 @@ def get_conn(self) -> Any: def _find_valid_server(self) -> Any: connection = self.get_connection(self.webhdfs_conn_id) - namenodes = connection.host.split(',') + namenodes = connection.host.split(",") for namenode in namenodes: host_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.log.info("Trying to connect to %s:%s", namenode, connection.port) try: conn_check = host_socket.connect_ex((namenode, connection.port)) if conn_check == 0: - self.log.info('Trying namenode %s', namenode) + self.log.info("Trying namenode %s", namenode) client = self._get_client( - namenode, connection.port, connection.login, connection.extra_dejson + namenode, + connection.port, + connection.login, + connection.get_password(), + connection.schema, + connection.extra_dejson, ) - client.status('/') - self.log.info('Using namenode %s for hook', namenode) + client.status("/") + self.log.info("Using namenode %s for hook", namenode) host_socket.close() return client else: self.log.warning("Could not connect to %s:%s", namenode, connection.port) except HdfsError as hdfs_error: - self.log.info('Read operation on namenode %s failed with error: %s', namenode, hdfs_error) + self.log.info("Read operation on namenode %s failed with error: %s", namenode, hdfs_error) return None - def _get_client(self, namenode: str, port: int, login: str, extra_dejson: dict) -> Any: - connection_str = f'http://{namenode}:{port}' + def _get_client( + self, namenode: str, port: int, login: str, password: str | None, schema: str, extra_dejson: dict + ) -> Any: + connection_str = f"http://{namenode}" session = requests.Session() + if password is not None: + session.auth = (login, password) - if extra_dejson.get('use_ssl', False): - connection_str = f'https://{namenode}:{port}' - session.verify = extra_dejson.get('verify', True) + if extra_dejson.get("use_ssl", "False") == "True" or extra_dejson.get("use_ssl", False): + connection_str = f"https://{namenode}" + session.verify = extra_dejson.get("verify", False) - if _kerberos_security_mode: - client = KerberosClient(connection_str, session=session) - else: - proxy_user = self.proxy_user or login - client = InsecureClient(connection_str, user=proxy_user, session=session) + if port is not None: + connection_str += f":{port}" - return client + if schema is not None: + connection_str += f"/{schema}" + + if _kerberos_security_mode: + return KerberosClient(connection_str, session=session) + proxy_user = self.proxy_user or login + return InsecureClient(connection_str, user=proxy_user, session=session) def check_for_path(self, hdfs_path: str) -> bool: """ @@ -111,7 +124,6 @@ def check_for_path(self, hdfs_path: str) -> bool: :param hdfs_path: The path to check. :return: True if the path exists and False if not. - :rtype: bool """ conn = self.get_conn() diff --git a/airflow/providers/apache/hdfs/provider.yaml b/airflow/providers/apache/hdfs/provider.yaml index a01a734d72629..2f575d7da32da 100644 --- a/airflow/providers/apache/hdfs/provider.yaml +++ b/airflow/providers/apache/hdfs/provider.yaml @@ -23,6 +23,10 @@ description: | and `WebHDFS `__. versions: + - 3.2.0 + - 3.1.0 + - 3.0.1 + - 3.0.0 - 2.2.3 - 2.2.2 - 2.2.1 @@ -33,8 +37,10 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - snakebite-py3 + - hdfs[avro,dataframe,kerberos]>=2.0.4 integrations: - integration-name: Hadoop Distributed File System (HDFS) @@ -66,9 +72,6 @@ hooks: python-modules: - airflow.providers.apache.hdfs.hooks.webhdfs -hook-class-names: - # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.apache.hdfs.hooks.hdfs.HDFSHook connection-types: - hook-class-name: airflow.providers.apache.hdfs.hooks.hdfs.HDFSHook diff --git a/airflow/providers/apache/hdfs/sensors/hdfs.py b/airflow/providers/apache/hdfs/sensors/hdfs.py index a445bb688e722..5c55209f4d1de 100644 --- a/airflow/providers/apache/hdfs/sensors/hdfs.py +++ b/airflow/providers/apache/hdfs/sensors/hdfs.py @@ -15,10 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import logging import re import sys -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Sequence, Type +from typing import TYPE_CHECKING, Any, Pattern, Sequence from airflow import settings from airflow.providers.apache.hdfs.hooks.hdfs import HDFSHook @@ -45,23 +47,23 @@ class HdfsSensor(BaseSensorOperator): :ref:`howto/operator:HdfsSensor` """ - template_fields: Sequence[str] = ('filepath',) - ui_color = settings.WEB_COLORS['LIGHTBLUE'] + template_fields: Sequence[str] = ("filepath",) + ui_color = settings.WEB_COLORS["LIGHTBLUE"] def __init__( self, *, filepath: str, - hdfs_conn_id: str = 'hdfs_default', - ignored_ext: Optional[List[str]] = None, + hdfs_conn_id: str = "hdfs_default", + ignored_ext: list[str] | None = None, ignore_copying: bool = True, - file_size: Optional[int] = None, - hook: Type[HDFSHook] = HDFSHook, + file_size: int | None = None, + hook: type[HDFSHook] = HDFSHook, **kwargs: Any, ) -> None: super().__init__(**kwargs) if ignored_ext is None: - ignored_ext = ['_COPYING_'] + ignored_ext = ["_COPYING_"] self.filepath = filepath self.hdfs_conn_id = hdfs_conn_id self.file_size = file_size @@ -70,7 +72,7 @@ def __init__( self.hook = hook @staticmethod - def filter_for_filesize(result: List[Dict[Any, Any]], size: Optional[int] = None) -> List[Dict[Any, Any]]: + def filter_for_filesize(result: list[dict[Any, Any]], size: int | None = None) -> list[dict[Any, Any]]: """ Will test the filepath result and test if its size is at least self.filesize @@ -79,16 +81,16 @@ def filter_for_filesize(result: List[Dict[Any, Any]], size: Optional[int] = None :return: (bool) depending on the matching criteria """ if size: - log.debug('Filtering for file size >= %s in files: %s', size, map(lambda x: x['path'], result)) + log.debug("Filtering for file size >= %s in files: %s", size, map(lambda x: x["path"], result)) size *= settings.MEGABYTE - result = [x for x in result if x['length'] >= size] - log.debug('HdfsSensor.poke: after size filter result is %s', result) + result = [x for x in result if x["length"] >= size] + log.debug("HdfsSensor.poke: after size filter result is %s", result) return result @staticmethod def filter_for_ignored_ext( - result: List[Dict[Any, Any]], ignored_ext: List[str], ignore_copying: bool - ) -> List[Dict[Any, Any]]: + result: list[dict[Any, Any]], ignored_ext: list[str], ignore_copying: bool + ) -> list[dict[Any, Any]]: """ Will filter if instructed to do so the result to remove matching criteria @@ -96,24 +98,23 @@ def filter_for_ignored_ext( :param ignored_ext: list of ignored extensions :param ignore_copying: shall we ignore ? :return: list of dicts which were not removed - :rtype: list[dict] """ if ignore_copying: - regex_builder = r"^.*\.(%s$)$" % '$|'.join(ignored_ext) + regex_builder = r"^.*\.(%s$)$" % "$|".join(ignored_ext) ignored_extensions_regex = re.compile(regex_builder) log.debug( - 'Filtering result for ignored extensions: %s in files %s', + "Filtering result for ignored extensions: %s in files %s", ignored_extensions_regex.pattern, - map(lambda x: x['path'], result), + map(lambda x: x["path"], result), ) - result = [x for x in result if not ignored_extensions_regex.match(x['path'])] - log.debug('HdfsSensor.poke: after ext filter result is %s', result) + result = [x for x in result if not ignored_extensions_regex.match(x["path"])] + log.debug("HdfsSensor.poke: after ext filter result is %s", result) return result - def poke(self, context: "Context") -> bool: + def poke(self, context: Context) -> bool: """Get a snakebite client connection and check for file.""" sb_client = self.hook(self.hdfs_conn_id).get_conn() - self.log.info('Poking for file %s', self.filepath) + self.log.info("Poking for file %s", self.filepath) try: # IMOO it's not right here, as there is no raise of any kind. # if the filepath is let's say '/data/mydirectory', @@ -121,7 +122,7 @@ def poke(self, context: "Context") -> bool: # it's not correct as the directory exists and sb_client does not raise any error # here is a quick fix result = sb_client.ls([self.filepath], include_toplevel=False) - self.log.debug('HdfsSensor.poke: result is %s', result) + self.log.debug("HdfsSensor.poke: result is %s", result) result = self.filter_for_ignored_ext(result, self.ignored_ext, self.ignore_copying) result = self.filter_for_filesize(result, self.file_size) return bool(result) @@ -144,7 +145,7 @@ def __init__(self, regex: Pattern[str], *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.regex = regex - def poke(self, context: "Context") -> bool: + def poke(self, context: Context) -> bool: """ Poke matching files in a directory with self.regex @@ -152,12 +153,12 @@ def poke(self, context: "Context") -> bool: """ sb_client = self.hook(self.hdfs_conn_id).get_conn() self.log.info( - 'Poking for %s to be a directory with files matching %s', self.filepath, self.regex.pattern + "Poking for %s to be a directory with files matching %s", self.filepath, self.regex.pattern ) result = [ f for f in sb_client.ls([self.filepath], include_toplevel=False) - if f['file_type'] == 'f' and self.regex.match(f['path'].replace(f'{self.filepath}/', '')) + if f["file_type"] == "f" and self.regex.match(f["path"].replace(f"{self.filepath}/", "")) ] result = self.filter_for_ignored_ext(result, self.ignored_ext, self.ignore_copying) result = self.filter_for_filesize(result, self.file_size) @@ -177,7 +178,7 @@ def __init__(self, be_empty: bool = False, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self.be_empty = be_empty - def poke(self, context: "Context") -> bool: + def poke(self, context: Context) -> bool: """ Poke for a non empty directory @@ -188,9 +189,9 @@ def poke(self, context: "Context") -> bool: result = self.filter_for_ignored_ext(result, self.ignored_ext, self.ignore_copying) result = self.filter_for_filesize(result, self.file_size) if self.be_empty: - self.log.info('Poking for filepath %s to a empty directory', self.filepath) - return len(result) == 1 and result[0]['path'] == self.filepath + self.log.info("Poking for filepath %s to a empty directory", self.filepath) + return len(result) == 1 and result[0]["path"] == self.filepath else: - self.log.info('Poking for filepath %s to a non empty directory', self.filepath) + self.log.info("Poking for filepath %s to a non empty directory", self.filepath) result.pop(0) - return bool(result) and result[0]['file_type'] == 'f' + return bool(result) and result[0]["file_type"] == "f" diff --git a/airflow/providers/apache/hdfs/sensors/web_hdfs.py b/airflow/providers/apache/hdfs/sensors/web_hdfs.py index adacdefecad07..38e1047679b03 100644 --- a/airflow/providers/apache/hdfs/sensors/web_hdfs.py +++ b/airflow/providers/apache/hdfs/sensors/web_hdfs.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from typing import TYPE_CHECKING, Any, Sequence from airflow.sensors.base import BaseSensorOperator @@ -26,16 +28,16 @@ class WebHdfsSensor(BaseSensorOperator): """Waits for a file or folder to land in HDFS""" - template_fields: Sequence[str] = ('filepath',) + template_fields: Sequence[str] = ("filepath",) - def __init__(self, *, filepath: str, webhdfs_conn_id: str = 'webhdfs_default', **kwargs: Any) -> None: + def __init__(self, *, filepath: str, webhdfs_conn_id: str = "webhdfs_default", **kwargs: Any) -> None: super().__init__(**kwargs) self.filepath = filepath self.webhdfs_conn_id = webhdfs_conn_id - def poke(self, context: "Context") -> bool: + def poke(self, context: Context) -> bool: from airflow.providers.apache.hdfs.hooks.webhdfs import WebHDFSHook hook = WebHDFSHook(self.webhdfs_conn_id) - self.log.info('Poking for file %s', self.filepath) + self.log.info("Poking for file %s", self.filepath) return hook.check_for_path(hdfs_path=self.filepath) diff --git a/airflow/providers/apache/hive/CHANGELOG.rst b/airflow/providers/apache/hive/CHANGELOG.rst index bb28c620b4126..19543da737b5b 100644 --- a/airflow/providers/apache/hive/CHANGELOG.rst +++ b/airflow/providers/apache/hive/CHANGELOG.rst @@ -16,9 +16,104 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +4.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Bug Fixes +~~~~~~~~~ + +* ``Filter out invalid schemas in Hive hook (#27647)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + +4.0.1 +..... + +Misc +~~~~ + +* ``Add common-sql lower bound for common-sql (#25789)`` + +.. Review and move the new changes to one of the sections above: + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +4.0.0 +..... + +Breaking Changes +~~~~~~~~~~~~~~~~ + +* The ``hql`` parameter in ``get_records`` of ``HiveServer2Hook`` has been renamed to sql to match the + ``get_records`` DbApiHook signature. If you used it as a positional parameter, this is no change for you, + but if you used it as keyword one, you need to rename it. +* ``hive_conf`` parameter has been renamed to ``parameters`` and it is now second parameter, to match ``get_records`` + signature from the DbApiHook. You need to rename it if you used it. +* ``schema`` parameter in ``get_records`` is an optional kwargs extra parameter that you can add, to match + the schema of ``get_records`` from DbApiHook. + +* ``Deprecate hql parameters and synchronize DBApiHook method APIs (#25299)`` +* ``Remove Smart Sensors (#25507)`` + + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Move all SQL classes to common-sql provider (#24836)`` + +Bug Fixes +~~~~~~~~~ + +* ``fix connection extra parameter 'auth_mechanism' in 'HiveMetastoreHook' and 'HiveServer2Hook' (#24713)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Misc +~~~~ + +* ``chore: Refactoring and Cleaning Apache Providers (#24219)`` +* ``AIP-47 - Migrate hive DAGs to new design #22439 (#24204)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add typing for airflow/configuration.py (#23716)`` + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.3.3 ..... diff --git a/airflow/providers/apache/hive/hooks/hive.py b/airflow/providers/apache/hive/hooks/hive.py index bccd279f77cf0..edd8a4b372c06 100644 --- a/airflow/providers/apache/hive/hooks/hive.py +++ b/airflow/providers/apache/hive/hooks/hive.py @@ -15,15 +15,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import contextlib import os import re import socket import subprocess import time +import warnings from collections import OrderedDict from tempfile import NamedTemporaryFile, TemporaryDirectory -from typing import Any, Dict, List, Optional, Union +from typing import Any, Iterable, Mapping import pandas import unicodecsv as csv @@ -31,15 +34,15 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook -from airflow.hooks.dbapi import DbApiHook +from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.security import utils from airflow.utils.helpers import as_flattened_list from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING -HIVE_QUEUE_PRIORITIES = ['VERY_HIGH', 'HIGH', 'NORMAL', 'LOW', 'VERY_LOW'] +HIVE_QUEUE_PRIORITIES = ["VERY_HIGH", "HIGH", "NORMAL", "LOW", "VERY_LOW"] -def get_context_from_env_var() -> Dict[Any, Any]: +def get_context_from_env_var() -> dict[Any, Any]: """ Extract context from env variable, e.g. dag_id, task_id and execution_date, so that they can be used inside BashOperator and PythonOperator. @@ -47,7 +50,7 @@ def get_context_from_env_var() -> Dict[Any, Any]: :return: The context of interest. """ return { - format_map['default']: os.environ.get(format_map['env_var_format'], '') + format_map["default"]: os.environ.get(format_map["env_var_format"], "") for format_map in AIRFLOW_VAR_NAME_FORMAT_MAPPING.values() } @@ -77,24 +80,24 @@ class HiveCliHook(BaseHook): This can make monitoring easier. """ - conn_name_attr = 'hive_cli_conn_id' - default_conn_name = 'hive_cli_default' - conn_type = 'hive_cli' - hook_name = 'Hive Client Wrapper' + conn_name_attr = "hive_cli_conn_id" + default_conn_name = "hive_cli_default" + conn_type = "hive_cli" + hook_name = "Hive Client Wrapper" def __init__( self, hive_cli_conn_id: str = default_conn_name, - run_as: Optional[str] = None, - mapred_queue: Optional[str] = None, - mapred_queue_priority: Optional[str] = None, - mapred_job_name: Optional[str] = None, + run_as: str | None = None, + mapred_queue: str | None = None, + mapred_queue_priority: str | None = None, + mapred_job_name: str | None = None, ) -> None: super().__init__() conn = self.get_connection(hive_cli_conn_id) - self.hive_cli_params: str = conn.extra_dejson.get('hive_cli_params', '') - self.use_beeline: bool = conn.extra_dejson.get('use_beeline', False) - self.auth = conn.extra_dejson.get('auth', 'noSasl') + self.hive_cli_params: str = conn.extra_dejson.get("hive_cli_params", "") + self.use_beeline: bool = conn.extra_dejson.get("use_beeline", False) + self.auth = conn.extra_dejson.get("auth", "noSasl") self.conn = conn self.run_as = run_as self.sub_process: Any = None @@ -106,7 +109,7 @@ def __init__( f"Invalid Mapred Queue Priority. Valid values are: {', '.join(HIVE_QUEUE_PRIORITIES)}" ) - self.mapred_queue = mapred_queue or conf.get('hive', 'default_hive_mapred_queue') + self.mapred_queue = mapred_queue or conf.get("hive", "default_hive_mapred_queue") self.mapred_queue_priority = mapred_queue_priority self.mapred_job_name = mapred_job_name @@ -114,7 +117,7 @@ def _get_proxy_user(self) -> str: """This function set the proper proxy_user value in case the user overwrite the default.""" conn = self.conn - proxy_user_value: str = conn.extra_dejson.get('proxy_user', "") + proxy_user_value: str = conn.extra_dejson.get("proxy_user", "") if proxy_user_value == "login" and conn.login: return f"hive.server2.proxy.user={conn.login}" if proxy_user_value == "owner" and self.run_as: @@ -123,17 +126,17 @@ def _get_proxy_user(self) -> str: return f"hive.server2.proxy.user={proxy_user_value}" return proxy_user_value # The default proxy user (undefined) - def _prepare_cli_cmd(self) -> List[Any]: + def _prepare_cli_cmd(self) -> list[Any]: """This function creates the command list from available information""" conn = self.conn - hive_bin = 'hive' + hive_bin = "hive" cmd_extra = [] if self.use_beeline: - hive_bin = 'beeline' + hive_bin = "beeline" jdbc_url = f"jdbc:hive2://{conn.host}:{conn.port}/{conn.schema}" - if conf.get('core', 'security') == 'kerberos': - template = conn.extra_dejson.get('principal', "hive/_HOST@EXAMPLE.COM") + if conf.get("core", "security") == "kerberos": + template = conn.extra_dejson.get("principal", "hive/_HOST@EXAMPLE.COM") if "_HOST" in template: template = utils.replace_hostname_pattern(utils.get_components(template)) @@ -145,18 +148,18 @@ def _prepare_cli_cmd(self) -> List[Any]: jdbc_url = f'"{jdbc_url}"' - cmd_extra += ['-u', jdbc_url] + cmd_extra += ["-u", jdbc_url] if conn.login: - cmd_extra += ['-n', conn.login] + cmd_extra += ["-n", conn.login] if conn.password: - cmd_extra += ['-p', conn.password] + cmd_extra += ["-p", conn.password] hive_params_list = self.hive_cli_params.split() return [hive_bin] + cmd_extra + hive_params_list @staticmethod - def _prepare_hiveconf(d: Dict[Any, Any]) -> List[Any]: + def _prepare_hiveconf(d: dict[Any, Any]) -> list[Any]: """ This function prepares a list of hiveconf params from a dictionary of key value pairs. @@ -177,9 +180,9 @@ def _prepare_hiveconf(d: Dict[Any, Any]) -> List[Any]: def run_cli( self, hql: str, - schema: Optional[str] = None, + schema: str | None = None, verbose: bool = True, - hive_conf: Optional[Dict[Any, Any]] = None, + hive_conf: dict[Any, Any] | None = None, ) -> Any: """ Run an hql statement using the hive cli. If hive_conf is specified @@ -201,13 +204,15 @@ def run_cli( """ conn = self.conn schema = schema or conn.schema + if "!" in schema or ";" in schema: + raise RuntimeError(f"The schema `{schema}` contains invalid characters (!;)") if schema: hql = f"USE {schema};\n{hql}" - with TemporaryDirectory(prefix='airflow_hiveop_') as tmp_dir: + with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir: with NamedTemporaryFile(dir=tmp_dir) as f: - hql += '\n' - f.write(hql.encode('UTF-8')) + hql += "\n" + f.write(hql.encode("UTF-8")) f.flush() hive_cmd = self._prepare_cli_cmd() env_context = get_context_from_env_var() @@ -218,25 +223,25 @@ def run_cli( if self.mapred_queue: hive_conf_params.extend( [ - '-hiveconf', - f'mapreduce.job.queuename={self.mapred_queue}', - '-hiveconf', - f'mapred.job.queue.name={self.mapred_queue}', - '-hiveconf', - f'tez.queue.name={self.mapred_queue}', + "-hiveconf", + f"mapreduce.job.queuename={self.mapred_queue}", + "-hiveconf", + f"mapred.job.queue.name={self.mapred_queue}", + "-hiveconf", + f"tez.queue.name={self.mapred_queue}", ] ) if self.mapred_queue_priority: hive_conf_params.extend( - ['-hiveconf', f'mapreduce.job.priority={self.mapred_queue_priority}'] + ["-hiveconf", f"mapreduce.job.priority={self.mapred_queue_priority}"] ) if self.mapred_job_name: - hive_conf_params.extend(['-hiveconf', f'mapred.job.name={self.mapred_job_name}']) + hive_conf_params.extend(["-hiveconf", f"mapred.job.name={self.mapred_job_name}"]) hive_cmd.extend(hive_conf_params) - hive_cmd.extend(['-f', f.name]) + hive_cmd.extend(["-f", f.name]) if verbose: self.log.info("%s", " ".join(hive_cmd)) @@ -244,14 +249,14 @@ def run_cli( hive_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=tmp_dir, close_fds=True ) self.sub_process = sub_process - stdout = '' + stdout = "" while True: line = sub_process.stdout.readline() if not line: break - stdout += line.decode('UTF-8') + stdout += line.decode("UTF-8") if verbose: - self.log.info(line.decode('UTF-8').strip()) + self.log.info(line.decode("UTF-8").strip()) sub_process.wait() if sub_process.returncode: @@ -262,37 +267,37 @@ def run_cli( def test_hql(self, hql: str) -> None: """Test an hql statement using the hive cli and EXPLAIN""" create, insert, other = [], [], [] - for query in hql.split(';'): # naive + for query in hql.split(";"): # naive query_original = query query = query.lower().strip() - if query.startswith('create table'): + if query.startswith("create table"): create.append(query_original) - elif query.startswith(('set ', 'add jar ', 'create temporary function')): + elif query.startswith(("set ", "add jar ", "create temporary function")): other.append(query_original) - elif query.startswith('insert'): + elif query.startswith("insert"): insert.append(query_original) - other_ = ';'.join(other) + other_ = ";".join(other) for query_set in [create, insert]: for query in query_set: - query_preview = ' '.join(query.split())[:50] + query_preview = " ".join(query.split())[:50] self.log.info("Testing HQL [%s (...)]", query_preview) if query_set == insert: - query = other_ + '; explain ' + query + query = other_ + "; explain " + query else: - query = 'explain ' + query + query = "explain " + query try: self.run_cli(query, verbose=False) except AirflowException as e: - message = e.args[0].split('\n')[-2] + message = e.args[0].split("\n")[-2] self.log.info(message) - error_loc = re.search(r'(\d+):(\d+)', message) + error_loc = re.search(r"(\d+):(\d+)", message) if error_loc and error_loc.group(1).isdigit(): lst = int(error_loc.group(1)) begin = max(lst - 2, 0) - end = min(lst + 3, len(query.split('\n'))) - context = '\n'.join(query.split('\n')[begin:end]) + end = min(lst + 3, len(query.split("\n"))) + context = "\n".join(query.split("\n")[begin:end]) self.log.info("Context :\n %s", context) else: self.log.info("SUCCESS") @@ -301,9 +306,9 @@ def load_df( self, df: pandas.DataFrame, table: str, - field_dict: Optional[Dict[Any, Any]] = None, - delimiter: str = ',', - encoding: str = 'utf8', + field_dict: dict[Any, Any] | None = None, + delimiter: str = ",", + encoding: str = "utf8", pandas_kwargs: Any = None, **kwargs: Any, ) -> None: @@ -324,18 +329,18 @@ def load_df( :param kwargs: passed to self.load_file """ - def _infer_field_types_from_df(df: pandas.DataFrame) -> Dict[Any, Any]: + def _infer_field_types_from_df(df: pandas.DataFrame) -> dict[Any, Any]: dtype_kind_hive_type = { - 'b': 'BOOLEAN', # boolean - 'i': 'BIGINT', # signed integer - 'u': 'BIGINT', # unsigned integer - 'f': 'DOUBLE', # floating-point - 'c': 'STRING', # complex floating-point - 'M': 'TIMESTAMP', # datetime - 'O': 'STRING', # object - 'S': 'STRING', # (byte-)string - 'U': 'STRING', # Unicode - 'V': 'STRING', # void + "b": "BOOLEAN", # boolean + "i": "BIGINT", # signed integer + "u": "BIGINT", # unsigned integer + "f": "DOUBLE", # floating-point + "c": "STRING", # complex floating-point + "M": "TIMESTAMP", # datetime + "O": "STRING", # object + "S": "STRING", # (byte-)string + "U": "STRING", # Unicode + "V": "STRING", # void } order_type = OrderedDict() @@ -346,7 +351,7 @@ def _infer_field_types_from_df(df: pandas.DataFrame) -> Dict[Any, Any]: if pandas_kwargs is None: pandas_kwargs = {} - with TemporaryDirectory(prefix='airflow_hiveop_') as tmp_dir: + with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir: with NamedTemporaryFile(dir=tmp_dir, mode="w") as f: if field_dict is None: field_dict = _infer_field_types_from_df(df) @@ -371,12 +376,12 @@ def load_file( filepath: str, table: str, delimiter: str = ",", - field_dict: Optional[Dict[Any, Any]] = None, + field_dict: dict[Any, Any] | None = None, create: bool = True, overwrite: bool = True, - partition: Optional[Dict[str, Any]] = None, + partition: dict[str, Any] | None = None, recreate: bool = False, - tblproperties: Optional[Dict[str, Any]] = None, + tblproperties: dict[str, Any] | None = None, ) -> None: """ Loads a local file into Hive @@ -403,7 +408,7 @@ def load_file( execution :param tblproperties: TBLPROPERTIES of the hive table being created """ - hql = '' + hql = "" if recreate: hql += f"DROP TABLE IF EXISTS {table};\n" if create or recreate: @@ -433,14 +438,14 @@ def load_file( # As a workaround for HIVE-10541, add a newline character # at the end of hql (AIRFLOW-2412). - hql += ';\n' + hql += ";\n" self.log.info(hql) self.run_cli(hql) def kill(self) -> None: """Kill Hive cli command""" - if hasattr(self, 'sub_process'): + if hasattr(self, "sub_process"): if self.sub_process.poll() is None: print("Killing the Hive job") self.sub_process.terminate() @@ -459,26 +464,26 @@ class HiveMetastoreHook(BaseHook): # java short max val MAX_PART_COUNT = 32767 - conn_name_attr = 'metastore_conn_id' - default_conn_name = 'metastore_default' - conn_type = 'hive_metastore' - hook_name = 'Hive Metastore Thrift' + conn_name_attr = "metastore_conn_id" + default_conn_name = "metastore_default" + conn_type = "hive_metastore" + hook_name = "Hive Metastore Thrift" def __init__(self, metastore_conn_id: str = default_conn_name) -> None: super().__init__() self.conn = self.get_connection(metastore_conn_id) self.metastore = self.get_metastore_client() - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: # This is for pickling to work despite the thrift hive client not # being picklable state = dict(self.__dict__) - del state['metastore'] + del state["metastore"] return state - def __setstate__(self, d: Dict[str, Any]) -> None: + def __setstate__(self, d: dict[str, Any]) -> None: self.__dict__.update(d) - self.__dict__['metastore'] = self.get_metastore_client() + self.__dict__["metastore"] = self.get_metastore_client() def get_metastore_client(self) -> Any: """Returns a Hive thrift client.""" @@ -492,15 +497,24 @@ def get_metastore_client(self) -> Any: if not host: raise AirflowException("Failed to locate the valid server.") - auth_mechanism = conn.extra_dejson.get('authMechanism', 'NOSASL') + if "authMechanism" in conn.extra_dejson: + warnings.warn( + "The 'authMechanism' option is deprecated. Please use 'auth_mechanism'.", + DeprecationWarning, + stacklevel=2, + ) + conn.extra_dejson["auth_mechanism"] = conn.extra_dejson["authMechanism"] + del conn.extra_dejson["authMechanism"] + + auth_mechanism = conn.extra_dejson.get("auth_mechanism", "NOSASL") - if conf.get('core', 'security') == 'kerberos': - auth_mechanism = conn.extra_dejson.get('authMechanism', 'GSSAPI') - kerberos_service_name = conn.extra_dejson.get('kerberos_service_name', 'hive') + if conf.get("core", "security") == "kerberos": + auth_mechanism = conn.extra_dejson.get("auth_mechanism", "GSSAPI") + kerberos_service_name = conn.extra_dejson.get("kerberos_service_name", "hive") conn_socket = TSocket.TSocket(host, conn.port) - if conf.get('core', 'security') == 'kerberos' and auth_mechanism == 'GSSAPI': + if conf.get("core", "security") == "kerberos" and auth_mechanism == "GSSAPI": try: import saslwrapper as sasl except ImportError: @@ -525,7 +539,7 @@ def sasl_factory() -> sasl.Client: def _find_valid_host(self) -> Any: conn = self.conn - hosts = conn.host.split(',') + hosts = conn.host.split(",") for host in hosts: host_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.log.info("Trying to connect to %s:%s", host, conn.port) @@ -548,7 +562,6 @@ def check_for_partition(self, schema: str, table: str, partition: str) -> bool: :param table: Name of hive table @partition belongs to :param partition: Expression that matches the partitions to check for (eg `a = 'b' AND c = 'd'`) - :rtype: bool >>> hh = HiveMetastoreHook() >>> t = 'static_babynames_partitioned' @@ -569,7 +582,6 @@ def check_for_named_partition(self, schema: str, table: str, partition_name: str :param schema: Name of hive schema (database) @table belongs to :param table: Name of hive table @partition belongs to :param partition_name: Name of the partitions to check for (eg `a=b/c=d`) - :rtype: bool >>> hh = HiveMetastoreHook() >>> t = 'static_babynames_partitioned' @@ -581,7 +593,7 @@ def check_for_named_partition(self, schema: str, table: str, partition_name: str with self.metastore as client: return client.check_for_named_partition(schema, table, partition_name) - def get_table(self, table_name: str, db: str = 'default') -> Any: + def get_table(self, table_name: str, db: str = "default") -> Any: """Get a metastore table object >>> hh = HiveMetastoreHook() @@ -591,25 +603,23 @@ def get_table(self, table_name: str, db: str = 'default') -> Any: >>> [col.name for col in t.sd.cols] ['state', 'year', 'name', 'gender', 'num'] """ - if db == 'default' and '.' in table_name: - db, table_name = table_name.split('.')[:2] + if db == "default" and "." in table_name: + db, table_name = table_name.split(".")[:2] with self.metastore as client: return client.get_table(dbname=db, tbl_name=table_name) - def get_tables(self, db: str, pattern: str = '*') -> Any: + def get_tables(self, db: str, pattern: str = "*") -> Any: """Get a metastore table object""" with self.metastore as client: tables = client.get_tables(db_name=db, pattern=pattern) return client.get_table_objects_by_name(db, tables) - def get_databases(self, pattern: str = '*') -> Any: + def get_databases(self, pattern: str = "*") -> Any: """Get a metastore table object""" with self.metastore as client: return client.get_databases(pattern) - def get_partitions( - self, schema: str, table_name: str, partition_filter: Optional[str] = None - ) -> List[Any]: + def get_partitions(self, schema: str, table_name: str, partition_filter: str | None = None) -> list[Any]: """ Returns a list of all partitions in a table. Works only for tables with less than 32767 (java short max val). @@ -645,7 +655,7 @@ def get_partitions( @staticmethod def _get_max_partition_from_part_specs( - part_specs: List[Any], partition_key: Optional[str], filter_map: Optional[Dict[str, Any]] + part_specs: list[Any], partition_key: str | None, filter_map: dict[str, Any] | None ) -> Any: """ Helper method to get max partition of partitions with partition_key @@ -659,7 +669,6 @@ def _get_max_partition_from_part_specs( Only partitions matching all partition_key:partition_value pairs will be considered as candidates of max partition. :return: Max partition or None if part_specs is empty. - :rtype: basestring """ if not part_specs: return None @@ -691,8 +700,8 @@ def max_partition( self, schema: str, table_name: str, - field: Optional[str] = None, - filter_map: Optional[Dict[Any, Any]] = None, + field: str | None = None, + filter_map: dict[Any, Any] | None = None, ) -> Any: """ Returns the maximum value for all partitions with given field in a table. @@ -732,7 +741,7 @@ def max_partition( return HiveMetastoreHook._get_max_partition_from_part_specs(part_specs, field, filter_map) - def table_exists(self, table_name: str, db: str = 'default') -> bool: + def table_exists(self, table_name: str, db: str = "default") -> bool: """ Check if table exists @@ -748,7 +757,7 @@ def table_exists(self, table_name: str, db: str = 'default') -> bool: except Exception: return False - def drop_partitions(self, table_name, part_vals, delete_data=False, db='default'): + def drop_partitions(self, table_name, part_vals, delete_data=False, db="default"): """ Drop partitions from the given table matching the part_vals input @@ -779,7 +788,7 @@ class HiveServer2Hook(DbApiHook): Wrapper around the pyhive library Notes: - * the default authMechanism is PLAIN, to override it you + * the default auth_mechanism is PLAIN, to override it you can specify it in the ``extra`` of your connection in the UI * the default for run_set_variable_statements is true, if you are using impala you may need to set it to false in the @@ -790,38 +799,47 @@ class HiveServer2Hook(DbApiHook): :param schema: Hive database name. """ - conn_name_attr = 'hiveserver2_conn_id' - default_conn_name = 'hiveserver2_default' - conn_type = 'hiveserver2' - hook_name = 'Hive Server 2 Thrift' + conn_name_attr = "hiveserver2_conn_id" + default_conn_name = "hiveserver2_default" + conn_type = "hiveserver2" + hook_name = "Hive Server 2 Thrift" supports_autocommit = False - def get_conn(self, schema: Optional[str] = None) -> Any: + def get_conn(self, schema: str | None = None) -> Any: """Returns a Hive connection object.""" - username: Optional[str] = None - password: Optional[str] = None + username: str | None = None + password: str | None = None db = self.get_connection(self.hiveserver2_conn_id) # type: ignore - auth_mechanism = db.extra_dejson.get('authMechanism', 'NONE') - if auth_mechanism == 'NONE' and db.login is None: + if "authMechanism" in db.extra_dejson: + warnings.warn( + "The 'authMechanism' option is deprecated. Please use 'auth_mechanism'.", + DeprecationWarning, + stacklevel=2, + ) + db.extra_dejson["auth_mechanism"] = db.extra_dejson["authMechanism"] + del db.extra_dejson["authMechanism"] + + auth_mechanism = db.extra_dejson.get("auth_mechanism", "NONE") + if auth_mechanism == "NONE" and db.login is None: # we need to give a username - username = 'airflow' + username = "airflow" kerberos_service_name = None - if conf.get('core', 'security') == 'kerberos': - auth_mechanism = db.extra_dejson.get('authMechanism', 'KERBEROS') - kerberos_service_name = db.extra_dejson.get('kerberos_service_name', 'hive') + if conf.get("core", "security") == "kerberos": + auth_mechanism = db.extra_dejson.get("auth_mechanism", "KERBEROS") + kerberos_service_name = db.extra_dejson.get("kerberos_service_name", "hive") # pyhive uses GSSAPI instead of KERBEROS as a auth_mechanism identifier - if auth_mechanism == 'GSSAPI': + if auth_mechanism == "GSSAPI": self.log.warning( - "Detected deprecated 'GSSAPI' for authMechanism for %s. Please use 'KERBEROS' instead", + "Detected deprecated 'GSSAPI' for auth_mechanism for %s. Please use 'KERBEROS' instead", self.hiveserver2_conn_id, # type: ignore ) - auth_mechanism = 'KERBEROS' + auth_mechanism = "KERBEROS" # Password should be set if and only if in LDAP or CUSTOM mode - if auth_mechanism in ('LDAP', 'CUSTOM'): + if auth_mechanism in ("LDAP", "CUSTOM"): password = db.password from pyhive.hive import connect @@ -833,20 +851,20 @@ def get_conn(self, schema: Optional[str] = None) -> Any: kerberos_service_name=kerberos_service_name, username=db.login or username, password=password, - database=schema or db.schema or 'default', + database=schema or db.schema or "default", ) def _get_results( self, - hql: Union[str, List[str]], - schema: str = 'default', - fetch_size: Optional[int] = None, - hive_conf: Optional[Dict[Any, Any]] = None, + sql: str | list[str], + schema: str = "default", + fetch_size: int | None = None, + hive_conf: Iterable | Mapping | None = None, ) -> Any: from pyhive.exc import ProgrammingError - if isinstance(hql, str): - hql = [hql] + if isinstance(sql, str): + sql = [sql] previous_description = None with contextlib.closing(self.get_conn(schema)) as conn, contextlib.closing(conn.cursor()) as cur: @@ -856,28 +874,28 @@ def _get_results( db = self.get_connection(self.hiveserver2_conn_id) # type: ignore - if db.extra_dejson.get('run_set_variable_statements', True): + if db.extra_dejson.get("run_set_variable_statements", True): env_context = get_context_from_env_var() if hive_conf: env_context.update(hive_conf) for k, v in env_context.items(): cur.execute(f"set {k}={v}") - for statement in hql: + for statement in sql: cur.execute(statement) # we only get results of statements that returns lowered_statement = statement.lower().strip() if ( - lowered_statement.startswith('select') - or lowered_statement.startswith('with') - or lowered_statement.startswith('show') - or (lowered_statement.startswith('set') and '=' not in lowered_statement) + lowered_statement.startswith("select") + or lowered_statement.startswith("with") + or lowered_statement.startswith("show") + or (lowered_statement.startswith("set") and "=" not in lowered_statement) ): description = cur.description if previous_description and previous_description != description: - message = f'''The statements are producing different descriptions: + message = f"""The statements are producing different descriptions: Current: {repr(description)} - Previous: {repr(previous_description)}''' + Previous: {repr(previous_description)}""" raise ValueError(message) elif not previous_description: previous_description = description @@ -892,41 +910,40 @@ def _get_results( def get_results( self, - hql: str, - schema: str = 'default', - fetch_size: Optional[int] = None, - hive_conf: Optional[Dict[Any, Any]] = None, - ) -> Dict[str, Any]: + sql: str | list[str], + schema: str = "default", + fetch_size: int | None = None, + hive_conf: Iterable | Mapping | None = None, + ) -> dict[str, Any]: """ Get results of the provided hql in target schema. - :param hql: hql to be executed. + :param sql: hql to be executed. :param schema: target schema, default to 'default'. :param fetch_size: max size of result to fetch. :param hive_conf: hive_conf to execute alone with the hql. :return: results of hql execution, dict with data (list of results) and header - :rtype: dict """ - results_iter = self._get_results(hql, schema, fetch_size=fetch_size, hive_conf=hive_conf) + results_iter = self._get_results(sql, schema, fetch_size=fetch_size, hive_conf=hive_conf) header = next(results_iter) - results = {'data': list(results_iter), 'header': header} + results = {"data": list(results_iter), "header": header} return results def to_csv( self, - hql: str, + sql: str, csv_filepath: str, - schema: str = 'default', - delimiter: str = ',', - lineterminator: str = '\r\n', + schema: str = "default", + delimiter: str = ",", + lineterminator: str = "\r\n", output_header: bool = True, fetch_size: int = 1000, - hive_conf: Optional[Dict[Any, Any]] = None, + hive_conf: dict[Any, Any] | None = None, ) -> None: """ Execute hql in target schema and write results to a csv file. - :param hql: hql to be executed. + :param sql: hql to be executed. :param csv_filepath: filepath of csv to write results into. :param schema: target schema, default to 'default'. :param delimiter: delimiter of the csv file, default to ','. @@ -936,16 +953,16 @@ def to_csv( :param hive_conf: hive_conf to execute alone with the hql. """ - results_iter = self._get_results(hql, schema, fetch_size=fetch_size, hive_conf=hive_conf) + results_iter = self._get_results(sql, schema, fetch_size=fetch_size, hive_conf=hive_conf) header = next(results_iter) message = None i = 0 - with open(csv_filepath, 'wb') as file: - writer = csv.writer(file, delimiter=delimiter, lineterminator=lineterminator, encoding='utf-8') + with open(csv_filepath, "wb") as file: + writer = csv.writer(file, delimiter=delimiter, lineterminator=lineterminator, encoding="utf-8") try: if output_header: - self.log.debug('Cursor description is %s', header) + self.log.debug("Cursor description is %s", header) writer.writerow([c[0] for c in header]) for i, row in enumerate(results_iter, 1): @@ -963,40 +980,39 @@ def to_csv( self.log.info("Done. Loaded a total of %s rows.", i) def get_records( - self, hql: str, schema: str = 'default', hive_conf: Optional[Dict[Any, Any]] = None + self, sql: str | list[str], parameters: Iterable | Mapping | None = None, **kwargs ) -> Any: """ - Get a set of records from a Hive query. + Get a set of records from a Hive query. You can optionally pass 'schema' kwarg + which specifies target schema and default to 'default'. - :param hql: hql to be executed. - :param schema: target schema, default to 'default'. - :param hive_conf: hive_conf to execute alone with the hql. + :param sql: hql to be executed. + :param parameters: optional configuration passed to get_results :return: result of hive execution - :rtype: list >>> hh = HiveServer2Hook() >>> sql = "SELECT * FROM airflow.static_babynames LIMIT 100" >>> len(hh.get_records(sql)) 100 """ - return self.get_results(hql, schema=schema, hive_conf=hive_conf)['data'] + schema = kwargs["schema"] if "schema" in kwargs else "default" + return self.get_results(sql, schema=schema, hive_conf=parameters)["data"] def get_pandas_df( # type: ignore self, - hql: str, - schema: str = 'default', - hive_conf: Optional[Dict[Any, Any]] = None, + sql: str, + schema: str = "default", + hive_conf: dict[Any, Any] | None = None, **kwargs, ) -> pandas.DataFrame: """ Get a pandas dataframe from a Hive query - :param hql: hql to be executed. + :param sql: hql to be executed. :param schema: target schema, default to 'default'. :param hive_conf: hive_conf to execute alone with the hql. :param kwargs: (optional) passed into pandas.DataFrame constructor :return: result of hive execution - :rtype: DataFrame >>> hh = HiveServer2Hook() >>> sql = "SELECT * FROM airflow.static_babynames LIMIT 100" @@ -1006,6 +1022,6 @@ def get_pandas_df( # type: ignore :return: pandas.DateFrame """ - res = self.get_results(hql, schema=schema, hive_conf=hive_conf) - df = pandas.DataFrame(res['data'], columns=[c[0] for c in res['header']], **kwargs) + res = self.get_results(sql, schema=schema, hive_conf=hive_conf) + df = pandas.DataFrame(res["data"], columns=[c[0] for c in res["header"]], **kwargs) return df diff --git a/airflow/providers/apache/hive/operators/hive.py b/airflow/providers/apache/hive/operators/hive.py index 45cae0fa4e31f..23f6c32edd3d7 100644 --- a/airflow/providers/apache/hive/operators/hive.py +++ b/airflow/providers/apache/hive/operators/hive.py @@ -15,9 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import os import re -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence +from typing import TYPE_CHECKING, Any, Sequence from airflow.configuration import conf from airflow.models import BaseOperator @@ -57,34 +59,34 @@ class HiveOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'hql', - 'schema', - 'hive_cli_conn_id', - 'mapred_queue', - 'hiveconfs', - 'mapred_job_name', - 'mapred_queue_priority', + "hql", + "schema", + "hive_cli_conn_id", + "mapred_queue", + "hiveconfs", + "mapred_job_name", + "mapred_queue_priority", ) template_ext: Sequence[str] = ( - '.hql', - '.sql', + ".hql", + ".sql", ) - template_fields_renderers = {'hql': 'hql'} - ui_color = '#f0e4ec' + template_fields_renderers = {"hql": "hql"} + ui_color = "#f0e4ec" def __init__( self, *, hql: str, - hive_cli_conn_id: str = 'hive_cli_default', - schema: str = 'default', - hiveconfs: Optional[Dict[Any, Any]] = None, + hive_cli_conn_id: str = "hive_cli_default", + schema: str = "default", + hiveconfs: dict[Any, Any] | None = None, hiveconf_jinja_translate: bool = False, - script_begin_tag: Optional[str] = None, + script_begin_tag: str | None = None, run_as_owner: bool = False, - mapred_queue: Optional[str] = None, - mapred_queue_priority: Optional[str] = None, - mapred_job_name: Optional[str] = None, + mapred_queue: str | None = None, + mapred_queue_priority: str | None = None, + mapred_job_name: str | None = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -101,20 +103,18 @@ def __init__( self.mapred_queue_priority = mapred_queue_priority self.mapred_job_name = mapred_job_name - job_name_template = conf.get( - 'hive', - 'mapred_job_name_template', + job_name_template = conf.get_mandatory_value( + "hive", + "mapred_job_name_template", fallback="Airflow HiveOperator task for {hostname}.{dag_id}.{task_id}.{execution_date}", ) - if job_name_template is None: - raise ValueError("Job name template should be set !") self.mapred_job_name_template: str = job_name_template # assigned lazily - just for consistency we can create the attribute with a # `None` initial value, later it will be populated by the execute method. # This also makes `on_kill` implementation consistent since it assumes `self.hook` # is defined. - self.hook: Optional[HiveCliHook] = None + self.hook: HiveCliHook | None = None def get_hook(self) -> HiveCliHook: """Get Hive cli hook""" @@ -132,18 +132,18 @@ def prepare_template(self) -> None: if self.script_begin_tag and self.script_begin_tag in self.hql: self.hql = "\n".join(self.hql.split(self.script_begin_tag)[1:]) - def execute(self, context: "Context") -> None: - self.log.info('Executing: %s', self.hql) + def execute(self, context: Context) -> None: + self.log.info("Executing: %s", self.hql) self.hook = self.get_hook() # set the mapred_job_name if it's not set with dag, task, execution time info if not self.mapred_job_name: - ti = context['ti'] + ti = context["ti"] self.hook.mapred_job_name = self.mapred_job_name_template.format( dag_id=ti.dag_id, task_id=ti.task_id, execution_date=ti.execution_date.isoformat(), - hostname=ti.hostname.split('.')[0], + hostname=ti.hostname.split(".")[0], ) if self.hiveconf_jinja_translate: @@ -151,7 +151,7 @@ def execute(self, context: "Context") -> None: else: self.hiveconfs.update(context_to_airflow_vars(context)) - self.log.info('Passing HiveConf: %s', self.hiveconfs) + self.log.info("Passing HiveConf: %s", self.hiveconfs) self.hook.run_cli(hql=self.hql, schema=self.schema, hive_conf=self.hiveconfs) def dry_run(self) -> None: @@ -169,6 +169,6 @@ def on_kill(self) -> None: def clear_airflow_vars(self) -> None: """Reset airflow environment variables to prevent existing ones from impacting behavior.""" blank_env_vars = { - value['env_var_format']: '' for value in operator_helpers.AIRFLOW_VAR_NAME_FORMAT_MAPPING.values() + value["env_var_format"]: "" for value in operator_helpers.AIRFLOW_VAR_NAME_FORMAT_MAPPING.values() } os.environ.update(blank_env_vars) diff --git a/airflow/providers/apache/hive/operators/hive_stats.py b/airflow/providers/apache/hive/operators/hive_stats.py index a1b5539622321..caa6770a8fdf8 100644 --- a/airflow/providers/apache/hive/operators/hive_stats.py +++ b/airflow/providers/apache/hive/operators/hive_stats.py @@ -15,10 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import json import warnings from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Callable, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -58,23 +60,23 @@ class HiveStatsCollectionOperator(BaseOperator): column. """ - template_fields: Sequence[str] = ('table', 'partition', 'ds', 'dttm') - ui_color = '#aff7a6' + template_fields: Sequence[str] = ("table", "partition", "ds", "dttm") + ui_color = "#aff7a6" def __init__( self, *, table: str, partition: Any, - extra_exprs: Optional[Dict[str, Any]] = None, - excluded_columns: Optional[List[str]] = None, - assignment_func: Optional[Callable[[str, str], Optional[Dict[Any, Any]]]] = None, - metastore_conn_id: str = 'metastore_default', - presto_conn_id: str = 'presto_default', - mysql_conn_id: str = 'airflow_db', + extra_exprs: dict[str, Any] | None = None, + excluded_columns: list[str] | None = None, + assignment_func: Callable[[str, str], dict[Any, Any] | None] | None = None, + metastore_conn_id: str = "metastore_default", + presto_conn_id: str = "presto_default", + mysql_conn_id: str = "airflow_db", **kwargs: Any, ) -> None: - if 'col_blacklist' in kwargs: + if "col_blacklist" in kwargs: warnings.warn( f"col_blacklist kwarg passed to {self.__class__.__name__} " f"(task_id: {kwargs.get('task_id')}) is deprecated, " @@ -82,44 +84,44 @@ def __init__( category=FutureWarning, stacklevel=2, ) - excluded_columns = kwargs.pop('col_blacklist') + excluded_columns = kwargs.pop("col_blacklist") super().__init__(**kwargs) self.table = table self.partition = partition self.extra_exprs = extra_exprs or {} - self.excluded_columns = excluded_columns or [] # type: List[str] + self.excluded_columns: list[str] = excluded_columns or [] self.metastore_conn_id = metastore_conn_id self.presto_conn_id = presto_conn_id self.mysql_conn_id = mysql_conn_id self.assignment_func = assignment_func - self.ds = '{{ ds }}' - self.dttm = '{{ execution_date.isoformat() }}' + self.ds = "{{ ds }}" + self.dttm = "{{ execution_date.isoformat() }}" - def get_default_exprs(self, col: str, col_type: str) -> Dict[Any, Any]: + def get_default_exprs(self, col: str, col_type: str) -> dict[Any, Any]: """Get default expressions""" if col in self.excluded_columns: return {} - exp = {(col, 'non_null'): f"COUNT({col})"} - if col_type in ['double', 'int', 'bigint', 'float']: - exp[(col, 'sum')] = f'SUM({col})' - exp[(col, 'min')] = f'MIN({col})' - exp[(col, 'max')] = f'MAX({col})' - exp[(col, 'avg')] = f'AVG({col})' - elif col_type == 'boolean': - exp[(col, 'true')] = f'SUM(CASE WHEN {col} THEN 1 ELSE 0 END)' - exp[(col, 'false')] = f'SUM(CASE WHEN NOT {col} THEN 1 ELSE 0 END)' - elif col_type in ['string']: - exp[(col, 'len')] = f'SUM(CAST(LENGTH({col}) AS BIGINT))' - exp[(col, 'approx_distinct')] = f'APPROX_DISTINCT({col})' + exp = {(col, "non_null"): f"COUNT({col})"} + if col_type in {"double", "int", "bigint", "float"}: + exp[(col, "sum")] = f"SUM({col})" + exp[(col, "min")] = f"MIN({col})" + exp[(col, "max")] = f"MAX({col})" + exp[(col, "avg")] = f"AVG({col})" + elif col_type == "boolean": + exp[(col, "true")] = f"SUM(CASE WHEN {col} THEN 1 ELSE 0 END)" + exp[(col, "false")] = f"SUM(CASE WHEN NOT {col} THEN 1 ELSE 0 END)" + elif col_type == "string": + exp[(col, "len")] = f"SUM(CAST(LENGTH({col}) AS BIGINT))" + exp[(col, "approx_distinct")] = f"APPROX_DISTINCT({col})" return exp - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: metastore = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id) table = metastore.get_table(table_name=self.table) field_types = {col.name: col.type for col in table.sd.cols} - exprs: Any = {('', 'count'): 'COUNT(*)'} + exprs: Any = {("", "count"): "COUNT(*)"} for col, col_type in list(field_types.items()): if self.assignment_func: assign_exprs = self.assignment_func(col, col_type) @@ -130,15 +132,15 @@ def execute(self, context: "Context") -> None: exprs.update(assign_exprs) exprs.update(self.extra_exprs) exprs = OrderedDict(exprs) - exprs_str = ",\n ".join(v + " AS " + k[0] + '__' + k[1] for k, v in exprs.items()) + exprs_str = ",\n ".join(f"{v} AS {k[0]}__{k[1]}" for k, v in exprs.items()) where_clause_ = [f"{k} = '{v}'" for k, v in self.partition.items()] where_clause = " AND\n ".join(where_clause_) sql = f"SELECT {exprs_str} FROM {self.table} WHERE {where_clause};" presto = PrestoHook(presto_conn_id=self.presto_conn_id) - self.log.info('Executing SQL check: %s', sql) - row = presto.get_first(hql=sql) + self.log.info("Executing SQL check: %s", sql) + row = presto.get_first(sql) self.log.info("Record: %s", row) if not row: raise AirflowException("The query returned None") @@ -170,15 +172,15 @@ def execute(self, context: "Context") -> None: (self.ds, self.dttm, self.table, part_json) + (r[0][0], r[0][1], r[1]) for r in zip(exprs, row) ] mysql.insert_rows( - table='hive_stats', + table="hive_stats", rows=rows, target_fields=[ - 'ds', - 'dttm', - 'table_name', - 'partition_repr', - 'col', - 'metric', - 'value', + "ds", + "dttm", + "table_name", + "partition_repr", + "col", + "metric", + "value", ], ) diff --git a/airflow/providers/apache/hive/provider.yaml b/airflow/providers/apache/hive/provider.yaml index cf00cf2e7cc88..d3da54dfbb57c 100644 --- a/airflow/providers/apache/hive/provider.yaml +++ b/airflow/providers/apache/hive/provider.yaml @@ -22,6 +22,11 @@ description: | `Apache Hive `__ versions: + - 4.1.0 + - 4.0.1 + - 4.0.0 + - 3.1.0 + - 3.0.0 - 2.3.3 - 2.3.2 - 2.3.1 @@ -37,8 +42,17 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-common-sql>=1.3.1 + - hmsclient>=0.1.0 + - pandas>=0.17.1 + - pyhive[hive]>=0.6.0 + # in case of Python 3.9 sasl library needs to be installed with version higher or equal than + # 0.3.1 because only that version supports Python 3.9. For other Python version pyhive[hive] pulls + # the sasl library anyway (and there sasl library version is not relevant) + - sasl>=0.3.1; python_version>="3.9" + - thrift>=0.9.2 integrations: - integration-name: Apache Hive @@ -86,12 +100,6 @@ transfers: target-integration-name: Apache Hive python-module: airflow.providers.apache.hive.transfers.mssql_to_hive -hook-class-names: - # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.apache.hive.hooks.hive.HiveCliHook - - airflow.providers.apache.hive.hooks.hive.HiveServer2Hook - - airflow.providers.apache.hive.hooks.hive.HiveMetastoreHook - connection-types: - hook-class-name: airflow.providers.apache.hive.hooks.hive.HiveCliHook connection-type: hive_cli diff --git a/airflow/providers/apache/hive/sensors/hive_partition.py b/airflow/providers/apache/hive/sensors/hive_partition.py index f03dcb18f52e1..d839bb444fcbd 100644 --- a/airflow/providers/apache/hive/sensors/hive_partition.py +++ b/airflow/providers/apache/hive/sensors/hive_partition.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Any, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook from airflow.sensors.base import BaseSensorOperator @@ -43,19 +45,19 @@ class HivePartitionSensor(BaseSensorOperator): """ template_fields: Sequence[str] = ( - 'schema', - 'table', - 'partition', + "schema", + "table", + "partition", ) - ui_color = '#C5CAE9' + ui_color = "#C5CAE9" def __init__( self, *, table: str, - partition: Optional[str] = "ds='{{ ds }}'", - metastore_conn_id: str = 'metastore_default', - schema: str = 'default', + partition: str | None = "ds='{{ ds }}'", + metastore_conn_id: str = "metastore_default", + schema: str = "default", poke_interval: int = 60 * 3, **kwargs: Any, ): @@ -67,10 +69,10 @@ def __init__( self.partition = partition self.schema = schema - def poke(self, context: "Context") -> bool: - if '.' in self.table: - self.schema, self.table = self.table.split('.') - self.log.info('Poking for table %s.%s, partition %s', self.schema, self.table, self.partition) - if not hasattr(self, 'hook'): + def poke(self, context: Context) -> bool: + if "." in self.table: + self.schema, self.table = self.table.split(".") + self.log.info("Poking for table %s.%s, partition %s", self.schema, self.table, self.partition) + if not hasattr(self, "hook"): hook = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id) return hook.check_for_partition(self.schema, self.table, self.partition) diff --git a/airflow/providers/apache/hive/sensors/metastore_partition.py b/airflow/providers/apache/hive/sensors/metastore_partition.py index ea6c1525a1d57..57e793849efce 100644 --- a/airflow/providers/apache/hive/sensors/metastore_partition.py +++ b/airflow/providers/apache/hive/sensors/metastore_partition.py @@ -15,9 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from typing import TYPE_CHECKING, Any, Sequence -from airflow.sensors.sql import SqlSensor +from airflow.providers.common.sql.sensors.sql import SqlSensor if TYPE_CHECKING: from airflow.utils.context import Context @@ -40,9 +42,8 @@ class MetastorePartitionSensor(SqlSensor): :param mysql_conn_id: a reference to the MySQL conn_id for the metastore """ - template_fields: Sequence[str] = ('partition_name', 'table', 'schema') - ui_color = '#8da7be' - poke_context_fields = ('partition_name', 'table', 'schema', 'mysql_conn_id') + template_fields: Sequence[str] = ("partition_name", "table", "schema") + ui_color = "#8da7be" def __init__( self, @@ -66,11 +67,11 @@ def __init__( # constructor below and apply_defaults will no longer throw an exception. super().__init__(**kwargs) - def poke(self, context: "Context") -> Any: + def poke(self, context: Context) -> Any: if self.first_poke: self.first_poke = False - if '.' in self.table: - self.schema, self.table = self.table.split('.') + if "." in self.table: + self.schema, self.table = self.table.split(".") self.sql = """ SELECT 'X' FROM PARTITIONS A0 diff --git a/airflow/providers/apache/hive/sensors/named_hive_partition.py b/airflow/providers/apache/hive/sensors/named_hive_partition.py index 9535bcdab0219..1a7e5ee36ebf8 100644 --- a/airflow/providers/apache/hive/sensors/named_hive_partition.py +++ b/airflow/providers/apache/hive/sensors/named_hive_partition.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Any, List, Sequence, Tuple +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence from airflow.sensors.base import BaseSensorOperator @@ -38,15 +40,14 @@ class NamedHivePartitionSensor(BaseSensorOperator): :ref:`metastore thrift service connection id `. """ - template_fields: Sequence[str] = ('partition_names',) - ui_color = '#8d99ae' - poke_context_fields = ('partition_names', 'metastore_conn_id') + template_fields: Sequence[str] = ("partition_names",) + ui_color = "#8d99ae" def __init__( self, *, - partition_names: List[str], - metastore_conn_id: str = 'metastore_default', + partition_names: list[str], + metastore_conn_id: str = "metastore_default", poke_interval: int = 60 * 3, hook: Any = None, **kwargs: Any, @@ -55,28 +56,28 @@ def __init__( self.next_index_to_poke = 0 if isinstance(partition_names, str): - raise TypeError('partition_names must be an array of strings') + raise TypeError("partition_names must be an array of strings") self.metastore_conn_id = metastore_conn_id self.partition_names = partition_names self.hook = hook - if self.hook and metastore_conn_id != 'metastore_default': + if self.hook and metastore_conn_id != "metastore_default": self.log.warning( - 'A hook was passed but a non default metastore_conn_id=%s was used', metastore_conn_id + "A hook was passed but a non default metastore_conn_id=%s was used", metastore_conn_id ) @staticmethod - def parse_partition_name(partition: str) -> Tuple[Any, ...]: + def parse_partition_name(partition: str) -> tuple[Any, ...]: """Get schema, table, and partition info.""" - first_split = partition.split('.', 1) + first_split = partition.split(".", 1) if len(first_split) == 1: - schema = 'default' + schema = "default" table_partition = max(first_split) # poor man first else: schema, table_partition = first_split - second_split = table_partition.split('/', 1) + second_split = table_partition.split("/", 1) if len(second_split) == 1: - raise ValueError('Could not parse ' + partition + 'into table, partition') + raise ValueError(f"Could not parse {partition}into table, partition") else: table, partition = second_split return schema, table, partition @@ -90,10 +91,10 @@ def poke_partition(self, partition: str) -> Any: schema, table, partition = self.parse_partition_name(partition) - self.log.info('Poking for %s.%s/%s', schema, table, partition) + self.log.info("Poking for %s.%s/%s", schema, table, partition) return self.hook.check_for_named_partition(schema, table, partition) - def poke(self, context: "Context") -> bool: + def poke(self, context: Context) -> bool: number_of_partitions = len(self.partition_names) poke_index_start = self.next_index_to_poke @@ -104,12 +105,3 @@ def poke(self, context: "Context") -> bool: self.next_index_to_poke = 0 return True - - def is_smart_sensor_compatible(self): - result = ( - not self.soft_fail - and not self.hook - and len(self.partition_names) <= 30 - and super().is_smart_sensor_compatible() - ) - return result diff --git a/airflow/providers/apache/hive/transfers/hive_to_mysql.py b/airflow/providers/apache/hive/transfers/hive_to_mysql.py index 65de4cc159ae2..b1a3669d7171c 100644 --- a/airflow/providers/apache/hive/transfers/hive_to_mysql.py +++ b/airflow/providers/apache/hive/transfers/hive_to_mysql.py @@ -15,23 +15,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains an operator to move data from Hive to MySQL.""" +from __future__ import annotations + from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Dict, Optional, Sequence +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook from airflow.providers.mysql.hooks.mysql import MySqlHook from airflow.utils.operator_helpers import context_to_airflow_vars -from airflow.www import utils as wwwutils if TYPE_CHECKING: from airflow.utils.context import Context -# TODO: Remove renderer check when the provider has an Airflow 2.3+ requirement. -MYSQL_RENDERER = 'mysql' if 'mysql' in wwwutils.get_attr_renderer() else 'sql' - class HiveToMySqlOperator(BaseOperator): """ @@ -59,26 +56,26 @@ class HiveToMySqlOperator(BaseOperator): :param hive_conf: """ - template_fields: Sequence[str] = ('sql', 'mysql_table', 'mysql_preoperator', 'mysql_postoperator') - template_ext: Sequence[str] = ('.sql',) + template_fields: Sequence[str] = ("sql", "mysql_table", "mysql_preoperator", "mysql_postoperator") + template_ext: Sequence[str] = (".sql",) template_fields_renderers = { - 'sql': 'hql', - 'mysql_preoperator': MYSQL_RENDERER, - 'mysql_postoperator': MYSQL_RENDERER, + "sql": "hql", + "mysql_preoperator": "mysql", + "mysql_postoperator": "mysql", } - ui_color = '#a0e08c' + ui_color = "#a0e08c" def __init__( self, *, sql: str, mysql_table: str, - hiveserver2_conn_id: str = 'hiveserver2_default', - mysql_conn_id: str = 'mysql_default', - mysql_preoperator: Optional[str] = None, - mysql_postoperator: Optional[str] = None, + hiveserver2_conn_id: str = "hiveserver2_default", + mysql_conn_id: str = "mysql_default", + mysql_preoperator: str | None = None, + mysql_postoperator: str | None = None, bulk_load: bool = False, - hive_conf: Optional[Dict] = None, + hive_conf: dict | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -91,7 +88,7 @@ def __init__( self.bulk_load = bulk_load self.hive_conf = hive_conf - def execute(self, context: 'Context'): + def execute(self, context: Context): hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id) self.log.info("Extracting data from Hive: %s", self.sql) @@ -103,15 +100,15 @@ def execute(self, context: 'Context'): hive.to_csv( self.sql, tmp_file.name, - delimiter='\t', - lineterminator='\n', + delimiter="\t", + lineterminator="\n", output_header=False, hive_conf=hive_conf, ) mysql = self._call_preoperator() mysql.bulk_load(table=self.mysql_table, tmp_file=tmp_file.name) else: - hive_results = hive.get_records(self.sql, hive_conf=hive_conf) + hive_results = hive.get_records(self.sql, parameters=hive_conf) mysql = self._call_preoperator() mysql.insert_rows(table=self.mysql_table, rows=hive_results) diff --git a/airflow/providers/apache/hive/transfers/hive_to_samba.py b/airflow/providers/apache/hive/transfers/hive_to_samba.py index c5ab66efa42e0..0ad815ea9cf73 100644 --- a/airflow/providers/apache/hive/transfers/hive_to_samba.py +++ b/airflow/providers/apache/hive/transfers/hive_to_samba.py @@ -15,8 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains an operator to move data from Hive to Samba.""" +from __future__ import annotations from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Sequence @@ -42,33 +42,33 @@ class HiveToSambaOperator(BaseOperator): :ref: `Hive Server2 thrift service connection id `. """ - template_fields: Sequence[str] = ('hql', 'destination_filepath') + template_fields: Sequence[str] = ("hql", "destination_filepath") template_ext: Sequence[str] = ( - '.hql', - '.sql', + ".hql", + ".sql", ) - template_fields_renderers = {'hql': 'hql'} + template_fields_renderers = {"hql": "hql"} def __init__( self, *, hql: str, destination_filepath: str, - samba_conn_id: str = 'samba_default', - hiveserver2_conn_id: str = 'hiveserver2_default', + samba_conn_id: str = "samba_default", + hiveserver2_conn_id: str = "hiveserver2_default", **kwargs, ) -> None: super().__init__(**kwargs) self.hiveserver2_conn_id = hiveserver2_conn_id self.samba_conn_id = samba_conn_id self.destination_filepath = destination_filepath - self.hql = hql.strip().rstrip(';') + self.hql = hql.strip().rstrip(";") - def execute(self, context: 'Context'): + def execute(self, context: Context): with NamedTemporaryFile() as tmp_file: self.log.info("Fetching file from Hive") hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id) - hive.to_csv(hql=self.hql, csv_filepath=tmp_file.name, hive_conf=context_to_airflow_vars(context)) + hive.to_csv(self.hql, csv_filepath=tmp_file.name, hive_conf=context_to_airflow_vars(context)) self.log.info("Pushing to samba") samba = SambaHook(samba_conn_id=self.samba_conn_id) samba.push_from_local(self.destination_filepath, tmp_file.name) diff --git a/airflow/providers/apache/hive/transfers/mssql_to_hive.py b/airflow/providers/apache/hive/transfers/mssql_to_hive.py index 912c2a58a36bd..9cdd581911238 100644 --- a/airflow/providers/apache/hive/transfers/mssql_to_hive.py +++ b/airflow/providers/apache/hive/transfers/mssql_to_hive.py @@ -15,12 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains an operator to move data from MSSQL to Hive.""" +from __future__ import annotations from collections import OrderedDict from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Dict, Optional, Sequence +from typing import TYPE_CHECKING, Sequence import pymssql import unicodecsv as csv @@ -28,7 +28,6 @@ from airflow.models import BaseOperator from airflow.providers.apache.hive.hooks.hive import HiveCliHook from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook -from airflow.www import utils as wwwutils if TYPE_CHECKING: from airflow.utils.context import Context @@ -64,11 +63,10 @@ class MsSqlToHiveOperator(BaseOperator): :param tblproperties: TBLPROPERTIES of the hive table being created """ - template_fields: Sequence[str] = ('sql', 'partition', 'hive_table') - template_ext: Sequence[str] = ('.sql',) - # TODO: Remove renderer check when the provider has an Airflow 2.3+ requirement. - template_fields_renderers = {'sql': 'tsql' if 'tsql' in wwwutils.get_attr_renderer() else 'sql'} - ui_color = '#a0e08c' + template_fields: Sequence[str] = ("sql", "partition", "hive_table") + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "tsql"} + ui_color = "#a0e08c" def __init__( self, @@ -77,11 +75,11 @@ def __init__( hive_table: str, create: bool = True, recreate: bool = False, - partition: Optional[Dict] = None, + partition: dict | None = None, delimiter: str = chr(1), - mssql_conn_id: str = 'mssql_default', - hive_cli_conn_id: str = 'hive_cli_default', - tblproperties: Optional[Dict] = None, + mssql_conn_id: str = "mssql_default", + hive_cli_conn_id: str = "hive_cli_default", + tblproperties: dict | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -100,26 +98,24 @@ def __init__( def type_map(cls, mssql_type: int) -> str: """Maps MsSQL type to Hive type.""" map_dict = { - pymssql.BINARY.value: 'INT', - pymssql.DECIMAL.value: 'FLOAT', - pymssql.NUMBER.value: 'INT', + pymssql.BINARY.value: "INT", + pymssql.DECIMAL.value: "FLOAT", + pymssql.NUMBER.value: "INT", } - return map_dict.get(mssql_type, 'STRING') + return map_dict.get(mssql_type, "STRING") - def execute(self, context: "Context"): + def execute(self, context: Context): mssql = MsSqlHook(mssql_conn_id=self.mssql_conn_id) self.log.info("Dumping Microsoft SQL Server query results to local file") with mssql.get_conn() as conn: with conn.cursor() as cursor: cursor.execute(self.sql) with NamedTemporaryFile("w") as tmp_file: - csv_writer = csv.writer(tmp_file, delimiter=self.delimiter, encoding='utf-8') + csv_writer = csv.writer(tmp_file, delimiter=self.delimiter, encoding="utf-8") field_dict = OrderedDict() - col_count = 0 - for field in cursor.description: - col_count += 1 + for col_count, field in enumerate(cursor.description, start=1): col_position = f"Column{col_count}" - field_dict[col_position if field[0] == '' else field[0]] = self.type_map(field[1]) + field_dict[col_position if field[0] == "" else field[0]] = self.type_map(field[1]) csv_writer.writerows(cursor) tmp_file.flush() diff --git a/airflow/providers/apache/hive/transfers/mysql_to_hive.py b/airflow/providers/apache/hive/transfers/mysql_to_hive.py index ccdaf7b682fc1..1693bf48dad9a 100644 --- a/airflow/providers/apache/hive/transfers/mysql_to_hive.py +++ b/airflow/providers/apache/hive/transfers/mysql_to_hive.py @@ -15,12 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains an operator to move data from MySQL to Hive.""" +from __future__ import annotations from collections import OrderedDict from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Dict, Optional, Sequence +from typing import TYPE_CHECKING, Sequence import MySQLdb import unicodecsv as csv @@ -68,10 +68,10 @@ class MySqlToHiveOperator(BaseOperator): :param tblproperties: TBLPROPERTIES of the hive table being created """ - template_fields: Sequence[str] = ('sql', 'partition', 'hive_table') - template_ext: Sequence[str] = ('.sql',) - template_fields_renderers = {'sql': 'mysql'} - ui_color = '#a0e08c' + template_fields: Sequence[str] = ("sql", "partition", "hive_table") + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "mysql"} + ui_color = "#a0e08c" def __init__( self, @@ -80,14 +80,14 @@ def __init__( hive_table: str, create: bool = True, recreate: bool = False, - partition: Optional[Dict] = None, + partition: dict | None = None, delimiter: str = chr(1), - quoting: Optional[str] = None, + quoting: str | None = None, quotechar: str = '"', - escapechar: Optional[str] = None, - mysql_conn_id: str = 'mysql_default', - hive_cli_conn_id: str = 'hive_cli_default', - tblproperties: Optional[Dict] = None, + escapechar: str | None = None, + mysql_conn_id: str = "mysql_default", + hive_cli_conn_id: str = "hive_cli_default", + tblproperties: dict | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -110,22 +110,22 @@ def type_map(cls, mysql_type: int) -> str: """Maps MySQL type to Hive type.""" types = MySQLdb.constants.FIELD_TYPE type_map = { - types.BIT: 'INT', - types.DECIMAL: 'DOUBLE', - types.NEWDECIMAL: 'DOUBLE', - types.DOUBLE: 'DOUBLE', - types.FLOAT: 'DOUBLE', - types.INT24: 'INT', - types.LONG: 'BIGINT', - types.LONGLONG: 'DECIMAL(38,0)', - types.SHORT: 'INT', - types.TINY: 'SMALLINT', - types.YEAR: 'INT', - types.TIMESTAMP: 'TIMESTAMP', + types.BIT: "INT", + types.DECIMAL: "DOUBLE", + types.NEWDECIMAL: "DOUBLE", + types.DOUBLE: "DOUBLE", + types.FLOAT: "DOUBLE", + types.INT24: "INT", + types.LONG: "BIGINT", + types.LONGLONG: "DECIMAL(38,0)", + types.SHORT: "INT", + types.TINY: "SMALLINT", + types.YEAR: "INT", + types.TIMESTAMP: "TIMESTAMP", } - return type_map.get(mysql_type, 'STRING') + return type_map.get(mysql_type, "STRING") - def execute(self, context: "Context"): + def execute(self, context: Context): hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) diff --git a/airflow/providers/apache/hive/transfers/s3_to_hive.py b/airflow/providers/apache/hive/transfers/s3_to_hive.py index cc189303e0584..470d1d23c7306 100644 --- a/airflow/providers/apache/hive/transfers/s3_to_hive.py +++ b/airflow/providers/apache/hive/transfers/s3_to_hive.py @@ -15,15 +15,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains an operator to move data from an S3 bucket to Hive.""" +from __future__ import annotations import bz2 import gzip import os import tempfile from tempfile import NamedTemporaryFile, TemporaryDirectory -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -86,29 +86,29 @@ class S3ToHiveOperator(BaseOperator): :param select_expression: S3 Select expression """ - template_fields: Sequence[str] = ('s3_key', 'partition', 'hive_table') + template_fields: Sequence[str] = ("s3_key", "partition", "hive_table") template_ext: Sequence[str] = () - ui_color = '#a0e08c' + ui_color = "#a0e08c" def __init__( self, *, s3_key: str, - field_dict: Dict, + field_dict: dict, hive_table: str, - delimiter: str = ',', + delimiter: str = ",", create: bool = True, recreate: bool = False, - partition: Optional[Dict] = None, + partition: dict | None = None, headers: bool = False, check_headers: bool = False, wildcard_match: bool = False, - aws_conn_id: str = 'aws_default', - verify: Optional[Union[bool, str]] = None, - hive_cli_conn_id: str = 'hive_cli_default', + aws_conn_id: str = "aws_default", + verify: bool | str | None = None, + hive_cli_conn_id: str = "hive_cli_default", input_compressed: bool = False, - tblproperties: Optional[Dict] = None, - select_expression: Optional[str] = None, + tblproperties: dict | None = None, + select_expression: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -132,7 +132,7 @@ def __init__( if self.check_headers and not (self.field_dict is not None and self.headers): raise AirflowException("To check_headers provide field_dict and headers") - def execute(self, context: 'Context'): + def execute(self, context: Context): # Downloading file from S3 s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) hive_hook = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) @@ -142,29 +142,29 @@ def execute(self, context: 'Context'): if not s3_hook.check_for_wildcard_key(self.s3_key): raise AirflowException(f"No key matches {self.s3_key}") s3_key_object = s3_hook.get_wildcard_key(self.s3_key) - else: - if not s3_hook.check_for_key(self.s3_key): - raise AirflowException(f"The key {self.s3_key} does not exists") + elif s3_hook.check_for_key(self.s3_key): s3_key_object = s3_hook.get_key(self.s3_key) + else: + raise AirflowException(f"The key {self.s3_key} does not exists") _, file_ext = os.path.splitext(s3_key_object.key) - if self.select_expression and self.input_compressed and file_ext.lower() != '.gz': + if self.select_expression and self.input_compressed and file_ext.lower() != ".gz": raise AirflowException("GZIP is the only compression format Amazon S3 Select supports") - with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir, NamedTemporaryFile( + with TemporaryDirectory(prefix="tmps32hive_") as tmp_dir, NamedTemporaryFile( mode="wb", dir=tmp_dir, suffix=file_ext ) as f: self.log.info("Dumping S3 key %s contents to local file %s", s3_key_object.key, f.name) if self.select_expression: option = {} if self.headers: - option['FileHeaderInfo'] = 'USE' + option["FileHeaderInfo"] = "USE" if self.delimiter: - option['FieldDelimiter'] = self.delimiter + option["FieldDelimiter"] = self.delimiter - input_serialization: Dict[str, Any] = {'CSV': option} + input_serialization: dict[str, Any] = {"CSV": option} if self.input_compressed: - input_serialization['CompressionType'] = 'GZIP' + input_serialization["CompressionType"] = "GZIP" content = s3_hook.select_key( bucket_name=s3_key_object.bucket_name, @@ -227,8 +227,7 @@ def execute(self, context: 'Context'): def _get_top_row_as_list(self, file_name): with open(file_name) as file: header_line = file.readline().strip() - header_list = header_line.split(self.delimiter) - return header_list + return header_line.split(self.delimiter) def _match_headers(self, header_list): if not header_list: @@ -254,13 +253,13 @@ def _match_headers(self, header_list): def _delete_top_row_and_compress(input_file_name, output_file_ext, dest_dir): # When output_file_ext is not defined, file is not compressed open_fn = open - if output_file_ext.lower() == '.gz': + if output_file_ext.lower() == ".gz": open_fn = gzip.GzipFile - elif output_file_ext.lower() == '.bz2': + elif output_file_ext.lower() == ".bz2": open_fn = bz2.BZ2File _, fn_output = tempfile.mkstemp(suffix=output_file_ext, dir=dest_dir) - with open(input_file_name, 'rb') as f_in, open_fn(fn_output, 'wb') as f_out: + with open(input_file_name, "rb") as f_in, open_fn(fn_output, "wb") as f_out: f_in.seek(0) next(f_in) for line in f_in: diff --git a/airflow/providers/apache/hive/transfers/vertica_to_hive.py b/airflow/providers/apache/hive/transfers/vertica_to_hive.py index 7e53638b0c70a..92a677af92cdd 100644 --- a/airflow/providers/apache/hive/transfers/vertica_to_hive.py +++ b/airflow/providers/apache/hive/transfers/vertica_to_hive.py @@ -15,12 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains an operator to move data from Vertica to Hive.""" +from __future__ import annotations from collections import OrderedDict from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence +from typing import TYPE_CHECKING, Any, Sequence import unicodecsv as csv @@ -60,10 +60,10 @@ class VerticaToHiveOperator(BaseOperator): :ref:`Hive CLI connection id `. """ - template_fields: Sequence[str] = ('sql', 'partition', 'hive_table') - template_ext: Sequence[str] = ('.sql',) - template_fields_renderers = {'sql': 'sql'} - ui_color = '#b4e0ff' + template_fields: Sequence[str] = ("sql", "partition", "hive_table") + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "sql"} + ui_color = "#b4e0ff" def __init__( self, @@ -72,10 +72,10 @@ def __init__( hive_table: str, create: bool = True, recreate: bool = False, - partition: Optional[Dict] = None, + partition: dict | None = None, delimiter: str = chr(1), - vertica_conn_id: str = 'vertica_default', - hive_cli_conn_id: str = 'hive_cli_default', + vertica_conn_id: str = "vertica_default", + hive_cli_conn_id: str = "hive_cli_default", **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -97,16 +97,16 @@ def type_map(cls, vertica_type): https://github.com/uber/vertica-python/blob/master/vertica_python/vertica/column.py """ type_map = { - 5: 'BOOLEAN', - 6: 'INT', - 7: 'FLOAT', - 8: 'STRING', - 9: 'STRING', - 16: 'FLOAT', + 5: "BOOLEAN", + 6: "INT", + 7: "FLOAT", + 8: "STRING", + 9: "STRING", + 16: "FLOAT", } - return type_map.get(vertica_type, 'STRING') + return type_map.get(vertica_type, "STRING") - def execute(self, context: 'Context'): + def execute(self, context: Context): hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id) @@ -115,13 +115,11 @@ def execute(self, context: 'Context'): cursor = conn.cursor() cursor.execute(self.sql) with NamedTemporaryFile("w") as f: - csv_writer = csv.writer(f, delimiter=self.delimiter, encoding='utf-8') + csv_writer = csv.writer(f, delimiter=self.delimiter, encoding="utf-8") field_dict = OrderedDict() - col_count = 0 - for field in cursor.description: - col_count += 1 + for col_count, field in enumerate(cursor.description, start=1): col_position = f"Column{col_count}" - field_dict[col_position if field[0] == '' else field[0]] = self.type_map(field[1]) + field_dict[col_position if field[0] == "" else field[0]] = self.type_map(field[1]) csv_writer.writerows(cursor.iterate()) f.flush() cursor.close() diff --git a/airflow/providers/apache/kylin/.latest-doc-only-change.txt b/airflow/providers/apache/kylin/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/apache/kylin/.latest-doc-only-change.txt +++ b/airflow/providers/apache/kylin/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/apache/kylin/CHANGELOG.rst b/airflow/providers/apache/kylin/CHANGELOG.rst index 5b25eda615175..f15f8611cf060 100644 --- a/airflow/providers/apache/kylin/CHANGELOG.rst +++ b/airflow/providers/apache/kylin/CHANGELOG.rst @@ -16,9 +16,60 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Bug Fixes +~~~~~~~~~ + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Remove "bad characters" from our codebase (#24841)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Misc +~~~~ + +* ``AIP-47 - Migrate kylin DAGs to new design #22439 (#24205)`` +* ``chore: Refactoring and Cleaning Apache Providers (#24219)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.0.4 ..... diff --git a/airflow/providers/apache/kylin/example_dags/example_kylin_dag.py b/airflow/providers/apache/kylin/example_dags/example_kylin_dag.py deleted file mode 100644 index 0d68b36d6534e..0000000000000 --- a/airflow/providers/apache/kylin/example_dags/example_kylin_dag.py +++ /dev/null @@ -1,114 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -This is an example DAG which uses the KylinCubeOperator. -The tasks below include kylin build, refresh, merge operation. -""" -from datetime import datetime - -from airflow import DAG -from airflow.providers.apache.kylin.operators.kylin_cube import KylinCubeOperator - -dag = DAG( - dag_id='example_kylin_operator', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - catchup=False, - default_args={'project': 'learn_kylin', 'cube': 'kylin_sales_cube'}, - tags=['example'], -) - - -@dag.task -def gen_build_time(): - """ - Gen build time and push to XCom (with key of "return_value") - :return: A dict with build time values. - """ - return {'date_start': '1325347200000', 'date_end': '1325433600000'} - - -gen_build_time_task = gen_build_time() -gen_build_time_output_date_start = gen_build_time_task['date_start'] -gen_build_time_output_date_end = gen_build_time_task['date_end'] - -build_task1 = KylinCubeOperator( - task_id="kylin_build_1", - command='build', - start_time=gen_build_time_output_date_start, - end_time=gen_build_time_output_date_end, - is_track_job=True, - dag=dag, -) - -build_task2 = KylinCubeOperator( - task_id="kylin_build_2", - command='build', - start_time=gen_build_time_output_date_end, - end_time='1325520000000', - is_track_job=True, - dag=dag, -) - -refresh_task1 = KylinCubeOperator( - task_id="kylin_refresh_1", - command='refresh', - start_time=gen_build_time_output_date_start, - end_time=gen_build_time_output_date_end, - is_track_job=True, - dag=dag, -) - -merge_task = KylinCubeOperator( - task_id="kylin_merge", - command='merge', - start_time=gen_build_time_output_date_start, - end_time='1325520000000', - is_track_job=True, - dag=dag, -) - -disable_task = KylinCubeOperator( - task_id="kylin_disable", - command='disable', - dag=dag, -) - -purge_task = KylinCubeOperator( - task_id="kylin_purge", - command='purge', - dag=dag, -) - -build_task3 = KylinCubeOperator( - task_id="kylin_build_3", - command='build', - start_time=gen_build_time_output_date_end, - end_time='1328730000000', - dag=dag, -) - -build_task1 >> build_task2 >> refresh_task1 >> merge_task >> disable_task >> purge_task >> build_task3 - -# Task dependency created via `XComArgs`: -# gen_build_time >> build_task1 -# gen_build_time >> build_task2 -# gen_build_time >> refresh_task1 -# gen_build_time >> merge_task -# gen_build_time >> build_task3 diff --git a/airflow/providers/apache/kylin/hooks/kylin.py b/airflow/providers/apache/kylin/hooks/kylin.py index 032b15c7e5bbf..709d4e7a281c7 100644 --- a/airflow/providers/apache/kylin/hooks/kylin.py +++ b/airflow/providers/apache/kylin/hooks/kylin.py @@ -15,8 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -from typing import Optional +from __future__ import annotations from kylinpy import exceptions, kylinpy @@ -35,9 +34,9 @@ class KylinHook(BaseHook): def __init__( self, - kylin_conn_id: str = 'kylin_default', - project: Optional[str] = None, - dsn: Optional[str] = None, + kylin_conn_id: str = "kylin_default", + project: str | None = None, + dsn: str | None = None, ): super().__init__() self.kylin_conn_id = kylin_conn_id @@ -48,16 +47,15 @@ def get_conn(self): conn = self.get_connection(self.kylin_conn_id) if self.dsn: return kylinpy.create_kylin(self.dsn) - else: - self.project = self.project if self.project else conn.schema - return kylinpy.Kylin( - conn.host, - username=conn.login, - password=conn.password, - port=conn.port, - project=self.project, - **conn.extra_dejson, - ) + self.project = self.project or conn.schema + return kylinpy.Kylin( + conn.host, + username=conn.login, + password=conn.password, + port=conn.port, + project=self.project, + **conn.extra_dejson, + ) def cube_run(self, datasource_name, op, **op_args): """ @@ -70,8 +68,7 @@ def cube_run(self, datasource_name, op, **op_args): """ cube_source = self.get_conn().get_datasource(datasource_name) try: - response = cube_source.invoke_command(op, **op_args) - return response + return cube_source.invoke_command(op, **op_args) except exceptions.KylinError as err: raise AirflowException(f"Cube operation {op} error , Message: {err}") diff --git a/airflow/providers/apache/kylin/operators/kylin_cube.py b/airflow/providers/apache/kylin/operators/kylin_cube.py index 5fe91ee831934..3f873cbaf54a2 100644 --- a/airflow/providers/apache/kylin/operators/kylin_cube.py +++ b/airflow/providers/apache/kylin/operators/kylin_cube.py @@ -15,10 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import time from datetime import datetime -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Sequence from kylinpy import kylinpy @@ -46,14 +47,14 @@ class KylinCubeOperator(BaseOperator): :param command: (kylin command include 'build', 'merge', 'refresh', 'delete', 'build_streaming', 'merge_streaming', 'refresh_streaming', 'disable', 'enable', 'purge', 'clone', 'drop'. - build - use /kylin/api/cubes/{cubeName}/build rest api,and buildType is ‘BUILD’, + build - use /kylin/api/cubes/{cubeName}/build rest api,and buildType is 'BUILD', and you should give start_time and end_time - refresh - use build rest api,and buildType is ‘REFRESH’ - merge - use build rest api,and buildType is ‘MERGE’ - build_streaming - use /kylin/api/cubes/{cubeName}/build2 rest api,and buildType is ‘BUILD’ + refresh - use build rest api,and buildType is 'REFRESH' + merge - use build rest api,and buildType is 'MERGE' + build_streaming - use /kylin/api/cubes/{cubeName}/build2 rest api,and buildType is 'BUILD' and you should give offset_start and offset_end - refresh_streaming - use build2 rest api,and buildType is ‘REFRESH’ - merge_streaming - use build2 rest api,and buildType is ‘MERGE’ + refresh_streaming - use build2 rest api,and buildType is 'REFRESH' + merge_streaming - use build2 rest api,and buildType is 'MERGE' delete - delete segment, and you should give segment_name value disable - disable cube enable - enable cube @@ -75,41 +76,41 @@ class KylinCubeOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'project', - 'cube', - 'dsn', - 'command', - 'start_time', - 'end_time', - 'segment_name', - 'offset_start', - 'offset_end', + "project", + "cube", + "dsn", + "command", + "start_time", + "end_time", + "segment_name", + "offset_start", + "offset_end", ) - ui_color = '#E79C46' + ui_color = "#E79C46" build_command = { - 'fullbuild', - 'build', - 'merge', - 'refresh', - 'build_streaming', - 'merge_streaming', - 'refresh_streaming', + "fullbuild", + "build", + "merge", + "refresh", + "build_streaming", + "merge_streaming", + "refresh_streaming", } jobs_end_status = {"FINISHED", "ERROR", "DISCARDED", "KILLED", "SUICIDAL", "STOPPED"} def __init__( self, *, - kylin_conn_id: str = 'kylin_default', - project: Optional[str] = None, - cube: Optional[str] = None, - dsn: Optional[str] = None, - command: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - offset_start: Optional[str] = None, - offset_end: Optional[str] = None, - segment_name: Optional[str] = None, + kylin_conn_id: str = "kylin_default", + project: str | None = None, + cube: str | None = None, + dsn: str | None = None, + command: str | None = None, + start_time: str | None = None, + end_time: str | None = None, + offset_start: str | None = None, + offset_end: str | None = None, + segment_name: str | None = None, is_track_job: bool = False, interval: int = 60, timeout: int = 60 * 60 * 24, @@ -133,24 +134,24 @@ def __init__( self.eager_error_status = eager_error_status self.jobs_error_status = [stat.upper() for stat in eager_error_status] - def execute(self, context: 'Context'): + def execute(self, context: Context): _hook = KylinHook(kylin_conn_id=self.kylin_conn_id, project=self.project, dsn=self.dsn) _support_invoke_command = kylinpy.CubeSource.support_invoke_command if not self.command: - raise AirflowException(f'Kylin:Command {self.command} can not be empty') + raise AirflowException(f"Kylin:Command {self.command} can not be empty") if self.command.lower() not in _support_invoke_command: raise AirflowException( - f'Kylin:Command {self.command} can not match kylin command list {_support_invoke_command}' + f"Kylin:Command {self.command} can not match kylin command list {_support_invoke_command}" ) kylinpy_params = { - 'start': datetime.fromtimestamp(int(self.start_time) / 1000) if self.start_time else None, - 'end': datetime.fromtimestamp(int(self.end_time) / 1000) if self.end_time else None, - 'name': self.segment_name, - 'offset_start': int(self.offset_start) if self.offset_start else None, - 'offset_end': int(self.offset_end) if self.offset_end else None, + "start": datetime.fromtimestamp(int(self.start_time) / 1000) if self.start_time else None, + "end": datetime.fromtimestamp(int(self.end_time) / 1000) if self.end_time else None, + "name": self.segment_name, + "offset_start": int(self.offset_start) if self.offset_start else None, + "offset_end": int(self.offset_end) if self.offset_end else None, } rsp_data = _hook.cube_run(self.cube, self.command.lower(), **kylinpy_params) if self.is_track_job and self.command.lower() in self.build_command: @@ -163,13 +164,13 @@ def execute(self, context: 'Context'): job_status = None while job_status not in self.jobs_end_status: if time.monotonic() - started_at > self.timeout: - raise AirflowException(f'kylin job {job_id} timeout') + raise AirflowException(f"kylin job {job_id} timeout") time.sleep(self.interval) job_status = _hook.get_job_status(job_id) - self.log.info('Kylin job status is %s ', job_status) + self.log.info("Kylin job status is %s ", job_status) if job_status in self.jobs_error_status: - raise AirflowException(f'Kylin job {job_id} status {job_status} is error ') + raise AirflowException(f"Kylin job {job_id} status {job_status} is error ") if self.do_xcom_push: return rsp_data diff --git a/airflow/providers/apache/kylin/provider.yaml b/airflow/providers/apache/kylin/provider.yaml index 613042f784a92..73f69523e11e8 100644 --- a/airflow/providers/apache/kylin/provider.yaml +++ b/airflow/providers/apache/kylin/provider.yaml @@ -22,6 +22,8 @@ description: | `Apache Kylin `__ versions: + - 3.1.0 + - 3.0.0 - 2.0.4 - 2.0.3 - 2.0.2 @@ -30,8 +32,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - kylinpy>=2.6 integrations: - integration-name: Apache Kylin diff --git a/airflow/providers/apache/livy/.latest-doc-only-change.txt b/airflow/providers/apache/livy/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/apache/livy/.latest-doc-only-change.txt +++ b/airflow/providers/apache/livy/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/apache/livy/CHANGELOG.rst b/airflow/providers/apache/livy/CHANGELOG.rst index 43556648aed02..7a4e720bdd9f4 100644 --- a/airflow/providers/apache/livy/CHANGELOG.rst +++ b/airflow/providers/apache/livy/CHANGELOG.rst @@ -16,9 +16,70 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.2.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Features +~~~~~~~~ + +* ``Add template to livy operator documentation (#27404)`` +* ``Add Spark's 'appId' to xcom output (#27376)`` +* ``add template field renderer to livy operator (#27321)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Add auth_type to LivyHook (#25183)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``AIP-47 - Migrate livy DAGs to new design #22439 (#24208)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.2.3 ..... diff --git a/airflow/providers/apache/livy/example_dags/example_livy.py b/airflow/providers/apache/livy/example_dags/example_livy.py deleted file mode 100644 index cf7bbbfb1b18d..0000000000000 --- a/airflow/providers/apache/livy/example_dags/example_livy.py +++ /dev/null @@ -1,50 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -This is an example DAG which uses the LivyOperator. -The tasks below trigger the computation of pi on the Spark instance -using the Java and Python executables provided in the example library. -""" -from datetime import datetime - -from airflow import DAG -from airflow.providers.apache.livy.operators.livy import LivyOperator - -with DAG( - dag_id='example_livy_operator', - default_args={'args': [10]}, - schedule_interval='@daily', - start_date=datetime(2021, 1, 1), - catchup=False, -) as dag: - - # [START create_livy] - livy_java_task = LivyOperator( - task_id='pi_java_task', - file='/spark-examples.jar', - num_executors=1, - conf={ - 'spark.shuffle.compress': 'false', - }, - class_name='org.apache.spark.examples.SparkPi', - ) - - livy_python_task = LivyOperator(task_id='pi_python_task', file='/pi.py', polling_interval=60) - - livy_java_task >> livy_python_task - # [END create_livy] diff --git a/airflow/providers/apache/livy/hooks/livy.py b/airflow/providers/apache/livy/hooks/livy.py index e809dc02c7fb4..8c3e49a545431 100644 --- a/airflow/providers/apache/livy/hooks/livy.py +++ b/airflow/providers/apache/livy/hooks/livy.py @@ -14,12 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains the Apache Livy hook.""" +from __future__ import annotations + import json import re from enum import Enum -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Sequence import requests @@ -31,16 +32,16 @@ class BatchState(Enum): """Batch session states""" - NOT_STARTED = 'not_started' - STARTING = 'starting' - RUNNING = 'running' - IDLE = 'idle' - BUSY = 'busy' - SHUTTING_DOWN = 'shutting_down' - ERROR = 'error' - DEAD = 'dead' - KILLED = 'killed' - SUCCESS = 'success' + NOT_STARTED = "not_started" + STARTING = "starting" + RUNNING = "running" + IDLE = "idle" + BUSY = "busy" + SHUTTING_DOWN = "shutting_down" + ERROR = "error" + DEAD = "dead" + KILLED = "killed" + SUCCESS = "success" class LivyHook(HttpHook, LoggingMixin): @@ -50,6 +51,7 @@ class LivyHook(HttpHook, LoggingMixin): :param livy_conn_id: reference to a pre-defined Livy Connection. :param extra_options: A dictionary of options passed to Livy. :param extra_headers: A dictionary of headers passed to the HTTP request to livy. + :param auth_type: The auth type for the service. .. seealso:: For more details refer to the Apache Livy API reference: @@ -63,30 +65,31 @@ class LivyHook(HttpHook, LoggingMixin): BatchState.ERROR, } - _def_headers = {'Content-Type': 'application/json', 'Accept': 'application/json'} + _def_headers = {"Content-Type": "application/json", "Accept": "application/json"} - conn_name_attr = 'livy_conn_id' - default_conn_name = 'livy_default' - conn_type = 'livy' - hook_name = 'Apache Livy' + conn_name_attr = "livy_conn_id" + default_conn_name = "livy_default" + conn_type = "livy" + hook_name = "Apache Livy" def __init__( self, livy_conn_id: str = default_conn_name, - extra_options: Optional[Dict[str, Any]] = None, - extra_headers: Optional[Dict[str, Any]] = None, + extra_options: dict[str, Any] | None = None, + extra_headers: dict[str, Any] | None = None, + auth_type: Any | None = None, ) -> None: super().__init__(http_conn_id=livy_conn_id) self.extra_headers = extra_headers or {} self.extra_options = extra_options or {} + self.auth_type = auth_type or self.auth_type - def get_conn(self, headers: Optional[Dict[str, Any]] = None) -> Any: + def get_conn(self, headers: dict[str, Any] | None = None) -> Any: """ Returns http session for use with requests :param headers: additional headers to be passed through as a dictionary :return: requests session - :rtype: requests.Session """ tmp_headers = self._def_headers.copy() # setting default headers if headers: @@ -96,10 +99,10 @@ def get_conn(self, headers: Optional[Dict[str, Any]] = None) -> Any: def run_method( self, endpoint: str, - method: str = 'GET', - data: Optional[Any] = None, - headers: Optional[Dict[str, Any]] = None, - retry_args: Optional[Dict[str, Any]] = None, + method: str = "GET", + data: Any | None = None, + headers: dict[str, Any] | None = None, + retry_args: dict[str, Any] | None = None, ) -> Any: """ Wrapper for HttpHook, allows to change method on the same HttpHook @@ -111,12 +114,11 @@ def run_method( :param retry_args: Arguments which define the retry behaviour. See Tenacity documentation at https://github.com/jd/tenacity :return: http response - :rtype: requests.Response """ - if method not in ('GET', 'POST', 'PUT', 'DELETE', 'HEAD'): + if method not in ("GET", "POST", "PUT", "DELETE", "HEAD"): raise ValueError(f"Invalid http method '{method}'") if not self.extra_options: - self.extra_options = {'check_response': False} + self.extra_options = {"check_response": False} back_method = self.method self.method = method @@ -136,12 +138,11 @@ def run_method( self.method = back_method return result - def post_batch(self, *args: Any, **kwargs: Any) -> Any: + def post_batch(self, *args: Any, **kwargs: Any) -> int: """ Perform request to submit batch :return: batch session id - :rtype: int """ batch_submit_body = json.dumps(self.build_post_batch_body(*args, **kwargs)) @@ -151,7 +152,7 @@ def post_batch(self, *args: Any, **kwargs: Any) -> Any: self.log.info("Submitting job %s to %s", batch_submit_body, self.base_url) response = self.run_method( - method='POST', endpoint='/batches', data=batch_submit_body, headers=self.extra_headers + method="POST", endpoint="/batches", data=batch_submit_body, headers=self.extra_headers ) self.log.debug("Got response: %s", response.text) @@ -170,18 +171,17 @@ def post_batch(self, *args: Any, **kwargs: Any) -> Any: return batch_id - def get_batch(self, session_id: Union[int, str]) -> Any: + def get_batch(self, session_id: int | str) -> dict: """ Fetch info about the specified batch :param session_id: identifier of the batch sessions :return: response body - :rtype: dict """ self._validate_session_id(session_id) self.log.debug("Fetching info for batch session %d", session_id) - response = self.run_method(endpoint=f'/batches/{session_id}', headers=self.extra_headers) + response = self.run_method(endpoint=f"/batches/{session_id}", headers=self.extra_headers) try: response.raise_for_status() @@ -193,9 +193,7 @@ def get_batch(self, session_id: Union[int, str]) -> Any: return response.json() - def get_batch_state( - self, session_id: Union[int, str], retry_args: Optional[Dict[str, Any]] = None - ) -> BatchState: + def get_batch_state(self, session_id: int | str, retry_args: dict[str, Any] | None = None) -> BatchState: """ Fetch the state of the specified batch @@ -203,13 +201,12 @@ def get_batch_state( :param retry_args: Arguments which define the retry behaviour. See Tenacity documentation at https://github.com/jd/tenacity :return: batch state - :rtype: BatchState """ self._validate_session_id(session_id) self.log.debug("Fetching info for batch session %d", session_id) response = self.run_method( - endpoint=f'/batches/{session_id}/state', retry_args=retry_args, headers=self.extra_headers + endpoint=f"/batches/{session_id}/state", retry_args=retry_args, headers=self.extra_headers ) try: @@ -221,23 +218,22 @@ def get_batch_state( ) jresp = response.json() - if 'state' not in jresp: + if "state" not in jresp: raise AirflowException(f"Unable to get state for batch with id: {session_id}") - return BatchState(jresp['state']) + return BatchState(jresp["state"]) - def delete_batch(self, session_id: Union[int, str]) -> Any: + def delete_batch(self, session_id: int | str) -> dict: """ Delete the specified batch :param session_id: identifier of the batch sessions :return: response body - :rtype: dict """ self._validate_session_id(session_id) self.log.info("Deleting batch session %d", session_id) response = self.run_method( - method='DELETE', endpoint=f'/batches/{session_id}', headers=self.extra_headers + method="DELETE", endpoint=f"/batches/{session_id}", headers=self.extra_headers ) try: @@ -250,7 +246,7 @@ def delete_batch(self, session_id: Union[int, str]) -> Any: return response.json() - def get_batch_logs(self, session_id: Union[int, str], log_start_position, log_batch_size) -> Any: + def get_batch_logs(self, session_id: int | str, log_start_position, log_batch_size) -> dict: """ Gets the session logs for a specified batch. :param session_id: identifier of the batch sessions @@ -258,12 +254,11 @@ def get_batch_logs(self, session_id: Union[int, str], log_start_position, log_ba :param log_batch_size: Number of lines to pull in one batch :return: response body - :rtype: dict """ self._validate_session_id(session_id) - log_params = {'from': log_start_position, 'size': log_batch_size} + log_params = {"from": log_start_position, "size": log_batch_size} response = self.run_method( - endpoint=f'/batches/{session_id}/log', data=log_params, headers=self.extra_headers + endpoint=f"/batches/{session_id}/log", data=log_params, headers=self.extra_headers ) try: response.raise_for_status() @@ -275,13 +270,12 @@ def get_batch_logs(self, session_id: Union[int, str], log_start_position, log_ba ) return response.json() - def dump_batch_logs(self, session_id: Union[int, str]) -> Any: + def dump_batch_logs(self, session_id: int | str) -> None: """ Dumps the session logs for a specified batch :param session_id: identifier of the batch sessions :return: response body - :rtype: dict """ self.log.info("Fetching the logs for batch session with id: %d", session_id) log_start_line = 0 @@ -291,14 +285,14 @@ def dump_batch_logs(self, session_id: Union[int, str]) -> Any: while log_start_line <= log_total_lines: # Livy log endpoint is paginated. response = self.get_batch_logs(session_id, log_start_line, log_batch_size) - log_total_lines = self._parse_request_response(response, 'total') + log_total_lines = self._parse_request_response(response, "total") log_start_line += log_batch_size - log_lines = self._parse_request_response(response, 'log') + log_lines = self._parse_request_response(response, "log") for log_line in log_lines: self.log.info(log_line) @staticmethod - def _validate_session_id(session_id: Union[int, str]) -> None: + def _validate_session_id(session_id: int | str) -> None: """ Validate session id is a int @@ -310,46 +304,44 @@ def _validate_session_id(session_id: Union[int, str]) -> None: raise TypeError("'session_id' must be an integer") @staticmethod - def _parse_post_response(response: Dict[Any, Any]) -> Any: + def _parse_post_response(response: dict[Any, Any]) -> int | None: """ Parse batch response for batch id :param response: response body :return: session id - :rtype: int """ - return response.get('id') + return response.get("id") @staticmethod - def _parse_request_response(response: Dict[Any, Any], parameter) -> Any: + def _parse_request_response(response: dict[Any, Any], parameter): """ Parse batch response for batch id :param response: response body :return: value of parameter - :rtype: Union[int, list] """ - return response.get(parameter) + return response.get(parameter, []) @staticmethod def build_post_batch_body( file: str, - args: Optional[Sequence[Union[str, int, float]]] = None, - class_name: Optional[str] = None, - jars: Optional[List[str]] = None, - py_files: Optional[List[str]] = None, - files: Optional[List[str]] = None, - archives: Optional[List[str]] = None, - name: Optional[str] = None, - driver_memory: Optional[str] = None, - driver_cores: Optional[Union[int, str]] = None, - executor_memory: Optional[str] = None, - executor_cores: Optional[int] = None, - num_executors: Optional[Union[int, str]] = None, - queue: Optional[str] = None, - proxy_user: Optional[str] = None, - conf: Optional[Dict[Any, Any]] = None, - ) -> Any: + args: Sequence[str | int | float] | None = None, + class_name: str | None = None, + jars: list[str] | None = None, + py_files: list[str] | None = None, + files: list[str] | None = None, + archives: list[str] | None = None, + name: str | None = None, + driver_memory: str | None = None, + driver_cores: int | str | None = None, + executor_memory: str | None = None, + executor_cores: int | None = None, + num_executors: int | str | None = None, + queue: str | None = None, + proxy_user: str | None = None, + conf: dict[Any, Any] | None = None, + ) -> dict: """ Build the post batch request body. For more information about the format refer to @@ -371,40 +363,39 @@ def build_post_batch_body( :param name: The name of this session string. :param conf: Spark configuration properties. :return: request body - :rtype: dict """ - body: Dict[str, Any] = {'file': file} + body: dict[str, Any] = {"file": file} if proxy_user: - body['proxyUser'] = proxy_user + body["proxyUser"] = proxy_user if class_name: - body['className'] = class_name + body["className"] = class_name if args and LivyHook._validate_list_of_stringables(args): - body['args'] = [str(val) for val in args] + body["args"] = [str(val) for val in args] if jars and LivyHook._validate_list_of_stringables(jars): - body['jars'] = jars + body["jars"] = jars if py_files and LivyHook._validate_list_of_stringables(py_files): - body['pyFiles'] = py_files + body["pyFiles"] = py_files if files and LivyHook._validate_list_of_stringables(files): - body['files'] = files + body["files"] = files if driver_memory and LivyHook._validate_size_format(driver_memory): - body['driverMemory'] = driver_memory + body["driverMemory"] = driver_memory if driver_cores: - body['driverCores'] = driver_cores + body["driverCores"] = driver_cores if executor_memory and LivyHook._validate_size_format(executor_memory): - body['executorMemory'] = executor_memory + body["executorMemory"] = executor_memory if executor_cores: - body['executorCores'] = executor_cores + body["executorCores"] = executor_cores if num_executors: - body['numExecutors'] = num_executors + body["numExecutors"] = num_executors if archives and LivyHook._validate_list_of_stringables(archives): - body['archives'] = archives + body["archives"] = archives if queue: - body['queue'] = queue + body["queue"] = queue if name: - body['name'] = name + body["name"] = name if conf and LivyHook._validate_extra_conf(conf): - body['conf'] = conf + body["conf"] = conf return body @@ -415,20 +406,18 @@ def _validate_size_format(size: str) -> bool: :param size: size value :return: true if valid format - :rtype: bool """ - if size and not (isinstance(size, str) and re.match(r'^\d+[kmgt]b?$', size, re.IGNORECASE)): + if size and not (isinstance(size, str) and re.match(r"^\d+[kmgt]b?$", size, re.IGNORECASE)): raise ValueError(f"Invalid java size format for string'{size}'") return True @staticmethod - def _validate_list_of_stringables(vals: Sequence[Union[str, int, float]]) -> bool: + def _validate_list_of_stringables(vals: Sequence[str | int | float]) -> bool: """ Check the values in the provided list can be converted to strings. :param vals: list to validate :return: true if valid - :rtype: bool """ if ( vals is None @@ -439,13 +428,12 @@ def _validate_list_of_stringables(vals: Sequence[Union[str, int, float]]) -> boo return True @staticmethod - def _validate_extra_conf(conf: Dict[Any, Any]) -> bool: + def _validate_extra_conf(conf: dict[Any, Any]) -> bool: """ Check configuration values are either strings or ints. :param conf: configuration variable :return: true if valid - :rtype: bool """ if conf: if not isinstance(conf, dict): diff --git a/airflow/providers/apache/livy/operators/livy.py b/airflow/providers/apache/livy/operators/livy.py index f0dbc9e3165fe..313f64c9f9afd 100644 --- a/airflow/providers/apache/livy/operators/livy.py +++ b/airflow/providers/apache/livy/operators/livy.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains the Apache Livy operator.""" +from __future__ import annotations + from time import sleep -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -32,23 +33,24 @@ class LivyOperator(BaseOperator): This operator wraps the Apache Livy batch REST API, allowing to submit a Spark application to the underlying cluster. - :param file: path of the file containing the application to execute (required). - :param class_name: name of the application Java/Spark main class. - :param args: application command line arguments. - :param jars: jars to be used in this sessions. - :param py_files: python files to be used in this session. - :param files: files to be used in this session. - :param driver_memory: amount of memory to use for the driver process. - :param driver_cores: number of cores to use for the driver process. - :param executor_memory: amount of memory to use per executor process. - :param executor_cores: number of cores to use for each executor. - :param num_executors: number of executors to launch for this session. - :param archives: archives to be used in this session. - :param queue: name of the YARN queue to which the application is submitted. - :param name: name of this session. - :param conf: Spark configuration properties. - :param proxy_user: user to impersonate when running the job. + :param file: path of the file containing the application to execute (required). (templated) + :param class_name: name of the application Java/Spark main class. (templated) + :param args: application command line arguments. (templated) + :param jars: jars to be used in this sessions. (templated) + :param py_files: python files to be used in this session. (templated) + :param files: files to be used in this session. (templated) + :param driver_memory: amount of memory to use for the driver process. (templated) + :param driver_cores: number of cores to use for the driver process. (templated) + :param executor_memory: amount of memory to use per executor process. (templated) + :param executor_cores: number of cores to use for each executor. (templated) + :param num_executors: number of executors to launch for this session. (templated) + :param archives: archives to be used in this session. (templated) + :param queue: name of the YARN queue to which the application is submitted. (templated) + :param name: name of this session. (templated) + :param conf: Spark configuration properties. (templated) + :param proxy_user: user to impersonate when running the job. (templated) :param livy_conn_id: reference to a pre-defined Livy Connection. + :param livy_conn_auth_type: The auth type for the Livy Connection. :param polling_interval: time in seconds between polling for job completion. Don't poll for values >=0 :param extra_options: A dictionary of options, where key is string and value depends on the option that's being modified. @@ -57,63 +59,66 @@ class LivyOperator(BaseOperator): See Tenacity documentation at https://github.com/jd/tenacity """ - template_fields: Sequence[str] = ('spark_params',) + template_fields: Sequence[str] = ("spark_params",) + template_fields_renderers = {"spark_params": "json"} def __init__( self, *, file: str, - class_name: Optional[str] = None, - args: Optional[Sequence[Union[str, int, float]]] = None, - conf: Optional[Dict[Any, Any]] = None, - jars: Optional[Sequence[str]] = None, - py_files: Optional[Sequence[str]] = None, - files: Optional[Sequence[str]] = None, - driver_memory: Optional[str] = None, - driver_cores: Optional[Union[int, str]] = None, - executor_memory: Optional[str] = None, - executor_cores: Optional[Union[int, str]] = None, - num_executors: Optional[Union[int, str]] = None, - archives: Optional[Sequence[str]] = None, - queue: Optional[str] = None, - name: Optional[str] = None, - proxy_user: Optional[str] = None, - livy_conn_id: str = 'livy_default', + class_name: str | None = None, + args: Sequence[str | int | float] | None = None, + conf: dict[Any, Any] | None = None, + jars: Sequence[str] | None = None, + py_files: Sequence[str] | None = None, + files: Sequence[str] | None = None, + driver_memory: str | None = None, + driver_cores: int | str | None = None, + executor_memory: str | None = None, + executor_cores: int | str | None = None, + num_executors: int | str | None = None, + archives: Sequence[str] | None = None, + queue: str | None = None, + name: str | None = None, + proxy_user: str | None = None, + livy_conn_id: str = "livy_default", + livy_conn_auth_type: Any | None = None, polling_interval: int = 0, - extra_options: Optional[Dict[str, Any]] = None, - extra_headers: Optional[Dict[str, Any]] = None, - retry_args: Optional[Dict[str, Any]] = None, + extra_options: dict[str, Any] | None = None, + extra_headers: dict[str, Any] | None = None, + retry_args: dict[str, Any] | None = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.spark_params = { - 'file': file, - 'class_name': class_name, - 'args': args, - 'jars': jars, - 'py_files': py_files, - 'files': files, - 'driver_memory': driver_memory, - 'driver_cores': driver_cores, - 'executor_memory': executor_memory, - 'executor_cores': executor_cores, - 'num_executors': num_executors, - 'archives': archives, - 'queue': queue, - 'name': name, - 'conf': conf, - 'proxy_user': proxy_user, + "file": file, + "class_name": class_name, + "args": args, + "jars": jars, + "py_files": py_files, + "files": files, + "driver_memory": driver_memory, + "driver_cores": driver_cores, + "executor_memory": executor_memory, + "executor_cores": executor_cores, + "num_executors": num_executors, + "archives": archives, + "queue": queue, + "name": name, + "conf": conf, + "proxy_user": proxy_user, } self._livy_conn_id = livy_conn_id + self._livy_conn_auth_type = livy_conn_auth_type self._polling_interval = polling_interval self._extra_options = extra_options or {} self._extra_headers = extra_headers or {} - self._livy_hook: Optional[LivyHook] = None - self._batch_id: Union[int, str] + self._livy_hook: LivyHook | None = None + self._batch_id: int | str self.retry_args = retry_args def get_hook(self) -> LivyHook: @@ -121,25 +126,27 @@ def get_hook(self) -> LivyHook: Get valid hook. :return: hook - :rtype: LivyHook """ if self._livy_hook is None or not isinstance(self._livy_hook, LivyHook): self._livy_hook = LivyHook( livy_conn_id=self._livy_conn_id, extra_headers=self._extra_headers, extra_options=self._extra_options, + auth_type=self._livy_conn_auth_type, ) return self._livy_hook - def execute(self, context: "Context") -> Any: + def execute(self, context: Context) -> Any: self._batch_id = self.get_hook().post_batch(**self.spark_params) if self._polling_interval > 0: self.poll_for_termination(self._batch_id) + context["ti"].xcom_push(key="app_id", value=self.get_hook().get_batch(self._batch_id)["appId"]) + return self._batch_id - def poll_for_termination(self, batch_id: Union[int, str]) -> None: + def poll_for_termination(self, batch_id: int | str) -> None: """ Pool Livy for batch termination. @@ -148,7 +155,7 @@ def poll_for_termination(self, batch_id: Union[int, str]) -> None: hook = self.get_hook() state = hook.get_batch_state(batch_id, retry_args=self.retry_args) while state not in hook.TERMINAL_STATES: - self.log.debug('Batch with id %s is in state: %s', batch_id, state.value) + self.log.debug("Batch with id %s is in state: %s", batch_id, state.value) sleep(self._polling_interval) state = hook.get_batch_state(batch_id, retry_args=self.retry_args) self.log.info("Batch with id %s terminated with state: %s", batch_id, state.value) diff --git a/airflow/providers/apache/livy/provider.yaml b/airflow/providers/apache/livy/provider.yaml index c624e0b9d360a..8e69ef935d03a 100644 --- a/airflow/providers/apache/livy/provider.yaml +++ b/airflow/providers/apache/livy/provider.yaml @@ -22,6 +22,9 @@ description: | `Apache Livy `__ versions: + - 3.2.0 + - 3.1.0 + - 3.0.0 - 2.2.3 - 2.2.2 - 2.2.1 @@ -32,8 +35,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-http integrations: - integration-name: Apache Livy @@ -58,9 +62,6 @@ hooks: python-modules: - airflow.providers.apache.livy.hooks.livy -hook-class-names: - # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.apache.livy.hooks.livy.LivyHook connection-types: - hook-class-name: airflow.providers.apache.livy.hooks.livy.LivyHook diff --git a/airflow/providers/apache/livy/sensors/livy.py b/airflow/providers/apache/livy/sensors/livy.py index 4c3419f2af4b2..d0838b187a965 100644 --- a/airflow/providers/apache/livy/sensors/livy.py +++ b/airflow/providers/apache/livy/sensors/livy.py @@ -14,9 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains the Apache Livy sensor.""" -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence from airflow.providers.apache.livy.hooks.livy import LivyHook from airflow.sensors.base import BaseSensorOperator @@ -34,20 +35,22 @@ class LivySensor(BaseSensorOperator): depends on the option that's being modified. """ - template_fields: Sequence[str] = ('batch_id',) + template_fields: Sequence[str] = ("batch_id",) def __init__( self, *, - batch_id: Union[int, str], - livy_conn_id: str = 'livy_default', - extra_options: Optional[Dict[str, Any]] = None, + batch_id: int | str, + livy_conn_id: str = "livy_default", + livy_conn_auth_type: Any | None = None, + extra_options: dict[str, Any] | None = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.batch_id = batch_id self._livy_conn_id = livy_conn_id - self._livy_hook: Optional[LivyHook] = None + self._livy_conn_auth_type = livy_conn_auth_type + self._livy_hook: LivyHook | None = None self._extra_options = extra_options or {} def get_hook(self) -> LivyHook: @@ -55,13 +58,16 @@ def get_hook(self) -> LivyHook: Get valid hook. :return: hook - :rtype: LivyHook """ if self._livy_hook is None or not isinstance(self._livy_hook, LivyHook): - self._livy_hook = LivyHook(livy_conn_id=self._livy_conn_id, extra_options=self._extra_options) + self._livy_hook = LivyHook( + livy_conn_id=self._livy_conn_id, + extra_options=self._extra_options, + auth_type=self._livy_conn_auth_type, + ) return self._livy_hook - def poke(self, context: "Context") -> bool: + def poke(self, context: Context) -> bool: batch_id = self.batch_id status = self.get_hook().get_batch_state(batch_id) diff --git a/airflow/providers/apache/pig/.latest-doc-only-change.txt b/airflow/providers/apache/pig/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/apache/pig/.latest-doc-only-change.txt +++ b/airflow/providers/apache/pig/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/apache/pig/CHANGELOG.rst b/airflow/providers/apache/pig/CHANGELOG.rst index c599c06bad8bc..f8338b3e9c91f 100644 --- a/airflow/providers/apache/pig/CHANGELOG.rst +++ b/airflow/providers/apache/pig/CHANGELOG.rst @@ -16,9 +16,65 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +4.0.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +You cannot use ``pig_properties`` any more as connection extras. If you want to add extra parameters +to ``pig`` command, you need to do it via ``pig_properties`` (string list) of the PigCliHook (new parameter) +or via ``pig_opts`` (string with options separated by spaces) or ``pig_properties`` (string list) in +the PigOperator . Any use of ``pig_properties`` extras in connection will raise an exception, +informing that you need to remove them and pass them as parameters. + +Both ``pig_properties`` and ``pig_opts`` are now templated fields in the PigOperator. + +* ``Pig cli connection properties cannot be passed by connection extra (#27644)`` + + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``AIP-47 - Migrate apache pig DAGs to new design #22439 (#24212)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.0.4 ..... diff --git a/airflow/providers/apache/pig/example_dags/example_pig.py b/airflow/providers/apache/pig/example_dags/example_pig.py deleted file mode 100644 index ed1b34ab0c8a4..0000000000000 --- a/airflow/providers/apache/pig/example_dags/example_pig.py +++ /dev/null @@ -1,40 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Example DAG demonstrating the usage of the PigOperator.""" -from datetime import datetime - -from airflow import DAG -from airflow.providers.apache.pig.operators.pig import PigOperator - -dag = DAG( - dag_id='example_pig_operator', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['example'], -) - -# [START create_pig] -run_this = PigOperator( - task_id="run_example_pig_script", - pig="ls /;", - pig_opts="-x local", - dag=dag, -) -# [END create_pig] diff --git a/airflow/providers/apache/pig/hooks/pig.py b/airflow/providers/apache/pig/hooks/pig.py index ae9db33c3db53..023b308e13e2c 100644 --- a/airflow/providers/apache/pig/hooks/pig.py +++ b/airflow/providers/apache/pig/hooks/pig.py @@ -15,63 +15,72 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import subprocess from tempfile import NamedTemporaryFile, TemporaryDirectory -from typing import Any, List, Optional +from typing import Any from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook class PigCliHook(BaseHook): - """ - Simple wrapper around the pig CLI. - - Note that you can also set default pig CLI properties using the - ``pig_properties`` to be used in your connection as in - ``{"pig_properties": "-Dpig.tmpfilecompression=true"}`` + """Simple wrapper around the pig CLI. + :param pig_cli_conn_id: Connection id used by the hook + :param pig_properties: additional properties added after pig cli command as list of strings. """ - conn_name_attr = 'pig_cli_conn_id' - default_conn_name = 'pig_cli_default' - conn_type = 'pig_cli' - hook_name = 'Pig Client Wrapper' + conn_name_attr = "pig_cli_conn_id" + default_conn_name = "pig_cli_default" + conn_type = "pig_cli" + hook_name = "Pig Client Wrapper" - def __init__(self, pig_cli_conn_id: str = default_conn_name) -> None: + def __init__( + self, pig_cli_conn_id: str = default_conn_name, pig_properties: list[str] | None = None + ) -> None: super().__init__() conn = self.get_connection(pig_cli_conn_id) - self.pig_properties = conn.extra_dejson.get('pig_properties', '') + conn_pig_properties = conn.extra_dejson.get("pig_properties") + if conn_pig_properties: + raise RuntimeError( + "The PigCliHook used to have possibility of passing `pig_properties` to the Hook," + " however with the 4.0.0 version of `apache-pig` provider it has been removed. You should" + " use ``pig_opts`` (space separated string) or ``pig_properties`` (string list) in the" + " PigOperator. You can also pass ``pig-properties`` in the PigCliHook `init`. Currently," + f" the {pig_cli_conn_id} connection has those extras: `{conn_pig_properties}`." + ) + self.pig_properties = pig_properties if pig_properties else [] self.conn = conn self.sub_process = None - def run_cli(self, pig: str, pig_opts: Optional[str] = None, verbose: bool = True) -> Any: + def run_cli(self, pig: str, pig_opts: str | None = None, verbose: bool = True) -> Any: """ - Run an pig script using the pig cli + Run a pig script using the pig cli >>> ph = PigCliHook() >>> result = ph.run_cli("ls /;", pig_opts="-x mapreduce") >>> ("hdfs://" in result) True """ - with TemporaryDirectory(prefix='airflow_pigop_') as tmp_dir: + with TemporaryDirectory(prefix="airflow_pigop_") as tmp_dir: with NamedTemporaryFile(dir=tmp_dir) as f: - f.write(pig.encode('utf-8')) + f.write(pig.encode("utf-8")) f.flush() fname = f.name - pig_bin = 'pig' - cmd_extra: List[str] = [] + pig_bin = "pig" + cmd_extra: list[str] = [] pig_cmd = [pig_bin] if self.pig_properties: - pig_properties_list = self.pig_properties.split() - pig_cmd.extend(pig_properties_list) + pig_cmd.extend(self.pig_properties) if pig_opts: pig_opts_list = pig_opts.split() pig_cmd.extend(pig_opts_list) - pig_cmd.extend(['-f', fname] + cmd_extra) + pig_cmd.extend(["-f", fname] + cmd_extra) if verbose: self.log.info("%s", " ".join(pig_cmd)) @@ -79,9 +88,9 @@ def run_cli(self, pig: str, pig_opts: Optional[str] = None, verbose: bool = True pig_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=tmp_dir, close_fds=True ) self.sub_process = sub_process - stdout = '' - for line in iter(sub_process.stdout.readline, b''): - stdout += line.decode('utf-8') + stdout = "" + for line in iter(sub_process.stdout.readline, b""): + stdout += line.decode("utf-8") if verbose: self.log.info(line.strip()) sub_process.wait() diff --git a/airflow/providers/apache/pig/operators/pig.py b/airflow/providers/apache/pig/operators/pig.py index 5a285530dc637..544895cea09dd 100644 --- a/airflow/providers/apache/pig/operators/pig.py +++ b/airflow/providers/apache/pig/operators/pig.py @@ -15,8 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import re -from typing import TYPE_CHECKING, Any, Optional, Sequence +from typing import TYPE_CHECKING, Any, Sequence from airflow.models import BaseOperator from airflow.providers.apache.pig.hooks.pig import PigCliHook @@ -36,23 +38,25 @@ class PigOperator(BaseOperator): you may want to use this along with the ``DAG(user_defined_macros=myargs)`` parameter. View the DAG object documentation for more details. - :param pig_opts: pig options, such as: -x tez, -useHCatalog, ... + :param pig_opts: pig options, such as: -x tez, -useHCatalog, ... - space separated list + :param pig_properties: pig properties, additional pig properties passed as list """ - template_fields: Sequence[str] = ('pig',) + template_fields: Sequence[str] = ("pig", "pig_opts", "pig_properties") template_ext: Sequence[str] = ( - '.pig', - '.piglatin', + ".pig", + ".piglatin", ) - ui_color = '#f0e4ec' + ui_color = "#f0e4ec" def __init__( self, *, pig: str, - pig_cli_conn_id: str = 'pig_cli_default', + pig_cli_conn_id: str = "pig_cli_default", pigparams_jinja_translate: bool = False, - pig_opts: Optional[str] = None, + pig_opts: str | None = None, + pig_properties: list[str] | None = None, **kwargs: Any, ) -> None: @@ -61,15 +65,16 @@ def __init__( self.pig = pig self.pig_cli_conn_id = pig_cli_conn_id self.pig_opts = pig_opts - self.hook: Optional[PigCliHook] = None + self.pig_properties = pig_properties + self.hook: PigCliHook | None = None def prepare_template(self): if self.pigparams_jinja_translate: self.pig = re.sub(r"(\$([a-zA-Z_][a-zA-Z0-9_]*))", r"{{ \g<2> }}", self.pig) - def execute(self, context: 'Context'): - self.log.info('Executing: %s', self.pig) - self.hook = PigCliHook(pig_cli_conn_id=self.pig_cli_conn_id) + def execute(self, context: Context): + self.log.info("Executing: %s", self.pig) + self.hook = PigCliHook(pig_cli_conn_id=self.pig_cli_conn_id, pig_properties=self.pig_properties) self.hook.run_cli(pig=self.pig, pig_opts=self.pig_opts) def on_kill(self): diff --git a/airflow/providers/apache/pig/provider.yaml b/airflow/providers/apache/pig/provider.yaml index 81540ae4420f6..df1c55a8a7b61 100644 --- a/airflow/providers/apache/pig/provider.yaml +++ b/airflow/providers/apache/pig/provider.yaml @@ -22,6 +22,8 @@ description: | `Apache Pig `__ versions: + - 4.0.0 + - 3.0.0 - 2.0.4 - 2.0.3 - 2.0.2 @@ -30,8 +32,8 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 integrations: - integration-name: Apache Pig @@ -51,10 +53,6 @@ hooks: python-modules: - airflow.providers.apache.pig.hooks.pig -hook-class-names: - # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.apache.pig.hooks.pig.PigCliHook - connection-types: - connection-type: pig_cli hook-class-name: airflow.providers.apache.pig.hooks.pig.PigCliHook diff --git a/airflow/providers/apache/pinot/CHANGELOG.rst b/airflow/providers/apache/pinot/CHANGELOG.rst index 923a27315fe43..cafdffda1dbcd 100644 --- a/airflow/providers/apache/pinot/CHANGELOG.rst +++ b/airflow/providers/apache/pinot/CHANGELOG.rst @@ -16,9 +16,98 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +4.0.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +The admin command is now hard-coded to ``pinot-admin.sh``. The ``pinot-admin.sh`` command must be available +on the path in order to use PinotAdminHook. + +* ``The pinot-admin.sh command is now hard-coded. (#27641)`` + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` +* ``Bump pinotdb version (#27201)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + +3.2.1 +..... + +Misc +~~~~ + +* ``Add common-sql lower bound for common-sql (#25789)`` + +Bug Fixes +~~~~~~~~~ + +* ``Fix PinotDB dependencies (#26705)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +3.2.0 +..... + +Features +~~~~~~~~ + +* ``Deprecate hql parameters and synchronize DBApiHook method APIs (#25299)`` +* ``Unify DbApiHook.run() method with the methods which override it (#23971)`` + + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Move all SQL classes to common-sql provider (#24836)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Misc +~~~~ + +* ``chore: Refactoring and Cleaning Apache Providers (#24219)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.0.4 ..... diff --git a/airflow/providers/apache/pinot/hooks/pinot.py b/airflow/providers/apache/pinot/hooks/pinot.py index 55ddce0bcc0cc..909b0182ddcb5 100644 --- a/airflow/providers/apache/pinot/hooks/pinot.py +++ b/airflow/providers/apache/pinot/hooks/pinot.py @@ -15,17 +15,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import os import subprocess -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Iterable, Mapping from pinotdb import connect from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook -from airflow.hooks.dbapi import DbApiHook from airflow.models import Connection +from airflow.providers.common.sql.hooks.sql import DbApiHook class PinotAdminHook(BaseHook): @@ -45,7 +46,10 @@ class PinotAdminHook(BaseHook): following PR: https://github.com/apache/incubator-pinot/pull/4110 :param conn_id: The name of the connection to use. - :param cmd_path: The filepath to the pinot-admin.sh executable + :param cmd_path: Do not modify the parameter. It used to be the filepath to the pinot-admin.sh + executable but in version 4.0.0 of apache-pinot provider, value of this parameter must + remain the default value: `pinot-admin.sh`. It is left here to not accidentally override + the `pinot_admin_system_exit` in case positional parameters were used to initialize the hook. :param pinot_admin_system_exit: If true, the result is evaluated based on the status code. Otherwise, the result is evaluated as a failure if "Error" or "Exception" is in the output message. @@ -61,7 +65,15 @@ def __init__( conn = self.get_connection(conn_id) self.host = conn.host self.port = str(conn.port) - self.cmd_path = conn.extra_dejson.get("cmd_path", cmd_path) + if cmd_path != "pinot-admin.sh": + raise RuntimeError( + "In version 4.0.0 of the PinotAdminHook the cmd_path has been hard-coded to" + " pinot-admin.sh. In order to avoid accidental using of this parameter as" + " positional `pinot_admin_system_exit` the `cmd_parameter`" + " parameter is left here but you should not modify it. Make sure that " + " `pinot-admin.sh` is on your PATH and do not change cmd_path value." + ) + self.cmd_path = "pinot-admin.sh" self.pinot_admin_system_exit = conn.extra_dejson.get( "pinot_admin_system_exit", pinot_admin_system_exit ) @@ -102,24 +114,24 @@ def add_table(self, file_path: str, with_exec: bool = True) -> Any: def create_segment( self, - generator_config_file: Optional[str] = None, - data_dir: Optional[str] = None, - segment_format: Optional[str] = None, - out_dir: Optional[str] = None, - overwrite: Optional[str] = None, - table_name: Optional[str] = None, - segment_name: Optional[str] = None, - time_column_name: Optional[str] = None, - schema_file: Optional[str] = None, - reader_config_file: Optional[str] = None, - enable_star_tree_index: Optional[str] = None, - star_tree_index_spec_file: Optional[str] = None, - hll_size: Optional[str] = None, - hll_columns: Optional[str] = None, - hll_suffix: Optional[str] = None, - num_threads: Optional[str] = None, - post_creation_verification: Optional[str] = None, - retry: Optional[str] = None, + generator_config_file: str | None = None, + data_dir: str | None = None, + segment_format: str | None = None, + out_dir: str | None = None, + overwrite: str | None = None, + table_name: str | None = None, + segment_name: str | None = None, + time_column_name: str | None = None, + schema_file: str | None = None, + reader_config_file: str | None = None, + enable_star_tree_index: str | None = None, + star_tree_index_spec_file: str | None = None, + hll_size: str | None = None, + hll_columns: str | None = None, + hll_suffix: str | None = None, + num_threads: str | None = None, + post_creation_verification: str | None = None, + retry: str | None = None, ) -> Any: """Create Pinot segment by run CreateSegment command""" cmd = ["CreateSegment"] @@ -180,7 +192,7 @@ def create_segment( self.run_cli(cmd) - def upload_segment(self, segment_dir: str, table_name: Optional[str] = None) -> Any: + def upload_segment(self, segment_dir: str, table_name: str | None = None) -> Any: """ Upload Segment with run UploadSegment command @@ -196,16 +208,14 @@ def upload_segment(self, segment_dir: str, table_name: Optional[str] = None) -> cmd += ["-tableName", table_name] self.run_cli(cmd) - def run_cli(self, cmd: List[str], verbose: bool = True) -> str: + def run_cli(self, cmd: list[str], verbose: bool = True) -> str: """ Run command with pinot-admin.sh :param cmd: List of command going to be run by pinot-admin.sh script :param verbose: """ - command = [self.cmd_path] - command.extend(cmd) - + command = [self.cmd_path, *cmd] env = None if self.pinot_admin_system_exit: env = os.environ.copy() @@ -214,13 +224,12 @@ def run_cli(self, cmd: List[str], verbose: bool = True) -> str: if verbose: self.log.info(" ".join(command)) - with subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True, env=env ) as sub_process: stdout = "" if sub_process.stdout: - for line in iter(sub_process.stdout.readline, b''): + for line in iter(sub_process.stdout.readline, b""): stdout += line.decode("utf-8") if verbose: self.log.info(line.decode("utf-8").strip()) @@ -246,8 +255,8 @@ class PinotDbApiHook(DbApiHook): https://docs.pinot.apache.org/users/api/querying-pinot-using-standard-sql """ - conn_name_attr = 'pinot_broker_conn_id' - default_conn_name = 'pinot_broker_default' + conn_name_attr = "pinot_broker_conn_id" + default_conn_name = "pinot_broker_default" supports_autocommit = False def get_conn(self) -> Any: @@ -257,10 +266,10 @@ def get_conn(self) -> Any: pinot_broker_conn = connect( host=conn.host, port=conn.port, - path=conn.extra_dejson.get('endpoint', '/query/sql'), - scheme=conn.extra_dejson.get('schema', 'http'), + path=conn.extra_dejson.get("endpoint", "/query/sql"), + scheme=conn.extra_dejson.get("schema", "http"), ) - self.log.info('Get the connection to pinot broker on %s', conn.host) + self.log.info("Get the connection to pinot broker on %s", conn.host) return pinot_broker_conn def get_uri(self) -> str: @@ -272,12 +281,14 @@ def get_uri(self) -> str: conn = self.get_connection(getattr(self, self.conn_name_attr)) host = conn.host if conn.port is not None: - host += f':{conn.port}' - conn_type = 'http' if not conn.conn_type else conn.conn_type - endpoint = conn.extra_dejson.get('endpoint', 'query/sql') - return f'{conn_type}://{host}/{endpoint}' + host += f":{conn.port}" + conn_type = conn.conn_type or "http" + endpoint = conn.extra_dejson.get("endpoint", "query/sql") + return f"{conn_type}://{host}/{endpoint}" - def get_records(self, sql: str, parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None) -> Any: + def get_records( + self, sql: str | list[str], parameters: Iterable | Mapping | None = None, **kwargs + ) -> Any: """ Executes the sql and returns a set of records. @@ -289,7 +300,7 @@ def get_records(self, sql: str, parameters: Optional[Union[Dict[str, Any], Itera cur.execute(sql) return cur.fetchall() - def get_first(self, sql: str, parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None) -> Any: + def get_first(self, sql: str | list[str], parameters: Iterable | Mapping | None = None) -> Any: """ Executes the sql and returns the first resulting row. @@ -308,7 +319,7 @@ def insert_rows( self, table: str, rows: str, - target_fields: Optional[str] = None, + target_fields: str | None = None, commit_every: int = 1000, replace: bool = False, **kwargs: Any, diff --git a/airflow/providers/apache/pinot/provider.yaml b/airflow/providers/apache/pinot/provider.yaml index 000827c9e141c..e8b02fbd4105c 100644 --- a/airflow/providers/apache/pinot/provider.yaml +++ b/airflow/providers/apache/pinot/provider.yaml @@ -22,6 +22,11 @@ description: | `Apache Pinot `__ versions: + - 4.0.0 + - 3.2.1 + - 3.2.0 + - 3.1.0 + - 3.0.0 - 2.0.4 - 2.0.3 - 2.0.2 @@ -30,8 +35,10 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-common-sql>=1.3.1 + - pinotdb>0.4.7 integrations: - integration-name: Apache Pinot diff --git a/airflow/providers/apache/spark/.latest-doc-only-change.txt b/airflow/providers/apache/spark/.latest-doc-only-change.txt index cda183acd3b04..ff7136e07d744 100644 --- a/airflow/providers/apache/spark/.latest-doc-only-change.txt +++ b/airflow/providers/apache/spark/.latest-doc-only-change.txt @@ -1 +1 @@ -cb73053211367e2c2dd76d5279cdc7dc7b190124 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/apache/spark/CHANGELOG.rst b/airflow/providers/apache/spark/CHANGELOG.rst index f6e4cf48ba18d..7b0cca49af7e8 100644 --- a/airflow/providers/apache/spark/CHANGELOG.rst +++ b/airflow/providers/apache/spark/CHANGELOG.rst @@ -16,9 +16,74 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +4.0.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +The ``spark-binary`` connection extra could be set to any binary, but with 4.0.0 version only two values +are allowed for it ``spark-submit`` and ``spark2-submit``. + +The ``spark-home`` connection extra is not allowed any more - the binary should be available on the +PATH in order to use SparkSubmitHook and SparkSubmitOperator. + +* ``Remove custom spark home and custom binaries for spark (#27646)`` + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Bug Fixes +~~~~~~~~~ + +* ``Add typing for airflow/configuration.py (#23716)`` +* ``Fix backwards-compatibility introduced by fixing mypy problems (#24230)`` + +Misc +~~~~ + +* ``AIP-47 - Migrate spark DAGs to new design #22439 (#24210)`` +* ``chore: Refactoring and Cleaning Apache Providers (#24219)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.1.3 ..... diff --git a/airflow/providers/apache/spark/hooks/spark_jdbc.py b/airflow/providers/apache/spark/hooks/spark_jdbc.py index df9d715be0dcc..87abfb863eae9 100644 --- a/airflow/providers/apache/spark/hooks/spark_jdbc.py +++ b/airflow/providers/apache/spark/hooks/spark_jdbc.py @@ -15,9 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations + import os -from typing import Any, Dict, Optional +from typing import Any from airflow.exceptions import AirflowException from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook @@ -85,41 +86,41 @@ class SparkJDBCHook(SparkSubmitHook): types. """ - conn_name_attr = 'spark_conn_id' - default_conn_name = 'spark_default' - conn_type = 'spark_jdbc' - hook_name = 'Spark JDBC' + conn_name_attr = "spark_conn_id" + default_conn_name = "spark_default" + conn_type = "spark_jdbc" + hook_name = "Spark JDBC" def __init__( self, - spark_app_name: str = 'airflow-spark-jdbc', + spark_app_name: str = "airflow-spark-jdbc", spark_conn_id: str = default_conn_name, - spark_conf: Optional[Dict[str, Any]] = None, - spark_py_files: Optional[str] = None, - spark_files: Optional[str] = None, - spark_jars: Optional[str] = None, - num_executors: Optional[int] = None, - executor_cores: Optional[int] = None, - executor_memory: Optional[str] = None, - driver_memory: Optional[str] = None, + spark_conf: dict[str, Any] | None = None, + spark_py_files: str | None = None, + spark_files: str | None = None, + spark_jars: str | None = None, + num_executors: int | None = None, + executor_cores: int | None = None, + executor_memory: str | None = None, + driver_memory: str | None = None, verbose: bool = False, - principal: Optional[str] = None, - keytab: Optional[str] = None, - cmd_type: str = 'spark_to_jdbc', - jdbc_table: Optional[str] = None, - jdbc_conn_id: str = 'jdbc-default', - jdbc_driver: Optional[str] = None, - metastore_table: Optional[str] = None, + principal: str | None = None, + keytab: str | None = None, + cmd_type: str = "spark_to_jdbc", + jdbc_table: str | None = None, + jdbc_conn_id: str = "jdbc-default", + jdbc_driver: str | None = None, + metastore_table: str | None = None, jdbc_truncate: bool = False, - save_mode: Optional[str] = None, - save_format: Optional[str] = None, - batch_size: Optional[int] = None, - fetch_size: Optional[int] = None, - num_partitions: Optional[int] = None, - partition_column: Optional[str] = None, - lower_bound: Optional[str] = None, - upper_bound: Optional[str] = None, - create_table_column_types: Optional[str] = None, + save_mode: str | None = None, + save_format: str | None = None, + batch_size: int | None = None, + fetch_size: int | None = None, + num_partitions: int | None = None, + partition_column: str | None = None, + lower_bound: str | None = None, + upper_bound: str | None = None, + create_table_column_types: str | None = None, *args: Any, **kwargs: Any, ): @@ -154,73 +155,73 @@ def __init__( self._create_table_column_types = create_table_column_types self._jdbc_connection = self._resolve_jdbc_connection() - def _resolve_jdbc_connection(self) -> Dict[str, Any]: - conn_data = {'url': '', 'schema': '', 'conn_prefix': '', 'user': '', 'password': ''} + def _resolve_jdbc_connection(self) -> dict[str, Any]: + conn_data = {"url": "", "schema": "", "conn_prefix": "", "user": "", "password": ""} try: conn = self.get_connection(self._jdbc_conn_id) if conn.port: - conn_data['url'] = f"{conn.host}:{conn.port}" + conn_data["url"] = f"{conn.host}:{conn.port}" else: - conn_data['url'] = conn.host - conn_data['schema'] = conn.schema - conn_data['user'] = conn.login - conn_data['password'] = conn.password + conn_data["url"] = conn.host + conn_data["schema"] = conn.schema + conn_data["user"] = conn.login + conn_data["password"] = conn.password extra = conn.extra_dejson - conn_data['conn_prefix'] = extra.get('conn_prefix', '') + conn_data["conn_prefix"] = extra.get("conn_prefix", "") except AirflowException: self.log.debug( "Could not load jdbc connection string %s, defaulting to %s", self._jdbc_conn_id, "" ) return conn_data - def _build_jdbc_application_arguments(self, jdbc_conn: Dict[str, Any]) -> Any: + def _build_jdbc_application_arguments(self, jdbc_conn: dict[str, Any]) -> Any: arguments = [] arguments += ["-cmdType", self._cmd_type] - if self._jdbc_connection['url']: + if self._jdbc_connection["url"]: arguments += [ - '-url', + "-url", f"{jdbc_conn['conn_prefix']}{jdbc_conn['url']}/{jdbc_conn['schema']}", ] - if self._jdbc_connection['user']: - arguments += ['-user', self._jdbc_connection['user']] - if self._jdbc_connection['password']: - arguments += ['-password', self._jdbc_connection['password']] + if self._jdbc_connection["user"]: + arguments += ["-user", self._jdbc_connection["user"]] + if self._jdbc_connection["password"]: + arguments += ["-password", self._jdbc_connection["password"]] if self._metastore_table: - arguments += ['-metastoreTable', self._metastore_table] + arguments += ["-metastoreTable", self._metastore_table] if self._jdbc_table: - arguments += ['-jdbcTable', self._jdbc_table] + arguments += ["-jdbcTable", self._jdbc_table] if self._jdbc_truncate: - arguments += ['-jdbcTruncate', str(self._jdbc_truncate)] + arguments += ["-jdbcTruncate", str(self._jdbc_truncate)] if self._jdbc_driver: - arguments += ['-jdbcDriver', self._jdbc_driver] + arguments += ["-jdbcDriver", self._jdbc_driver] if self._batch_size: - arguments += ['-batchsize', str(self._batch_size)] + arguments += ["-batchsize", str(self._batch_size)] if self._fetch_size: - arguments += ['-fetchsize', str(self._fetch_size)] + arguments += ["-fetchsize", str(self._fetch_size)] if self._num_partitions: - arguments += ['-numPartitions', str(self._num_partitions)] + arguments += ["-numPartitions", str(self._num_partitions)] if self._partition_column and self._lower_bound and self._upper_bound and self._num_partitions: # these 3 parameters need to be used all together to take effect. arguments += [ - '-partitionColumn', + "-partitionColumn", self._partition_column, - '-lowerBound', + "-lowerBound", self._lower_bound, - '-upperBound', + "-upperBound", self._upper_bound, ] if self._save_mode: - arguments += ['-saveMode', self._save_mode] + arguments += ["-saveMode", self._save_mode] if self._save_format: - arguments += ['-saveFormat', self._save_format] + arguments += ["-saveFormat", self._save_format] if self._create_table_column_types: - arguments += ['-createTableColumnTypes', self._create_table_column_types] + arguments += ["-createTableColumnTypes", self._create_table_column_types] return arguments def submit_jdbc_job(self) -> None: """Submit Spark JDBC job""" self._application_args = self._build_jdbc_application_arguments(self._jdbc_connection) - self.submit(application=os.path.dirname(os.path.abspath(__file__)) + "/spark_jdbc_script.py") + self.submit(application=f"{os.path.dirname(os.path.abspath(__file__))}/spark_jdbc_script.py") def get_conn(self) -> Any: pass diff --git a/airflow/providers/apache/spark/hooks/spark_jdbc_script.py b/airflow/providers/apache/spark/hooks/spark_jdbc_script.py index c354de6ab0ca4..2cc10584d7752 100644 --- a/airflow/providers/apache/spark/hooks/spark_jdbc_script.py +++ b/airflow/providers/apache/spark/hooks/spark_jdbc_script.py @@ -15,9 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations + import argparse -from typing import Any, List, Optional +from typing import Any from pyspark.sql import SparkSession @@ -27,11 +28,11 @@ def set_common_options( spark_source: Any, - url: str = 'localhost:5432', - jdbc_table: str = 'default.default', - user: str = 'root', - password: str = 'root', - driver: str = 'driver', + url: str = "localhost:5432", + jdbc_table: str = "default.default", + user: str = "root", + password: str = "root", + driver: str = "driver", ) -> Any: """ Get Spark source from JDBC connection @@ -44,12 +45,12 @@ def set_common_options( :param driver: JDBC resource driver """ spark_source = ( - spark_source.format('jdbc') - .option('url', url) - .option('dbtable', jdbc_table) - .option('user', user) - .option('password', password) - .option('driver', driver) + spark_source.format("jdbc") + .option("url", url) + .option("dbtable", jdbc_table) + .option("user", user) + .option("password", password) + .option("driver", driver) ) return spark_source @@ -75,11 +76,11 @@ def spark_write_to_jdbc( # now set write-specific options if truncate: - writer = writer.option('truncate', truncate) + writer = writer.option("truncate", truncate) if batch_size: - writer = writer.option('batchsize', batch_size) + writer = writer.option("batchsize", batch_size) if num_partitions: - writer = writer.option('numPartitions', num_partitions) + writer = writer.option("numPartitions", num_partitions) if create_table_column_types: writer = writer.option("createTableColumnTypes", create_table_column_types) @@ -108,39 +109,39 @@ def spark_read_from_jdbc( # now set specific read options if fetch_size: - reader = reader.option('fetchsize', fetch_size) + reader = reader.option("fetchsize", fetch_size) if num_partitions: - reader = reader.option('numPartitions', num_partitions) + reader = reader.option("numPartitions", num_partitions) if partition_column and lower_bound and upper_bound: reader = ( - reader.option('partitionColumn', partition_column) - .option('lowerBound', lower_bound) - .option('upperBound', upper_bound) + reader.option("partitionColumn", partition_column) + .option("lowerBound", lower_bound) + .option("upperBound", upper_bound) ) reader.load().write.saveAsTable(metastore_table, format=save_format, mode=save_mode) -def _parse_arguments(args: Optional[List[str]] = None) -> Any: - parser = argparse.ArgumentParser(description='Spark-JDBC') - parser.add_argument('-cmdType', dest='cmd_type', action='store') - parser.add_argument('-url', dest='url', action='store') - parser.add_argument('-user', dest='user', action='store') - parser.add_argument('-password', dest='password', action='store') - parser.add_argument('-metastoreTable', dest='metastore_table', action='store') - parser.add_argument('-jdbcTable', dest='jdbc_table', action='store') - parser.add_argument('-jdbcDriver', dest='jdbc_driver', action='store') - parser.add_argument('-jdbcTruncate', dest='truncate', action='store') - parser.add_argument('-saveMode', dest='save_mode', action='store') - parser.add_argument('-saveFormat', dest='save_format', action='store') - parser.add_argument('-batchsize', dest='batch_size', action='store') - parser.add_argument('-fetchsize', dest='fetch_size', action='store') - parser.add_argument('-name', dest='name', action='store') - parser.add_argument('-numPartitions', dest='num_partitions', action='store') - parser.add_argument('-partitionColumn', dest='partition_column', action='store') - parser.add_argument('-lowerBound', dest='lower_bound', action='store') - parser.add_argument('-upperBound', dest='upper_bound', action='store') - parser.add_argument('-createTableColumnTypes', dest='create_table_column_types', action='store') +def _parse_arguments(args: list[str] | None = None) -> Any: + parser = argparse.ArgumentParser(description="Spark-JDBC") + parser.add_argument("-cmdType", dest="cmd_type", action="store") + parser.add_argument("-url", dest="url", action="store") + parser.add_argument("-user", dest="user", action="store") + parser.add_argument("-password", dest="password", action="store") + parser.add_argument("-metastoreTable", dest="metastore_table", action="store") + parser.add_argument("-jdbcTable", dest="jdbc_table", action="store") + parser.add_argument("-jdbcDriver", dest="jdbc_driver", action="store") + parser.add_argument("-jdbcTruncate", dest="truncate", action="store") + parser.add_argument("-saveMode", dest="save_mode", action="store") + parser.add_argument("-saveFormat", dest="save_format", action="store") + parser.add_argument("-batchsize", dest="batch_size", action="store") + parser.add_argument("-fetchsize", dest="fetch_size", action="store") + parser.add_argument("-name", dest="name", action="store") + parser.add_argument("-numPartitions", dest="num_partitions", action="store") + parser.add_argument("-partitionColumn", dest="partition_column", action="store") + parser.add_argument("-lowerBound", dest="lower_bound", action="store") + parser.add_argument("-upperBound", dest="upper_bound", action="store") + parser.add_argument("-createTableColumnTypes", dest="create_table_column_types", action="store") return parser.parse_args(args=args) diff --git a/airflow/providers/apache/spark/hooks/spark_sql.py b/airflow/providers/apache/spark/hooks/spark_sql.py index 2621bf9a4099b..d6f8f56c27732 100644 --- a/airflow/providers/apache/spark/hooks/spark_sql.py +++ b/airflow/providers/apache/spark/hooks/spark_sql.py @@ -15,9 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations + import subprocess -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.hooks.base import BaseHook @@ -49,30 +50,30 @@ class SparkSqlHook(BaseHook): (Default: The ``queue`` value set in the Connection, or ``"default"``) """ - conn_name_attr = 'conn_id' - default_conn_name = 'spark_sql_default' - conn_type = 'spark_sql' - hook_name = 'Spark SQL' + conn_name_attr = "conn_id" + default_conn_name = "spark_sql_default" + conn_type = "spark_sql" + hook_name = "Spark SQL" def __init__( self, sql: str, - conf: Optional[str] = None, + conf: str | None = None, conn_id: str = default_conn_name, - total_executor_cores: Optional[int] = None, - executor_cores: Optional[int] = None, - executor_memory: Optional[str] = None, - keytab: Optional[str] = None, - principal: Optional[str] = None, - master: Optional[str] = None, - name: str = 'default-name', - num_executors: Optional[int] = None, + total_executor_cores: int | None = None, + executor_cores: int | None = None, + executor_memory: str | None = None, + keytab: str | None = None, + principal: str | None = None, + master: str | None = None, + name: str = "default-name", + num_executors: int | None = None, verbose: bool = True, - yarn_queue: Optional[str] = None, + yarn_queue: str | None = None, ) -> None: super().__init__() - options: Dict = {} - conn: Optional[Connection] = None + options: dict = {} + conn: Connection | None = None try: conn = self.get_connection(conn_id) @@ -109,7 +110,7 @@ def __init__( def get_conn(self) -> Any: pass - def _prepare_command(self, cmd: Union[str, List[str]]) -> List[str]: + def _prepare_command(self, cmd: str | list[str]) -> list[str]: """ Construct the spark-sql command to execute. Verbose output is enabled as default. diff --git a/airflow/providers/apache/spark/hooks/spark_submit.py b/airflow/providers/apache/spark/hooks/spark_submit.py index 0f5dc2f7307cc..bfc08eda6402d 100644 --- a/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/airflow/providers/apache/spark/hooks/spark_submit.py @@ -15,12 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations + +import contextlib import os import re import subprocess import time -from typing import Any, Dict, Iterator, List, Optional, Union +from typing import Any, Iterator from airflow.configuration import conf as airflow_conf from airflow.exceptions import AirflowException @@ -28,17 +30,16 @@ from airflow.security.kerberos import renew_from_kt from airflow.utils.log.logging_mixin import LoggingMixin -try: +with contextlib.suppress(ImportError, NameError): from airflow.kubernetes import kube_client -except (ImportError, NameError): - pass + +ALLOWED_SPARK_BINARIES = ["spark-submit", "spark2-submit"] class SparkSubmitHook(BaseHook, LoggingMixin): """ This hook is a wrapper around the spark-submit binary to kick off a spark-submit job. - It requires that the "spark-submit" binary is in the PATH or the spark_home to be - supplied. + It requires that the "spark-submit" binary is in the PATH. :param conf: Arbitrary Spark configuration properties :param spark_conn_id: The :ref:`spark connection id ` as configured @@ -80,46 +81,46 @@ class SparkSubmitHook(BaseHook, LoggingMixin): Some distros may use spark2-submit. """ - conn_name_attr = 'conn_id' - default_conn_name = 'spark_default' - conn_type = 'spark' - hook_name = 'Spark' + conn_name_attr = "conn_id" + default_conn_name = "spark_default" + conn_type = "spark" + hook_name = "Spark" @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { - "hidden_fields": ['schema', 'login', 'password'], + "hidden_fields": ["schema", "login", "password"], "relabeling": {}, } def __init__( self, - conf: Optional[Dict[str, Any]] = None, - conn_id: str = 'spark_default', - files: Optional[str] = None, - py_files: Optional[str] = None, - archives: Optional[str] = None, - driver_class_path: Optional[str] = None, - jars: Optional[str] = None, - java_class: Optional[str] = None, - packages: Optional[str] = None, - exclude_packages: Optional[str] = None, - repositories: Optional[str] = None, - total_executor_cores: Optional[int] = None, - executor_cores: Optional[int] = None, - executor_memory: Optional[str] = None, - driver_memory: Optional[str] = None, - keytab: Optional[str] = None, - principal: Optional[str] = None, - proxy_user: Optional[str] = None, - name: str = 'default-name', - num_executors: Optional[int] = None, + conf: dict[str, Any] | None = None, + conn_id: str = "spark_default", + files: str | None = None, + py_files: str | None = None, + archives: str | None = None, + driver_class_path: str | None = None, + jars: str | None = None, + java_class: str | None = None, + packages: str | None = None, + exclude_packages: str | None = None, + repositories: str | None = None, + total_executor_cores: int | None = None, + executor_cores: int | None = None, + executor_memory: str | None = None, + driver_memory: str | None = None, + keytab: str | None = None, + principal: str | None = None, + proxy_user: str | None = None, + name: str = "default-name", + num_executors: int | None = None, status_poll_interval: int = 1, - application_args: Optional[List[Any]] = None, - env_vars: Optional[Dict[str, Any]] = None, + application_args: list[Any] | None = None, + env_vars: dict[str, Any] | None = None, verbose: bool = False, - spark_binary: Optional[str] = None, + spark_binary: str | None = None, ) -> None: super().__init__() self._conf = conf or {} @@ -146,42 +147,47 @@ def __init__( self._application_args = application_args self._env_vars = env_vars self._verbose = verbose - self._submit_sp: Optional[Any] = None - self._yarn_application_id: Optional[str] = None - self._kubernetes_driver_pod: Optional[str] = None + self._submit_sp: Any | None = None + self._yarn_application_id: str | None = None + self._kubernetes_driver_pod: str | None = None self._spark_binary = spark_binary + if self._spark_binary is not None and self._spark_binary not in ALLOWED_SPARK_BINARIES: + raise RuntimeError( + f"The spark-binary extra can be on of {ALLOWED_SPARK_BINARIES} and it" + f" was `{spark_binary}`. Please make sure your spark binary is one of the" + f" allowed ones and that it is available on the PATH" + ) self._connection = self._resolve_connection() - self._is_yarn = 'yarn' in self._connection['master'] - self._is_kubernetes = 'k8s' in self._connection['master'] + self._is_yarn = "yarn" in self._connection["master"] + self._is_kubernetes = "k8s" in self._connection["master"] if self._is_kubernetes and kube_client is None: raise RuntimeError( f"{self._connection['master']} specified by kubernetes dependencies are not installed!" ) self._should_track_driver_status = self._resolve_should_track_driver_status() - self._driver_id: Optional[str] = None - self._driver_status: Optional[str] = None - self._spark_exit_code: Optional[int] = None - self._env: Optional[Dict[str, Any]] = None + self._driver_id: str | None = None + self._driver_status: str | None = None + self._spark_exit_code: int | None = None + self._env: dict[str, Any] | None = None def _resolve_should_track_driver_status(self) -> bool: """ - Determines whether or not this hook should poll the spark driver status through + Determines whether this hook should poll the spark driver status through subsequent spark-submit status requests after the initial spark-submit request :return: if the driver status should be tracked """ - return 'spark://' in self._connection['master'] and self._connection['deploy_mode'] == 'cluster' + return "spark://" in self._connection["master"] and self._connection["deploy_mode"] == "cluster" - def _resolve_connection(self) -> Dict[str, Any]: + def _resolve_connection(self) -> dict[str, Any]: # Build from connection master or default to yarn if not available conn_data = { - 'master': 'yarn', - 'queue': None, - 'deploy_mode': None, - 'spark_home': None, - 'spark_binary': self._spark_binary or "spark-submit", - 'namespace': None, + "master": "yarn", + "queue": None, + "deploy_mode": None, + "spark_binary": self._spark_binary or "spark-submit", + "namespace": None, } try: @@ -189,44 +195,47 @@ def _resolve_connection(self) -> Dict[str, Any]: # k8s://https://: conn = self.get_connection(self._conn_id) if conn.port: - conn_data['master'] = f"{conn.host}:{conn.port}" + conn_data["master"] = f"{conn.host}:{conn.port}" else: - conn_data['master'] = conn.host + conn_data["master"] = conn.host # Determine optional yarn queue from the extra field extra = conn.extra_dejson - conn_data['queue'] = extra.get('queue') - conn_data['deploy_mode'] = extra.get('deploy-mode') - conn_data['spark_home'] = extra.get('spark-home') - conn_data['spark_binary'] = self._spark_binary or extra.get('spark-binary', "spark-submit") - conn_data['namespace'] = extra.get('namespace') + conn_data["queue"] = extra.get("queue") + conn_data["deploy_mode"] = extra.get("deploy-mode") + spark_binary = self._spark_binary or extra.get("spark-binary", "spark-submit") + if spark_binary not in ALLOWED_SPARK_BINARIES: + raise RuntimeError( + f"The `spark-binary` extra can be on of {ALLOWED_SPARK_BINARIES} and it" + f" was `{spark_binary}`. Please make sure your spark binary is one of the" + " allowed ones and that it is available on the PATH" + ) + conn_spark_home = extra.get("spark-home") + if conn_spark_home: + raise RuntimeError( + "The `spark-home` extra is not allowed any more. Please make sure your `spark-submit` or" + " `spark2-submit` are available on the PATH." + ) + conn_data["spark_binary"] = spark_binary + conn_data["namespace"] = extra.get("namespace") except AirflowException: self.log.info( - "Could not load connection string %s, defaulting to %s", self._conn_id, conn_data['master'] + "Could not load connection string %s, defaulting to %s", self._conn_id, conn_data["master"] ) - if 'spark.kubernetes.namespace' in self._conf: - conn_data['namespace'] = self._conf['spark.kubernetes.namespace'] + if "spark.kubernetes.namespace" in self._conf: + conn_data["namespace"] = self._conf["spark.kubernetes.namespace"] return conn_data def get_conn(self) -> Any: pass - def _get_spark_binary_path(self) -> List[str]: - # If the spark_home is passed then build the spark-submit executable path using - # the spark_home; otherwise assume that spark-submit is present in the path to - # the executing user - if self._connection['spark_home']: - connection_cmd = [ - os.path.join(self._connection['spark_home'], 'bin', self._connection['spark_binary']) - ] - else: - connection_cmd = [self._connection['spark_binary']] - - return connection_cmd + def _get_spark_binary_path(self) -> list[str]: + # Assume that spark-submit is present in the path to the executing user + return [self._connection["spark_binary"]] - def _mask_cmd(self, connection_cmd: Union[str, List[str]]) -> str: + def _mask_cmd(self, connection_cmd: str | list[str]) -> str: # Mask any password related fields in application args with key value pair # where key contains password (case insensitive), e.g. HivePassword='abc' connection_cmd_masked = re.sub( @@ -243,14 +252,14 @@ def _mask_cmd(self, connection_cmd: Union[str, List[str]]) -> str: # (matched above); if the value is quoted, # it may contain whitespace. r"(\2)", # Optional matching quote. - r'\1******\3', - ' '.join(connection_cmd), + r"\1******\3", + " ".join(connection_cmd), flags=re.I, ) return connection_cmd_masked - def _build_spark_submit_command(self, application: str) -> List[str]: + def _build_spark_submit_command(self, application: str) -> list[str]: """ Construct the spark-submit command to execute. @@ -260,7 +269,7 @@ def _build_spark_submit_command(self, application: str) -> List[str]: connection_cmd = self._get_spark_binary_path() # The url of the spark master - connection_cmd += ["--master", self._connection['master']] + connection_cmd += ["--master", self._connection["master"]] for key in self._conf: connection_cmd += ["--conf", f"{key}={str(self._conf[key])}"] @@ -273,11 +282,11 @@ def _build_spark_submit_command(self, application: str) -> List[str]: tmpl = "spark.kubernetes.driverEnv.{}={}" for key in self._env_vars: connection_cmd += ["--conf", tmpl.format(key, str(self._env_vars[key]))] - elif self._env_vars and self._connection['deploy_mode'] != "cluster": + elif self._env_vars and self._connection["deploy_mode"] != "cluster": self._env = self._env_vars # Do it on Popen of the process - elif self._env_vars and self._connection['deploy_mode'] == "cluster": + elif self._env_vars and self._connection["deploy_mode"] == "cluster": raise AirflowException("SparkSubmitHook env_vars is not supported in standalone-cluster mode.") - if self._is_kubernetes and self._connection['namespace']: + if self._is_kubernetes and self._connection["namespace"]: connection_cmd += [ "--conf", f"spark.kubernetes.namespace={self._connection['namespace']}", @@ -320,10 +329,10 @@ def _build_spark_submit_command(self, application: str) -> List[str]: connection_cmd += ["--class", self._java_class] if self._verbose: connection_cmd += ["--verbose"] - if self._connection['queue']: - connection_cmd += ["--queue", self._connection['queue']] - if self._connection['deploy_mode']: - connection_cmd += ["--deploy-mode", self._connection['deploy_mode']] + if self._connection["queue"]: + connection_cmd += ["--queue", self._connection["queue"]] + if self._connection["deploy_mode"]: + connection_cmd += ["--deploy-mode", self._connection["deploy_mode"]] # The actual script to execute connection_cmd += [application] @@ -336,15 +345,15 @@ def _build_spark_submit_command(self, application: str) -> List[str]: return connection_cmd - def _build_track_driver_status_command(self) -> List[str]: + def _build_track_driver_status_command(self) -> list[str]: """ Construct the command to poll the driver status. :return: full command to be executed """ curl_max_wait_time = 30 - spark_host = self._connection['master'] - if spark_host.endswith(':6066'): + spark_host = self._connection["master"] + if spark_host.endswith(":6066"): spark_host = spark_host.replace("spark://", "http://") connection_cmd = [ "/usr/bin/curl", @@ -355,9 +364,7 @@ def _build_track_driver_status_command(self) -> List[str]: self.log.info(connection_cmd) # The driver id so we can poll for its status - if self._driver_id: - pass - else: + if not self._driver_id: raise AirflowException( "Invalid status: attempted to poll driver status but no driver id is known. Giving up." ) @@ -367,7 +374,7 @@ def _build_track_driver_status_command(self) -> List[str]: connection_cmd = self._get_spark_binary_path() # The url to the spark master - connection_cmd += ["--master", self._connection['master']] + connection_cmd += ["--master", self._connection["master"]] # The driver id so we can poll for its status if self._driver_id: @@ -457,8 +464,8 @@ def _process_spark_submit_log(self, itr: Iterator[Any]) -> None: line = line.strip() # If we run yarn cluster mode, we want to extract the application id from # the logs so we can kill the application when we stop it unexpectedly - if self._is_yarn and self._connection['deploy_mode'] == 'cluster': - match = re.search('(application[0-9_]+)', line) + if self._is_yarn and self._connection["deploy_mode"] == "cluster": + match = re.search("(application[0-9_]+)", line) if match: self._yarn_application_id = match.groups()[0] self.log.info("Identified spark driver id: %s", self._yarn_application_id) @@ -466,13 +473,13 @@ def _process_spark_submit_log(self, itr: Iterator[Any]) -> None: # If we run Kubernetes cluster mode, we want to extract the driver pod id # from the logs so we can kill the application when we stop it unexpectedly elif self._is_kubernetes: - match = re.search(r'\s*pod name: ((.+?)-([a-z0-9]+)-driver)', line) + match = re.search(r"\s*pod name: ((.+?)-([a-z0-9]+)-driver)", line) if match: self._kubernetes_driver_pod = match.groups()[0] self.log.info("Identified spark driver pod: %s", self._kubernetes_driver_pod) # Store the Spark Exit code - match_exit_code = re.search(r'\s*[eE]xit code: (\d+)', line) + match_exit_code = re.search(r"\s*[eE]xit code: (\d+)", line) if match_exit_code: self._spark_exit_code = int(match_exit_code.groups()[0]) @@ -480,7 +487,7 @@ def _process_spark_submit_log(self, itr: Iterator[Any]) -> None: # we need to extract the driver id from the logs. This allows us to poll for # the status using the driver id. Also, we can kill the driver when needed. elif self._should_track_driver_status and not self._driver_id: - match_driver_id = re.search(r'(driver-[0-9\-]+)', line) + match_driver_id = re.search(r"(driver-[0-9\-]+)", line) if match_driver_id: self._driver_id = match_driver_id.groups()[0] self.log.info("identified spark driver id: %s", self._driver_id) @@ -505,7 +512,7 @@ def _process_spark_status_log(self, itr: Iterator[Any]) -> None: # Check if the log line is about the driver status and extract the status. if "driverState" in line: - self._driver_status = line.split(' : ')[1].replace(',', '').replace('\"', '').strip() + self._driver_status = line.split(" : ")[1].replace(",", "").replace('"', "").strip() driver_found = True self.log.debug("spark driver status log: %s", line) @@ -577,23 +584,16 @@ def _start_driver_status_tracking(self) -> None: f"returncode = {returncode}" ) - def _build_spark_driver_kill_command(self) -> List[str]: + def _build_spark_driver_kill_command(self) -> list[str]: """ Construct the spark-submit command to kill a driver. :return: full command to kill a driver """ - # If the spark_home is passed then build the spark-submit executable path using - # the spark_home; otherwise assume that spark-submit is present in the path to - # the executing user - if self._connection['spark_home']: - connection_cmd = [ - os.path.join(self._connection['spark_home'], 'bin', self._connection['spark_binary']) - ] - else: - connection_cmd = [self._connection['spark_binary']] + # Assume that spark-submit is present in the path to the executing user + connection_cmd = [self._connection["spark_binary"]] # The url to the spark master - connection_cmd += ["--master", self._connection['master']] + connection_cmd += ["--master", self._connection["master"]] # The actual kill command if self._driver_id: @@ -607,20 +607,17 @@ def on_kill(self) -> None: """Kill Spark submit command""" self.log.debug("Kill Command is being called") - if self._should_track_driver_status: - if self._driver_id: - self.log.info('Killing driver %s on cluster', self._driver_id) + if self._should_track_driver_status and self._driver_id: + self.log.info("Killing driver %s on cluster", self._driver_id) - kill_cmd = self._build_spark_driver_kill_command() - with subprocess.Popen( - kill_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) as driver_kill: - self.log.info( - "Spark driver %s killed with return code: %s", self._driver_id, driver_kill.wait() - ) + kill_cmd = self._build_spark_driver_kill_command() + with subprocess.Popen(kill_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as driver_kill: + self.log.info( + "Spark driver %s killed with return code: %s", self._driver_id, driver_kill.wait() + ) if self._submit_sp and self._submit_sp.poll() is None: - self.log.info('Sending kill signal to %s', self._connection['spark_binary']) + self.log.info("Sending kill signal to %s", self._connection["spark_binary"]) self._submit_sp.kill() if self._yarn_application_id: @@ -632,7 +629,8 @@ def on_kill(self) -> None: # we still attempt to kill the yarn application renew_from_kt(self._principal, self._keytab, exit_on_fail=False) env = os.environ.copy() - env["KRB5CCNAME"] = airflow_conf.get_mandatory_value('kerberos', 'ccache') + ccacche = airflow_conf.get_mandatory_value("kerberos", "ccache") + env["KRB5CCNAME"] = ccacche with subprocess.Popen( kill_cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE @@ -640,7 +638,7 @@ def on_kill(self) -> None: self.log.info("YARN app killed with return code: %s", yarn_kill.wait()) if self._kubernetes_driver_pod: - self.log.info('Killing pod %s on Kubernetes', self._kubernetes_driver_pod) + self.log.info("Killing pod %s on Kubernetes", self._kubernetes_driver_pod) # Currently only instantiate Kubernetes client for killing a spark pod. try: @@ -649,7 +647,7 @@ def on_kill(self) -> None: client = kube_client.get_kube_client() api_response = client.delete_namespaced_pod( self._kubernetes_driver_pod, - self._connection['namespace'], + self._connection["namespace"], body=kubernetes.client.V1DeleteOptions(), pretty=True, ) diff --git a/airflow/providers/apache/spark/operators/spark_jdbc.py b/airflow/providers/apache/spark/operators/spark_jdbc.py index 87f244be50899..7dd035db40fa7 100644 --- a/airflow/providers/apache/spark/operators/spark_jdbc.py +++ b/airflow/providers/apache/spark/operators/spark_jdbc.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -from typing import TYPE_CHECKING, Any, Dict, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING, Any from airflow.providers.apache.spark.hooks.spark_jdbc import SparkJDBCHook from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator @@ -96,34 +97,34 @@ class SparkJDBCOperator(SparkSubmitOperator): def __init__( self, *, - spark_app_name: str = 'airflow-spark-jdbc', - spark_conn_id: str = 'spark-default', - spark_conf: Optional[Dict[str, Any]] = None, - spark_py_files: Optional[str] = None, - spark_files: Optional[str] = None, - spark_jars: Optional[str] = None, - num_executors: Optional[int] = None, - executor_cores: Optional[int] = None, - executor_memory: Optional[str] = None, - driver_memory: Optional[str] = None, + spark_app_name: str = "airflow-spark-jdbc", + spark_conn_id: str = "spark-default", + spark_conf: dict[str, Any] | None = None, + spark_py_files: str | None = None, + spark_files: str | None = None, + spark_jars: str | None = None, + num_executors: int | None = None, + executor_cores: int | None = None, + executor_memory: str | None = None, + driver_memory: str | None = None, verbose: bool = False, - principal: Optional[str] = None, - keytab: Optional[str] = None, - cmd_type: str = 'spark_to_jdbc', - jdbc_table: Optional[str] = None, - jdbc_conn_id: str = 'jdbc-default', - jdbc_driver: Optional[str] = None, - metastore_table: Optional[str] = None, + principal: str | None = None, + keytab: str | None = None, + cmd_type: str = "spark_to_jdbc", + jdbc_table: str | None = None, + jdbc_conn_id: str = "jdbc-default", + jdbc_driver: str | None = None, + metastore_table: str | None = None, jdbc_truncate: bool = False, - save_mode: Optional[str] = None, - save_format: Optional[str] = None, - batch_size: Optional[int] = None, - fetch_size: Optional[int] = None, - num_partitions: Optional[int] = None, - partition_column: Optional[str] = None, - lower_bound: Optional[str] = None, - upper_bound: Optional[str] = None, - create_table_column_types: Optional[str] = None, + save_mode: str | None = None, + save_format: str | None = None, + batch_size: int | None = None, + fetch_size: int | None = None, + num_partitions: int | None = None, + partition_column: str | None = None, + lower_bound: str | None = None, + upper_bound: str | None = None, + create_table_column_types: str | None = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -155,9 +156,9 @@ def __init__( self._lower_bound = lower_bound self._upper_bound = upper_bound self._create_table_column_types = create_table_column_types - self._hook: Optional[SparkJDBCHook] = None + self._hook: SparkJDBCHook | None = None - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: """Call the SparkSubmitHook to run the provided spark job""" if self._hook is None: self._hook = self._get_hook() diff --git a/airflow/providers/apache/spark/operators/spark_sql.py b/airflow/providers/apache/spark/operators/spark_sql.py index 33f19e9c43287..a5150d281aebe 100644 --- a/airflow/providers/apache/spark/operators/spark_sql.py +++ b/airflow/providers/apache/spark/operators/spark_sql.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -from typing import TYPE_CHECKING, Any, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence from airflow.models import BaseOperator from airflow.providers.apache.spark.hooks.spark_sql import SparkSqlHook @@ -51,26 +52,26 @@ class SparkSqlOperator(BaseOperator): (Default: The ``queue`` value set in the Connection, or ``"default"``) """ - template_fields: Sequence[str] = ('_sql',) + template_fields: Sequence[str] = ("_sql",) template_ext: Sequence[str] = (".sql", ".hql") - template_fields_renderers = {'_sql': 'sql'} + template_fields_renderers = {"_sql": "sql"} def __init__( self, *, sql: str, - conf: Optional[str] = None, - conn_id: str = 'spark_sql_default', - total_executor_cores: Optional[int] = None, - executor_cores: Optional[int] = None, - executor_memory: Optional[str] = None, - keytab: Optional[str] = None, - principal: Optional[str] = None, - master: Optional[str] = None, - name: str = 'default-name', - num_executors: Optional[int] = None, + conf: str | None = None, + conn_id: str = "spark_sql_default", + total_executor_cores: int | None = None, + executor_cores: int | None = None, + executor_memory: str | None = None, + keytab: str | None = None, + principal: str | None = None, + master: str | None = None, + name: str = "default-name", + num_executors: int | None = None, verbose: bool = True, - yarn_queue: Optional[str] = None, + yarn_queue: str | None = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -87,9 +88,9 @@ def __init__( self._num_executors = num_executors self._verbose = verbose self._yarn_queue = yarn_queue - self._hook: Optional[SparkSqlHook] = None + self._hook: SparkSqlHook | None = None - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: """Call the SparkSqlHook to run the provided sql query""" if self._hook is None: self._hook = self._get_hook() diff --git a/airflow/providers/apache/spark/operators/spark_submit.py b/airflow/providers/apache/spark/operators/spark_submit.py index db1114cf201dc..b0b2961c73f4c 100644 --- a/airflow/providers/apache/spark/operators/spark_submit.py +++ b/airflow/providers/apache/spark/operators/spark_submit.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence from airflow.models import BaseOperator from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook @@ -29,8 +30,7 @@ class SparkSubmitOperator(BaseOperator): """ This hook is a wrapper around the spark-submit binary to kick off a spark-submit job. - It requires that the "spark-submit" binary is in the PATH or the spark-home is set - in the extra on the connection. + It requires that the "spark-submit" binary is in the PATH. .. seealso:: For more information on how to use this operator, take a look at the guide: @@ -73,52 +73,52 @@ class SparkSubmitOperator(BaseOperator): """ template_fields: Sequence[str] = ( - '_application', - '_conf', - '_files', - '_py_files', - '_jars', - '_driver_class_path', - '_packages', - '_exclude_packages', - '_keytab', - '_principal', - '_proxy_user', - '_name', - '_application_args', - '_env_vars', + "_application", + "_conf", + "_files", + "_py_files", + "_jars", + "_driver_class_path", + "_packages", + "_exclude_packages", + "_keytab", + "_principal", + "_proxy_user", + "_name", + "_application_args", + "_env_vars", ) - ui_color = WEB_COLORS['LIGHTORANGE'] + ui_color = WEB_COLORS["LIGHTORANGE"] def __init__( self, *, - application: str = '', - conf: Optional[Dict[str, Any]] = None, - conn_id: str = 'spark_default', - files: Optional[str] = None, - py_files: Optional[str] = None, - archives: Optional[str] = None, - driver_class_path: Optional[str] = None, - jars: Optional[str] = None, - java_class: Optional[str] = None, - packages: Optional[str] = None, - exclude_packages: Optional[str] = None, - repositories: Optional[str] = None, - total_executor_cores: Optional[int] = None, - executor_cores: Optional[int] = None, - executor_memory: Optional[str] = None, - driver_memory: Optional[str] = None, - keytab: Optional[str] = None, - principal: Optional[str] = None, - proxy_user: Optional[str] = None, - name: str = 'arrow-spark', - num_executors: Optional[int] = None, + application: str = "", + conf: dict[str, Any] | None = None, + conn_id: str = "spark_default", + files: str | None = None, + py_files: str | None = None, + archives: str | None = None, + driver_class_path: str | None = None, + jars: str | None = None, + java_class: str | None = None, + packages: str | None = None, + exclude_packages: str | None = None, + repositories: str | None = None, + total_executor_cores: int | None = None, + executor_cores: int | None = None, + executor_memory: str | None = None, + driver_memory: str | None = None, + keytab: str | None = None, + principal: str | None = None, + proxy_user: str | None = None, + name: str = "arrow-spark", + num_executors: int | None = None, status_poll_interval: int = 1, - application_args: Optional[List[Any]] = None, - env_vars: Optional[Dict[str, Any]] = None, + application_args: list[Any] | None = None, + env_vars: dict[str, Any] | None = None, verbose: bool = False, - spark_binary: Optional[str] = None, + spark_binary: str | None = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -147,10 +147,10 @@ def __init__( self._env_vars = env_vars self._verbose = verbose self._spark_binary = spark_binary - self._hook: Optional[SparkSubmitHook] = None + self._hook: SparkSubmitHook | None = None self._conn_id = conn_id - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: """Call the SparkSubmitHook to run the provided spark job""" if self._hook is None: self._hook = self._get_hook() diff --git a/airflow/providers/apache/spark/provider.yaml b/airflow/providers/apache/spark/provider.yaml index c0a2dd23c21c8..a40fa651f3084 100644 --- a/airflow/providers/apache/spark/provider.yaml +++ b/airflow/providers/apache/spark/provider.yaml @@ -22,6 +22,8 @@ description: | `Apache Spark `__ versions: + - 4.0.0 + - 3.0.0 - 2.1.3 - 2.1.2 - 2.1.1 @@ -35,8 +37,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - pyspark integrations: - integration-name: Apache Spark @@ -61,10 +64,6 @@ hooks: - airflow.providers.apache.spark.hooks.spark_sql - airflow.providers.apache.spark.hooks.spark_submit -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.apache.spark.hooks.spark_jdbc.SparkJDBCHook - - airflow.providers.apache.spark.hooks.spark_sql.SparkSqlHook - - airflow.providers.apache.spark.hooks.spark_submit.SparkSubmitHook connection-types: - hook-class-name: airflow.providers.apache.spark.hooks.spark_jdbc.SparkJDBCHook diff --git a/airflow/providers/apache/sqoop/.latest-doc-only-change.txt b/airflow/providers/apache/sqoop/.latest-doc-only-change.txt index e7c3c940c9c77..ff7136e07d744 100644 --- a/airflow/providers/apache/sqoop/.latest-doc-only-change.txt +++ b/airflow/providers/apache/sqoop/.latest-doc-only-change.txt @@ -1 +1 @@ -602abe8394fafe7de54df7e73af56de848cdf617 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/apache/sqoop/CHANGELOG.rst b/airflow/providers/apache/sqoop/CHANGELOG.rst index 5178f1d269566..a5785dd689508 100644 --- a/airflow/providers/apache/sqoop/CHANGELOG.rst +++ b/airflow/providers/apache/sqoop/CHANGELOG.rst @@ -16,9 +16,50 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.1.3 ..... diff --git a/airflow/providers/apache/sqoop/hooks/sqoop.py b/airflow/providers/apache/sqoop/hooks/sqoop.py index 65ed7500cb5ae..43a1c105885b6 100644 --- a/airflow/providers/apache/sqoop/hooks/sqoop.py +++ b/airflow/providers/apache/sqoop/hooks/sqoop.py @@ -15,12 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# - """This module contains a sqoop 1.x hook""" +from __future__ import annotations + import subprocess from copy import deepcopy -from typing import Any, Dict, List, Optional +from typing import Any from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook @@ -48,30 +48,30 @@ class SqoopHook(BaseHook): :param properties: Properties to set via the -D argument """ - conn_name_attr = 'conn_id' - default_conn_name = 'sqoop_default' - conn_type = 'sqoop' - hook_name = 'Sqoop' + conn_name_attr = "conn_id" + default_conn_name = "sqoop_default" + conn_type = "sqoop" + hook_name = "Sqoop" def __init__( self, conn_id: str = default_conn_name, verbose: bool = False, - num_mappers: Optional[int] = None, - hcatalog_database: Optional[str] = None, - hcatalog_table: Optional[str] = None, - properties: Optional[Dict[str, Any]] = None, + num_mappers: int | None = None, + hcatalog_database: str | None = None, + hcatalog_table: str | None = None, + properties: dict[str, Any] | None = None, ) -> None: # No mutable types in the default parameters super().__init__() self.conn = self.get_connection(conn_id) connection_parameters = self.conn.extra_dejson - self.job_tracker = connection_parameters.get('job_tracker', None) - self.namenode = connection_parameters.get('namenode', None) - self.libjars = connection_parameters.get('libjars', None) - self.files = connection_parameters.get('files', None) - self.archives = connection_parameters.get('archives', None) - self.password_file = connection_parameters.get('password_file', None) + self.job_tracker = connection_parameters.get("job_tracker", None) + self.namenode = connection_parameters.get("namenode", None) + self.libjars = connection_parameters.get("libjars", None) + self.files = connection_parameters.get("files", None) + self.archives = connection_parameters.get("archives", None) + self.password_file = connection_parameters.get("password_file", None) self.hcatalog_database = hcatalog_database self.hcatalog_table = hcatalog_table self.verbose = verbose @@ -83,17 +83,17 @@ def __init__( def get_conn(self) -> Any: return self.conn - def cmd_mask_password(self, cmd_orig: List[str]) -> List[str]: + def cmd_mask_password(self, cmd_orig: list[str]) -> list[str]: """Mask command password for safety""" cmd = deepcopy(cmd_orig) try: - password_index = cmd.index('--password') - cmd[password_index + 1] = 'MASKED' + password_index = cmd.index("--password") + cmd[password_index + 1] = "MASKED" except ValueError: self.log.debug("No password in sqoop cmd") return cmd - def popen(self, cmd: List[str], **kwargs: Any) -> None: + def popen(self, cmd: list[str], **kwargs: Any) -> None: """ Remote Popen @@ -101,7 +101,7 @@ def popen(self, cmd: List[str], **kwargs: Any) -> None: :param kwargs: extra arguments to Popen (see subprocess.Popen) :return: handle to subprocess """ - masked_cmd = ' '.join(self.cmd_mask_password(cmd)) + masked_cmd = " ".join(self.cmd_mask_password(cmd)) self.log.info("Executing command: %s", masked_cmd) with subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kwargs) as sub_process: self.sub_process_pid = sub_process.pid @@ -112,7 +112,7 @@ def popen(self, cmd: List[str], **kwargs: Any) -> None: if sub_process.returncode: raise AirflowException(f"Sqoop command failed: {masked_cmd}") - def _prepare_command(self, export: bool = False) -> List[str]: + def _prepare_command(self, export: bool = False) -> list[str]: sqoop_cmd_type = "export" if export else "import" connection_cmd = ["sqoop", sqoop_cmd_type] @@ -149,7 +149,7 @@ def _prepare_command(self, export: bool = False) -> List[str]: connect_str += f":{self.conn.port}" if self.conn.schema: self.log.info("CONNECTION TYPE %s", self.conn.conn_type) - if self.conn.conn_type != 'mssql': + if self.conn.conn_type != "mssql": connect_str += f"/{self.conn.schema}" else: connect_str += f";databaseName={self.conn.schema}" @@ -158,7 +158,7 @@ def _prepare_command(self, export: bool = False) -> List[str]: return connection_cmd @staticmethod - def _get_export_format_argument(file_type: str = 'text') -> List[str]: + def _get_export_format_argument(file_type: str = "text") -> list[str]: if file_type == "avro": return ["--as-avrodatafile"] elif file_type == "sequence": @@ -172,14 +172,14 @@ def _get_export_format_argument(file_type: str = 'text') -> List[str]: def _import_cmd( self, - target_dir: Optional[str], + target_dir: str | None, append: bool, file_type: str, - split_by: Optional[str], - direct: Optional[bool], + split_by: str | None, + direct: bool | None, driver: Any, extra_import_options: Any, - ) -> List[str]: + ) -> list[str]: cmd = self._prepare_command(export=False) @@ -202,7 +202,7 @@ def _import_cmd( if extra_import_options: for key, value in extra_import_options.items(): - cmd += [f'--{key}'] + cmd += [f"--{key}"] if value: cmd += [str(value)] @@ -211,16 +211,16 @@ def _import_cmd( def import_table( self, table: str, - target_dir: Optional[str] = None, + target_dir: str | None = None, append: bool = False, file_type: str = "text", - columns: Optional[str] = None, - split_by: Optional[str] = None, - where: Optional[str] = None, + columns: str | None = None, + split_by: str | None = None, + where: str | None = None, direct: bool = False, driver: Any = None, - extra_import_options: Optional[Dict[str, Any]] = None, - schema: Optional[str] = None, + extra_import_options: dict[str, Any] | None = None, + schema: str | None = None, ) -> Any: """ Imports table from remote location to target dir. Arguments are @@ -257,13 +257,13 @@ def import_table( def import_query( self, query: str, - target_dir: Optional[str] = None, + target_dir: str | None = None, append: bool = False, file_type: str = "text", - split_by: Optional[str] = None, - direct: Optional[bool] = None, - driver: Optional[Any] = None, - extra_import_options: Optional[Dict[str, Any]] = None, + split_by: str | None = None, + direct: bool | None = None, + driver: Any | None = None, + extra_import_options: dict[str, Any] | None = None, ) -> Any: """ Imports a specific query from the rdbms to hdfs @@ -288,21 +288,21 @@ def import_query( def _export_cmd( self, table: str, - export_dir: Optional[str] = None, - input_null_string: Optional[str] = None, - input_null_non_string: Optional[str] = None, - staging_table: Optional[str] = None, + export_dir: str | None = None, + input_null_string: str | None = None, + input_null_non_string: str | None = None, + staging_table: str | None = None, clear_staging_table: bool = False, - enclosed_by: Optional[str] = None, - escaped_by: Optional[str] = None, - input_fields_terminated_by: Optional[str] = None, - input_lines_terminated_by: Optional[str] = None, - input_optionally_enclosed_by: Optional[str] = None, + enclosed_by: str | None = None, + escaped_by: str | None = None, + input_fields_terminated_by: str | None = None, + input_lines_terminated_by: str | None = None, + input_optionally_enclosed_by: str | None = None, batch: bool = False, relaxed_isolation: bool = False, - extra_export_options: Optional[Dict[str, Any]] = None, - schema: Optional[str] = None, - ) -> List[str]: + extra_export_options: dict[str, Any] | None = None, + schema: str | None = None, + ) -> list[str]: cmd = self._prepare_command(export=True) @@ -344,7 +344,7 @@ def _export_cmd( if extra_export_options: for key, value in extra_export_options.items(): - cmd += [f'--{key}'] + cmd += [f"--{key}"] if value: cmd += [str(value)] @@ -359,20 +359,20 @@ def _export_cmd( def export_table( self, table: str, - export_dir: Optional[str] = None, - input_null_string: Optional[str] = None, - input_null_non_string: Optional[str] = None, - staging_table: Optional[str] = None, + export_dir: str | None = None, + input_null_string: str | None = None, + input_null_non_string: str | None = None, + staging_table: str | None = None, clear_staging_table: bool = False, - enclosed_by: Optional[str] = None, - escaped_by: Optional[str] = None, - input_fields_terminated_by: Optional[str] = None, - input_lines_terminated_by: Optional[str] = None, - input_optionally_enclosed_by: Optional[str] = None, + enclosed_by: str | None = None, + escaped_by: str | None = None, + input_fields_terminated_by: str | None = None, + input_lines_terminated_by: str | None = None, + input_optionally_enclosed_by: str | None = None, batch: bool = False, relaxed_isolation: bool = False, - extra_export_options: Optional[Dict[str, Any]] = None, - schema: Optional[str] = None, + extra_export_options: dict[str, Any] | None = None, + schema: str | None = None, ) -> None: """ Exports Hive table to remote location. Arguments are copies of direct diff --git a/airflow/providers/apache/sqoop/operators/sqoop.py b/airflow/providers/apache/sqoop/operators/sqoop.py index 3ad7c5fe1a464..f6a44ebbf64a0 100644 --- a/airflow/providers/apache/sqoop/operators/sqoop.py +++ b/airflow/providers/apache/sqoop/operators/sqoop.py @@ -15,11 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains a sqoop 1 operator""" +from __future__ import annotations + import os import signal -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence +from typing import TYPE_CHECKING, Any, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -84,71 +85,71 @@ class SqoopOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'conn_id', - 'cmd_type', - 'table', - 'query', - 'target_dir', - 'file_type', - 'columns', - 'split_by', - 'where', - 'export_dir', - 'input_null_string', - 'input_null_non_string', - 'staging_table', - 'enclosed_by', - 'escaped_by', - 'input_fields_terminated_by', - 'input_lines_terminated_by', - 'input_optionally_enclosed_by', - 'properties', - 'extra_import_options', - 'driver', - 'extra_export_options', - 'hcatalog_database', - 'hcatalog_table', - 'schema', + "conn_id", + "cmd_type", + "table", + "query", + "target_dir", + "file_type", + "columns", + "split_by", + "where", + "export_dir", + "input_null_string", + "input_null_non_string", + "staging_table", + "enclosed_by", + "escaped_by", + "input_fields_terminated_by", + "input_lines_terminated_by", + "input_optionally_enclosed_by", + "properties", + "extra_import_options", + "driver", + "extra_export_options", + "hcatalog_database", + "hcatalog_table", + "schema", ) - template_fields_renderers = {'query': 'sql'} - ui_color = '#7D8CA4' + template_fields_renderers = {"query": "sql"} + ui_color = "#7D8CA4" def __init__( self, *, - conn_id: str = 'sqoop_default', - cmd_type: str = 'import', - table: Optional[str] = None, - query: Optional[str] = None, - target_dir: Optional[str] = None, + conn_id: str = "sqoop_default", + cmd_type: str = "import", + table: str | None = None, + query: str | None = None, + target_dir: str | None = None, append: bool = False, - file_type: str = 'text', - columns: Optional[str] = None, - num_mappers: Optional[int] = None, - split_by: Optional[str] = None, - where: Optional[str] = None, - export_dir: Optional[str] = None, - input_null_string: Optional[str] = None, - input_null_non_string: Optional[str] = None, - staging_table: Optional[str] = None, + file_type: str = "text", + columns: str | None = None, + num_mappers: int | None = None, + split_by: str | None = None, + where: str | None = None, + export_dir: str | None = None, + input_null_string: str | None = None, + input_null_non_string: str | None = None, + staging_table: str | None = None, clear_staging_table: bool = False, - enclosed_by: Optional[str] = None, - escaped_by: Optional[str] = None, - input_fields_terminated_by: Optional[str] = None, - input_lines_terminated_by: Optional[str] = None, - input_optionally_enclosed_by: Optional[str] = None, + enclosed_by: str | None = None, + escaped_by: str | None = None, + input_fields_terminated_by: str | None = None, + input_lines_terminated_by: str | None = None, + input_optionally_enclosed_by: str | None = None, batch: bool = False, direct: bool = False, - driver: Optional[Any] = None, + driver: Any | None = None, verbose: bool = False, relaxed_isolation: bool = False, - properties: Optional[Dict[str, Any]] = None, - hcatalog_database: Optional[str] = None, - hcatalog_table: Optional[str] = None, + properties: dict[str, Any] | None = None, + hcatalog_database: str | None = None, + hcatalog_table: str | None = None, create_hcatalog_table: bool = False, - extra_import_options: Optional[Dict[str, Any]] = None, - extra_export_options: Optional[Dict[str, Any]] = None, - schema: Optional[str] = None, + extra_import_options: dict[str, Any] | None = None, + extra_export_options: dict[str, Any] | None = None, + schema: str | None = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -184,15 +185,15 @@ def __init__( self.properties = properties self.extra_import_options = extra_import_options or {} self.extra_export_options = extra_export_options or {} - self.hook: Optional[SqoopHook] = None + self.hook: SqoopHook | None = None self.schema = schema - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: """Execute sqoop job""" if self.hook is None: self.hook = self._get_hook() - if self.cmd_type == 'export': + if self.cmd_type == "export": self.hook.export_table( table=self.table, # type: ignore export_dir=self.export_dir, @@ -210,15 +211,15 @@ def execute(self, context: "Context") -> None: extra_export_options=self.extra_export_options, schema=self.schema, ) - elif self.cmd_type == 'import': + elif self.cmd_type == "import": # add create hcatalog table to extra import options if option passed # if new params are added to constructor can pass them in here # so don't modify sqoop_hook for each param if self.create_hcatalog_table: - self.extra_import_options['create-hcatalog-table'] = '' + self.extra_import_options["create-hcatalog-table"] = "" if self.table and self.query: - raise AirflowException('Cannot specify query and table together. Need to specify either or.') + raise AirflowException("Cannot specify query and table together. Need to specify either or.") if self.table: self.hook.import_table( @@ -253,7 +254,7 @@ def execute(self, context: "Context") -> None: def on_kill(self) -> None: if self.hook is None: self.hook = self._get_hook() - self.log.info('Sending SIGTERM signal to bash process group') + self.log.info("Sending SIGTERM signal to bash process group") os.killpg(os.getpgid(self.hook.sub_process_pid), signal.SIGTERM) def _get_hook(self) -> SqoopHook: diff --git a/airflow/providers/apache/sqoop/provider.yaml b/airflow/providers/apache/sqoop/provider.yaml index 51b94f251f81f..b542763fd0f85 100644 --- a/airflow/providers/apache/sqoop/provider.yaml +++ b/airflow/providers/apache/sqoop/provider.yaml @@ -22,6 +22,8 @@ description: | `Apache Sqoop `__ versions: + - 3.1.0 + - 3.0.0 - 2.1.3 - 2.1.2 - 2.1.1 @@ -32,8 +34,8 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 integrations: - integration-name: Apache Sqoop @@ -53,9 +55,6 @@ hooks: python-modules: - airflow.providers.apache.sqoop.hooks.sqoop -hook-class-names: - # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.apache.sqoop.hooks.sqoop.SqoopHook connection-types: - hook-class-name: airflow.providers.apache.sqoop.hooks.sqoop.SqoopHook diff --git a/airflow/providers/arangodb/.latest-doc-only-change.txt b/airflow/providers/arangodb/.latest-doc-only-change.txt new file mode 100644 index 0000000000000..ff7136e07d744 --- /dev/null +++ b/airflow/providers/arangodb/.latest-doc-only-change.txt @@ -0,0 +1 @@ +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/arangodb/CHANGELOG.rst b/airflow/providers/arangodb/CHANGELOG.rst index 1bafa2d67956f..d6dfb2f0d4010 100644 --- a/airflow/providers/arangodb/CHANGELOG.rst +++ b/airflow/providers/arangodb/CHANGELOG.rst @@ -17,8 +17,51 @@ specific language governing permissions and limitations under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- + +2.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` +* ``Fix links to sources for examples (#24386)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + +2.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Clean up f-strings in logging calls (#23597)`` + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 1.0.0 ..... diff --git a/airflow/providers/arangodb/example_dags/example_arangodb.py b/airflow/providers/arangodb/example_dags/example_arangodb.py index f9da187cfb665..71c6346ef6073 100644 --- a/airflow/providers/arangodb/example_dags/example_arangodb.py +++ b/airflow/providers/arangodb/example_dags/example_arangodb.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from datetime import datetime from airflow.models.dag import DAG @@ -21,9 +23,9 @@ from airflow.providers.arangodb.sensors.arangodb import AQLSensor dag = DAG( - 'example_arangodb_operator', + "example_arangodb_operator", start_date=datetime(2021, 1, 1), - tags=['example'], + tags=["example"], catchup=False, ) @@ -41,7 +43,7 @@ # [START howto_aql_sensor_template_file_arangodb] -sensor = AQLSensor( +sensor2 = AQLSensor( task_id="aql_sensor_template_file", query="search_judy.sql", timeout=60, @@ -55,7 +57,7 @@ # [START howto_aql_operator_arangodb] operator = AQLOperator( - task_id='aql_operator', + task_id="aql_operator", query="FOR doc IN students RETURN doc", dag=dag, result_processor=lambda cursor: print([document["name"] for document in cursor]), @@ -65,8 +67,8 @@ # [START howto_aql_operator_template_file_arangodb] -operator = AQLOperator( - task_id='aql_operator_template_file', +operator2 = AQLOperator( + task_id="aql_operator_template_file", dag=dag, result_processor=lambda cursor: print([document["name"] for document in cursor]), query="search_all.sql", diff --git a/airflow/providers/arangodb/hooks/arangodb.py b/airflow/providers/arangodb/hooks/arangodb.py index f88ed9fb33cfa..a30583e837859 100644 --- a/airflow/providers/arangodb/hooks/arangodb.py +++ b/airflow/providers/arangodb/hooks/arangodb.py @@ -15,9 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module allows connecting to a ArangoDB.""" -from typing import Any, Dict, Optional +from __future__ import annotations + +from typing import Any from arango import AQLQueryExecuteError, ArangoClient as ArangoDBClient from arango.result import Result @@ -35,10 +36,10 @@ class ArangoDBHook(BaseHook): :param arangodb_conn_id: Reference to :ref:`ArangoDB connection id `. """ - conn_name_attr = 'arangodb_conn_id' - default_conn_name = 'arangodb_default' - conn_type = 'arangodb' - hook_name = 'ArangoDB' + conn_name_attr = "arangodb_conn_id" + default_conn_name = "arangodb_default" + conn_type = "arangodb" + hook_name = "ArangoDB" def __init__(self, arangodb_conn_id: str = default_conn_name, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -48,7 +49,7 @@ def __init__(self, arangodb_conn_id: str = default_conn_name, *args, **kwargs) - self.password = None self.db_conn = None self.arangodb_conn_id = arangodb_conn_id - self.client: Optional[ArangoDBClient] = None + self.client: ArangoDBClient | None = None self.get_conn() def get_conn(self) -> ArangoDBClient: @@ -90,7 +91,7 @@ def create_collection(self, name): self.db_conn.create_collection(name) return True else: - self.log.info('Collection already exists: %s', name) + self.log.info("Collection already exists: %s", name) return False def create_database(self, name): @@ -98,7 +99,7 @@ def create_database(self, name): self.db_conn.create_database(name) return True else: - self.log.info('Database already exists: %s', name) + self.log.info("Database already exists: %s", name) return False def create_graph(self, name): @@ -106,24 +107,24 @@ def create_graph(self, name): self.db_conn.create_graph(name) return True else: - self.log.info('Graph already exists: %s', name) + self.log.info("Graph already exists: %s", name) return False @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + def get_ui_field_behaviour() -> dict[str, Any]: return { - "hidden_fields": ['port', 'extra'], + "hidden_fields": ["port", "extra"], "relabeling": { - 'host': 'ArangoDB Host URL or comma separated list of URLs (coordinators in a cluster)', - 'schema': 'ArangoDB Database', - 'login': 'ArangoDB Username', - 'password': 'ArangoDB Password', + "host": "ArangoDB Host URL or comma separated list of URLs (coordinators in a cluster)", + "schema": "ArangoDB Database", + "login": "ArangoDB Username", + "password": "ArangoDB Password", }, "placeholders": { - 'host': 'eg."http://127.0.0.1:8529" or "http://127.0.0.1:8529,http://127.0.0.1:8530"' - ' (coordinators in a cluster)', - 'schema': '_system', - 'login': 'root', - 'password': 'password', + "host": 'eg."http://127.0.0.1:8529" or "http://127.0.0.1:8529,http://127.0.0.1:8530"' + " (coordinators in a cluster)", + "schema": "_system", + "login": "root", + "password": "password", }, } diff --git a/airflow/providers/arangodb/operators/arangodb.py b/airflow/providers/arangodb/operators/arangodb.py index 716ca455d7e0a..8f8c357713e74 100644 --- a/airflow/providers/arangodb/operators/arangodb.py +++ b/airflow/providers/arangodb/operators/arangodb.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Callable, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Sequence from airflow.models import BaseOperator from airflow.providers.arangodb.hooks.arangodb import ArangoDBHook @@ -38,7 +40,7 @@ class AQLOperator(BaseOperator): :param arangodb_conn_id: Reference to :ref:`ArangoDB connection id `. """ - template_fields: Sequence[str] = ('query',) + template_fields: Sequence[str] = ("query",) template_ext: Sequence[str] = (".sql",) template_fields_renderers = {"query": "sql"} @@ -47,8 +49,8 @@ def __init__( self, *, query: str, - arangodb_conn_id: str = 'arangodb_default', - result_processor: Optional[Callable] = None, + arangodb_conn_id: str = "arangodb_default", + result_processor: Callable | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -56,8 +58,8 @@ def __init__( self.query = query self.result_processor = result_processor - def execute(self, context: 'Context'): - self.log.info('Executing: %s', self.query) + def execute(self, context: Context): + self.log.info("Executing: %s", self.query) hook = ArangoDBHook(arangodb_conn_id=self.arangodb_conn_id) result = hook.query(self.query) if self.result_processor: diff --git a/airflow/providers/arangodb/provider.yaml b/airflow/providers/arangodb/provider.yaml index 129be5b8a0d9c..b830c10976360 100644 --- a/airflow/providers/arangodb/provider.yaml +++ b/airflow/providers/arangodb/provider.yaml @@ -20,7 +20,14 @@ package-name: apache-airflow-providers-arangodb name: ArangoDB description: | `ArangoDB `__ + +dependencies: + - apache-airflow>=2.3.0 + - python-arango>=7.3.2 + versions: + - 2.1.0 + - 2.0.0 - 1.0.0 integrations: diff --git a/airflow/providers/arangodb/sensors/arangodb.py b/airflow/providers/arangodb/sensors/arangodb.py index ee9d0d2a9004d..541cb9b38f8e3 100644 --- a/airflow/providers/arangodb/sensors/arangodb.py +++ b/airflow/providers/arangodb/sensors/arangodb.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from typing import TYPE_CHECKING, Sequence from airflow.providers.arangodb.hooks.arangodb import ArangoDBHook @@ -36,7 +38,7 @@ class AQLSensor(BaseSensorOperator): :param arangodb_db: Target ArangoDB name. """ - template_fields: Sequence[str] = ('query',) + template_fields: Sequence[str] = ("query",) template_ext: Sequence[str] = (".sql",) template_fields_renderers = {"query": "sql"} @@ -46,7 +48,7 @@ def __init__(self, *, query: str, arangodb_conn_id: str = "arangodb_default", ** self.arangodb_conn_id = arangodb_conn_id self.query = query - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: self.log.info("Sensor running the following query: %s", self.query) hook = ArangoDBHook(self.arangodb_conn_id) records = hook.query(self.query, count=True).count() diff --git a/airflow/providers/asana/.latest-doc-only-change.txt b/airflow/providers/asana/.latest-doc-only-change.txt index ab24993f57139..ff7136e07d744 100644 --- a/airflow/providers/asana/.latest-doc-only-change.txt +++ b/airflow/providers/asana/.latest-doc-only-change.txt @@ -1 +1 @@ -8b6b0848a3cacf9999477d6af4d2a87463f03026 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/asana/CHANGELOG.rst b/airflow/providers/asana/CHANGELOG.rst index 5262e12f92bc0..a868ec662a007 100644 --- a/airflow/providers/asana/CHANGELOG.rst +++ b/airflow/providers/asana/CHANGELOG.rst @@ -15,9 +15,75 @@ specific language governing permissions and limitations under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.0.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +* In AsanaHook, non-prefixed extra fields are supported and are preferred. So if you should update your + connection to replace ``extra__asana__workspace`` with ``workspace`` etc. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Features +~~~~~~~~ + +* ``Allow and prefer non-prefixed extra fields for AsanaHook (#27043)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + + +2.0.1 +..... + +Bug Fixes +~~~~~~~~~ + +* ``Update providers to use functools compat for ''cached_property'' (#24582)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +2.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Migrate Asana example DAGs to new design #22440 (#24131)`` + * ``Prepare provider documentation 2022.05.11 (#23631)`` + * ``Use new Breese for building, pulling and verifying the images. (#23104)`` + * ``Fix new MyPy errors in main (#22884)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 1.1.3 ..... diff --git a/airflow/providers/asana/hooks/asana.py b/airflow/providers/asana/hooks/asana.py index 45816437f81b8..544a5afb59961 100644 --- a/airflow/providers/asana/hooks/asana.py +++ b/airflow/providers/asana/hooks/asana.py @@ -15,22 +15,47 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Connect to Asana.""" -import sys -from typing import Any, Dict, Optional +from __future__ import annotations + +from functools import wraps +from typing import Any from asana import Client # type: ignore[attr-defined] from asana.error import NotFoundError # type: ignore[attr-defined] -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - +from airflow.compat.functools import cached_property from airflow.hooks.base import BaseHook +def _ensure_prefixes(conn_type): + """ + Remove when provider min airflow version >= 2.5.0 since this is handled by + provider manager from that version. + """ + + def dec(func): + @wraps(func) + def inner(): + field_behaviors = func() + conn_attrs = {"host", "schema", "login", "password", "port", "extra"} + + def _ensure_prefix(field): + if field not in conn_attrs and not field.startswith("extra__"): + return f"extra__{conn_type}__{field}" + else: + return field + + if "placeholders" in field_behaviors: + placeholders = field_behaviors["placeholders"] + field_behaviors["placeholders"] = {_ensure_prefix(k): v for k, v in placeholders.items()} + return field_behaviors + + return inner + + return dec + + class AsanaHook(BaseHook): """Wrapper around Asana Python client library.""" @@ -43,34 +68,48 @@ def __init__(self, conn_id: str = default_conn_name, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.connection = self.get_connection(conn_id) extras = self.connection.extra_dejson - self.workspace = extras.get("extra__asana__workspace") or None - self.project = extras.get("extra__asana__project") or None + self.workspace = self._get_field(extras, "workspace") or None + self.project = self._get_field(extras, "project") or None + + def _get_field(self, extras: dict, field_name: str): + """Get field from extra, first checking short name, then for backcompat we check for prefixed name.""" + backcompat_prefix = "extra__asana__" + if field_name.startswith("extra__"): + raise ValueError( + f"Got prefixed name {field_name}; please remove the '{backcompat_prefix}' prefix " + "when using this method." + ) + if field_name in extras: + return extras[field_name] or None + prefixed_name = f"{backcompat_prefix}{field_name}" + return extras.get(prefixed_name) or None def get_conn(self) -> Client: return self.client @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: + def get_connection_form_widgets() -> dict[str, Any]: """Returns connection widgets to add to connection form""" from flask_appbuilder.fieldwidgets import BS3TextFieldWidget from flask_babel import lazy_gettext from wtforms import StringField return { - "extra__asana__workspace": StringField(lazy_gettext("Workspace"), widget=BS3TextFieldWidget()), - "extra__asana__project": StringField(lazy_gettext("Project"), widget=BS3TextFieldWidget()), + "workspace": StringField(lazy_gettext("Workspace"), widget=BS3TextFieldWidget()), + "project": StringField(lazy_gettext("Project"), widget=BS3TextFieldWidget()), } @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + @_ensure_prefixes(conn_type="asana") + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { "hidden_fields": ["port", "host", "login", "schema"], "relabeling": {}, "placeholders": { "password": "Asana personal access token", - "extra__asana__workspace": "Asana workspace gid", - "extra__asana__project": "Asana project gid", + "workspace": "Asana workspace gid", + "project": "Asana project gid", }, } @@ -85,7 +124,7 @@ def client(self) -> Client: return Client.access_token(self.connection.password) - def create_task(self, task_name: str, params: Optional[dict]) -> dict: + def create_task(self, task_name: str, params: dict | None) -> dict: """ Creates an Asana task. @@ -99,7 +138,7 @@ def create_task(self, task_name: str, params: Optional[dict]) -> dict: response = self.client.tasks.create(params=merged_params) return response - def _merge_create_task_parameters(self, task_name: str, task_params: Optional[dict]) -> dict: + def _merge_create_task_parameters(self, task_name: str, task_params: dict | None) -> dict: """ Merge create_task parameters with default params from the connection. @@ -107,7 +146,7 @@ def _merge_create_task_parameters(self, task_name: str, task_params: Optional[di :param task_params: Other task parameters which should override defaults from the connection :return: A dict of merged parameters to use in the new task """ - merged_params: Dict[str, Any] = {"name": task_name} + merged_params: dict[str, Any] = {"name": task_name} if self.project: merged_params["projects"] = [self.project] # Only use default workspace if user did not provide a project id @@ -145,7 +184,7 @@ def delete_task(self, task_id: str) -> dict: self.log.info("Asana task %s not found for deletion.", task_id) return {} - def find_task(self, params: Optional[dict]) -> list: + def find_task(self, params: dict | None) -> list: """ Retrieves a list of Asana tasks that match search parameters. @@ -158,7 +197,7 @@ def find_task(self, params: Optional[dict]) -> list: response = self.client.tasks.find_all(params=merged_params) return list(response) - def _merge_find_task_parameters(self, search_parameters: Optional[dict]) -> dict: + def _merge_find_task_parameters(self, search_parameters: dict | None) -> dict: """ Merge find_task parameters with default params from the connection. diff --git a/airflow/providers/asana/operators/asana_tasks.py b/airflow/providers/asana/operators/asana_tasks.py index 66d88291a3903..9d204153d4e16 100644 --- a/airflow/providers/asana/operators/asana_tasks.py +++ b/airflow/providers/asana/operators/asana_tasks.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from airflow.models import BaseOperator from airflow.providers.asana.hooks.asana import AsanaHook @@ -47,7 +48,7 @@ def __init__( *, conn_id: str, name: str, - task_parameters: Optional[dict] = None, + task_parameters: dict | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -56,7 +57,7 @@ def __init__( self.name = name self.task_parameters = task_parameters - def execute(self, context: 'Context') -> str: + def execute(self, context: Context) -> str: hook = AsanaHook(conn_id=self.conn_id) response = hook.create_task(self.name, self.task_parameters) self.log.info(response) @@ -93,7 +94,7 @@ def __init__( self.asana_task_gid = asana_task_gid self.task_parameters = task_parameters - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = AsanaHook(conn_id=self.conn_id) response = hook.update_task(self.asana_task_gid, self.task_parameters) self.log.info(response) @@ -123,7 +124,7 @@ def __init__( self.conn_id = conn_id self.asana_task_gid = asana_task_gid - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = AsanaHook(conn_id=self.conn_id) response = hook.delete_task(self.asana_task_gid) self.log.info(response) @@ -148,7 +149,7 @@ def __init__( self, *, conn_id: str, - search_parameters: Optional[dict] = None, + search_parameters: dict | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -156,7 +157,7 @@ def __init__( self.conn_id = conn_id self.search_parameters = search_parameters - def execute(self, context: 'Context') -> list: + def execute(self, context: Context) -> list: hook = AsanaHook(conn_id=self.conn_id) response = hook.find_task(self.search_parameters) self.log.info(response) diff --git a/airflow/providers/asana/provider.yaml b/airflow/providers/asana/provider.yaml index e82dd7c6ecc45..d8b27c99080b1 100644 --- a/airflow/providers/asana/provider.yaml +++ b/airflow/providers/asana/provider.yaml @@ -22,14 +22,18 @@ description: | `Asana `__ versions: + - 3.0.0 + - 2.0.1 + - 2.0.0 - 1.1.3 - 1.1.2 - 1.1.1 - 1.1.0 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - asana>=0.10 integrations: - integration-name: Asana @@ -48,8 +52,6 @@ hooks: python-modules: - airflow.providers.asana.hooks.asana -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.asana.hooks.asana.AsanaHook connection-types: - hook-class-name: airflow.providers.asana.hooks.asana.AsanaHook diff --git a/airflow/providers/apache/drill/example_dags/__init__.py b/airflow/providers/atlassian/__init__.py similarity index 100% rename from airflow/providers/apache/drill/example_dags/__init__.py rename to airflow/providers/atlassian/__init__.py diff --git a/airflow/providers/atlassian/jira/CHANGELOG.rst b/airflow/providers/atlassian/jira/CHANGELOG.rst new file mode 100644 index 0000000000000..e316872b87bfb --- /dev/null +++ b/airflow/providers/atlassian/jira/CHANGELOG.rst @@ -0,0 +1,45 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + +Changelog +--------- + +1.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + +1.0.0 +..... + +Initial version of the provider. diff --git a/airflow/providers/apache/kylin/example_dags/__init__.py b/airflow/providers/atlassian/jira/__init__.py similarity index 100% rename from airflow/providers/apache/kylin/example_dags/__init__.py rename to airflow/providers/atlassian/jira/__init__.py diff --git a/airflow/providers/apache/livy/example_dags/__init__.py b/airflow/providers/atlassian/jira/hooks/__init__.py similarity index 100% rename from airflow/providers/apache/livy/example_dags/__init__.py rename to airflow/providers/atlassian/jira/hooks/__init__.py diff --git a/airflow/providers/atlassian/jira/hooks/jira.py b/airflow/providers/atlassian/jira/hooks/jira.py new file mode 100644 index 0000000000000..7ce1a9e80a16a --- /dev/null +++ b/airflow/providers/atlassian/jira/hooks/jira.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for JIRA""" +from __future__ import annotations + +from typing import Any + +from jira import JIRA +from jira.exceptions import JIRAError + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook + + +class JiraHook(BaseHook): + """ + Jira interaction hook, a Wrapper around JIRA Python SDK. + + :param jira_conn_id: reference to a pre-defined Jira Connection + """ + + default_conn_name = "jira_default" + conn_type = "jira" + conn_name_attr = "jira_conn_id" + hook_name = "JIRA" + + def __init__(self, jira_conn_id: str = default_conn_name, proxies: Any | None = None) -> None: + super().__init__() + self.jira_conn_id = jira_conn_id + self.proxies = proxies + self.client: JIRA | None = None + self.get_conn() + + def get_conn(self) -> JIRA: + if not self.client: + self.log.debug("Creating Jira client for conn_id: %s", self.jira_conn_id) + + get_server_info = True + validate = True + extra_options = {} + if not self.jira_conn_id: + raise AirflowException("Failed to create jira client. no jira_conn_id provided") + + conn = self.get_connection(self.jira_conn_id) + if conn.extra is not None: + extra_options = conn.extra_dejson + # only required attributes are taken for now, + # more can be added ex: async, logging, max_retries + + # verify + if "verify" in extra_options and extra_options["verify"].lower() == "false": + extra_options["verify"] = False + + # validate + if "validate" in extra_options and extra_options["validate"].lower() == "false": + validate = False + + if "get_server_info" in extra_options and extra_options["get_server_info"].lower() == "false": + get_server_info = False + + try: + self.client = JIRA( + conn.host, + options=extra_options, + basic_auth=(conn.login, conn.password), + get_server_info=get_server_info, + validate=validate, + proxies=self.proxies, + ) + except JIRAError as jira_error: + raise AirflowException(f"Failed to create jira client, jira error: {str(jira_error)}") + except Exception as e: + raise AirflowException(f"Failed to create jira client, error: {str(e)}") + + return self.client diff --git a/airflow/providers/dbt/cloud/example_dags/__init__.py b/airflow/providers/atlassian/jira/operators/__init__.py similarity index 100% rename from airflow/providers/dbt/cloud/example_dags/__init__.py rename to airflow/providers/atlassian/jira/operators/__init__.py diff --git a/airflow/providers/atlassian/jira/operators/jira.py b/airflow/providers/atlassian/jira/operators/jira.py new file mode 100644 index 0000000000000..b1fe7bb05cd65 --- /dev/null +++ b/airflow/providers/atlassian/jira/operators/jira.py @@ -0,0 +1,91 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Sequence + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.atlassian.jira.hooks.jira import JIRAError, JiraHook + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class JiraOperator(BaseOperator): + """ + JiraOperator to interact and perform action on Jira issue tracking system. + This operator is designed to use Jira Python SDK: http://jira.readthedocs.io + + :param jira_conn_id: reference to a pre-defined Jira Connection + :param jira_method: method name from Jira Python SDK to be called + :param jira_method_args: required method parameters for the jira_method. (templated) + :param result_processor: function to further process the response from Jira + :param get_jira_resource_method: function or operator to get jira resource + on which the provided jira_method will be executed + """ + + template_fields: Sequence[str] = ("jira_method_args",) + + def __init__( + self, + *, + jira_method: str, + jira_conn_id: str = "jira_default", + jira_method_args: dict | None = None, + result_processor: Callable | None = None, + get_jira_resource_method: Callable | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.jira_conn_id = jira_conn_id + self.method_name = jira_method + self.jira_method_args = jira_method_args + self.result_processor = result_processor + self.get_jira_resource_method = get_jira_resource_method + + def execute(self, context: Context) -> Any: + try: + if self.get_jira_resource_method is not None: + # if get_jira_resource_method is provided, jira_method will be executed on + # resource returned by executing the get_jira_resource_method. + # This makes all the provided methods of JIRA sdk accessible and usable + # directly at the JiraOperator without additional wrappers. + # ref: http://jira.readthedocs.io/en/latest/api.html + if isinstance(self.get_jira_resource_method, JiraOperator): + resource = self.get_jira_resource_method.execute(**context) + else: + resource = self.get_jira_resource_method(**context) + else: + # Default method execution is on the top level jira client resource + hook = JiraHook(jira_conn_id=self.jira_conn_id) + resource = hook.client + + # Current Jira-Python SDK (1.0.7) has issue with pickling the jira response. + # ex: self.xcom_push(context, key='operator_response', value=jira_response) + # This could potentially throw error if jira_result is not picklable + jira_result = getattr(resource, self.method_name)(**self.jira_method_args) + if self.result_processor: + return self.result_processor(context, jira_result) + + return jira_result + + except JIRAError as jira_error: + raise AirflowException(f"Failed to execute jiraOperator, error: {str(jira_error)}") + except Exception as e: + raise AirflowException(f"Jira operator error: {str(e)}") diff --git a/airflow/providers/atlassian/jira/provider.yaml b/airflow/providers/atlassian/jira/provider.yaml new file mode 100644 index 0000000000000..3e68f29ec4bd8 --- /dev/null +++ b/airflow/providers/atlassian/jira/provider.yaml @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-atlassian-jira +name: Atlassian Jira +description: | + `Atlassian Jira `__ + +versions: + - 1.1.0 + - 1.0.0 + +dependencies: + - apache-airflow>=2.3.0 + - JIRA>1.0.7 + +integrations: + - integration-name: Atlassian Jira + external-doc-url: https://www.atlassian.com/pl/software/jira + logo: /integration-logos/jira/Jira.png + tags: [software] + +operators: + - integration-name: Atlassian Jira + python-modules: + - airflow.providers.atlassian.jira.operators.jira + +sensors: + - integration-name: Atlassian Jira + python-modules: + - airflow.providers.atlassian.jira.sensors.jira + +hooks: + - integration-name: Atlassian Jira + python-modules: + - airflow.providers.atlassian.jira.hooks.jira + +connection-types: + - hook-class-name: airflow.providers.atlassian.jira.hooks.jira.JiraHook + connection-type: jira diff --git a/airflow/providers/elasticsearch/example_dags/__init__.py b/airflow/providers/atlassian/jira/sensors/__init__.py similarity index 100% rename from airflow/providers/elasticsearch/example_dags/__init__.py rename to airflow/providers/atlassian/jira/sensors/__init__.py diff --git a/airflow/providers/atlassian/jira/sensors/jira.py b/airflow/providers/atlassian/jira/sensors/jira.py new file mode 100644 index 0000000000000..cd3af90dab2ae --- /dev/null +++ b/airflow/providers/atlassian/jira/sensors/jira.py @@ -0,0 +1,139 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Sequence + +from jira.resources import Issue, Resource + +from airflow.providers.atlassian.jira.hooks.jira import JiraHook +from airflow.providers.atlassian.jira.operators.jira import JIRAError +from airflow.sensors.base import BaseSensorOperator + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class JiraSensor(BaseSensorOperator): + """ + Monitors a jira ticket for any change. + + :param jira_conn_id: reference to a pre-defined Jira Connection + :param method_name: method name from jira-python-sdk to be execute + :param method_params: parameters for the method method_name + :param result_processor: function that return boolean and act as a sensor response + """ + + def __init__( + self, + *, + method_name: str, + jira_conn_id: str = "jira_default", + method_params: dict | None = None, + result_processor: Callable | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.jira_conn_id = jira_conn_id + self.result_processor = None + if result_processor is not None: + self.result_processor = result_processor + self.method_name = method_name + self.method_params = method_params + + def poke(self, context: Context) -> Any: + hook = JiraHook(jira_conn_id=self.jira_conn_id) + resource = hook.get_conn() + jira_result = getattr(resource, self.method_name)(**self.method_params) + if self.result_processor is None: + return jira_result + return self.result_processor(jira_result) + + +class JiraTicketSensor(JiraSensor): + """ + Monitors a jira ticket for given change in terms of function. + + :param jira_conn_id: reference to a pre-defined Jira Connection + :param ticket_id: id of the ticket to be monitored + :param field: field of the ticket to be monitored + :param expected_value: expected value of the field + :param result_processor: function that return boolean and act as a sensor response + """ + + template_fields: Sequence[str] = ("ticket_id",) + + def __init__( + self, + *, + jira_conn_id: str = "jira_default", + ticket_id: str | None = None, + field: str | None = None, + expected_value: str | None = None, + field_checker_func: Callable | None = None, + **kwargs, + ) -> None: + + self.jira_conn_id = jira_conn_id + self.ticket_id = ticket_id + self.field = field + self.expected_value = expected_value + if field_checker_func is None: + field_checker_func = self.issue_field_checker + + super().__init__(jira_conn_id=jira_conn_id, result_processor=field_checker_func, **kwargs) + + def poke(self, context: Context) -> Any: + self.log.info("Jira Sensor checking for change in ticket: %s", self.ticket_id) + + self.method_name = "issue" + self.method_params = {"id": self.ticket_id, "fields": self.field} + return JiraSensor.poke(self, context=context) + + def issue_field_checker(self, issue: Issue) -> bool | None: + """Check issue using different conditions to prepare to evaluate sensor.""" + result = None + try: + if issue is not None and self.field is not None and self.expected_value is not None: + + field_val = getattr(issue.fields, self.field) + if field_val is not None: + if isinstance(field_val, list): + result = self.expected_value in field_val + elif isinstance(field_val, str): + result = self.expected_value.lower() == field_val.lower() + elif isinstance(field_val, Resource) and getattr(field_val, "name"): + result = self.expected_value.lower() == field_val.name.lower() + else: + self.log.warning( + "Not implemented checker for issue field %s which " + "is neither string nor list nor Jira Resource", + self.field, + ) + + except JIRAError as jira_error: + self.log.error("Jira error while checking with expected value: %s", jira_error) + except Exception: + self.log.exception("Error while checking with expected value %s:", self.expected_value) + if result is True: + self.log.info( + "Issue field %s has expected value %s, returning success", self.field, self.expected_value + ) + else: + self.log.info("Issue field %s don't have expected value %s yet.", self.field, self.expected_value) + return result diff --git a/airflow/providers/celery/.latest-doc-only-change.txt b/airflow/providers/celery/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/celery/.latest-doc-only-change.txt +++ b/airflow/providers/celery/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/celery/CHANGELOG.rst b/airflow/providers/celery/CHANGELOG.rst index 4a291303f63af..33e86a4d7fb5c 100644 --- a/airflow/providers/celery/CHANGELOG.rst +++ b/airflow/providers/celery/CHANGELOG.rst @@ -16,9 +16,50 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.1.4 ..... diff --git a/airflow/providers/celery/provider.yaml b/airflow/providers/celery/provider.yaml index 058796dc23302..8adc4c1c6c12c 100644 --- a/airflow/providers/celery/provider.yaml +++ b/airflow/providers/celery/provider.yaml @@ -22,6 +22,8 @@ description: | `Celery `__ versions: + - 3.1.0 + - 3.0.0 - 2.1.4 - 2.1.3 - 2.1.2 @@ -31,8 +33,14 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.2.0 +dependencies: + - apache-airflow>=2.3.0 + # The Celery is known to introduce problems when upgraded to a MAJOR version. Airflow Core + # Uses Celery for CeleryExecutor, and we also know that Kubernetes Python client follows SemVer + # (https://docs.celeryq.dev/en/stable/contributing.html?highlight=semver#versions). + # Make sure that the limit here is synchronized with [celery] extra in the airflow core + - celery>=5.2.3,<6 + - flower>=1.0.0 integrations: - integration-name: Celery diff --git a/airflow/providers/celery/sensors/celery_queue.py b/airflow/providers/celery/sensors/celery_queue.py index 5a7674ae6bc09..78e412c06a0b6 100644 --- a/airflow/providers/celery/sensors/celery_queue.py +++ b/airflow/providers/celery/sensors/celery_queue.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from celery.app import control @@ -36,13 +37,13 @@ class CeleryQueueSensor(BaseSensorOperator): :param target_task_id: Task id for checking """ - def __init__(self, *, celery_queue: str, target_task_id: Optional[str] = None, **kwargs) -> None: + def __init__(self, *, celery_queue: str, target_task_id: str | None = None, **kwargs) -> None: super().__init__(**kwargs) self.celery_queue = celery_queue self.target_task_id = target_task_id - def _check_task_id(self, context: 'Context') -> bool: + def _check_task_id(self, context: Context) -> bool: """ Gets the returned Celery result from the Airflow task ID provided to the sensor, and returns True if the @@ -50,13 +51,12 @@ def _check_task_id(self, context: 'Context') -> bool: :param context: Airflow's execution context :return: True if task has been executed, otherwise False - :rtype: bool """ - ti = context['ti'] + ti = context["ti"] celery_result = ti.xcom_pull(task_ids=self.target_task_id) return celery_result.ready() - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: if self.target_task_id: return self._check_task_id(context) @@ -71,8 +71,8 @@ def poke(self, context: 'Context') -> bool: scheduled = len(scheduled[self.celery_queue]) active = len(active[self.celery_queue]) - self.log.info('Checking if celery queue %s is empty.', self.celery_queue) + self.log.info("Checking if celery queue %s is empty.", self.celery_queue) return reserved == 0 and scheduled == 0 and active == 0 except KeyError: - raise KeyError(f'Could not locate Celery queue {self.celery_queue}') + raise KeyError(f"Could not locate Celery queue {self.celery_queue}") diff --git a/airflow/providers/cloudant/.latest-doc-only-change.txt b/airflow/providers/cloudant/.latest-doc-only-change.txt index ab24993f57139..ff7136e07d744 100644 --- a/airflow/providers/cloudant/.latest-doc-only-change.txt +++ b/airflow/providers/cloudant/.latest-doc-only-change.txt @@ -1 +1 @@ -8b6b0848a3cacf9999477d6af4d2a87463f03026 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/cloudant/CHANGELOG.rst b/airflow/providers/cloudant/CHANGELOG.rst index a0167e8d356ba..8e452f1088931 100644 --- a/airflow/providers/cloudant/CHANGELOG.rst +++ b/airflow/providers/cloudant/CHANGELOG.rst @@ -16,9 +16,54 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare provider documentation 2022.05.11 (#23631)`` + * ``Use new Breese for building, pulling and verifying the images. (#23104)`` + * ``Fix new MyPy errors in main (#22884)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.0.4 ..... diff --git a/airflow/providers/cloudant/hooks/cloudant.py b/airflow/providers/cloudant/hooks/cloudant.py index 2398376dc4b24..248ee4d32c82a 100644 --- a/airflow/providers/cloudant/hooks/cloudant.py +++ b/airflow/providers/cloudant/hooks/cloudant.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """Hook for Cloudant""" -from typing import Any, Dict +from __future__ import annotations + +from typing import Any from cloudant import cloudant # type: ignore[attr-defined] @@ -33,17 +35,17 @@ class CloudantHook(BaseHook): :param cloudant_conn_id: The connection id to authenticate and get a session object from cloudant. """ - conn_name_attr = 'cloudant_conn_id' - default_conn_name = 'cloudant_default' - conn_type = 'cloudant' - hook_name = 'Cloudant' + conn_name_attr = "cloudant_conn_id" + default_conn_name = "cloudant_default" + conn_type = "cloudant" + hook_name = "Cloudant" @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { - "hidden_fields": ['port', 'extra'], - "relabeling": {'host': 'Account', 'login': 'Username (or API Key)', 'schema': 'Database'}, + "hidden_fields": ["port", "extra"], + "relabeling": {"host": "Account", "login": "Username (or API Key)", "schema": "Database"}, } def __init__(self, cloudant_conn_id: str = default_conn_name) -> None: @@ -61,7 +63,6 @@ def get_conn(self) -> cloudant: - 'password' equals the 'Password' (required) :return: an authorized cloudant session context manager object. - :rtype: cloudant """ conn = self.get_connection(self.cloudant_conn_id) @@ -72,6 +73,6 @@ def get_conn(self) -> cloudant: return cloudant_session def _validate_connection(self, conn: cloudant) -> None: - for conn_param in ['login', 'password']: + for conn_param in ["login", "password"]: if not getattr(conn, conn_param): - raise AirflowException(f'missing connection parameter {conn_param}') + raise AirflowException(f"missing connection parameter {conn_param}") diff --git a/airflow/providers/cloudant/provider.yaml b/airflow/providers/cloudant/provider.yaml index 9c4479f1859c5..8495f93af0b22 100644 --- a/airflow/providers/cloudant/provider.yaml +++ b/airflow/providers/cloudant/provider.yaml @@ -22,6 +22,8 @@ description: | `IBM Cloudant `__ versions: + - 3.1.0 + - 3.0.0 - 2.0.4 - 2.0.3 - 2.0.2 @@ -30,8 +32,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - cloudant>=2.0 integrations: - integration-name: IBM Cloudant @@ -44,9 +47,6 @@ hooks: python-modules: - airflow.providers.cloudant.hooks.cloudant -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.cloudant.hooks.cloudant.CloudantHook - connection-types: - hook-class-name: airflow.providers.cloudant.hooks.cloudant.CloudantHook connection-type: cloudant diff --git a/airflow/providers/cncf/kubernetes/CHANGELOG.rst b/airflow/providers/cncf/kubernetes/CHANGELOG.rst index bb2d236594044..f7b293fd1fd40 100644 --- a/airflow/providers/cncf/kubernetes/CHANGELOG.rst +++ b/airflow/providers/cncf/kubernetes/CHANGELOG.rst @@ -16,21 +16,192 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- -main -.... +5.0.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +Previously KubernetesPodOperator considered some settings from the Airflow config's ``kubernetes`` section. +Such consideration was deprecated in 4.1.0 and is now removed. If you previously relied on the Airflow +config, and you want client generation to have non-default configuration, you will need to define your +configuration in an Airflow connection and set KPO to use the connection. See kubernetes provider +documentation on defining a kubernetes Airflow connection for details. + +Drop support for providing ``resource`` as dict in ``KubernetesPodOperator``. You +should use ``container_resources`` with ``V1ResourceRequirements``. + +Param ``node_selectors`` has been removed in ``KubernetesPodOperator``; use ``node_selector`` instead. + +The following backcompat modules for KubernetesPodOperator are removed and you must now use +the corresponding objects from the kubernetes library: + +* ``airflow.providers.cncf.kubernetes.backcompat.pod`` +* ``airflow.providers.cncf.kubernetes.backcompat.pod_runtime_info_env`` +* ``airflow.providers.cncf.kubernetes.backcompat.volume`` +* ``airflow.providers.cncf.kubernetes.backcompat.volume_mount`` + +In ``KubernetesHook.get_namespace``, if a connection is defined but a namespace isn't set, we +currently return 'default'; this behavior is deprecated. In the next release, we'll return ``None``. + +* ``Remove deprecated backcompat objects for KPO (#27518)`` +* ``Remove support for node_selectors param in KPO (#27515)`` +* ``Remove unused backcompat method in k8s hook (#27490)`` +* ``Drop support for providing ''resource'' as dict in ''KubernetesPodOperator'' (#27197)`` +* ``Deprecate use of core get_kube_client in PodManager (#26848)`` +* ``Don't consider airflow core conf for KPO (#26849)`` + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` +* ``Use log.exception where more economical than log.error (#27517)`` + +Features +~~~~~~~~ + +Previously, ``name`` was a required argument for KubernetesPodOperator (when also not supplying pod +template or full pod spec). Now, if ``name`` is not supplied, ``task_id`` will be used. + +KubernetsPodOperator argument ``namespace`` is now optional. If not supplied via KPO param or pod +template file or full pod spec, then we'll check the airflow conn, +then if in a k8s pod, try to infer the namespace from the container, then finally +will use the ``default`` namespace. + + +* ``Add container_resources as KubernetesPodOperator templatable (#27457)`` +* ``Add deprecation warning re unset namespace in k8s hook (#27202)`` +* ``add container_name option for SparkKubernetesSensor (#26560)`` +* ``Allow xcom sidecar container image to be configurable in KPO (#26766)`` +* ``Improve task_id to pod name conversion (#27524)`` +* ``Make pod name optional in KubernetesPodOperator (#27120)`` +* ``Make namespace optional for KPO (#27116)`` +* ``Enable template rendering for env_vars field for the @task.kubernetes decorator (#27433)`` + +Bug Fixes +~~~~~~~~~ + +* ``Fix KubernetesHook fail on an attribute absence (#25787)`` +* ``Fix log message for kubernetes hooks (#26999)`` +* ``Remove extra__kubernetes__ prefix from k8s hook extras (#27021)`` +* ``KPO should use hook's get namespace method to get namespace (#27516)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + +4.4.0 +..... + +Features +~~~~~~~~ + +* ``feat(KubernetesPodOperator): Add support of container_security_context (#25530)`` +* ``Add @task.kubernetes taskflow decorator (#25663)`` +* ``pretty print KubernetesPodOperator rendered template env_vars (#25850)`` + +Bug Fixes +~~~~~~~~~ + +* ``Avoid calculating all elements when one item is needed (#26377)`` +* ``Wait for xcom sidecar container to start before sidecar exec (#25055)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare to release cncf.kubernetes provider (#26588)`` + +4.3.0 +..... Features ~~~~~~~~ -KubernetesPodOperator now uses KubernetesHook -````````````````````````````````````````````` +* ``Improve taskflow type hints with ParamSpec (#25173)`` + +Bug Fixes +~~~~~~~~~ + +* ``Fix xcom_sidecar stuck problem (#24993)`` + +4.2.0 +..... -Previously, KubernetesPodOperator relied on core Airflow configuration (namely setting for kubernetes executor) for certain settings used in client generation. Now KubernetesPodOperator uses KubernetesHook, and the consideration of core k8s settings is officially deprecated. +Features +~~~~~~~~ + +* ``Add 'airflow_kpo_in_cluster' label to KPO pods (#24658)`` +* ``Use found pod for deletion in KubernetesPodOperator (#22092)`` + +Bug Fixes +~~~~~~~~~ + +* ``Revert "Fix await_container_completion condition (#23883)" (#24474)`` +* ``Update providers to use functools compat for ''cached_property'' (#24582)`` + +Misc +~~~~ +* ``Rename 'resources' arg in Kub op to k8s_resources (#24673)`` -If you are using the Airflow configuration settings (e.g. as opposed to operator params) to configure the kubernetes client, then prior to the next major release you will need to add an Airflow connection and set your KPO tasks to use that connection. +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Only assert stuff for mypy when type checking (#24937)`` + * ``Remove 'xcom_push' flag from providers (#24823)`` + * ``More typing and minor refactor for kubernetes (#24719)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Use our yaml util in all providers (#24720)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +4.1.0 +..... + +Features +~~~~~~~~ + +* Previously, KubernetesPodOperator relied on core Airflow configuration (namely setting for kubernetes + executor) for certain settings used in client generation. Now KubernetesPodOperator + uses KubernetesHook, and the consideration of core k8s settings is officially deprecated. + +* If you are using the Airflow configuration settings (e.g. as opposed to operator params) to + configure the kubernetes client, then prior to the next major release you will need to + add an Airflow connection and set your KPO tasks to use that connection. + +* ``Use KubernetesHook to create api client in KubernetesPodOperator (#20578)`` +* ``[FEATURE] KPO use K8S hook (#22086)`` +* ``Add param docs to KubernetesHook and KubernetesPodOperator (#23955) (#24054)`` + +Bug Fixes +~~~~~~~~~ + +* ``Use "remote" pod when patching KPO pod as "checked" (#23676)`` +* ``Don't use the root logger in KPO _suppress function (#23835)`` +* ``Fix await_container_completion condition (#23883)`` + +Misc +~~~~ + +* ``Migrate Cncf.Kubernetes example DAGs to new design #22441 (#24132)`` +* ``Clean up f-strings in logging calls (#23597)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``pydocstyle D202 added (#24221)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` 4.0.2 ..... @@ -46,7 +217,6 @@ Bug Fixes appropriate section above if needed. Do not delete the lines(!): * ``Add YANKED to yanked releases of the cncf.kubernetes (#23378)`` -.. Review and move the new changes to one of the sections above: * ``Fix k8s pod.execute randomly stuck indefinitely by logs consumption (#23497) (#23618)`` * ``Revert "Fix k8s pod.execute randomly stuck indefinitely by logs consumption (#23497) (#23618)" (#23656)`` @@ -127,7 +297,7 @@ Features ~~~~~~~~ * ``Add map_index label to mapped KubernetesPodOperator (#21916)`` -* ``Change KubePodOperator labels from exeuction_date to run_id (#21960)`` +* ``Change KubernetesPodOperator labels from execution_date to run_id (#21960)`` Misc ~~~~ @@ -196,7 +366,7 @@ Notes on changes KubernetesPodOperator and PodLauncher Overview '''''''' -Generally speaking if you did not subclass ``KubernetesPodOperator`` and you didn't use the ``PodLauncher`` class directly, +Generally speaking if you did not subclass ``KubernetesPodOperator`` and you did not use the ``PodLauncher`` class directly, then you don't need to worry about this change. If however you have subclassed ``KubernetesPodOperator``, what follows are some notes on the changes in this release. @@ -392,7 +562,7 @@ Breaking changes Features ~~~~~~~~ -* ``Add 'KubernetesPodOperat' 'pod-template-file' jinja template support (#15942)`` +* ``Add 'KubernetesPodOperator' 'pod-template-file' jinja template support (#15942)`` * ``Save pod name to xcom for KubernetesPodOperator (#15755)`` Bug Fixes @@ -401,7 +571,7 @@ Bug Fixes * ``Bug Fix Pod-Template Affinity Ignored due to empty Affinity K8S Object (#15787)`` * ``Bug Pod Template File Values Ignored (#16095)`` * ``Fix issue with parsing error logs in the KPO (#15638)`` -* ``Fix unsuccessful KubernetesPod final_state call when 'is_delete_operator_pod=True' (#15490)`` +* ``Fix unsuccessful KubernetesPodOperator final_state call when 'is_delete_operator_pod=True' (#15490)`` .. Below changes are excluded from the changelog. Move them to appropriate section above if needed. Do not delete the lines(!): @@ -423,7 +593,7 @@ Bug Fixes ~~~~~~~~~ * ``Fix timeout when using XCom with KubernetesPodOperator (#15388)`` -* ``Fix labels on the pod created by ''KubernetsPodOperator'' (#15492)`` +* ``Fix labels on the pod created by ''KubernetesPodOperator'' (#15492)`` 1.1.0 ..... diff --git a/airflow/providers/cncf/kubernetes/__init__.py b/airflow/providers/cncf/kubernetes/__init__.py index 0998e31143fc8..3ca5974776c6c 100644 --- a/airflow/providers/cncf/kubernetes/__init__.py +++ b/airflow/providers/cncf/kubernetes/__init__.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import sys if sys.version_info < (3, 7): @@ -33,7 +35,7 @@ def _reduce_Logger(logger): if logging.getLogger(logger.name) is not logger: import pickle - raise pickle.PicklingError('logger cannot be pickled') + raise pickle.PicklingError("logger cannot be pickled") return logging.getLogger, (logger.name,) def _reduce_RootLogger(logger): diff --git a/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py b/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py index bf2b8329f8ccd..9d37da983e519 100644 --- a/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py +++ b/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py @@ -15,8 +15,7 @@ # specific language governing permissions and limitations # under the License. """Executes task in a Kubernetes POD""" - -from typing import List +from __future__ import annotations from kubernetes.client import ApiClient, models as k8s @@ -63,20 +62,6 @@ def convert_volume_mount(volume_mount) -> k8s.V1VolumeMount: return _convert_kube_model_object(volume_mount, k8s.V1VolumeMount) -def convert_resources(resources) -> k8s.V1ResourceRequirements: - """ - Converts an airflow Resources object into a k8s.V1ResourceRequirements - - :param resources: - :return: k8s.V1ResourceRequirements - """ - if isinstance(resources, dict): - from airflow.providers.cncf.kubernetes.backcompat.pod import Resources - - resources = Resources(**resources) - return _convert_kube_model_object(resources, k8s.V1ResourceRequirements) - - def convert_port(port) -> k8s.V1ContainerPort: """ Converts an airflow Port object into a k8s.V1ContainerPort @@ -87,7 +72,7 @@ def convert_port(port) -> k8s.V1ContainerPort: return _convert_kube_model_object(port, k8s.V1ContainerPort) -def convert_env_vars(env_vars) -> List[k8s.V1EnvVar]: +def convert_env_vars(env_vars) -> list[k8s.V1EnvVar]: """ Converts a dictionary into a list of env_vars @@ -115,7 +100,7 @@ def convert_pod_runtime_info_env(pod_runtime_info_envs) -> k8s.V1EnvVar: return _convert_kube_model_object(pod_runtime_info_envs, k8s.V1EnvVar) -def convert_image_pull_secrets(image_pull_secrets) -> List[k8s.V1LocalObjectReference]: +def convert_image_pull_secrets(image_pull_secrets) -> list[k8s.V1LocalObjectReference]: """ Converts a PodRuntimeInfoEnv into an k8s.V1EnvVar diff --git a/airflow/providers/cncf/kubernetes/backcompat/pod.py b/airflow/providers/cncf/kubernetes/backcompat/pod.py deleted file mode 100644 index 7f18117e18bfa..0000000000000 --- a/airflow/providers/cncf/kubernetes/backcompat/pod.py +++ /dev/null @@ -1,119 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Classes for interacting with Kubernetes API. - -This module is deprecated. Please use :mod:`kubernetes.client.models.V1ResourceRequirements` -and :mod:`kubernetes.client.models.V1ContainerPort`. -""" - -import warnings - -from kubernetes.client import models as k8s - -warnings.warn( - ( - "This module is deprecated. Please use `kubernetes.client.models.V1ResourceRequirements`" - " and `kubernetes.client.models.V1ContainerPort`." - ), - DeprecationWarning, - stacklevel=2, -) - - -class Resources: - """backwards compat for Resources.""" - - __slots__ = ( - 'request_memory', - 'request_cpu', - 'limit_memory', - 'limit_cpu', - 'limit_gpu', - 'request_ephemeral_storage', - 'limit_ephemeral_storage', - ) - - """ - :param request_memory: requested memory - :param request_cpu: requested CPU number - :param request_ephemeral_storage: requested ephemeral storage - :param limit_memory: limit for memory usage - :param limit_cpu: Limit for CPU used - :param limit_gpu: Limits for GPU used - :param limit_ephemeral_storage: Limit for ephemeral storage - """ - - def __init__( - self, - request_memory=None, - request_cpu=None, - request_ephemeral_storage=None, - limit_memory=None, - limit_cpu=None, - limit_gpu=None, - limit_ephemeral_storage=None, - ): - self.request_memory = request_memory - self.request_cpu = request_cpu - self.request_ephemeral_storage = request_ephemeral_storage - self.limit_memory = limit_memory - self.limit_cpu = limit_cpu - self.limit_gpu = limit_gpu - self.limit_ephemeral_storage = limit_ephemeral_storage - - def to_k8s_client_obj(self): - """ - Converts to k8s object. - - @rtype: object - """ - limits_raw = { - 'cpu': self.limit_cpu, - 'memory': self.limit_memory, - 'nvidia.com/gpu': self.limit_gpu, - 'ephemeral-storage': self.limit_ephemeral_storage, - } - requests_raw = { - 'cpu': self.request_cpu, - 'memory': self.request_memory, - 'ephemeral-storage': self.request_ephemeral_storage, - } - - limits = {k: v for k, v in limits_raw.items() if v} - requests = {k: v for k, v in requests_raw.items() if v} - resource_req = k8s.V1ResourceRequirements(limits=limits, requests=requests) - return resource_req - - -class Port: - """POD port""" - - __slots__ = ('name', 'container_port') - - def __init__(self, name=None, container_port=None): - """Creates port""" - self.name = name - self.container_port = container_port - - def to_k8s_client_obj(self): - """ - Converts to k8s object. - - :rtype: object - """ - return k8s.V1ContainerPort(name=self.name, container_port=self.container_port) diff --git a/airflow/providers/cncf/kubernetes/backcompat/pod_runtime_info_env.py b/airflow/providers/cncf/kubernetes/backcompat/pod_runtime_info_env.py deleted file mode 100644 index f08aecff33a89..0000000000000 --- a/airflow/providers/cncf/kubernetes/backcompat/pod_runtime_info_env.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Classes for interacting with Kubernetes API. - -This module is deprecated. Please use :mod:`kubernetes.client.models.V1EnvVar`. -""" - -import warnings - -import kubernetes.client.models as k8s - -warnings.warn( - "This module is deprecated. Please use `kubernetes.client.models.V1EnvVar`.", - DeprecationWarning, - stacklevel=2, -) - - -class PodRuntimeInfoEnv: - """Defines Pod runtime information as environment variable.""" - - def __init__(self, name, field_path): - """ - Adds Kubernetes pod runtime information as environment variables such as namespace, pod IP, pod name. - Full list of options can be found in kubernetes documentation. - - :param name: the name of the environment variable - :param field_path: path to pod runtime info. Ex: metadata.namespace | status.podIP - """ - self.name = name - self.field_path = field_path - - def to_k8s_client_obj(self): - """Converts to k8s object. - - :return: kubernetes.client.models.V1EnvVar - """ - return k8s.V1EnvVar( - name=self.name, - value_from=k8s.V1EnvVarSource(field_ref=k8s.V1ObjectFieldSelector(field_path=self.field_path)), - ) diff --git a/airflow/providers/cncf/kubernetes/backcompat/volume.py b/airflow/providers/cncf/kubernetes/backcompat/volume.py deleted file mode 100644 index c51ce8a551e38..0000000000000 --- a/airflow/providers/cncf/kubernetes/backcompat/volume.py +++ /dev/null @@ -1,62 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`kubernetes.client.models.V1Volume`.""" - -import warnings - -from kubernetes.client import models as k8s - -warnings.warn( - "This module is deprecated. Please use `kubernetes.client.models.V1Volume`.", - DeprecationWarning, - stacklevel=2, -) - - -class Volume: - """Backward compatible Volume""" - - def __init__(self, name, configs): - """Adds Kubernetes Volume to pod. allows pod to access features like ConfigMaps - and Persistent Volumes - - :param name: the name of the volume mount - :param configs: dictionary of any features needed for volume. We purposely keep this - vague since there are multiple volume types with changing configs. - """ - self.name = name - self.configs = configs - - def to_k8s_client_obj(self) -> k8s.V1Volume: - """ - Converts to k8s object. - - :return: Volume Mount k8s object - """ - resp = k8s.V1Volume(name=self.name) - for k, v in self.configs.items(): - snake_key = Volume._convert_to_snake_case(k) - if hasattr(resp, snake_key): - setattr(resp, snake_key, v) - else: - raise AttributeError(f"V1Volume does not have attribute {k}") - return resp - - # source: https://www.geeksforgeeks.org/python-program-to-convert-camel-case-string-to-snake-case/ - @staticmethod - def _convert_to_snake_case(input_string): - return ''.join('_' + i.lower() if i.isupper() else i for i in input_string).lstrip('_') diff --git a/airflow/providers/cncf/kubernetes/backcompat/volume_mount.py b/airflow/providers/cncf/kubernetes/backcompat/volume_mount.py deleted file mode 100644 index f9faed9d04a97..0000000000000 --- a/airflow/providers/cncf/kubernetes/backcompat/volume_mount.py +++ /dev/null @@ -1,58 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Classes for interacting with Kubernetes API""" - -import warnings - -from kubernetes.client import models as k8s - -warnings.warn( - "This module is deprecated. Please use `kubernetes.client.models.V1VolumeMount`.", - DeprecationWarning, - stacklevel=2, -) - - -class VolumeMount: - """Backward compatible VolumeMount""" - - __slots__ = ('name', 'mount_path', 'sub_path', 'read_only') - - def __init__(self, name, mount_path, sub_path, read_only): - """ - Initialize a Kubernetes Volume Mount. Used to mount pod level volumes to - running container. - - :param name: the name of the volume mount - :param mount_path: - :param sub_path: subpath within the volume mount - :param read_only: whether to access pod with read-only mode - """ - self.name = name - self.mount_path = mount_path - self.sub_path = sub_path - self.read_only = read_only - - def to_k8s_client_obj(self) -> k8s.V1VolumeMount: - """ - Converts to k8s object. - - :return: Volume Mount k8s object - """ - return k8s.V1VolumeMount( - name=self.name, mount_path=self.mount_path, sub_path=self.sub_path, read_only=self.read_only - ) diff --git a/airflow/mypy/__init__.py b/airflow/providers/cncf/kubernetes/decorators/__init__.py similarity index 100% rename from airflow/mypy/__init__.py rename to airflow/providers/cncf/kubernetes/decorators/__init__.py diff --git a/airflow/providers/cncf/kubernetes/decorators/kubernetes.py b/airflow/providers/cncf/kubernetes/decorators/kubernetes.py new file mode 100644 index 0000000000000..f68927c676433 --- /dev/null +++ b/airflow/providers/cncf/kubernetes/decorators/kubernetes.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import inspect +import os +import pickle +import uuid +from tempfile import TemporaryDirectory +from textwrap import dedent +from typing import TYPE_CHECKING, Callable, Sequence + +from kubernetes.client import models as k8s + +from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory +from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator +from airflow.providers.cncf.kubernetes.python_kubernetes_script import ( + remove_task_decorator, + write_python_script, +) + +if TYPE_CHECKING: + from airflow.utils.context import Context + +_PYTHON_SCRIPT_ENV = "__PYTHON_SCRIPT" + +_FILENAME_IN_CONTAINER = "/tmp/script.py" + + +def _generate_decode_command() -> str: + return ( + f'python -c "import base64, os;' + rf"x = os.environ[\"{_PYTHON_SCRIPT_ENV}\"];" + rf'f = open(\"{_FILENAME_IN_CONTAINER}\", \"w\"); f.write(x); f.close()"' + ) + + +def _read_file_contents(filename): + with open(filename) as script_file: + return script_file.read() + + +class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator): + custom_operator_name = "@task.kubernetes" + + # `cmds` and `arguments` are used internally by the operator + template_fields: Sequence[str] = tuple( + {"op_args", "op_kwargs", *KubernetesPodOperator.template_fields} - {"cmds", "arguments"} + ) + + # since we won't mutate the arguments, we should just do the shallow copy + # there are some cases we can't deepcopy the objects (e.g protobuf). + shallow_copy_attrs: Sequence[str] = ("python_callable",) + + def __init__(self, namespace: str = "default", **kwargs) -> None: + self.pickling_library = pickle + super().__init__( + namespace=namespace, + name=kwargs.pop("name", f"k8s_airflow_pod_{uuid.uuid4().hex}"), + cmds=["bash"], + arguments=["-cx", f"{_generate_decode_command()} && python {_FILENAME_IN_CONTAINER}"], + **kwargs, + ) + + def _get_python_source(self): + raw_source = inspect.getsource(self.python_callable) + res = dedent(raw_source) + res = remove_task_decorator(res, "@task.kubernetes") + return res + + def execute(self, context: Context): + with TemporaryDirectory(prefix="venv") as tmp_dir: + script_filename = os.path.join(tmp_dir, "script.py") + py_source = self._get_python_source() + + jinja_context = { + "op_args": self.op_args, + "op_kwargs": self.op_kwargs, + "pickling_library": self.pickling_library.__name__, + "python_callable": self.python_callable.__name__, + "python_callable_source": py_source, + "string_args_global": False, + } + write_python_script(jinja_context=jinja_context, filename=script_filename) + + self.env_vars = [ + *self.env_vars, + k8s.V1EnvVar(name=_PYTHON_SCRIPT_ENV, value=_read_file_contents(script_filename)), + ] + return super().execute(context) + + +def kubernetes_task( + python_callable: Callable | None = None, + multiple_outputs: bool | None = None, + **kwargs, +) -> TaskDecorator: + """Kubernetes operator decorator. + + This wraps a function to be executed in K8s using KubernetesPodOperator. + Also accepts any argument that DockerOperator will via ``kwargs``. Can be + reused in a single DAG. + + :param python_callable: Function to decorate + :param multiple_outputs: if set, function return value will be + unrolled to multiple XCom values. Dict will unroll to xcom values with + keys as XCom keys. Defaults to False. + """ + return task_decorator_factory( + python_callable=python_callable, + multiple_outputs=multiple_outputs, + decorated_operator_class=_KubernetesDecoratedOperator, + **kwargs, + ) diff --git a/airflow/providers/cncf/kubernetes/example_dags/example_spark_kubernetes.py b/airflow/providers/cncf/kubernetes/example_dags/example_spark_kubernetes.py deleted file mode 100644 index d01d4b1328c68..0000000000000 --- a/airflow/providers/cncf/kubernetes/example_dags/example_spark_kubernetes.py +++ /dev/null @@ -1,66 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This is an example DAG which uses SparkKubernetesOperator and SparkKubernetesSensor. -In this example, we create two tasks which execute sequentially. -The first task is to submit sparkApplication on Kubernetes cluster(the example uses spark-pi application). -and the second task is to check the final state of the sparkApplication that submitted in the first state. - -Spark-on-k8s operator is required to be already installed on Kubernetes -https://github.com/GoogleCloudPlatform/spark-on-k8s-operator -""" - -from datetime import datetime, timedelta - -# [START import_module] -# The DAG object; we'll need this to instantiate a DAG -from airflow import DAG - -# Operators; we need this to operate! -from airflow.providers.cncf.kubernetes.operators.spark_kubernetes import SparkKubernetesOperator -from airflow.providers.cncf.kubernetes.sensors.spark_kubernetes import SparkKubernetesSensor - -# [END import_module] - - -# [START instantiate_dag] - -dag = DAG( - 'spark_pi', - default_args={'max_active_runs': 1}, - description='submit spark-pi as sparkApplication on kubernetes', - schedule_interval=timedelta(days=1), - start_date=datetime(2021, 1, 1), - catchup=False, -) - -t1 = SparkKubernetesOperator( - task_id='spark_pi_submit', - namespace="default", - application_file="example_spark_kubernetes_spark_pi.yaml", - do_xcom_push=True, - dag=dag, -) - -t2 = SparkKubernetesSensor( - task_id='spark_pi_monitor', - namespace="default", - application_name="{{ task_instance.xcom_pull(task_ids='spark_pi_submit')['metadata']['name'] }}", - dag=dag, -) -t1 >> t2 diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index e15dce67ef40a..85b76d5f52970 100644 --- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -14,29 +14,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys +from __future__ import annotations + import tempfile import warnings -from typing import Any, Dict, Generator, List, Optional, Tuple, Union - -from kubernetes.config import ConfigException - -from airflow.kubernetes.kube_client import _disable_verify_ssl, _enable_tcp_keepalive - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property +from typing import TYPE_CHECKING, Any, Generator from kubernetes import client, config, watch +from kubernetes.config import ConfigException -try: - import airflow.utils.yaml as yaml -except ImportError: - import yaml # type: ignore[no-redef] - +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook +from airflow.kubernetes.kube_client import _disable_verify_ssl, _enable_tcp_keepalive +from airflow.utils import yaml def _load_body_to_dict(body): @@ -51,10 +42,10 @@ class KubernetesHook(BaseHook): """ Creates Kubernetes API connection. - - use in cluster configuration by using ``extra__kubernetes__in_cluster`` in connection - - use custom config by providing path to the file using ``extra__kubernetes__kube_config_path`` + - use in cluster configuration by using extra field ``in_cluster`` in connection + - use custom config by providing path to the file using extra field ``kube_config_path`` in connection - use custom configuration by providing content of kubeconfig file via - ``extra__kubernetes__kube_config`` in connection + extra field ``kube_config`` in connection - use default config by providing no extras This hook check for configuration option in the above order. Once an option is present it will @@ -76,53 +67,52 @@ class KubernetesHook(BaseHook): :param disable_tcp_keepalive: Set to ``True`` if you want to disable keepalive logic. """ - conn_name_attr = 'kubernetes_conn_id' - default_conn_name = 'kubernetes_default' - conn_type = 'kubernetes' - hook_name = 'Kubernetes Cluster Connection' + conn_name_attr = "kubernetes_conn_id" + default_conn_name = "kubernetes_default" + conn_type = "kubernetes" + hook_name = "Kubernetes Cluster Connection" + + DEFAULT_NAMESPACE = "default" @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: + def get_connection_form_widgets() -> dict[str, Any]: """Returns connection widgets to add to connection form""" from flask_appbuilder.fieldwidgets import BS3TextFieldWidget from flask_babel import lazy_gettext from wtforms import BooleanField, StringField return { - "extra__kubernetes__in_cluster": BooleanField(lazy_gettext('In cluster configuration')), - "extra__kubernetes__kube_config_path": StringField( - lazy_gettext('Kube config path'), widget=BS3TextFieldWidget() - ), - "extra__kubernetes__kube_config": StringField( - lazy_gettext('Kube config (JSON format)'), widget=BS3TextFieldWidget() + "in_cluster": BooleanField(lazy_gettext("In cluster configuration")), + "kube_config_path": StringField(lazy_gettext("Kube config path"), widget=BS3TextFieldWidget()), + "kube_config": StringField( + lazy_gettext("Kube config (JSON format)"), widget=BS3TextFieldWidget() ), - "extra__kubernetes__namespace": StringField( - lazy_gettext('Namespace'), widget=BS3TextFieldWidget() + "namespace": StringField(lazy_gettext("Namespace"), widget=BS3TextFieldWidget()), + "cluster_context": StringField(lazy_gettext("Cluster context"), widget=BS3TextFieldWidget()), + "disable_verify_ssl": BooleanField(lazy_gettext("Disable SSL")), + "disable_tcp_keepalive": BooleanField(lazy_gettext("Disable TCP keepalive")), + "xcom_sidecar_container_image": StringField( + lazy_gettext("XCom sidecar image"), widget=BS3TextFieldWidget() ), - "extra__kubernetes__cluster_context": StringField( - lazy_gettext('Cluster context'), widget=BS3TextFieldWidget() - ), - "extra__kubernetes__disable_verify_ssl": BooleanField(lazy_gettext('Disable SSL')), - "extra__kubernetes__disable_tcp_keepalive": BooleanField(lazy_gettext('Disable TCP keepalive')), } @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { - "hidden_fields": ['host', 'schema', 'login', 'password', 'port', 'extra'], + "hidden_fields": ["host", "schema", "login", "password", "port", "extra"], "relabeling": {}, } def __init__( self, - conn_id: Optional[str] = default_conn_name, - client_configuration: Optional[client.Configuration] = None, - cluster_context: Optional[str] = None, - config_file: Optional[str] = None, - in_cluster: Optional[bool] = None, - disable_verify_ssl: Optional[bool] = None, - disable_tcp_keepalive: Optional[bool] = None, + conn_id: str | None = default_conn_name, + client_configuration: client.Configuration | None = None, + cluster_context: str | None = None, + config_file: str | None = None, + in_cluster: bool | None = None, + disable_verify_ssl: bool | None = None, + disable_tcp_keepalive: bool | None = None, ) -> None: super().__init__() self.conn_id = conn_id @@ -132,14 +122,7 @@ def __init__( self.in_cluster = in_cluster self.disable_verify_ssl = disable_verify_ssl self.disable_tcp_keepalive = disable_tcp_keepalive - - # these params used for transition in KPO to K8s hook - # for a deprecation period we will continue to consider k8s settings from airflow.cfg - self._deprecated_core_disable_tcp_keepalive: Optional[bool] = None - self._deprecated_core_disable_verify_ssl: Optional[bool] = None - self._deprecated_core_in_cluster: Optional[bool] = None - self._deprecated_core_cluster_context: Optional[str] = None - self._deprecated_core_config_file: Optional[str] = None + self._is_in_cluster: bool | None = None @staticmethod def _coalesce_param(*params): @@ -157,7 +140,12 @@ def conn_extras(self): return extras def _get_field(self, field_name): - if field_name.startswith('extra_'): + """ + Prior to Airflow 2.3, in order to make use of UI customizations for extra fields, + we needed to store them with the prefix ``extra__kubernetes__``. This method + handles the backcompat, i.e. if the extra dict contains prefixed fields. + """ + if field_name.startswith("extra__"): raise ValueError( f"Got prefixed name {field_name}; please remove the 'extra__kubernetes__' prefix " f"when using this method." @@ -167,31 +155,12 @@ def _get_field(self, field_name): prefixed_name = f"extra__kubernetes__{field_name}" return self.conn_extras.get(prefixed_name) or None - @staticmethod - def _deprecation_warning_core_param(deprecation_warnings): - settings_list_str = ''.join([f"\n\t{k}={v!r}" for k, v in deprecation_warnings]) - warnings.warn( - f"\nApplying core Airflow settings from section [kubernetes] with the following keys:" - f"{settings_list_str}\n" - "In a future release, KubernetesPodOperator will no longer consider core\n" - "Airflow settings; define an Airflow connection instead.", - DeprecationWarning, - ) - - def get_conn(self) -> Any: + def get_conn(self) -> client.ApiClient: """Returns kubernetes api session for use with requests""" - - in_cluster = self._coalesce_param( - self.in_cluster, self.conn_extras.get("extra__kubernetes__in_cluster") or None - ) - cluster_context = self._coalesce_param( - self.cluster_context, self.conn_extras.get("extra__kubernetes__cluster_context") or None - ) - kubeconfig_path = self._coalesce_param( - self.config_file, self.conn_extras.get("extra__kubernetes__kube_config_path") or None - ) - - kubeconfig = self.conn_extras.get("extra__kubernetes__kube_config") or None + in_cluster = self._coalesce_param(self.in_cluster, self._get_field("in_cluster")) + cluster_context = self._coalesce_param(self.cluster_context, self._get_field("cluster_context")) + kubeconfig_path = self._coalesce_param(self.config_file, self._get_field("kube_config_path")) + kubeconfig = self._get_field("kube_config") num_selected_configuration = len([o for o in [in_cluster, kubeconfig, kubeconfig_path] if o]) if num_selected_configuration > 1: @@ -208,30 +177,6 @@ def get_conn(self) -> Any: self.disable_tcp_keepalive, _get_bool(self._get_field("disable_tcp_keepalive")) ) - # BEGIN apply settings from core kubernetes configuration - # this section should be removed in next major release - deprecation_warnings: List[Tuple[str, Any]] = [] - if disable_verify_ssl is None and self._deprecated_core_disable_verify_ssl is True: - deprecation_warnings.append(('verify_ssl', False)) - disable_verify_ssl = self._deprecated_core_disable_verify_ssl - # by default, hook will try in_cluster first. so we only need to - # apply core airflow config and alert when False and in_cluster not otherwise set. - if in_cluster is None and self._deprecated_core_in_cluster is False: - deprecation_warnings.append(('in_cluster', self._deprecated_core_in_cluster)) - in_cluster = self._deprecated_core_in_cluster - if not cluster_context and self._deprecated_core_cluster_context: - deprecation_warnings.append(('cluster_context', self._deprecated_core_cluster_context)) - cluster_context = self._deprecated_core_cluster_context - if not kubeconfig_path and self._deprecated_core_config_file: - deprecation_warnings.append(('config_file', self._deprecated_core_config_file)) - kubeconfig_path = self._deprecated_core_config_file - if disable_tcp_keepalive is None and self._deprecated_core_disable_tcp_keepalive is True: - deprecation_warnings.append(('enable_tcp_keepalive', False)) - disable_tcp_keepalive = True - if deprecation_warnings: - self._deprecation_warning_core_param(deprecation_warnings) - # END apply settings from core kubernetes configuration - if disable_verify_ssl is True: _disable_verify_ssl() if disable_tcp_keepalive is not True: @@ -239,11 +184,13 @@ def get_conn(self) -> Any: if in_cluster: self.log.debug("loading kube_config from: in_cluster configuration") + self._is_in_cluster = True config.load_incluster_config() return client.ApiClient() if kubeconfig_path is not None: self.log.debug("loading kube_config from: %s", kubeconfig_path) + self._is_in_cluster = False config.load_kube_config( config_file=kubeconfig_path, client_configuration=self.client_configuration, @@ -256,6 +203,7 @@ def get_conn(self) -> Any: self.log.debug("loading kube_config from: connection kube_config") temp_config.write(kubeconfig.encode()) temp_config.flush() + self._is_in_cluster = False config.load_kube_config( config_file=temp_config.name, client_configuration=self.client_configuration, @@ -265,36 +213,47 @@ def get_conn(self) -> Any: return self._get_default_client(cluster_context=cluster_context) - def _get_default_client(self, *, cluster_context=None): + def _get_default_client(self, *, cluster_context: str | None = None) -> client.ApiClient: # if we get here, then no configuration has been supplied # we should try in_cluster since that's most likely # but failing that just load assuming a kubeconfig file # in the default location try: config.load_incluster_config(client_configuration=self.client_configuration) + self._is_in_cluster = True except ConfigException: self.log.debug("loading kube_config from: default file") + self._is_in_cluster = False config.load_kube_config( client_configuration=self.client_configuration, context=cluster_context, ) return client.ApiClient() + @property + def is_in_cluster(self) -> bool: + """Expose whether the hook is configured with ``load_incluster_config`` or not""" + if self._is_in_cluster is not None: + return self._is_in_cluster + self.api_client # so we can determine if we are in_cluster or not + if TYPE_CHECKING: + assert self._is_in_cluster is not None + return self._is_in_cluster + @cached_property - def api_client(self) -> Any: + def api_client(self) -> client.ApiClient: """Cached Kubernetes API client""" return self.get_conn() @cached_property - def core_v1_client(self): + def core_v1_client(self) -> client.CoreV1Api: return client.CoreV1Api(api_client=self.api_client) def create_custom_object( - self, group: str, version: str, plural: str, body: Union[str, dict], namespace: Optional[str] = None + self, group: str, version: str, plural: str, body: str | dict, namespace: str | None = None ): """ Creates custom resource definition object in Kubernetes - :param group: api group :param version: api version :param plural: api plural @@ -302,35 +261,40 @@ def create_custom_object( :param namespace: kubernetes namespace """ api = client.CustomObjectsApi(self.api_client) - if namespace is None: - namespace = self.get_namespace() + namespace = namespace or self._get_namespace() or self.DEFAULT_NAMESPACE + if isinstance(body, str): body_dict = _load_body_to_dict(body) else: body_dict = body - try: - api.delete_namespaced_custom_object( - group=group, - version=version, - namespace=namespace, - plural=plural, - name=body_dict["metadata"]["name"], - ) - self.log.warning("Deleted SparkApplication with the same name.") - except client.rest.ApiException: - self.log.info("SparkApp %s not found.", body_dict['metadata']['name']) + + # Attribute "name" is not mandatory if "generateName" is used instead + if "name" in body_dict["metadata"]: + try: + api.delete_namespaced_custom_object( + group=group, + version=version, + namespace=namespace, + plural=plural, + name=body_dict["metadata"]["name"], + ) + + self.log.warning("Deleted SparkApplication with the same name") + except client.rest.ApiException: + self.log.info("SparkApplication %s not found", body_dict["metadata"]["name"]) try: response = api.create_namespaced_custom_object( group=group, version=version, namespace=namespace, plural=plural, body=body_dict ) + self.log.debug("Response: %s", response) return response except client.rest.ApiException as e: raise AirflowException(f"Exception when calling -> create_custom_object: {e}\n") def get_custom_object( - self, group: str, version: str, plural: str, name: str, namespace: Optional[str] = None + self, group: str, version: str, plural: str, name: str, namespace: str | None = None ): """ Get custom resource definition object from Kubernetes @@ -342,8 +306,7 @@ def get_custom_object( :param namespace: kubernetes namespace """ api = client.CustomObjectsApi(self.api_client) - if namespace is None: - namespace = self.get_namespace() + namespace = namespace or self._get_namespace() or self.DEFAULT_NAMESPACE try: response = api.get_namespaced_custom_object( group=group, version=version, namespace=namespace, plural=plural, name=name @@ -352,21 +315,45 @@ def get_custom_object( except client.rest.ApiException as e: raise AirflowException(f"Exception when calling -> get_custom_object: {e}\n") - def get_namespace(self) -> Optional[str]: - """Returns the namespace that defined in the connection""" + def get_namespace(self) -> str | None: + """ + Returns the namespace defined in the connection or 'default'. + + TODO: in provider version 6.0, return None when namespace not defined in connection + """ + namespace = self._get_namespace() + if self.conn_id and not namespace: + warnings.warn( + "Airflow connection defined but namespace is not set; returning 'default'. In " + "cncf.kubernetes provider version 6.0 we will return None when namespace is " + "not defined in the connection so that it's clear whether user intends 'default' or " + "whether namespace is unset (which is required in order to apply precedence logic in " + "KubernetesPodOperator).", + DeprecationWarning, + ) + return "default" + return namespace + + def _get_namespace(self) -> str | None: + """ + Returns the namespace that defined in the connection + + TODO: in provider version 6.0, get rid of this method and make it the behavior of get_namespace. + """ if self.conn_id: - connection = self.get_connection(self.conn_id) - extras = connection.extra_dejson - namespace = extras.get("extra__kubernetes__namespace", "default") - return namespace + return self._get_field("namespace") return None + def get_xcom_sidecar_container_image(self): + """Returns the xcom sidecar image that defined in the connection""" + return self._get_field("xcom_sidecar_container_image") + def get_pod_log_stream( self, pod_name: str, - container: Optional[str] = "", - namespace: Optional[str] = None, - ) -> Tuple[watch.Watch, Generator[str, None, None]]: + container: str | None = "", + namespace: str | None = None, + ) -> tuple[watch.Watch, Generator[str, None, None]]: """ Retrieves a log stream for a container in a kubernetes pod. @@ -374,23 +361,22 @@ def get_pod_log_stream( :param container: container name :param namespace: kubernetes namespace """ - api = client.CoreV1Api(self.api_client) watcher = watch.Watch() return ( watcher, watcher.stream( - api.read_namespaced_pod_log, + self.core_v1_client.read_namespaced_pod_log, name=pod_name, container=container, - namespace=namespace if namespace else self.get_namespace(), + namespace=namespace or self._get_namespace() or self.DEFAULT_NAMESPACE, ), ) def get_pod_logs( self, pod_name: str, - container: Optional[str] = "", - namespace: Optional[str] = None, + container: str | None = "", + namespace: str | None = None, ): """ Retrieves a container's log from the specified pod. @@ -399,16 +385,15 @@ def get_pod_logs( :param container: container name :param namespace: kubernetes namespace """ - api = client.CoreV1Api(self.api_client) - return api.read_namespaced_pod_log( + return self.core_v1_client.read_namespaced_pod_log( name=pod_name, container=container, _preload_content=False, - namespace=namespace if namespace else self.get_namespace(), + namespace=namespace or self._get_namespace() or self.DEFAULT_NAMESPACE, ) -def _get_bool(val) -> Optional[bool]: +def _get_bool(val) -> bool | None: """ Converts val to bool if can be done with certainty. If we cannot infer intention we return None. @@ -416,8 +401,8 @@ def _get_bool(val) -> Optional[bool]: if isinstance(val, bool): return val elif isinstance(val, str): - if val.strip().lower() == 'true': + if val.strip().lower() == "true": return True - elif val.strip().lower() == 'false': + elif val.strip().lower() == "false": return False return None diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py index 69f120e823b3b..c1735158ec1ec 100644 --- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py +++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py @@ -15,17 +15,18 @@ # specific language governing permissions and limitations # under the License. """Executes task in a Kubernetes POD""" +from __future__ import annotations + import json import logging import re -import sys import warnings from contextlib import AbstractContextManager -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Sequence from kubernetes.client import CoreV1Api, models as k8s -from airflow.configuration import conf +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.kubernetes import pod_generator from airflow.kubernetes.pod_generator import PodGenerator @@ -38,7 +39,6 @@ convert_image_pull_secrets, convert_pod_runtime_info_env, convert_port, - convert_resources, convert_toleration, convert_volume, convert_volume_mount, @@ -56,17 +56,37 @@ from airflow.utils.helpers import prune_dict, validate_key from airflow.version import version as airflow_version -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - if TYPE_CHECKING: import jinja2 from airflow.utils.context import Context +def _task_id_to_pod_name(val: str) -> str: + """ + Given a task_id, convert it to a pod name. + Adds a 0 if start or end char is invalid. + Replaces any other invalid char with `-`. + + :param val: non-empty string, presumed to be a task id + :return valid kubernetes object name. + """ + if not val: + raise ValueError("_task_id_to_pod_name requires non-empty string.") + val = val.lower() + if not re.match(r"[a-z0-9]", val[0]): + val = f"0{val}" + if not re.match(r"[a-z0-9]", val[-1]): + val = f"{val}0" + val = re.sub(r"[^a-z0-9\-.]", "-", val) + if len(val) > 253: + raise ValueError( + f"Pod name {val} is longer than 253 characters. " + "See https://kubernetes.io/docs/concepts/overview/working-with-objects/names/." + ) + return val + + class PodReattachFailure(AirflowException): """When we expect to be able to find a pod but cannot.""" @@ -102,6 +122,7 @@ class KubernetesPodOperator(BaseOperator): :param volume_mounts: volumeMounts for the launched pod. :param volumes: volumes for the launched pod. Includes ConfigMaps and PersistentVolumes. :param env_vars: Environment variables initialized in the container. (templated) + :param env_from: (Optional) List of sources to populate environment variables in the container. :param secrets: Kubernetes secrets to inject in the container. They can be exposed as environment vars or files in a volume. :param in_cluster: run kubernetes client with in_cluster configuration. @@ -116,7 +137,7 @@ class KubernetesPodOperator(BaseOperator): :param annotations: non-identifying metadata you can attach to the Pod. Can be a large range of data, and can include characters that are not permitted by labels. - :param resources: resources for the launched pod. + :param container_resources: resources for the launched pod. (templated) :param affinity: affinity scheduling rules for the launched pod. :param config_file: The path to the Kubernetes config file. (templated) If not specified, default value is ``~/.kube/config`` @@ -131,6 +152,7 @@ class KubernetesPodOperator(BaseOperator): :param hostnetwork: If True enable host networking on the pod. :param tolerations: A list of kubernetes tolerations. :param security_context: security options the pod should run with (PodSecurityContext). + :param container_security_context: security options the container should run with. :param dnspolicy: dnspolicy for the pod. :param schedulername: Specify a schedulername for the pod :param full_pod_spec: The complete podSpec @@ -141,76 +163,95 @@ class KubernetesPodOperator(BaseOperator): XCom when the container completes. :param pod_template_file: path to pod template file (templated) :param priority_class_name: priority class name for the launched Pod + :param pod_runtime_info_envs: (Optional) A list of environment variables, + to be set in the container. :param termination_grace_period: Termination grace period if task killed in UI, defaults to kubernetes default - :param: kubernetes_conn_id: To retrieve credentials for your k8s cluster from an Airflow connection + :param configmaps: (Optional) A list of names of config maps from which it collects ConfigMaps + to populate the environment variables with. The contents of the target + ConfigMap's Data field will represent the key-value pairs as environment variables. + Extends env_from. """ - BASE_CONTAINER_NAME = 'base' - POD_CHECKED_KEY = 'already_checked' + BASE_CONTAINER_NAME = "base" + POD_CHECKED_KEY = "already_checked" template_fields: Sequence[str] = ( - 'image', - 'cmds', - 'arguments', - 'env_vars', - 'labels', - 'config_file', - 'pod_template_file', - 'namespace', + "image", + "cmds", + "arguments", + "env_vars", + "labels", + "config_file", + "pod_template_file", + "namespace", + "container_resources", ) + template_fields_renderers = {"env_vars": "py"} def __init__( self, *, - kubernetes_conn_id: Optional[str] = None, # 'kubernetes_default', - namespace: Optional[str] = None, - image: Optional[str] = None, - name: Optional[str] = None, - random_name_suffix: Optional[bool] = True, - cmds: Optional[List[str]] = None, - arguments: Optional[List[str]] = None, - ports: Optional[List[k8s.V1ContainerPort]] = None, - volume_mounts: Optional[List[k8s.V1VolumeMount]] = None, - volumes: Optional[List[k8s.V1Volume]] = None, - env_vars: Optional[List[k8s.V1EnvVar]] = None, - env_from: Optional[List[k8s.V1EnvFromSource]] = None, - secrets: Optional[List[Secret]] = None, - in_cluster: Optional[bool] = None, - cluster_context: Optional[str] = None, - labels: Optional[Dict] = None, + kubernetes_conn_id: str | None = None, # 'kubernetes_default', + namespace: str | None = None, + image: str | None = None, + name: str | None = None, + random_name_suffix: bool | None = True, + cmds: list[str] | None = None, + arguments: list[str] | None = None, + ports: list[k8s.V1ContainerPort] | None = None, + volume_mounts: list[k8s.V1VolumeMount] | None = None, + volumes: list[k8s.V1Volume] | None = None, + env_vars: list[k8s.V1EnvVar] | None = None, + env_from: list[k8s.V1EnvFromSource] | None = None, + secrets: list[Secret] | None = None, + in_cluster: bool | None = None, + cluster_context: str | None = None, + labels: dict | None = None, reattach_on_restart: bool = True, startup_timeout_seconds: int = 120, get_logs: bool = True, - image_pull_policy: Optional[str] = None, - annotations: Optional[Dict] = None, - resources: Optional[k8s.V1ResourceRequirements] = None, - affinity: Optional[k8s.V1Affinity] = None, - config_file: Optional[str] = None, - node_selectors: Optional[dict] = None, - node_selector: Optional[dict] = None, - image_pull_secrets: Optional[List[k8s.V1LocalObjectReference]] = None, - service_account_name: Optional[str] = None, + image_pull_policy: str | None = None, + annotations: dict | None = None, + container_resources: k8s.V1ResourceRequirements | None = None, + affinity: k8s.V1Affinity | None = None, + config_file: str | None = None, + node_selector: dict | None = None, + image_pull_secrets: list[k8s.V1LocalObjectReference] | None = None, + service_account_name: str | None = None, is_delete_operator_pod: bool = True, hostnetwork: bool = False, - tolerations: Optional[List[k8s.V1Toleration]] = None, - security_context: Optional[Dict] = None, - dnspolicy: Optional[str] = None, - schedulername: Optional[str] = None, - full_pod_spec: Optional[k8s.V1Pod] = None, - init_containers: Optional[List[k8s.V1Container]] = None, + tolerations: list[k8s.V1Toleration] | None = None, + security_context: dict | None = None, + container_security_context: dict | None = None, + dnspolicy: str | None = None, + schedulername: str | None = None, + full_pod_spec: k8s.V1Pod | None = None, + init_containers: list[k8s.V1Container] | None = None, log_events_on_failure: bool = False, do_xcom_push: bool = False, - pod_template_file: Optional[str] = None, - priority_class_name: Optional[str] = None, - pod_runtime_info_envs: Optional[List[k8s.V1EnvVar]] = None, - termination_grace_period: Optional[int] = None, - configmaps: Optional[List[str]] = None, + pod_template_file: str | None = None, + priority_class_name: str | None = None, + pod_runtime_info_envs: list[k8s.V1EnvVar] | None = None, + termination_grace_period: int | None = None, + configmaps: list[str] | None = None, **kwargs, ) -> None: - if kwargs.get('xcom_push') is not None: - raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead") - super().__init__(resources=None, **kwargs) + # TODO: remove in provider 6.0.0 release. This is a mitigate step to advise users to switch to the + # container_resources parameter. + if isinstance(kwargs.get("resources"), k8s.V1ResourceRequirements): + raise AirflowException( + "Specifying resources for the launched pod with 'resources' is deprecated. " + "Use 'container_resources' instead." + ) + # TODO: remove in provider 6.0.0 release. This is a mitigate step to advise users to switch to the + # node_selector parameter. + if "node_selectors" in kwargs: + raise ValueError( + "Param `node_selectors` supplied. This param is no longer supported. " + "Use `node_selector` instead." + ) + super().__init__(**kwargs) self.kubernetes_conn_id = kubernetes_conn_id self.do_xcom_push = do_xcom_push self.image = image @@ -234,19 +275,10 @@ def __init__( self.reattach_on_restart = reattach_on_restart self.get_logs = get_logs self.image_pull_policy = image_pull_policy - if node_selectors: - # Node selectors is incorrect based on k8s API - warnings.warn( - "node_selectors is deprecated. Please use node_selector instead.", DeprecationWarning - ) - self.node_selector = node_selectors - elif node_selector: - self.node_selector = node_selector - else: - self.node_selector = {} + self.node_selector = node_selector or {} self.annotations = annotations or {} self.affinity = convert_affinity(affinity) if affinity else {} - self.k8s_resources = convert_resources(resources) if resources else {} + self.container_resources = container_resources self.config_file = config_file self.image_pull_secrets = convert_image_pull_secrets(image_pull_secrets) if image_pull_secrets else [] self.service_account_name = service_account_name @@ -256,6 +288,7 @@ def __init__( [convert_toleration(toleration) for toleration in tolerations] if tolerations else [] ) self.security_context = security_context or {} + self.container_security_context = container_security_context self.dnspolicy = dnspolicy self.schedulername = schedulername self.full_pod_spec = full_pod_spec @@ -266,25 +299,37 @@ def __init__( self.name = self._set_name(name) self.random_name_suffix = random_name_suffix self.termination_grace_period = termination_grace_period - self.pod_request_obj: Optional[k8s.V1Pod] = None - self.pod: Optional[k8s.V1Pod] = None + self.pod_request_obj: k8s.V1Pod | None = None + self.pod: k8s.V1Pod | None = None + + @cached_property + def _incluster_namespace(self): + from pathlib import Path + + path = Path("/var/run/secrets/kubernetes.io/serviceaccount/namespace") + return path.exists() and path.read_text() or None def _render_nested_template_fields( self, content: Any, - context: 'Context', - jinja_env: "jinja2.Environment", + context: Context, + jinja_env: jinja2.Environment, seen_oids: set, ) -> None: if id(content) not in seen_oids and isinstance(content, k8s.V1EnvVar): seen_oids.add(id(content)) - self._do_render_template_fields(content, ('value', 'name'), context, jinja_env, seen_oids) + self._do_render_template_fields(content, ("value", "name"), context, jinja_env, seen_oids) + return + + if id(content) not in seen_oids and isinstance(content, k8s.V1ResourceRequirements): + seen_oids.add(id(content)) + self._do_render_template_fields(content, ("limits", "requests"), context, jinja_env, seen_oids) return super()._render_nested_template_fields(content, context, jinja_env, seen_oids) @staticmethod - def _get_ti_pod_labels(context: Optional[dict] = None, include_try_number: bool = True) -> dict: + def _get_ti_pod_labels(context: Context | None = None, include_try_number: bool = True) -> dict[str, str]: """ Generate labels for the pod to track the pod in case of Operator crash @@ -294,26 +339,25 @@ def _get_ti_pod_labels(context: Optional[dict] = None, include_try_number: bool if not context: return {} - ti = context['ti'] - run_id = context['run_id'] + ti = context["ti"] + run_id = context["run_id"] labels = { - 'dag_id': ti.dag_id, - 'task_id': ti.task_id, - 'run_id': run_id, - 'kubernetes_pod_operator': 'True', + "dag_id": ti.dag_id, + "task_id": ti.task_id, + "run_id": run_id, + "kubernetes_pod_operator": "True", } - # If running on Airflow 2.3+: - map_index = getattr(ti, 'map_index', -1) + map_index = ti.map_index if map_index >= 0: - labels['map_index'] = map_index + labels["map_index"] = map_index if include_try_number: labels.update(try_number=ti.try_number) # In the case of sub dags this is just useful - if context['dag'].is_subdag: - labels['parent_dag_id'] = context['dag'].parent_dag.dag_id + if context["dag"].parent_dag: + labels["parent_dag_id"] = context["dag"].parent_dag.dag_id # Ensure that label is valid for Kube, # and if not truncate/remove invalid chars and replace with short hash. for label_id, label in labels.items(): @@ -326,21 +370,24 @@ def pod_manager(self) -> PodManager: return PodManager(kube_client=self.client) def get_hook(self): + warnings.warn("get_hook is deprecated. Please use hook instead.", DeprecationWarning, stacklevel=2) + return self.hook + + @cached_property + def hook(self) -> KubernetesHook: hook = KubernetesHook( conn_id=self.kubernetes_conn_id, in_cluster=self.in_cluster, config_file=self.config_file, cluster_context=self.cluster_context, ) - self._patch_deprecated_k8s_settings(hook) return hook @cached_property def client(self) -> CoreV1Api: - hook = self.get_hook() - return hook.core_v1_client + return self.hook.core_v1_client - def find_pod(self, namespace, context, *, exclude_checked=True) -> Optional[k8s.V1Pod]: + def find_pod(self, namespace: str, context: Context, *, exclude_checked: bool = True) -> k8s.V1Pod | None: """Returns an already-running pod for this task instance if one exists.""" label_selector = self._build_find_pod_label_selector(context, exclude_checked=exclude_checked) pod_list = self.client.list_namespaced_pod( @@ -351,15 +398,15 @@ def find_pod(self, namespace, context, *, exclude_checked=True) -> Optional[k8s. pod = None num_pods = len(pod_list) if num_pods > 1: - raise AirflowException(f'More than one pod running with labels {label_selector}') + raise AirflowException(f"More than one pod running with labels {label_selector}") elif num_pods == 1: pod = pod_list[0] self.log.info("Found matching pod %s with labels %s", pod.metadata.name, pod.metadata.labels) - self.log.info("`try_number` of task_instance: %s", context['ti'].try_number) - self.log.info("`try_number` of pod: %s", pod.metadata.labels['try_number']) + self.log.info("`try_number` of task_instance: %s", context["ti"].try_number) + self.log.info("`try_number` of pod: %s", pod.metadata.labels["try_number"]) return pod - def get_or_create_pod(self, pod_request_obj: k8s.V1Pod, context): + def get_or_create_pod(self, pod_request_obj: k8s.V1Pod, context: Context) -> k8s.V1Pod: if self.reattach_on_restart: pod = self.find_pod(self.namespace or pod_request_obj.metadata.namespace, context=context) if pod: @@ -368,7 +415,7 @@ def get_or_create_pod(self, pod_request_obj: k8s.V1Pod, context): self.pod_manager.create_pod(pod=pod_request_obj) return pod_request_obj - def await_pod_start(self, pod): + def await_pod_start(self, pod: k8s.V1Pod): try: self.pod_manager.await_pod_start(pod=pod, startup_timeout=self.startup_timeout_seconds) except PodLaunchFailedException: @@ -377,13 +424,17 @@ def await_pod_start(self, pod): self.log.error("Pod Event: %s - %s", event.reason, event.message) raise - def extract_xcom(self, pod): + def extract_xcom(self, pod: k8s.V1Pod): """Retrieves xcom value and kills xcom sidecar container""" result = self.pod_manager.extract_xcom(pod) - self.log.info("xcom result: \n%s", result) - return json.loads(result) + if isinstance(result, str) and result.rstrip() == "__airflow_xcom_result_empty__": + self.log.info("Result file is empty.") + return None + else: + self.log.info("xcom result: \n%s", result) + return json.loads(result) - def execute(self, context: 'Context'): + def execute(self, context: Context): remote_pod = None try: self.pod_request_obj = self.build_pod_request_obj(context) @@ -391,6 +442,8 @@ def execute(self, context: 'Context'): pod_request_obj=self.pod_request_obj, context=context, ) + # get remote pod for use in cleanup methods + remote_pod = self.find_pod(self.pod.metadata.namespace, context=context) self.await_pod_start(pod=self.pod) if self.get_logs: @@ -405,6 +458,7 @@ def execute(self, context: 'Context'): ) if self.do_xcom_push: + self.pod_manager.await_xcom_sidecar_container_start(pod=self.pod) result = self.extract_xcom(pod=self.pod) remote_pod = self.pod_manager.await_pod_completion(self.pod) finally: @@ -412,14 +466,14 @@ def execute(self, context: 'Context'): pod=self.pod or self.pod_request_obj, remote_pod=remote_pod, ) - ti = context['ti'] - ti.xcom_push(key='pod_name', value=self.pod.metadata.name) - ti.xcom_push(key='pod_namespace', value=self.pod.metadata.namespace) + ti = context["ti"] + ti.xcom_push(key="pod_name", value=self.pod.metadata.name) + ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace) if self.do_xcom_push: return result def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod): - pod_phase = remote_pod.status.phase if hasattr(remote_pod, 'status') else None + pod_phase = remote_pod.status.phase if hasattr(remote_pod, "status") else None if not self.is_delete_operator_pod: with _suppress(Exception): self.patch_already_checked(remote_pod) @@ -429,40 +483,39 @@ def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod): for event in self.pod_manager.read_pod_events(pod).items: self.log.error("Pod Event: %s - %s", event.reason, event.message) with _suppress(Exception): - self.process_pod_deletion(pod) + self.process_pod_deletion(remote_pod) error_message = get_container_termination_message(remote_pod, self.BASE_CONTAINER_NAME) error_message = "\n" + error_message if error_message else "" raise AirflowException( - f'Pod {pod and pod.metadata.name} returned a failure:{error_message}\n{remote_pod}' + f"Pod {pod and pod.metadata.name} returned a failure:{error_message}\n{remote_pod}" ) else: with _suppress(Exception): - self.process_pod_deletion(pod) + self.process_pod_deletion(remote_pod) - def process_pod_deletion(self, pod): - if self.is_delete_operator_pod: - self.log.info("Deleting pod: %s", pod.metadata.name) - self.pod_manager.delete_pod(pod) - else: - self.log.info("skipping deleting pod: %s", pod.metadata.name) + def process_pod_deletion(self, pod: k8s.V1Pod): + if pod is not None: + if self.is_delete_operator_pod: + self.log.info("Deleting pod: %s", pod.metadata.name) + self.pod_manager.delete_pod(pod) + else: + self.log.info("skipping deleting pod: %s", pod.metadata.name) - def _build_find_pod_label_selector(self, context: Optional[dict] = None, *, exclude_checked=True) -> str: + def _build_find_pod_label_selector(self, context: Context | None = None, *, exclude_checked=True) -> str: labels = self._get_ti_pod_labels(context, include_try_number=False) - label_strings = [f'{label_id}={label}' for label_id, label in sorted(labels.items())] - labels_value = ','.join(label_strings) + label_strings = [f"{label_id}={label}" for label_id, label in sorted(labels.items())] + labels_value = ",".join(label_strings) if exclude_checked: - labels_value += f',{self.POD_CHECKED_KEY}!=True' - labels_value += ',!airflow-worker' + labels_value += f",{self.POD_CHECKED_KEY}!=True" + labels_value += ",!airflow-worker" return labels_value - def _set_name(self, name): - if name is None: - if self.pod_template_file or self.full_pod_spec: - return None - raise AirflowException("`name` is required unless `pod_template_file` or `full_pod_spec` is set") - - validate_key(name, max_length=220) - return re.sub(r'[^a-z0-9-]+', '-', name.lower()) + @staticmethod + def _set_name(name: str | None) -> str | None: + if name is not None: + validate_key(name, max_length=220) + return re.sub(r"[^a-z0-9-]+", "-", name.lower()) + return None def patch_already_checked(self, pod: k8s.V1Pod): """Add an "already checked" annotation to ensure we don't reattach on retries""" @@ -481,7 +534,7 @@ def on_kill(self) -> None: kwargs.update(grace_period_seconds=self.termination_grace_period) self.client.delete_namespaced_pod(**kwargs) - def build_pod_request_obj(self, context=None): + def build_pod_request_obj(self, context: Context | None = None) -> k8s.V1Pod: """ Returns V1Pod object based on pod template file, full pod spec, and other operator parameters. @@ -497,7 +550,7 @@ def build_pod_request_obj(self, context=None): elif self.full_pod_spec: pod_template = self.full_pod_spec else: - pod_template = k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="name")) + pod_template = k8s.V1Pod(metadata=k8s.V1ObjectMeta()) pod = k8s.V1Pod( api_version="v1", @@ -520,11 +573,12 @@ def build_pod_request_obj(self, context=None): command=self.cmds, ports=self.ports, image_pull_policy=self.image_pull_policy, - resources=self.k8s_resources, + resources=self.container_resources, volume_mounts=self.volume_mounts, args=self.arguments, env=self.env_vars, env_from=self.env_from, + security_context=self.container_security_context, ) ], image_pull_secrets=self.image_pull_secrets, @@ -533,7 +587,7 @@ def build_pod_request_obj(self, context=None): security_context=self.security_context, dns_policy=self.dnspolicy, scheduler_name=self.schedulername, - restart_policy='Never', + restart_policy="Never", priority_class_name=self.priority_class_name, volumes=self.volumes, ), @@ -541,18 +595,30 @@ def build_pod_request_obj(self, context=None): pod = PodGenerator.reconcile_pods(pod_template, pod) + if not pod.metadata.name: + pod.metadata.name = _task_id_to_pod_name(self.task_id) + if self.random_name_suffix: pod.metadata.name = PodGenerator.make_unique_pod_id(pod.metadata.name) + if not pod.metadata.namespace: + # todo: replace with call to `hook.get_namespace` in 6.0, when it doesn't default to `default`. + # if namespace not actually defined in hook, we want to check k8s if in cluster + hook_namespace = self.hook._get_namespace() + pod_namespace = self.namespace or hook_namespace or self._incluster_namespace or "default" + pod.metadata.namespace = pod_namespace + for secret in self.secrets: self.log.debug("Adding secret to task %s", self.task_id) pod = secret.attach_to_pod(pod) if self.do_xcom_push: self.log.debug("Adding xcom sidecar to task %s", self.task_id) - pod = xcom_sidecar.add_xcom_sidecar(pod) + pod = xcom_sidecar.add_xcom_sidecar( + pod, sidecar_container_image=self.hook.get_xcom_sidecar_container_image() + ) labels = self._get_ti_pod_labels(context) - self.log.info("Creating pod %s with labels: %s", pod.metadata.name, labels) + self.log.info("Building pod %s with labels: %s", pod.metadata.name, labels) # Merge Pod Identifying labels with labels passed to operator pod.metadata.labels.update(labels) @@ -560,7 +626,8 @@ def build_pod_request_obj(self, context=None): # And a label to identify that pod is launched by KubernetesPodOperator pod.metadata.labels.update( { - 'airflow_version': airflow_version.replace('+', '-'), + "airflow_version": airflow_version.replace("+", "-"), + "airflow_kpo_in_cluster": str(self.hook.is_in_cluster), } ) pod_mutation_hook(pod) @@ -573,40 +640,7 @@ def dry_run(self) -> None: one in a dry_run) and excludes all empty elements. """ pod = self.build_pod_request_obj() - print(yaml.dump(prune_dict(pod.to_dict(), mode='strict'))) - - def _patch_deprecated_k8s_settings(self, hook: KubernetesHook): - """ - Here we read config from core Airflow config [kubernetes] section. - In a future release we will stop looking at this section and require users - to use Airflow connections to configure KPO. - - When we find values there that we need to apply on the hook, we patch special - hook attributes here. - """ - - # default for enable_tcp_keepalive is True; patch if False - if conf.getboolean('kubernetes', 'enable_tcp_keepalive') is False: - hook._deprecated_core_disable_tcp_keepalive = True - - # default verify_ssl is True; patch if False. - if conf.getboolean('kubernetes', 'verify_ssl') is False: - hook._deprecated_core_disable_verify_ssl = True - - # default for in_cluster is True; patch if False and no KPO param. - conf_in_cluster = conf.getboolean('kubernetes', 'in_cluster') - if self.in_cluster is None and conf_in_cluster is False: - hook._deprecated_core_in_cluster = conf_in_cluster - - # there's no default for cluster context; if we get something (and no KPO param) patch it. - conf_cluster_context = conf.get('kubernetes', 'cluster_context', fallback=None) - if not self.cluster_context and conf_cluster_context: - hook._deprecated_core_cluster_context = conf_cluster_context - - # there's no default for config_file; if we get something (and no KPO param) patch it. - conf_config_file = conf.get('kubernetes', 'config_file', fallback=None) - if not self.config_file and conf_config_file: - hook._deprecated_core_config_file = conf_config_file + print(yaml.dump(prune_dict(pod.to_dict(), mode="strict"))) class _suppress(AbstractContextManager): @@ -630,5 +664,5 @@ def __exit__(self, exctype, excinst, exctb): if caught_error: self.exception = excinst logger = logging.getLogger(__name__) - logger.error(str(excinst), exc_info=True) + logger.exception(excinst) return caught_error diff --git a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py index fbf0aebd4948c..ff3828ffb0d4a 100644 --- a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py +++ b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook @@ -41,18 +43,18 @@ class SparkKubernetesOperator(BaseOperator): :param api_version: kubernetes api version of sparkApplication """ - template_fields: Sequence[str] = ('application_file', 'namespace') - template_ext: Sequence[str] = ('.yaml', '.yml', '.json') - ui_color = '#f4a460' + template_fields: Sequence[str] = ("application_file", "namespace") + template_ext: Sequence[str] = (".yaml", ".yml", ".json") + ui_color = "#f4a460" def __init__( self, *, application_file: str, - namespace: Optional[str] = None, - kubernetes_conn_id: str = 'kubernetes_default', - api_group: str = 'sparkoperator.k8s.io', - api_version: str = 'v1beta2', + namespace: str | None = None, + kubernetes_conn_id: str = "kubernetes_default", + api_group: str = "sparkoperator.k8s.io", + api_version: str = "v1beta2", **kwargs, ) -> None: super().__init__(**kwargs) @@ -63,7 +65,7 @@ def __init__( self.api_version = api_version self.plural = "sparkapplications" - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = KubernetesHook(conn_id=self.kubernetes_conn_id) self.log.info("Creating sparkApplication") response = hook.create_custom_object( diff --git a/airflow/providers/cncf/kubernetes/provider.yaml b/airflow/providers/cncf/kubernetes/provider.yaml index ce2ce89eddeec..3de2fb8717c80 100644 --- a/airflow/providers/cncf/kubernetes/provider.yaml +++ b/airflow/providers/cncf/kubernetes/provider.yaml @@ -22,6 +22,11 @@ description: | `Kubernetes `__ versions: + - 5.0.0 + - 4.4.0 + - 4.3.0 + - 4.2.0 + - 4.1.0 - 4.0.2 - 4.0.1 - 4.0.0 @@ -43,8 +48,18 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: +dependencies: - apache-airflow>=2.3.0 + - cryptography>=2.0.0 + # The Kubernetes API is known to introduce problems when upgraded to a MAJOR version. Airflow Core + # Uses Kubernetes for Kubernetes executor, and we also know that Kubernetes Python client follows SemVer + # (https://github.com/kubernetes-client/python#compatibility). This is a crucial component of Airflow + # So we should limit it to the next MAJOR version and only deliberately bump the version when we + # tested it, and we know it can be bumped. Bumping this version should also be connected with + # limiting minimum airflow version supported in cncf.kubernetes provider, due to the + # potential breaking changes in Airflow Core as well (kubernetes is added as extra, so Airflow + # core is not hard-limited via install-requirements, only by extra). + - kubernetes>=21.7.0,<24 integrations: - integration-name: Kubernetes @@ -74,9 +89,11 @@ hooks: python-modules: - airflow.providers.cncf.kubernetes.hooks.kubernetes -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook connection-types: - hook-class-name: airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook connection-type: kubernetes + +task-decorators: + - class-name: airflow.providers.cncf.kubernetes.decorators.kubernetes.kubernetes_task + name: kubernetes diff --git a/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2 b/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2 new file mode 100644 index 0000000000000..c961f10de4e5c --- /dev/null +++ b/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2 @@ -0,0 +1,44 @@ +{# + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +-#} + +import {{ pickling_library }} +import sys + +{# Check whether Airflow is available in the environment. + # If it is, we'll want to ensure that we integrate any macros that are being provided + # by plugins prior to unpickling the task context. #} +if sys.version_info >= (3,6): + try: + from airflow.plugins_manager import integrate_macros_plugins + integrate_macros_plugins() + except ImportError: + {# Airflow is not available in this environment, therefore we won't + # be able to integrate any plugin macros. #} + pass + +{% if op_args or op_kwargs %} +with open(sys.argv[1], "rb") as file: + arg_dict = {{ pickling_library }}.load(file) +{% else %} +arg_dict = {"args": [], "kwargs": {}} +{% endif %} + +# Script +{{ python_callable_source }} +res = {{ python_callable }}(*arg_dict["args"], **arg_dict["kwargs"]) diff --git a/airflow/providers/cncf/kubernetes/python_kubernetes_script.py b/airflow/providers/cncf/kubernetes/python_kubernetes_script.py new file mode 100644 index 0000000000000..785daf6e56fb2 --- /dev/null +++ b/airflow/providers/cncf/kubernetes/python_kubernetes_script.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utilities for using the kubernetes decorator""" +from __future__ import annotations + +import os +from collections import deque + +import jinja2 + + +def _balance_parens(after_decorator): + num_paren = 1 + after_decorator = deque(after_decorator) + after_decorator.popleft() + while num_paren: + current = after_decorator.popleft() + if current == "(": + num_paren = num_paren + 1 + elif current == ")": + num_paren = num_paren - 1 + return "".join(after_decorator) + + +def remove_task_decorator(python_source: str, task_decorator_name: str) -> str: + """ + Removed @kubernetes_task + + :param python_source: + """ + if task_decorator_name not in python_source: + return python_source + split = python_source.split(task_decorator_name) + before_decorator, after_decorator = split[0], split[1] + if after_decorator[0] == "(": + after_decorator = _balance_parens(after_decorator) + if after_decorator[0] == "\n": + after_decorator = after_decorator[1:] + return before_decorator + after_decorator + + +def write_python_script( + jinja_context: dict, + filename: str, + render_template_as_native_obj: bool = False, +): + """ + Renders the python script to a file to execute in the virtual environment. + + :param jinja_context: The jinja context variables to unpack and replace with its placeholders in the + template file. + :param filename: The name of the file to dump the rendered script to. + :param render_template_as_native_obj: If ``True``, rendered Jinja template would be converted + to a native Python object + """ + template_loader = jinja2.FileSystemLoader(searchpath=os.path.dirname(__file__)) + template_env: jinja2.Environment + if render_template_as_native_obj: + template_env = jinja2.nativetypes.NativeEnvironment( + loader=template_loader, undefined=jinja2.StrictUndefined + ) + else: + template_env = jinja2.Environment(loader=template_loader, undefined=jinja2.StrictUndefined) + template = template_env.get_template("python_kubernetes_script.jinja2") + template.stream(**jinja_context).dump(filename) diff --git a/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py b/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py index 15ac40bcdb90a..6c09202ab8782 100644 --- a/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py +++ b/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from kubernetes import client @@ -37,6 +39,7 @@ class SparkKubernetesSensor(BaseSensorOperator): :param application_name: spark Application resource name :param namespace: the kubernetes namespace where the sparkApplication reside in + :param container_name: the kubernetes container name where the sparkApplication reside in :param kubernetes_conn_id: The :ref:`kubernetes connection` to Kubernetes cluster. :param attach_log: determines whether logs for driver pod should be appended to the sensor log @@ -53,16 +56,18 @@ def __init__( *, application_name: str, attach_log: bool = False, - namespace: Optional[str] = None, + namespace: str | None = None, + container_name: str = "spark-kubernetes-driver", kubernetes_conn_id: str = "kubernetes_default", - api_group: str = 'sparkoperator.k8s.io', - api_version: str = 'v1beta2', + api_group: str = "sparkoperator.k8s.io", + api_version: str = "v1beta2", **kwargs, ) -> None: super().__init__(**kwargs) self.application_name = application_name self.attach_log = attach_log self.namespace = namespace + self.container_name = container_name self.kubernetes_conn_id = kubernetes_conn_id self.hook = KubernetesHook(conn_id=self.kubernetes_conn_id) self.api_group = api_group @@ -82,7 +87,9 @@ def _log_driver(self, application_state: str, response: dict) -> None: log_method = self.log.error if application_state in self.FAILURE_STATES else self.log.info try: log = "" - for line in self.hook.get_pod_logs(driver_pod_name, namespace=namespace): + for line in self.hook.get_pod_logs( + driver_pod_name, namespace=namespace, container=self.container_name + ): log += line.decode() log_method(log) except client.rest.ApiException as e: @@ -94,7 +101,7 @@ def _log_driver(self, application_state: str, response: dict) -> None: e, ) - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: self.log.info("Poking: %s", self.application_name) response = self.hook.get_custom_object( group=self.api_group, diff --git a/airflow/providers/cncf/kubernetes/utils/__init__.py b/airflow/providers/cncf/kubernetes/utils/__init__.py index 9cebc80e1d022..84e243c6dbc76 100644 --- a/airflow/providers/cncf/kubernetes/utils/__init__.py +++ b/airflow/providers/cncf/kubernetes/utils/__init__.py @@ -14,4 +14,4 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -__all__ = ['xcom_sidecar', 'pod_manager'] +__all__ = ["xcom_sidecar", "pod_manager"] diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/airflow/providers/cncf/kubernetes/utils/pod_manager.py index 27c9439dbde1d..ffd3e7ec4bd88 100644 --- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Launches PODs""" +from __future__ import annotations + import json import math import time @@ -22,7 +24,7 @@ from contextlib import closing from dataclasses import dataclass from datetime import datetime -from typing import TYPE_CHECKING, Iterable, Optional, Tuple, cast +from typing import TYPE_CHECKING, Iterable, cast import pendulum import tenacity @@ -60,10 +62,10 @@ class PodPhase: See https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-phase. """ - PENDING = 'Pending' - RUNNING = 'Running' - FAILED = 'Failed' - SUCCEEDED = 'Succeeded' + PENDING = "Pending" + RUNNING = "Running" + FAILED = "Failed" + SUCCEEDED = "Succeeded" terminal_states = {FAILED, SUCCEEDED} @@ -76,7 +78,7 @@ def container_is_running(pod: V1Pod, container_name: str) -> bool: container_statuses = pod.status.container_statuses if pod and pod.status else None if not container_statuses: return False - container_status = next(iter([x for x in container_statuses if x.name == container_name]), None) + container_status = next((x for x in container_statuses if x.name == container_name), None) if not container_status: return False return container_status.state.running is not None @@ -85,9 +87,9 @@ def container_is_running(pod: V1Pod, container_name: str) -> bool: def get_container_termination_message(pod: V1Pod, container_name: str): try: container_statuses = pod.status.container_statuses - container_status = next(iter([x for x in container_statuses if x.name == container_name]), None) + container_status = next((x for x in container_statuses if x.name == container_name), None) return container_status.state.terminated.message if container_status else None - except AttributeError: + except (AttributeError, TypeError): return None @@ -96,7 +98,7 @@ class PodLoggingStatus: """Used for returning the status of the pod and last log time when exiting from `fetch_container_logs`""" running: bool - last_log_time: Optional[DateTime] + last_log_time: DateTime | None class PodManager(LoggingMixin): @@ -109,7 +111,7 @@ def __init__( self, kube_client: client.CoreV1Api = None, in_cluster: bool = True, - cluster_context: Optional[str] = None, + cluster_context: str | None = None, ): """ Creates the launcher. @@ -119,7 +121,16 @@ def __init__( :param cluster_context: context of the cluster """ super().__init__() - self._client = kube_client or get_kube_client(in_cluster=in_cluster, cluster_context=cluster_context) + if kube_client: + self._client = kube_client + else: + self._client = get_kube_client(in_cluster=in_cluster, cluster_context=cluster_context) + warnings.warn( + "`kube_client` not supplied to PodManager. " + "This will be a required argument in a future release. " + "Please use KubernetesHook to create the client before calling.", + DeprecationWarning, + ) self._watch = watch.Watch() def run_pod_async(self, pod: V1Pod, **kwargs) -> V1Pod: @@ -127,15 +138,15 @@ def run_pod_async(self, pod: V1Pod, **kwargs) -> V1Pod: sanitized_pod = self._client.api_client.sanitize_for_serialization(pod) json_pod = json.dumps(sanitized_pod, indent=2) - self.log.debug('Pod Creation Request: \n%s', json_pod) + self.log.debug("Pod Creation Request: \n%s", json_pod) try: resp = self._client.create_namespaced_pod( body=sanitized_pod, namespace=pod.metadata.namespace, **kwargs ) - self.log.debug('Pod Creation Response: %s', resp) + self.log.debug("Pod Creation Response: %s", resp) except Exception as e: self.log.exception( - 'Exception when attempting to create Namespaced Pod: %s', str(json_pod).replace("\n", " ") + "Exception when attempting to create Namespaced Pod: %s", str(json_pod).replace("\n", " ") ) raise e return resp @@ -194,14 +205,14 @@ def follow_container_logs(self, pod: V1Pod, container_name: str) -> PodLoggingSt return self.fetch_container_logs(pod=pod, container_name=container_name, follow=True) def fetch_container_logs( - self, pod: V1Pod, container_name: str, *, follow=False, since_time: Optional[DateTime] = None + self, pod: V1Pod, container_name: str, *, follow=False, since_time: DateTime | None = None ) -> PodLoggingStatus: """ Follows the logs of container and streams to airflow logging. Returns when container exits. """ - def consume_logs(*, since_time: Optional[DateTime] = None, follow: bool = True) -> Optional[DateTime]: + def consume_logs(*, since_time: DateTime | None = None, follow: bool = True) -> DateTime | None: """ Tries to follow container logs until container completes. For a long-running container, sometimes the log read may be interrupted @@ -221,7 +232,7 @@ def consume_logs(*, since_time: Optional[DateTime] = None, follow: bool = True) follow=follow, ) for raw_line in logs: - line = raw_line.decode('utf-8', errors="backslashreplace") + line = raw_line.decode("utf-8", errors="backslashreplace") timestamp, message = self.parse_log_line(line) self.log.info(message) except BaseHTTPError as e: @@ -249,7 +260,7 @@ def consume_logs(*, since_time: Optional[DateTime] = None, follow: bool = True) return PodLoggingStatus(running=True, last_log_time=last_log_time) else: self.log.warning( - 'Pod %s log read interrupted but container %s still running', + "Pod %s log read interrupted but container %s still running", pod.metadata.name, container_name, ) @@ -264,25 +275,24 @@ def await_pod_completion(self, pod: V1Pod) -> V1Pod: Monitors a pod and returns the final state :param pod: pod spec that will be monitored - :return: Tuple[State, Optional[str]] + :return: tuple[State, str | None] """ while True: remote_pod = self.read_pod(pod) if remote_pod.status.phase in PodPhase.terminal_states: break - self.log.info('Pod %s has phase %s', pod.metadata.name, remote_pod.status.phase) + self.log.info("Pod %s has phase %s", pod.metadata.name, remote_pod.status.phase) time.sleep(2) return remote_pod - def parse_log_line(self, line: str) -> Tuple[Optional[DateTime], str]: + def parse_log_line(self, line: str) -> tuple[DateTime | None, str]: """ Parse K8s log line and returns the final state :param line: k8s log line :return: timestamp and log message - :rtype: Tuple[str, str] """ - split_at = line.find(' ') + split_at = line.find(" ") if split_at == -1: self.log.error( "Error parsing timestamp (no timestamp in message %r). " @@ -309,18 +319,18 @@ def read_pod_logs( self, pod: V1Pod, container_name: str, - tail_lines: Optional[int] = None, + tail_lines: int | None = None, timestamps: bool = False, - since_seconds: Optional[int] = None, + since_seconds: int | None = None, follow=True, ) -> Iterable[bytes]: """Reads log from the POD""" additional_kwargs = {} if since_seconds: - additional_kwargs['since_seconds'] = since_seconds + additional_kwargs["since_seconds"] = since_seconds if tail_lines: - additional_kwargs['tail_lines'] = tail_lines + additional_kwargs["tail_lines"] = tail_lines try: return self._client.read_namespaced_pod_log( @@ -333,18 +343,18 @@ def read_pod_logs( **additional_kwargs, ) except BaseHTTPError: - self.log.exception('There was an error reading the kubernetes API.') + self.log.exception("There was an error reading the kubernetes API.") raise @tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True) - def read_pod_events(self, pod: V1Pod) -> "CoreV1EventList": + def read_pod_events(self, pod: V1Pod) -> CoreV1EventList: """Reads events from the POD""" try: return self._client.list_namespaced_event( namespace=pod.metadata.namespace, field_selector=f"involvedObject.name={pod.metadata.name}" ) except BaseHTTPError as e: - raise AirflowException(f'There was an error reading the kubernetes API: {e}') + raise AirflowException(f"There was an error reading the kubernetes API: {e}") @tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True) def read_pod(self, pod: V1Pod) -> V1Pod: @@ -352,7 +362,19 @@ def read_pod(self, pod: V1Pod) -> V1Pod: try: return self._client.read_namespaced_pod(pod.metadata.name, pod.metadata.namespace) except BaseHTTPError as e: - raise AirflowException(f'There was an error reading the kubernetes API: {e}') + raise AirflowException(f"There was an error reading the kubernetes API: {e}") + + def await_xcom_sidecar_container_start(self, pod: V1Pod) -> None: + self.log.info("Checking if xcom sidecar container is started.") + warned = False + while True: + if self.container_is_running(pod, PodDefaults.SIDECAR_CONTAINER_NAME): + self.log.info("The xcom sidecar container is started.") + break + if not warned: + self.log.warning("The xcom sidecar container is not yet started.") + warned = True + time.sleep(1) def extract_xcom(self, pod: V1Pod) -> str: """Retrieves XCom value and kills xcom sidecar container""" @@ -362,7 +384,7 @@ def extract_xcom(self, pod: V1Pod) -> str: pod.metadata.name, pod.metadata.namespace, container=PodDefaults.SIDECAR_CONTAINER_NAME, - command=['/bin/sh'], + command=["/bin/sh"], stdin=True, stdout=True, stderr=True, @@ -370,17 +392,20 @@ def extract_xcom(self, pod: V1Pod) -> str: _preload_content=False, ) ) as resp: - result = self._exec_pod_command(resp, f'cat {PodDefaults.XCOM_MOUNT_PATH}/return.json') - self._exec_pod_command(resp, 'kill -s SIGINT 1') + result = self._exec_pod_command( + resp, + f"if [ -s {PodDefaults.XCOM_MOUNT_PATH}/return.json ]; then cat {PodDefaults.XCOM_MOUNT_PATH}/return.json; else echo __airflow_xcom_result_empty__; fi", # noqa + ) + self._exec_pod_command(resp, "kill -s SIGINT 1") if result is None: - raise AirflowException(f'Failed to extract xcom from pod: {pod.metadata.name}') + raise AirflowException(f"Failed to extract xcom from pod: {pod.metadata.name}") return result - def _exec_pod_command(self, resp, command: str) -> Optional[str]: + def _exec_pod_command(self, resp, command: str) -> str | None: res = None if resp.is_open(): - self.log.info('Running command... %s\n', command) - resp.write_stdin(command + '\n') + self.log.info("Running command... %s\n", command) + resp.write_stdin(command + "\n") while resp.is_open(): resp.update(timeout=1) while resp.peek_stdout(): diff --git a/airflow/providers/cncf/kubernetes/utils/xcom_sidecar.py b/airflow/providers/cncf/kubernetes/utils/xcom_sidecar.py index a8c0ea4c1936f..81b3047993691 100644 --- a/airflow/providers/cncf/kubernetes/utils/xcom_sidecar.py +++ b/airflow/providers/cncf/kubernetes/utils/xcom_sidecar.py @@ -19,6 +19,8 @@ by attaching a sidecar container that blocks the pod from completing until Airflow has pulled result data into the worker for xcom serialization. """ +from __future__ import annotations + import copy from kubernetes.client import models as k8s @@ -27,15 +29,15 @@ class PodDefaults: """Static defaults for Pods""" - XCOM_MOUNT_PATH = '/airflow/xcom' - SIDECAR_CONTAINER_NAME = 'airflow-xcom-sidecar' + XCOM_MOUNT_PATH = "/airflow/xcom" + SIDECAR_CONTAINER_NAME = "airflow-xcom-sidecar" XCOM_CMD = 'trap "exit 0" INT; while true; do sleep 1; done;' - VOLUME_MOUNT = k8s.V1VolumeMount(name='xcom', mount_path=XCOM_MOUNT_PATH) - VOLUME = k8s.V1Volume(name='xcom', empty_dir=k8s.V1EmptyDirVolumeSource()) + VOLUME_MOUNT = k8s.V1VolumeMount(name="xcom", mount_path=XCOM_MOUNT_PATH) + VOLUME = k8s.V1Volume(name="xcom", empty_dir=k8s.V1EmptyDirVolumeSource()) SIDECAR_CONTAINER = k8s.V1Container( name=SIDECAR_CONTAINER_NAME, - command=['sh', '-c', XCOM_CMD], - image='alpine', + command=["sh", "-c", XCOM_CMD], + image="alpine", volume_mounts=[VOLUME_MOUNT], resources=k8s.V1ResourceRequirements( requests={ @@ -45,13 +47,15 @@ class PodDefaults: ) -def add_xcom_sidecar(pod: k8s.V1Pod) -> k8s.V1Pod: +def add_xcom_sidecar(pod: k8s.V1Pod, *, sidecar_container_image=None) -> k8s.V1Pod: """Adds sidecar""" pod_cp = copy.deepcopy(pod) pod_cp.spec.volumes = pod.spec.volumes or [] pod_cp.spec.volumes.insert(0, PodDefaults.VOLUME) pod_cp.spec.containers[0].volume_mounts = pod_cp.spec.containers[0].volume_mounts or [] pod_cp.spec.containers[0].volume_mounts.insert(0, PodDefaults.VOLUME_MOUNT) - pod_cp.spec.containers.append(PodDefaults.SIDECAR_CONTAINER) + sidecar = copy.deepcopy(PodDefaults.SIDECAR_CONTAINER) + sidecar.image = sidecar_container_image or PodDefaults.SIDECAR_CONTAINER.image + pod_cp.spec.containers.append(sidecar) return pod_cp diff --git a/airflow/providers/github/example_dags/__init__.py b/airflow/providers/common/__init__.py similarity index 100% rename from airflow/providers/github/example_dags/__init__.py rename to airflow/providers/common/__init__.py diff --git a/airflow/providers/common/sql/.latest-doc-only-change.txt b/airflow/providers/common/sql/.latest-doc-only-change.txt new file mode 100644 index 0000000000000..ff7136e07d744 --- /dev/null +++ b/airflow/providers/common/sql/.latest-doc-only-change.txt @@ -0,0 +1 @@ +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/common/sql/CHANGELOG.rst b/airflow/providers/common/sql/CHANGELOG.rst new file mode 100644 index 0000000000000..16c1cd098338e --- /dev/null +++ b/airflow/providers/common/sql/CHANGELOG.rst @@ -0,0 +1,127 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + + +Changelog +--------- + +1.3.1 +..... + +This release fixes a few errors that were introduced in common.sql operator while refactoring common parts: + +* ``_process_output`` method in ``SQLExecuteQueryOperator`` has now consistent semantics and typing, it + can also modify the returned (and stored in XCom) values in the operators that derive from the + ``SQLExecuteQueryOperator``). +* last description of the cursor whether to return scalar values are now stored in DBApiHook + +Lack of consistency in the operator caused ``1.3.0`` to be yanked - the ``1.3.0`` should not be used - if +you have ``1.3.0`` installed, upgrade to ``1.3.1``. + +Bug Fixes +~~~~~~~~~ + +* ``Restore removed (but used) methods in common.sql (#27843)`` +* ``Fix errors in Databricks SQL operator introduced when refactoring (#27854)`` + + +1.3.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Features +~~~~~~~~ + +* ``Add SQLExecuteQueryOperator (#25717)`` +* ``Use DbApiHook.run for DbApiHook.get_records and DbApiHook.get_first (#26944)`` +* ``DbApiHook consistent insert_rows logging (#26758)`` + +Bug Fixes +~~~~~~~~~ + +* ``Common sql bugfixes and improvements (#26761)`` +* ``Use unused SQLCheckOperator.parameters in SQLCheckOperator.execute. (#27599)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +1.2.0 +..... + +Features +~~~~~~~~ + +* ``Make placeholder style configurable (#25939)`` +* ``Better error message for pre-common-sql providers (#26051)`` + +Bug Fixes +~~~~~~~~~ + +* ``Fix (and test) SQLTableCheckOperator on postgresql (#25821)`` +* ``Don't use Pandas for SQLTableCheckOperator (#25822)`` +* ``Discard semicolon stripping in SQL hook (#25855)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + + +1.1.0 +..... + +Features +~~~~~~~~ + +* ``Improve taskflow type hints with ParamSpec (#25173)`` +* ``Move all "old" SQL operators to common.sql providers (#25350)`` +* ``Deprecate hql parameters and synchronize DBApiHook method APIs (#25299)`` +* ``Unify DbApiHook.run() method with the methods which override it (#23971)`` +* ``Common SQLCheckOperators Various Functionality Update (#25164)`` + +Bug Fixes +~~~~~~~~~ + +* ``Allow Legacy SqlSensor to use the common.sql providers (#25293)`` +* ``Fix fetch_all_handler & db-api tests for it (#25430)`` +* ``Align Common SQL provider logo location (#25538)`` +* ``Fix SQL split string to include ';-less' statements (#25713)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Fix CHANGELOG for common.sql provider and add amazon commit (#25636)`` + +1.0.0 +..... + +Initial version of the provider. +Adds ``SQLColumnCheckOperator`` and ``SQLTableCheckOperator``. +Moves ``DBApiHook``, ``SQLSensor`` and ``ConnectorProtocol`` to the provider. diff --git a/airflow/providers/google/ads/example_dags/__init__.py b/airflow/providers/common/sql/__init__.py similarity index 100% rename from airflow/providers/google/ads/example_dags/__init__.py rename to airflow/providers/common/sql/__init__.py diff --git a/airflow/providers/google/firebase/example_dags/__init__.py b/airflow/providers/common/sql/hooks/__init__.py similarity index 100% rename from airflow/providers/google/firebase/example_dags/__init__.py rename to airflow/providers/common/sql/hooks/__init__.py diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py new file mode 100644 index 0000000000000..df808430fd9aa --- /dev/null +++ b/airflow/providers/common/sql/hooks/sql.py @@ -0,0 +1,433 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from contextlib import closing +from datetime import datetime +from typing import Any, Callable, Iterable, Mapping, Sequence, cast + +import sqlparse +from packaging.version import Version +from sqlalchemy import create_engine +from typing_extensions import Protocol + +from airflow import AirflowException +from airflow.hooks.base import BaseHook +from airflow.version import version + + +def fetch_all_handler(cursor) -> list[tuple] | None: + """Handler for DbApiHook.run() to return results""" + if cursor.description is not None: + return cursor.fetchall() + else: + return None + + +def fetch_one_handler(cursor) -> list[tuple] | None: + """Handler for DbApiHook.run() to return results""" + if cursor.description is not None: + return cursor.fetchone() + else: + return None + + +class ConnectorProtocol(Protocol): + """A protocol where you can connect to a database.""" + + def connect(self, host: str, port: int, username: str, schema: str) -> Any: + """ + Connect to a database. + + :param host: The database host to connect to. + :param port: The database port to connect to. + :param username: The database username used for the authentication. + :param schema: The database schema to connect to. + :return: the authorized connection object. + """ + + +# In case we are running it on Airflow 2.4+, we should use BaseHook, but on Airflow 2.3 and below +# We want the DbApiHook to derive from the original DbApiHook from airflow, because otherwise +# SqlSensor and BaseSqlOperator from "airflow.operators" and "airflow.sensors" will refuse to +# accept the new Hooks as not derived from the original DbApiHook +if Version(version) < Version("2.4"): + try: + from airflow.hooks.dbapi import DbApiHook as BaseForDbApiHook + except ImportError: + # just in case we have a problem with circular import + BaseForDbApiHook: type[BaseHook] = BaseHook # type: ignore[no-redef] +else: + BaseForDbApiHook: type[BaseHook] = BaseHook # type: ignore[no-redef] + + +class DbApiHook(BaseForDbApiHook): + """ + Abstract base class for sql hooks. + + :param schema: Optional DB schema that overrides the schema specified in the connection. Make sure that + if you change the schema parameter value in the constructor of the derived Hook, such change + should be done before calling the ``DBApiHook.__init__()``. + :param log_sql: Whether to log SQL query when it's executed. Defaults to *True*. + """ + + # Override to provide the connection name. + conn_name_attr: str + # Override to have a default connection id for a particular dbHook + default_conn_name = "default_conn_id" + # Override if this db supports autocommit. + supports_autocommit = False + # Override with the object that exposes the connect method + connector: ConnectorProtocol | None = None + # Override with db-specific query to check connection + _test_connection_sql = "select 1" + # Override with the db-specific value used for placeholders + placeholder: str = "%s" + + def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwargs): + super().__init__() + if not self.conn_name_attr: + raise AirflowException("conn_name_attr is not defined") + elif len(args) == 1: + setattr(self, self.conn_name_attr, args[0]) + elif self.conn_name_attr not in kwargs: + setattr(self, self.conn_name_attr, self.default_conn_name) + else: + setattr(self, self.conn_name_attr, kwargs[self.conn_name_attr]) + # We should not make schema available in deriving hooks for backwards compatibility + # If a hook deriving from DBApiHook has a need to access schema, then it should retrieve it + # from kwargs and store it on its own. We do not run "pop" here as we want to give the + # Hook deriving from the DBApiHook to still have access to the field in its constructor + self.__schema = schema + self.log_sql = log_sql + self.scalar_return_last = False + self.last_description: Sequence[Sequence] | None = None + + def get_conn(self): + """Returns a connection object""" + db = self.get_connection(getattr(self, cast(str, self.conn_name_attr))) + return self.connector.connect(host=db.host, port=db.port, username=db.login, schema=db.schema) + + def get_uri(self) -> str: + """ + Extract the URI from the connection. + + :return: the extracted uri. + """ + conn = self.get_connection(getattr(self, self.conn_name_attr)) + conn.schema = self.__schema or conn.schema + return conn.get_uri() + + def get_sqlalchemy_engine(self, engine_kwargs=None): + """ + Get an sqlalchemy_engine object. + + :param engine_kwargs: Kwargs used in :func:`~sqlalchemy.create_engine`. + :return: the created engine. + """ + if engine_kwargs is None: + engine_kwargs = {} + return create_engine(self.get_uri(), **engine_kwargs) + + def get_pandas_df(self, sql, parameters=None, **kwargs): + """ + Executes the sql and returns a pandas dataframe + + :param sql: the sql statement to be executed (str) or a list of + sql statements to execute + :param parameters: The parameters to render the SQL query with. + :param kwargs: (optional) passed into pandas.io.sql.read_sql method + """ + try: + from pandas.io import sql as psql + except ImportError: + raise Exception( + "pandas library not installed, run: pip install " + "'apache-airflow-providers-common-sql[pandas]'." + ) + + with closing(self.get_conn()) as conn: + return psql.read_sql(sql, con=conn, params=parameters, **kwargs) + + def get_pandas_df_by_chunks(self, sql, parameters=None, *, chunksize, **kwargs): + """ + Executes the sql and returns a generator + + :param sql: the sql statement to be executed (str) or a list of + sql statements to execute + :param parameters: The parameters to render the SQL query with + :param chunksize: number of rows to include in each chunk + :param kwargs: (optional) passed into pandas.io.sql.read_sql method + """ + try: + from pandas.io import sql as psql + except ImportError: + raise Exception( + "pandas library not installed, run: pip install " + "'apache-airflow-providers-common-sql[pandas]'." + ) + + with closing(self.get_conn()) as conn: + yield from psql.read_sql(sql, con=conn, params=parameters, chunksize=chunksize, **kwargs) + + def get_records( + self, + sql: str | list[str], + parameters: Iterable | Mapping | None = None, + ) -> Any: + """ + Executes the sql and returns a set of records. + + :param sql: the sql statement to be executed (str) or a list of sql statements to execute + :param parameters: The parameters to render the SQL query with. + """ + return self.run(sql=sql, parameters=parameters, handler=fetch_all_handler) + + def get_first(self, sql: str | list[str], parameters: Iterable | Mapping | None = None) -> Any: + """ + Executes the sql and returns the first resulting row. + + :param sql: the sql statement to be executed (str) or a list of sql statements to execute + :param parameters: The parameters to render the SQL query with. + """ + return self.run(sql=sql, parameters=parameters, handler=fetch_one_handler) + + @staticmethod + def strip_sql_string(sql: str) -> str: + return sql.strip().rstrip(";") + + @staticmethod + def split_sql_string(sql: str) -> list[str]: + """ + Splits string into multiple SQL expressions + + :param sql: SQL string potentially consisting of multiple expressions + :return: list of individual expressions + """ + splits = sqlparse.split(sqlparse.format(sql, strip_comments=True)) + statements: list[str] = list(filter(None, splits)) + return statements + + def run( + self, + sql: str | Iterable[str], + autocommit: bool = False, + parameters: Iterable | Mapping | None = None, + handler: Callable | None = None, + split_statements: bool = False, + return_last: bool = True, + ) -> Any | list[Any] | None: + """ + Runs a command or a list of commands. Pass a list of sql + statements to the sql parameter to get them to execute + sequentially + + :param sql: the sql statement to be executed (str) or a list of + sql statements to execute + :param autocommit: What to set the connection's autocommit setting to + before executing the query. + :param parameters: The parameters to render the SQL query with. + :param handler: The result handler which is called with the result of each statement. + :param split_statements: Whether to split a single SQL string into statements and run separately + :param return_last: Whether to return result for only last statement or for all after split + :return: return only result of the ALL SQL expressions if handler was provided. + """ + self.scalar_return_last = isinstance(sql, str) and return_last + if isinstance(sql, str): + if split_statements: + sql = self.split_sql_string(sql) + else: + sql = [sql] + + if sql: + self.log.debug("Executing following statements against DB: %s", list(sql)) + else: + raise ValueError("List of SQL statements is empty") + + with closing(self.get_conn()) as conn: + if self.supports_autocommit: + self.set_autocommit(conn, autocommit) + + with closing(conn.cursor()) as cur: + results = [] + for sql_statement in sql: + self._run_command(cur, sql_statement, parameters) + + if handler is not None: + result = handler(cur) + results.append(result) + self.last_description = cur.description + + # If autocommit was set to False or db does not support autocommit, we do a manual commit. + if not self.get_autocommit(conn): + conn.commit() + + if handler is None: + return None + elif self.scalar_return_last: + return results[-1] + else: + return results + + def _run_command(self, cur, sql_statement, parameters): + """Runs a statement using an already open cursor.""" + if self.log_sql: + self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters) + + if parameters: + cur.execute(sql_statement, parameters) + else: + cur.execute(sql_statement) + + # According to PEP 249, this is -1 when query result is not applicable. + if cur.rowcount >= 0: + self.log.info("Rows affected: %s", cur.rowcount) + + def set_autocommit(self, conn, autocommit): + """Sets the autocommit flag on the connection""" + if not self.supports_autocommit and autocommit: + self.log.warning( + "%s connection doesn't support autocommit but autocommit activated.", + getattr(self, self.conn_name_attr), + ) + conn.autocommit = autocommit + + def get_autocommit(self, conn) -> bool: + """ + Get autocommit setting for the provided connection. + Return True if conn.autocommit is set to True. + Return False if conn.autocommit is not set or set to False or conn + does not support autocommit. + + :param conn: Connection to get autocommit setting from. + :return: connection autocommit setting. + """ + return getattr(conn, "autocommit", False) and self.supports_autocommit + + def get_cursor(self): + """Returns a cursor""" + return self.get_conn().cursor() + + @classmethod + def _generate_insert_sql(cls, table, values, target_fields, replace, **kwargs) -> str: + """ + Helper class method that generates the INSERT SQL statement. + The REPLACE variant is specific to MySQL syntax. + + :param table: Name of the target table + :param values: The row to insert into the table + :param target_fields: The names of the columns to fill in the table + :param replace: Whether to replace instead of insert + :return: The generated INSERT or REPLACE SQL statement + """ + placeholders = [ + cls.placeholder, + ] * len(values) + + if target_fields: + target_fields = ", ".join(target_fields) + target_fields = f"({target_fields})" + else: + target_fields = "" + + if not replace: + sql = "INSERT INTO " + else: + sql = "REPLACE INTO " + sql += f"{table} {target_fields} VALUES ({','.join(placeholders)})" + return sql + + def insert_rows(self, table, rows, target_fields=None, commit_every=1000, replace=False, **kwargs): + """ + A generic way to insert a set of tuples into a table, + a new transaction is created every commit_every rows + + :param table: Name of the target table + :param rows: The rows to insert into the table + :param target_fields: The names of the columns to fill in the table + :param commit_every: The maximum number of rows to insert in one + transaction. Set to 0 to insert all rows in one transaction. + :param replace: Whether to replace instead of insert + """ + i = 0 + with closing(self.get_conn()) as conn: + if self.supports_autocommit: + self.set_autocommit(conn, False) + + conn.commit() + + with closing(conn.cursor()) as cur: + for i, row in enumerate(rows, 1): + lst = [] + for cell in row: + lst.append(self._serialize_cell(cell, conn)) + values = tuple(lst) + sql = self._generate_insert_sql(table, values, target_fields, replace, **kwargs) + self.log.debug("Generated sql: %s", sql) + cur.execute(sql, values) + if commit_every and i % commit_every == 0: + conn.commit() + self.log.info("Loaded %s rows into %s so far", i, table) + + conn.commit() + self.log.info("Done loading. Loaded a total of %s rows into %s", i, table) + + @staticmethod + def _serialize_cell(cell, conn=None) -> str | None: + """ + Returns the SQL literal of the cell as a string. + + :param cell: The cell to insert into the table + :param conn: The database connection + :return: The serialized cell + """ + if cell is None: + return None + if isinstance(cell, datetime): + return cell.isoformat() + return str(cell) + + def bulk_dump(self, table, tmp_file): + """ + Dumps a database table into a tab-delimited file + + :param table: The name of the source table + :param tmp_file: The path of the target file + """ + raise NotImplementedError() + + def bulk_load(self, table, tmp_file): + """ + Loads a tab-delimited file into a database table + + :param table: The name of the target table + :param tmp_file: The path of the file to load into the table + """ + raise NotImplementedError() + + def test_connection(self): + """Tests the connection using db-specific query""" + status, message = False, "" + try: + if self.get_first(self._test_connection_sql): + status = True + message = "Connection successfully tested" + except Exception as e: + status = False + message = str(e) + + return status, message diff --git a/airflow/providers/influxdb/example_dags/__init__.py b/airflow/providers/common/sql/operators/__init__.py similarity index 100% rename from airflow/providers/influxdb/example_dags/__init__.py rename to airflow/providers/common/sql/operators/__init__.py diff --git a/airflow/providers/common/sql/operators/sql.py b/airflow/providers/common/sql/operators/sql.py new file mode 100644 index 0000000000000..314af43003488 --- /dev/null +++ b/airflow/providers/common/sql/operators/sql.py @@ -0,0 +1,1104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import ast +import re +from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, NoReturn, Sequence, SupportsAbs, overload + +from airflow.compat.functools import cached_property +from airflow.exceptions import AirflowException, AirflowFailException +from airflow.hooks.base import BaseHook +from airflow.models import BaseOperator, SkipMixin +from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler +from airflow.typing_compat import Literal + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +def _convert_to_float_if_possible(s: str) -> float | str: + try: + return float(s) + except (ValueError, TypeError): + return s + + +def _parse_boolean(val: str) -> str | bool: + """Try to parse a string into boolean. + + Raises ValueError if the input is not a valid true- or false-like string value. + """ + val = val.lower() + if val in ("y", "yes", "t", "true", "on", "1"): + return True + if val in ("n", "no", "f", "false", "off", "0"): + return False + raise ValueError(f"{val!r} is not a boolean-like string value") + + +def _get_failed_checks(checks, col=None): + """ + IMPORTANT!!! Keep it for compatibility with released 8.4.0 version of google provider. + + Unfortunately the provider used _get_failed_checks and parse_boolean as imports and we should + keep those methods to avoid 8.4.0 version from failing. + """ + if col: + return [ + f"Column: {col}\nCheck: {check},\nCheck Values: {check_values}\n" + for check, check_values in checks.items() + if not check_values["success"] + ] + return [ + f"\tCheck: {check},\n\tCheck Values: {check_values}\n" + for check, check_values in checks.items() + if not check_values["success"] + ] + + +parse_boolean = _parse_boolean +""" +IMPORTANT!!! Keep it for compatibility with released 8.4.0 version of google provider. + +Unfortunately the provider used _get_failed_checks and parse_boolean as imports and we should +keep those methods to avoid 8.4.0 version from failing. +""" + + +_PROVIDERS_MATCHER = re.compile(r"airflow\.providers\.(.*)\.hooks.*") + +_MIN_SUPPORTED_PROVIDERS_VERSION = { + "amazon": "4.1.0", + "apache.drill": "2.1.0", + "apache.druid": "3.1.0", + "apache.hive": "3.1.0", + "apache.pinot": "3.1.0", + "databricks": "3.1.0", + "elasticsearch": "4.1.0", + "exasol": "3.1.0", + "google": "8.2.0", + "jdbc": "3.1.0", + "mssql": "3.1.0", + "mysql": "3.1.0", + "odbc": "3.1.0", + "oracle": "3.1.0", + "postgres": "5.1.0", + "presto": "3.1.0", + "qubole": "3.1.0", + "slack": "5.1.0", + "snowflake": "3.1.0", + "sqlite": "3.1.0", + "trino": "3.1.0", + "vertica": "3.1.0", +} + + +class BaseSQLOperator(BaseOperator): + """ + This is a base class for generic SQL Operator to get a DB Hook + + The provided method is .get_db_hook(). The default behavior will try to + retrieve the DB hook based on connection type. + You can customize the behavior by overriding the .get_db_hook() method. + + :param conn_id: reference to a specific database + """ + + def __init__( + self, + *, + conn_id: str | None = None, + database: str | None = None, + hook_params: dict | None = None, + retry_on_failure: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.conn_id = conn_id + self.database = database + self.hook_params = {} if hook_params is None else hook_params + self.retry_on_failure = retry_on_failure + + @cached_property + def _hook(self): + """Get DB Hook based on connection type""" + self.log.debug("Get connection for %s", self.conn_id) + conn = BaseHook.get_connection(self.conn_id) + hook = conn.get_hook(hook_params=self.hook_params) + if not isinstance(hook, DbApiHook): + from airflow.hooks.dbapi_hook import DbApiHook as _DbApiHook + + if isinstance(hook, _DbApiHook): + # This case might happen if user installed common.sql provider but did not upgrade the + # Other provider's versions to a version that supports common.sql provider + class_module = hook.__class__.__module__ + match = _PROVIDERS_MATCHER.match(class_module) + if match: + provider = match.group(1) + min_version = _MIN_SUPPORTED_PROVIDERS_VERSION.get(provider) + if min_version: + raise AirflowException( + f"You are trying to use common-sql with {hook.__class__.__name__}," + f" but the Hook class comes from provider {provider} that does not support it." + f" Please upgrade provider {provider} to at least {min_version}." + ) + raise AirflowException( + f"You are trying to use `common-sql` with {hook.__class__.__name__}," + " but its provider does not support it. Please upgrade the provider to a version that" + " supports `common-sql`. The hook class should be a subclass of" + " `airflow.providers.common.sql.hooks.sql.DbApiHook`." + f" Got {hook.__class__.__name__} Hook with class hierarchy: {hook.__class__.mro()}" + ) + + if self.database: + hook.schema = self.database + + return hook + + def get_db_hook(self) -> DbApiHook: + """ + Get the database hook for the connection. + + :return: the database hook object. + """ + return self._hook + + def _raise_exception(self, exception_string: str) -> NoReturn: + if self.retry_on_failure: + raise AirflowException(exception_string) + raise AirflowFailException(exception_string) + + +class SQLExecuteQueryOperator(BaseSQLOperator): + """ + Executes SQL code in a specific database + :param sql: the SQL code or string pointing to a template file to be executed (templated). + File must have a '.sql' extensions. + :param autocommit: (optional) if True, each command is automatically committed (default: False). + :param parameters: (optional) the parameters to render the SQL query with. + :param handler: (optional) the function that will be applied to the cursor (default: fetch_all_handler). + :param split_statements: (optional) if split single SQL string into statements (default: False). + :param return_last: (optional) if return the result of only last statement (default: True). + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SQLExecuteQueryOperator` + """ + + template_fields: Sequence[str] = ("sql", "parameters") + template_ext: Sequence[str] = (".sql", ".json") + template_fields_renderers = {"sql": "sql", "parameters": "json"} + ui_color = "#cdaaed" + + def __init__( + self, + *, + sql: str | list[str], + autocommit: bool = False, + parameters: Mapping | Iterable | None = None, + handler: Callable[[Any], Any] = fetch_all_handler, + split_statements: bool = False, + return_last: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.sql = sql + self.autocommit = autocommit + self.parameters = parameters + self.handler = handler + self.split_statements = split_statements + self.return_last = return_last + + @overload + def _process_output( + self, results: Any, description: Sequence[Sequence] | None, scalar_results: Literal[True] + ) -> Any: + pass + + @overload + def _process_output( + self, results: list[Any], description: Sequence[Sequence] | None, scalar_results: Literal[False] + ) -> Any: + pass + + def _process_output( + self, results: Any | list[Any], description: Sequence[Sequence] | None, scalar_results: bool + ) -> Any: + """ + Can be overridden by the subclass in case some extra processing is needed. + The "process_output" method can override the returned output - augmenting or processing the + output as needed - the output returned will be returned as execute return value and if + do_xcom_push is set to True, it will be set as XCom returned + + :param results: results in the form of list of rows. + :param description: as returned by ``cur.description`` in the Python DBAPI + :param scalar_results: True if result is single scalar value rather than list of rows + """ + return results + + def execute(self, context): + self.log.info("Executing: %s", self.sql) + hook = self.get_db_hook() + if self.do_xcom_push: + output = hook.run( + sql=self.sql, + autocommit=self.autocommit, + parameters=self.parameters, + handler=self.handler, + split_statements=self.split_statements, + return_last=self.return_last, + ) + else: + output = hook.run( + sql=self.sql, + autocommit=self.autocommit, + parameters=self.parameters, + split_statements=self.split_statements, + ) + + return self._process_output(output, hook.last_description, hook.scalar_return_last) + + def prepare_template(self) -> None: + """Parse template file for attribute parameters.""" + if isinstance(self.parameters, str): + self.parameters = ast.literal_eval(self.parameters) + + +class SQLColumnCheckOperator(BaseSQLOperator): + """ + Performs one or more of the templated checks in the column_checks dictionary. + Checks are performed on a per-column basis specified by the column_mapping. + Each check can take one or more of the following options: + - equal_to: an exact value to equal, cannot be used with other comparison options + - greater_than: value that result should be strictly greater than + - less_than: value that results should be strictly less than + - geq_to: value that results should be greater than or equal to + - leq_to: value that results should be less than or equal to + - tolerance: the percentage that the result may be off from the expected value + - partition_clause: an extra clause passed into a WHERE statement to partition data + + :param table: the table to run checks on + :param column_mapping: the dictionary of columns and their associated checks, e.g. + + .. code-block:: python + + { + "col_name": { + "null_check": { + "equal_to": 0, + "partition_clause": "foreign_key IS NOT NULL", + }, + "min": { + "greater_than": 5, + "leq_to": 10, + "tolerance": 0.2, + }, + "max": {"less_than": 1000, "geq_to": 10, "tolerance": 0.01}, + } + } + + :param partition_clause: a partial SQL statement that is added to a WHERE clause in the query built by + the operator that creates partition_clauses for the checks to run on, e.g. + + .. code-block:: python + + "date = '1970-01-01'" + + :param conn_id: the connection ID used to connect to the database + :param database: name of database which overwrite the defined one in connection + :param accept_none: whether or not to accept None values returned by the query. If true, converts None + to 0. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SQLColumnCheckOperator` + """ + + template_fields = ("partition_clause",) + + sql_check_template = """ + SELECT '{column}' AS col_name, '{check}' AS check_type, {column}_{check} AS check_result + FROM (SELECT {check_statement} AS {column}_{check} FROM {table} {partition_clause}) AS sq + """ + + column_checks = { + "null_check": "SUM(CASE WHEN {column} IS NULL THEN 1 ELSE 0 END)", + "distinct_check": "COUNT(DISTINCT({column}))", + "unique_check": "COUNT({column}) - COUNT(DISTINCT({column}))", + "min": "MIN({column})", + "max": "MAX({column})", + } + + def __init__( + self, + *, + table: str, + column_mapping: dict[str, dict[str, Any]], + partition_clause: str | None = None, + conn_id: str | None = None, + database: str | None = None, + accept_none: bool = True, + **kwargs, + ): + super().__init__(conn_id=conn_id, database=database, **kwargs) + + self.table = table + self.column_mapping = column_mapping + self.partition_clause = partition_clause + self.accept_none = accept_none + + def _build_checks_sql(): + for column, checks in self.column_mapping.items(): + for check, check_values in checks.items(): + self._column_mapping_validation(check, check_values) + yield self._generate_sql_query(column, checks) + + checks_sql = "UNION ALL".join(_build_checks_sql()) + + self.sql = f"SELECT col_name, check_type, check_result FROM ({checks_sql}) AS check_columns" + + def execute(self, context: Context): + hook = self.get_db_hook() + records = hook.get_records(self.sql) + + if not records: + self._raise_exception(f"The following query returned zero rows: {self.sql}") + + self.log.info("Record: %s", records) + + for column, check, result in records: + tolerance = self.column_mapping[column][check].get("tolerance") + + self.column_mapping[column][check]["result"] = result + self.column_mapping[column][check]["success"] = self._get_match( + self.column_mapping[column][check], result, tolerance + ) + + failed_tests = [ + f"Column: {col}\n\tCheck: {check},\n\tCheck Values: {check_values}\n" + for col, checks in self.column_mapping.items() + for check, check_values in checks.items() + if not check_values["success"] + ] + if failed_tests: + exception_string = ( + f"Test failed.\nResults:\n{records!s}\n" + f"The following tests have failed:\n{''.join(failed_tests)}" + ) + self._raise_exception(exception_string) + + self.log.info("All tests have passed") + + def _generate_sql_query(self, column, checks): + def _generate_partition_clause(check): + if self.partition_clause and "partition_clause" not in checks[check]: + return f"WHERE {self.partition_clause}" + elif not self.partition_clause and "partition_clause" in checks[check]: + return f"WHERE {checks[check]['partition_clause']}" + elif self.partition_clause and "partition_clause" in checks[check]: + return f"WHERE {self.partition_clause} AND {checks[check]['partition_clause']}" + else: + return "" + + checks_sql = "UNION ALL".join( + self.sql_check_template.format( + check_statement=self.column_checks[check].format(column=column), + check=check, + table=self.table, + column=column, + partition_clause=_generate_partition_clause(check), + ) + for check in checks + ) + return checks_sql + + def _get_match(self, check_values, record, tolerance=None) -> bool: + if record is None and self.accept_none: + record = 0 + match_boolean = True + if "geq_to" in check_values: + if tolerance is not None: + match_boolean = record >= check_values["geq_to"] * (1 - tolerance) + else: + match_boolean = record >= check_values["geq_to"] + elif "greater_than" in check_values: + if tolerance is not None: + match_boolean = record > check_values["greater_than"] * (1 - tolerance) + else: + match_boolean = record > check_values["greater_than"] + if "leq_to" in check_values: + if tolerance is not None: + match_boolean = record <= check_values["leq_to"] * (1 + tolerance) and match_boolean + else: + match_boolean = record <= check_values["leq_to"] and match_boolean + elif "less_than" in check_values: + if tolerance is not None: + match_boolean = record < check_values["less_than"] * (1 + tolerance) and match_boolean + else: + match_boolean = record < check_values["less_than"] and match_boolean + if "equal_to" in check_values: + if tolerance is not None: + match_boolean = ( + check_values["equal_to"] * (1 - tolerance) + <= record + <= check_values["equal_to"] * (1 + tolerance) + ) and match_boolean + else: + match_boolean = record == check_values["equal_to"] and match_boolean + return match_boolean + + def _column_mapping_validation(self, check, check_values): + if check not in self.column_checks: + raise AirflowException(f"Invalid column check: {check}.") + if ( + "greater_than" not in check_values + and "geq_to" not in check_values + and "less_than" not in check_values + and "leq_to" not in check_values + and "equal_to" not in check_values + ): + raise ValueError( + "Please provide one or more of: less_than, leq_to, " + "greater_than, geq_to, or equal_to in the check's dict." + ) + + if "greater_than" in check_values and "less_than" in check_values: + if check_values["greater_than"] >= check_values["less_than"]: + raise ValueError( + "greater_than should be strictly less than " + "less_than. Use geq_to or leq_to for " + "overlapping equality." + ) + + if "greater_than" in check_values and "leq_to" in check_values: + if check_values["greater_than"] >= check_values["leq_to"]: + raise ValueError( + "greater_than must be strictly less than leq_to. " + "Use geq_to with leq_to for overlapping equality." + ) + + if "geq_to" in check_values and "less_than" in check_values: + if check_values["geq_to"] >= check_values["less_than"]: + raise ValueError( + "geq_to should be strictly less than less_than. " + "Use leq_to with geq_to for overlapping equality." + ) + + if "geq_to" in check_values and "leq_to" in check_values: + if check_values["geq_to"] > check_values["leq_to"]: + raise ValueError("geq_to should be less than or equal to leq_to.") + + if "greater_than" in check_values and "geq_to" in check_values: + raise ValueError("Only supply one of greater_than or geq_to.") + + if "less_than" in check_values and "leq_to" in check_values: + raise ValueError("Only supply one of less_than or leq_to.") + + if ( + "greater_than" in check_values + or "geq_to" in check_values + or "less_than" in check_values + or "leq_to" in check_values + ) and "equal_to" in check_values: + raise ValueError( + "equal_to cannot be passed with a greater or less than " + "function. To specify 'greater than or equal to' or " + "'less than or equal to', use geq_to or leq_to." + ) + + +class SQLTableCheckOperator(BaseSQLOperator): + """ + Performs one or more of the checks provided in the checks dictionary. + Checks should be written to return a boolean result. + + :param table: the table to run checks on + :param checks: the dictionary of checks, where check names are followed by a dictionary containing at + least a check statement, and optionally a partition clause, e.g.: + + .. code-block:: python + + { + "row_count_check": {"check_statement": "COUNT(*) = 1000"}, + "column_sum_check": {"check_statement": "col_a + col_b < col_c"}, + "third_check": {"check_statement": "MIN(col) = 1", "partition_clause": "col IS NOT NULL"}, + } + + + :param partition_clause: a partial SQL statement that is added to a WHERE clause in the query built by + the operator that creates partition_clauses for the checks to run on, e.g. + + .. code-block:: python + + "date = '1970-01-01'" + + :param conn_id: the connection ID used to connect to the database + :param database: name of database which overwrite the defined one in connection + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SQLTableCheckOperator` + """ + + template_fields = ("partition_clause",) + + sql_check_template = """ + SELECT '{check_name}' AS check_name, MIN({check_name}) AS check_result + FROM (SELECT CASE WHEN {check_statement} THEN 1 ELSE 0 END AS {check_name} + FROM {table} {partition_clause}) AS sq + """ + + def __init__( + self, + *, + table: str, + checks: dict[str, dict[str, Any]], + partition_clause: str | None = None, + conn_id: str | None = None, + database: str | None = None, + **kwargs, + ): + super().__init__(conn_id=conn_id, database=database, **kwargs) + + self.table = table + self.checks = checks + self.partition_clause = partition_clause + self.sql = f"SELECT check_name, check_result FROM ({self._generate_sql_query()}) AS check_table" + + def execute(self, context: Context): + hook = self.get_db_hook() + records = hook.get_records(self.sql) + + if not records: + self._raise_exception(f"The following query returned zero rows: {self.sql}") + + self.log.info("Record:\n%s", records) + + for row in records: + check, result = row + self.checks[check]["success"] = _parse_boolean(str(result)) + + failed_tests = [ + f"\tCheck: {check},\n\tCheck Values: {check_values}\n" + for check, check_values in self.checks.items() + if not check_values["success"] + ] + if failed_tests: + exception_string = ( + f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n" + f"The following tests have failed:\n{', '.join(failed_tests)}" + ) + self._raise_exception(exception_string) + + self.log.info("All tests have passed") + + def _generate_sql_query(self): + def _generate_partition_clause(check_name): + if self.partition_clause and "partition_clause" not in self.checks[check_name]: + return f"WHERE {self.partition_clause}" + elif not self.partition_clause and "partition_clause" in self.checks[check_name]: + return f"WHERE {self.checks[check_name]['partition_clause']}" + elif self.partition_clause and "partition_clause" in self.checks[check_name]: + return f"WHERE {self.partition_clause} AND {self.checks[check_name]['partition_clause']}" + else: + return "" + + return "UNION ALL".join( + self.sql_check_template.format( + check_statement=value["check_statement"], + check_name=check_name, + table=self.table, + partition_clause=_generate_partition_clause(check_name), + ) + for check_name, value in self.checks.items() + ) + + +class SQLCheckOperator(BaseSQLOperator): + """ + Performs checks against a db. The ``SQLCheckOperator`` expects + a sql query that will return a single row. Each value on that + first row is evaluated using python ``bool`` casting. If any of the + values return ``False`` the check is failed and errors out. + + Note that Python bool casting evals the following as ``False``: + + * ``False`` + * ``0`` + * Empty string (``""``) + * Empty list (``[]``) + * Empty dictionary or set (``{}``) + + Given a query like ``SELECT COUNT(*) FROM foo``, it will fail only if + the count ``== 0``. You can craft much more complex query that could, + for instance, check that the table has the same number of rows as + the source table upstream, or that the count of today's partition is + greater than yesterday's partition, or that a set of metrics are less + than 3 standard deviation for the 7 day average. + + This operator can be used as a data quality check in your pipeline, and + depending on where you put it in your DAG, you have the choice to + stop the critical path, preventing from + publishing dubious data, or on the side and receive email alerts + without stopping the progress of the DAG. + + :param sql: the sql to be executed. (templated) + :param conn_id: the connection ID used to connect to the database. + :param database: name of database which overwrite the defined one in connection + :param parameters: (optional) the parameters to render the SQL query with. + """ + + template_fields: Sequence[str] = ("sql",) + template_ext: Sequence[str] = ( + ".hql", + ".sql", + ) + template_fields_renderers = {"sql": "sql"} + ui_color = "#fff7e6" + + def __init__( + self, + *, + sql: str, + conn_id: str | None = None, + database: str | None = None, + parameters: Iterable | Mapping | None = None, + **kwargs, + ) -> None: + super().__init__(conn_id=conn_id, database=database, **kwargs) + self.sql = sql + self.parameters = parameters + + def execute(self, context: Context): + self.log.info("Executing SQL check: %s", self.sql) + records = self.get_db_hook().get_first(self.sql, self.parameters) + + self.log.info("Record: %s", records) + if not records: + self._raise_exception(f"The following query returned zero rows: {self.sql}") + elif not all(bool(r) for r in records): + self._raise_exception(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}") + + self.log.info("Success.") + + +class SQLValueCheckOperator(BaseSQLOperator): + """ + Performs a simple value check using sql code. + + :param sql: the sql to be executed. (templated) + :param conn_id: the connection ID used to connect to the database. + :param database: name of database which overwrite the defined one in connection + """ + + __mapper_args__ = {"polymorphic_identity": "SQLValueCheckOperator"} + template_fields: Sequence[str] = ( + "sql", + "pass_value", + ) + template_ext: Sequence[str] = ( + ".hql", + ".sql", + ) + template_fields_renderers = {"sql": "sql"} + ui_color = "#fff7e6" + + def __init__( + self, + *, + sql: str, + pass_value: Any, + tolerance: Any = None, + conn_id: str | None = None, + database: str | None = None, + **kwargs, + ): + super().__init__(conn_id=conn_id, database=database, **kwargs) + self.sql = sql + self.pass_value = str(pass_value) + tol = _convert_to_float_if_possible(tolerance) + self.tol = tol if isinstance(tol, float) else None + self.has_tolerance = self.tol is not None + + def execute(self, context: Context): + self.log.info("Executing SQL check: %s", self.sql) + records = self.get_db_hook().get_first(self.sql) + + if not records: + self._raise_exception(f"The following query returned zero rows: {self.sql}") + + pass_value_conv = _convert_to_float_if_possible(self.pass_value) + is_numeric_value_check = isinstance(pass_value_conv, float) + + tolerance_pct_str = str(self.tol * 100) + "%" if self.tol is not None else None + error_msg = ( + "Test failed.\nPass value:{pass_value_conv}\n" + "Tolerance:{tolerance_pct_str}\n" + "Query:\n{sql}\nResults:\n{records!s}" + ).format( + pass_value_conv=pass_value_conv, + tolerance_pct_str=tolerance_pct_str, + sql=self.sql, + records=records, + ) + + if not is_numeric_value_check: + tests = self._get_string_matches(records, pass_value_conv) + elif is_numeric_value_check: + try: + numeric_records = self._to_float(records) + except (ValueError, TypeError): + raise AirflowException(f"Converting a result to float failed.\n{error_msg}") + tests = self._get_numeric_matches(numeric_records, pass_value_conv) + else: + tests = [] + + if not all(tests): + self._raise_exception(error_msg) + + def _to_float(self, records): + return [float(record) for record in records] + + def _get_string_matches(self, records, pass_value_conv): + return [str(record) == pass_value_conv for record in records] + + def _get_numeric_matches(self, numeric_records, numeric_pass_value_conv): + if self.has_tolerance: + return [ + numeric_pass_value_conv * (1 - self.tol) <= record <= numeric_pass_value_conv * (1 + self.tol) + for record in numeric_records + ] + + return [record == numeric_pass_value_conv for record in numeric_records] + + +class SQLIntervalCheckOperator(BaseSQLOperator): + """ + Checks that the values of metrics given as SQL expressions are within + a certain tolerance of the ones from days_back before. + + :param table: the table name + :param conn_id: the connection ID used to connect to the database. + :param database: name of database which will overwrite the defined one in connection + :param days_back: number of days between ds and the ds we want to check + against. Defaults to 7 days + :param date_filter_column: The column name for the dates to filter on. Defaults to 'ds' + :param ratio_formula: which formula to use to compute the ratio between + the two metrics. Assuming cur is the metric of today and ref is + the metric to today - days_back. + + max_over_min: computes max(cur, ref) / min(cur, ref) + relative_diff: computes abs(cur-ref) / ref + + Default: 'max_over_min' + :param ignore_zero: whether we should ignore zero metrics + :param metrics_thresholds: a dictionary of ratios indexed by metrics + """ + + __mapper_args__ = {"polymorphic_identity": "SQLIntervalCheckOperator"} + template_fields: Sequence[str] = ("sql1", "sql2") + template_ext: Sequence[str] = ( + ".hql", + ".sql", + ) + template_fields_renderers = {"sql1": "sql", "sql2": "sql"} + ui_color = "#fff7e6" + + ratio_formulas = { + "max_over_min": lambda cur, ref: float(max(cur, ref)) / min(cur, ref), + "relative_diff": lambda cur, ref: float(abs(cur - ref)) / ref, + } + + def __init__( + self, + *, + table: str, + metrics_thresholds: dict[str, int], + date_filter_column: str | None = "ds", + days_back: SupportsAbs[int] = -7, + ratio_formula: str | None = "max_over_min", + ignore_zero: bool = True, + conn_id: str | None = None, + database: str | None = None, + **kwargs, + ): + super().__init__(conn_id=conn_id, database=database, **kwargs) + if ratio_formula not in self.ratio_formulas: + msg_template = "Invalid diff_method: {diff_method}. Supported diff methods are: {diff_methods}" + + raise AirflowFailException( + msg_template.format(diff_method=ratio_formula, diff_methods=self.ratio_formulas) + ) + self.ratio_formula = ratio_formula + self.ignore_zero = ignore_zero + self.table = table + self.metrics_thresholds = metrics_thresholds + self.metrics_sorted = sorted(metrics_thresholds.keys()) + self.date_filter_column = date_filter_column + self.days_back = -abs(days_back) + sqlexp = ", ".join(self.metrics_sorted) + sqlt = f"SELECT {sqlexp} FROM {table} WHERE {date_filter_column}=" + + self.sql1 = sqlt + "'{{ ds }}'" + self.sql2 = sqlt + "'{{ macros.ds_add(ds, " + str(self.days_back) + ") }}'" + + def execute(self, context: Context): + hook = self.get_db_hook() + self.log.info("Using ratio formula: %s", self.ratio_formula) + self.log.info("Executing SQL check: %s", self.sql2) + row2 = hook.get_first(self.sql2) + self.log.info("Executing SQL check: %s", self.sql1) + row1 = hook.get_first(self.sql1) + + if not row2: + self._raise_exception(f"The following query returned zero rows: {self.sql2}") + if not row1: + self._raise_exception(f"The following query returned zero rows: {self.sql1}") + + current = dict(zip(self.metrics_sorted, row1)) + reference = dict(zip(self.metrics_sorted, row2)) + + ratios: dict[str, int | None] = {} + test_results = {} + + for metric in self.metrics_sorted: + cur = current[metric] + ref = reference[metric] + threshold = self.metrics_thresholds[metric] + if cur == 0 or ref == 0: + ratios[metric] = None + test_results[metric] = self.ignore_zero + else: + ratio_metric = self.ratio_formulas[self.ratio_formula](current[metric], reference[metric]) + ratios[metric] = ratio_metric + if ratio_metric is not None: + test_results[metric] = ratio_metric < threshold + else: + test_results[metric] = self.ignore_zero + + self.log.info( + ( + "Current metric for %s: %s\n" + "Past metric for %s: %s\n" + "Ratio for %s: %s\n" + "Threshold: %s\n" + ), + metric, + cur, + metric, + ref, + metric, + ratios[metric], + threshold, + ) + + if not all(test_results.values()): + failed_tests = [it[0] for it in test_results.items() if not it[1]] + self.log.warning( + "The following %s tests out of %s failed:", + len(failed_tests), + len(self.metrics_sorted), + ) + for k in failed_tests: + self.log.warning( + "'%s' check failed. %s is above %s", + k, + ratios[k], + self.metrics_thresholds[k], + ) + self._raise_exception(f"The following tests have failed:\n {', '.join(sorted(failed_tests))}") + + self.log.info("All tests have passed") + + +class SQLThresholdCheckOperator(BaseSQLOperator): + """ + Performs a value check using sql code against a minimum threshold + and a maximum threshold. Thresholds can be in the form of a numeric + value OR a sql statement that results a numeric. + + :param sql: the sql to be executed. (templated) + :param conn_id: the connection ID used to connect to the database. + :param database: name of database which overwrite the defined one in connection + :param min_threshold: numerical value or min threshold sql to be executed (templated) + :param max_threshold: numerical value or max threshold sql to be executed (templated) + """ + + template_fields: Sequence[str] = ("sql", "min_threshold", "max_threshold") + template_ext: Sequence[str] = ( + ".hql", + ".sql", + ) + template_fields_renderers = {"sql": "sql"} + + def __init__( + self, + *, + sql: str, + min_threshold: Any, + max_threshold: Any, + conn_id: str | None = None, + database: str | None = None, + **kwargs, + ): + super().__init__(conn_id=conn_id, database=database, **kwargs) + self.sql = sql + self.min_threshold = _convert_to_float_if_possible(min_threshold) + self.max_threshold = _convert_to_float_if_possible(max_threshold) + + def execute(self, context: Context): + hook = self.get_db_hook() + result = hook.get_first(self.sql)[0] + if not result: + self._raise_exception(f"The following query returned zero rows: {self.sql}") + + if isinstance(self.min_threshold, float): + lower_bound = self.min_threshold + else: + lower_bound = hook.get_first(self.min_threshold)[0] + + if isinstance(self.max_threshold, float): + upper_bound = self.max_threshold + else: + upper_bound = hook.get_first(self.max_threshold)[0] + + meta_data = { + "result": result, + "task_id": self.task_id, + "min_threshold": lower_bound, + "max_threshold": upper_bound, + "within_threshold": lower_bound <= result <= upper_bound, + } + + self.push(meta_data) + if not meta_data["within_threshold"]: + result = ( + round(meta_data.get("result"), 2) # type: ignore[arg-type] + if meta_data.get("result") is not None + else "" + ) + error_msg = ( + f'Threshold Check: "{meta_data.get("task_id")}" failed.\n' + f'DAG: {self.dag_id}\nTask_id: {meta_data.get("task_id")}\n' + f'Check description: {meta_data.get("description")}\n' + f"SQL: {self.sql}\n" + f"Result: {result} is not within thresholds " + f'{meta_data.get("min_threshold")} and {meta_data.get("max_threshold")}' + ) + self._raise_exception(error_msg) + + self.log.info("Test %s Successful.", self.task_id) + + def push(self, meta_data): + """ + Optional: Send data check info and metadata to an external database. + Default functionality will log metadata. + """ + info = "\n".join(f"""{key}: {item}""" for key, item in meta_data.items()) + self.log.info("Log from %s:\n%s", self.dag_id, info) + + +class BranchSQLOperator(BaseSQLOperator, SkipMixin): + """ + Allows a DAG to "branch" or follow a specified path based on the results of a SQL query. + + :param sql: The SQL code to be executed, should return true or false (templated) + Template reference are recognized by str ending in '.sql'. + Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1) + or string (true/y/yes/1/on/false/n/no/0/off). + :param follow_task_ids_if_true: task id or task ids to follow if query returns true + :param follow_task_ids_if_false: task id or task ids to follow if query returns false + :param conn_id: the connection ID used to connect to the database. + :param database: name of database which overwrite the defined one in connection + :param parameters: (optional) the parameters to render the SQL query with. + """ + + template_fields: Sequence[str] = ("sql",) + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "sql"} + ui_color = "#a22034" + ui_fgcolor = "#F7F7F7" + + def __init__( + self, + *, + sql: str, + follow_task_ids_if_true: list[str], + follow_task_ids_if_false: list[str], + conn_id: str = "default_conn_id", + database: str | None = None, + parameters: Iterable | Mapping | None = None, + **kwargs, + ) -> None: + super().__init__(conn_id=conn_id, database=database, **kwargs) + self.sql = sql + self.parameters = parameters + self.follow_task_ids_if_true = follow_task_ids_if_true + self.follow_task_ids_if_false = follow_task_ids_if_false + + def execute(self, context: Context): + self.log.info( + "Executing: %s (with parameters %s) with connection: %s", + self.sql, + self.parameters, + self.conn_id, + ) + record = self.get_db_hook().get_first(self.sql, self.parameters) + if not record: + raise AirflowException( + "No rows returned from sql query. Operator expected True or False return value." + ) + + if isinstance(record, list): + if isinstance(record[0], list): + query_result = record[0][0] + else: + query_result = record[0] + elif isinstance(record, tuple): + query_result = record[0] + else: + query_result = record + + self.log.info("Query returns %s, type '%s'", query_result, type(query_result)) + + follow_branch = None + try: + if isinstance(query_result, bool): + if query_result: + follow_branch = self.follow_task_ids_if_true + elif isinstance(query_result, str): + # return result is not Boolean, try to convert from String to Boolean + if _parse_boolean(query_result): + follow_branch = self.follow_task_ids_if_true + elif isinstance(query_result, int): + if bool(query_result): + follow_branch = self.follow_task_ids_if_true + else: + raise AirflowException( + f"Unexpected query return result '{query_result}' type '{type(query_result)}'" + ) + + if follow_branch is None: + follow_branch = self.follow_task_ids_if_false + except ValueError: + raise AirflowException( + f"Unexpected query return result '{query_result}' type '{type(query_result)}'" + ) + + self.skip_all_except(context["ti"], follow_branch) diff --git a/airflow/providers/common/sql/provider.yaml b/airflow/providers/common/sql/provider.yaml new file mode 100644 index 0000000000000..4527dfff586a6 --- /dev/null +++ b/airflow/providers/common/sql/provider.yaml @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-common-sql +name: Common SQL +description: | + `Common SQL Provider `__ + +versions: + - 1.3.1 + - 1.3.0 + - 1.2.0 + - 1.1.0 + - 1.0.0 + +dependencies: + - sqlparse>=0.4.2 + +additional-extras: + - name: pandas + dependencies: + - pandas>=0.17.1 + +integrations: + - integration-name: Common SQL + external-doc-url: https://en.wikipedia.org/wiki/SQL + how-to-guide: + - /docs/apache-airflow-providers-common-sql/operators.rst + logo: /integration-logos/common/sql/sql.png + tags: [software] + +operators: + - integration-name: Common SQL + python-modules: + - airflow.providers.common.sql.operators.sql + +hooks: + - integration-name: Common SQL + python-modules: + - airflow.providers.common.sql.hooks.sql + +sensors: + - integration-name: Common SQL + python-modules: + - airflow.providers.common.sql.sensors.sql diff --git a/airflow/providers/jdbc/example_dags/__init__.py b/airflow/providers/common/sql/sensors/__init__.py similarity index 100% rename from airflow/providers/jdbc/example_dags/__init__.py rename to airflow/providers/common/sql/sensors/__init__.py diff --git a/airflow/providers/common/sql/sensors/sql.py b/airflow/providers/common/sql/sensors/sql.py new file mode 100644 index 0000000000000..d58802dc98b03 --- /dev/null +++ b/airflow/providers/common/sql/sensors/sql.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any, Sequence + +from airflow import AirflowException +from airflow.hooks.base import BaseHook +from airflow.providers.common.sql.hooks.sql import DbApiHook +from airflow.sensors.base import BaseSensorOperator + + +class SqlSensor(BaseSensorOperator): + """ + Runs a sql statement repeatedly until a criteria is met. It will keep trying until + success or failure criteria are met, or if the first cell is not in (0, '0', '', None). + Optional success and failure callables are called with the first cell returned as the argument. + If success callable is defined the sensor will keep retrying until the criteria is met. + If failure callable is defined and the criteria is met the sensor will raise AirflowException. + Failure criteria is evaluated before success criteria. A fail_on_empty boolean can also + be passed to the sensor in which case it will fail if no rows have been returned + + :param conn_id: The connection to run the sensor against + :param sql: The sql to run. To pass, it needs to return at least one cell + that contains a non-zero / empty string value. + :param parameters: The parameters to render the SQL query with (optional). + :param success: Success criteria for the sensor is a Callable that takes first_cell + as the only argument, and returns a boolean (optional). + :param failure: Failure criteria for the sensor is a Callable that takes first_cell + as the only argument and return a boolean (optional). + :param fail_on_empty: Explicitly fail on no rows returned. + :param hook_params: Extra config params to be passed to the underlying hook. + Should match the desired hook constructor params. + """ + + template_fields: Sequence[str] = ("sql",) + template_ext: Sequence[str] = ( + ".hql", + ".sql", + ) + ui_color = "#7c7287" + + def __init__( + self, + *, + conn_id, + sql, + parameters=None, + success=None, + failure=None, + fail_on_empty=False, + hook_params=None, + **kwargs, + ): + self.conn_id = conn_id + self.sql = sql + self.parameters = parameters + self.success = success + self.failure = failure + self.fail_on_empty = fail_on_empty + self.hook_params = hook_params + super().__init__(**kwargs) + + def _get_hook(self): + conn = BaseHook.get_connection(self.conn_id) + hook = conn.get_hook(hook_params=self.hook_params) + if not isinstance(hook, DbApiHook): + raise AirflowException( + f"The connection type is not supported by {self.__class__.__name__}. " + f"The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}" + ) + return hook + + def poke(self, context: Any): + hook = self._get_hook() + + self.log.info("Poking: %s (with parameters %s)", self.sql, self.parameters) + records = hook.get_records(self.sql, self.parameters) + if not records: + if self.fail_on_empty: + raise AirflowException("No rows returned, raising as per fail_on_empty flag") + else: + return False + first_cell = records[0][0] + if self.failure is not None: + if callable(self.failure): + if self.failure(first_cell): + raise AirflowException(f"Failure criteria met. self.failure({first_cell}) returned True") + else: + raise AirflowException(f"self.failure is present, but not callable -> {self.failure}") + if self.success is not None: + if callable(self.success): + return self.success(first_cell) + else: + raise AirflowException(f"self.success is present, but not callable -> {self.success}") + return bool(first_cell) diff --git a/airflow/providers/databricks/CHANGELOG.rst b/airflow/providers/databricks/CHANGELOG.rst index beb539e8a7b69..0d446920542b1 100644 --- a/airflow/providers/databricks/CHANGELOG.rst +++ b/airflow/providers/databricks/CHANGELOG.rst @@ -16,9 +16,137 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.4.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` +* ``Replace urlparse with urlsplit (#27389)`` + +Features +~~~~~~~~ + +* ``Add SQLExecuteQueryOperator (#25717)`` +* ``Use new job search API for triggering Databricks job by name (#27446)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + +3.3.0 +..... + +Features +~~~~~~~~ + +* ``DatabricksSubmitRunOperator dbt task support (#25623)`` + +Misc +~~~~ + +* ``Add common-sql lower bound for common-sql (#25789)`` +* ``Remove duplicated connection-type within the provider (#26628)`` + +Bug Fixes +~~~~~~~~~ + +* ``Databricks: fix provider name in the User-Agent string (#25873)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``D400 first line should end with period batch02 (#25268)`` + +3.2.0 +..... + +Features +~~~~~~~~ + +* ``Databricks: update user-agent string (#25578)`` +* ``More improvements in the Databricks operators (#25260)`` +* ``Improved telemetry for Databricks provider (#25115)`` +* ``Unify DbApiHook.run() method with the methods which override it (#23971)`` + +Bug Fixes +~~~~~~~~~ + +* ``Databricks: fix test_connection implementation (#25114)`` +* ``Do not convert boolean values to string in deep_string_coerce function (#25394)`` +* ``Correctly handle output of the failed tasks (#25427)`` +* ``Databricks: Fix provider for Airflow 2.2.x (#25674)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``updated documentation for databricks operator (#24599)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Added databricks_conn_id as templated field (#24945)`` +* ``Add 'test_connection' method to Databricks hook (#24617)`` +* ``Move all SQL classes to common-sql provider (#24836)`` + +Bug Fixes +~~~~~~~~~ + +* ``Update providers to use functools compat for ''cached_property'' (#24582)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Automatically detect if non-lazy logging interpolation is used (#24910)`` + * ``Remove "bad characters" from our codebase (#24841)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Features +~~~~~~~~ + +* ``Add Deferrable Databricks operators (#19736)`` +* ``Add git_source to DatabricksSubmitRunOperator (#23620)`` + +Bug Fixes +~~~~~~~~~ + +* ``fix: DatabricksSubmitRunOperator and DatabricksRunNowOperator cannot define .json as template_ext (#23622) (#23641)`` +* ``Fix UnboundLocalError when sql is empty list in DatabricksSqlHook (#23815)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``AIP-47 - Migrate databricks DAGs to new design #22442 (#24203)`` + * ``Introduce 'flake8-implicit-str-concat' plugin to static checks (#23873)`` + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.7.0 ..... @@ -62,7 +190,6 @@ Misc * ``Fix new MyPy errors in main (#22884)`` * ``Prepare mid-April provider documentation. (#22819)`` -.. Review and move the new changes to one of the sections above: * ``Prepare for RC2 release of March Databricks provider (#22979)`` 2.5.0 diff --git a/airflow/providers/databricks/example_dags/example_databricks.py b/airflow/providers/databricks/example_dags/example_databricks.py deleted file mode 100644 index bea9038afeb0a..0000000000000 --- a/airflow/providers/databricks/example_dags/example_databricks.py +++ /dev/null @@ -1,75 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This is an example DAG which uses the DatabricksSubmitRunOperator. -In this example, we create two tasks which execute sequentially. -The first task is to run a notebook at the workspace path "/test" -and the second task is to run a JAR uploaded to DBFS. Both, -tasks use new clusters. - -Because we have set a downstream dependency on the notebook task, -the spark jar task will NOT run until the notebook task completes -successfully. - -The definition of a successful run is if the run has a result_state of "SUCCESS". -For more information about the state of a run refer to -https://docs.databricks.com/api/latest/jobs.html#runstate -""" - -from datetime import datetime - -from airflow import DAG -from airflow.providers.databricks.operators.databricks import DatabricksSubmitRunOperator - -with DAG( - dag_id='example_databricks_operator', - schedule_interval='@daily', - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - # [START howto_operator_databricks_json] - # Example of using the JSON parameter to initialize the operator. - new_cluster = { - 'spark_version': '9.1.x-scala2.12', - 'node_type_id': 'r3.xlarge', - 'aws_attributes': {'availability': 'ON_DEMAND'}, - 'num_workers': 8, - } - - notebook_task_params = { - 'new_cluster': new_cluster, - 'notebook_task': { - 'notebook_path': '/Users/airflow@example.com/PrepareData', - }, - } - - notebook_task = DatabricksSubmitRunOperator(task_id='notebook_task', json=notebook_task_params) - # [END howto_operator_databricks_json] - - # [START howto_operator_databricks_named] - # Example of using the named parameters of DatabricksSubmitRunOperator - # to initialize the operator. - spark_jar_task = DatabricksSubmitRunOperator( - task_id='spark_jar_task', - new_cluster=new_cluster, - spark_jar_task={'main_class_name': 'com.example.ProcessData'}, - libraries=[{'jar': 'dbfs:/lib/etl-0.1.jar'}], - ) - # [END howto_operator_databricks_named] - notebook_task >> spark_jar_task diff --git a/airflow/providers/databricks/example_dags/example_databricks_repos.py b/airflow/providers/databricks/example_dags/example_databricks_repos.py deleted file mode 100644 index e33d32044f5df..0000000000000 --- a/airflow/providers/databricks/example_dags/example_databricks_repos.py +++ /dev/null @@ -1,74 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from datetime import datetime - -from airflow import DAG -from airflow.providers.databricks.operators.databricks import DatabricksSubmitRunOperator -from airflow.providers.databricks.operators.databricks_repos import ( - DatabricksReposCreateOperator, - DatabricksReposDeleteOperator, - DatabricksReposUpdateOperator, -) - -default_args = { - 'owner': 'airflow', - 'databricks_conn_id': 'databricks', -} - -with DAG( - dag_id='example_databricks_repos_operator', - schedule_interval='@daily', - start_date=datetime(2021, 1, 1), - default_args=default_args, - tags=['example'], - catchup=False, -) as dag: - # [START howto_operator_databricks_repo_create] - # Example of creating a Databricks Repo - repo_path = "/Repos/user@domain.com/demo-repo" - git_url = "https://github.com/test/test" - create_repo = DatabricksReposCreateOperator(task_id='create_repo', repo_path=repo_path, git_url=git_url) - # [END howto_operator_databricks_repo_create] - - # [START howto_operator_databricks_repo_update] - # Example of updating a Databricks Repo to the latest code - repo_path = "/Repos/user@domain.com/demo-repo" - update_repo = DatabricksReposUpdateOperator(task_id='update_repo', repo_path=repo_path, branch="releases") - # [END howto_operator_databricks_repo_update] - - notebook_task_params = { - 'new_cluster': { - 'spark_version': '9.1.x-scala2.12', - 'node_type_id': 'r3.xlarge', - 'aws_attributes': {'availability': 'ON_DEMAND'}, - 'num_workers': 8, - }, - 'notebook_task': { - 'notebook_path': f'{repo_path}/PrepareData', - }, - } - - notebook_task = DatabricksSubmitRunOperator(task_id='notebook_task', json=notebook_task_params) - - # [START howto_operator_databricks_repo_delete] - # Example of deleting a Databricks Repo - repo_path = "/Repos/user@domain.com/demo-repo" - delete_repo = DatabricksReposDeleteOperator(task_id='delete_repo', repo_path=repo_path) - # [END howto_operator_databricks_repo_delete] - - (create_repo >> update_repo >> notebook_task >> delete_repo) diff --git a/airflow/providers/databricks/example_dags/example_databricks_sql.py b/airflow/providers/databricks/example_dags/example_databricks_sql.py deleted file mode 100644 index 6032c0fb03032..0000000000000 --- a/airflow/providers/databricks/example_dags/example_databricks_sql.py +++ /dev/null @@ -1,113 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This is an example DAG which uses the DatabricksSubmitRunOperator. -In this example, we create two tasks which execute sequentially. -The first task is to run a notebook at the workspace path "/test" -and the second task is to run a JAR uploaded to DBFS. Both, -tasks use new clusters. - -Because we have set a downstream dependency on the notebook task, -the spark jar task will NOT run until the notebook task completes -successfully. - -The definition of a successful run is if the run has a result_state of "SUCCESS". -For more information about the state of a run refer to -https://docs.databricks.com/api/latest/jobs.html#runstate -""" - -from datetime import datetime - -from airflow import DAG -from airflow.providers.databricks.operators.databricks_sql import ( - DatabricksCopyIntoOperator, - DatabricksSqlOperator, -) - -with DAG( - dag_id='example_databricks_sql_operator', - schedule_interval='@daily', - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - connection_id = 'my_connection' - sql_endpoint_name = "My Endpoint" - - # [START howto_operator_databricks_sql_multiple] - # Example of using the Databricks SQL Operator to perform multiple operations. - create = DatabricksSqlOperator( - databricks_conn_id=connection_id, - sql_endpoint_name=sql_endpoint_name, - task_id='create_and_populate_table', - sql=[ - "drop table if exists default.my_airflow_table", - "create table default.my_airflow_table(id int, v string)", - "insert into default.my_airflow_table values (1, 'test 1'), (2, 'test 2')", - ], - ) - # [END howto_operator_databricks_sql_multiple] - - # [START howto_operator_databricks_sql_select] - # Example of using the Databricks SQL Operator to select data. - select = DatabricksSqlOperator( - databricks_conn_id=connection_id, - sql_endpoint_name=sql_endpoint_name, - task_id='select_data', - sql="select * from default.my_airflow_table", - ) - # [END howto_operator_databricks_sql_select] - - # [START howto_operator_databricks_sql_select_file] - # Example of using the Databricks SQL Operator to select data into a file with JSONL format. - select_into_file = DatabricksSqlOperator( - databricks_conn_id=connection_id, - sql_endpoint_name=sql_endpoint_name, - task_id='select_data_into_file', - sql="select * from default.my_airflow_table", - output_path="/tmp/1.jsonl", - output_format="jsonl", - ) - # [END howto_operator_databricks_sql_select_file] - - # [START howto_operator_databricks_sql_multiple_file] - # Example of using the Databricks SQL Operator to select data. - # SQL statements should be in the file with name test.sql - create_file = DatabricksSqlOperator( - databricks_conn_id=connection_id, - sql_endpoint_name=sql_endpoint_name, - task_id='create_and_populate_from_file', - sql="test.sql", - ) - # [END howto_operator_databricks_sql_multiple_file] - - # [START howto_operator_databricks_copy_into] - # Example of importing data using COPY_INTO SQL command - import_csv = DatabricksCopyIntoOperator( - task_id='import_csv', - databricks_conn_id=connection_id, - sql_endpoint_name=sql_endpoint_name, - table_name="my_table", - file_format="CSV", - file_location="abfss://container@account.dfs.core.windows.net/my-data/csv", - format_options={'header': 'true'}, - force_copy=True, - ) - # [END howto_operator_databricks_copy_into] - - (create >> create_file >> import_csv >> select >> select_into_file) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 400bbe895588a..bb8a1dc88080e 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -25,8 +25,10 @@ or the ``api/2.1/jobs/runs/submit`` `endpoint `_. """ +from __future__ import annotations + import json -from typing import Any, Dict, List, Optional +from typing import Any from requests import exceptions as requests_exceptions @@ -37,27 +39,29 @@ START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start") TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete") -RUN_NOW_ENDPOINT = ('POST', 'api/2.1/jobs/run-now') -SUBMIT_RUN_ENDPOINT = ('POST', 'api/2.1/jobs/runs/submit') -GET_RUN_ENDPOINT = ('GET', 'api/2.1/jobs/runs/get') -CANCEL_RUN_ENDPOINT = ('POST', 'api/2.1/jobs/runs/cancel') -OUTPUT_RUNS_JOB_ENDPOINT = ('GET', 'api/2.1/jobs/runs/get-output') +RUN_NOW_ENDPOINT = ("POST", "api/2.1/jobs/run-now") +SUBMIT_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/submit") +GET_RUN_ENDPOINT = ("GET", "api/2.1/jobs/runs/get") +CANCEL_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/cancel") +OUTPUT_RUNS_JOB_ENDPOINT = ("GET", "api/2.1/jobs/runs/get-output") + +INSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/install") +UNINSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/uninstall") -INSTALL_LIBS_ENDPOINT = ('POST', 'api/2.0/libraries/install') -UNINSTALL_LIBS_ENDPOINT = ('POST', 'api/2.0/libraries/uninstall') +LIST_JOBS_ENDPOINT = ("GET", "api/2.1/jobs/list") -LIST_JOBS_ENDPOINT = ('GET', 'api/2.1/jobs/list') +WORKSPACE_GET_STATUS_ENDPOINT = ("GET", "api/2.0/workspace/get-status") -WORKSPACE_GET_STATUS_ENDPOINT = ('GET', 'api/2.0/workspace/get-status') +RUN_LIFE_CYCLE_STATES = ["PENDING", "RUNNING", "TERMINATING", "TERMINATED", "SKIPPED", "INTERNAL_ERROR"] -RUN_LIFE_CYCLE_STATES = ['PENDING', 'RUNNING', 'TERMINATING', 'TERMINATED', 'SKIPPED', 'INTERNAL_ERROR'] +SPARK_VERSIONS_ENDPOINT = ("GET", "api/2.0/clusters/spark-versions") class RunState: """Utility class for the run state concept of Databricks runs.""" def __init__( - self, life_cycle_state: str, result_state: str = '', state_message: str = '', *args, **kwargs + self, life_cycle_state: str, result_state: str = "", state_message: str = "", *args, **kwargs ) -> None: self.life_cycle_state = life_cycle_state self.result_state = result_state @@ -69,17 +73,17 @@ def is_terminal(self) -> bool: if self.life_cycle_state not in RUN_LIFE_CYCLE_STATES: raise AirflowException( ( - 'Unexpected life cycle state: {}: If the state has ' - 'been introduced recently, please check the Databricks user ' - 'guide for troubleshooting information' + "Unexpected life cycle state: {}: If the state has " + "been introduced recently, please check the Databricks user " + "guide for troubleshooting information" ).format(self.life_cycle_state) ) - return self.life_cycle_state in ('TERMINATED', 'SKIPPED', 'INTERNAL_ERROR') + return self.life_cycle_state in ("TERMINATED", "SKIPPED", "INTERNAL_ERROR") @property def is_successful(self) -> bool: """True if the result state is SUCCESS""" - return self.result_state == 'SUCCESS' + return self.result_state == "SUCCESS" def __eq__(self, other: object) -> bool: if not isinstance(other, RunState): @@ -97,7 +101,7 @@ def to_json(self) -> str: return json.dumps(self.__dict__) @classmethod - def from_json(cls, data: str) -> 'RunState': + def from_json(cls, data: str) -> RunState: return RunState(**json.loads(data)) @@ -115,7 +119,7 @@ class DatabricksHook(BaseDatabricksHook): :param retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. """ - hook_name = 'Databricks' + hook_name = "Databricks" def __init__( self, @@ -123,9 +127,10 @@ def __init__( timeout_seconds: int = 180, retry_limit: int = 3, retry_delay: float = 1.0, - retry_args: Optional[Dict[Any, Any]] = None, + retry_args: dict[Any, Any] | None = None, + caller: str = "DatabricksHook", ) -> None: - super().__init__(databricks_conn_id, timeout_seconds, retry_limit, retry_delay, retry_args) + super().__init__(databricks_conn_id, timeout_seconds, retry_limit, retry_delay, retry_args, caller) def run_now(self, json: dict) -> int: """ @@ -133,10 +138,9 @@ def run_now(self, json: dict) -> int: :param json: The data used in the body of the request to the ``run-now`` endpoint. :return: the run_id as an int - :rtype: str """ response = self._do_api_call(RUN_NOW_ENDPOINT, json) - return response['run_id'] + return response["run_id"] def submit_run(self, json: dict) -> int: """ @@ -144,46 +148,53 @@ def submit_run(self, json: dict) -> int: :param json: The data used in the body of the request to the ``submit`` endpoint. :return: the run_id as an int - :rtype: str """ response = self._do_api_call(SUBMIT_RUN_ENDPOINT, json) - return response['run_id'] + return response["run_id"] - def list_jobs(self, limit: int = 25, offset: int = 0, expand_tasks: bool = False) -> List[Dict[str, Any]]: + def list_jobs( + self, limit: int = 25, offset: int = 0, expand_tasks: bool = False, job_name: str | None = None + ) -> list[dict[str, Any]]: """ Lists the jobs in the Databricks Job Service. :param limit: The limit/batch size used to retrieve jobs. :param offset: The offset of the first job to return, relative to the most recently created job. :param expand_tasks: Whether to include task and cluster details in the response. + :param job_name: Optional name of a job to search. :return: A list of jobs. """ has_more = True - jobs = [] + all_jobs = [] while has_more: - json = { - 'limit': limit, - 'offset': offset, - 'expand_tasks': expand_tasks, + payload: dict[str, Any] = { + "limit": limit, + "expand_tasks": expand_tasks, + "offset": offset, } - response = self._do_api_call(LIST_JOBS_ENDPOINT, json) - jobs += response['jobs'] if 'jobs' in response else [] - has_more = response.get('has_more', False) + if job_name: + payload["name"] = job_name + response = self._do_api_call(LIST_JOBS_ENDPOINT, payload) + jobs = response.get("jobs", []) + if job_name: + all_jobs += [j for j in jobs if j["settings"]["name"] == job_name] + else: + all_jobs += jobs + has_more = response.get("has_more", False) if has_more: - offset += len(response['jobs']) + offset += len(jobs) - return jobs + return all_jobs - def find_job_id_by_name(self, job_name: str) -> Optional[int]: + def find_job_id_by_name(self, job_name: str) -> int | None: """ Finds job id by its name. If there are multiple jobs with the same name, raises AirflowException. :param job_name: The name of the job to look up. :return: The job_id as an int or None if no job was found. """ - all_jobs = self.list_jobs() - matching_jobs = [j for j in all_jobs if j['settings']['name'] == job_name] + matching_jobs = self.list_jobs(job_name=job_name) if len(matching_jobs) > 1: raise AirflowException( @@ -193,7 +204,7 @@ def find_job_id_by_name(self, job_name: str) -> Optional[int]: if not matching_jobs: return None else: - return matching_jobs[0]['job_id'] + return matching_jobs[0]["job_id"] def get_run_page_url(self, run_id: int) -> str: """ @@ -202,9 +213,9 @@ def get_run_page_url(self, run_id: int) -> str: :param run_id: id of the run :return: URL of the run page """ - json = {'run_id': run_id} + json = {"run_id": run_id} response = self._do_api_call(GET_RUN_ENDPOINT, json) - return response['run_page_url'] + return response["run_page_url"] async def a_get_run_page_url(self, run_id: int) -> str: """ @@ -212,9 +223,9 @@ async def a_get_run_page_url(self, run_id: int) -> str: :param run_id: id of the run :return: URL of the run page """ - json = {'run_id': run_id} + json = {"run_id": run_id} response = await self._a_do_api_call(GET_RUN_ENDPOINT, json) - return response['run_page_url'] + return response["run_page_url"] def get_job_id(self, run_id: int) -> int: """ @@ -223,9 +234,9 @@ def get_job_id(self, run_id: int) -> int: :param run_id: id of the run :return: Job id for given Databricks run """ - json = {'run_id': run_id} + json = {"run_id": run_id} response = self._do_api_call(GET_RUN_ENDPOINT, json) - return response['job_id'] + return response["job_id"] def get_run_state(self, run_id: int) -> RunState: """ @@ -242,9 +253,9 @@ def get_run_state(self, run_id: int) -> RunState: :param run_id: id of the run :return: state of the run """ - json = {'run_id': run_id} + json = {"run_id": run_id} response = self._do_api_call(GET_RUN_ENDPOINT, json) - state = response['state'] + state = response["state"] return RunState(**state) async def a_get_run_state(self, run_id: int) -> RunState: @@ -253,11 +264,33 @@ async def a_get_run_state(self, run_id: int) -> RunState: :param run_id: id of the run :return: state of the run """ - json = {'run_id': run_id} + json = {"run_id": run_id} response = await self._a_do_api_call(GET_RUN_ENDPOINT, json) - state = response['state'] + state = response["state"] return RunState(**state) + def get_run(self, run_id: int) -> dict[str, Any]: + """ + Retrieve run information. + + :param run_id: id of the run + :return: state of the run + """ + json = {"run_id": run_id} + response = self._do_api_call(GET_RUN_ENDPOINT, json) + return response + + async def a_get_run(self, run_id: int) -> dict[str, Any]: + """ + Async version of `get_run`. + + :param run_id: id of the run + :return: state of the run + """ + json = {"run_id": run_id} + response = await self._a_do_api_call(GET_RUN_ENDPOINT, json) + return response + def get_run_state_str(self, run_id: int) -> str: """ Return the string representation of RunState. @@ -305,7 +338,7 @@ def get_run_output(self, run_id: int) -> dict: :param run_id: id of the run :return: output of the run """ - json = {'run_id': run_id} + json = {"run_id": run_id} run_output = self._do_api_call(OUTPUT_RUNS_JOB_ENDPOINT, json) return run_output @@ -315,7 +348,7 @@ def cancel_run(self, run_id: int) -> None: :param run_id: id of the run """ - json = {'run_id': run_id} + json = {"run_id": run_id} self._do_api_call(CANCEL_RUN_ENDPOINT, json) def restart_cluster(self, json: dict) -> None: @@ -362,7 +395,7 @@ def uninstall(self, json: dict) -> None: """ self._do_api_call(UNINSTALL_LIBS_ENDPOINT, json) - def update_repo(self, repo_id: str, json: Dict[str, Any]) -> dict: + def update_repo(self, repo_id: str, json: dict[str, Any]) -> dict: """ Updates given Databricks Repos @@ -370,7 +403,7 @@ def update_repo(self, repo_id: str, json: Dict[str, Any]) -> dict: :param json: payload :return: metadata from update """ - repos_endpoint = ('PATCH', f'api/2.0/repos/{repo_id}') + repos_endpoint = ("PATCH", f"api/2.0/repos/{repo_id}") return self._do_api_call(repos_endpoint, json) def delete_repo(self, repo_id: str): @@ -380,31 +413,44 @@ def delete_repo(self, repo_id: str): :param repo_id: ID of Databricks Repos :return: """ - repos_endpoint = ('DELETE', f'api/2.0/repos/{repo_id}') + repos_endpoint = ("DELETE", f"api/2.0/repos/{repo_id}") self._do_api_call(repos_endpoint) - def create_repo(self, json: Dict[str, Any]) -> dict: + def create_repo(self, json: dict[str, Any]) -> dict: """ Creates a Databricks Repos :param json: payload :return: """ - repos_endpoint = ('POST', 'api/2.0/repos') + repos_endpoint = ("POST", "api/2.0/repos") return self._do_api_call(repos_endpoint, json) - def get_repo_by_path(self, path: str) -> Optional[str]: + def get_repo_by_path(self, path: str) -> str | None: """ Obtains Repos ID by path :param path: path to a repository :return: Repos ID if it exists, None if doesn't. """ try: - result = self._do_api_call(WORKSPACE_GET_STATUS_ENDPOINT, {'path': path}, wrap_http_errors=False) - if result.get('object_type', '') == 'REPO': - return str(result['object_id']) + result = self._do_api_call(WORKSPACE_GET_STATUS_ENDPOINT, {"path": path}, wrap_http_errors=False) + if result.get("object_type", "") == "REPO": + return str(result["object_id"]) except requests_exceptions.HTTPError as e: if e.response.status_code != 404: raise e return None + + def test_connection(self) -> tuple[bool, str]: + """Test the Databricks connectivity from UI""" + hook = DatabricksHook(databricks_conn_id=self.databricks_conn_id) + try: + hook._do_api_call(endpoint_info=SPARK_VERSIONS_ENDPOINT).get("versions") + status = True + message = "Connection successfully tested" + except Exception as e: + status = False + message = str(e) + + return status, message diff --git a/airflow/providers/databricks/hooks/databricks_base.py b/airflow/providers/databricks/hooks/databricks_base.py index 5b18dad9303ef..50ab2eff6a251 100644 --- a/airflow/providers/databricks/hooks/databricks_base.py +++ b/airflow/providers/databricks/hooks/databricks_base.py @@ -22,11 +22,13 @@ operators talk to the ``api/2.0/jobs/runs/submit`` `endpoint `_. """ +from __future__ import annotations + import copy -import sys +import platform import time -from typing import Any, Dict, Optional, Tuple -from urllib.parse import urlparse +from typing import Any +from urllib.parse import urlsplit import aiohttp import requests @@ -43,16 +45,11 @@ ) from airflow import __version__ +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook from airflow.models import Connection - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - -USER_AGENT_HEADER = {'user-agent': f'airflow-{__version__}'} +from airflow.providers_manager import ProvidersManager # https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/aad/service-prin-aad-token#--get-an-azure-active-directory-access-token # https://docs.microsoft.com/en-us/graph/deployments#app-registration-and-token-service-root-endpoints @@ -81,17 +78,17 @@ class BaseDatabricksHook(BaseHook): :param retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. """ - conn_name_attr = 'databricks_conn_id' - default_conn_name = 'databricks_default' - conn_type = 'databricks' + conn_name_attr: str = "databricks_conn_id" + default_conn_name = "databricks_default" + conn_type = "databricks" extra_parameters = [ - 'token', - 'host', - 'use_azure_managed_identity', - 'azure_ad_endpoint', - 'azure_resource_id', - 'azure_tenant_id', + "token", + "host", + "use_azure_managed_identity", + "azure_ad_endpoint", + "azure_resource_id", + "azure_tenant_id", ] def __init__( @@ -100,25 +97,27 @@ def __init__( timeout_seconds: int = 180, retry_limit: int = 3, retry_delay: float = 1.0, - retry_args: Optional[Dict[Any, Any]] = None, + retry_args: dict[Any, Any] | None = None, + caller: str = "Unknown", ) -> None: super().__init__() self.databricks_conn_id = databricks_conn_id self.timeout_seconds = timeout_seconds if retry_limit < 1: - raise ValueError('Retry limit must be greater than or equal to 1') + raise ValueError("Retry limit must be greater than or equal to 1") self.retry_limit = retry_limit self.retry_delay = retry_delay - self.aad_tokens: Dict[str, dict] = {} + self.aad_tokens: dict[str, dict] = {} self.aad_timeout_seconds = 10 + self.caller = caller def my_after_func(retry_state): self._log_request_error(retry_state.attempt_number, retry_state.outcome) if retry_args: self.retry_args = copy.copy(retry_args) - self.retry_args['retry'] = retry_if_exception(self._retryable_error) - self.retry_args['after'] = my_after_func + self.retry_args["retry"] = retry_if_exception(self._retryable_error) + self.retry_args["after"] = my_after_func else: self.retry_args = dict( stop=stop_after_attempt(self.retry_limit), @@ -134,10 +133,28 @@ def databricks_conn(self) -> Connection: def get_conn(self) -> Connection: return self.databricks_conn + @cached_property + def user_agent_header(self) -> dict[str, str]: + return {"user-agent": self.user_agent_value} + + @cached_property + def user_agent_value(self) -> str: + manager = ProvidersManager() + package_name = manager.hooks[BaseDatabricksHook.conn_type].package_name # type: ignore[union-attr] + provider = manager.providers[package_name] + version = provider.version + python_version = platform.python_version() + system = platform.system().lower() + ua_string = ( + f"databricks-airflow/{version} _/0.0.0 python/{python_version} os/{system} " + f"airflow/{__version__} operator/{self.caller}" + ) + return ua_string + @cached_property def host(self) -> str: - if 'host' in self.databricks_conn.extra_dejson: - host = self._parse_host(self.databricks_conn.extra_dejson['host']) + if "host" in self.databricks_conn.extra_dejson: + host = self._parse_host(self.databricks_conn.extra_dejson["host"]) else: host = self._parse_host(self.databricks_conn.host) @@ -154,8 +171,7 @@ async def __aexit__(self, *err): @staticmethod def _parse_host(host: str) -> str: """ - The purpose of this function is to be robust to improper connections - settings provided by users, specifically in the host field. + This function is resistant to incorrect connection settings provided by users, in the host field. For example -- when users supply ``https://xx.cloud.databricks.com`` as the host, we must strip out the protocol to get the host.:: @@ -170,7 +186,7 @@ def _parse_host(host: str) -> str: assert h._parse_host('xx.cloud.databricks.com') == 'xx.cloud.databricks.com' """ - urlparse_host = urlparse(host).hostname + urlparse_host = urlsplit(host).hostname if urlparse_host: # In this case, host = https://xx.cloud.databricks.com return urlparse_host @@ -180,33 +196,35 @@ def _parse_host(host: str) -> str: def _get_retry_object(self) -> Retrying: """ - Instantiates a retry object + Instantiate a retry object. :return: instance of Retrying class """ return Retrying(**self.retry_args) def _a_get_retry_object(self) -> AsyncRetrying: """ - Instantiates an async retry object + Instantiate an async retry object. :return: instance of AsyncRetrying class """ return AsyncRetrying(**self.retry_args) def _get_aad_token(self, resource: str) -> str: """ - Function to get AAD token for given resource. Supports managed identity or service principal auth + Function to get AAD token for given resource. + + Supports managed identity or service principal auth. :param resource: resource to issue token to :return: AAD token, or raise an exception """ aad_token = self.aad_tokens.get(resource) if aad_token and self._is_aad_token_valid(aad_token): - return aad_token['token'] + return aad_token["token"] - self.log.info('Existing AAD token is expired, or going to expire soon. Refreshing...') + self.log.info("Existing AAD token is expired, or going to expire soon. Refreshing...") try: for attempt in self._get_retry_object(): with attempt: - if self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False): + if self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False): params = { "api-version": "2018-02-01", "resource": resource, @@ -214,11 +232,11 @@ def _get_aad_token(self, resource: str) -> str: resp = requests.get( AZURE_METADATA_SERVICE_TOKEN_URL, params=params, - headers={**USER_AGENT_HEADER, "Metadata": "true"}, + headers={**self.user_agent_header, "Metadata": "true"}, timeout=self.aad_timeout_seconds, ) else: - tenant_id = self.databricks_conn.extra_dejson['azure_tenant_id'] + tenant_id = self.databricks_conn.extra_dejson["azure_tenant_id"] data = { "grant_type": "client_credentials", "client_id": self.databricks_conn.login, @@ -232,8 +250,8 @@ def _get_aad_token(self, resource: str) -> str: AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id), data=data, headers={ - **USER_AGENT_HEADER, - 'Content-Type': 'application/x-www-form-urlencoded', + **self.user_agent_header, + "Content-Type": "application/x-www-form-urlencoded", }, timeout=self.aad_timeout_seconds, ) @@ -241,19 +259,19 @@ def _get_aad_token(self, resource: str) -> str: resp.raise_for_status() jsn = resp.json() if ( - 'access_token' not in jsn - or jsn.get('token_type') != 'Bearer' - or 'expires_on' not in jsn + "access_token" not in jsn + or jsn.get("token_type") != "Bearer" + or "expires_on" not in jsn ): raise AirflowException(f"Can't get necessary data from AAD token: {jsn}") - token = jsn['access_token'] - self.aad_tokens[resource] = {'token': token, 'expires_on': int(jsn["expires_on"])} + token = jsn["access_token"] + self.aad_tokens[resource] = {"token": token, "expires_on": int(jsn["expires_on"])} break except RetryError: - raise AirflowException(f'API requests to Azure failed {self.retry_limit} times. Giving up.') + raise AirflowException(f"API requests to Azure failed {self.retry_limit} times. Giving up.") except requests_exceptions.HTTPError as e: - raise AirflowException(f'Response: {e.response.content}, Status Code: {e.response.status_code}') + raise AirflowException(f"Response: {e.response.content}, Status Code: {e.response.status_code}") return token @@ -265,13 +283,13 @@ async def _a_get_aad_token(self, resource: str) -> str: """ aad_token = self.aad_tokens.get(resource) if aad_token and self._is_aad_token_valid(aad_token): - return aad_token['token'] + return aad_token["token"] - self.log.info('Existing AAD token is expired, or going to expire soon. Refreshing...') + self.log.info("Existing AAD token is expired, or going to expire soon. Refreshing...") try: async for attempt in self._a_get_retry_object(): with attempt: - if self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False): + if self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False): params = { "api-version": "2018-02-01", "resource": resource, @@ -279,13 +297,13 @@ async def _a_get_aad_token(self, resource: str) -> str: async with self._session.get( url=AZURE_METADATA_SERVICE_TOKEN_URL, params=params, - headers={**USER_AGENT_HEADER, "Metadata": "true"}, + headers={**self.user_agent_header, "Metadata": "true"}, timeout=self.aad_timeout_seconds, ) as resp: resp.raise_for_status() jsn = await resp.json() else: - tenant_id = self.databricks_conn.extra_dejson['azure_tenant_id'] + tenant_id = self.databricks_conn.extra_dejson["azure_tenant_id"] data = { "grant_type": "client_credentials", "client_id": self.databricks_conn.login, @@ -299,42 +317,42 @@ async def _a_get_aad_token(self, resource: str) -> str: url=AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id), data=data, headers={ - **USER_AGENT_HEADER, - 'Content-Type': 'application/x-www-form-urlencoded', + **self.user_agent_header, + "Content-Type": "application/x-www-form-urlencoded", }, timeout=self.aad_timeout_seconds, ) as resp: resp.raise_for_status() jsn = await resp.json() if ( - 'access_token' not in jsn - or jsn.get('token_type') != 'Bearer' - or 'expires_on' not in jsn + "access_token" not in jsn + or jsn.get("token_type") != "Bearer" + or "expires_on" not in jsn ): raise AirflowException(f"Can't get necessary data from AAD token: {jsn}") - token = jsn['access_token'] - self.aad_tokens[resource] = {'token': token, 'expires_on': int(jsn["expires_on"])} + token = jsn["access_token"] + self.aad_tokens[resource] = {"token": token, "expires_on": int(jsn["expires_on"])} break except RetryError: - raise AirflowException(f'API requests to Azure failed {self.retry_limit} times. Giving up.') + raise AirflowException(f"API requests to Azure failed {self.retry_limit} times. Giving up.") except aiohttp.ClientResponseError as err: - raise AirflowException(f'Response: {err.message}, Status Code: {err.status}') + raise AirflowException(f"Response: {err.message}, Status Code: {err.status}") return token def _get_aad_headers(self) -> dict: """ - Fills AAD headers if necessary (SPN is outside of the workspace) + Fill AAD headers if necessary (SPN is outside of the workspace). :return: dictionary with filled AAD headers """ headers = {} - if 'azure_resource_id' in self.databricks_conn.extra_dejson: + if "azure_resource_id" in self.databricks_conn.extra_dejson: mgmt_token = self._get_aad_token(AZURE_MANAGEMENT_ENDPOINT) - headers['X-Databricks-Azure-Workspace-Resource-Id'] = self.databricks_conn.extra_dejson[ - 'azure_resource_id' + headers["X-Databricks-Azure-Workspace-Resource-Id"] = self.databricks_conn.extra_dejson[ + "azure_resource_id" ] - headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token + headers["X-Databricks-Azure-SP-Management-Token"] = mgmt_token return headers async def _a_get_aad_headers(self) -> dict: @@ -343,31 +361,31 @@ async def _a_get_aad_headers(self) -> dict: :return: dictionary with filled AAD headers """ headers = {} - if 'azure_resource_id' in self.databricks_conn.extra_dejson: + if "azure_resource_id" in self.databricks_conn.extra_dejson: mgmt_token = await self._a_get_aad_token(AZURE_MANAGEMENT_ENDPOINT) - headers['X-Databricks-Azure-Workspace-Resource-Id'] = self.databricks_conn.extra_dejson[ - 'azure_resource_id' + headers["X-Databricks-Azure-Workspace-Resource-Id"] = self.databricks_conn.extra_dejson[ + "azure_resource_id" ] - headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token + headers["X-Databricks-Azure-SP-Management-Token"] = mgmt_token return headers @staticmethod def _is_aad_token_valid(aad_token: dict) -> bool: """ - Utility function to check AAD token hasn't expired yet + Utility function to check AAD token hasn't expired yet. + :param aad_token: dict with properties of AAD token :return: true if token is valid, false otherwise - :rtype: bool """ now = int(time.time()) - if aad_token['expires_on'] > (now + TOKEN_REFRESH_LEAD_TIME): + if aad_token["expires_on"] > (now + TOKEN_REFRESH_LEAD_TIME): return True return False @staticmethod def _check_azure_metadata_service() -> None: """ - Check for Azure Metadata Service + Check for Azure Metadata Service. https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service """ try: @@ -377,7 +395,7 @@ def _check_azure_metadata_service() -> None: headers={"Metadata": "true"}, timeout=2, ).json() - if 'compute' not in jsn or 'azEnvironment' not in jsn['compute']: + if "compute" not in jsn or "azEnvironment" not in jsn["compute"]: raise AirflowException( f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}" ) @@ -394,113 +412,112 @@ async def _a_check_azure_metadata_service(self): timeout=2, ) as resp: jsn = await resp.json() - if 'compute' not in jsn or 'azEnvironment' not in jsn['compute']: + if "compute" not in jsn or "azEnvironment" not in jsn["compute"]: raise AirflowException( f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}" ) except (requests_exceptions.RequestException, ValueError) as e: raise AirflowException(f"Can't reach Azure Metadata Service: {e}") - def _get_token(self, raise_error: bool = False) -> Optional[str]: - if 'token' in self.databricks_conn.extra_dejson: + def _get_token(self, raise_error: bool = False) -> str | None: + if "token" in self.databricks_conn.extra_dejson: self.log.info( - 'Using token auth. For security reasons, please set token in Password field instead of extra' + "Using token auth. For security reasons, please set token in Password field instead of extra" ) - return self.databricks_conn.extra_dejson['token'] + return self.databricks_conn.extra_dejson["token"] elif not self.databricks_conn.login and self.databricks_conn.password: - self.log.info('Using token auth.') + self.log.info("Using token auth.") return self.databricks_conn.password - elif 'azure_tenant_id' in self.databricks_conn.extra_dejson: + elif "azure_tenant_id" in self.databricks_conn.extra_dejson: if self.databricks_conn.login == "" or self.databricks_conn.password == "": raise AirflowException("Azure SPN credentials aren't provided") - self.log.info('Using AAD Token for SPN.') + self.log.info("Using AAD Token for SPN.") return self._get_aad_token(DEFAULT_DATABRICKS_SCOPE) - elif self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False): - self.log.info('Using AAD Token for managed identity.') + elif self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False): + self.log.info("Using AAD Token for managed identity.") self._check_azure_metadata_service() return self._get_aad_token(DEFAULT_DATABRICKS_SCOPE) elif raise_error: - raise AirflowException('Token authentication isn\'t configured') + raise AirflowException("Token authentication isn't configured") return None - async def _a_get_token(self, raise_error: bool = False) -> Optional[str]: - if 'token' in self.databricks_conn.extra_dejson: + async def _a_get_token(self, raise_error: bool = False) -> str | None: + if "token" in self.databricks_conn.extra_dejson: self.log.info( - 'Using token auth. For security reasons, please set token in Password field instead of extra' + "Using token auth. For security reasons, please set token in Password field instead of extra" ) return self.databricks_conn.extra_dejson["token"] elif not self.databricks_conn.login and self.databricks_conn.password: - self.log.info('Using token auth.') + self.log.info("Using token auth.") return self.databricks_conn.password - elif 'azure_tenant_id' in self.databricks_conn.extra_dejson: + elif "azure_tenant_id" in self.databricks_conn.extra_dejson: if self.databricks_conn.login == "" or self.databricks_conn.password == "": raise AirflowException("Azure SPN credentials aren't provided") - self.log.info('Using AAD Token for SPN.') + self.log.info("Using AAD Token for SPN.") return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE) - elif self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False): - self.log.info('Using AAD Token for managed identity.') + elif self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False): + self.log.info("Using AAD Token for managed identity.") await self._a_check_azure_metadata_service() return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE) elif raise_error: - raise AirflowException('Token authentication isn\'t configured') + raise AirflowException("Token authentication isn't configured") return None def _log_request_error(self, attempt_num: int, error: str) -> None: - self.log.error('Attempt %s API Request to Databricks failed with reason: %s', attempt_num, error) + self.log.error("Attempt %s API Request to Databricks failed with reason: %s", attempt_num, error) def _do_api_call( self, - endpoint_info: Tuple[str, str], - json: Optional[Dict[str, Any]] = None, + endpoint_info: tuple[str, str], + json: dict[str, Any] | None = None, wrap_http_errors: bool = True, ): """ - Utility function to perform an API call with retries + Utility function to perform an API call with retries. :param endpoint_info: Tuple of method and endpoint :param json: Parameters for this API call. :return: If the api call returns a OK status code, this function returns the response in JSON. Otherwise, we throw an AirflowException. - :rtype: dict """ method, endpoint = endpoint_info # TODO: get rid of explicit 'api/' in the endpoint specification - url = f'https://{self.host}/{endpoint}' + url = f"https://{self.host}/{endpoint}" aad_headers = self._get_aad_headers() - headers = {**USER_AGENT_HEADER.copy(), **aad_headers} + headers = {**self.user_agent_header, **aad_headers} auth: AuthBase token = self._get_token() if token: auth = _TokenAuth(token) else: - self.log.info('Using basic auth.') + self.log.info("Using basic auth.") auth = HTTPBasicAuth(self.databricks_conn.login, self.databricks_conn.password) request_func: Any - if method == 'GET': + if method == "GET": request_func = requests.get - elif method == 'POST': + elif method == "POST": request_func = requests.post - elif method == 'PATCH': + elif method == "PATCH": request_func = requests.patch - elif method == 'DELETE': + elif method == "DELETE": request_func = requests.delete else: - raise AirflowException('Unexpected HTTP Method: ' + method) + raise AirflowException("Unexpected HTTP Method: " + method) try: for attempt in self._get_retry_object(): with attempt: response = request_func( url, - json=json if method in ('POST', 'PATCH') else None, - params=json if method == 'GET' else None, + json=json if method in ("POST", "PATCH") else None, + params=json if method == "GET" else None, auth=auth, headers=headers, timeout=self.timeout_seconds, @@ -508,16 +525,16 @@ def _do_api_call( response.raise_for_status() return response.json() except RetryError: - raise AirflowException(f'API requests to Databricks failed {self.retry_limit} times. Giving up.') + raise AirflowException(f"API requests to Databricks failed {self.retry_limit} times. Giving up.") except requests_exceptions.HTTPError as e: if wrap_http_errors: raise AirflowException( - f'Response: {e.response.content}, Status Code: {e.response.status_code}' + f"Response: {e.response.content}, Status Code: {e.response.status_code}" ) else: raise e - async def _a_do_api_call(self, endpoint_info: Tuple[str, str], json: Optional[Dict[str, Any]] = None): + async def _a_do_api_call(self, endpoint_info: tuple[str, str], json: dict[str, Any] | None = None): """ Async version of `_do_api_call()`. :param endpoint_info: Tuple of method and endpoint @@ -527,28 +544,28 @@ async def _a_do_api_call(self, endpoint_info: Tuple[str, str], json: Optional[Di """ method, endpoint = endpoint_info - url = f'https://{self.host}/{endpoint}' + url = f"https://{self.host}/{endpoint}" aad_headers = await self._a_get_aad_headers() - headers = {**USER_AGENT_HEADER.copy(), **aad_headers} + headers = {**self.user_agent_header, **aad_headers} auth: aiohttp.BasicAuth token = await self._a_get_token() if token: auth = BearerAuth(token) else: - self.log.info('Using basic auth.') + self.log.info("Using basic auth.") auth = aiohttp.BasicAuth(self.databricks_conn.login, self.databricks_conn.password) request_func: Any - if method == 'GET': + if method == "GET": request_func = self._session.get - elif method == 'POST': + elif method == "POST": request_func = self._session.post - elif method == 'PATCH': + elif method == "PATCH": request_func = self._session.patch else: - raise AirflowException('Unexpected HTTP Method: ' + method) + raise AirflowException("Unexpected HTTP Method: " + method) try: async for attempt in self._a_get_retry_object(): with attempt: @@ -556,22 +573,22 @@ async def _a_do_api_call(self, endpoint_info: Tuple[str, str], json: Optional[Di url, json=json, auth=auth, - headers={**headers, **USER_AGENT_HEADER}, + headers={**headers, **self.user_agent_header}, timeout=self.timeout_seconds, ) as response: response.raise_for_status() return await response.json() except RetryError: - raise AirflowException(f'API requests to Databricks failed {self.retry_limit} times. Giving up.') + raise AirflowException(f"API requests to Databricks failed {self.retry_limit} times. Giving up.") except aiohttp.ClientResponseError as err: - raise AirflowException(f'Response: {err.message}, Status Code: {err.status}') + raise AirflowException(f"Response: {err.message}, Status Code: {err.status}") @staticmethod def _get_error_code(exception: BaseException) -> str: if isinstance(exception, requests_exceptions.HTTPError): try: jsn = exception.response.json() - return jsn.get('error_code', '') + return jsn.get("error_code", "") except JSONDecodeError: pass @@ -587,7 +604,7 @@ def _retryable_error(exception: BaseException) -> bool: or exception.response.status_code == 429 or ( exception.response.status_code == 400 - and BaseDatabricksHook._get_error_code(exception) == 'COULD_NOT_ACQUIRE_LOCK' + and BaseDatabricksHook._get_error_code(exception) == "COULD_NOT_ACQUIRE_LOCK" ) ) ): @@ -602,7 +619,9 @@ def _retryable_error(exception: BaseException) -> bool: class _TokenAuth(AuthBase): """ - Helper class for requests Auth field. AuthBase requires you to implement the __call__ + Helper class for requests Auth field. + + AuthBase requires you to implement the ``__call__`` magic function. """ @@ -610,18 +629,18 @@ def __init__(self, token: str) -> None: self.token = token def __call__(self, r: PreparedRequest) -> PreparedRequest: - r.headers['Authorization'] = 'Bearer ' + self.token + r.headers["Authorization"] = "Bearer " + self.token return r class BearerAuth(aiohttp.BasicAuth): """aiohttp only ships BasicAuth, for Bearer auth we need a subclass of BasicAuth.""" - def __new__(cls, token: str) -> 'BearerAuth': + def __new__(cls, token: str) -> BearerAuth: return super().__new__(cls, token) # type: ignore def __init__(self, token: str) -> None: self.token = token def encode(self) -> str: - return f'Bearer {self.token}' + return f"Bearer {self.token}" diff --git a/airflow/providers/databricks/hooks/databricks_sql.py b/airflow/providers/databricks/hooks/databricks_sql.py index 9d86b4dbe3ae6..f042435943889 100644 --- a/airflow/providers/databricks/hooks/databricks_sql.py +++ b/airflow/providers/databricks/hooks/databricks_sql.py @@ -14,22 +14,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -import re from contextlib import closing from copy import copy -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Iterable, Mapping from databricks import sql # type: ignore[attr-defined] from databricks.sql.client import Connection # type: ignore[attr-defined] -from airflow import __version__ from airflow.exceptions import AirflowException -from airflow.hooks.dbapi import DbApiHook +from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook -LIST_SQL_ENDPOINTS_ENDPOINT = ('GET', 'api/2.0/sql/endpoints') -USER_AGENT_STRING = f'airflow-{__version__}' +LIST_SQL_ENDPOINTS_ENDPOINT = ("GET", "api/2.0/sql/endpoints") class DatabricksSqlHook(BaseDatabricksHook, DbApiHook): @@ -52,22 +50,24 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook): :param kwargs: Additional parameters internal to Databricks SQL Connector parameters """ - hook_name = 'Databricks SQL' + hook_name = "Databricks SQL" + _test_connection_sql = "select 42" def __init__( self, databricks_conn_id: str = BaseDatabricksHook.default_conn_name, - http_path: Optional[str] = None, - sql_endpoint_name: Optional[str] = None, - session_configuration: Optional[Dict[str, str]] = None, - http_headers: Optional[List[Tuple[str, str]]] = None, - catalog: Optional[str] = None, - schema: Optional[str] = None, + http_path: str | None = None, + sql_endpoint_name: str | None = None, + session_configuration: dict[str, str] | None = None, + http_headers: list[tuple[str, str]] | None = None, + catalog: str | None = None, + schema: str | None = None, + caller: str = "DatabricksSqlHook", **kwargs, ) -> None: - super().__init__(databricks_conn_id) + super().__init__(databricks_conn_id, caller=caller) self._sql_conn = None - self._token: Optional[str] = None + self._token: str | None = None self._http_path = http_path self._sql_endpoint_name = sql_endpoint_name self.supports_autocommit = True @@ -77,19 +77,19 @@ def __init__( self.schema = schema self.additional_params = kwargs - def _get_extra_config(self) -> Dict[str, Optional[Any]]: + def _get_extra_config(self) -> dict[str, Any | None]: extra_params = copy(self.databricks_conn.extra_dejson) - for arg in ['http_path', 'session_configuration'] + self.extra_parameters: + for arg in ["http_path", "session_configuration"] + self.extra_parameters: if arg in extra_params: del extra_params[arg] return extra_params - def _get_sql_endpoint_by_name(self, endpoint_name) -> Dict[str, Any]: + def _get_sql_endpoint_by_name(self, endpoint_name) -> dict[str, Any]: result = self._do_api_call(LIST_SQL_ENDPOINTS_ENDPOINT) - if 'endpoints' not in result: + if "endpoints" not in result: raise AirflowException("Can't list Databricks SQL endpoints") - lst = [endpoint for endpoint in result['endpoints'] if endpoint['name'] == endpoint_name] + lst = [endpoint for endpoint in result["endpoints"] if endpoint["name"] == endpoint_name] if len(lst) == 0: raise AirflowException(f"Can't f Databricks SQL endpoint with name '{endpoint_name}'") return lst[0] @@ -99,9 +99,9 @@ def get_conn(self) -> Connection: if not self._http_path: if self._sql_endpoint_name: endpoint = self._get_sql_endpoint_by_name(self._sql_endpoint_name) - self._http_path = endpoint['odbc_params']['path'] - elif 'http_path' in self.databricks_conn.extra_dejson: - self._http_path = self.databricks_conn.extra_dejson['http_path'] + self._http_path = endpoint["odbc_params"]["path"] + elif "http_path" in self.databricks_conn.extra_dejson: + self._http_path = self.databricks_conn.extra_dejson["http_path"] else: raise AirflowException( "http_path should be provided either explicitly, " @@ -120,7 +120,7 @@ def get_conn(self) -> Connection: requires_init = False if not self.session_config: - self.session_config = self.databricks_conn.extra_dejson.get('session_configuration') + self.session_config = self.databricks_conn.extra_dejson.get("session_configuration") if not self._sql_conn or requires_init: if self._sql_conn: # close already existing connection @@ -133,25 +133,21 @@ def get_conn(self) -> Connection: catalog=self.catalog, session_configuration=self.session_config, http_headers=self.http_headers, - _user_agent_entry=USER_AGENT_STRING, + _user_agent_entry=self.user_agent_value, **self._get_extra_config(), **self.additional_params, ) return self._sql_conn - @staticmethod - def maybe_split_sql_string(sql: str) -> List[str]: - """ - Splits strings consisting of multiple SQL expressions into an - TODO: do we need something more sophisticated? - - :param sql: SQL string potentially consisting of multiple expressions - :return: list of individual expressions - """ - splits = [s.strip() for s in re.split(";\\s*\r?\n", sql) if s.strip() != ""] - return splits - - def run(self, sql: Union[str, List[str]], autocommit=True, parameters=None, handler=None): + def run( + self, + sql: str | Iterable[str], + autocommit: bool = False, + parameters: Iterable | Mapping | None = None, + handler: Callable | None = None, + split_statements: bool = True, + return_last: bool = True, + ) -> Any | list[Any] | None: """ Runs a command or a list of commands. Pass a list of sql statements to the sql parameter to get them to execute @@ -163,49 +159,44 @@ def run(self, sql: Union[str, List[str]], autocommit=True, parameters=None, hand before executing the query. :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. - :return: query results. + :param split_statements: Whether to split a single SQL string into statements and run separately + :param return_last: Whether to return result for only last statement or for all after split + :return: return only result of the LAST SQL expression if handler was provided. """ + self.scalar_return_last = isinstance(sql, str) and return_last if isinstance(sql, str): - sql = self.maybe_split_sql_string(sql) + if split_statements: + sql = self.split_sql_string(sql) + else: + sql = [self.strip_sql_string(sql)] if sql: - self.log.debug("Executing %d statements", len(sql)) + self.log.debug("Executing following statements against Databricks DB: %s", list(sql)) else: raise ValueError("List of SQL statements is empty") - conn = None + results = [] for sql_statement in sql: # when using AAD tokens, it could expire if previous query run longer than token lifetime - conn = self.get_conn() - with closing(conn.cursor()) as cur: - self.log.info("Executing statement: '%s', parameters: '%s'", sql_statement, parameters) - if parameters: - cur.execute(sql_statement, parameters) - else: - cur.execute(sql_statement) - schema = cur.description - results = [] - if handler is not None: - cur = handler(cur) - for row in cur: - self.log.debug("Statement results: %s", row) - results.append(row) - - self.log.info("Rows affected: %s", cur.rowcount) - if conn: - conn.close() - self._sql_conn = None + with closing(self.get_conn()) as conn: + self.set_autocommit(conn, autocommit) + + with closing(conn.cursor()) as cur: + self._run_command(cur, sql_statement, parameters) - # Return only result of the last SQL expression - return schema, results + if handler is not None: + result = handler(cur) + results.append(result) + self.last_description = cur.description - def test_connection(self): - """Test the Databricks SQL connection by running a simple query.""" - try: - self.run(sql="select 42") - except Exception as e: - return False, str(e) - return True, "Connection successfully checked" + self._sql_conn = None + + if handler is None: + return None + elif self.scalar_return_last: + return results[-1] + else: + return results def bulk_dump(self, table, tmp_file): raise NotImplementedError() diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 8af4474b1315f..dbb44c70829a5 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -15,26 +15,27 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains Databricks operators.""" +from __future__ import annotations import time from logging import Logger -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Sequence +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.models import BaseOperator, BaseOperatorLink, XCom from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunState from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger -from airflow.providers.databricks.utils.databricks import deep_string_coerce, validate_trigger_event +from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstanceKey from airflow.utils.context import Context -DEFER_METHOD_NAME = 'execute_complete' -XCOM_RUN_ID_KEY = 'run_id' -XCOM_RUN_PAGE_URL_KEY = 'run_page_url' +DEFER_METHOD_NAME = "execute_complete" +XCOM_RUN_ID_KEY = "run_id" +XCOM_RUN_PAGE_URL_KEY = "run_page_url" def _handle_databricks_operator_execution(operator, hook, log, context) -> None: @@ -45,35 +46,54 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None: :param context: Airflow context """ if operator.do_xcom_push and context is not None: - context['ti'].xcom_push(key=XCOM_RUN_ID_KEY, value=operator.run_id) - log.info('Run submitted with run_id: %s', operator.run_id) + context["ti"].xcom_push(key=XCOM_RUN_ID_KEY, value=operator.run_id) + log.info("Run submitted with run_id: %s", operator.run_id) run_page_url = hook.get_run_page_url(operator.run_id) if operator.do_xcom_push and context is not None: - context['ti'].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=run_page_url) + context["ti"].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=run_page_url) if operator.wait_for_termination: while True: - run_state = hook.get_run_state(operator.run_id) + run_info = hook.get_run(operator.run_id) + run_state = RunState(**run_info["state"]) if run_state.is_terminal: if run_state.is_successful: - log.info('%s completed successfully.', operator.task_id) - log.info('View run status, Spark UI, and logs at %s', run_page_url) + log.info("%s completed successfully.", operator.task_id) + log.info("View run status, Spark UI, and logs at %s", run_page_url) return else: - run_output = hook.get_run_output(operator.run_id) - notebook_error = run_output['error'] - error_message = ( - f'{operator.task_id} failed with terminal state: {run_state} ' - f'and with the error {notebook_error}' - ) + if run_state.result_state == "FAILED": + task_run_id = None + if "tasks" in run_info: + for task in run_info["tasks"]: + if task.get("state", {}).get("result_state", "") == "FAILED": + task_run_id = task["run_id"] + if task_run_id is not None: + run_output = hook.get_run_output(task_run_id) + if "error" in run_output: + notebook_error = run_output["error"] + else: + notebook_error = run_state.state_message + else: + notebook_error = run_state.state_message + error_message = ( + f"{operator.task_id} failed with terminal state: {run_state} " + f"and with the error {notebook_error}" + ) + else: + error_message = ( + f"{operator.task_id} failed with terminal state: {run_state} " + f"and with the error {run_state.state_message}" + ) raise AirflowException(error_message) + else: - log.info('%s in run state: %s', operator.task_id, run_state) - log.info('View run status, Spark UI, and logs at %s', run_page_url) - log.info('Sleeping for %s seconds.', operator.polling_period_seconds) + log.info("%s in run state: %s", operator.task_id, run_state) + log.info("View run status, Spark UI, and logs at %s", run_page_url) + log.info("Sleeping for %s seconds.", operator.polling_period_seconds) time.sleep(operator.polling_period_seconds) else: - log.info('View run status, Spark UI, and logs at %s', run_page_url) + log.info("View run status, Spark UI, and logs at %s", run_page_url) def _handle_deferrable_databricks_operator_execution(operator, hook, log, context) -> None: @@ -84,13 +104,13 @@ def _handle_deferrable_databricks_operator_execution(operator, hook, log, contex :param context: Airflow context """ if operator.do_xcom_push and context is not None: - context['ti'].xcom_push(key=XCOM_RUN_ID_KEY, value=operator.run_id) - log.info(f'Run submitted with run_id: {operator.run_id}') + context["ti"].xcom_push(key=XCOM_RUN_ID_KEY, value=operator.run_id) + log.info("Run submitted with run_id: %s", operator.run_id) run_page_url = hook.get_run_page_url(operator.run_id) if operator.do_xcom_push and context is not None: - context['ti'].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=run_page_url) - log.info(f'View run status, Spark UI, and logs at {run_page_url}') + context["ti"].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=run_page_url) + log.info("View run status, Spark UI, and logs at %s", run_page_url) if operator.wait_for_termination: operator.defer( @@ -105,15 +125,15 @@ def _handle_deferrable_databricks_operator_execution(operator, hook, log, contex def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger) -> None: validate_trigger_event(event) - run_state = RunState.from_json(event['run_state']) - run_page_url = event['run_page_url'] - log.info(f'View run status, Spark UI, and logs at {run_page_url}') + run_state = RunState.from_json(event["run_state"]) + run_page_url = event["run_page_url"] + log.info("View run status, Spark UI, and logs at %s", run_page_url) if run_state.is_successful: - log.info('Job run completed successfully.') + log.info("Job run completed successfully.") return else: - error_message = f'Job run failed with terminal state: {run_state}' + error_message = f"Job run failed with terminal state: {run_state}" raise AirflowException(error_message) @@ -124,23 +144,11 @@ class DatabricksJobRunLink(BaseOperatorLink): def get_link( self, - operator, - dttm=None, + operator: BaseOperator, *, - ti_key: Optional["TaskInstanceKey"] = None, + ti_key: TaskInstanceKey, ) -> str: - if ti_key is not None: - run_page_url = XCom.get_value(key=XCOM_RUN_PAGE_URL_KEY, ti_key=ti_key) - else: - assert dttm - run_page_url = XCom.get_one( - key=XCOM_RUN_PAGE_URL_KEY, - dag_id=operator.dag.dag_id, - task_id=operator.task_id, - execution_date=dttm, - ) - - return run_page_url + return XCom.get_value(key=XCOM_RUN_PAGE_URL_KEY, ti_key=ti_key) class DatabricksSubmitRunOperator(BaseOperator): @@ -150,62 +158,16 @@ class DatabricksSubmitRunOperator(BaseOperator): `_ API endpoint. - There are two ways to instantiate this operator. - - In the first way, you can take the JSON payload that you typically use - to call the ``api/2.1/jobs/runs/submit`` endpoint and pass it directly - to our ``DatabricksSubmitRunOperator`` through the ``json`` parameter. - For example :: - - json = { - 'new_cluster': { - 'spark_version': '2.1.0-db3-scala2.11', - 'num_workers': 2 - }, - 'notebook_task': { - 'notebook_path': '/Users/airflow@example.com/PrepareData', - }, - } - notebook_run = DatabricksSubmitRunOperator(task_id='notebook_run', json=json) - - Another way to accomplish the same thing is to use the named parameters - of the ``DatabricksSubmitRunOperator`` directly. Note that there is exactly - one named parameter for each top level parameter in the ``runs/submit`` - endpoint. In this method, your code would look like this: :: - - new_cluster = { - 'spark_version': '10.1.x-scala2.12', - 'num_workers': 2 - } - notebook_task = { - 'notebook_path': '/Users/airflow@example.com/PrepareData', - } - notebook_run = DatabricksSubmitRunOperator( - task_id='notebook_run', - new_cluster=new_cluster, - notebook_task=notebook_task) - - In the case where both the json parameter **AND** the named parameters - are provided, they will be merged together. If there are conflicts during the merge, - the named parameters will take precedence and override the top level ``json`` keys. - - Currently the named parameters that ``DatabricksSubmitRunOperator`` supports are - - ``spark_jar_task`` - - ``notebook_task`` - - ``spark_python_task`` - - ``spark_jar_task`` - - ``spark_submit_task`` - - ``pipeline_task`` - - ``new_cluster`` - - ``existing_cluster_id`` - - ``libraries`` - - ``run_name`` - - ``timeout_seconds`` + There are three ways to instantiate this operator. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:DatabricksSubmitRunOperator` + :param tasks: Array of Objects(RunSubmitTaskSettings) <= 100 items. + + .. seealso:: + https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit :param json: A JSON object containing API parameters which will be passed directly to the ``api/2.1/jobs/runs/submit`` endpoint. The other named parameters (i.e. ``spark_jar_task``, ``notebook_task``..) to this operator will @@ -219,28 +181,28 @@ class DatabricksSubmitRunOperator(BaseOperator): :param spark_jar_task: The main class and parameters for the JAR task. Note that the actual JAR is specified in the ``libraries``. *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` *OR* ``spark_python_task`` - *OR* ``spark_submit_task`` *OR* ``pipeline_task`` should be specified. + *OR* ``spark_submit_task`` *OR* ``pipeline_task`` *OR* ``dbt_task`` should be specified. This field will be templated. .. seealso:: https://docs.databricks.com/dev-tools/api/2.0/jobs.html#jobssparkjartask :param notebook_task: The notebook path and parameters for the notebook task. *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` *OR* ``spark_python_task`` - *OR* ``spark_submit_task`` *OR* ``pipeline_task`` should be specified. + *OR* ``spark_submit_task`` *OR* ``pipeline_task`` *OR* ``dbt_task`` should be specified. This field will be templated. .. seealso:: https://docs.databricks.com/dev-tools/api/2.0/jobs.html#jobsnotebooktask :param spark_python_task: The python file path and parameters to run the python file with. *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` *OR* ``spark_python_task`` - *OR* ``spark_submit_task`` *OR* ``pipeline_task`` should be specified. + *OR* ``spark_submit_task`` *OR* ``pipeline_task`` *OR* ``dbt_task`` should be specified. This field will be templated. .. seealso:: https://docs.databricks.com/dev-tools/api/2.0/jobs.html#jobssparkpythontask :param spark_submit_task: Parameters needed to run a spark-submit command. *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` *OR* ``spark_python_task`` - *OR* ``spark_submit_task`` *OR* ``pipeline_task`` should be specified. + *OR* ``spark_submit_task`` *OR* ``pipeline_task`` *OR* ``dbt_task`` should be specified. This field will be templated. .. seealso:: @@ -248,11 +210,18 @@ class DatabricksSubmitRunOperator(BaseOperator): :param pipeline_task: Parameters needed to execute a Delta Live Tables pipeline task. The provided dictionary must contain at least ``pipeline_id`` field! *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` *OR* ``spark_python_task`` - *OR* ``spark_submit_task`` *OR* ``pipeline_task`` should be specified. + *OR* ``spark_submit_task`` *OR* ``pipeline_task`` *OR* ``dbt_task`` should be specified. This field will be templated. .. seealso:: https://docs.databricks.com/dev-tools/api/2.0/jobs.html#jobspipelinetask + :param dbt_task: Parameters needed to execute a dbt task. + The provided dictionary must contain at least the ``commands`` field and the + ``git_source`` parameter also needs to be set. + *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` *OR* ``spark_python_task`` + *OR* ``spark_submit_task`` *OR* ``pipeline_task`` *OR* ``dbt_task`` should be specified. + This field will be templated. + :param new_cluster: Specs for a new cluster on which this task will be run. *EITHER* ``new_cluster`` *OR* ``existing_cluster_id`` should be specified (except when ``pipeline_task`` is used). @@ -287,7 +256,7 @@ class DatabricksSubmitRunOperator(BaseOperator): :param databricks_conn_id: Reference to the :ref:`Databricks connection `. By default and in the common case this will be ``databricks_default``. To use token based authentication, provide the key ``token`` in the extra field for the - connection and create the key ``host`` and leave the ``host`` field empty. + connection and create the key ``host`` and leave the ``host`` field empty. (templated) :param polling_period_seconds: Controls the rate which we poll for the result of this run. By default the operator will poll every 30 seconds. :param databricks_retry_limit: Amount of times retry if the Databricks backend is @@ -304,38 +273,39 @@ class DatabricksSubmitRunOperator(BaseOperator): """ # Used in airflow.models.BaseOperator - template_fields: Sequence[str] = ('json',) - template_ext: Sequence[str] = ('.json',) + template_fields: Sequence[str] = ("json", "databricks_conn_id") + template_ext: Sequence[str] = (".json-tpl",) # Databricks brand color (blue) under white text - ui_color = '#1CB1C2' - ui_fgcolor = '#fff' + ui_color = "#1CB1C2" + ui_fgcolor = "#fff" operator_extra_links = (DatabricksJobRunLink(),) def __init__( self, *, - json: Optional[Any] = None, - tasks: Optional[List[object]] = None, - spark_jar_task: Optional[Dict[str, str]] = None, - notebook_task: Optional[Dict[str, str]] = None, - spark_python_task: Optional[Dict[str, Union[str, List[str]]]] = None, - spark_submit_task: Optional[Dict[str, List[str]]] = None, - pipeline_task: Optional[Dict[str, str]] = None, - new_cluster: Optional[Dict[str, object]] = None, - existing_cluster_id: Optional[str] = None, - libraries: Optional[List[Dict[str, str]]] = None, - run_name: Optional[str] = None, - timeout_seconds: Optional[int] = None, - databricks_conn_id: str = 'databricks_default', + json: Any | None = None, + tasks: list[object] | None = None, + spark_jar_task: dict[str, str] | None = None, + notebook_task: dict[str, str] | None = None, + spark_python_task: dict[str, str | list[str]] | None = None, + spark_submit_task: dict[str, list[str]] | None = None, + pipeline_task: dict[str, str] | None = None, + dbt_task: dict[str, str | list[str]] | None = None, + new_cluster: dict[str, object] | None = None, + existing_cluster_id: str | None = None, + libraries: list[dict[str, str]] | None = None, + run_name: str | None = None, + timeout_seconds: int | None = None, + databricks_conn_id: str = "databricks_default", polling_period_seconds: int = 30, databricks_retry_limit: int = 3, databricks_retry_delay: int = 1, - databricks_retry_args: Optional[Dict[Any, Any]] = None, + databricks_retry_args: dict[Any, Any] | None = None, do_xcom_push: bool = True, - idempotency_token: Optional[str] = None, - access_control_list: Optional[List[Dict[str, str]]] = None, + idempotency_token: str | None = None, + access_control_list: list[dict[str, str]] | None = None, wait_for_termination: bool = True, - git_source: Optional[Dict[str, str]] = None, + git_source: dict[str, str] | None = None, **kwargs, ) -> None: """Creates a new ``DatabricksSubmitRunOperator``.""" @@ -348,74 +318,82 @@ def __init__( self.databricks_retry_args = databricks_retry_args self.wait_for_termination = wait_for_termination if tasks is not None: - self.json['tasks'] = tasks + self.json["tasks"] = tasks if spark_jar_task is not None: - self.json['spark_jar_task'] = spark_jar_task + self.json["spark_jar_task"] = spark_jar_task if notebook_task is not None: - self.json['notebook_task'] = notebook_task + self.json["notebook_task"] = notebook_task if spark_python_task is not None: - self.json['spark_python_task'] = spark_python_task + self.json["spark_python_task"] = spark_python_task if spark_submit_task is not None: - self.json['spark_submit_task'] = spark_submit_task + self.json["spark_submit_task"] = spark_submit_task if pipeline_task is not None: - self.json['pipeline_task'] = pipeline_task + self.json["pipeline_task"] = pipeline_task + if dbt_task is not None: + self.json["dbt_task"] = dbt_task if new_cluster is not None: - self.json['new_cluster'] = new_cluster + self.json["new_cluster"] = new_cluster if existing_cluster_id is not None: - self.json['existing_cluster_id'] = existing_cluster_id + self.json["existing_cluster_id"] = existing_cluster_id if libraries is not None: - self.json['libraries'] = libraries + self.json["libraries"] = libraries if run_name is not None: - self.json['run_name'] = run_name + self.json["run_name"] = run_name if timeout_seconds is not None: - self.json['timeout_seconds'] = timeout_seconds - if 'run_name' not in self.json: - self.json['run_name'] = run_name or kwargs['task_id'] + self.json["timeout_seconds"] = timeout_seconds + if "run_name" not in self.json: + self.json["run_name"] = run_name or kwargs["task_id"] if idempotency_token is not None: - self.json['idempotency_token'] = idempotency_token + self.json["idempotency_token"] = idempotency_token if access_control_list is not None: - self.json['access_control_list'] = access_control_list + self.json["access_control_list"] = access_control_list if git_source is not None: - self.json['git_source'] = git_source + self.json["git_source"] = git_source - self.json = deep_string_coerce(self.json) + if "dbt_task" in self.json and "git_source" not in self.json: + raise AirflowException("git_source is required for dbt_task") + + self.json = normalise_json_content(self.json) # This variable will be used in case our task gets killed. - self.run_id: Optional[int] = None + self.run_id: int | None = None self.do_xcom_push = do_xcom_push - def _get_hook(self) -> DatabricksHook: + @cached_property + def _hook(self): + return self._get_hook(caller="DatabricksSubmitRunOperator") + + def _get_hook(self, caller: str) -> DatabricksHook: return DatabricksHook( self.databricks_conn_id, retry_limit=self.databricks_retry_limit, retry_delay=self.databricks_retry_delay, retry_args=self.databricks_retry_args, + caller=caller, ) - def execute(self, context: 'Context'): - hook = self._get_hook() - self.run_id = hook.submit_run(self.json) - _handle_databricks_operator_execution(self, hook, self.log, context) + def execute(self, context: Context): + self.run_id = self._hook.submit_run(self.json) + _handle_databricks_operator_execution(self, self._hook, self.log, context) def on_kill(self): if self.run_id: - hook = self._get_hook() - hook.cancel_run(self.run_id) + self._hook.cancel_run(self.run_id) self.log.info( - 'Task: %s with run_id: %s was requested to be cancelled.', self.task_id, self.run_id + "Task: %s with run_id: %s was requested to be cancelled.", self.task_id, self.run_id ) else: - self.log.error('Error: Task: %s with invalid run_id was requested to be cancelled.', self.task_id) + self.log.error("Error: Task: %s with invalid run_id was requested to be cancelled.", self.task_id) class DatabricksSubmitRunDeferrableOperator(DatabricksSubmitRunOperator): """Deferrable version of ``DatabricksSubmitRunOperator``""" def execute(self, context): - hook = self._get_hook() + hook = self._get_hook(caller="DatabricksSubmitRunDeferrableOperator") self.run_id = hook.submit_run(self.json) _handle_deferrable_databricks_operator_execution(self, hook, self.log, context) - def execute_complete(self, context: Optional[dict], event: dict): + def execute_complete(self, context: dict | None, event: dict): _handle_deferrable_databricks_operator_completion(event, self.log) @@ -508,7 +486,7 @@ class DatabricksRunNowOperator(BaseOperator): The map is passed to the notebook and will be accessible through the dbutils.widgets.get function. See Widgets for more information. If not specified upon run-now, the triggered run will use the - job’s base parameters. notebook_params cannot be + job's base parameters. notebook_params cannot be specified in conjunction with jar_params. The json representation of this field (i.e. {"notebook_params":{"name":"john doe","age":"35"}}) cannot exceed 10,000 bytes. @@ -526,8 +504,8 @@ class DatabricksRunNowOperator(BaseOperator): .. seealso:: https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunNow - :param python_named_parameters: A list of parameters for jobs with python wheel tasks, - e.g. "python_named_parameters": {"name": "john doe", "age": "35"}. + :param python_named_params: A list of named parameters for jobs with python wheel tasks, + e.g. "python_named_params": {"name": "john doe", "age": "35"}. If specified upon run-now, it would overwrite the parameters specified in job setting. This field will be templated. @@ -560,9 +538,9 @@ class DatabricksRunNowOperator(BaseOperator): :param databricks_conn_id: Reference to the :ref:`Databricks connection `. By default and in the common case this will be ``databricks_default``. To use token based authentication, provide the key ``token`` in the extra field for the - connection and create the key ``host`` and leave the ``host`` field empty. + connection and create the key ``host`` and leave the ``host`` field empty. (templated) :param polling_period_seconds: Controls the rate which we poll for the result of - this run. By default the operator will poll every 30 seconds. + this run. By default, the operator will poll every 30 seconds. :param databricks_retry_limit: Amount of times retry if the Databricks backend is unreachable. Its value must be greater than or equal to 1. :param databricks_retry_delay: Number of seconds to wait between retries (it @@ -573,30 +551,30 @@ class DatabricksRunNowOperator(BaseOperator): """ # Used in airflow.models.BaseOperator - template_fields: Sequence[str] = ('json',) - template_ext: Sequence[str] = ('.json',) + template_fields: Sequence[str] = ("json", "databricks_conn_id") + template_ext: Sequence[str] = (".json-tpl",) # Databricks brand color (blue) under white text - ui_color = '#1CB1C2' - ui_fgcolor = '#fff' + ui_color = "#1CB1C2" + ui_fgcolor = "#fff" operator_extra_links = (DatabricksJobRunLink(),) def __init__( self, *, - job_id: Optional[str] = None, - job_name: Optional[str] = None, - json: Optional[Any] = None, - notebook_params: Optional[Dict[str, str]] = None, - python_params: Optional[List[str]] = None, - jar_params: Optional[List[str]] = None, - spark_submit_params: Optional[List[str]] = None, - python_named_parameters: Optional[Dict[str, str]] = None, - idempotency_token: Optional[str] = None, - databricks_conn_id: str = 'databricks_default', + job_id: str | None = None, + job_name: str | None = None, + json: Any | None = None, + notebook_params: dict[str, str] | None = None, + python_params: list[str] | None = None, + jar_params: list[str] | None = None, + spark_submit_params: list[str] | None = None, + python_named_params: dict[str, str] | None = None, + idempotency_token: str | None = None, + databricks_conn_id: str = "databricks_default", polling_period_seconds: int = 30, databricks_retry_limit: int = 3, databricks_retry_delay: int = 1, - databricks_retry_args: Optional[Dict[Any, Any]] = None, + databricks_retry_args: dict[Any, Any] | None = None, do_xcom_push: bool = True, wait_for_termination: bool = True, **kwargs, @@ -612,66 +590,70 @@ def __init__( self.wait_for_termination = wait_for_termination if job_id is not None: - self.json['job_id'] = job_id + self.json["job_id"] = job_id if job_name is not None: - self.json['job_name'] = job_name - if 'job_id' in self.json and 'job_name' in self.json: + self.json["job_name"] = job_name + if "job_id" in self.json and "job_name" in self.json: raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'") if notebook_params is not None: - self.json['notebook_params'] = notebook_params + self.json["notebook_params"] = notebook_params if python_params is not None: - self.json['python_params'] = python_params - if python_named_parameters is not None: - self.json['python_named_parameters'] = python_named_parameters + self.json["python_params"] = python_params + if python_named_params is not None: + self.json["python_named_params"] = python_named_params if jar_params is not None: - self.json['jar_params'] = jar_params + self.json["jar_params"] = jar_params if spark_submit_params is not None: - self.json['spark_submit_params'] = spark_submit_params + self.json["spark_submit_params"] = spark_submit_params if idempotency_token is not None: - self.json['idempotency_token'] = idempotency_token + self.json["idempotency_token"] = idempotency_token - self.json = deep_string_coerce(self.json) + self.json = normalise_json_content(self.json) # This variable will be used in case our task gets killed. - self.run_id: Optional[int] = None + self.run_id: int | None = None self.do_xcom_push = do_xcom_push - def _get_hook(self) -> DatabricksHook: + @cached_property + def _hook(self): + return self._get_hook(caller="DatabricksRunNowOperator") + + def _get_hook(self, caller: str) -> DatabricksHook: return DatabricksHook( self.databricks_conn_id, retry_limit=self.databricks_retry_limit, retry_delay=self.databricks_retry_delay, retry_args=self.databricks_retry_args, + caller=caller, ) - def execute(self, context: 'Context'): - hook = self._get_hook() - if 'job_name' in self.json: - job_id = hook.find_job_id_by_name(self.json['job_name']) + def execute(self, context: Context): + hook = self._hook + if "job_name" in self.json: + job_id = hook.find_job_id_by_name(self.json["job_name"]) if job_id is None: raise AirflowException(f"Job ID for job name {self.json['job_name']} can not be found") - self.json['job_id'] = job_id - del self.json['job_name'] + self.json["job_id"] = job_id + del self.json["job_name"] self.run_id = hook.run_now(self.json) _handle_databricks_operator_execution(self, hook, self.log, context) def on_kill(self): if self.run_id: - hook = self._get_hook() - hook.cancel_run(self.run_id) + self._hook.cancel_run(self.run_id) self.log.info( - 'Task: %s with run_id: %s was requested to be cancelled.', self.task_id, self.run_id + "Task: %s with run_id: %s was requested to be cancelled.", self.task_id, self.run_id ) else: - self.log.error('Error: Task: %s with invalid run_id was requested to be cancelled.', self.task_id) + self.log.error("Error: Task: %s with invalid run_id was requested to be cancelled.", self.task_id) class DatabricksRunNowDeferrableOperator(DatabricksRunNowOperator): """Deferrable version of ``DatabricksRunNowOperator``""" def execute(self, context): - hook = self._get_hook() + hook = self._get_hook(caller="DatabricksRunNowDeferrableOperator") self.run_id = hook.run_now(self.json) _handle_deferrable_databricks_operator_execution(self, hook, self.log, context) - def execute_complete(self, context: Optional[dict], event: dict): + def execute_complete(self, context: dict | None, event: dict): _handle_deferrable_databricks_operator_completion(event, self.log) diff --git a/airflow/providers/databricks/operators/databricks_repos.py b/airflow/providers/databricks/operators/databricks_repos.py index 97b46b8e244d6..f42114d474267 100644 --- a/airflow/providers/databricks/operators/databricks_repos.py +++ b/airflow/providers/databricks/operators/databricks_repos.py @@ -15,12 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains Databricks operators.""" +from __future__ import annotations + import re -from typing import TYPE_CHECKING, Optional, Sequence -from urllib.parse import urlparse +from typing import TYPE_CHECKING, Sequence +from urllib.parse import urlsplit +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.databricks.hooks.databricks import DatabricksHook @@ -46,7 +48,7 @@ class DatabricksReposCreateOperator(BaseOperator): :param databricks_conn_id: Reference to the :ref:`Databricks connection `. By default and in the common case this will be ``databricks_default``. To use token based authentication, provide the key ``token`` in the extra field for the - connection and create the key ``host`` and leave the ``host`` field empty. + connection and create the key ``host`` and leave the ``host`` field empty. (templated) :param databricks_retry_limit: Amount of times retry if the Databricks backend is unreachable. Its value must be greater than or equal to 1. :param databricks_retry_delay: Number of seconds to wait between retries (it @@ -54,7 +56,7 @@ class DatabricksReposCreateOperator(BaseOperator): """ # Used in airflow.models.BaseOperator - template_fields: Sequence[str] = ('repo_path', 'tag', 'branch') + template_fields: Sequence[str] = ("repo_path", "tag", "branch", "databricks_conn_id") __git_providers__ = { "github.com": "gitHub", @@ -69,12 +71,12 @@ def __init__( self, *, git_url: str, - git_provider: Optional[str] = None, - branch: Optional[str] = None, - tag: Optional[str] = None, - repo_path: Optional[str] = None, + git_provider: str | None = None, + branch: str | None = None, + tag: str | None = None, + repo_path: str | None = None, ignore_existing_repo: bool = False, - databricks_conn_id: str = 'databricks_default', + databricks_conn_id: str = "databricks_default", databricks_retry_limit: int = 3, databricks_retry_delay: int = 1, **kwargs, @@ -104,7 +106,7 @@ def __init__( def __detect_repo_provider__(url): provider = None try: - netloc = urlparse(url).netloc + netloc = urlsplit(url).netloc idx = netloc.rfind("@") if idx != -1: netloc = netloc[(idx + 1) :] @@ -116,14 +118,16 @@ def __detect_repo_provider__(url): pass return provider - def _get_hook(self) -> DatabricksHook: + @cached_property + def _hook(self) -> DatabricksHook: return DatabricksHook( self.databricks_conn_id, retry_limit=self.databricks_retry_limit, retry_delay=self.databricks_retry_delay, + caller="DatabricksReposCreateOperator", ) - def execute(self, context: 'Context'): + def execute(self, context: Context): """ Creates a Databricks Repo @@ -140,22 +144,21 @@ def execute(self, context: 'Context'): f"repo_path should have form of /Repos/{{folder}}/{{repo-name}}, got '{self.repo_path}'" ) payload["path"] = self.repo_path - hook = self._get_hook() existing_repo_id = None if self.repo_path is not None: - existing_repo_id = hook.get_repo_by_path(self.repo_path) + existing_repo_id = self._hook.get_repo_by_path(self.repo_path) if existing_repo_id is not None and not self.ignore_existing_repo: raise AirflowException(f"Repo with path '{self.repo_path}' already exists") if existing_repo_id is None: - result = hook.create_repo(payload) + result = self._hook.create_repo(payload) repo_id = result["id"] else: repo_id = existing_repo_id # update repo if necessary if self.branch is not None: - hook.update_repo(str(repo_id), {'branch': str(self.branch)}) + self._hook.update_repo(str(repo_id), {"branch": str(self.branch)}) elif self.tag is not None: - hook.update_repo(str(repo_id), {'tag': str(self.tag)}) + self._hook.update_repo(str(repo_id), {"tag": str(self.tag)}) return repo_id @@ -173,7 +176,7 @@ class DatabricksReposUpdateOperator(BaseOperator): :param databricks_conn_id: Reference to the :ref:`Databricks connection `. By default and in the common case this will be ``databricks_default``. To use token based authentication, provide the key ``token`` in the extra field for the - connection and create the key ``host`` and leave the ``host`` field empty. + connection and create the key ``host`` and leave the ``host`` field empty. (templated) :param databricks_retry_limit: Amount of times retry if the Databricks backend is unreachable. Its value must be greater than or equal to 1. :param databricks_retry_delay: Number of seconds to wait between retries (it @@ -181,16 +184,16 @@ class DatabricksReposUpdateOperator(BaseOperator): """ # Used in airflow.models.BaseOperator - template_fields: Sequence[str] = ('repo_path', 'tag', 'branch') + template_fields: Sequence[str] = ("repo_path", "tag", "branch", "databricks_conn_id") def __init__( self, *, - branch: Optional[str] = None, - tag: Optional[str] = None, - repo_id: Optional[str] = None, - repo_path: Optional[str] = None, - databricks_conn_id: str = 'databricks_default', + branch: str | None = None, + tag: str | None = None, + repo_id: str | None = None, + repo_path: str | None = None, + databricks_conn_id: str = "databricks_default", databricks_retry_limit: int = 3, databricks_retry_delay: int = 1, **kwargs, @@ -213,26 +216,27 @@ def __init__( self.branch = branch self.tag = tag - def _get_hook(self) -> DatabricksHook: + @cached_property + def _hook(self) -> DatabricksHook: return DatabricksHook( self.databricks_conn_id, retry_limit=self.databricks_retry_limit, retry_delay=self.databricks_retry_delay, + caller="DatabricksReposUpdateOperator", ) - def execute(self, context: 'Context'): - hook = self._get_hook() + def execute(self, context: Context): if self.repo_path is not None: - self.repo_id = hook.get_repo_by_path(self.repo_path) + self.repo_id = self._hook.get_repo_by_path(self.repo_path) if self.repo_id is None: raise AirflowException(f"Can't find Repo ID for path '{self.repo_path}'") if self.branch is not None: - payload = {'branch': str(self.branch)} + payload = {"branch": str(self.branch)} else: - payload = {'tag': str(self.tag)} + payload = {"tag": str(self.tag)} - result = hook.update_repo(str(self.repo_id), payload) - return result['head_commit_id'] + result = self._hook.update_repo(str(self.repo_id), payload) + return result["head_commit_id"] class DatabricksReposDeleteOperator(BaseOperator): @@ -246,7 +250,7 @@ class DatabricksReposDeleteOperator(BaseOperator): :param databricks_conn_id: Reference to the :ref:`Databricks connection `. By default and in the common case this will be ``databricks_default``. To use token based authentication, provide the key ``token`` in the extra field for the - connection and create the key ``host`` and leave the ``host`` field empty. + connection and create the key ``host`` and leave the ``host`` field empty. (templated) :param databricks_retry_limit: Amount of times retry if the Databricks backend is unreachable. Its value must be greater than or equal to 1. :param databricks_retry_delay: Number of seconds to wait between retries (it @@ -254,14 +258,14 @@ class DatabricksReposDeleteOperator(BaseOperator): """ # Used in airflow.models.BaseOperator - template_fields: Sequence[str] = ('repo_path',) + template_fields: Sequence[str] = ("repo_path", "databricks_conn_id") def __init__( self, *, - repo_id: Optional[str] = None, - repo_path: Optional[str] = None, - databricks_conn_id: str = 'databricks_default', + repo_id: str | None = None, + repo_path: str | None = None, + databricks_conn_id: str = "databricks_default", databricks_retry_limit: int = 3, databricks_retry_delay: int = 1, **kwargs, @@ -278,18 +282,19 @@ def __init__( self.repo_path = repo_path self.repo_id = repo_id - def _get_hook(self) -> DatabricksHook: + @cached_property + def _hook(self) -> DatabricksHook: return DatabricksHook( self.databricks_conn_id, retry_limit=self.databricks_retry_limit, retry_delay=self.databricks_retry_delay, + caller="DatabricksReposDeleteOperator", ) - def execute(self, context: 'Context'): - hook = self._get_hook() + def execute(self, context: Context): if self.repo_path is not None: - self.repo_id = hook.get_repo_by_path(self.repo_path) + self.repo_id = self._hook.get_repo_by_path(self.repo_path) if self.repo_id is None: raise AirflowException(f"Can't find Repo ID for path '{self.repo_path}'") - hook.delete_repo(str(self.repo_id)) + self._hook.delete_repo(str(self.repo_id)) diff --git a/airflow/providers/databricks/operators/databricks_sql.py b/airflow/providers/databricks/operators/databricks_sql.py index 9e6298bc21263..379b0fd2c9bbb 100644 --- a/airflow/providers/databricks/operators/databricks_sql.py +++ b/airflow/providers/databricks/operators/databricks_sql.py @@ -15,24 +15,25 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains Databricks operators.""" +from __future__ import annotations import csv import json -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Sequence from databricks.sql.utils import ParamEscaper from airflow.exceptions import AirflowException from airflow.models import BaseOperator +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook if TYPE_CHECKING: from airflow.utils.context import Context -class DatabricksSqlOperator(BaseOperator): +class DatabricksSqlOperator(SQLExecuteQueryOperator): """ Executes SQL code in a Databricks SQL endpoint or a Databricks cluster @@ -41,7 +42,7 @@ class DatabricksSqlOperator(BaseOperator): :ref:`howto/operator:DatabricksSqlOperator` :param databricks_conn_id: Reference to - :ref:`Databricks connection id` + :ref:`Databricks connection id` (templated) :param http_path: Optional string specifying HTTP path of Databricks SQL Endpoint or cluster. If not specified, it should be either specified in the Databricks connection's extra parameters, or ``sql_endpoint_name`` must be specified. @@ -64,66 +65,78 @@ class DatabricksSqlOperator(BaseOperator): :param csv_params: parameters that will be passed to the ``csv.DictWriter`` class used to write CSV data. """ - template_fields: Sequence[str] = ('sql', '_output_path', 'schema', 'catalog', 'http_headers') - template_ext: Sequence[str] = ('.sql',) - template_fields_renderers = {'sql': 'sql'} + template_fields: Sequence[str] = ( + "sql", + "_output_path", + "schema", + "catalog", + "http_headers", + "databricks_conn_id", + ) + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "sql"} def __init__( self, *, - sql: Union[str, List[str]], databricks_conn_id: str = DatabricksSqlHook.default_conn_name, - http_path: Optional[str] = None, - sql_endpoint_name: Optional[str] = None, - parameters: Optional[Union[Mapping, Iterable]] = None, + http_path: str | None = None, + sql_endpoint_name: str | None = None, session_configuration=None, - http_headers: Optional[List[Tuple[str, str]]] = None, - catalog: Optional[str] = None, - schema: Optional[str] = None, - do_xcom_push: bool = False, - output_path: Optional[str] = None, - output_format: str = 'csv', - csv_params: Optional[Dict[str, Any]] = None, - client_parameters: Optional[Dict[str, Any]] = None, + http_headers: list[tuple[str, str]] | None = None, + catalog: str | None = None, + schema: str | None = None, + output_path: str | None = None, + output_format: str = "csv", + csv_params: dict[str, Any] | None = None, + client_parameters: dict[str, Any] | None = None, **kwargs, ) -> None: - """Creates a new ``DatabricksSqlOperator``.""" - super().__init__(**kwargs) + super().__init__(conn_id=databricks_conn_id, **kwargs) self.databricks_conn_id = databricks_conn_id - self.sql = sql - self._http_path = http_path - self._sql_endpoint_name = sql_endpoint_name self._output_path = output_path self._output_format = output_format self._csv_params = csv_params - self.parameters = parameters - self.do_xcom_push = do_xcom_push - self.session_config = session_configuration + self.http_path = http_path + self.sql_endpoint_name = sql_endpoint_name + self.session_configuration = session_configuration + self.client_parameters = {} if client_parameters is None else client_parameters + self.hook_params = kwargs.pop("hook_params", {}) self.http_headers = http_headers self.catalog = catalog self.schema = schema - self.client_parameters = client_parameters or {} - def _get_hook(self) -> DatabricksSqlHook: - return DatabricksSqlHook( - self.databricks_conn_id, - http_path=self._http_path, - session_configuration=self.session_config, - sql_endpoint_name=self._sql_endpoint_name, - http_headers=self.http_headers, - catalog=self.catalog, - schema=self.schema, + def get_db_hook(self) -> DatabricksSqlHook: + hook_params = { + "http_path": self.http_path, + "session_configuration": self.session_configuration, + "sql_endpoint_name": self.sql_endpoint_name, + "http_headers": self.http_headers, + "catalog": self.catalog, + "schema": self.schema, + "caller": "DatabricksSqlOperator", **self.client_parameters, - ) + **self.hook_params, + } + return DatabricksSqlHook(self.databricks_conn_id, **hook_params) - def _format_output(self, schema, results): + def _process_output( + self, results: Any | list[Any], description: Sequence[Sequence] | None, scalar_results: bool + ) -> Any: if not self._output_path: - return + return description, results if not self._output_format: raise AirflowException("Output format should be specified!") - field_names = [field[0] for field in schema] + if description is None: + self.log.warning("Description of the cursor is missing. Will not process the output") + return description, results + field_names = [field[0] for field in description] + if scalar_results: + list_results: list[Any] = [results] + else: + list_results = results if self._output_format.lower() == "csv": - with open(self._output_path, "w", newline='') as file: + with open(self._output_path, "w", newline="") as file: if self._csv_params: csv_params = self._csv_params else: @@ -134,28 +147,19 @@ def _format_output(self, schema, results): writer = csv.DictWriter(file, fieldnames=field_names, **csv_params) if write_header: writer.writeheader() - for row in results: + for row in list_results: writer.writerow(row.asDict()) elif self._output_format.lower() == "json": with open(self._output_path, "w") as file: - file.write(json.dumps([row.asDict() for row in results])) + file.write(json.dumps([row.asDict() for row in list_results])) elif self._output_format.lower() == "jsonl": with open(self._output_path, "w") as file: - for row in results: + for row in list_results: file.write(json.dumps(row.asDict())) file.write("\n") else: raise AirflowException(f"Unsupported output format: '{self._output_format}'") - - def execute(self, context: 'Context') -> Any: - self.log.info('Executing: %s', self.sql) - hook = self._get_hook() - schema, results = hook.run(self.sql, parameters=self.parameters) - # self.log.info('Schema: %s', schema) - # self.log.info('Results: %s', results) - self._format_output(schema, results) - if self.do_xcom_push: - return results + return description, results COPY_INTO_APPROVED_FORMATS = ["CSV", "JSON", "AVRO", "ORC", "PARQUET", "TEXT", "BINARYFILE"] @@ -176,7 +180,7 @@ class DatabricksCopyIntoOperator(BaseOperator): :param file_format: Required file format. Supported formats are ``CSV``, ``JSON``, ``AVRO``, ``ORC``, ``PARQUET``, ``TEXT``, ``BINARYFILE``. :param databricks_conn_id: Reference to - :ref:`Databricks connection id` + :ref:`Databricks connection id` (templated) :param http_path: Optional string specifying HTTP path of Databricks SQL Endpoint or cluster. If not specified, it should be either specified in the Databricks connection's extra parameters, or ``sql_endpoint_name`` must be specified. @@ -204,9 +208,10 @@ class DatabricksCopyIntoOperator(BaseOperator): """ template_fields: Sequence[str] = ( - '_file_location', - '_files', - '_table_name', + "_file_location", + "_files", + "_table_name", + "databricks_conn_id", ) def __init__( @@ -216,23 +221,23 @@ def __init__( file_location: str, file_format: str, databricks_conn_id: str = DatabricksSqlHook.default_conn_name, - http_path: Optional[str] = None, - sql_endpoint_name: Optional[str] = None, + http_path: str | None = None, + sql_endpoint_name: str | None = None, session_configuration=None, - http_headers: Optional[List[Tuple[str, str]]] = None, - client_parameters: Optional[Dict[str, Any]] = None, - catalog: Optional[str] = None, - schema: Optional[str] = None, - files: Optional[List[str]] = None, - pattern: Optional[str] = None, - expression_list: Optional[str] = None, - credential: Optional[Dict[str, str]] = None, - storage_credential: Optional[str] = None, - encryption: Optional[Dict[str, str]] = None, - format_options: Optional[Dict[str, str]] = None, - force_copy: Optional[bool] = None, - copy_options: Optional[Dict[str, str]] = None, - validate: Optional[Union[bool, int]] = None, + http_headers: list[tuple[str, str]] | None = None, + client_parameters: dict[str, Any] | None = None, + catalog: str | None = None, + schema: str | None = None, + files: list[str] | None = None, + pattern: str | None = None, + expression_list: str | None = None, + credential: dict[str, str] | None = None, + storage_credential: str | None = None, + encryption: dict[str, str] | None = None, + format_options: dict[str, str] | None = None, + force_copy: bool | None = None, + copy_options: dict[str, str] | None = None, + validate: bool | int | None = None, **kwargs, ) -> None: """Creates a new ``DatabricksSqlOperator``.""" @@ -266,7 +271,7 @@ def __init__( self._http_headers = http_headers self._client_parameters = client_parameters or {} if force_copy is not None: - self._copy_options["force"] = 'true' if force_copy else 'false' + self._copy_options["force"] = "true" if force_copy else "false" def _get_hook(self) -> DatabricksSqlHook: return DatabricksSqlHook( @@ -277,6 +282,7 @@ def _get_hook(self) -> DatabricksSqlHook: http_headers=self._http_headers, catalog=self._catalog, schema=self._schema, + caller="DatabricksCopyIntoOperator", **self._client_parameters, ) @@ -284,7 +290,7 @@ def _get_hook(self) -> DatabricksSqlHook: def _generate_options( name: str, escaper: ParamEscaper, - opts: Optional[Dict[str, str]] = None, + opts: dict[str, str] | None = None, escape_key: bool = True, ) -> str: formatted_opts = "" @@ -344,8 +350,8 @@ def _create_sql_query(self) -> str: """ return sql.strip() - def execute(self, context: 'Context') -> Any: + def execute(self, context: Context) -> Any: sql = self._create_sql_query() - self.log.info('Executing: %s', sql) + self.log.info("Executing: %s", sql) hook = self._get_hook() hook.run(sql) diff --git a/airflow/providers/databricks/provider.yaml b/airflow/providers/databricks/provider.yaml index 470209604572d..742a594235ffc 100644 --- a/airflow/providers/databricks/provider.yaml +++ b/airflow/providers/databricks/provider.yaml @@ -22,6 +22,11 @@ description: | `Databricks `__ versions: + - 3.4.0 + - 3.3.0 + - 3.2.0 + - 3.1.0 + - 3.0.0 - 2.7.0 - 2.6.0 - 2.5.0 @@ -35,8 +40,12 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-common-sql>=1.3.1 + - requests>=2.27,<3 + - databricks-sql-connector>=2.0.0, <3.0.0 + - aiohttp>=3.6.3, <4 integrations: - integration-name: Databricks @@ -82,14 +91,10 @@ hooks: python-modules: - airflow.providers.databricks.hooks.databricks_sql -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.databricks.hooks.databricks.DatabricksHook connection-types: - hook-class-name: airflow.providers.databricks.hooks.databricks.DatabricksHook connection-type: databricks - - hook-class-name: airflow.providers.databricks.hooks.databricks_sql.DatabricksSqlHook - connection-type: databricks extra-links: - airflow.providers.databricks.operators.databricks.DatabricksJobRunLink diff --git a/airflow/providers/databricks/triggers/databricks.py b/airflow/providers/databricks/triggers/databricks.py index 5f50f5aff29ee..cd2421c376989 100644 --- a/airflow/providers/databricks/triggers/databricks.py +++ b/airflow/providers/databricks/triggers/databricks.py @@ -15,21 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import asyncio -import logging -from typing import Any, Dict, Tuple +from typing import Any from airflow.providers.databricks.hooks.databricks import DatabricksHook - -try: - from airflow.triggers.base import BaseTrigger, TriggerEvent -except ImportError: - logging.getLogger(__name__).warning( - 'Deferrable Operators only work starting Airflow 2.2', - exc_info=True, - ) - BaseTrigger = object # type: ignore - TriggerEvent = None # type: ignore +from airflow.triggers.base import BaseTrigger, TriggerEvent class DatabricksExecutionTrigger(BaseTrigger): @@ -49,13 +41,13 @@ def __init__(self, run_id: int, databricks_conn_id: str, polling_period_seconds: self.polling_period_seconds = polling_period_seconds self.hook = DatabricksHook(databricks_conn_id) - def serialize(self) -> Tuple[str, Dict[str, Any]]: + def serialize(self) -> tuple[str, dict[str, Any]]: return ( - 'airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger', + "airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger", { - 'run_id': self.run_id, - 'databricks_conn_id': self.databricks_conn_id, - 'polling_period_seconds': self.polling_period_seconds, + "run_id": self.run_id, + "databricks_conn_id": self.databricks_conn_id, + "polling_period_seconds": self.polling_period_seconds, }, ) @@ -67,9 +59,9 @@ async def run(self): if run_state.is_terminal: yield TriggerEvent( { - 'run_id': self.run_id, - 'run_state': run_state.to_json(), - 'run_page_url': run_page_url, + "run_id": self.run_id, + "run_state": run_state.to_json(), + "run_page_url": run_page_url, } ) break diff --git a/airflow/providers/databricks/utils/databricks.py b/airflow/providers/databricks/utils/databricks.py index 96935d806344e..9548a6c61466b 100644 --- a/airflow/providers/databricks/utils/databricks.py +++ b/airflow/providers/databricks/utils/databricks.py @@ -15,24 +15,25 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# - -from typing import Union +from __future__ import annotations from airflow.exceptions import AirflowException from airflow.providers.databricks.hooks.databricks import RunState -def deep_string_coerce(content, json_path: str = 'json') -> Union[str, list, dict]: +def normalise_json_content(content, json_path: str = "json") -> str | bool | list | dict: """ - Coerces content or all values of content if it is a dict to a string. The - function will throw if content contains non-string or non-numeric types. + Normalize content or all values of content if it is a dict to a string. The + function will throw if content contains non-string or non-numeric non-boolean types. The reason why we have this function is because the ``self.json`` field must be a dict with only string values. This is because ``render_template`` will fail for numerical values. + + The only one exception is when we have boolean values, they can not be converted + to string type because databricks does not understand 'True' or 'False' values. """ - coerce = deep_string_coerce - if isinstance(content, str): + normalise = normalise_json_content + if isinstance(content, (str, bool)): return content elif isinstance( content, @@ -44,12 +45,12 @@ def deep_string_coerce(content, json_path: str = 'json') -> Union[str, list, dic # Databricks can tolerate either numeric or string types in the API backend. return str(content) elif isinstance(content, (list, tuple)): - return [coerce(e, f'{json_path}[{i}]') for i, e in enumerate(content)] + return [normalise(e, f"{json_path}[{i}]") for i, e in enumerate(content)] elif isinstance(content, dict): - return {k: coerce(v, f'{json_path}[{k}]') for k, v in list(content.items())} + return {k: normalise(v, f"{json_path}[{k}]") for k, v in list(content.items())} else: param_type = type(content) - msg = f'Type {param_type} used for parameter {json_path} is not a number or a string' + msg = f"Type {param_type} used for parameter {json_path} is not a number or a string" raise AirflowException(msg) @@ -58,12 +59,12 @@ def validate_trigger_event(event: dict): Validates correctness of the event received from :class:`~airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger` """ - keys_to_check = ['run_id', 'run_page_url', 'run_state'] + keys_to_check = ["run_id", "run_page_url", "run_state"] for key in keys_to_check: if key not in event: - raise AirflowException(f'Could not find `{key}` in the event: {event}') + raise AirflowException(f"Could not find `{key}` in the event: {event}") try: - RunState.from_json(event['run_state']) + RunState.from_json(event["run_state"]) except Exception: raise AirflowException(f'Run state returned by the Trigger is incorrect: {event["run_state"]}') diff --git a/airflow/providers/datadog/.latest-doc-only-change.txt b/airflow/providers/datadog/.latest-doc-only-change.txt index ab24993f57139..ff7136e07d744 100644 --- a/airflow/providers/datadog/.latest-doc-only-change.txt +++ b/airflow/providers/datadog/.latest-doc-only-change.txt @@ -1 +1 @@ -8b6b0848a3cacf9999477d6af4d2a87463f03026 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/datadog/CHANGELOG.rst b/airflow/providers/datadog/CHANGELOG.rst index c6d15452d1209..a73fa654033cd 100644 --- a/airflow/providers/datadog/CHANGELOG.rst +++ b/airflow/providers/datadog/CHANGELOG.rst @@ -16,9 +16,55 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Remove "bad characters" from our codebase (#24841)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Support host_name on Datadog provider (#23784)`` + * ``Fix new MyPy errors in main (#22884)`` + * ``Prepare provider documentation 2022.05.11 (#23631)`` + * ``Use new Breeze for building, pulling and verifying the images. (#23104)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.0.4 ..... diff --git a/airflow/providers/datadog/hooks/datadog.py b/airflow/providers/datadog/hooks/datadog.py index 574b85efc0bb0..14f2b664d015a 100644 --- a/airflow/providers/datadog/hooks/datadog.py +++ b/airflow/providers/datadog/hooks/datadog.py @@ -15,9 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import time -from typing import Any, Dict, List, Optional, Union +from typing import Any from datadog import api, initialize # type: ignore[attr-defined] @@ -38,13 +39,13 @@ class DatadogHook(BaseHook, LoggingMixin): :param datadog_conn_id: The connection to datadog, containing metadata for api keys. """ - def __init__(self, datadog_conn_id: str = 'datadog_default') -> None: + def __init__(self, datadog_conn_id: str = "datadog_default") -> None: super().__init__() conn = self.get_connection(datadog_conn_id) - self.api_key = conn.extra_dejson.get('api_key', None) - self.app_key = conn.extra_dejson.get('app_key', None) - self.api_host = conn.extra_dejson.get('api_host', None) - self.source_type_name = conn.extra_dejson.get('source_type_name', None) + self.api_key = conn.extra_dejson.get("api_key", None) + self.app_key = conn.extra_dejson.get("app_key", None) + self.api_host = conn.extra_dejson.get("api_host", None) + self.source_type_name = conn.extra_dejson.get("source_type_name", None) # If the host is populated, it will use that hostname instead. # for all metric submissions. @@ -56,20 +57,20 @@ def __init__(self, datadog_conn_id: str = 'datadog_default') -> None: self.log.info("Setting up api keys for Datadog") initialize(api_key=self.api_key, app_key=self.app_key, api_host=self.api_host) - def validate_response(self, response: Dict[str, Any]) -> None: + def validate_response(self, response: dict[str, Any]) -> None: """Validate Datadog response""" - if response['status'] != 'ok': + if response["status"] != "ok": self.log.error("Datadog returned: %s", response) raise AirflowException("Error status received from Datadog") def send_metric( self, metric_name: str, - datapoint: Union[float, int], - tags: Optional[List[str]] = None, - type_: Optional[str] = None, - interval: Optional[int] = None, - ) -> Dict[str, Any]: + datapoint: float | int, + tags: list[str] | None = None, + type_: str | None = None, + interval: int | None = None, + ) -> dict[str, Any]: """ Sends a single datapoint metric to DataDog @@ -86,7 +87,7 @@ def send_metric( self.validate_response(response) return response - def query_metric(self, query: str, from_seconds_ago: int, to_seconds_ago: int) -> Dict[str, Any]: + def query_metric(self, query: str, from_seconds_ago: int, to_seconds_ago: int) -> dict[str, Any]: """ Queries datadog for a specific metric, potentially with some function applied to it and returns the results. @@ -106,15 +107,15 @@ def post_event( self, title: str, text: str, - aggregation_key: Optional[str] = None, - alert_type: Optional[str] = None, - date_happened: Optional[int] = None, - handle: Optional[str] = None, - priority: Optional[str] = None, - related_event_id: Optional[int] = None, - tags: Optional[List[str]] = None, - device_name: Optional[List[str]] = None, - ) -> Dict[str, Any]: + aggregation_key: str | None = None, + alert_type: str | None = None, + date_happened: int | None = None, + handle: str | None = None, + priority: str | None = None, + related_event_id: int | None = None, + tags: list[str] | None = None, + device_name: list[str] | None = None, + ) -> dict[str, Any]: """ Posts an event to datadog (processing finished, potentially alerts, other issues) Think about this as a means to maintain persistence of alerts, rather than diff --git a/airflow/providers/datadog/provider.yaml b/airflow/providers/datadog/provider.yaml index 9276349e242ad..3bd1734a53aae 100644 --- a/airflow/providers/datadog/provider.yaml +++ b/airflow/providers/datadog/provider.yaml @@ -22,6 +22,8 @@ description: | `Datadog `__ versions: + - 3.1.0 + - 3.0.0 - 2.0.4 - 2.0.3 - 2.0.2 @@ -30,8 +32,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - datadog>=0.14.0 integrations: - integration-name: Datadog diff --git a/airflow/providers/datadog/sensors/datadog.py b/airflow/providers/datadog/sensors/datadog.py index 7dbcec80676d6..a94f6ccc3c7af 100644 --- a/airflow/providers/datadog/sensors/datadog.py +++ b/airflow/providers/datadog/sensors/datadog.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable from datadog import api @@ -42,25 +44,25 @@ class DatadogSensor(BaseSensorOperator): :param sources: A comma separated list indicating what tags, if any, should be used to filter the list of monitors by scope :param tags: Get datadog events from specific sources. - :param response_check: A check against the ‘requests’ response object. The callable takes + :param response_check: A check against the 'requests' response object. The callable takes the response object as the first positional argument and optionally any number of keyword arguments available in the context dictionary. It should return True for - ‘pass’ and False otherwise. - :param response_check: Optional[Callable[[Dict[str, Any]], bool]] + 'pass' and False otherwise. + :param response_check: Callable[[dict[str, Any]], bool] | None """ - ui_color = '#66c3dd' + ui_color = "#66c3dd" def __init__( self, *, - datadog_conn_id: str = 'datadog_default', + datadog_conn_id: str = "datadog_default", from_seconds_ago: int = 3600, up_to_seconds_from_now: int = 0, - priority: Optional[str] = None, - sources: Optional[str] = None, - tags: Optional[List[str]] = None, - response_check: Optional[Callable[[Dict[str, Any]], bool]] = None, + priority: str | None = None, + sources: str | None = None, + tags: list[str] | None = None, + response_check: Callable[[dict[str, Any]], bool] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -72,7 +74,7 @@ def __init__( self.tags = tags self.response_check = response_check - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: # This instantiates the hook, but doesn't need it further, # because the API authenticates globally (unfortunately), # but for airflow this shouldn't matter too much, because each @@ -87,7 +89,7 @@ def poke(self, context: 'Context') -> bool: tags=self.tags, ) - if isinstance(response, dict) and response.get('status', 'ok') != 'ok': + if isinstance(response, dict) and response.get("status", "ok") != "ok": self.log.error("Unexpected Datadog result: %s", response) raise AirflowException("Datadog returned unexpected result") diff --git a/airflow/providers/dbt/cloud/CHANGELOG.rst b/airflow/providers/dbt/cloud/CHANGELOG.rst index f23047bdfa90a..6ee467d87245b 100644 --- a/airflow/providers/dbt/cloud/CHANGELOG.rst +++ b/airflow/providers/dbt/cloud/CHANGELOG.rst @@ -16,9 +16,93 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +2.3.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + +2.2.0 +..... + +Features +~~~~~~~~ + +* ``Add 'DbtCloudListJobsOperator' (#26475)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +2.1.0 +..... + +Features +~~~~~~~~ + +* ``Improve taskflow type hints with ParamSpec (#25173)`` + +2.0.1 +..... + +Bug Fixes +~~~~~~~~~ + +* ``Update providers to use functools compat for ''cached_property'' (#24582)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +2.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Features +~~~~~~~~ + +* ``Enable dbt Cloud provider to interact with single tenant instances (#24264)`` + +Bug Fixes +~~~~~~~~~ + +* ``Fix typo in dbt Cloud provider description (#23179)`` +* ``Fix new MyPy errors in main (#22884)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``AIP-47 - Migrate dbt DAGs to new design #22472 (#24202)`` + * ``Prepare provider documentation 2022.05.11 (#23631)`` + * ``Use new Breese for building, pulling and verifying the images. (#23104)`` + * ``Replace usage of 'DummyOperator' with 'EmptyOperator' (#22974)`` + * ``Update dbt.py (#24218)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 1.0.2 ..... diff --git a/airflow/providers/dbt/cloud/example_dags/example_dbt_cloud.py b/airflow/providers/dbt/cloud/example_dags/example_dbt_cloud.py deleted file mode 100644 index aa7f220bd2715..0000000000000 --- a/airflow/providers/dbt/cloud/example_dags/example_dbt_cloud.py +++ /dev/null @@ -1,79 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from datetime import datetime - -from airflow.models import DAG, BaseOperator - -try: - from airflow.operators.empty import EmptyOperator -except ModuleNotFoundError: - from airflow.operators.dummy import DummyOperator as EmptyOperator # type: ignore -from airflow.providers.dbt.cloud.operators.dbt import ( - DbtCloudGetJobRunArtifactOperator, - DbtCloudRunJobOperator, -) -from airflow.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensor -from airflow.utils.edgemodifier import Label - -with DAG( - dag_id="example_dbt_cloud", - default_args={"dbt_cloud_conn_id": "dbt", "account_id": 39151}, - start_date=datetime(2021, 1, 1), - schedule_interval=None, - catchup=False, -) as dag: - begin = EmptyOperator(task_id="begin") - end = EmptyOperator(task_id="end") - - # [START howto_operator_dbt_cloud_run_job] - trigger_job_run1 = DbtCloudRunJobOperator( - task_id="trigger_job_run1", - job_id=48617, - check_interval=10, - timeout=300, - ) - # [END howto_operator_dbt_cloud_run_job] - - # [START howto_operator_dbt_cloud_get_artifact] - get_run_results_artifact: BaseOperator = DbtCloudGetJobRunArtifactOperator( - task_id="get_run_results_artifact", run_id=trigger_job_run1.output, path="run_results.json" - ) - # [END howto_operator_dbt_cloud_get_artifact] - - # [START howto_operator_dbt_cloud_run_job_async] - trigger_job_run2 = DbtCloudRunJobOperator( - task_id="trigger_job_run2", - job_id=48617, - wait_for_termination=False, - additional_run_config={"threads_override": 8}, - ) - # [END howto_operator_dbt_cloud_run_job_async] - - # [START howto_operator_dbt_cloud_run_job_sensor] - job_run_sensor: BaseOperator = DbtCloudJobRunSensor( - task_id="job_run_sensor", run_id=trigger_job_run2.output, timeout=20 - ) - # [END howto_operator_dbt_cloud_run_job_sensor] - - begin >> Label("No async wait") >> trigger_job_run1 - begin >> Label("Do async wait with sensor") >> trigger_job_run2 - [get_run_results_artifact, job_run_sensor] >> end - - # Task dependency created via `XComArgs`: - # trigger_job_run1 >> get_run_results_artifact - # trigger_job_run2 >> job_run_sensor diff --git a/airflow/providers/dbt/cloud/hooks/dbt.py b/airflow/providers/dbt/cloud/hooks/dbt.py index 13d10dbc6283c..6197335c9260b 100644 --- a/airflow/providers/dbt/cloud/hooks/dbt.py +++ b/airflow/providers/dbt/cloud/hooks/dbt.py @@ -14,29 +14,25 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import json -import sys import time from enum import Enum from functools import wraps from inspect import signature -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import Any, Callable, Sequence, Set from requests import PreparedRequest, Session from requests.auth import AuthBase from requests.models import Response +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.providers.http.hooks.http import HttpHook from airflow.typing_compat import TypedDict -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - def fallback_to_default_account(func: Callable) -> Callable: """ @@ -64,7 +60,7 @@ def wrapper(*args, **kwargs) -> Callable: return wrapper -def _get_provider_info() -> Tuple[str, str]: +def _get_provider_info() -> tuple[str, str]: from airflow.providers_manager import ProvidersManager manager = ProvidersManager() @@ -92,7 +88,7 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest: class JobRunInfo(TypedDict): """Type class for the ``job_run_info`` dictionary.""" - account_id: int + account_id: int | None run_id: int @@ -108,7 +104,7 @@ class DbtCloudJobRunStatus(Enum): TERMINAL_STATUSES = (SUCCESS, ERROR, CANCELLED) @classmethod - def check_is_valid(cls, statuses: Union[int, Sequence[int], Set[int]]): + def check_is_valid(cls, statuses: int | Sequence[int] | set[int]): """Validates input statuses are a known value.""" if isinstance(statuses, (Sequence, Set)): for status in statuses: @@ -141,17 +137,20 @@ class DbtCloudHook(HttpHook): hook_name = "dbt Cloud" @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + def get_ui_field_behaviour() -> dict[str, Any]: """Builds custom field behavior for the dbt Cloud connection form in the Airflow UI.""" return { - "hidden_fields": ["host", "port", "schema", "extra"], - "relabeling": {"login": "Account ID", "password": "API Token"}, + "hidden_fields": ["host", "port", "extra"], + "relabeling": {"login": "Account ID", "password": "API Token", "schema": "Tenant"}, + "placeholders": {"schema": "Defaults to 'cloud'."}, } def __init__(self, dbt_cloud_conn_id: str = default_conn_name, *args, **kwargs) -> None: super().__init__(auth_type=TokenAuth) self.dbt_cloud_conn_id = dbt_cloud_conn_id - self.base_url = "https://cloud.getdbt.com/api/v2/accounts/" + tenant = self.connection.schema if self.connection.schema else "cloud" + + self.base_url = f"https://{tenant}.getdbt.com/api/v2/accounts/" @cached_property def connection(self) -> Connection: @@ -167,36 +166,30 @@ def get_conn(self, *args, **kwargs) -> Session: return session - def _paginate(self, endpoint: str, payload: Optional[Dict[str, Any]] = None) -> List[Response]: - results = [] + def _paginate(self, endpoint: str, payload: dict[str, Any] | None = None) -> list[Response]: response = self.run(endpoint=endpoint, data=payload) resp_json = response.json() limit = resp_json["extra"]["filters"]["limit"] num_total_results = resp_json["extra"]["pagination"]["total_count"] num_current_results = resp_json["extra"]["pagination"]["count"] - results.append(response) - - if not num_current_results == num_total_results: + results = [response] + if num_current_results != num_total_results: _paginate_payload = payload.copy() if payload else {} _paginate_payload["offset"] = limit - while True: - if num_current_results < num_total_results: - response = self.run(endpoint=endpoint, data=_paginate_payload) - resp_json = response.json() - results.append(response) - num_current_results += resp_json["extra"]["pagination"]["count"] - _paginate_payload["offset"] += limit - else: - break - + while not num_current_results >= num_total_results: + response = self.run(endpoint=endpoint, data=_paginate_payload) + resp_json = response.json() + results.append(response) + num_current_results += resp_json["extra"]["pagination"]["count"] + _paginate_payload["offset"] += limit return results def _run_and_get_response( self, method: str = "GET", - endpoint: Optional[str] = None, - payload: Union[str, Dict[str, Any], None] = None, + endpoint: str | None = None, + payload: str | dict[str, Any] | None = None, paginate: bool = False, ) -> Any: self.method = method @@ -212,7 +205,7 @@ def _run_and_get_response( return self.run(endpoint=endpoint, data=payload) - def list_accounts(self) -> List[Response]: + def list_accounts(self) -> list[Response]: """ Retrieves all of the dbt Cloud accounts the configured API token is authorized to access. @@ -221,7 +214,7 @@ def list_accounts(self) -> List[Response]: return self._run_and_get_response() @fallback_to_default_account - def get_account(self, account_id: Optional[int] = None) -> Response: + def get_account(self, account_id: int | None = None) -> Response: """ Retrieves metadata for a specific dbt Cloud account. @@ -231,7 +224,7 @@ def get_account(self, account_id: Optional[int] = None) -> Response: return self._run_and_get_response(endpoint=f"{account_id}/") @fallback_to_default_account - def list_projects(self, account_id: Optional[int] = None) -> List[Response]: + def list_projects(self, account_id: int | None = None) -> list[Response]: """ Retrieves metadata for all projects tied to a specified dbt Cloud account. @@ -241,7 +234,7 @@ def list_projects(self, account_id: Optional[int] = None) -> List[Response]: return self._run_and_get_response(endpoint=f"{account_id}/projects/", paginate=True) @fallback_to_default_account - def get_project(self, project_id: int, account_id: Optional[int] = None) -> Response: + def get_project(self, project_id: int, account_id: int | None = None) -> Response: """ Retrieves metadata for a specific project. @@ -254,13 +247,13 @@ def get_project(self, project_id: int, account_id: Optional[int] = None) -> Resp @fallback_to_default_account def list_jobs( self, - account_id: Optional[int] = None, - order_by: Optional[str] = None, - project_id: Optional[int] = None, - ) -> List[Response]: + account_id: int | None = None, + order_by: str | None = None, + project_id: int | None = None, + ) -> list[Response]: """ Retrieves metadata for all jobs tied to a specified dbt Cloud account. If a ``project_id`` is - supplied, only jobs pertaining to this job will be retrieved. + supplied, only jobs pertaining to this project will be retrieved. :param account_id: Optional. The ID of a dbt Cloud account. :param order_by: Optional. Field to order the result by. Use '-' to indicate reverse order. @@ -275,7 +268,7 @@ def list_jobs( ) @fallback_to_default_account - def get_job(self, job_id: int, account_id: Optional[int] = None) -> Response: + def get_job(self, job_id: int, account_id: int | None = None) -> Response: """ Retrieves metadata for a specific job. @@ -290,10 +283,10 @@ def trigger_job_run( self, job_id: int, cause: str, - account_id: Optional[int] = None, - steps_override: Optional[List[str]] = None, - schema_override: Optional[str] = None, - additional_run_config: Optional[Dict[str, Any]] = None, + account_id: int | None = None, + steps_override: list[str] | None = None, + schema_override: str | None = None, + additional_run_config: dict[str, Any] | None = None, ) -> Response: """ Triggers a run of a dbt Cloud job. @@ -328,11 +321,11 @@ def trigger_job_run( @fallback_to_default_account def list_job_runs( self, - account_id: Optional[int] = None, - include_related: Optional[List[str]] = None, - job_definition_id: Optional[int] = None, - order_by: Optional[str] = None, - ) -> List[Response]: + account_id: int | None = None, + include_related: list[str] | None = None, + job_definition_id: int | None = None, + order_by: str | None = None, + ) -> list[Response]: """ Retrieves metadata for all of the dbt Cloud job runs for an account. If a ``job_definition_id`` is supplied, only metadata for runs of that specific job are pulled. @@ -357,7 +350,7 @@ def list_job_runs( @fallback_to_default_account def get_job_run( - self, run_id: int, account_id: Optional[int] = None, include_related: Optional[List[str]] = None + self, run_id: int, account_id: int | None = None, include_related: list[str] | None = None ) -> Response: """ Retrieves metadata for a specific run of a dbt Cloud job. @@ -373,7 +366,7 @@ def get_job_run( payload={"include_related": include_related}, ) - def get_job_run_status(self, run_id: int, account_id: Optional[int] = None) -> int: + def get_job_run_status(self, run_id: int, account_id: int | None = None) -> int: """ Retrieves the status for a specific run of a dbt Cloud job. @@ -395,8 +388,8 @@ def get_job_run_status(self, run_id: int, account_id: Optional[int] = None) -> i def wait_for_job_run_status( self, run_id: int, - account_id: Optional[int] = None, - expected_statuses: Union[int, Sequence[int], Set[int]] = DbtCloudJobRunStatus.SUCCESS.value, + account_id: int | None = None, + expected_statuses: int | Sequence[int] | set[int] = DbtCloudJobRunStatus.SUCCESS.value, check_interval: int = 60, timeout: int = 60 * 60 * 24 * 7, ) -> bool: @@ -438,7 +431,7 @@ def wait_for_job_run_status( return job_run_status in expected_statuses @fallback_to_default_account - def cancel_job_run(self, run_id: int, account_id: Optional[int] = None) -> None: + def cancel_job_run(self, run_id: int, account_id: int | None = None) -> None: """ Cancel a specific dbt Cloud job run. @@ -449,8 +442,8 @@ def cancel_job_run(self, run_id: int, account_id: Optional[int] = None) -> None: @fallback_to_default_account def list_job_run_artifacts( - self, run_id: int, account_id: Optional[int] = None, step: Optional[int] = None - ) -> List[Response]: + self, run_id: int, account_id: int | None = None, step: int | None = None + ) -> list[Response]: """ Retrieves a list of the available artifact files generated for a completed run of a dbt Cloud job. By default, this returns artifacts from the last step in the run. To list artifacts from other steps in @@ -469,7 +462,7 @@ def list_job_run_artifacts( @fallback_to_default_account def get_job_run_artifact( - self, run_id: int, path: str, account_id: Optional[int] = None, step: Optional[int] = None + self, run_id: int, path: str, account_id: int | None = None, step: int | None = None ) -> Response: """ Retrieves a list of the available artifact files generated for a completed run of a dbt Cloud job. By @@ -490,7 +483,7 @@ def get_job_run_artifact( endpoint=f"{account_id}/runs/{run_id}/artifacts/{path}", payload={"step": step} ) - def test_connection(self) -> Tuple[bool, str]: + def test_connection(self) -> tuple[bool, str]: """Test dbt Cloud connection.""" try: self._run_and_get_response() diff --git a/airflow/providers/dbt/cloud/operators/dbt.py b/airflow/providers/dbt/cloud/operators/dbt.py index e26eaaed2325f..6f8080097bf04 100644 --- a/airflow/providers/dbt/cloud/operators/dbt.py +++ b/airflow/providers/dbt/cloud/operators/dbt.py @@ -14,9 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import json -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any from airflow.models import BaseOperator, BaseOperatorLink, XCom from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook, DbtCloudJobRunException, DbtCloudJobRunStatus @@ -33,16 +34,8 @@ class DbtCloudRunJobOperatorLink(BaseOperatorLink): name = "Monitor Job Run" - def get_link(self, operator, dttm=None, *, ti_key=None): - if ti_key is not None: - job_run_url = XCom.get_value(key="job_run_url", ti_key=ti_key) - else: - assert dttm - job_run_url = XCom.get_one( - dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm, key="job_run_url" - ) - - return job_run_url + def get_link(self, operator: BaseOperator, *, ti_key=None): + return XCom.get_value(key="job_run_url", ti_key=ti_key) class DbtCloudRunJobOperator(BaseOperator): @@ -89,14 +82,14 @@ def __init__( *, dbt_cloud_conn_id: str = DbtCloudHook.default_conn_name, job_id: int, - account_id: Optional[int] = None, - trigger_reason: Optional[str] = None, - steps_override: Optional[List[str]] = None, - schema_override: Optional[str] = None, + account_id: int | None = None, + trigger_reason: str | None = None, + steps_override: list[str] | None = None, + schema_override: str | None = None, wait_for_termination: bool = True, timeout: int = 60 * 60 * 24 * 7, check_interval: int = 60, - additional_run_config: Optional[Dict[str, Any]] = None, + additional_run_config: dict[str, Any] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -113,7 +106,7 @@ def __init__( self.hook: DbtCloudHook self.run_id: int - def execute(self, context: "Context") -> int: + def execute(self, context: Context) -> int: if self.trigger_reason is None: self.trigger_reason = ( f"Triggered via Apache Airflow by task {self.task_id!r} in the {self.dag.dag_id} DAG." @@ -193,9 +186,9 @@ def __init__( dbt_cloud_conn_id: str = DbtCloudHook.default_conn_name, run_id: int, path: str, - account_id: Optional[int] = None, - step: Optional[int] = None, - output_file_name: Optional[str] = None, + account_id: int | None = None, + step: int | None = None, + output_file_name: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -206,7 +199,7 @@ def __init__( self.step = step self.output_file_name = output_file_name or f"{self.run_id}_{self.path}".replace("/", "-") - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: hook = DbtCloudHook(self.dbt_cloud_conn_id) response = hook.get_job_run_artifact( run_id=self.run_id, path=self.path, account_id=self.account_id, step=self.step @@ -217,3 +210,54 @@ def execute(self, context: "Context") -> None: json.dump(response.json(), file) else: file.write(response.text) + + +class DbtCloudListJobsOperator(BaseOperator): + """ + List jobs in a dbt Cloud project. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DbtCloudListJobsOperator` + + Retrieves metadata for all jobs tied to a specified dbt Cloud account. If a ``project_id`` is + supplied, only jobs pertaining to this project id will be retrieved. + + :param account_id: Optional. If an account ID is not provided explicitly, + the account ID from the dbt Cloud connection will be used. + :param order_by: Optional. Field to order the result by. Use '-' to indicate reverse order. + For example, to use reverse order by the run ID use ``order_by=-id``. + :param project_id: Optional. The ID of a dbt Cloud project. + """ + + template_fields = ( + "account_id", + "project_id", + ) + + def __init__( + self, + *, + dbt_cloud_conn_id: str = DbtCloudHook.default_conn_name, + account_id: int | None = None, + project_id: int | None = None, + order_by: str | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dbt_cloud_conn_id = dbt_cloud_conn_id + self.account_id = account_id + self.project_id = project_id + self.order_by = order_by + + def execute(self, context: Context) -> list: + hook = DbtCloudHook(self.dbt_cloud_conn_id) + list_jobs_response = hook.list_jobs( + account_id=self.account_id, order_by=self.order_by, project_id=self.project_id + ) + buffer = [] + for job_metadata in list_jobs_response: + for job in job_metadata.json()["data"]: + buffer.append(job["id"]) + self.log.info("Jobs in the specified dbt Cloud account are: %s", ", ".join(map(str, buffer))) + return buffer diff --git a/airflow/providers/dbt/cloud/provider.yaml b/airflow/providers/dbt/cloud/provider.yaml index 2114d4a5d03e1..4526746043c26 100644 --- a/airflow/providers/dbt/cloud/provider.yaml +++ b/airflow/providers/dbt/cloud/provider.yaml @@ -22,11 +22,17 @@ description: | `dbt Cloud `__ versions: + - 2.3.0 + - 2.2.0 + - 2.1.0 + - 2.0.1 + - 2.0.0 - 1.0.2 - 1.0.1 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-http integrations: - integration-name: dbt Cloud @@ -51,9 +57,6 @@ hooks: python-modules: - airflow.providers.dbt.cloud.hooks.dbt -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook - connection-types: - hook-class-name: airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook connection-type: dbt_cloud diff --git a/airflow/providers/dbt/cloud/sensors/dbt.py b/airflow/providers/dbt/cloud/sensors/dbt.py index 34df4f1ccf2bc..14df6910c9749 100644 --- a/airflow/providers/dbt/cloud/sensors/dbt.py +++ b/airflow/providers/dbt/cloud/sensors/dbt.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook, DbtCloudJobRunException, DbtCloudJobRunStatus from airflow.sensors.base import BaseSensorOperator @@ -44,7 +45,7 @@ def __init__( *, dbt_cloud_conn_id: str = DbtCloudHook.default_conn_name, run_id: int, - account_id: Optional[int] = None, + account_id: int | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -52,7 +53,7 @@ def __init__( self.run_id = run_id self.account_id = account_id - def poke(self, context: "Context") -> bool: + def poke(self, context: Context) -> bool: hook = DbtCloudHook(self.dbt_cloud_conn_id) job_run_status = hook.get_job_run_status(run_id=self.run_id, account_id=self.account_id) diff --git a/airflow/providers/dependencies.json b/airflow/providers/dependencies.json deleted file mode 100644 index 79a58e5bc2ef1..0000000000000 --- a/airflow/providers/dependencies.json +++ /dev/null @@ -1,94 +0,0 @@ -{ - "airbyte": [ - "http" - ], - "amazon": [ - "apache.hive", - "cncf.kubernetes", - "exasol", - "ftp", - "google", - "imap", - "mongo", - "salesforce", - "ssh" - ], - "apache.beam": [ - "google" - ], - "apache.druid": [ - "apache.hive" - ], - "apache.hive": [ - "amazon", - "microsoft.mssql", - "mysql", - "presto", - "samba", - "vertica" - ], - "apache.livy": [ - "http" - ], - "dbt.cloud": [ - "http" - ], - "dingding": [ - "http" - ], - "discord": [ - "http" - ], - "google": [ - "amazon", - "apache.beam", - "apache.cassandra", - "cncf.kubernetes", - "facebook", - "microsoft.azure", - "microsoft.mssql", - "mysql", - "oracle", - "postgres", - "presto", - "salesforce", - "sftp", - "ssh", - "trino" - ], - "hashicorp": [ - "google" - ], - "microsoft.azure": [ - "google", - "oracle", - "sftp" - ], - "mysql": [ - "amazon", - "presto", - "trino", - "vertica" - ], - "postgres": [ - "amazon" - ], - "presto": [ - "google" - ], - "salesforce": [ - "tableau" - ], - "sftp": [ - "ssh" - ], - "slack": [ - "http" - ], - "snowflake": [ - "slack" - ], - "trino": [ - "google" - ] -} diff --git a/airflow/providers/dingding/.latest-doc-only-change.txt b/airflow/providers/dingding/.latest-doc-only-change.txt index 029fd1fd22aec..ff7136e07d744 100644 --- a/airflow/providers/dingding/.latest-doc-only-change.txt +++ b/airflow/providers/dingding/.latest-doc-only-change.txt @@ -1 +1 @@ -2d109401b3566aef613501691d18cf7e4c776cd2 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/dingding/CHANGELOG.rst b/airflow/providers/dingding/CHANGELOG.rst index ef9f03584cc14..27c4697f88a1d 100644 --- a/airflow/providers/dingding/CHANGELOG.rst +++ b/airflow/providers/dingding/CHANGELOG.rst @@ -16,9 +16,55 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Migrate DingTalk example DAGs to new design #22443 (#24133)`` + * ``Prepare provider documentation 2022.05.11 (#23631)`` + * ``Bump pre-commit hook versions (#22887)`` + * ``Use new Breese for building, pulling and verifying the images. (#23104)`` + * ``Update tree doc references to grid (#22966)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.0.4 ..... diff --git a/airflow/providers/dingding/example_dags/example_dingding.py b/airflow/providers/dingding/example_dags/example_dingding.py deleted file mode 100644 index e57409e740126..0000000000000 --- a/airflow/providers/dingding/example_dags/example_dingding.py +++ /dev/null @@ -1,203 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This is an example dag for using the DingdingOperator. -""" -from datetime import datetime, timedelta - -from airflow import DAG -from airflow.providers.dingding.operators.dingding import DingdingOperator - - -# [START howto_operator_dingding_failure_callback] -def failure_callback(context): - """ - The function that will be executed on failure. - - :param context: The context of the executed task. - """ - message = ( - f"AIRFLOW TASK FAILURE TIPS:\n" - f"DAG: {context['task_instance'].dag_id}\n" - f"TASKS: {context['task_instance'].task_id}\n" - f"Reason: {context['exception']}\n" - ) - return DingdingOperator( - task_id='dingding_success_callback', - message_type='text', - message=message, - at_all=True, - ).execute(context) - - -# [END howto_operator_dingding_failure_callback] - -with DAG( - dag_id='example_dingding_operator', - default_args={'retries': 3, 'on_failure_callback': failure_callback}, - schedule_interval='@once', - dagrun_timeout=timedelta(minutes=60), - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) as dag: - - # [START howto_operator_dingding] - text_msg_remind_none = DingdingOperator( - task_id='text_msg_remind_none', - message_type='text', - message='Airflow dingding text message remind none', - at_mobiles=None, - at_all=False, - ) - # [END howto_operator_dingding] - - text_msg_remind_specific = DingdingOperator( - task_id='text_msg_remind_specific', - message_type='text', - message='Airflow dingding text message remind specific users', - at_mobiles=['156XXXXXXXX', '130XXXXXXXX'], - at_all=False, - ) - - text_msg_remind_include_invalid = DingdingOperator( - task_id='text_msg_remind_include_invalid', - message_type='text', - message='Airflow dingding text message remind users including invalid', - # 123 is invalid user or user not in the group - at_mobiles=['156XXXXXXXX', '123'], - at_all=False, - ) - - # [START howto_operator_dingding_remind_users] - text_msg_remind_all = DingdingOperator( - task_id='text_msg_remind_all', - message_type='text', - message='Airflow dingding text message remind all users in group', - # list of user phone/email here in the group - # when at_all is specific will cover at_mobiles - at_mobiles=['156XXXXXXXX', '130XXXXXXXX'], - at_all=True, - ) - # [END howto_operator_dingding_remind_users] - - link_msg = DingdingOperator( - task_id='link_msg', - message_type='link', - message={ - 'title': 'Airflow dingding link message', - 'text': 'Airflow official documentation link', - 'messageUrl': 'https://airflow.apache.org', - 'picURL': 'https://airflow.apache.org/_images/pin_large.png', - }, - ) - - # [START howto_operator_dingding_rich_text] - markdown_msg = DingdingOperator( - task_id='markdown_msg', - message_type='markdown', - message={ - 'title': 'Airflow dingding markdown message', - 'text': '# Markdown message title\n' - 'content content .. \n' - '### sub-title\n' - '![logo](https://airflow.apache.org/_images/pin_large.png)', - }, - at_mobiles=['156XXXXXXXX'], - at_all=False, - ) - # [END howto_operator_dingding_rich_text] - - single_action_card_msg = DingdingOperator( - task_id='single_action_card_msg', - message_type='actionCard', - message={ - 'title': 'Airflow dingding single actionCard message', - 'text': 'Airflow dingding single actionCard message\n' - '![logo](https://airflow.apache.org/_images/pin_large.png)\n' - 'This is a official logo in Airflow website.', - 'hideAvatar': '0', - 'btnOrientation': '0', - 'singleTitle': 'read more', - 'singleURL': 'https://airflow.apache.org', - }, - ) - - multi_action_card_msg = DingdingOperator( - task_id='multi_action_card_msg', - message_type='actionCard', - message={ - 'title': 'Airflow dingding multi actionCard message', - 'text': 'Airflow dingding multi actionCard message\n' - '![logo](https://airflow.apache.org/_images/pin_large.png)\n' - 'Airflow documentation and GitHub', - 'hideAvatar': '0', - 'btnOrientation': '0', - 'btns': [ - {'title': 'Airflow Documentation', 'actionURL': 'https://airflow.apache.org'}, - {'title': 'Airflow GitHub', 'actionURL': 'https://github.com/apache/airflow'}, - ], - }, - ) - - feed_card_msg = DingdingOperator( - task_id='feed_card_msg', - message_type='feedCard', - message={ - "links": [ - { - "title": "Airflow DAG feed card", - "messageURL": "https://airflow.apache.org/docs/apache-airflow/stable/ui.html", - "picURL": "https://airflow.apache.org/_images/dags.png", - }, - { - "title": "Airflow grid feed card", - "messageURL": "https://airflow.apache.org/docs/apache-airflow/stable/ui.html", - "picURL": "https://airflow.apache.org/_images/grid.png", - }, - { - "title": "Airflow graph feed card", - "messageURL": "https://airflow.apache.org/docs/apache-airflow/stable/ui.html", - "picURL": "https://airflow.apache.org/_images/graph.png", - }, - ] - }, - ) - - msg_failure_callback = DingdingOperator( - task_id='msg_failure_callback', - message_type='not_support_msg_type', - message="", - ) - - ( - [ - text_msg_remind_none, - text_msg_remind_specific, - text_msg_remind_include_invalid, - text_msg_remind_all, - ] - >> link_msg - >> markdown_msg - >> [ - single_action_card_msg, - multi_action_card_msg, - ] - >> feed_card_msg - >> msg_failure_callback - ) diff --git a/airflow/providers/dingding/hooks/dingding.py b/airflow/providers/dingding/hooks/dingding.py index c21c00ea0e2b2..832e8724e55a5 100644 --- a/airflow/providers/dingding/hooks/dingding.py +++ b/airflow/providers/dingding/hooks/dingding.py @@ -15,9 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import json -from typing import List, Optional, Union import requests from requests import Session @@ -43,17 +43,17 @@ class DingdingHook(HttpHook): :param at_all: Remind all people in group or not. If True, will overwrite ``at_mobiles`` """ - conn_name_attr = 'dingding_conn_id' - default_conn_name = 'dingding_default' - conn_type = 'dingding' - hook_name = 'Dingding' + conn_name_attr = "dingding_conn_id" + default_conn_name = "dingding_default" + conn_type = "dingding" + hook_name = "Dingding" def __init__( self, - dingding_conn_id='dingding_default', - message_type: str = 'text', - message: Optional[Union[str, dict]] = None, - at_mobiles: Optional[List[str]] = None, + dingding_conn_id="dingding_default", + message_type: str = "text", + message: str | dict | None = None, + at_mobiles: list[str] | None = None, at_all: bool = False, *args, **kwargs, @@ -70,9 +70,9 @@ def _get_endpoint(self) -> str: token = conn.password if not token: raise AirflowException( - 'Dingding token is requests but get nothing, check you conn_id configuration.' + "Dingding token is requests but get nothing, check you conn_id configuration." ) - return f'robot/send?access_token={token}' + return f"robot/send?access_token={token}" def _build_message(self) -> str: """ @@ -80,17 +80,17 @@ def _build_message(self) -> str: As most commonly used type, text message just need post message content rather than a dict like ``{'content': 'message'}`` """ - if self.message_type in ['text', 'markdown']: + if self.message_type in ["text", "markdown"]: data = { - 'msgtype': self.message_type, - self.message_type: {'content': self.message} if self.message_type == 'text' else self.message, - 'at': {'atMobiles': self.at_mobiles, 'isAtAll': self.at_all}, + "msgtype": self.message_type, + self.message_type: {"content": self.message} if self.message_type == "text" else self.message, + "at": {"atMobiles": self.at_mobiles, "isAtAll": self.at_all}, } else: - data = {'msgtype': self.message_type, self.message_type: self.message} + data = {"msgtype": self.message_type, self.message_type: self.message} return json.dumps(data) - def get_conn(self, headers: Optional[dict] = None) -> Session: + def get_conn(self, headers: dict | None = None) -> Session: """ Overwrite HttpHook get_conn because just need base_url and headers and not don't need generic params @@ -98,7 +98,7 @@ def get_conn(self, headers: Optional[dict] = None) -> Session: :param headers: additional headers to be passed through as a dictionary """ conn = self.get_connection(self.http_conn_id) - self.base_url = conn.host if conn.host else 'https://oapi.dingtalk.com' + self.base_url = conn.host if conn.host else "https://oapi.dingtalk.com" session = requests.Session() if headers: session.headers.update(headers) @@ -106,19 +106,19 @@ def get_conn(self, headers: Optional[dict] = None) -> Session: def send(self) -> None: """Send Dingding message""" - support_type = ['text', 'link', 'markdown', 'actionCard', 'feedCard'] + support_type = ["text", "link", "markdown", "actionCard", "feedCard"] if self.message_type not in support_type: raise ValueError( - f'DingdingWebhookHook only support {support_type} so far, but receive {self.message_type}' + f"DingdingWebhookHook only support {support_type} so far, but receive {self.message_type}" ) data = self._build_message() - self.log.info('Sending Dingding type %s message %s', self.message_type, data) + self.log.info("Sending Dingding type %s message %s", self.message_type, data) resp = self.run( - endpoint=self._get_endpoint(), data=data, headers={'Content-Type': 'application/json'} + endpoint=self._get_endpoint(), data=data, headers={"Content-Type": "application/json"} ) # Dingding success send message will with errcode equal to 0 - if int(resp.json().get('errcode')) != 0: - raise AirflowException(f'Send Dingding message failed, receive error message {resp.text}') - self.log.info('Success Send Dingding message') + if int(resp.json().get("errcode")) != 0: + raise AirflowException(f"Send Dingding message failed, receive error message {resp.text}") + self.log.info("Success Send Dingding message") diff --git a/airflow/providers/dingding/operators/dingding.py b/airflow/providers/dingding/operators/dingding.py index 23e4d144c146f..6152ddefa6e83 100644 --- a/airflow/providers/dingding/operators/dingding.py +++ b/airflow/providers/dingding/operators/dingding.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, List, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.dingding.hooks.dingding import DingdingHook @@ -41,16 +43,16 @@ class DingdingOperator(BaseOperator): :param at_all: Remind all people in group or not. If True, will overwrite ``at_mobiles`` """ - template_fields: Sequence[str] = ('message',) - ui_color = '#4ea4d4' # Dingding icon color + template_fields: Sequence[str] = ("message",) + ui_color = "#4ea4d4" # Dingding icon color def __init__( self, *, - dingding_conn_id: str = 'dingding_default', - message_type: str = 'text', - message: Union[str, dict, None] = None, - at_mobiles: Optional[List[str]] = None, + dingding_conn_id: str = "dingding_default", + message_type: str = "text", + message: str | dict | None = None, + at_mobiles: list[str] | None = None, at_all: bool = False, **kwargs, ) -> None: @@ -61,8 +63,8 @@ def __init__( self.at_mobiles = at_mobiles self.at_all = at_all - def execute(self, context: 'Context') -> None: - self.log.info('Sending Dingding message.') + def execute(self, context: Context) -> None: + self.log.info("Sending Dingding message.") hook = DingdingHook( self.dingding_conn_id, self.message_type, self.message, self.at_mobiles, self.at_all ) diff --git a/airflow/providers/dingding/provider.yaml b/airflow/providers/dingding/provider.yaml index 9aac526f3cf98..37ddc69b8706e 100644 --- a/airflow/providers/dingding/provider.yaml +++ b/airflow/providers/dingding/provider.yaml @@ -22,6 +22,8 @@ description: | `Dingding `__ versions: + - 3.1.0 + - 3.0.0 - 2.0.4 - 2.0.3 - 2.0.2 @@ -31,8 +33,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-http integrations: - integration-name: Dingding @@ -52,8 +55,6 @@ hooks: python-modules: - airflow.providers.dingding.hooks.dingding -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.dingding.hooks.dingding.DingdingHook connection-types: - hook-class-name: airflow.providers.dingding.hooks.dingding.DingdingHook diff --git a/airflow/providers/discord/.latest-doc-only-change.txt b/airflow/providers/discord/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/discord/.latest-doc-only-change.txt +++ b/airflow/providers/discord/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/discord/CHANGELOG.rst b/airflow/providers/discord/CHANGELOG.rst index b57469160c6bd..ba68a18789f3c 100644 --- a/airflow/providers/discord/CHANGELOG.rst +++ b/airflow/providers/discord/CHANGELOG.rst @@ -16,9 +16,53 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Bug Fixes +~~~~~~~~~ + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + * ``Add documentation for July 2022 Provider's release (#25030)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.1.4 ..... diff --git a/airflow/providers/discord/hooks/discord_webhook.py b/airflow/providers/discord/hooks/discord_webhook.py index 8f1931a170a59..e4533c02f84f0 100644 --- a/airflow/providers/discord/hooks/discord_webhook.py +++ b/airflow/providers/discord/hooks/discord_webhook.py @@ -15,10 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations + import json import re -from typing import Any, Dict, Optional +from typing import Any from airflow.exceptions import AirflowException from airflow.providers.http.hooks.http import HttpHook @@ -47,20 +48,20 @@ class DiscordWebhookHook(HttpHook): :param proxy: Proxy to use to make the Discord webhook call """ - conn_name_attr = 'http_conn_id' - default_conn_name = 'discord_default' - conn_type = 'discord' - hook_name = 'Discord' + conn_name_attr = "http_conn_id" + default_conn_name = "discord_default" + conn_type = "discord" + hook_name = "Discord" def __init__( self, - http_conn_id: Optional[str] = None, - webhook_endpoint: Optional[str] = None, + http_conn_id: str | None = None, + webhook_endpoint: str | None = None, message: str = "", - username: Optional[str] = None, - avatar_url: Optional[str] = None, + username: str | None = None, + avatar_url: str | None = None, tts: bool = False, - proxy: Optional[str] = None, + proxy: str | None = None, *args: Any, **kwargs: Any, ) -> None: @@ -73,7 +74,7 @@ def __init__( self.tts = tts self.proxy = proxy - def _get_webhook_endpoint(self, http_conn_id: Optional[str], webhook_endpoint: Optional[str]) -> str: + def _get_webhook_endpoint(self, http_conn_id: str | None, webhook_endpoint: str | None) -> str: """ Given a Discord http_conn_id, return the default webhook endpoint or override if a webhook_endpoint is manually supplied. @@ -87,14 +88,14 @@ def _get_webhook_endpoint(self, http_conn_id: Optional[str], webhook_endpoint: O elif http_conn_id: conn = self.get_connection(http_conn_id) extra = conn.extra_dejson - endpoint = extra.get('webhook_endpoint', '') + endpoint = extra.get("webhook_endpoint", "") else: raise AirflowException( - 'Cannot get webhook endpoint: No valid Discord webhook endpoint or http_conn_id supplied.' + "Cannot get webhook endpoint: No valid Discord webhook endpoint or http_conn_id supplied." ) # make sure endpoint matches the expected Discord webhook format - if not re.match('^webhooks/[0-9]+/[a-zA-Z0-9_-]+$', endpoint): + if not re.match("^webhooks/[0-9]+/[a-zA-Z0-9_-]+$", endpoint): raise AirflowException( 'Expected Discord webhook endpoint in the form of "webhooks/{webhook.id}/{webhook.token}".' ) @@ -108,19 +109,19 @@ def _build_discord_payload(self) -> str: :return: Discord payload (str) to send """ - payload: Dict[str, Any] = {} + payload: dict[str, Any] = {} if self.username: - payload['username'] = self.username + payload["username"] = self.username if self.avatar_url: - payload['avatar_url'] = self.avatar_url + payload["avatar_url"] = self.avatar_url - payload['tts'] = self.tts + payload["tts"] = self.tts if len(self.message) <= 2000: - payload['content'] = self.message + payload["content"] = self.message else: - raise AirflowException('Discord message length must be 2000 or fewer characters.') + raise AirflowException("Discord message length must be 2000 or fewer characters.") return json.dumps(payload) @@ -129,13 +130,13 @@ def execute(self) -> None: proxies = {} if self.proxy: # we only need https proxy for Discord - proxies = {'https': self.proxy} + proxies = {"https": self.proxy} discord_payload = self._build_discord_payload() self.run( endpoint=self.webhook_endpoint, data=discord_payload, - headers={'Content-type': 'application/json'}, - extra_options={'proxies': proxies}, + headers={"Content-type": "application/json"}, + extra_options={"proxies": proxies}, ) diff --git a/airflow/providers/discord/operators/discord_webhook.py b/airflow/providers/discord/operators/discord_webhook.py index 81129b678860b..3ab9d8fb0c0c6 100644 --- a/airflow/providers/discord/operators/discord_webhook.py +++ b/airflow/providers/discord/operators/discord_webhook.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -from typing import TYPE_CHECKING, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException from airflow.providers.discord.hooks.discord_webhook import DiscordWebhookHook @@ -49,24 +50,24 @@ class DiscordWebhookOperator(SimpleHttpOperator): :param proxy: Proxy to use to make the Discord webhook call """ - template_fields: Sequence[str] = ('username', 'message', 'webhook_endpoint') + template_fields: Sequence[str] = ("username", "message", "webhook_endpoint") def __init__( self, *, - http_conn_id: Optional[str] = None, - webhook_endpoint: Optional[str] = None, + http_conn_id: str | None = None, + webhook_endpoint: str | None = None, message: str = "", - username: Optional[str] = None, - avatar_url: Optional[str] = None, + username: str | None = None, + avatar_url: str | None = None, tts: bool = False, - proxy: Optional[str] = None, + proxy: str | None = None, **kwargs, ) -> None: super().__init__(endpoint=webhook_endpoint, **kwargs) if not http_conn_id: - raise AirflowException('No valid Discord http_conn_id supplied.') + raise AirflowException("No valid Discord http_conn_id supplied.") self.http_conn_id = http_conn_id self.webhook_endpoint = webhook_endpoint @@ -75,9 +76,9 @@ def __init__( self.avatar_url = avatar_url self.tts = tts self.proxy = proxy - self.hook: Optional[DiscordWebhookHook] = None + self.hook: DiscordWebhookHook | None = None - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: """Call the DiscordWebhookHook to post message""" self.hook = DiscordWebhookHook( self.http_conn_id, diff --git a/airflow/providers/discord/provider.yaml b/airflow/providers/discord/provider.yaml index 409a585e47967..3dc1a30ac9165 100644 --- a/airflow/providers/discord/provider.yaml +++ b/airflow/providers/discord/provider.yaml @@ -22,6 +22,8 @@ description: | `Discord `__ versions: + - 3.1.0 + - 3.0.0 - 2.1.4 - 2.0.4 - 2.0.3 @@ -31,8 +33,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-http integrations: - integration-name: Discord @@ -50,9 +53,6 @@ hooks: python-modules: - airflow.providers.discord.hooks.discord_webhook -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.discord.hooks.discord_webhook.DiscordWebhookHook - connection-types: - hook-class-name: airflow.providers.discord.hooks.discord_webhook.DiscordWebhookHook connection-type: discord diff --git a/airflow/providers/docker/CHANGELOG.rst b/airflow/providers/docker/CHANGELOG.rst index 991bd10c7bb6d..3e402d4ef29e7 100644 --- a/airflow/providers/docker/CHANGELOG.rst +++ b/airflow/providers/docker/CHANGELOG.rst @@ -16,9 +16,93 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.3.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Features +~~~~~~~~ + +* ``Add ipc_mode for DockerOperator (#27553)`` +* ``Add env-file parameter to Docker Operator (#26951)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + +3.2.0 +..... + +Features +~~~~~~~~ + +* ``Add logging options to docker operator (#26653)`` +* ``Add pre-commit hook for custom_operator_name (#25786)`` +* ``Implement ExternalPythonOperator (#25780)`` + +Bug Fixes +~~~~~~~~~ + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Force-remove container after DockerOperator execution (#23160)`` + +Bug Fixes +~~~~~~~~~ + +* ``'DockerOperator' fix cli.logs giving character array instead of string (#24726)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + * ``Clean up task decorator type hints and docstrings (#24667)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Misc +~~~~ + +* ``Remove 'xcom_push' from 'DockerOperator' (#23981)`` +* ``docker new system test (#23167)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.7.0 ..... diff --git a/airflow/providers/docker/decorators/docker.py b/airflow/providers/docker/decorators/docker.py index db3293a5f569f..8fb868ab8e4fc 100644 --- a/airflow/providers/docker/decorators/docker.py +++ b/airflow/providers/docker/decorators/docker.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import base64 import inspect @@ -21,13 +22,21 @@ import pickle from tempfile import TemporaryDirectory from textwrap import dedent -from typing import TYPE_CHECKING, Callable, Optional, Sequence, TypeVar +from typing import TYPE_CHECKING, Callable, Sequence import dill from airflow.decorators.base import DecoratedOperator, task_decorator_factory from airflow.providers.docker.operators.docker import DockerOperator -from airflow.utils.python_virtualenv import remove_task_decorator, write_python_script + +try: + from airflow.utils.decorators import remove_task_decorator + + # This can be removed after we move to Airflow 2.4+ +except ImportError: + from airflow.utils.python_virtualenv import remove_task_decorator + +from airflow.utils.python_virtualenv import write_python_script if TYPE_CHECKING: from airflow.decorators.base import TaskDecorator @@ -38,7 +47,7 @@ def _generate_decode_command(env_var, file, python_command): # We don't need `f.close()` as the interpreter is about to exit anyway return ( f'{python_command} -c "import base64, os;' - rf'x = base64.b64decode(os.environ[\"{env_var}\"]);' + rf"x = base64.b64decode(os.environ[\"{env_var}\"]);" rf'f = open(\"{file}\", \"wb\"); f.write(x);"' ) @@ -53,6 +62,10 @@ class _DockerDecoratedOperator(DecoratedOperator, DockerOperator): Wraps a Python callable and captures args/kwargs when called for execution. :param python_callable: A reference to an object that is callable + :param python: Python binary name to use + :param use_dill: Whether dill should be used to serialize the callable + :param expect_airflow: whether to expect airflow to be installed in the docker environment. if this + one is specified, the script to run callable will attempt to load Airflow macros. :param op_kwargs: a dictionary of keyword arguments that will get unpacked in your function (templated) :param op_args: a list of positional arguments that will get unpacked when @@ -62,20 +75,24 @@ class _DockerDecoratedOperator(DecoratedOperator, DockerOperator): Defaults to False. """ - template_fields: Sequence[str] = ('op_args', 'op_kwargs') + custom_operator_name = "@task.docker" + + template_fields: Sequence[str] = ("op_args", "op_kwargs") # since we won't mutate the arguments, we should just do the shallow copy # there are some cases we can't deepcopy the objects (e.g protobuf). - shallow_copy_attrs: Sequence[str] = ('python_callable',) + shallow_copy_attrs: Sequence[str] = ("python_callable",) def __init__( self, use_dill=False, - python_command='python3', + python_command="python3", + expect_airflow: bool = True, **kwargs, ) -> None: command = "dummy command" self.python_command = python_command + self.expect_airflow = expect_airflow self.pickling_library = dill if use_dill else pickle super().__init__( command=command, retrieve_output=True, retrieve_output_path="/tmp/script.out", **kwargs @@ -86,17 +103,17 @@ def generate_command(self): f"""bash -cx '{_generate_decode_command("__PYTHON_SCRIPT", "/tmp/script.py", self.python_command)} &&""" f'{_generate_decode_command("__PYTHON_INPUT", "/tmp/script.in", self.python_command)} &&' - f'{self.python_command} /tmp/script.py /tmp/script.in /tmp/script.out\'' + f"{self.python_command} /tmp/script.py /tmp/script.in /tmp/script.out'" ) - def execute(self, context: 'Context'): - with TemporaryDirectory(prefix='venv') as tmp_dir: - input_filename = os.path.join(tmp_dir, 'script.in') - script_filename = os.path.join(tmp_dir, 'script.py') + def execute(self, context: Context): + with TemporaryDirectory(prefix="venv") as tmp_dir: + input_filename = os.path.join(tmp_dir, "script.in") + script_filename = os.path.join(tmp_dir, "script.py") - with open(input_filename, 'wb') as file: + with open(input_filename, "wb") as file: if self.op_args or self.op_kwargs: - self.pickling_library.dump({'args': self.op_args, 'kwargs': self.op_kwargs}, file) + self.pickling_library.dump({"args": self.op_args, "kwargs": self.op_kwargs}, file) py_source = self._get_python_source() write_python_script( jinja_context=dict( @@ -105,6 +122,7 @@ def execute(self, context: 'Context'): pickling_library=self.pickling_library.__name__, python_callable=self.python_callable.__name__, python_callable_source=py_source, + expect_airflow=self.expect_airflow, string_args_global=False, ), filename=script_filename, @@ -129,14 +147,11 @@ def _get_python_source(self): return res -T = TypeVar("T", bound=Callable) - - def docker_task( - python_callable: Optional[Callable] = None, - multiple_outputs: Optional[bool] = None, + python_callable: Callable | None = None, + multiple_outputs: bool | None = None, **kwargs, -) -> "TaskDecorator": +) -> TaskDecorator: """ Python operator decorator. Wraps a function into an Airflow operator. Also accepts any argument that DockerOperator will via ``kwargs``. Can be reused in a single DAG. diff --git a/airflow/providers/docker/example_dags/example_docker.py b/airflow/providers/docker/example_dags/example_docker.py deleted file mode 100644 index 83f6744883b07..0000000000000 --- a/airflow/providers/docker/example_dags/example_docker.py +++ /dev/null @@ -1,51 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from datetime import datetime, timedelta - -from airflow import DAG -from airflow.operators.bash import BashOperator -from airflow.providers.docker.operators.docker import DockerOperator - -dag = DAG( - 'docker_sample', - default_args={'retries': 1}, - schedule_interval=timedelta(minutes=10), - start_date=datetime(2021, 1, 1), - catchup=False, -) - -t1 = BashOperator(task_id='print_date', bash_command='date', dag=dag) - -t2 = BashOperator(task_id='sleep', bash_command='sleep 5', retries=3, dag=dag) - -t3 = DockerOperator( - docker_url='tcp://localhost:2375', # Set your docker URL - command='/bin/sleep 30', - image='centos:latest', - network_mode='bridge', - task_id='docker_op_tester', - dag=dag, -) - - -t4 = BashOperator(task_id='print_hello', bash_command='echo "hello world!!!"', dag=dag) - - -t1 >> t2 -t1 >> t3 -t3 >> t4 diff --git a/airflow/providers/docker/example_dags/example_docker_copy_data.py b/airflow/providers/docker/example_dags/example_docker_copy_data.py deleted file mode 100644 index 5ce78d02cd201..0000000000000 --- a/airflow/providers/docker/example_dags/example_docker_copy_data.py +++ /dev/null @@ -1,101 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -This sample "listen to directory". move the new file and print it, -using docker-containers. -The following operators are being used: DockerOperator, -BashOperator & ShortCircuitOperator. -TODO: Review the workflow, change it accordingly to - your environment & enable the code. -""" - -from datetime import datetime, timedelta - -from docker.types import Mount - -from airflow import DAG -from airflow.operators.bash import BashOperator -from airflow.operators.python import ShortCircuitOperator -from airflow.providers.docker.operators.docker import DockerOperator - -dag = DAG( - "docker_sample_copy_data", - default_args={"retries": 1}, - schedule_interval=timedelta(minutes=10), - start_date=datetime(2021, 1, 1), - catchup=False, -) - -locate_file_cmd = """ - sleep 10 - find {{params.source_location}} -type f -printf "%f\n" | head -1 -""" - -t_view = BashOperator( - task_id="view_file", - bash_command=locate_file_cmd, - do_xcom_push=True, - params={"source_location": "/your/input_dir/path"}, - dag=dag, -) - -t_is_data_available = ShortCircuitOperator( - task_id="check_if_data_available", - python_callable=lambda task_output: not task_output == "", - op_kwargs=dict(task_output=t_view.output), - dag=dag, -) - -t_move = DockerOperator( - api_version="1.19", - docker_url="tcp://localhost:2375", # replace it with swarm/docker endpoint - image="centos:latest", - network_mode="bridge", - mounts=[ - Mount(source="/your/host/input_dir/path", target="/your/input_dir/path", type="bind"), - Mount(source="/your/host/output_dir/path", target="/your/output_dir/path", type="bind"), - ], - command=[ - "/bin/bash", - "-c", - "/bin/sleep 30; " - "/bin/mv {{ params.source_location }}/" + str(t_view.output) + " {{ params.target_location }};" - "/bin/echo '{{ params.target_location }}/" + f"{t_view.output}';", - ], - task_id="move_data", - do_xcom_push=True, - params={"source_location": "/your/input_dir/path", "target_location": "/your/output_dir/path"}, - dag=dag, -) - -t_print = DockerOperator( - api_version="1.19", - docker_url="tcp://localhost:2375", - image="centos:latest", - mounts=[Mount(source="/your/host/output_dir/path", target="/your/output_dir/path", type="bind")], - command=f"cat {t_move.output}", - task_id="print", - dag=dag, -) - -t_is_data_available.set_downstream(t_move) -t_move.set_downstream(t_print) - -# Task dependencies created via `XComArgs`: -# t_view >> t_is_data_available diff --git a/airflow/providers/docker/example_dags/example_docker_swarm.py b/airflow/providers/docker/example_dags/example_docker_swarm.py deleted file mode 100644 index 365a4b44a97bb..0000000000000 --- a/airflow/providers/docker/example_dags/example_docker_swarm.py +++ /dev/null @@ -1,38 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from datetime import datetime, timedelta - -from airflow import DAG -from airflow.providers.docker.operators.docker_swarm import DockerSwarmOperator - -dag = DAG( - 'docker_swarm_sample', - schedule_interval=timedelta(minutes=10), - start_date=datetime(2021, 1, 1), - catchup=False, -) - -with dag as dag: - t1 = DockerSwarmOperator( - api_version='auto', - docker_url='tcp://localhost:2375', # Set your docker URL - command='/bin/sleep 10', - image='centos:latest', - auto_remove=True, - task_id='sleep_with_swarm', - ) diff --git a/airflow/providers/docker/example_dags/tutorial_taskflow_api_etl_docker_virtualenv.py b/airflow/providers/docker/example_dags/tutorial_taskflow_api_etl_docker_virtualenv.py deleted file mode 100644 index c16588c10c149..0000000000000 --- a/airflow/providers/docker/example_dags/tutorial_taskflow_api_etl_docker_virtualenv.py +++ /dev/null @@ -1,111 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -# [START tutorial] -# [START import_module] -from datetime import datetime - -from airflow.decorators import dag, task - -# [END import_module] - - -# [START instantiate_dag] -@dag(schedule_interval=None, start_date=datetime(2021, 1, 1), catchup=False, tags=['example']) -def tutorial_taskflow_api_etl_docker_virtualenv(): - """ - ### TaskFlow API Tutorial Documentation - This is a simple ETL data pipeline example which demonstrates the use of - the TaskFlow API using three simple tasks for Extract, Transform, and Load. - Documentation that goes along with the Airflow TaskFlow API tutorial is - located - [here](https://airflow.apache.org/docs/apache-airflow/stable/tutorial_taskflow_api.html) - """ - # [END instantiate_dag] - - # [START extract_virtualenv] - @task.virtualenv( - use_dill=True, - system_site_packages=False, - requirements=['funcsigs'], - ) - def extract(): - """ - #### Extract task - A simple Extract task to get data ready for the rest of the data - pipeline. In this case, getting data is simulated by reading from a - hardcoded JSON string. - """ - import json - - data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}' - - order_data_dict = json.loads(data_string) - return order_data_dict - - # [END extract_virtualenv] - - # [START transform_docker] - @task.docker(image='python:3.9-slim-bullseye', multiple_outputs=True) - def transform(order_data_dict: dict): - """ - #### Transform task - A simple Transform task which takes in the collection of order data and - computes the total order value. - """ - total_order_value = 0 - - for value in order_data_dict.values(): - total_order_value += value - - return {"total_order_value": total_order_value} - - # [END transform_docker] - - # [START load] - @task() - def load(total_order_value: float): - """ - #### Load task - A simple Load task which takes in the result of the Transform task and - instead of saving it to end user review, just prints it out. - """ - - print(f"Total order value is: {total_order_value:.2f}") - - # [END load] - - # [START main_flow] - order_data = extract() - order_summary = transform(order_data) - load(order_summary["total_order_value"]) - # [END main_flow] - - -# The try/except here is because Airflow versions less than 2.2.0 doesn't support -# @task.docker decorator and we use this dag in CI test. Thus, in order not to -# break the CI test, we added this try/except here. -try: - # [START dag_invocation] - tutorial_etl_dag = tutorial_taskflow_api_etl_docker_virtualenv() - # [END dag_invocation] -except AttributeError: - pass - -# [END tutorial] diff --git a/airflow/providers/docker/hooks/docker.py b/airflow/providers/docker/hooks/docker.py index c5c339c85fd5a..981e4b6a48a52 100644 --- a/airflow/providers/docker/hooks/docker.py +++ b/airflow/providers/docker/hooks/docker.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Optional +from __future__ import annotations + +from typing import Any from docker import APIClient # type: ignore[attr-defined] from docker.constants import DEFAULT_TIMEOUT_SECONDS # type: ignore[attr-defined] @@ -34,45 +36,45 @@ class DockerHook(BaseHook, LoggingMixin): where credentials and extra configuration are stored """ - conn_name_attr = 'docker_conn_id' - default_conn_name = 'docker_default' - conn_type = 'docker' - hook_name = 'Docker' + conn_name_attr = "docker_conn_id" + default_conn_name = "docker_default" + conn_type = "docker" + hook_name = "Docker" @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { - "hidden_fields": ['schema'], + "hidden_fields": ["schema"], "relabeling": { - 'host': 'Registry URL', - 'login': 'Username', + "host": "Registry URL", + "login": "Username", }, } def __init__( self, - docker_conn_id: Optional[str] = default_conn_name, - base_url: Optional[str] = None, - version: Optional[str] = None, - tls: Optional[str] = None, + docker_conn_id: str | None = default_conn_name, + base_url: str | None = None, + version: str | None = None, + tls: str | None = None, timeout: int = DEFAULT_TIMEOUT_SECONDS, ) -> None: super().__init__() if not base_url: - raise AirflowException('No Docker base URL provided') + raise AirflowException("No Docker base URL provided") if not version: - raise AirflowException('No Docker API version provided') + raise AirflowException("No Docker API version provided") if not docker_conn_id: - raise AirflowException('No Docker connection id provided') + raise AirflowException("No Docker connection id provided") conn = self.get_connection(docker_conn_id) if not conn.host: - raise AirflowException('No Docker URL provided') + raise AirflowException("No Docker URL provided") if not conn.login: - raise AirflowException('No username provided') + raise AirflowException("No username provided") extra_options = conn.extra_dejson self.__base_url = base_url @@ -85,8 +87,8 @@ def __init__( self.__registry = conn.host self.__username = conn.login self.__password = conn.password - self.__email = extra_options.get('email') - self.__reauth = extra_options.get('reauth') != 'no' + self.__email = extra_options.get("email") + self.__reauth = extra_options.get("reauth") != "no" def get_conn(self) -> APIClient: client = APIClient( @@ -96,7 +98,7 @@ def get_conn(self) -> APIClient: return client def __login(self, client) -> None: - self.log.debug('Logging into Docker') + self.log.debug("Logging into Docker") try: client.login( username=self.__username, @@ -105,7 +107,7 @@ def __login(self, client) -> None: email=self.__email, reauth=self.__reauth, ) - self.log.debug('Login successful') + self.log.debug("Login successful") except APIError as docker_error: - self.log.error('Docker login failed: %s', str(docker_error)) - raise AirflowException(f'Docker login failed: {docker_error}') + self.log.error("Docker login failed: %s", str(docker_error)) + raise AirflowException(f"Docker login failed: {docker_error}") diff --git a/airflow/providers/docker/operators/docker.py b/airflow/providers/docker/operators/docker.py index 61f5f9e393f94..509713680a766 100644 --- a/airflow/providers/docker/operators/docker.py +++ b/airflow/providers/docker/operators/docker.py @@ -16,17 +16,21 @@ # specific language governing permissions and limitations # under the License. """Implements Docker operator""" +from __future__ import annotations + import ast -import io import pickle import tarfile +import warnings +from io import BytesIO, StringIO from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Iterable, Sequence from docker import APIClient, tls # type: ignore[attr-defined] from docker.constants import DEFAULT_TIMEOUT_SECONDS # type: ignore[attr-defined] from docker.errors import APIError # type: ignore[attr-defined] -from docker.types import DeviceRequest, Mount # type: ignore[attr-defined] +from docker.types import DeviceRequest, LogConfig, Mount # type: ignore[attr-defined] +from dotenv import dotenv_values from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -36,11 +40,11 @@ from airflow.utils.context import Context -def stringify(line: Union[str, bytes]): +def stringify(line: str | bytes): """Make sure string is returned even if bytes are passed. Docker stream can return bytes.""" - decode_method = getattr(line, 'decode', None) + decode_method = getattr(line, "decode", None) if decode_method: - return decode_method(encoding='utf-8', errors='surrogateescape') + return decode_method(encoding="utf-8", errors="surrogateescape") else: return line @@ -84,6 +88,8 @@ class DockerOperator(BaseOperator): :param environment: Environment variables to set in the container. (templated) :param private_environment: Private environment variables to set in the container. These are not templated, and hidden from the website. + :param env_file: Relative path to the .env file with environment variables to set in the container. + Overridden by variables in the environment parameter. (templated) :param force_pull: Pull the docker image on every run. Default is False. :param mem_limit: Maximum amount of memory the container can use. Either a float value, which represents the limit in bytes, @@ -125,7 +131,7 @@ class DockerOperator(BaseOperator): :param dns_search: Docker custom DNS search domain :param auto_remove: Auto-removal of the container on daemon side when the container's process exits. - The default is False. + The default is never. :param shm_size: Size of ``/dev/shm`` in bytes. The size must be greater than 0. If omitted uses system default. :param tty: Allocate pseudo-TTY to the container @@ -137,59 +143,85 @@ class DockerOperator(BaseOperator): output that is not posted to logs :param retrieve_output_path: path for output file that will be retrieved and passed to xcom :param device_requests: Expose host resources such as GPUs to the container. + :param log_opts_max_size: The maximum size of the log before it is rolled. + A positive integer plus a modifier representing the unit of measure (k, m, or g). + Eg: 10m or 1g Defaults to -1 (unlimited). + :param log_opts_max_file: The maximum number of log files that can be present. + If rolling the logs creates excess files, the oldest file is removed. + Only effective when max-size is also set. A positive integer. Defaults to 1. + :param ipc_mode: Set the IPC mode for the container. """ - template_fields: Sequence[str] = ('image', 'command', 'environment', 'container_name') + template_fields: Sequence[str] = ("image", "command", "environment", "env_file", "container_name") + template_fields_renderers = {"env_file": "yaml"} template_ext: Sequence[str] = ( - '.sh', - '.bash', + ".sh", + ".bash", + ".env", ) def __init__( self, *, image: str, - api_version: Optional[str] = None, - command: Optional[Union[str, List[str]]] = None, - container_name: Optional[str] = None, + api_version: str | None = None, + command: str | list[str] | None = None, + container_name: str | None = None, cpus: float = 1.0, - docker_url: str = 'unix://var/run/docker.sock', - environment: Optional[Dict] = None, - private_environment: Optional[Dict] = None, + docker_url: str = "unix://var/run/docker.sock", + environment: dict | None = None, + private_environment: dict | None = None, + env_file: str | None = None, force_pull: bool = False, - mem_limit: Optional[Union[float, str]] = None, - host_tmp_dir: Optional[str] = None, - network_mode: Optional[str] = None, - tls_ca_cert: Optional[str] = None, - tls_client_cert: Optional[str] = None, - tls_client_key: Optional[str] = None, - tls_hostname: Optional[Union[str, bool]] = None, - tls_ssl_version: Optional[str] = None, + mem_limit: float | str | None = None, + host_tmp_dir: str | None = None, + network_mode: str | None = None, + tls_ca_cert: str | None = None, + tls_client_cert: str | None = None, + tls_client_key: str | None = None, + tls_hostname: str | bool | None = None, + tls_ssl_version: str | None = None, mount_tmp_dir: bool = True, - tmp_dir: str = '/tmp/airflow', - user: Optional[Union[str, int]] = None, - mounts: Optional[List[Mount]] = None, - entrypoint: Optional[Union[str, List[str]]] = None, - working_dir: Optional[str] = None, + tmp_dir: str = "/tmp/airflow", + user: str | int | None = None, + mounts: list[Mount] | None = None, + entrypoint: str | list[str] | None = None, + working_dir: str | None = None, xcom_all: bool = False, - docker_conn_id: Optional[str] = None, - dns: Optional[List[str]] = None, - dns_search: Optional[List[str]] = None, - auto_remove: bool = False, - shm_size: Optional[int] = None, + docker_conn_id: str | None = None, + dns: list[str] | None = None, + dns_search: list[str] | None = None, + auto_remove: str = "never", + shm_size: int | None = None, tty: bool = False, privileged: bool = False, - cap_add: Optional[Iterable[str]] = None, - extra_hosts: Optional[Dict[str, str]] = None, + cap_add: Iterable[str] | None = None, + extra_hosts: dict[str, str] | None = None, retrieve_output: bool = False, - retrieve_output_path: Optional[str] = None, + retrieve_output_path: str | None = None, timeout: int = DEFAULT_TIMEOUT_SECONDS, - device_requests: Optional[List[DeviceRequest]] = None, + device_requests: list[DeviceRequest] | None = None, + log_opts_max_size: str | None = None, + log_opts_max_file: str | None = None, + ipc_mode: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) self.api_version = api_version - self.auto_remove = auto_remove + if type(auto_remove) == bool: + warnings.warn( + "bool value for auto_remove is deprecated, please use 'never', 'success', or 'force' instead", + DeprecationWarning, + stacklevel=2, + ) + if str(auto_remove) == "False": + self.auto_remove = "never" + elif str(auto_remove) == "True": + self.auto_remove = "success" + elif str(auto_remove) in ("never", "success", "force"): + self.auto_remove = auto_remove + else: + raise ValueError("unsupported auto_remove option, use 'never', 'success', or 'force' instead") self.command = command self.container_name = container_name self.cpus = cpus @@ -198,6 +230,7 @@ def __init__( self.docker_url = docker_url self.environment = environment or {} self._private_environment = private_environment or {} + self.env_file = env_file self.force_pull = force_pull self.image = image self.mem_limit = mem_limit @@ -228,6 +261,9 @@ def __init__( self.retrieve_output_path = retrieve_output_path self.timeout = timeout self.device_requests = device_requests + self.log_opts_max_size = log_opts_max_size + self.log_opts_max_file = log_opts_max_file + self.ipc_mode = ipc_mode def get_hook(self) -> DockerHook: """ @@ -243,13 +279,13 @@ def get_hook(self) -> DockerHook: timeout=self.timeout, ) - def _run_image(self) -> Optional[Union[List[str], str]]: + def _run_image(self) -> list[str] | str | None: """Run a Docker container with the provided image""" - self.log.info('Starting docker container from image %s', self.image) + self.log.info("Starting docker container from image %s", self.image) if not self.cli: raise Exception("The 'cli' should be initialized before!") if self.mount_tmp_dir: - with TemporaryDirectory(prefix='airflowtmp', dir=self.host_tmp_dir) as host_tmp_dir_generated: + with TemporaryDirectory(prefix="airflowtmp", dir=self.host_tmp_dir) as host_tmp_dir_generated: tmp_mount = Mount(self.tmp_dir, host_tmp_dir_generated, "bind") try: return self._run_image_with_mounts(self.mounts + [tmp_mount], add_tmp_variable=True) @@ -266,19 +302,25 @@ def _run_image(self) -> Optional[Union[List[str], str]]: else: return self._run_image_with_mounts(self.mounts, add_tmp_variable=False) - def _run_image_with_mounts( - self, target_mounts, add_tmp_variable: bool - ) -> Optional[Union[List[str], str]]: + def _run_image_with_mounts(self, target_mounts, add_tmp_variable: bool) -> list[str] | str | None: if add_tmp_variable: - self.environment['AIRFLOW_TMP_DIR'] = self.tmp_dir + self.environment["AIRFLOW_TMP_DIR"] = self.tmp_dir else: - self.environment.pop('AIRFLOW_TMP_DIR', None) + self.environment.pop("AIRFLOW_TMP_DIR", None) if not self.cli: raise Exception("The 'cli' should be initialized before!") + docker_log_config = {} + if self.log_opts_max_size is not None: + docker_log_config["max-size"] = self.log_opts_max_size + if self.log_opts_max_file is not None: + docker_log_config["max-file"] = self.log_opts_max_file + env_file_vars = {} + if self.env_file is not None: + env_file_vars = self.unpack_environment_variables(self.env_file) self.container = self.cli.create_container( command=self.format_command(self.command), name=self.container_name, - environment={**self.environment, **self._private_environment}, + environment={**env_file_vars, **self.environment, **self._private_environment}, host_config=self.cli.create_host_config( auto_remove=False, mounts=target_mounts, @@ -292,6 +334,8 @@ def _run_image_with_mounts( extra_hosts=self.extra_hosts, privileged=self.privileged, device_requests=self.device_requests, + log_config=LogConfig(config=docker_log_config), + ipc_mode=self.ipc_mode, ), image=self.image, user=self.user, @@ -299,9 +343,9 @@ def _run_image_with_mounts( working_dir=self.working_dir, tty=self.tty, ) - logstream = self.cli.attach(container=self.container['Id'], stdout=True, stderr=True, stream=True) + logstream = self.cli.attach(container=self.container["Id"], stdout=True, stderr=True, stream=True) try: - self.cli.start(self.container['Id']) + self.cli.start(self.container["Id"]) log_lines = [] for log_chunk in logstream: @@ -309,33 +353,30 @@ def _run_image_with_mounts( log_lines.append(log_chunk) self.log.info("%s", log_chunk) - result = self.cli.wait(self.container['Id']) - if result['StatusCode'] != 0: + result = self.cli.wait(self.container["Id"]) + if result["StatusCode"] != 0: joined_log_lines = "\n".join(log_lines) - raise AirflowException(f'Docker container failed: {repr(result)} lines {joined_log_lines}') + raise AirflowException(f"Docker container failed: {repr(result)} lines {joined_log_lines}") if self.retrieve_output: return self._attempt_to_retrieve_result() elif self.do_xcom_push: - log_parameters = { - 'container': self.container['Id'], - 'stdout': True, - 'stderr': True, - 'stream': True, - } + if len(log_lines) == 0: + return None try: if self.xcom_all: - return [stringify(line).strip() for line in self.cli.logs(**log_parameters)] + return log_lines else: - lines = [stringify(line).strip() for line in self.cli.logs(**log_parameters, tail=1)] - return lines[-1] if lines else None + return log_lines[-1] except StopIteration: # handle the case when there is not a single line to iterate on return None return None finally: - if self.auto_remove: - self.cli.remove_container(self.container['Id']) + if self.auto_remove == "success": + self.cli.remove_container(self.container["Id"]) + elif self.auto_remove == "force": + self.cli.remove_container(self.container["Id"], force=True) def _attempt_to_retrieve_result(self): """ @@ -347,22 +388,22 @@ def _attempt_to_retrieve_result(self): def copy_from_docker(container_id, src): archived_result, stat = self.cli.get_archive(container_id, src) - if stat['size'] == 0: + if stat["size"] == 0: # 0 byte file, it can't be anything else than None return None # no need to port to a file since we intend to deserialize - file_standin = io.BytesIO(b"".join(archived_result)) + file_standin = BytesIO(b"".join(archived_result)) tar = tarfile.open(fileobj=file_standin) - file = tar.extractfile(stat['name']) - lib = getattr(self, 'pickling_library', pickle) + file = tar.extractfile(stat["name"]) + lib = getattr(self, "pickling_library", pickle) return lib.loads(file.read()) try: - return copy_from_docker(self.container['Id'], self.retrieve_output_path) + return copy_from_docker(self.container["Id"], self.retrieve_output_path) except APIError: return None - def execute(self, context: 'Context') -> Optional[str]: + def execute(self, context: Context) -> str | None: self.cli = self._get_cli() if not self.cli: raise Exception("The 'cli' should be initialized before!") @@ -370,15 +411,15 @@ def execute(self, context: 'Context') -> Optional[str]: # Pull the docker image if `force_pull` is set or image does not exist locally if self.force_pull or not self.cli.images(name=self.image): - self.log.info('Pulling docker image %s', self.image) + self.log.info("Pulling docker image %s", self.image) latest_status = {} for output in self.cli.pull(self.image, stream=True, decode=True): if isinstance(output, str): self.log.info("%s", output) continue - if isinstance(output, dict) and 'status' in output: + if isinstance(output, dict) and "status" in output: output_status = output["status"] - if 'id' not in output: + if "id" not in output: self.log.info("%s", output_status) continue @@ -398,28 +439,27 @@ def _get_cli(self) -> APIClient: ) @staticmethod - def format_command(command: Union[str, List[str]]) -> Union[List[str], str]: + def format_command(command: str | list[str]) -> list[str] | str: """ Retrieve command(s). if command string starts with [, it returns the command list) :param command: Docker command or entrypoint :return: the command (or commands) - :rtype: str | List[str] """ - if isinstance(command, str) and command.strip().find('[') == 0: + if isinstance(command, str) and command.strip().find("[") == 0: return ast.literal_eval(command) return command def on_kill(self) -> None: if self.cli is not None: - self.log.info('Stopping docker container') + self.log.info("Stopping docker container") if self.container is None: - self.log.info('Not attempting to kill container as it was not created') + self.log.info("Not attempting to kill container as it was not created") return - self.cli.stop(self.container['Id']) + self.cli.stop(self.container["Id"]) - def __get_tls_config(self) -> Optional[tls.TLSConfig]: + def __get_tls_config(self) -> tls.TLSConfig | None: tls_config = None if self.tls_ca_cert and self.tls_client_cert and self.tls_client_key: # Ignore type error on SSL version here - it is deprecated and type annotation is wrong @@ -431,5 +471,16 @@ def __get_tls_config(self) -> Optional[tls.TLSConfig]: ssl_version=self.tls_ssl_version, assert_hostname=self.tls_hostname, ) - self.docker_url = self.docker_url.replace('tcp://', 'https://') + self.docker_url = self.docker_url.replace("tcp://", "https://") return tls_config + + @staticmethod + def unpack_environment_variables(env_str: str) -> dict: + r""" + Parse environment variables from the string + + :param env_str: environment variables in key=value format separated by '\n' + + :return: dictionary containing parsed environment variables + """ + return dotenv_values(stream=StringIO(env_str)) diff --git a/airflow/providers/docker/operators/docker_swarm.py b/airflow/providers/docker/operators/docker_swarm.py index 1fbb35d8d8af7..d92b0c036c428 100644 --- a/airflow/providers/docker/operators/docker_swarm.py +++ b/airflow/providers/docker/operators/docker_swarm.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. """Run ephemeral Docker Swarm services""" -from typing import TYPE_CHECKING, List, Optional, Union +from __future__ import annotations + +from typing import TYPE_CHECKING from docker import types @@ -95,11 +97,11 @@ def __init__( *, image: str, enable_logging: bool = True, - configs: Optional[List[types.ConfigReference]] = None, - secrets: Optional[List[types.SecretReference]] = None, - mode: Optional[types.ServiceMode] = None, - networks: Optional[List[Union[str, types.NetworkAttachmentConfig]]] = None, - placement: Optional[Union[types.Placement, List[types.Placement]]] = None, + configs: list[types.ConfigReference] | None = None, + secrets: list[types.SecretReference] | None = None, + mode: types.ServiceMode | None = None, + networks: list[str | types.NetworkAttachmentConfig] | None = None, + placement: types.Placement | list[types.Placement] | None = None, **kwargs, ) -> None: super().__init__(image=image, **kwargs) @@ -112,15 +114,15 @@ def __init__( self.networks = networks self.placement = placement - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: self.cli = self._get_cli() - self.environment['AIRFLOW_TMP_DIR'] = self.tmp_dir + self.environment["AIRFLOW_TMP_DIR"] = self.tmp_dir return self._run_service() def _run_service(self) -> None: - self.log.info('Starting docker service from image %s', self.image) + self.log.info("Starting docker service from image %s", self.image) if not self.cli: raise Exception("The 'cli' should be initialized before!") self.service = self.cli.create_service( @@ -135,21 +137,21 @@ def _run_service(self) -> None: configs=self.configs, secrets=self.secrets, ), - restart_policy=types.RestartPolicy(condition='none'), + restart_policy=types.RestartPolicy(condition="none"), resources=types.Resources(mem_limit=self.mem_limit), networks=self.networks, placement=self.placement, ), - name=f'airflow-{get_random_string()}', - labels={'name': f'airflow__{self.dag_id}__{self.task_id}'}, + name=f"airflow-{get_random_string()}", + labels={"name": f"airflow__{self.dag_id}__{self.task_id}"}, mode=self.mode, ) if self.service is None: raise Exception("Service should be set here") - self.log.info('Service started: %s', str(self.service)) + self.log.info("Service started: %s", str(self.service)) # wait for the service to start the task - while not self.cli.tasks(filters={'service': self.service['ID']}): + while not self.cli.tasks(filters={"service": self.service["ID"]}): continue if self.enable_logging: @@ -157,28 +159,29 @@ def _run_service(self) -> None: while True: if self._has_service_terminated(): - self.log.info('Service status before exiting: %s', self._service_status()) + self.log.info("Service status before exiting: %s", self._service_status()) break - if self.service and self._service_status() != 'complete': - if self.auto_remove: - self.cli.remove_service(self.service['ID']) - raise AirflowException('Service did not complete: ' + repr(self.service)) - elif self.auto_remove: + self.log.info("auto_removeauto_removeauto_removeauto_removeauto_remove : %s", str(self.auto_remove)) + if self.service and self._service_status() != "complete": + if self.auto_remove == "success": + self.cli.remove_service(self.service["ID"]) + raise AirflowException("Service did not complete: " + repr(self.service)) + elif self.auto_remove == "success": if not self.service: raise Exception("The 'service' should be initialized before!") - self.cli.remove_service(self.service['ID']) + self.cli.remove_service(self.service["ID"]) - def _service_status(self) -> Optional[str]: + def _service_status(self) -> str | None: if not self.cli: raise Exception("The 'cli' should be initialized before!") if not self.service: raise Exception("The 'service' should be initialized before!") - return self.cli.tasks(filters={'service': self.service['ID']})[0]['Status']['State'] + return self.cli.tasks(filters={"service": self.service["ID"]})[0]["Status"]["State"] def _has_service_terminated(self) -> bool: status = self._service_status() - return status in ['complete', 'failed', 'shutdown', 'rejected', 'orphaned', 'remove'] + return status in ["complete", "failed", "shutdown", "rejected", "orphaned", "remove"] def _stream_logs_to_output(self) -> None: if not self.cli: @@ -186,9 +189,9 @@ def _stream_logs_to_output(self) -> None: if not self.service: raise Exception("The 'service' should be initialized before!") logs = self.cli.service_logs( - self.service['ID'], follow=True, stdout=True, stderr=True, is_tty=self.tty + self.service["ID"], follow=True, stdout=True, stderr=True, is_tty=self.tty ) - line = '' + line = "" while True: try: log = next(logs) @@ -200,9 +203,9 @@ def _stream_logs_to_output(self) -> None: log = log.decode() except UnicodeDecodeError: continue - if log == '\n': + if log == "\n": self.log.info(line) - line = '' + line = "" else: line += log # flush any remaining log stream @@ -211,5 +214,5 @@ def _stream_logs_to_output(self) -> None: def on_kill(self) -> None: if self.cli is not None and self.service is not None: - self.log.info('Removing docker service: %s', self.service['ID']) - self.cli.remove_service(self.service['ID']) + self.log.info("Removing docker service: %s", self.service["ID"]) + self.cli.remove_service(self.service["ID"]) diff --git a/airflow/providers/docker/provider.yaml b/airflow/providers/docker/provider.yaml index 9a4cc2bab5dca..623222ae253ac 100644 --- a/airflow/providers/docker/provider.yaml +++ b/airflow/providers/docker/provider.yaml @@ -22,6 +22,10 @@ description: | `Docker `__ versions: + - 3.3.0 + - 3.2.0 + - 3.1.0 + - 3.0.0 - 2.7.0 - 2.6.0 - 2.5.2 @@ -40,8 +44,10 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.4.0 + - docker>=5.0.3 + - python-dotenv>=0.21.0 integrations: - integration-name: Docker @@ -66,9 +72,6 @@ hooks: python-modules: - airflow.providers.docker.hooks.docker -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.docker.hooks.docker.DockerHook - connection-types: - hook-class-name: airflow.providers.docker.hooks.docker.DockerHook connection-type: docker diff --git a/airflow/providers/elasticsearch/CHANGELOG.rst b/airflow/providers/elasticsearch/CHANGELOG.rst index db50b07729bb6..64c806be67a4f 100644 --- a/airflow/providers/elasticsearch/CHANGELOG.rst +++ b/airflow/providers/elasticsearch/CHANGELOG.rst @@ -16,9 +16,95 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +4.3.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + +4.2.1 +..... + +Misc +~~~~ + +* ``Add common-sql lower bound for common-sql (#25789)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +4.2.0 +..... + +Features +~~~~~~~~ + +* ``Improve ElasticsearchTaskHandler (#21942)`` + + +4.1.0 +..... + +Features +~~~~~~~~ + +* ``Adding ElasticserachPythonHook - ES Hook With The Python Client (#24895)`` +* ``Move all SQL classes to common-sql provider (#24836)`` + +Bug Fixes +~~~~~~~~~ + +* ``Move fallible ti.task.dag assignment back inside try/except block (#24533) (#24592)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Only assert stuff for mypy when type checking (#24937)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +4.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Misc +~~~~ + +* ``Apply per-run log templates to log handlers (#24153)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Fix new MyPy errors in main (#22884)`` + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``removed old files (#24172)`` + * ``Prepare provider documentation 2022.05.11 (#23631)`` + * ``Use new Breese for building, pulling and verifying the images. (#23104)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 3.0.3 ..... diff --git a/airflow/providers/elasticsearch/example_dags/example_elasticsearch_query.py b/airflow/providers/elasticsearch/example_dags/example_elasticsearch_query.py deleted file mode 100644 index d4a7b100f00f6..0000000000000 --- a/airflow/providers/elasticsearch/example_dags/example_elasticsearch_query.py +++ /dev/null @@ -1,50 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from datetime import datetime, timedelta - -from airflow import DAG -from airflow.decorators import task -from airflow.providers.elasticsearch.hooks.elasticsearch import ElasticsearchHook - - -@task(task_id='es_print_tables') -def show_tables(): - """ - show_tables queries elasticsearch to list available tables - """ - es = ElasticsearchHook(elasticsearch_conn_id='production-es') - - # Handle ES conn with context manager - with es.get_conn() as es_conn: - tables = es_conn.execute('SHOW TABLES') - for table, *_ in tables: - print(f"table: {table}") - return True - - -# Using a DAG context manager, you don't have to specify the dag property of each task -with DAG( - 'elasticsearch_dag', - start_date=datetime(2021, 8, 30), - max_active_runs=1, - schedule_interval=timedelta(days=1), - default_args={'retries': 1}, # Default setting applied to all tasks - catchup=False, -) as dag: - - show_tables() diff --git a/airflow/providers/elasticsearch/hooks/elasticsearch.py b/airflow/providers/elasticsearch/hooks/elasticsearch.py index b48511670ffe5..7556817693118 100644 --- a/airflow/providers/elasticsearch/hooks/elasticsearch.py +++ b/airflow/providers/elasticsearch/hooks/elasticsearch.py @@ -15,16 +15,21 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import Optional +import warnings +from typing import Any +from elasticsearch import Elasticsearch from es.elastic.api import Connection as ESConnection, connect -from airflow.hooks.dbapi import DbApiHook +from airflow.compat.functools import cached_property +from airflow.hooks.base import BaseHook from airflow.models.connection import Connection as AirflowConnection +from airflow.providers.common.sql.hooks.sql import DbApiHook -class ElasticsearchHook(DbApiHook): +class ElasticsearchSQLHook(DbApiHook): """ Interact with Elasticsearch through the elasticsearch-dbapi. @@ -34,12 +39,12 @@ class ElasticsearchHook(DbApiHook): used for Elasticsearch credentials. """ - conn_name_attr = 'elasticsearch_conn_id' - default_conn_name = 'elasticsearch_default' - conn_type = 'elasticsearch' - hook_name = 'Elasticsearch' + conn_name_attr = "elasticsearch_conn_id" + default_conn_name = "elasticsearch_default" + conn_type = "elasticsearch" + hook_name = "Elasticsearch" - def __init__(self, schema: str = "http", connection: Optional[AirflowConnection] = None, *args, **kwargs): + def __init__(self, schema: str = "http", connection: AirflowConnection | None = None, *args, **kwargs): super().__init__(*args, **kwargs) self.schema = schema self.connection = connection @@ -57,10 +62,10 @@ def get_conn(self) -> ESConnection: scheme=conn.schema or "http", ) - if conn.extra_dejson.get('http_compress', False): + if conn.extra_dejson.get("http_compress", False): conn_args["http_compress"] = bool(["http_compress"]) - if conn.extra_dejson.get('timeout', False): + if conn.extra_dejson.get("timeout", False): conn_args["timeout"] = conn.extra_dejson["timeout"] conn = connect(**conn_args) @@ -71,25 +76,80 @@ def get_uri(self) -> str: conn_id = getattr(self, self.conn_name_attr) conn = self.connection or self.get_connection(conn_id) - login = '' + login = "" if conn.login: - login = '{conn.login}:{conn.password}@'.format(conn=conn) + login = "{conn.login}:{conn.password}@".format(conn=conn) host = conn.host if conn.port is not None: - host += f':{conn.port}' - uri = '{conn.conn_type}+{conn.schema}://{login}{host}/'.format(conn=conn, login=login, host=host) + host += f":{conn.port}" + uri = "{conn.conn_type}+{conn.schema}://{login}{host}/".format(conn=conn, login=login, host=host) extras_length = len(conn.extra_dejson) if not extras_length: return uri - uri += '?' + uri += "?" for arg_key, arg_value in conn.extra_dejson.items(): extras_length -= 1 uri += f"{arg_key}={arg_value}" if extras_length: - uri += '&' + uri += "&" return uri + + +class ElasticsearchHook(ElasticsearchSQLHook): + """ + This class is deprecated and was renamed to ElasticsearchSQLHook. + Please use `airflow.providers.elasticsearch.hooks.elasticsearch.ElasticsearchSQLHook`. + """ + + def __init__(self, *args, **kwargs): + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.elasticsearch.hooks.elasticsearch.ElasticsearchSQLHook`.""", + DeprecationWarning, + stacklevel=3, + ) + super().__init__(*args, **kwargs) + + +class ElasticsearchPythonHook(BaseHook): + """ + Interacts with Elasticsearch. This hook uses the official Elasticsearch Python Client. + + :param hosts: list: A list of a single or many Elasticsearch instances. Example: ["http://localhost:9200"] + :param es_conn_args: dict: Additional arguments you might need to enter to connect to Elasticsearch. + Example: {"ca_cert":"/path/to/cert", "basic_auth": "(user, pass)"} + """ + + def __init__(self, hosts: list[Any], es_conn_args: dict | None = None): + super().__init__() + self.hosts = hosts + self.es_conn_args = es_conn_args if es_conn_args else {} + + def _get_elastic_connection(self): + """Returns the Elasticsearch client""" + client = Elasticsearch(self.hosts, **self.es_conn_args) + + return client + + @cached_property + def get_conn(self): + """Returns the Elasticsearch client (cached)""" + return self._get_elastic_connection() + + def search(self, query: dict[Any, Any], index: str = "_all") -> dict: + """ + Returns results matching a query using Elasticsearch DSL + + :param index: str: The index you want to query + :param query: dict: The query you want to run + + :returns: dict: The response 'hits' object from Elasticsearch + """ + es_client = self.get_conn + result = es_client.search(index=index, body=query) + return result["hits"] diff --git a/airflow/providers/elasticsearch/log/es_json_formatter.py b/airflow/providers/elasticsearch/log/es_json_formatter.py new file mode 100644 index 0000000000000..bfc68fe20e93f --- /dev/null +++ b/airflow/providers/elasticsearch/log/es_json_formatter.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pendulum + +from airflow.utils.log.json_formatter import JSONFormatter + + +class ElasticsearchJSONFormatter(JSONFormatter): + """ + ElasticsearchJSONFormatter instances are used to convert a log record + to json with ISO 8601 date and time format + """ + + default_time_format = "%Y-%m-%dT%H:%M:%S" + default_msec_format = "%s.%03d" + default_tz_format = "%z" + + def formatTime(self, record, datefmt=None): + """ + Returns the creation time of the specified LogRecord in ISO 8601 date and time format + in the local time zone. + """ + dt = pendulum.from_timestamp(record.created, tz=pendulum.local_timezone()) + if datefmt: + s = dt.strftime(datefmt) + else: + s = dt.strftime(self.default_time_format) + + if self.default_msec_format: + s = self.default_msec_format % (s, record.msecs) + if self.default_tz_format: + s += dt.strftime(self.default_tz_format) + return s diff --git a/airflow/providers/elasticsearch/log/es_task_handler.py b/airflow/providers/elasticsearch/log/es_task_handler.py index 83c1163d80c87..ecd964a1472f8 100644 --- a/airflow/providers/elasticsearch/log/es_task_handler.py +++ b/airflow/providers/elasticsearch/log/es_task_handler.py @@ -15,14 +15,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import logging import sys +import warnings from collections import defaultdict from datetime import datetime from operator import attrgetter from time import time -from typing import List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Tuple from urllib.parse import quote # Using `from elasticsearch import *` would break elasticsearch mocking used in unit test. @@ -31,15 +33,23 @@ from elasticsearch_dsl import Search from airflow.configuration import conf -from airflow.models import TaskInstance +from airflow.models.dagrun import DagRun +from airflow.models.taskinstance import TaskInstance +from airflow.providers.elasticsearch.log.es_json_formatter import ElasticsearchJSONFormatter from airflow.utils import timezone from airflow.utils.log.file_task_handler import FileTaskHandler -from airflow.utils.log.json_formatter import JSONFormatter from airflow.utils.log.logging_mixin import ExternalLoggingMixin, LoggingMixin +from airflow.utils.session import create_session +LOG_LINE_DEFAULTS = {"exc_text": "", "stack_info": ""} # Elasticsearch hosted log type EsLogMsgType = List[Tuple[str, str]] +# Compatibility: Airflow 2.3.3 and up uses this method, which accesses the +# LogTemplate model to record the log ID template used. If this function does +# not exist, the task handler should use the log_id_template attribute instead. +USE_PER_RUN_LOG_ID = hasattr(DagRun, "get_log_template") + class ElasticsearchTaskHandler(FileTaskHandler, ExternalLoggingMixin, LoggingMixin): """ @@ -60,13 +70,11 @@ class ElasticsearchTaskHandler(FileTaskHandler, ExternalLoggingMixin, LoggingMix PAGE = 0 MAX_LINE_PER_PAGE = 1000 - LOG_NAME = 'Elasticsearch' + LOG_NAME = "Elasticsearch" def __init__( self, base_log_folder: str, - filename_template: str, - log_id_template: str, end_of_log_mark: str, write_stdout: bool, json_format: bool, @@ -75,7 +83,10 @@ def __init__( offset_field: str = "offset", host: str = "localhost:9200", frontend: str = "localhost:5601", - es_kwargs: Optional[dict] = conf.getsection("elasticsearch_configs"), + es_kwargs: dict | None = conf.getsection("elasticsearch_configs"), + *, + filename_template: str | None = None, + log_id_template: str | None = None, ): """ :param base_log_folder: base folder to store logs locally @@ -86,12 +97,18 @@ def __init__( super().__init__(base_log_folder, filename_template) self.closed = False - self.client = elasticsearch.Elasticsearch([host], **es_kwargs) # type: ignore[attr-defined] + self.client = elasticsearch.Elasticsearch(host.split(";"), **es_kwargs) # type: ignore[attr-defined] + + if USE_PER_RUN_LOG_ID and log_id_template is not None: + warnings.warn( + "Passing log_id_template to ElasticsearchTaskHandler is deprecated and has no effect", + DeprecationWarning, + ) - self.log_id_template = log_id_template + self.log_id_template = log_id_template # Only used on Airflow < 2.3.2. self.frontend = frontend self.mark_end_on_close = True - self.end_of_log_mark = end_of_log_mark + self.end_of_log_mark = end_of_log_mark.strip() self.write_stdout = write_stdout self.json_format = json_format self.json_fields = [label.strip() for label in json_fields.split(",")] @@ -100,16 +117,24 @@ def __init__( self.context_set = False self.formatter: logging.Formatter - self.handler: Union[logging.FileHandler, logging.StreamHandler] # type: ignore[assignment] + self.handler: logging.FileHandler | logging.StreamHandler # type: ignore[assignment] def _render_log_id(self, ti: TaskInstance, try_number: int) -> str: - dag_run = ti.get_dagrun() - dag = ti.task.dag - assert dag is not None # For Mypy. + with create_session() as session: + dag_run = ti.get_dagrun(session=session) + if USE_PER_RUN_LOG_ID: + log_id_template = dag_run.get_log_template(session=session).elasticsearch_id + else: + log_id_template = self.log_id_template + try: - data_interval: Tuple[datetime, datetime] = dag.get_run_data_interval(dag_run) + dag = ti.task.dag except AttributeError: # ti.task is not always set. data_interval = (dag_run.data_interval_start, dag_run.data_interval_end) + else: + if TYPE_CHECKING: + assert dag is not None + data_interval = dag.get_run_data_interval(dag_run) if self.json_format: data_interval_start = self._clean_date(data_interval[0]) @@ -126,7 +151,7 @@ def _render_log_id(self, ti: TaskInstance, try_number: int) -> str: data_interval_end = "" execution_date = dag_run.execution_date.isoformat() - return self.log_id_template.format( + return log_id_template.format( dag_id=ti.dag_id, task_id=ti.task_id, run_id=getattr(ti, "run_id", ""), @@ -138,7 +163,7 @@ def _render_log_id(self, ti: TaskInstance, try_number: int) -> str: ) @staticmethod - def _clean_date(value: Optional[datetime]) -> str: + def _clean_date(value: datetime | None) -> str: """ Clean up a date value so that it is safe to query in elasticsearch by removing reserved characters. @@ -152,20 +177,17 @@ def _clean_date(value: Optional[datetime]) -> str: def _group_logs_by_host(self, logs): grouped_logs = defaultdict(list) for log in logs: - key = getattr(log, self.host_field, 'default_host') + key = getattr(log, self.host_field, "default_host") grouped_logs[key].append(log) - # return items sorted by timestamp. - result = sorted(grouped_logs.items(), key=lambda kv: getattr(kv[1][0], 'message', '_')) - - return result + return grouped_logs def _read_grouped_logs(self): return True def _read( - self, ti: TaskInstance, try_number: int, metadata: Optional[dict] = None - ) -> Tuple[EsLogMsgType, dict]: + self, ti: TaskInstance, try_number: int, metadata: dict | None = None + ) -> tuple[EsLogMsgType, dict]: """ Endpoint for streaming log. @@ -176,11 +198,11 @@ def _read( :return: a list of tuple with host and log documents, metadata. """ if not metadata: - metadata = {'offset': 0} - if 'offset' not in metadata: - metadata['offset'] = 0 + metadata = {"offset": 0} + if "offset" not in metadata: + metadata["offset"] = 0 - offset = metadata['offset'] + offset = metadata["offset"] log_id = self._render_log_id(ti, try_number) logs = self.es_read(log_id, offset, metadata) @@ -191,47 +213,47 @@ def _read( # Ensure a string here. Large offset numbers will get JSON.parsed incorrectly # on the client. Sending as a string prevents this issue. # https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Number/MAX_SAFE_INTEGER - metadata['offset'] = str(next_offset) + metadata["offset"] = str(next_offset) # end_of_log_mark may contain characters like '\n' which is needed to # have the log uploaded but will not be stored in elasticsearch. - loading_hosts = [ - item[0] for item in logs_by_host if item[-1][-1].message != self.end_of_log_mark.strip() - ] - metadata['end_of_log'] = False if not logs else len(loading_hosts) == 0 + metadata["end_of_log"] = False + for logs in logs_by_host.values(): + if logs[-1].message == self.end_of_log_mark: + metadata["end_of_log"] = True cur_ts = pendulum.now() - if 'last_log_timestamp' in metadata: - last_log_ts = timezone.parse(metadata['last_log_timestamp']) + if "last_log_timestamp" in metadata: + last_log_ts = timezone.parse(metadata["last_log_timestamp"]) # if we are not getting any logs at all after more than N seconds of trying, # assume logs do not exist if int(next_offset) == 0 and cur_ts.diff(last_log_ts).in_seconds() > 5: - metadata['end_of_log'] = True + metadata["end_of_log"] = True missing_log_message = ( f"*** Log {log_id} not found in Elasticsearch. " "If your task started recently, please wait a moment and reload this page. " "Otherwise, the logs for this task instance may have been removed." ) - return [('', missing_log_message)], metadata + return [("", missing_log_message)], metadata if ( # Assume end of log after not receiving new log for N min, cur_ts.diff(last_log_ts).in_minutes() >= 5 # if max_offset specified, respect it - or ('max_offset' in metadata and int(offset) >= int(metadata['max_offset'])) + or ("max_offset" in metadata and int(offset) >= int(metadata["max_offset"])) ): - metadata['end_of_log'] = True + metadata["end_of_log"] = True - if int(offset) != int(next_offset) or 'last_log_timestamp' not in metadata: - metadata['last_log_timestamp'] = str(cur_ts) + if int(offset) != int(next_offset) or "last_log_timestamp" not in metadata: + metadata["last_log_timestamp"] = str(cur_ts) # If we hit the end of the log, remove the actual end_of_log message # to prevent it from showing in the UI. def concat_logs(lines): - log_range = (len(lines) - 1) if lines[-1].message == self.end_of_log_mark.strip() else len(lines) - return '\n'.join(self._format_msg(lines[i]) for i in range(log_range)) + log_range = (len(lines) - 1) if lines[-1].message == self.end_of_log_mark else len(lines) + return "\n".join(self._format_msg(lines[i]) for i in range(log_range)) - message = [(host, concat_logs(hosted_log)) for host, hosted_log in logs_by_host] + message = [(host, concat_logs(hosted_log)) for host, hosted_log in logs_by_host.items()] return message, metadata @@ -241,8 +263,9 @@ def _format_msg(self, log_line): # if we change the formatter style from '%' to '{' or '$', this will still work if self.json_format: try: - - return self.formatter._style.format(_ESJsonLogFmt(self.json_fields, **log_line.to_dict())) + return self.formatter._style.format( + logging.makeLogRecord({**LOG_LINE_DEFAULTS, **log_line.to_dict()}) + ) except Exception: pass @@ -259,20 +282,20 @@ def es_read(self, log_id: str, offset: str, metadata: dict) -> list: :param metadata: log metadata, used for steaming log download. """ # Offset is the unique key for sorting logs given log_id. - search = Search(using=self.client).query('match_phrase', log_id=log_id).sort(self.offset_field) + search = Search(using=self.client).query("match_phrase", log_id=log_id).sort(self.offset_field) - search = search.filter('range', **{self.offset_field: {'gt': int(offset)}}) + search = search.filter("range", **{self.offset_field: {"gt": int(offset)}}) max_log_line = search.count() - if 'download_logs' in metadata and metadata['download_logs'] and 'max_offset' not in metadata: + if "download_logs" in metadata and metadata["download_logs"] and "max_offset" not in metadata: try: if max_log_line > 0: - metadata['max_offset'] = attrgetter(self.offset_field)( + metadata["max_offset"] = attrgetter(self.offset_field)( search[max_log_line - 1].execute()[-1] ) else: - metadata['max_offset'] = 0 + metadata["max_offset"] = 0 except Exception: - self.log.exception('Could not get current log size with log_id: %s', log_id) + self.log.exception("Could not get current log size with log_id: %s", log_id) logs = [] if max_log_line != 0: @@ -280,13 +303,13 @@ def es_read(self, log_id: str, offset: str, metadata: dict) -> list: logs = search[self.MAX_LINE_PER_PAGE * self.PAGE : self.MAX_LINE_PER_PAGE].execute() except Exception: - self.log.exception('Could not read log with log_id: %s', log_id) + self.log.exception("Could not read log with log_id: %s", log_id) return logs def emit(self, record): if self.handler: - record.offset = int(time() * (10**9)) + setattr(record, self.offset_field, int(time() * (10**9))) self.handler.emit(record) def set_context(self, ti: TaskInstance) -> None: @@ -298,15 +321,15 @@ def set_context(self, ti: TaskInstance) -> None: self.mark_end_on_close = not ti.raw if self.json_format: - self.formatter = JSONFormatter( + self.formatter = ElasticsearchJSONFormatter( fmt=self.formatter._fmt, json_fields=self.json_fields + [self.offset_field], extras={ - 'dag_id': str(ti.dag_id), - 'task_id': str(ti.task_id), - 'execution_date': self._clean_date(ti.execution_date), - 'try_number': str(ti.try_number), - 'log_id': self._render_log_id(ti, ti.try_number), + "dag_id": str(ti.dag_id), + "task_id": str(ti.task_id), + "execution_date": self._clean_date(ti.execution_date), + "try_number": str(ti.try_number), + "log_id": self._render_log_id(ti, ti.try_number), }, ) @@ -347,7 +370,7 @@ def close(self) -> None: # Mark the end of file using end of log mark, # so we know where to stop while auto-tailing. - self.handler.stream.write(self.end_of_log_mark) + self.emit(logging.makeLogRecord({"msg": self.end_of_log_mark})) if self.write_stdout: self.handler.close() @@ -369,23 +392,12 @@ def get_external_log_url(self, task_instance: TaskInstance, try_number: int) -> :param task_instance: task instance object :param try_number: task instance try_number to read logs from. :return: URL to the external log collection service - :rtype: str """ log_id = self._render_log_id(task_instance, try_number) - scheme = '' if '://' in self.frontend else 'https://' + scheme = "" if "://" in self.frontend else "https://" return scheme + self.frontend.format(log_id=quote(log_id)) @property def supports_external_link(self) -> bool: """Whether we can support external links""" return bool(self.frontend) - - -class _ESJsonLogFmt: - """Helper class to read ES Logs and re-format it to match settings.LOG_FORMAT""" - - # A separate class is needed because 'self.formatter._style.format' uses '.__dict__' - def __init__(self, json_fields: List, **kwargs): - for field in json_fields: - self.__setattr__(field, '') - self.__dict__.update(kwargs) diff --git a/airflow/providers/elasticsearch/provider.yaml b/airflow/providers/elasticsearch/provider.yaml index d2234d542e4f3..32663a92920a6 100644 --- a/airflow/providers/elasticsearch/provider.yaml +++ b/airflow/providers/elasticsearch/provider.yaml @@ -22,6 +22,11 @@ description: | `Elasticsearch `__ versions: + - 4.3.0 + - 4.2.1 + - 4.2.0 + - 4.1.0 + - 4.0.0 - 3.0.3 - 3.0.2 - 3.0.1 @@ -37,8 +42,12 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-common-sql>=1.3.1 + - elasticsearch>7 + - elasticsearch-dbapi + - elasticsearch-dsl>=5.0.0 integrations: - integration-name: Elasticsearch @@ -51,9 +60,6 @@ hooks: python-modules: - airflow.providers.elasticsearch.hooks.elasticsearch -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.elasticsearch.hooks.elasticsearch.ElasticsearchHook - connection-types: - hook-class-name: airflow.providers.elasticsearch.hooks.elasticsearch.ElasticsearchHook connection-type: elasticsearch diff --git a/airflow/providers/exasol/CHANGELOG.rst b/airflow/providers/exasol/CHANGELOG.rst index 5cd54d10a1def..d3d58f64b618b 100644 --- a/airflow/providers/exasol/CHANGELOG.rst +++ b/airflow/providers/exasol/CHANGELOG.rst @@ -16,9 +16,95 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +4.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Features +~~~~~~~~ + +* ``Add SQLExecuteQueryOperator (#25717)`` +* ``Use DbApiHook.run for DbApiHook.get_records and DbApiHook.get_first (#26944)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + +4.0.1 +..... + +Misc +~~~~ + +* ``Add common-sql lower bound for common-sql (#25789)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +4.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* ``Deprecate hql parameters and synchronize DBApiHook method APIs (#25299)`` + +Features +~~~~~~~~ + +* ``Unify DbApiHook.run() method with the methods which override it (#23971)`` + + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Move all SQL classes to common-sql provider (#24836)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Bug Fixes +~~~~~~~~~ + +* ``Fix UnboundLocalError when sql is empty list in ExasolHook (#23812)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.1.3 ..... diff --git a/airflow/providers/exasol/hooks/exasol.py b/airflow/providers/exasol/hooks/exasol.py index 2233ce1e2c347..49289df37a314 100644 --- a/airflow/providers/exasol/hooks/exasol.py +++ b/airflow/providers/exasol/hooks/exasol.py @@ -15,15 +15,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from contextlib import closing -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Iterable, Mapping import pandas as pd import pyexasol from pyexasol import ExaConnection -from airflow.hooks.dbapi import DbApiHook +from airflow.providers.common.sql.hooks.sql import DbApiHook class ExasolHook(DbApiHook): @@ -37,10 +38,10 @@ class ExasolHook(DbApiHook): for more details. """ - conn_name_attr = 'exasol_conn_id' - default_conn_name = 'exasol_default' - conn_type = 'exasol' - hook_name = 'Exasol' + conn_name_attr = "exasol_conn_id" + default_conn_name = "exasol_default" + conn_type = "exasol" + hook_name = "Exasol" supports_autocommit = True def __init__(self, *args, **kwargs) -> None: @@ -51,22 +52,20 @@ def get_conn(self) -> ExaConnection: conn_id = getattr(self, self.conn_name_attr) conn = self.get_connection(conn_id) conn_args = dict( - dsn=f'{conn.host}:{conn.port}', + dsn=f"{conn.host}:{conn.port}", user=conn.login, password=conn.password, schema=self.schema or conn.schema, ) # check for parameters in conn.extra for arg_name, arg_val in conn.extra_dejson.items(): - if arg_name in ['compression', 'encryption', 'json_lib', 'client_name']: + if arg_name in ["compression", "encryption", "json_lib", "client_name"]: conn_args[arg_name] = arg_val conn = pyexasol.connect(**conn_args) return conn - def get_pandas_df( - self, sql: Union[str, list], parameters: Optional[dict] = None, **kwargs - ) -> pd.DataFrame: + def get_pandas_df(self, sql: str, parameters: dict | None = None, **kwargs) -> pd.DataFrame: """ Executes the sql and returns a pandas dataframe @@ -80,8 +79,10 @@ def get_pandas_df( return df def get_records( - self, sql: Union[str, list], parameters: Optional[dict] = None - ) -> List[Union[dict, Tuple[Any, ...]]]: + self, + sql: str | list[str], + parameters: Iterable | Mapping | None = None, + ) -> list[dict | tuple[Any, ...]]: """ Executes the sql and returns a set of records. @@ -93,7 +94,7 @@ def get_records( with closing(conn.execute(sql, parameters)) as cur: return cur.fetchall() - def get_first(self, sql: Union[str, list], parameters: Optional[dict] = None) -> Optional[Any]: + def get_first(self, sql: str | list[str], parameters: Iterable | Mapping | None = None) -> Any: """ Executes the sql and returns the first resulting row. @@ -109,8 +110,8 @@ def export_to_file( self, filename: str, query_or_table: str, - query_params: Optional[Dict] = None, - export_params: Optional[Dict] = None, + query_params: dict | None = None, + export_params: dict | None = None, ) -> None: """ Exports data to a file. @@ -133,8 +134,14 @@ def export_to_file( self.log.info("Data saved to %s", filename) def run( - self, sql: Union[str, list], autocommit: bool = False, parameters: Optional[dict] = None, handler=None - ) -> Optional[list]: + self, + sql: str | Iterable[str], + autocommit: bool = False, + parameters: Iterable | Mapping | None = None, + handler: Callable | None = None, + split_statements: bool = False, + return_last: bool = True, + ) -> Any | list[Any] | None: """ Runs a command or a list of commands. Pass a list of sql statements to the sql parameter to get them to execute @@ -146,38 +153,44 @@ def run( before executing the query. :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. + :param split_statements: Whether to split a single SQL string into statements and run separately + :param return_last: Whether to return result for only last statement or for all after split + :return: return only result of the LAST SQL expression if handler was provided. """ + self.scalar_return_last = isinstance(sql, str) and return_last if isinstance(sql, str): - sql = [sql] + if split_statements: + sql = self.split_sql_string(sql) + else: + sql = [self.strip_sql_string(sql)] if sql: - self.log.debug("Executing %d statements against Exasol DB", len(sql)) + self.log.debug("Executing following statements against Exasol DB: %s", list(sql)) else: raise ValueError("List of SQL statements is empty") with closing(self.get_conn()) as conn: - if self.supports_autocommit: - self.set_autocommit(conn, autocommit) - - for query in sql: - self.log.info(query) - with closing(conn.execute(query, parameters)) as cur: - results = [] - + self.set_autocommit(conn, autocommit) + results = [] + for sql_statement in sql: + with closing(conn.execute(sql_statement, parameters)) as cur: + self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters) if handler is not None: - cur = handler(cur) + result = handler(cur) + results.append(result) - for row in cur: - self.log.info("Statement execution info - %s", row) - results.append(row) + self.log.info("Rows affected: %s", cur.rowcount) - self.log.info(cur.row_count) - # If autocommit was set to False for db that supports autocommit, - # or if db does not support autocommit, we do a manual commit. + # If autocommit was set to False or db does not support autocommit, we do a manual commit. if not self.get_autocommit(conn): conn.commit() - return results + if handler is None: + return None + elif self.scalar_return_last: + return results[-1] + else: + return results def set_autocommit(self, conn, autocommit: bool) -> None: """ @@ -202,15 +215,14 @@ def get_autocommit(self, conn) -> bool: :param conn: Connection to get autocommit setting from. :return: connection autocommit setting. - :rtype: bool """ - autocommit = conn.attr.get('autocommit') + autocommit = conn.attr.get("autocommit") if autocommit is None: autocommit = super().get_autocommit(conn) return autocommit @staticmethod - def _serialize_cell(cell, conn=None) -> object: + def _serialize_cell(cell, conn=None) -> Any: """ Exasol will adapt all arguments to the execute() method internally, hence we return cell without any conversion. @@ -218,6 +230,5 @@ def _serialize_cell(cell, conn=None) -> object: :param cell: The cell to insert into the table :param conn: The database connection :return: The cell - :rtype: object """ return cell diff --git a/airflow/providers/exasol/operators/exasol.py b/airflow/providers/exasol/operators/exasol.py index eecf44885ec42..253e443b8ee6c 100644 --- a/airflow/providers/exasol/operators/exasol.py +++ b/airflow/providers/exasol/operators/exasol.py @@ -15,16 +15,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Optional, Sequence +from __future__ import annotations -from airflow.models import BaseOperator -from airflow.providers.exasol.hooks.exasol import ExasolHook +import warnings +from typing import Sequence -if TYPE_CHECKING: - from airflow.utils.context import Context +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator -class ExasolOperator(BaseOperator): +class ExasolOperator(SQLExecuteQueryOperator): """ Executes sql code in a specific Exasol database @@ -38,29 +37,22 @@ class ExasolOperator(BaseOperator): :param schema: (optional) name of the schema which overwrite defined one in connection """ - template_fields: Sequence[str] = ('sql',) - template_ext: Sequence[str] = ('.sql',) - template_fields_renderers = {'sql': 'sql'} - ui_color = '#ededed' + template_fields: Sequence[str] = ("sql",) + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "sql"} + ui_color = "#ededed" def __init__( - self, - *, - sql: str, - exasol_conn_id: str = 'exasol_default', - autocommit: bool = False, - parameters: Optional[dict] = None, - schema: Optional[str] = None, - **kwargs, + self, *, exasol_conn_id: str = "exasol_default", schema: str | None = None, **kwargs ) -> None: - super().__init__(**kwargs) - self.exasol_conn_id = exasol_conn_id - self.sql = sql - self.autocommit = autocommit - self.parameters = parameters - self.schema = schema - - def execute(self, context: 'Context') -> None: - self.log.info('Executing: %s', self.sql) - hook = ExasolHook(exasol_conn_id=self.exasol_conn_id, schema=self.schema) - hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) + if schema is not None: + hook_params = kwargs.pop("hook_params", {}) + kwargs["hook_params"] = {"schema": schema, **hook_params} + + super().__init__(conn_id=exasol_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""", + DeprecationWarning, + stacklevel=2, + ) diff --git a/airflow/providers/exasol/provider.yaml b/airflow/providers/exasol/provider.yaml index 6e6d8cbd7ca92..de64998877825 100644 --- a/airflow/providers/exasol/provider.yaml +++ b/airflow/providers/exasol/provider.yaml @@ -22,6 +22,11 @@ description: | `Exasol `__ versions: + - 4.1.0 + - 4.0.1 + - 4.0.0 + - 3.1.0 + - 3.0.0 - 2.1.3 - 2.1.2 - 2.1.1 @@ -32,8 +37,11 @@ versions: - 1.1.0 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-common-sql>=1.3.1 + - pyexasol>=0.5.1 + - pandas>=0.17.1 integrations: - integration-name: Exasol @@ -51,9 +59,6 @@ hooks: python-modules: - airflow.providers.exasol.hooks.exasol -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.exasol.hooks.exasol.ExasolHook - connection-types: - hook-class-name: airflow.providers.exasol.hooks.exasol.ExasolHook connection-type: exasol diff --git a/airflow/providers/facebook/.latest-doc-only-change.txt b/airflow/providers/facebook/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/facebook/.latest-doc-only-change.txt +++ b/airflow/providers/facebook/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/facebook/CHANGELOG.rst b/airflow/providers/facebook/CHANGELOG.rst index e0a8dbe15bc2f..6ed7e7fd1e380 100644 --- a/airflow/providers/facebook/CHANGELOG.rst +++ b/airflow/providers/facebook/CHANGELOG.rst @@ -16,9 +16,60 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +3.0.1 +..... + +Bug Fixes +~~~~~~~~~ + +* ``Update providers to use functools compat for ''cached_property'' (#24582)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.2.3 ..... diff --git a/airflow/providers/facebook/ads/hooks/ads.py b/airflow/providers/facebook/ads/hooks/ads.py index 96f2db9737351..86048712c965b 100644 --- a/airflow/providers/facebook/ads/hooks/ads.py +++ b/airflow/providers/facebook/ads/hooks/ads.py @@ -16,21 +16,18 @@ # specific language governing permissions and limitations # under the License. """This module contains Facebook Ads Reporting hooks""" -import sys +from __future__ import annotations + import time from enum import Enum -from typing import Any, Dict, List, Optional, Union - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property +from typing import Any from facebook_business.adobjects.adaccount import AdAccount from facebook_business.adobjects.adreportrun import AdReportRun from facebook_business.adobjects.adsinsights import AdsInsights from facebook_business.api import FacebookAdsApi +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook @@ -38,11 +35,11 @@ class JobStatus(Enum): """Available options for facebook async task status""" - COMPLETED = 'Job Completed' - STARTED = 'Job Started' - RUNNING = 'Job Running' - FAILED = 'Job Failed' - SKIPPED = 'Job Skipped' + COMPLETED = "Job Completed" + STARTED = "Job Started" + RUNNING = "Job Running" + FAILED = "Job Failed" + SKIPPED = "Job Skipped" class FacebookAdsReportingHook(BaseHook): @@ -59,15 +56,15 @@ class FacebookAdsReportingHook(BaseHook): """ - conn_name_attr = 'facebook_conn_id' - default_conn_name = 'facebook_default' - conn_type = 'facebook_social' - hook_name = 'Facebook Ads' + conn_name_attr = "facebook_conn_id" + default_conn_name = "facebook_default" + conn_type = "facebook_social" + hook_name = "Facebook Ads" def __init__( self, facebook_conn_id: str = default_conn_name, - api_version: Optional[str] = None, + api_version: str | None = None, ) -> None: super().__init__() self.facebook_conn_id = facebook_conn_id @@ -90,7 +87,7 @@ def multiple_accounts(self) -> bool: return isinstance(self.facebook_ads_config["account_id"], list) @cached_property - def facebook_ads_config(self) -> Dict: + def facebook_ads_config(self) -> dict: """ Gets Facebook ads connection from meta db and sets facebook_ads_config attribute with returned config file @@ -106,10 +103,10 @@ def facebook_ads_config(self) -> Dict: def bulk_facebook_report( self, - params: Optional[Dict[str, Any]], - fields: List[str], + params: dict[str, Any] | None, + fields: list[str], sleep_time: int = 5, - ) -> Union[List[AdsInsights], Dict[str, List[AdsInsights]]]: + ) -> list[AdsInsights] | dict[str, list[AdsInsights]]: """Pulls data from the Facebook Ads API regarding Account ID with matching return type. The return type and value depends on the ``account_id`` configuration. If the @@ -126,7 +123,6 @@ def bulk_facebook_report( :return: Facebook Ads API response, converted to Facebook Ads Row objects regarding given Account ID type - :rtype: List[AdsInsights] or Dict[str, List[AdsInsights]] """ api = self._get_service() if self.multiple_accounts: @@ -152,10 +148,10 @@ def _facebook_report( self, account_id: str, api: FacebookAdsApi, - params: Optional[Dict[str, Any]], - fields: List[str], + params: dict[str, Any] | None, + fields: list[str], sleep_time: int = 5, - ) -> List[AdsInsights]: + ) -> list[AdsInsights]: """ Pulls data from the Facebook Ads API with given account_id diff --git a/airflow/providers/facebook/provider.yaml b/airflow/providers/facebook/provider.yaml index 8b2471e64b0ee..2e05cca631562 100644 --- a/airflow/providers/facebook/provider.yaml +++ b/airflow/providers/facebook/provider.yaml @@ -22,6 +22,9 @@ description: | `Facebook Ads `__ versions: + - 3.1.0 + - 3.0.1 + - 3.0.0 - 2.2.3 - 2.2.2 - 2.2.1 @@ -33,8 +36,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - facebook-business>=6.0.2 integrations: - integration-name: Facebook Ads @@ -47,9 +51,6 @@ hooks: python-modules: - airflow.providers.facebook.ads.hooks.ads -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.facebook.ads.hooks.ads.FacebookAdsReportingHook - connection-types: - hook-class-name: airflow.providers.facebook.ads.hooks.ads.FacebookAdsReportingHook connection-type: facebook_social diff --git a/airflow/providers/ftp/.latest-doc-only-change.txt b/airflow/providers/ftp/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/ftp/.latest-doc-only-change.txt +++ b/airflow/providers/ftp/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/ftp/CHANGELOG.rst b/airflow/providers/ftp/CHANGELOG.rst index cda0e643a1a3d..93f6a3ab59877 100644 --- a/airflow/providers/ftp/CHANGELOG.rst +++ b/airflow/providers/ftp/CHANGELOG.rst @@ -16,9 +16,60 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.2.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Add blocksize arg for ftp hook (#24860)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.1.2 ..... diff --git a/airflow/providers/ftp/hooks/ftp.py b/airflow/providers/ftp/hooks/ftp.py index 5e6e5c12c20c1..28e7e6a175b03 100644 --- a/airflow/providers/ftp/hooks/ftp.py +++ b/airflow/providers/ftp/hooks/ftp.py @@ -15,12 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations import datetime import ftplib import os.path -from typing import Any, List, Optional, Tuple +from typing import Any, Callable from airflow.hooks.base import BaseHook @@ -37,15 +37,15 @@ class FTPHook(BaseHook): reference. """ - conn_name_attr = 'ftp_conn_id' - default_conn_name = 'ftp_default' - conn_type = 'ftp' - hook_name = 'FTP' + conn_name_attr = "ftp_conn_id" + default_conn_name = "ftp_default" + conn_type = "ftp" + hook_name = "FTP" def __init__(self, ftp_conn_id: str = default_conn_name) -> None: super().__init__() self.ftp_conn_id = ftp_conn_id - self.conn: Optional[ftplib.FTP] = None + self.conn: ftplib.FTP | None = None def __enter__(self): return self @@ -85,7 +85,7 @@ def describe_directory(self, path: str) -> dict: files = dict(conn.mlsd()) return files - def list_directory(self, path: str) -> List[str]: + def list_directory(self, path: str) -> list[str]: """ Returns a list of files on the remote system. @@ -115,7 +115,13 @@ def delete_directory(self, path: str) -> None: conn = self.get_conn() conn.rmd(path) - def retrieve_file(self, remote_full_path, local_full_path_or_buffer, callback=None): + def retrieve_file( + self, + remote_full_path: str, + local_full_path_or_buffer: Any, + callback: Callable | None = None, + block_size: int = 8192, + ) -> None: """ Transfers the remote file to a local location. @@ -132,6 +138,8 @@ def retrieve_file(self, remote_full_path, local_full_path_or_buffer, callback=No that writing to a file or buffer will need to be handled inside the callback. [default: output_handle.write()] + :param block_size: file is transferred in chunks of default size 8192 + or as set by user .. code-block:: python @@ -164,31 +172,30 @@ def write_to_file_with_progress(data): """ conn = self.get_conn() - is_path = isinstance(local_full_path_or_buffer, str) # without a callback, default to writing to a user-provided file or # file-like buffer if not callback: if is_path: - - output_handle = open(local_full_path_or_buffer, 'wb') + output_handle = open(local_full_path_or_buffer, "wb") else: output_handle = local_full_path_or_buffer + callback = output_handle.write - else: - output_handle = None remote_path, remote_file_name = os.path.split(remote_full_path) conn.cwd(remote_path) - self.log.info('Retrieving file from FTP: %s', remote_full_path) - conn.retrbinary(f'RETR {remote_file_name}', callback) - self.log.info('Finished retrieving file from FTP: %s', remote_full_path) + self.log.info("Retrieving file from FTP: %s", remote_full_path) + conn.retrbinary(f"RETR {remote_file_name}", callback, block_size) + self.log.info("Finished retrieving file from FTP: %s", remote_full_path) if is_path and output_handle: output_handle.close() - def store_file(self, remote_full_path: str, local_full_path_or_buffer: Any) -> None: + def store_file( + self, remote_full_path: str, local_full_path_or_buffer: Any, block_size: int = 8192 + ) -> None: """ Transfers a local file to the remote location. @@ -199,19 +206,19 @@ def store_file(self, remote_full_path: str, local_full_path_or_buffer: Any) -> N :param remote_full_path: full path to the remote file :param local_full_path_or_buffer: full path to the local file or a file-like buffer + :param block_size: file is transferred in chunks of default size 8192 + or as set by user """ conn = self.get_conn() - is_path = isinstance(local_full_path_or_buffer, str) if is_path: - - input_handle = open(local_full_path_or_buffer, 'rb') + input_handle = open(local_full_path_or_buffer, "rb") else: input_handle = local_full_path_or_buffer remote_path, remote_file_name = os.path.split(remote_full_path) conn.cwd(remote_path) - conn.storbinary(f'STOR {remote_file_name}', input_handle) + conn.storbinary(f"STOR {remote_file_name}", input_handle, block_size) if is_path: input_handle.close() @@ -242,15 +249,15 @@ def get_mod_time(self, path: str) -> datetime.datetime: :param path: remote file path """ conn = self.get_conn() - ftp_mdtm = conn.sendcmd('MDTM ' + path) + ftp_mdtm = conn.sendcmd("MDTM " + path) time_val = ftp_mdtm[4:] # time_val optionally has microseconds try: return datetime.datetime.strptime(time_val, "%Y%m%d%H%M%S.%f") except ValueError: - return datetime.datetime.strptime(time_val, '%Y%m%d%H%M%S') + return datetime.datetime.strptime(time_val, "%Y%m%d%H%M%S") - def get_size(self, path: str) -> Optional[int]: + def get_size(self, path: str) -> int | None: """ Returns the size of a file (in bytes) @@ -260,7 +267,7 @@ def get_size(self, path: str) -> Optional[int]: size = conn.size(path) return int(size) if size else None - def test_connection(self) -> Tuple[bool, str]: + def test_connection(self) -> tuple[bool, str]: """Test the FTP connection by calling path with directory""" try: conn = self.get_conn() diff --git a/airflow/providers/ftp/provider.yaml b/airflow/providers/ftp/provider.yaml index bc8c8a32616a6..a60dfcb8a3634 100644 --- a/airflow/providers/ftp/provider.yaml +++ b/airflow/providers/ftp/provider.yaml @@ -22,6 +22,9 @@ description: | `File Transfer Protocol (FTP) `__ versions: + - 3.2.0 + - 3.1.0 + - 3.0.0 - 2.1.2 - 2.1.1 - 2.1.0 @@ -31,6 +34,8 @@ versions: - 1.0.1 - 1.0.0 +dependencies: [] + integrations: - integration-name: File Transfer Protocol (FTP) external-doc-url: https://tools.ietf.org/html/rfc114 @@ -47,9 +52,6 @@ hooks: python-modules: - airflow.providers.ftp.hooks.ftp -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.ftp.hooks.ftp.FTPHook - connection-types: - hook-class-name: airflow.providers.ftp.hooks.ftp.FTPHook connection-type: ftp diff --git a/airflow/providers/ftp/sensors/ftp.py b/airflow/providers/ftp/sensors/ftp.py index faa9c5c315af1..5998c1b8f8db1 100644 --- a/airflow/providers/ftp/sensors/ftp.py +++ b/airflow/providers/ftp/sensors/ftp.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import ftplib import re from typing import TYPE_CHECKING, Sequence @@ -37,7 +39,7 @@ class FTPSensor(BaseSensorOperator): reference to run the sensor against. """ - template_fields: Sequence[str] = ('path',) + template_fields: Sequence[str] = ("path",) """Errors that are transient in nature, and where action can be retried""" transient_errors = [421, 425, 426, 434, 450, 451, 452] @@ -45,7 +47,7 @@ class FTPSensor(BaseSensorOperator): error_code_pattern = re.compile(r"([\d]+)") def __init__( - self, *, path: str, ftp_conn_id: str = 'ftp_default', fail_on_transient_errors: bool = True, **kwargs + self, *, path: str, ftp_conn_id: str = "ftp_default", fail_on_transient_errors: bool = True, **kwargs ) -> None: super().__init__(**kwargs) @@ -66,15 +68,15 @@ def _get_error_code(self, e): except ValueError: return e - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: with self._create_hook() as hook: - self.log.info('Poking for %s', self.path) + self.log.info("Poking for %s", self.path) try: mod_time = hook.get_mod_time(self.path) - self.log.info('Found File %s last modified: %s', str(self.path), str(mod_time)) + self.log.info("Found File %s last modified: %s", str(self.path), str(mod_time)) except ftplib.error_perm as e: - self.log.error('Ftp error encountered: %s', str(e)) + self.log.error("Ftp error encountered: %s", str(e)) error_code = self._get_error_code(e) if (error_code != 550) and ( self.fail_on_transient_errors or (error_code not in self.transient_errors) diff --git a/airflow/providers/github/.latest-doc-only-change.txt b/airflow/providers/github/.latest-doc-only-change.txt index ab24993f57139..ff7136e07d744 100644 --- a/airflow/providers/github/.latest-doc-only-change.txt +++ b/airflow/providers/github/.latest-doc-only-change.txt @@ -1 +1 @@ -8b6b0848a3cacf9999477d6af4d2a87463f03026 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/github/CHANGELOG.rst b/airflow/providers/github/CHANGELOG.rst index c800f7904856d..ebbdad7cfac95 100644 --- a/airflow/providers/github/CHANGELOG.rst +++ b/airflow/providers/github/CHANGELOG.rst @@ -18,6 +18,61 @@ under the License. Changelog +--------- + +2.2.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +2.1.0 +..... + +Features +~~~~~~~~ + +* ``Add test connection functionality to 'GithubHook' (#24903)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Only assert stuff for mypy when type checking (#24937)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + +2.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Bug Fixes +~~~~~~~~~ + + * ``Remove 'GithubOperator' use in 'GithubSensor.__init__()'' (#24214)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Migrate GitHub example DAGs to new design #22446 (#24134)`` + * ``Fix new MyPy errors in main (#22884)`` + * ``Change 'Github' to 'GitHub' (#23764)`` + * ``Prepare provider documentation 2022.05.11 (#23631)`` + * ``Use new Breese for building, pulling and verifying the images. (#23104)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` 1.0.3 ..... diff --git a/airflow/providers/github/example_dags/example_github.py b/airflow/providers/github/example_dags/example_github.py deleted file mode 100644 index 642bdd43e26dc..0000000000000 --- a/airflow/providers/github/example_dags/example_github.py +++ /dev/null @@ -1,101 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import logging -from datetime import datetime -from typing import Any, Optional - -from github import GithubException - -from airflow import AirflowException -from airflow.models.dag import DAG -from airflow.providers.github.operators.github import GithubOperator -from airflow.providers.github.sensors.github import GithubSensor, GithubTagSensor - -dag = DAG( - 'example_github_operator', - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) - -# [START howto_tag_sensor_github] - -tag_sensor = GithubTagSensor( - task_id='example_tag_sensor', - tag_name='v1.0', - repository_name="apache/airflow", - timeout=60, - poke_interval=10, - dag=dag, -) - - -# [END howto_tag_sensor_github] - -# [START howto_sensor_github] - - -def tag_checker(repo: Any, tag_name: str) -> Optional[bool]: - result = None - try: - if repo is not None and tag_name is not None: - all_tags = [x.name for x in repo.get_tags()] - result = tag_name in all_tags - - except GithubException as github_error: # type: ignore[misc] - raise AirflowException(f"Failed to execute GithubSensor, error: {str(github_error)}") - except Exception as e: - raise AirflowException(f"GitHub operator error: {str(e)}") - return result - - -github_sensor = GithubSensor( - task_id='example_sensor', - method_name="get_repo", - method_params={'full_name_or_id': "apache/airflow"}, - result_processor=lambda repo: tag_checker(repo, 'v1.0'), - timeout=60, - poke_interval=10, - dag=dag, -) - -# [END howto_sensor_github] - - -# [START howto_operator_list_repos_github] - -github_list_repos = GithubOperator( - task_id='github_list_repos', - github_method="get_user", - github_method_args={}, - result_processor=lambda user: logging.info(list(user.get_repos())), - dag=dag, -) - -# [END howto_operator_list_repos_github] - -# [START howto_operator_list_tags_github] - -list_repo_tags = GithubOperator( - task_id='list_repo_tags', - github_method="get_repo", - github_method_args={'full_name_or_id': 'apache/airflow'}, - result_processor=lambda repo: logging.info(list(repo.get_tags())), - dag=dag, -) - -# [END howto_operator_list_tags_github] diff --git a/airflow/providers/github/hooks/github.py b/airflow/providers/github/hooks/github.py index 07a8566a7f575..d7dfd1465a10b 100644 --- a/airflow/providers/github/hooks/github.py +++ b/airflow/providers/github/hooks/github.py @@ -15,40 +15,39 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""This module allows you to connect to GitHub.""" +from __future__ import annotations -"""This module allows to connect to a Github.""" -from typing import Dict, Optional +from typing import TYPE_CHECKING from github import Github as GithubClient +from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook class GithubHook(BaseHook): """ - Interact with Github. + Interact with GitHub. Performs a connection to GitHub and retrieves client. :param github_conn_id: Reference to :ref:`GitHub connection id `. """ - conn_name_attr = 'github_conn_id' - default_conn_name = 'github_default' - conn_type = 'github' - hook_name = 'Github' + conn_name_attr = "github_conn_id" + default_conn_name = "github_default" + conn_type = "github" + hook_name = "GitHub" def __init__(self, github_conn_id: str = default_conn_name, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.github_conn_id = github_conn_id - self.client: Optional[GithubClient] = None + self.client: GithubClient | None = None self.get_conn() def get_conn(self) -> GithubClient: - """ - Function that initiates a new GitHub connection - with token and hostname ( for GitHub Enterprise ) - """ + """Function that initiates a new GitHub connection with token and hostname (for GitHub Enterprise).""" if self.client is not None: return self.client @@ -56,6 +55,12 @@ def get_conn(self) -> GithubClient: access_token = conn.password host = conn.host + # Currently the only method of authenticating to GitHub in Airflow is via a token. This is not the + # only means available, but raising an exception to enforce this method for now. + # TODO: When/If other auth methods are implemented this exception should be removed/modified. + if not access_token: + raise AirflowException("An access token is required to authenticate to GitHub.") + if not host: self.client = GithubClient(login_or_token=access_token) else: @@ -64,16 +69,20 @@ def get_conn(self) -> GithubClient: return self.client @staticmethod - def get_ui_field_behaviour() -> Dict: + def get_ui_field_behaviour() -> dict: """Returns custom field behaviour""" return { - "hidden_fields": ['schema', 'port', 'login', 'extra'], - "relabeling": { - 'host': 'GitHub Enterprise Url (Optional)', - 'password': 'GitHub Access Token', - }, - "placeholders": { - 'host': 'https://{hostname}/api/v3 (for GitHub Enterprise Connection)', - 'password': 'token credentials auth', - }, + "hidden_fields": ["schema", "port", "login", "extra"], + "relabeling": {"host": "GitHub Enterprise URL (Optional)", "password": "GitHub Access Token"}, + "placeholders": {"host": "https://{hostname}/api/v3 (for GitHub Enterprise)"}, } + + def test_connection(self) -> tuple[bool, str]: + """Test GitHub connection.""" + try: + if TYPE_CHECKING: + assert self.client + self.client.get_user().id + return True, "Successfully connected to GitHub." + except Exception as e: + return False, str(e) diff --git a/airflow/providers/github/operators/github.py b/airflow/providers/github/operators/github.py index 73e11810714c6..1de3916f26888 100644 --- a/airflow/providers/github/operators/github.py +++ b/airflow/providers/github/operators/github.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable from github import GithubException @@ -49,9 +50,9 @@ def __init__( self, *, github_method: str, - github_conn_id: str = 'github_default', - github_method_args: Optional[dict] = None, - result_processor: Optional[Callable] = None, + github_conn_id: str = "github_default", + github_method_args: dict | None = None, + result_processor: Callable | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -60,7 +61,7 @@ def __init__( self.github_method_args = github_method_args self.result_processor = result_processor - def execute(self, context: 'Context') -> Any: + def execute(self, context: Context) -> Any: try: # Default method execution is on the top level GitHub client hook = GithubHook(github_conn_id=self.github_conn_id) @@ -75,4 +76,4 @@ def execute(self, context: 'Context') -> Any: except GithubException as github_error: raise AirflowException(f"Failed to execute GithubOperator, error: {str(github_error)}") except Exception as e: - raise AirflowException(f'GitHub operator error: {str(e)}') + raise AirflowException(f"GitHub operator error: {str(e)}") diff --git a/airflow/providers/github/provider.yaml b/airflow/providers/github/provider.yaml index 090176fcfdfc1..a03ec3ad5600e 100644 --- a/airflow/providers/github/provider.yaml +++ b/airflow/providers/github/provider.yaml @@ -18,9 +18,18 @@ --- package-name: apache-airflow-providers-github name: Github + description: | `GitHub `__ + +dependencies: + - apache-airflow>=2.3.0 + - pygithub + versions: + - 2.2.0 + - 2.1.0 + - 2.0.0 - 1.0.3 - 1.0.2 - 1.0.1 diff --git a/airflow/providers/github/sensors/github.py b/airflow/providers/github/sensors/github.py index f0501e055b173..602f71f518c5a 100644 --- a/airflow/providers/github/sensors/github.py +++ b/airflow/providers/github/sensors/github.py @@ -15,13 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable from github import GithubException from airflow import AirflowException -from airflow.providers.github.operators.github import GithubOperator +from airflow.providers.github.hooks.github import GithubHook from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: @@ -42,9 +43,9 @@ def __init__( self, *, method_name: str, - github_conn_id: str = 'github_default', - method_params: Optional[dict] = None, - result_processor: Optional[Callable] = None, + github_conn_id: str = "github_default", + method_params: dict | None = None, + result_processor: Callable | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -54,16 +55,15 @@ def __init__( self.result_processor = result_processor self.method_name = method_name self.method_params = method_params - self.github_operator = GithubOperator( - task_id=self.task_id, - github_conn_id=self.github_conn_id, - github_method=self.method_name, - github_method_args=self.method_params, - result_processor=self.result_processor, - ) - def poke(self, context: 'Context') -> bool: - return self.github_operator.execute(context=context) + def poke(self, context: Context) -> bool: + hook = GithubHook(github_conn_id=self.github_conn_id) + github_result = getattr(hook.client, self.method_name)(**self.method_params) + + if self.result_processor: + return self.result_processor(github_result) + + return github_result class BaseGithubRepositorySensor(GithubSensor): @@ -77,25 +77,25 @@ class BaseGithubRepositorySensor(GithubSensor): def __init__( self, *, - github_conn_id: str = 'github_default', - repository_name: Optional[str] = None, - result_processor: Optional[Callable] = None, + github_conn_id: str = "github_default", + repository_name: str | None = None, + result_processor: Callable | None = None, **kwargs, ) -> None: super().__init__( github_conn_id=github_conn_id, result_processor=result_processor, method_name="get_repo", - method_params={'full_name_or_id': repository_name}, + method_params={"full_name_or_id": repository_name}, **kwargs, ) - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: """ Function that the sensors defined while deriving this class should override. """ - raise AirflowException('Override me.') + raise AirflowException("Override me.") class GithubTagSensor(BaseGithubRepositorySensor): @@ -112,9 +112,9 @@ class GithubTagSensor(BaseGithubRepositorySensor): def __init__( self, *, - github_conn_id: str = 'github_default', - tag_name: Optional[str] = None, - repository_name: Optional[str] = None, + github_conn_id: str = "github_default", + tag_name: str | None = None, + repository_name: str | None = None, **kwargs, ) -> None: self.repository_name = repository_name @@ -126,11 +126,11 @@ def __init__( **kwargs, ) - def poke(self, context: 'Context') -> bool: - self.log.info('Poking for tag: %s in repository: %s', self.tag_name, self.repository_name) + def poke(self, context: Context) -> bool: + self.log.info("Poking for tag: %s in repository: %s", self.tag_name, self.repository_name) return GithubSensor.poke(self, context=context) - def tag_checker(self, repo: Any) -> Optional[bool]: + def tag_checker(self, repo: Any) -> bool | None: """Checking existence of Tag in a Repository""" result = None try: diff --git a/airflow/providers/google/CHANGELOG.rst b/airflow/providers/google/CHANGELOG.rst index 8971de875c1f1..f3b27e33887cb 100644 --- a/airflow/providers/google/CHANGELOG.rst +++ b/airflow/providers/google/CHANGELOG.rst @@ -15,9 +15,300 @@ specific language governing permissions and limitations under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +8.5.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` +* ``Rename hook bigquery function '_bq_cast' to 'bq_cast' (#27543)`` +* ``Use non-deprecated method for on_kill in BigQueryHook (#27547)`` +* ``Typecast biquery job response col value (#27236)`` +* ``Remove <2 limit on google-cloud-storage (#26922)`` + +Features +~~~~~~~~ + +* ``Add backward compatibility with old versions of Apache Beam (#27263)`` +* ``Add deferrable mode to GCPToBigQueryOperator + tests (#27052)`` +* ``Add system tests for Vertex AI operators in new approach (#27053)`` +* ``Dataform operators, links, update system tests and docs (#27144)`` +* ``Allow values in WorkflowsCreateExecutionOperator execution argument to be dicts (#27361)`` +* ``DataflowStopJobOperator Operator (#27033)`` +* ``Allow for the overriding of stringify_dict for json/jsonb column data type in Postgres #26875 (#26876)`` + +Bug Fixes +~~~~~~~~~ + +* ``Add new Compute Engine Operators and fix system tests (#25608)`` +* ``Allow and prefer non-prefixed extra fields for dataprep hook (#27039)`` +* ``Common sql bugfixes and improvements (#26761)`` +* ``Update google hooks to prefer non-prefixed extra fields (#27023)`` +* ``Fix delay in Dataproc CreateBatch operator (#26126)`` +* ``Remove unnecessary newlines around single arg in signature (#27525)`` +* ``set project_id and location when canceling BigQuery job (#27521)`` +* ``use the proper key to retrieve the dataflow job_id (#27336)`` +* ``Make GSheetsHook return an empty list when there are no values (#27261)`` +* ``Cloud ML Engine operators assets (#26836)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Change dataprep system tests assets (#26488)`` + * ``Upgrade dependencies in order to avoid backtracking (#27531)`` + * ``Migration of System Tests: Cloud Composer (AIP-47) (#27227)`` + * ``Rewrite system tests for ML Engine service (#26915)`` + * ``Migration of System Tests: Cloud BigQuery Data Transfer (AIP-47) (#27312)`` + * ``Migration of System Tests: Dataplex (AIP-47) (#26989)`` + * ``Migration of System Tests: Cloud Vision Operators (AIP-47) (#26963)`` + * ``Google Drive to local - system tests migrations (AIP-47) (#26798)`` + * ``Migrate Bigtable operators system tests according to AIP-47 (#26911)`` + * ``Migrate Dataproc Metastore system tests according to AIP-47 (#26858)`` + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Local filesystem to Google Drive Operator - system tests migration (AIP-47) (#26797)`` + * ``SFTP to Google Cloud Storage Transfer system tests migration (AIP-47) (#26799)`` + +.. Review and move the new changes to one of the sections above: + * ``Replace urlparse with urlsplit (#27389)`` + +8.4.0 +..... + +Features +~~~~~~~~ + +* ``Add BigQuery Column and Table Check Operators (#26368)`` +* ``Add deferrable big query operators and sensors (#26156)`` +* ``Add 'output' property to MappedOperator (#25604)`` +* ``Added append_job_name parameter to DataflowTemplatedJobStartOperator (#25746)`` +* ``Adding a parameter for exclusion of trashed files in GoogleDriveHook (#25675)`` +* ``Cloud Data Loss Prevention Operators assets (#26618)`` +* ``Cloud Storage Transfer Operators assets & system tests migration (AIP-47) (#26072)`` +* ``Merge deferrable BigQuery operators to exisitng one (#26433)`` +* ``specifying project id when calling wait_for_operation in delete/create cluster (#26418)`` +* ``Auto tail file logs in Web UI (#26169)`` +* ``Cloud Functions Operators assets & system tests migration (AIP-47) (#26073)`` +* ``GCSToBigQueryOperator Resolve 'max_id_key' job retrieval and xcom return (#26285)`` +* ``Allow for the overriding of 'stringify_dict' for json export format on BaseSQLToGCSOperator (#26277)`` +* ``Append GoogleLink base in the link class (#26057)`` +* ``Cloud Video Intelligence Operators assets & system tests migration (AIP-47) (#26132)`` +* ``Life Science assets & system tests migration (AIP-47) (#25548)`` +* ``GCSToBigQueryOperator allow for schema_object in alternate GCS Bucket (#26190)`` +* ``Use AsyncClient for Composer Operators in deferrable mode (#25951)`` +* ``Use project_id to get authenticated client (#25984)`` +* ``Cloud Build assets & system tests migration (AIP-47) (#25895)`` +* ``Dataproc submit job operator async (#25302)`` +* ``Support project_id argument in BigQueryGetDataOperator (#25782)`` + +Bug Fixes +~~~~~~~~~ + +* ``Fix JSONDecodeError in Datafusion operators (#26202)`` +* ``Fixed never ending loop to in CreateWorkflowInvocation (#25737)`` +* ``Update gcs.py (#26570)`` +* ``Don't throw an exception when a BQ cusor job has no schema (#26096)`` +* ``Google Cloud Tasks Sensor for queue being empty (#25622)`` +* ``Correcting the transfer config name. (#25719)`` +* ``Fix parsing of optional 'mode' field in BigQuery Result Schema (#26786)`` +* ``Fix MaxID logic for GCSToBigQueryOperator (#26768)`` + +Misc +~~~~ + +* ``Sql to GSC operators update docs for parquet format (#25878)`` +* ``Limit Google Protobuf for compatibility with biggtable client (#25886)`` +* ``Make GoogleBaseHook credentials functions public (#25785)`` +* ``Consolidate to one 'schedule' param (#25410)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Migrate Data Loss Prevention system tests according to AIP-47 (#26060)`` + * ``Google Drive to Google Cloud Storage Transfer Operator - system tests migration (AIP-47) (#26487)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to core airflow (#26290)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Replace SQL with Common SQL in pre commit (#26058)`` + * ``Hook into Mypy to get rid of those cast() (#26023)`` + * ``Work around pyupgrade edge cases (#26384)`` + * ``D400 first line should end with period batch02 (#25268)`` + * ``Fix GCS sensor system tests failing with DebugExecutor (#26742)`` + * ``Update docs for September Provider's release (#26731)`` + +8.3.0 +..... + +Features +~~~~~~~~ + +* ``add description method in BigQueryCursor class (#25366)`` +* ``Add project_id as a templated variable in two BQ operators (#24768)`` +* ``Remove deprecated modules in Amazon provider (#25543)`` +* ``Move all "old" SQL operators to common.sql providers (#25350)`` +* ``Improve taskflow type hints with ParamSpec (#25173)`` +* ``Unify DbApiHook.run() method with the methods which override it (#23971)`` +* ``Bump typing-extensions and mypy for ParamSpec (#25088)`` +* ``Deprecate hql parameters and synchronize DBApiHook method APIs (#25299)`` +* ``Dataform operators (#25587)`` + +Bug Fixes +~~~~~~~~~ + +* ``Fix GCSListObjectsOperator docstring (#25614)`` +* ``Fix BigQueryInsertJobOperator cancel_on_kill (#25342)`` +* ``Fix BaseSQLToGCSOperator approx_max_file_size_bytes (#25469)`` +* ``Fix PostgresToGCSOperat bool dtype (#25475)`` +* ``Fix Vertex AI Custom Job training issue (#25367)`` +* ``Fix Flask Login user setting for Flask 2.2 and Flask-Login 0.6.2 (#25318)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Migrate Google example trino_to_gcs to new design AIP-47 (#25420)`` + * ``Migrate Google example automl_nl_text_extraction to new design AIP-47 (#25418)`` + * ``Memorystore assets & system tests migration (AIP-47) (#25361)`` + * ``Translate system tests migration (AIP-47) (#25340)`` + * ``Migrate Google example life_sciences to new design AIP-47 (#25264)`` + * ``Migrate Google example natural_language to new design AIP-47 (#25262)`` + * ``Delete redundant system test bigquery_to_bigquery (#25261)`` + * ``Migrate Google example bigquery_to_mssql to new design AIP-47 (#25174)`` + * ``Migrate Google example compute_igm to new design AIP-47 (#25132)`` + * ``Migrate Google example automl_vision to new design AIP-47 (#25152)`` + * ``Migrate Google example gcs_to_sftp to new design AIP-47 (#25107)`` + * ``Migrate Google campaign manager example to new design AIP-47 (#25069)`` + * ``Migrate Google analytics example to new design AIP-47 (#25006)`` + +8.2.0 +..... + +Features +~~~~~~~~ + +* ``PubSub assets & system tests migration (AIP-47) (#24867)`` +* ``Add handling state of existing Dataproc batch (#24924)`` +* ``Add links for Google Kubernetes Engine operators (#24786)`` +* ``Add test_connection method to 'GoogleBaseHook' (#24682)`` +* ``Add gcp_conn_id argument to GoogleDriveToLocalOperator (#24622)`` +* ``Add DeprecationWarning for column_transformations parameter in AutoML (#24467)`` +* ``Modify BigQueryCreateExternalTableOperator to use updated hook function (#24363)`` +* ``Move all SQL classes to common-sql provider (#24836)`` +* ``Datacatalog assets & system tests migration (AIP-47) (#24600)`` +* ``Upgrade FAB to 4.1.1 (#24399)`` + +Bug Fixes +~~~~~~~~~ + +* ``GCSDeleteObjectsOperator empty prefix bug fix (#24353)`` +* ``perf(BigQuery): pass table_id as str type (#23141)`` +* ``Update providers to use functools compat for ''cached_property'' (#24582)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Migrate Google sheets example to new design AIP-47 (#24975)`` + * ``Migrate Google ads example to new design AIP-47 (#24941)`` + * ``Migrate Google example gcs_to_gdrive to new design AIP-47 (#24949)`` + * ``Migrate Google firestore example to new design AIP-47 (#24830)`` + * ``Automatically detect if non-lazy logging interpolation is used (#24910)`` + * ``Migrate Google example sql_to_sheets to new design AIP-47 (#24814)`` + * ``Remove "bad characters" from our codebase (#24841)`` + * ``Migrate Google example DAG mssql_to_gcs to new design AIP-47 (#24541)`` + * ``Align Black and blacken-docs configs (#24785)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Use our yaml util in all providers (#24720)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + * ``Migrate Google example DAG s3_to_gcs to new design AIP-47 (#24641)`` + * ``Migrate Google example DAG bigquery_transfer to new design AIP-47 (#24543)`` + * ``Migrate Google example DAG oracle_to_gcs to new design AIP-47 (#24542)`` + * ``Migrate Google example DAG mysql_to_gcs to new design AIP-47 (#24540)`` + * ``Migrate Google search_ads DAG to new design AIP-47 (#24298)`` + * ``Migrate Google gcs_to_sheets DAG to new design AIP-47 (#24501)`` + +8.1.0 +..... + +Features +~~~~~~~~ + +* ``Update Oracle library to latest version (#24311)`` +* ``Expose SQL to GCS Metadata (#24382)`` + +Bug Fixes +~~~~~~~~~ + +* ``fix typo in google provider additional extras (#24431)`` +* ``Use insert_job in the BigQueryToGCPOpertor and adjust links (#24416)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Fix links to sources for examples (#24386)`` + * ``Deprecate remaining occurrences of 'bigquery_conn_id' in favor of 'gcp_conn_id' (#24376)`` + * ``Migrate Google calendar example DAG to new design AIP-47 (#24333)`` + * ``Migrate Google azure_fileshare example DAG to new design AIP-47 (#24349)`` + * ``Remove bigquery example already migrated to AIP-47 (#24379)`` + * ``Migrate Google sheets example DAG to new design AIP-47 (#24351)`` + +8.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Features +~~~~~~~~ + +* ``Add key_secret_project_id parameter which specifies a project with KeyFile (#23930)`` +* ``Added impersonation_chain for DataflowStartFlexTemplateOperator and DataflowStartSqlJobOperator (#24046)`` +* ``Add fields to CLOUD_SQL_EXPORT_VALIDATION. (#23724)`` +* ``Update credentials when using ADC in Compute Engine (#23773)`` +* ``set color to operators in cloud_sql.py (#24000)`` +* ``Sql to gcs with exclude columns (#23695)`` +* ``[Issue#22846] allow option to encode or not encode UUID when uploading from Cassandra to GCS (#23766)`` +* ``Workflows assets & system tests migration (AIP-47) (#24105)`` +* ``Spanner assets & system tests migration (AIP-47) (#23957)`` +* ``Speech To Text assets & system tests migration (AIP-47) (#23643)`` +* ``Cloud SQL assets & system tests migration (AIP-47) (#23583)`` +* ``Cloud Storage assets & StorageLink update (#23865)`` + +Bug Fixes +~~~~~~~~~ + +* ``fix BigQueryInsertJobOperator (#24165)`` +* ``Fix the link to google workplace (#24080)`` +* ``Fix DataprocJobBaseOperator not being compatible with dotted names (#23439). (#23791)`` +* ``Remove hack from BigQuery DTS hook (#23887)`` +* ``Fix GCSToGCSOperator cannot copy a single file/folder without copying other files/folders with that prefix (#24039)`` +* ``Workaround job race bug on biguery to gcs transfer (#24330)`` + +Misc +~~~~ + +* ``Fix BigQuery system tests (#24013)`` +* ``Ensure @contextmanager decorates generator func (#23103)`` +* ``Migrate Dataproc to new system tests design (#22777)`` +* ``AIP-47 - Migrate google leveldb DAGs to new design ##22447 (#24233)`` +* ``Apply per-run log templates to log handlers (#24153)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Introduce 'flake8-implicit-str-concat' plugin to static checks (#23873)`` + * ``Clean up f-strings in logging calls (#23597)`` + * ``pydocstyle D202 added (#24221)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 7.0.0 ..... @@ -172,7 +463,6 @@ Misc * ``migrate system test gcs_to_bigquery into new design (#22753)`` * ``Add example DAG for demonstrating usage of GCS sensors (#22808)`` -.. Review and move the new changes to one of the sections above: * ``Clean up in-line f-string concatenation (#23591)`` * ``Bump pre-commit hook versions (#22887)`` * ``Use new Breese for building, pulling and verifying the images. (#23104)`` @@ -816,8 +1106,7 @@ now the snake_case convention is used. set_acl_permission = GCSBucketCreateAclEntryOperator( task_id="gcs-set-acl-permission", bucket=BUCKET_NAME, - entity="user-{{ task_instance.xcom_pull('get-instance')['persistenceIamIdentity']" - ".split(':', 2)[1] }}", + entity="user-{{ task_instance.xcom_pull('get-instance')['persistenceIamIdentity'].split(':', 2)[1] }}", role="OWNER", ) diff --git a/airflow/providers/google/__init__.py b/airflow/providers/google/__init__.py index de7c3eef9373f..5eccc44ea6fe0 100644 --- a/airflow/providers/google/__init__.py +++ b/airflow/providers/google/__init__.py @@ -14,15 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import importlib import logging # HACK: # Sphinx-autoapi doesn't like imports to excluded packages in the main module. -conf = importlib.import_module('airflow.configuration').conf # type: ignore[attr-defined] +conf = importlib.import_module("airflow.configuration").conf # type: ignore[attr-defined] PROVIDERS_GOOGLE_VERBOSE_LOGGING: bool = conf.getboolean( - 'providers_google', 'VERBOSE_LOGGING', fallback=False + "providers_google", "VERBOSE_LOGGING", fallback=False ) if PROVIDERS_GOOGLE_VERBOSE_LOGGING: for logger_name in ["google_auth_httplib2", "httplib2", "googleapiclient"]: diff --git a/airflow/providers/google/ads/example_dags/example_ads.py b/airflow/providers/google/ads/example_dags/example_ads.py deleted file mode 100644 index 85446b563d563..0000000000000 --- a/airflow/providers/google/ads/example_dags/example_ads.py +++ /dev/null @@ -1,88 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example Airflow DAG that shows how to use GoogleAdsToGcsOperator. -""" -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.ads.operators.ads import GoogleAdsListAccountsOperator -from airflow.providers.google.ads.transfers.ads_to_gcs import GoogleAdsToGcsOperator - -# [START howto_google_ads_env_variables] -CLIENT_IDS = ["1111111111", "2222222222"] -BUCKET = os.environ.get("GOOGLE_ADS_BUCKET", "gs://INVALID BUCKET NAME") -GCS_OBJ_PATH = "folder_name/google-ads-api-results.csv" -GCS_ACCOUNTS_CSV = "folder_name/accounts.csv" -QUERY = """ - SELECT - segments.date, - customer.id, - campaign.id, - ad_group.id, - ad_group_ad.ad.id, - metrics.impressions, - metrics.clicks, - metrics.conversions, - metrics.all_conversions, - metrics.cost_micros - FROM - ad_group_ad - WHERE - segments.date >= '2020-02-01' - AND segments.date <= '2020-02-29' - """ - -FIELDS_TO_EXTRACT = [ - "segments.date.value", - "customer.id.value", - "campaign.id.value", - "ad_group.id.value", - "ad_group_ad.ad.id.value", - "metrics.impressions.value", - "metrics.clicks.value", - "metrics.conversions.value", - "metrics.all_conversions.value", - "metrics.cost_micros.value", -] - -# [END howto_google_ads_env_variables] - -with models.DAG( - "example_google_ads", - schedule_interval=None, # Override to match your needs - start_date=datetime(2021, 1, 1), - catchup=False, -) as dag: - # [START howto_google_ads_to_gcs_operator] - run_operator = GoogleAdsToGcsOperator( - client_ids=CLIENT_IDS, - query=QUERY, - attributes=FIELDS_TO_EXTRACT, - obj=GCS_OBJ_PATH, - bucket=BUCKET, - task_id="run_operator", - ) - # [END howto_google_ads_to_gcs_operator] - - # [START howto_ads_list_accounts_operator] - list_accounts = GoogleAdsListAccountsOperator( - task_id="list_accounts", bucket=BUCKET, object_name=GCS_ACCOUNTS_CSV - ) - # [END howto_ads_list_accounts_operator] diff --git a/airflow/providers/google/ads/hooks/ads.py b/airflow/providers/google/ads/hooks/ads.py index c94997fa2f9f5..20bd75348c7ab 100644 --- a/airflow/providers/google/ads/hooks/ads.py +++ b/airflow/providers/google/ads/hooks/ads.py @@ -16,14 +16,10 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Ad hook.""" -import sys -from tempfile import NamedTemporaryFile -from typing import IO, Any, Dict, List, Optional +from __future__ import annotations -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property +from tempfile import NamedTemporaryFile +from typing import IO, Any from google.ads.googleads.client import GoogleAdsClient from google.ads.googleads.errors import GoogleAdsException @@ -34,7 +30,9 @@ from google.auth.exceptions import GoogleAuthError from airflow import AirflowException +from airflow.compat.functools import cached_property from airflow.hooks.base import BaseHook +from airflow.providers.google.common.hooks.base_google import get_field class GoogleAdsHook(BaseHook): @@ -73,14 +71,13 @@ class GoogleAdsHook(BaseHook): :param api_version: The Google Ads API version to use. :return: list of Google Ads Row object(s) - :rtype: list[GoogleAdsRow] """ default_api_version = "v10" def __init__( self, - api_version: Optional[str], + api_version: str | None, gcp_conn_id: str = "google_cloud_default", google_ads_conn_id: str = "google_ads_default", ) -> None: @@ -88,11 +85,11 @@ def __init__( self.api_version = api_version or self.default_api_version self.gcp_conn_id = gcp_conn_id self.google_ads_conn_id = google_ads_conn_id - self.google_ads_config: Dict[str, Any] = {} + self.google_ads_config: dict[str, Any] = {} def search( - self, client_ids: List[str], query: str, page_size: int = 10000, **kwargs - ) -> List[GoogleAdsRow]: + self, client_ids: list[str], query: str, page_size: int = 10000, **kwargs + ) -> list[GoogleAdsRow]: """ Pulls data from the Google Ads API and returns it as native protobuf message instances (those seen in versions prior to 10.0.0 of the @@ -109,7 +106,6 @@ def search( :param query: Google Ads Query Language query. :param page_size: Number of results to return per page. Max 10000. :return: Google Ads API response, converted to Google Ads Row objects - :rtype: list[GoogleAdsRow] """ data_proto_plus = self._search(client_ids, query, page_size, **kwargs) data_native_pb = [row._pb for row in data_proto_plus] @@ -117,8 +113,8 @@ def search( return data_native_pb def search_proto_plus( - self, client_ids: List[str], query: str, page_size: int = 10000, **kwargs - ) -> List[GoogleAdsRow]: + self, client_ids: list[str], query: str, page_size: int = 10000, **kwargs + ) -> list[GoogleAdsRow]: """ Pulls data from the Google Ads API and returns it as proto-plus-python message instances that behave more like conventional python objects. @@ -127,11 +123,10 @@ def search_proto_plus( :param query: Google Ads Query Language query. :param page_size: Number of results to return per page. Max 10000. :return: Google Ads API response, converted to Google Ads Row objects - :rtype: list[GoogleAdsRow] """ return self._search(client_ids, query, page_size, **kwargs) - def list_accessible_customers(self) -> List[str]: + def list_accessible_customers(self) -> list[str]: """ Returns resource names of customers directly accessible by the user authenticating the call. The resulting list of customers is based on your OAuth credentials. The request returns a list @@ -152,7 +147,7 @@ def list_accessible_customers(self) -> List[str]: self.log.error('\tError with message "%s".', error.message) if error.location: for field_path_element in error.location.field_path_elements: - self.log.error('\t\tOn field: %s', field_path_element.field_name) + self.log.error("\t\tOn field: %s", field_path_element.field_name) raise @cached_property @@ -203,16 +198,18 @@ def _update_config_with_secret(self, secrets_temp: IO[str]) -> None: Updates google ads config with file path of the temp file containing the secret Note, the secret must be passed as a file path for Google Ads API """ - secret_conn = self.get_connection(self.gcp_conn_id) - secret = secret_conn.extra_dejson["extra__google_cloud_platform__keyfile_dict"] + extras = self.get_connection(self.gcp_conn_id).extra_dejson + secret = get_field(extras, "keyfile_dict") + if not secret: + raise KeyError("secret_conn.extra_dejson does not contain keyfile_dict") secrets_temp.write(secret) secrets_temp.flush() self.google_ads_config["json_key_file_path"] = secrets_temp.name def _search( - self, client_ids: List[str], query: str, page_size: int = 10000, **kwargs - ) -> List[GoogleAdsRow]: + self, client_ids: list[str], query: str, page_size: int = 10000, **kwargs + ) -> list[GoogleAdsRow]: """ Pulls data from the Google Ads API @@ -221,13 +218,12 @@ def _search( :param page_size: Number of results to return per page. Max 10000. :return: Google Ads API response, converted to Google Ads Row objects - :rtype: list[GoogleAdsRow] """ service = self._get_service iterators = [] for client_id in client_ids: - request = self._get_client.get_type("SearchGoogleAdsRequest") # type: SearchGoogleAdsRequest + request: SearchGoogleAdsRequest = self._get_client.get_type("SearchGoogleAdsRequest") request.customer_id = client_id request.query = query request.page_size = page_size @@ -239,14 +235,13 @@ def _search( return self._extract_rows(iterators) - def _extract_rows(self, iterators: List[GRPCIterator]) -> List[GoogleAdsRow]: + def _extract_rows(self, iterators: list[GRPCIterator]) -> list[GoogleAdsRow]: """ Convert Google Page Iterator (GRPCIterator) objects to Google Ads Rows :param iterators: List of Google Page Iterator (GRPCIterator) objects :return: API response for all clients in the form of Google Ads Row object(s) - :rtype: list[GoogleAdsRow] """ try: self.log.info("Extracting data from returned Google Ads Iterators") diff --git a/airflow/providers/google/ads/operators/ads.py b/airflow/providers/google/ads/operators/ads.py index 702359e36488d..fc05b1ad54e44 100644 --- a/airflow/providers/google/ads/operators/ads.py +++ b/airflow/providers/google/ads/operators/ads.py @@ -16,9 +16,11 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Ad to GCS operators.""" +from __future__ import annotations + import csv from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.google.ads.hooks.ads import GoogleAdsHook @@ -75,8 +77,8 @@ def __init__( gcp_conn_id: str = "google_cloud_default", google_ads_conn_id: str = "google_ads_default", gzip: bool = False, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - api_version: Optional[str] = None, + impersonation_chain: str | Sequence[str] | None = None, + api_version: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -88,7 +90,7 @@ def __init__( self.impersonation_chain = impersonation_chain self.api_version = api_version - def execute(self, context: 'Context') -> str: + def execute(self, context: Context) -> str: uri = f"gs://{self.bucket}/{self.object_name}" ads_hook = GoogleAdsHook( diff --git a/airflow/providers/google/ads/transfers/ads_to_gcs.py b/airflow/providers/google/ads/transfers/ads_to_gcs.py index ffce93940c0d8..7e92de94e8e66 100644 --- a/airflow/providers/google/ads/transfers/ads_to_gcs.py +++ b/airflow/providers/google/ads/transfers/ads_to_gcs.py @@ -14,11 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import csv from operator import attrgetter from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.google.ads.hooks.ads import GoogleAdsHook @@ -74,17 +75,17 @@ class GoogleAdsToGcsOperator(BaseOperator): def __init__( self, *, - client_ids: List[str], + client_ids: list[str], query: str, - attributes: List[str], + attributes: list[str], bucket: str, obj: str, gcp_conn_id: str = "google_cloud_default", google_ads_conn_id: str = "google_ads_default", page_size: int = 10000, gzip: bool = False, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - api_version: Optional[str] = None, + impersonation_chain: str | Sequence[str] | None = None, + api_version: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -100,7 +101,7 @@ def __init__( self.impersonation_chain = impersonation_chain self.api_version = api_version - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: service = GoogleAdsHook( gcp_conn_id=self.gcp_conn_id, google_ads_conn_id=self.google_ads_conn_id, diff --git a/airflow/providers/google/cloud/_internal_client/secret_manager_client.py b/airflow/providers/google/cloud/_internal_client/secret_manager_client.py index 203cc7ac0bbfb..4a6fe127cc360 100644 --- a/airflow/providers/google/cloud/_internal_client/secret_manager_client.py +++ b/airflow/providers/google/cloud/_internal_client/secret_manager_client.py @@ -14,23 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import re -import sys -from typing import Optional import google - -from airflow.providers.google.common.consts import CLIENT_INFO - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - from google.api_core.exceptions import InvalidArgument, NotFound, PermissionDenied from google.cloud.secretmanager_v1 import SecretManagerServiceClient +from airflow.compat.functools import cached_property +from airflow.providers.google.common.consts import CLIENT_INFO from airflow.utils.log.logging_mixin import LoggingMixin SECRET_ID_PATTERN = r"^[a-zA-Z0-9-_]*$" @@ -68,7 +61,7 @@ def client(self) -> SecretManagerServiceClient: _client = SecretManagerServiceClient(credentials=self.credentials, client_info=CLIENT_INFO) return _client - def get_secret(self, secret_id: str, project_id: str, secret_version: str = 'latest') -> Optional[str]: + def get_secret(self, secret_id: str, project_id: str, secret_version: str = "latest") -> str | None: """ Get secret value from the Secret Manager. @@ -79,7 +72,7 @@ def get_secret(self, secret_id: str, project_id: str, secret_version: str = 'lat name = self.client.secret_version_path(project_id, secret_id, secret_version) try: response = self.client.access_secret_version(name) - value = response.payload.data.decode('UTF-8') + value = response.payload.data.decode("UTF-8") return value except NotFound: self.log.error("Google Cloud API Call Error (NotFound): Secret ID %s not found.", secret_id) diff --git a/airflow/providers/google/cloud/example_dags/example_automl_nl_text_classification.py b/airflow/providers/google/cloud/example_dags/example_automl_nl_text_classification.py index 28935e7b286f5..afe28a9ff82ec 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_nl_text_classification.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_nl_text_classification.py @@ -15,14 +15,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example Airflow DAG that uses Google AutoML services. """ +from __future__ import annotations + import os from datetime import datetime +from typing import cast from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( AutoMLCreateDatasetOperator, @@ -59,16 +62,15 @@ # Example DAG for AutoML Natural Language Text Classification with models.DAG( "example_automl_text_cls", - schedule_interval=None, # Override to match your needs start_date=datetime(2021, 1, 1), catchup=False, - tags=['example'], + tags=["example"], ) as example_dag: create_dataset_task = AutoMLCreateDatasetOperator( task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = create_dataset_task.output['dataset_id'] + dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id")) import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -81,7 +83,7 @@ create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) - model_id = create_model.output['model_id'] + model_id = cast(str, XComArg(create_model, key="model_id")) delete_model_task = AutoMLDeleteModelOperator( task_id="delete_model_task", diff --git a/airflow/providers/google/cloud/example_dags/example_automl_nl_text_extraction.py b/airflow/providers/google/cloud/example_dags/example_automl_nl_text_extraction.py deleted file mode 100644 index 0367ee3114cc1..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_automl_nl_text_extraction.py +++ /dev/null @@ -1,107 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG that uses Google AutoML services. -""" -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook -from airflow.providers.google.cloud.operators.automl import ( - AutoMLCreateDatasetOperator, - AutoMLDeleteDatasetOperator, - AutoMLDeleteModelOperator, - AutoMLImportDataOperator, - AutoMLTrainModelOperator, -) - -GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") -GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1") -GCP_AUTOML_TEXT_BUCKET = os.environ.get( - "GCP_AUTOML_TEXT_BUCKET", "gs://INVALID BUCKET NAME/NL-entity/dataset.csv" -) - -# Example values -DATASET_ID = "" - -# Example model -MODEL = { - "display_name": "auto_model_1", - "dataset_id": DATASET_ID, - "text_extraction_model_metadata": {}, -} - -# Example dataset -DATASET = {"display_name": "test_text_dataset", "text_extraction_dataset_metadata": {}} - -IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [GCP_AUTOML_TEXT_BUCKET]}} - -extract_object_id = CloudAutoMLHook.extract_object_id - -# Example DAG for AutoML Natural Language Entities Extraction -with models.DAG( - "example_automl_text", - schedule_interval=None, # Override to match your needs - start_date=datetime(2021, 1, 1), - catchup=False, - user_defined_macros={"extract_object_id": extract_object_id}, - tags=['example'], -) as example_dag: - create_dataset_task = AutoMLCreateDatasetOperator( - task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION - ) - - dataset_id = create_dataset_task.output['dataset_id'] - - import_dataset_task = AutoMLImportDataOperator( - task_id="import_dataset_task", - dataset_id=dataset_id, - location=GCP_AUTOML_LOCATION, - input_config=IMPORT_INPUT_CONFIG, - ) - - MODEL["dataset_id"] = dataset_id - - create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) - - model_id = create_model.output['model_id'] - - delete_model_task = AutoMLDeleteModelOperator( - task_id="delete_model_task", - model_id=model_id, - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, - ) - - delete_datasets_task = AutoMLDeleteDatasetOperator( - task_id="delete_datasets_task", - dataset_id=dataset_id, - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, - ) - - import_dataset_task >> create_model - delete_model_task >> delete_datasets_task - - # Task dependencies created via `XComArgs`: - # create_dataset_task >> import_dataset_task - # create_dataset_task >> create_model - # create_model >> delete_model_task - # create_dataset_task >> delete_datasets_task diff --git a/airflow/providers/google/cloud/example_dags/example_automl_nl_text_sentiment.py b/airflow/providers/google/cloud/example_dags/example_automl_nl_text_sentiment.py index 47a4a7695a98a..a823b8af89538 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_nl_text_sentiment.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_nl_text_sentiment.py @@ -15,14 +15,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example Airflow DAG that uses Google AutoML services. """ +from __future__ import annotations + import os from datetime import datetime +from typing import cast from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( AutoMLCreateDatasetOperator, @@ -59,17 +62,16 @@ # Example DAG for AutoML Natural Language Text Sentiment with models.DAG( "example_automl_text_sentiment", - schedule_interval=None, # Override to match your needs start_date=datetime(2021, 1, 1), catchup=False, user_defined_macros={"extract_object_id": extract_object_id}, - tags=['example'], + tags=["example"], ) as example_dag: create_dataset_task = AutoMLCreateDatasetOperator( task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = create_dataset_task.output['dataset_id'] + dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id")) import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -82,7 +84,7 @@ create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) - model_id = create_model.output['model_id'] + model_id = cast(str, XComArg(create_model, key="model_id")) delete_model_task = AutoMLDeleteModelOperator( task_id="delete_model_task", diff --git a/airflow/providers/google/cloud/example_dags/example_automl_tables.py b/airflow/providers/google/cloud/example_dags/example_automl_tables.py index 9ba0314dae777..89006402f7e25 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_tables.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_tables.py @@ -15,16 +15,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example Airflow DAG that uses Google AutoML services. """ +from __future__ import annotations + import os from copy import deepcopy from datetime import datetime -from typing import Dict, List +from typing import cast from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( AutoMLBatchPredictOperator, @@ -73,7 +75,7 @@ extract_object_id = CloudAutoMLHook.extract_object_id -def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: +def get_target_column_spec(columns_specs: list[dict], column_name: str) -> str: """ Using column name returns spec of the column. """ @@ -86,7 +88,6 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: # Example DAG to create dataset, train model_id and deploy it. with models.DAG( "example_create_and_deploy", - schedule_interval='@once', # Override to match your needs start_date=START_DATE, catchup=False, user_defined_macros={ @@ -94,7 +95,7 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: "target": TARGET, "extract_object_id": extract_object_id, }, - tags=['example'], + tags=["example"], ) as create_deploy_dag: # [START howto_operator_automl_create_dataset] create_dataset_task = AutoMLCreateDatasetOperator( @@ -104,7 +105,7 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: project_id=GCP_PROJECT_ID, ) - dataset_id = create_dataset_task.output['dataset_id'] + dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id")) # [END howto_operator_automl_create_dataset] MODEL["dataset_id"] = dataset_id @@ -159,7 +160,7 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: project_id=GCP_PROJECT_ID, ) - model_id = create_model_task.output['model_id'] + model_id = cast(str, XComArg(create_model_task, key="model_id")) # [END howto_operator_automl_create_model] # [START howto_operator_automl_delete_model] @@ -199,19 +200,18 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: # Example DAG for AutoML datasets operations with models.DAG( "example_automl_dataset", - schedule_interval='@once', # Override to match your needs start_date=START_DATE, catchup=False, user_defined_macros={"extract_object_id": extract_object_id}, ) as example_dag: - create_dataset_task = AutoMLCreateDatasetOperator( + create_dataset_task2 = AutoMLCreateDatasetOperator( task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION, project_id=GCP_PROJECT_ID, ) - dataset_id = create_dataset_task.output['dataset_id'] + dataset_id = cast(str, XComArg(create_dataset_task2, key="dataset_id")) import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -268,7 +268,6 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: with models.DAG( "example_gcp_get_deploy", - schedule_interval='@once', # Override to match your needs start_date=START_DATE, catchup=False, tags=["example"], @@ -294,7 +293,6 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: with models.DAG( "example_gcp_predict", - schedule_interval='@once', # Override to match your needs start_date=START_DATE, catchup=False, tags=["example"], diff --git a/airflow/providers/google/cloud/example_dags/example_automl_translation.py b/airflow/providers/google/cloud/example_dags/example_automl_translation.py index ae90458b5869a..2bef20caabee3 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_translation.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_translation.py @@ -15,14 +15,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example Airflow DAG that uses Google AutoML services. """ +from __future__ import annotations + import os from datetime import datetime +from typing import cast from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( AutoMLCreateDatasetOperator, @@ -65,17 +68,16 @@ # Example DAG for AutoML Translation with models.DAG( "example_automl_translation", - schedule_interval=None, # Override to match your needs start_date=datetime(2021, 1, 1), catchup=False, user_defined_macros={"extract_object_id": extract_object_id}, - tags=['example'], + tags=["example"], ) as example_dag: create_dataset_task = AutoMLCreateDatasetOperator( task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = create_dataset_task.output["dataset_id"] + dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id")) import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -88,7 +90,7 @@ create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) - model_id = create_model.output["model_id"] + model_id = cast(str, XComArg(create_model, key="model_id")) delete_model_task = AutoMLDeleteModelOperator( task_id="delete_model_task", diff --git a/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_classification.py b/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_classification.py index 2ecaebf871f95..2b55c42a8ab96 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_classification.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_classification.py @@ -15,14 +15,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example Airflow DAG that uses Google AutoML services. """ +from __future__ import annotations + import os from datetime import datetime +from typing import cast from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( AutoMLCreateDatasetOperator, @@ -62,17 +65,16 @@ # Example DAG for AutoML Video Intelligence Classification with models.DAG( "example_automl_video", - schedule_interval=None, # Override to match your needs start_date=datetime(2021, 1, 1), catchup=False, user_defined_macros={"extract_object_id": extract_object_id}, - tags=['example'], + tags=["example"], ) as example_dag: create_dataset_task = AutoMLCreateDatasetOperator( task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = create_dataset_task.output["dataset_id"] + dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id")) import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -85,7 +87,7 @@ create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) - model_id = create_model.output["model_id"] + model_id = cast(str, XComArg(create_model, key="model_id")) delete_model_task = AutoMLDeleteModelOperator( task_id="delete_model_task", diff --git a/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_tracking.py b/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_tracking.py index f8ba82da98823..daed5748b2399 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_tracking.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_tracking.py @@ -15,14 +15,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example Airflow DAG that uses Google AutoML services. """ +from __future__ import annotations + import os from datetime import datetime +from typing import cast from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( AutoMLCreateDatasetOperator, @@ -63,17 +66,16 @@ # Example DAG for AutoML Video Intelligence Object Tracking with models.DAG( "example_automl_video_tracking", - schedule_interval=None, # Override to match your needs start_date=datetime(2021, 1, 1), catchup=False, user_defined_macros={"extract_object_id": extract_object_id}, - tags=['example'], + tags=["example"], ) as example_dag: create_dataset_task = AutoMLCreateDatasetOperator( task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = create_dataset_task.output["dataset_id"] + dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id")) import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -86,7 +88,7 @@ create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) - model_id = create_model.output["model_id"] + model_id = cast(str, XComArg(create_model, key="model_id")) delete_model_task = AutoMLDeleteModelOperator( task_id="delete_model_task", diff --git a/airflow/providers/google/cloud/example_dags/example_automl_vision_classification.py b/airflow/providers/google/cloud/example_dags/example_automl_vision_classification.py deleted file mode 100644 index 66df48f64ef93..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_automl_vision_classification.py +++ /dev/null @@ -1,109 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG that uses Google AutoML services. -""" -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook -from airflow.providers.google.cloud.operators.automl import ( - AutoMLCreateDatasetOperator, - AutoMLDeleteDatasetOperator, - AutoMLDeleteModelOperator, - AutoMLImportDataOperator, - AutoMLTrainModelOperator, -) - -GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") -GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1") -GCP_AUTOML_VISION_BUCKET = os.environ.get("GCP_AUTOML_VISION_BUCKET", "gs://INVALID BUCKET NAME") - -# Example values -DATASET_ID = "ICN123455678" - -# Example model -MODEL = { - "display_name": "auto_model_2", - "dataset_id": DATASET_ID, - "image_classification_model_metadata": {"train_budget": 1}, -} - -# Example dataset -DATASET = { - "display_name": "test_vision_dataset", - "image_classification_dataset_metadata": {"classification_type": "MULTILABEL"}, -} - -IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [GCP_AUTOML_VISION_BUCKET]}} - -extract_object_id = CloudAutoMLHook.extract_object_id - - -# Example DAG for AutoML Vision Classification -with models.DAG( - "example_automl_vision", - schedule_interval=None, # Override to match your needs - start_date=datetime(2021, 1, 1), - catchup=False, - user_defined_macros={"extract_object_id": extract_object_id}, - tags=['example'], -) as example_dag: - create_dataset_task = AutoMLCreateDatasetOperator( - task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION - ) - - dataset_id = create_dataset_task.output["dataset_id"] - - import_dataset_task = AutoMLImportDataOperator( - task_id="import_dataset_task", - dataset_id=dataset_id, - location=GCP_AUTOML_LOCATION, - input_config=IMPORT_INPUT_CONFIG, - ) - - MODEL["dataset_id"] = dataset_id - - create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) - - model_id = create_model.output["model_id"] - - delete_model_task = AutoMLDeleteModelOperator( - task_id="delete_model_task", - model_id=model_id, - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, - ) - - delete_datasets_task = AutoMLDeleteDatasetOperator( - task_id="delete_datasets_task", - dataset_id=dataset_id, - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, - ) - - import_dataset_task >> create_model - delete_model_task >> delete_datasets_task - - # Task dependencies created via `XComArgs`: - # create_dataset_task >> import_dataset_task - # create_dataset_task >> create_model - # create_model >> delete_model_task - # create_dataset_task >> delete_datasets_task diff --git a/airflow/providers/google/cloud/example_dags/example_automl_vision_object_detection.py b/airflow/providers/google/cloud/example_dags/example_automl_vision_object_detection.py index 1d897f2d6a9e4..8b9ae271d1126 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_vision_object_detection.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_vision_object_detection.py @@ -15,14 +15,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example Airflow DAG that uses Google AutoML services. """ +from __future__ import annotations + import os from datetime import datetime +from typing import cast from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( AutoMLCreateDatasetOperator, @@ -62,17 +65,16 @@ # Example DAG for AutoML Vision Object Detection with models.DAG( "example_automl_vision_detection", - schedule_interval=None, # Override to match your needs start_date=datetime(2021, 1, 1), catchup=False, user_defined_macros={"extract_object_id": extract_object_id}, - tags=['example'], + tags=["example"], ) as example_dag: create_dataset_task = AutoMLCreateDatasetOperator( task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = create_dataset_task.output["dataset_id"] + dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id")) import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -85,7 +87,7 @@ create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) - model_id = create_model.output["model_id"] + model_id = cast(str, XComArg(create_model, key="model_id")) delete_model_task = AutoMLDeleteModelOperator( task_id="delete_model_task", diff --git a/airflow/providers/google/cloud/example_dags/example_azure_fileshare_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_azure_fileshare_to_gcs.py deleted file mode 100644 index 680b43bd01eb4..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_azure_fileshare_to_gcs.py +++ /dev/null @@ -1,54 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import os -from datetime import datetime, timedelta - -from airflow import DAG -from airflow.providers.google.cloud.transfers.azure_fileshare_to_gcs import AzureFileShareToGCSOperator - -DEST_GCS_BUCKET = os.environ.get('GCP_GCS_BUCKET', 'gs://INVALID BUCKET NAME') -AZURE_SHARE_NAME = os.environ.get('AZURE_SHARE_NAME', 'test-azure-share') -AZURE_DIRECTORY_NAME = "test-azure-dir" - - -with DAG( - dag_id='azure_fileshare_to_gcs_example', - default_args={ - 'owner': 'airflow', - 'depends_on_past': False, - 'email': ['airflow@example.com'], - 'email_on_failure': False, - 'email_on_retry': False, - 'retries': 1, - 'retry_delay': timedelta(minutes=5), - }, - schedule_interval='@once', - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['example'], -) as dag: - # [START howto_operator_azure_fileshare_to_gcs_basic] - sync_azure_files_with_gcs = AzureFileShareToGCSOperator( - task_id='sync_azure_files_with_gcs', - share_name=AZURE_SHARE_NAME, - dest_gcs=DEST_GCS_BUCKET, - directory_name=AZURE_DIRECTORY_NAME, - replace=False, - gzip=True, - google_impersonation_chain=None, - ) - # [END howto_operator_azure_fileshare_to_gcs_basic] diff --git a/airflow/providers/google/cloud/example_dags/example_bigquery_dts.py b/airflow/providers/google/cloud/example_dags/example_bigquery_dts.py deleted file mode 100644 index ac584eb197496..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_bigquery_dts.py +++ /dev/null @@ -1,111 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG that creates and deletes Bigquery data transfer configurations. -""" -import os -import time -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.operators.bigquery_dts import ( - BigQueryCreateDataTransferOperator, - BigQueryDataTransferServiceStartTransferRunsOperator, - BigQueryDeleteDataTransferConfigOperator, -) -from airflow.providers.google.cloud.sensors.bigquery_dts import BigQueryDataTransferServiceTransferRunSensor - -GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") -BUCKET_URI = os.environ.get("GCP_DTS_BUCKET_URI", "gs://INVALID BUCKET NAME/bank-marketing.csv") -GCP_DTS_BQ_DATASET = os.environ.get("GCP_DTS_BQ_DATASET", "test_dts") -GCP_DTS_BQ_TABLE = os.environ.get("GCP_DTS_BQ_TABLE", "GCS_Test") - -# [START howto_bigquery_dts_create_args] - -# In the case of Airflow, the customer needs to create a transfer -# config with the automatic scheduling disabled and then trigger -# a transfer run using a specialized Airflow operator -schedule_options = {"disable_auto_scheduling": True} - -PARAMS = { - "field_delimiter": ",", - "max_bad_records": "0", - "skip_leading_rows": "1", - "data_path_template": BUCKET_URI, - "destination_table_name_template": GCP_DTS_BQ_TABLE, - "file_format": "CSV", -} - -TRANSFER_CONFIG = { - "destination_dataset_id": GCP_DTS_BQ_DATASET, - "display_name": "GCS Test Config", - "data_source_id": "google_cloud_storage", - "schedule_options": schedule_options, - "params": PARAMS, -} - -# [END howto_bigquery_dts_create_args] - -with models.DAG( - "example_gcp_bigquery_dts", - schedule_interval='@once', # Override to match your needs - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['example'], -) as dag: - # [START howto_bigquery_create_data_transfer] - gcp_bigquery_create_transfer = BigQueryCreateDataTransferOperator( - transfer_config=TRANSFER_CONFIG, - project_id=GCP_PROJECT_ID, - task_id="gcp_bigquery_create_transfer", - ) - - transfer_config_id = gcp_bigquery_create_transfer.output["transfer_config_id"] - # [END howto_bigquery_create_data_transfer] - - # [START howto_bigquery_start_transfer] - gcp_bigquery_start_transfer = BigQueryDataTransferServiceStartTransferRunsOperator( - task_id="gcp_bigquery_start_transfer", - transfer_config_id=transfer_config_id, - requested_run_time={"seconds": int(time.time() + 60)}, - ) - # [END howto_bigquery_start_transfer] - - # [START howto_bigquery_dts_sensor] - gcp_run_sensor = BigQueryDataTransferServiceTransferRunSensor( - task_id="gcp_run_sensor", - transfer_config_id=transfer_config_id, - run_id=gcp_bigquery_start_transfer.output["run_id"], - expected_statuses={"SUCCEEDED"}, - ) - # [END howto_bigquery_dts_sensor] - - # [START howto_bigquery_delete_data_transfer] - gcp_bigquery_delete_transfer = BigQueryDeleteDataTransferConfigOperator( - transfer_config_id=transfer_config_id, task_id="gcp_bigquery_delete_transfer" - ) - # [END howto_bigquery_delete_data_transfer] - - gcp_run_sensor >> gcp_bigquery_delete_transfer - - # Task dependencies created via `XComArgs`: - # gcp_bigquery_create_transfer >> gcp_bigquery_start_transfer - # gcp_bigquery_create_transfer >> gcp_run_sensor - # gcp_bigquery_start_transfer >> gcp_run_sensor - # gcp_bigquery_create_transfer >> gcp_bigquery_delete_transfer diff --git a/airflow/providers/google/cloud/example_dags/example_bigquery_to_bigquery.py b/airflow/providers/google/cloud/example_dags/example_bigquery_to_bigquery.py deleted file mode 100644 index 0d6a5bb2d5e0a..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_bigquery_to_bigquery.py +++ /dev/null @@ -1,69 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG for Google BigQuery service. -""" -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.operators.bigquery import ( - BigQueryCreateEmptyDatasetOperator, - BigQueryCreateEmptyTableOperator, - BigQueryDeleteDatasetOperator, -) -from airflow.providers.google.cloud.transfers.bigquery_to_bigquery import BigQueryToBigQueryOperator - -PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") -DATASET_NAME = os.environ.get("GCP_BIGQUERY_DATASET_NAME", "test_dataset_transfer") -ORIGIN = "origin" -TARGET = "target" - -with models.DAG( - "example_bigquery_to_bigquery", - schedule_interval='@once', # Override to match your needs - start_date=datetime(2021, 1, 1), - catchup=False, - tags=["example"], -) as dag: - copy_selected_data = BigQueryToBigQueryOperator( - task_id="copy_selected_data", - source_project_dataset_tables=f"{DATASET_NAME}.{ORIGIN}", - destination_project_dataset_table=f"{DATASET_NAME}.{TARGET}", - ) - - create_dataset = BigQueryCreateEmptyDatasetOperator(task_id="create_dataset", dataset_id=DATASET_NAME) - - for table in [ORIGIN, TARGET]: - create_table = BigQueryCreateEmptyTableOperator( - task_id=f"create_{table}_table", - dataset_id=DATASET_NAME, - table_id=table, - schema_fields=[ - {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, - {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}, - ], - ) - create_dataset >> create_table >> copy_selected_data - - delete_dataset = BigQueryDeleteDatasetOperator( - task_id="delete_dataset", dataset_id=DATASET_NAME, delete_contents=True - ) - - copy_selected_data >> delete_dataset diff --git a/airflow/providers/google/cloud/example_dags/example_bigquery_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_bigquery_to_gcs.py deleted file mode 100644 index ba66e10e44a35..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_bigquery_to_gcs.py +++ /dev/null @@ -1,68 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG for Google BigQuery service. -""" -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.operators.bigquery import ( - BigQueryCreateEmptyDatasetOperator, - BigQueryCreateEmptyTableOperator, - BigQueryDeleteDatasetOperator, -) -from airflow.providers.google.cloud.transfers.bigquery_to_gcs import BigQueryToGCSOperator - -PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") -DATASET_NAME = os.environ.get("GCP_BIGQUERY_DATASET_NAME", "test_dataset_transfer") -DATA_EXPORT_BUCKET_NAME = os.environ.get("GCP_BIGQUERY_EXPORT_BUCKET_NAME", "INVALID BUCKET NAME") -TABLE = "table_42" - -with models.DAG( - "example_bigquery_to_gcs", - schedule_interval=None, # Override to match your needs - start_date=datetime(2021, 1, 1), - catchup=False, - tags=["example"], -) as dag: - bigquery_to_gcs = BigQueryToGCSOperator( - task_id="bigquery_to_gcs", - source_project_dataset_table=f"{DATASET_NAME}.{TABLE}", - destination_cloud_storage_uris=[f"gs://{DATA_EXPORT_BUCKET_NAME}/export-bigquery.csv"], - ) - - create_dataset = BigQueryCreateEmptyDatasetOperator(task_id="create_dataset", dataset_id=DATASET_NAME) - - create_table = BigQueryCreateEmptyTableOperator( - task_id="create_table", - dataset_id=DATASET_NAME, - table_id=TABLE, - schema_fields=[ - {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, - {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}, - ], - ) - create_dataset >> create_table >> bigquery_to_gcs - - delete_dataset = BigQueryDeleteDatasetOperator( - task_id="delete_dataset", dataset_id=DATASET_NAME, delete_contents=True - ) - - bigquery_to_gcs >> delete_dataset diff --git a/airflow/providers/google/cloud/example_dags/example_bigquery_to_mssql.py b/airflow/providers/google/cloud/example_dags/example_bigquery_to_mssql.py deleted file mode 100644 index 64a18ad07e007..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_bigquery_to_mssql.py +++ /dev/null @@ -1,70 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG for Google BigQuery service. -""" -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.operators.bigquery import ( - BigQueryCreateEmptyDatasetOperator, - BigQueryCreateEmptyTableOperator, - BigQueryDeleteDatasetOperator, -) -from airflow.providers.google.cloud.transfers.bigquery_to_mssql import BigQueryToMsSqlOperator - -PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") -DATASET_NAME = os.environ.get("GCP_BIGQUERY_DATASET_NAME", "test_dataset_transfer") -DATA_EXPORT_BUCKET_NAME = os.environ.get("GCP_BIGQUERY_EXPORT_BUCKET_NAME", "INVALID BUCKET NAME") -TABLE = "table_42" -destination_table = "mssql_table_test" - -with models.DAG( - "example_bigquery_to_mssql", - schedule_interval=None, # Override to match your needs - start_date=datetime(2021, 1, 1), - catchup=False, - tags=["example"], -) as dag: - bigquery_to_mssql = BigQueryToMsSqlOperator( - task_id="bigquery_to_mssql", - source_project_dataset_table=f'{PROJECT_ID}.{DATASET_NAME}.{TABLE}', - mssql_table=destination_table, - replace=False, - ) - - create_dataset = BigQueryCreateEmptyDatasetOperator(task_id="create_dataset", dataset_id=DATASET_NAME) - - create_table = BigQueryCreateEmptyTableOperator( - task_id="create_table", - dataset_id=DATASET_NAME, - table_id=TABLE, - schema_fields=[ - {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, - {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}, - ], - ) - create_dataset >> create_table >> bigquery_to_mssql - - delete_dataset = BigQueryDeleteDatasetOperator( - task_id="delete_dataset", dataset_id=DATASET_NAME, delete_contents=True - ) - - bigquery_to_mssql >> delete_dataset diff --git a/airflow/providers/google/cloud/example_dags/example_bigquery_transfer.py b/airflow/providers/google/cloud/example_dags/example_bigquery_transfer.py deleted file mode 100644 index c12934dc31c9a..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_bigquery_transfer.py +++ /dev/null @@ -1,77 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG for Google BigQuery service. -""" -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.operators.bigquery import ( - BigQueryCreateEmptyDatasetOperator, - BigQueryCreateEmptyTableOperator, - BigQueryDeleteDatasetOperator, -) -from airflow.providers.google.cloud.transfers.bigquery_to_bigquery import BigQueryToBigQueryOperator -from airflow.providers.google.cloud.transfers.bigquery_to_gcs import BigQueryToGCSOperator - -PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") -DATASET_NAME = os.environ.get("GCP_BIGQUERY_DATASET_NAME", "test_dataset_transfer") -DATA_EXPORT_BUCKET_NAME = os.environ.get("GCP_BIGQUERY_EXPORT_BUCKET_NAME", "INVALID BUCKET NAME") -ORIGIN = "origin" -TARGET = "target" - -with models.DAG( - "example_bigquery_transfer", - schedule_interval=None, # Override to match your needs - start_date=datetime(2021, 1, 1), - catchup=False, - tags=["example"], -) as dag: - copy_selected_data = BigQueryToBigQueryOperator( - task_id="copy_selected_data", - source_project_dataset_tables=f"{DATASET_NAME}.{ORIGIN}", - destination_project_dataset_table=f"{DATASET_NAME}.{TARGET}", - ) - - bigquery_to_gcs = BigQueryToGCSOperator( - task_id="bigquery_to_gcs", - source_project_dataset_table=f"{DATASET_NAME}.{ORIGIN}", - destination_cloud_storage_uris=[f"gs://{DATA_EXPORT_BUCKET_NAME}/export-bigquery.csv"], - ) - - create_dataset = BigQueryCreateEmptyDatasetOperator(task_id="create_dataset", dataset_id=DATASET_NAME) - - for table in [ORIGIN, TARGET]: - create_table = BigQueryCreateEmptyTableOperator( - task_id=f"create_{table}_table", - dataset_id=DATASET_NAME, - table_id=table, - schema_fields=[ - {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, - {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}, - ], - ) - create_dataset >> create_table >> [copy_selected_data, bigquery_to_gcs] - - delete_dataset = BigQueryDeleteDatasetOperator( - task_id="delete_dataset", dataset_id=DATASET_NAME, delete_contents=True - ) - - [copy_selected_data, bigquery_to_gcs] >> delete_dataset diff --git a/airflow/providers/google/cloud/example_dags/example_bigtable.py b/airflow/providers/google/cloud/example_dags/example_bigtable.py deleted file mode 100644 index 2bfc145f877ba..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_bigtable.py +++ /dev/null @@ -1,210 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG that creates and performs following operations on Cloud Bigtable: -- creates an Instance -- creates a Table -- updates Cluster -- waits for Table replication completeness -- deletes the Table -- deletes the Instance - -This DAG relies on the following environment variables: - -* GCP_PROJECT_ID - Google Cloud project -* CBT_INSTANCE_ID - desired ID of a Cloud Bigtable instance -* CBT_INSTANCE_DISPLAY_NAME - desired human-readable display name of the Instance -* CBT_INSTANCE_TYPE - type of the Instance, e.g. 1 for DEVELOPMENT - See https://googleapis.github.io/google-cloud-python/latest/bigtable/instance.html#google.cloud.bigtable.instance.Instance # noqa E501 -* CBT_INSTANCE_LABELS - labels to add for the Instance -* CBT_CLUSTER_ID - desired ID of the main Cluster created for the Instance -* CBT_CLUSTER_ZONE - zone in which main Cluster will be created. e.g. europe-west1-b - See available zones: https://cloud.google.com/bigtable/docs/locations -* CBT_CLUSTER_NODES - initial amount of nodes of the Cluster -* CBT_CLUSTER_NODES_UPDATED - amount of nodes for BigtableClusterUpdateOperator -* CBT_CLUSTER_STORAGE_TYPE - storage for the Cluster, e.g. 1 for SSD - See https://googleapis.github.io/google-cloud-python/latest/bigtable/instance.html#google.cloud.bigtable.instance.Instance.cluster # noqa E501 -* CBT_TABLE_ID - desired ID of the Table -* CBT_POKE_INTERVAL - number of seconds between every attempt of Sensor check - -""" - -import json -from datetime import datetime -from os import getenv - -from airflow import models -from airflow.providers.google.cloud.operators.bigtable import ( - BigtableCreateInstanceOperator, - BigtableCreateTableOperator, - BigtableDeleteInstanceOperator, - BigtableDeleteTableOperator, - BigtableUpdateClusterOperator, - BigtableUpdateInstanceOperator, -) -from airflow.providers.google.cloud.sensors.bigtable import BigtableTableReplicationCompletedSensor - -GCP_PROJECT_ID = getenv('GCP_PROJECT_ID', 'example-project') -CBT_INSTANCE_ID = getenv('GCP_BIG_TABLE_INSTANCE_ID', 'some-instance-id') -CBT_INSTANCE_DISPLAY_NAME = getenv('GCP_BIG_TABLE_INSTANCE_DISPLAY_NAME', 'Human-readable name') -CBT_INSTANCE_DISPLAY_NAME_UPDATED = getenv( - "GCP_BIG_TABLE_INSTANCE_DISPLAY_NAME_UPDATED", f"{CBT_INSTANCE_DISPLAY_NAME} - updated" -) -CBT_INSTANCE_TYPE = getenv('GCP_BIG_TABLE_INSTANCE_TYPE', '2') -CBT_INSTANCE_TYPE_PROD = getenv('GCP_BIG_TABLE_INSTANCE_TYPE_PROD', '1') -CBT_INSTANCE_LABELS = getenv('GCP_BIG_TABLE_INSTANCE_LABELS', '{}') -CBT_INSTANCE_LABELS_UPDATED = getenv('GCP_BIG_TABLE_INSTANCE_LABELS_UPDATED', '{"env": "prod"}') -CBT_CLUSTER_ID = getenv('GCP_BIG_TABLE_CLUSTER_ID', 'some-cluster-id') -CBT_CLUSTER_ZONE = getenv('GCP_BIG_TABLE_CLUSTER_ZONE', 'europe-west1-b') -CBT_CLUSTER_NODES = getenv('GCP_BIG_TABLE_CLUSTER_NODES', '3') -CBT_CLUSTER_NODES_UPDATED = getenv('GCP_BIG_TABLE_CLUSTER_NODES_UPDATED', '5') -CBT_CLUSTER_STORAGE_TYPE = getenv('GCP_BIG_TABLE_CLUSTER_STORAGE_TYPE', '2') -CBT_TABLE_ID = getenv('GCP_BIG_TABLE_TABLE_ID', 'some-table-id') -CBT_POKE_INTERVAL = getenv('GCP_BIG_TABLE_POKE_INTERVAL', '60') - - -with models.DAG( - 'example_gcp_bigtable_operators', - schedule_interval='@once', # Override to match your needs - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['example'], -) as dag: - # [START howto_operator_gcp_bigtable_instance_create] - create_instance_task = BigtableCreateInstanceOperator( - project_id=GCP_PROJECT_ID, - instance_id=CBT_INSTANCE_ID, - main_cluster_id=CBT_CLUSTER_ID, - main_cluster_zone=CBT_CLUSTER_ZONE, - instance_display_name=CBT_INSTANCE_DISPLAY_NAME, - instance_type=int(CBT_INSTANCE_TYPE), - instance_labels=json.loads(CBT_INSTANCE_LABELS), - cluster_nodes=None, - cluster_storage_type=int(CBT_CLUSTER_STORAGE_TYPE), - task_id='create_instance_task', - ) - create_instance_task2 = BigtableCreateInstanceOperator( - instance_id=CBT_INSTANCE_ID, - main_cluster_id=CBT_CLUSTER_ID, - main_cluster_zone=CBT_CLUSTER_ZONE, - instance_display_name=CBT_INSTANCE_DISPLAY_NAME, - instance_type=int(CBT_INSTANCE_TYPE), - instance_labels=json.loads(CBT_INSTANCE_LABELS), - cluster_nodes=int(CBT_CLUSTER_NODES), - cluster_storage_type=int(CBT_CLUSTER_STORAGE_TYPE), - task_id='create_instance_task2', - ) - create_instance_task >> create_instance_task2 - # [END howto_operator_gcp_bigtable_instance_create] - - # [START howto_operator_gcp_bigtable_instance_update] - update_instance_task = BigtableUpdateInstanceOperator( - instance_id=CBT_INSTANCE_ID, - instance_display_name=CBT_INSTANCE_DISPLAY_NAME_UPDATED, - instance_type=int(CBT_INSTANCE_TYPE_PROD), - instance_labels=json.loads(CBT_INSTANCE_LABELS_UPDATED), - task_id='update_instance_task', - ) - # [END howto_operator_gcp_bigtable_instance_update] - - # [START howto_operator_gcp_bigtable_cluster_update] - cluster_update_task = BigtableUpdateClusterOperator( - project_id=GCP_PROJECT_ID, - instance_id=CBT_INSTANCE_ID, - cluster_id=CBT_CLUSTER_ID, - nodes=int(CBT_CLUSTER_NODES_UPDATED), - task_id='update_cluster_task', - ) - cluster_update_task2 = BigtableUpdateClusterOperator( - instance_id=CBT_INSTANCE_ID, - cluster_id=CBT_CLUSTER_ID, - nodes=int(CBT_CLUSTER_NODES_UPDATED), - task_id='update_cluster_task2', - ) - cluster_update_task >> cluster_update_task2 - # [END howto_operator_gcp_bigtable_cluster_update] - - # [START howto_operator_gcp_bigtable_instance_delete] - delete_instance_task = BigtableDeleteInstanceOperator( - project_id=GCP_PROJECT_ID, - instance_id=CBT_INSTANCE_ID, - task_id='delete_instance_task', - ) - delete_instance_task2 = BigtableDeleteInstanceOperator( - instance_id=CBT_INSTANCE_ID, - task_id='delete_instance_task2', - ) - # [END howto_operator_gcp_bigtable_instance_delete] - - # [START howto_operator_gcp_bigtable_table_create] - create_table_task = BigtableCreateTableOperator( - project_id=GCP_PROJECT_ID, - instance_id=CBT_INSTANCE_ID, - table_id=CBT_TABLE_ID, - task_id='create_table', - ) - create_table_task2 = BigtableCreateTableOperator( - instance_id=CBT_INSTANCE_ID, - table_id=CBT_TABLE_ID, - task_id='create_table_task2', - ) - create_table_task >> create_table_task2 - # [END howto_operator_gcp_bigtable_table_create] - - # [START howto_operator_gcp_bigtable_table_wait_for_replication] - wait_for_table_replication_task = BigtableTableReplicationCompletedSensor( - project_id=GCP_PROJECT_ID, - instance_id=CBT_INSTANCE_ID, - table_id=CBT_TABLE_ID, - poke_interval=int(CBT_POKE_INTERVAL), - timeout=180, - task_id='wait_for_table_replication_task', - ) - wait_for_table_replication_task2 = BigtableTableReplicationCompletedSensor( - instance_id=CBT_INSTANCE_ID, - table_id=CBT_TABLE_ID, - poke_interval=int(CBT_POKE_INTERVAL), - timeout=180, - task_id='wait_for_table_replication_task2', - ) - # [END howto_operator_gcp_bigtable_table_wait_for_replication] - - # [START howto_operator_gcp_bigtable_table_delete] - delete_table_task = BigtableDeleteTableOperator( - project_id=GCP_PROJECT_ID, - instance_id=CBT_INSTANCE_ID, - table_id=CBT_TABLE_ID, - task_id='delete_table_task', - ) - delete_table_task2 = BigtableDeleteTableOperator( - instance_id=CBT_INSTANCE_ID, - table_id=CBT_TABLE_ID, - task_id='delete_table_task2', - ) - # [END howto_operator_gcp_bigtable_table_delete] - - wait_for_table_replication_task >> delete_table_task - wait_for_table_replication_task2 >> delete_table_task - wait_for_table_replication_task >> delete_table_task2 - wait_for_table_replication_task2 >> delete_table_task2 - create_instance_task >> create_table_task >> cluster_update_task - cluster_update_task >> update_instance_task >> delete_table_task - create_instance_task2 >> create_table_task2 >> cluster_update_task2 >> delete_table_task2 - - # Only delete instances after all tables are deleted - [delete_table_task, delete_table_task2] >> delete_instance_task >> delete_instance_task2 diff --git a/airflow/providers/google/cloud/example_dags/example_calendar_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_calendar_to_gcs.py deleted file mode 100644 index aab3a1fe65ae4..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_calendar_to_gcs.py +++ /dev/null @@ -1,43 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.transfers.calendar_to_gcs import GoogleCalendarToGCSOperator - -BUCKET = os.environ.get("GCP_GCS_BUCKET", "test28397yeo") -CALENDAR_ID = os.environ.get("CALENDAR_ID", "1234567890qwerty") -API_VERSION = "v3" - -with models.DAG( - "example_calendar_to_gcs", - schedule_interval='@once', # Override to match your needs - start_date=datetime(2022, 1, 1), - catchup=False, - tags=["example"], -) as dag: - # [START upload_calendar_to_gcs] - upload_calendar_to_gcs = GoogleCalendarToGCSOperator( - task_id="upload_calendar_to_gcs", - destination_bucket=BUCKET, - calendar_id=CALENDAR_ID, - api_version=API_VERSION, - ) - # [END upload_calendar_to_gcs] diff --git a/airflow/providers/google/cloud/example_dags/example_cloud_build.py b/airflow/providers/google/cloud/example_dags/example_cloud_build.py deleted file mode 100644 index f78e1eb47891a..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_cloud_build.py +++ /dev/null @@ -1,267 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG that displays interactions with Google Cloud Build. - -This DAG relies on the following OS environment variables: - -* GCP_PROJECT_ID - Google Cloud Project to use for the Cloud Function. -* GCP_CLOUD_BUILD_ARCHIVE_URL - Path to the zipped source in Google Cloud Storage. - This object must be a gzipped archive file (.tar.gz) containing source to build. -* GCP_CLOUD_BUILD_REPOSITORY_NAME - Name of the Cloud Source Repository. - -""" - -import os -from datetime import datetime -from pathlib import Path -from typing import Any, Dict - -import yaml -from future.backports.urllib.parse import urlparse - -from airflow import models -from airflow.models.baseoperator import chain -from airflow.operators.bash import BashOperator -from airflow.providers.google.cloud.operators.cloud_build import ( - CloudBuildCancelBuildOperator, - CloudBuildCreateBuildOperator, - CloudBuildCreateBuildTriggerOperator, - CloudBuildDeleteBuildTriggerOperator, - CloudBuildGetBuildOperator, - CloudBuildGetBuildTriggerOperator, - CloudBuildListBuildsOperator, - CloudBuildListBuildTriggersOperator, - CloudBuildRetryBuildOperator, - CloudBuildRunBuildTriggerOperator, - CloudBuildUpdateBuildTriggerOperator, -) - -START_DATE = datetime(2021, 1, 1) - -GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "aitflow-test-project") - -GCP_SOURCE_ARCHIVE_URL = os.environ.get("GCP_CLOUD_BUILD_ARCHIVE_URL", "gs://airflow-test-bucket/file.tar.gz") -GCP_SOURCE_REPOSITORY_NAME = os.environ.get("GCP_CLOUD_BUILD_REPOSITORY_NAME", "airflow-test-repo") - -GCP_SOURCE_ARCHIVE_URL_PARTS = urlparse(GCP_SOURCE_ARCHIVE_URL) -GCP_SOURCE_BUCKET_NAME = GCP_SOURCE_ARCHIVE_URL_PARTS.netloc - -CURRENT_FOLDER = Path(__file__).parent - -# [START howto_operator_gcp_create_build_trigger_body] -create_build_trigger_body = { - "name": "test-cloud-build-trigger", - "trigger_template": { - "project_id": GCP_PROJECT_ID, - "repo_name": GCP_SOURCE_REPOSITORY_NAME, - "branch_name": "master", - }, - "filename": "cloudbuild.yaml", -} -# [END howto_operator_gcp_create_build_trigger_body] - -update_build_trigger_body = { - "name": "test-cloud-build-trigger", - "trigger_template": { - "project_id": GCP_PROJECT_ID, - "repo_name": GCP_SOURCE_REPOSITORY_NAME, - "branch_name": "dev", - }, - "filename": "cloudbuild.yaml", -} - -# [START howto_operator_gcp_create_build_from_storage_body] -create_build_from_storage_body = { - "source": {"storage_source": GCP_SOURCE_ARCHIVE_URL}, - "steps": [ - { - "name": "gcr.io/cloud-builders/docker", - "args": ["build", "-t", f"gcr.io/$PROJECT_ID/{GCP_SOURCE_BUCKET_NAME}", "."], - } - ], - "images": [f"gcr.io/$PROJECT_ID/{GCP_SOURCE_BUCKET_NAME}"], -} -# [END howto_operator_gcp_create_build_from_storage_body] - -# [START howto_operator_create_build_from_repo_body] -create_build_from_repo_body: Dict[str, Any] = { - "source": {"repo_source": {"repo_name": GCP_SOURCE_REPOSITORY_NAME, "branch_name": "main"}}, - "steps": [ - { - "name": "gcr.io/cloud-builders/docker", - "args": ["build", "-t", "gcr.io/$PROJECT_ID/$REPO_NAME", "."], - } - ], - "images": ["gcr.io/$PROJECT_ID/$REPO_NAME"], -} -# [END howto_operator_create_build_from_repo_body] - - -with models.DAG( - "example_gcp_cloud_build", - schedule_interval='@once', - start_date=START_DATE, - catchup=False, - tags=["example"], -) as build_dag: - - # [START howto_operator_create_build_from_storage] - create_build_from_storage = CloudBuildCreateBuildOperator( - task_id="create_build_from_storage", project_id=GCP_PROJECT_ID, build=create_build_from_storage_body - ) - # [END howto_operator_create_build_from_storage] - - # [START howto_operator_create_build_from_storage_result] - create_build_from_storage_result = BashOperator( - bash_command=f"echo { create_build_from_storage.output['results'] }", - task_id="create_build_from_storage_result", - ) - # [END howto_operator_create_build_from_storage_result] - - # [START howto_operator_create_build_from_repo] - create_build_from_repo = CloudBuildCreateBuildOperator( - task_id="create_build_from_repo", project_id=GCP_PROJECT_ID, build=create_build_from_repo_body - ) - # [END howto_operator_create_build_from_repo] - - # [START howto_operator_create_build_from_repo_result] - create_build_from_repo_result = BashOperator( - bash_command=f"echo { create_build_from_repo.output['results'] }", - task_id="create_build_from_repo_result", - ) - # [END howto_operator_create_build_from_repo_result] - - # [START howto_operator_list_builds] - list_builds = CloudBuildListBuildsOperator( - task_id="list_builds", project_id=GCP_PROJECT_ID, location="global" - ) - # [END howto_operator_list_builds] - - # [START howto_operator_create_build_without_wait] - create_build_without_wait = CloudBuildCreateBuildOperator( - task_id="create_build_without_wait", - project_id=GCP_PROJECT_ID, - build=create_build_from_repo_body, - wait=False, - ) - # [END howto_operator_create_build_without_wait] - - # [START howto_operator_cancel_build] - cancel_build = CloudBuildCancelBuildOperator( - task_id="cancel_build", - id_=create_build_without_wait.output['id'], - project_id=GCP_PROJECT_ID, - ) - # [END howto_operator_cancel_build] - - # [START howto_operator_retry_build] - retry_build = CloudBuildRetryBuildOperator( - task_id="retry_build", - id_=cancel_build.output['id'], - project_id=GCP_PROJECT_ID, - ) - # [END howto_operator_retry_build] - - # [START howto_operator_get_build] - get_build = CloudBuildGetBuildOperator( - task_id="get_build", - id_=retry_build.output['id'], - project_id=GCP_PROJECT_ID, - ) - # [END howto_operator_get_build] - - # [START howto_operator_gcp_create_build_from_yaml_body] - create_build_from_file = CloudBuildCreateBuildOperator( - task_id="create_build_from_file", - project_id=GCP_PROJECT_ID, - build=yaml.safe_load((Path(CURRENT_FOLDER) / 'example_cloud_build.yaml').read_text()), - params={'name': 'Airflow'}, - ) - # [END howto_operator_gcp_create_build_from_yaml_body] - - create_build_from_storage >> create_build_from_storage_result - create_build_from_storage_result >> list_builds - create_build_from_repo >> create_build_from_repo_result - create_build_from_repo_result >> list_builds - list_builds >> create_build_without_wait >> cancel_build - cancel_build >> retry_build >> get_build - -with models.DAG( - "example_gcp_cloud_build_trigger", - schedule_interval='@once', - start_date=START_DATE, - catchup=False, - tags=["example"], -) as build_trigger_dag: - - # [START howto_operator_create_build_trigger] - create_build_trigger = CloudBuildCreateBuildTriggerOperator( - task_id="create_build_trigger", project_id=GCP_PROJECT_ID, trigger=create_build_trigger_body - ) - # [END howto_operator_create_build_trigger] - - # [START howto_operator_run_build_trigger] - run_build_trigger = CloudBuildRunBuildTriggerOperator( - task_id="run_build_trigger", - project_id=GCP_PROJECT_ID, - trigger_id=create_build_trigger.output['id'], - source=create_build_from_repo_body['source']['repo_source'], - ) - # [END howto_operator_run_build_trigger] - - # [START howto_operator_create_build_trigger] - update_build_trigger = CloudBuildUpdateBuildTriggerOperator( - task_id="update_build_trigger", - project_id=GCP_PROJECT_ID, - trigger_id=create_build_trigger.output['id'], - trigger=update_build_trigger_body, - ) - # [END howto_operator_create_build_trigger] - - # [START howto_operator_get_build_trigger] - get_build_trigger = CloudBuildGetBuildTriggerOperator( - task_id="get_build_trigger", - project_id=GCP_PROJECT_ID, - trigger_id=create_build_trigger.output['id'], - ) - # [END howto_operator_get_build_trigger] - - # [START howto_operator_delete_build_trigger] - delete_build_trigger = CloudBuildDeleteBuildTriggerOperator( - task_id="delete_build_trigger", - project_id=GCP_PROJECT_ID, - trigger_id=create_build_trigger.output['id'], - ) - # [END howto_operator_delete_build_trigger] - - # [START howto_operator_list_build_triggers] - list_build_triggers = CloudBuildListBuildTriggersOperator( - task_id="list_build_triggers", project_id=GCP_PROJECT_ID, location="global", page_size=5 - ) - # [END howto_operator_list_build_triggers] - - chain( - create_build_trigger, - run_build_trigger, - update_build_trigger, - get_build_trigger, - delete_build_trigger, - list_build_triggers, - ) diff --git a/airflow/providers/google/cloud/example_dags/example_cloud_composer.py b/airflow/providers/google/cloud/example_dags/example_cloud_composer.py deleted file mode 100644 index eda597023e3fb..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_cloud_composer.py +++ /dev/null @@ -1,163 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -from datetime import datetime - -from airflow import models -from airflow.models.baseoperator import chain -from airflow.providers.google.cloud.operators.cloud_composer import ( - CloudComposerCreateEnvironmentOperator, - CloudComposerDeleteEnvironmentOperator, - CloudComposerGetEnvironmentOperator, - CloudComposerListEnvironmentsOperator, - CloudComposerListImageVersionsOperator, - CloudComposerUpdateEnvironmentOperator, -) - -PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "") -REGION = os.environ.get("GCP_REGION", "") - -# [START howto_operator_composer_simple_environment] -ENVIRONMENT_ID = os.environ.get("ENVIRONMENT_ID", "ENVIRONMENT_ID>") -ENVIRONMENT = { - "config": { - "node_count": 3, - "software_config": {"image_version": "composer-1.17.7-airflow-2.1.4"}, - } -} -# [END howto_operator_composer_simple_environment] - -# [START howto_operator_composer_update_environment] -UPDATED_ENVIRONMENT = { - "labels": { - "label1": "testing", - } -} -UPDATE_MASK = {"paths": ["labels.label1"]} -# [END howto_operator_composer_update_environment] - - -with models.DAG( - "composer_dag1", - schedule_interval="@once", # Override to match your needs - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['example'], -) as dag: - # [START howto_operator_composer_image_list] - image_versions = CloudComposerListImageVersionsOperator( - task_id="image_versions", - project_id=PROJECT_ID, - region=REGION, - ) - # [END howto_operator_composer_image_list] - - # [START howto_operator_create_composer_environment] - create_env = CloudComposerCreateEnvironmentOperator( - task_id="create_env", - project_id=PROJECT_ID, - region=REGION, - environment_id=ENVIRONMENT_ID, - environment=ENVIRONMENT, - ) - # [END howto_operator_create_composer_environment] - - # [START howto_operator_list_composer_environments] - list_envs = CloudComposerListEnvironmentsOperator( - task_id="list_envs", project_id=PROJECT_ID, region=REGION - ) - # [END howto_operator_list_composer_environments] - - # [START howto_operator_get_composer_environment] - get_env = CloudComposerGetEnvironmentOperator( - task_id="get_env", - project_id=PROJECT_ID, - region=REGION, - environment_id=ENVIRONMENT_ID, - ) - # [END howto_operator_get_composer_environment] - - # [START howto_operator_update_composer_environment] - update_env = CloudComposerUpdateEnvironmentOperator( - task_id="update_env", - project_id=PROJECT_ID, - region=REGION, - environment_id=ENVIRONMENT_ID, - update_mask=UPDATE_MASK, - environment=UPDATED_ENVIRONMENT, - ) - # [END howto_operator_update_composer_environment] - - # [START howto_operator_delete_composer_environment] - delete_env = CloudComposerDeleteEnvironmentOperator( - task_id="delete_env", - project_id=PROJECT_ID, - region=REGION, - environment_id=ENVIRONMENT_ID, - ) - # [END howto_operator_delete_composer_environment] - - chain(image_versions, create_env, list_envs, get_env, update_env, delete_env) - - -with models.DAG( - "composer_dag_deferrable1", - schedule_interval="@once", # Override to match your needs - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['example'], -) as defer_dag: - # [START howto_operator_create_composer_environment_deferrable_mode] - defer_create_env = CloudComposerCreateEnvironmentOperator( - task_id="defer_create_env", - project_id=PROJECT_ID, - region=REGION, - environment_id=ENVIRONMENT_ID, - environment=ENVIRONMENT, - deferrable=True, - ) - # [END howto_operator_create_composer_environment_deferrable_mode] - - # [START howto_operator_update_composer_environment_deferrable_mode] - defer_update_env = CloudComposerUpdateEnvironmentOperator( - task_id="defer_update_env", - project_id=PROJECT_ID, - region=REGION, - environment_id=ENVIRONMENT_ID, - update_mask=UPDATE_MASK, - environment=UPDATED_ENVIRONMENT, - deferrable=True, - ) - # [END howto_operator_update_composer_environment_deferrable_mode] - - # [START howto_operator_delete_composer_environment_deferrable_mode] - defer_delete_env = CloudComposerDeleteEnvironmentOperator( - task_id="defer_delete_env", - project_id=PROJECT_ID, - region=REGION, - environment_id=ENVIRONMENT_ID, - deferrable=True, - ) - # [END howto_operator_delete_composer_environment_deferrable_mode] - - chain( - defer_create_env, - defer_update_env, - defer_delete_env, - ) diff --git a/airflow/providers/google/cloud/example_dags/example_cloud_memorystore.py b/airflow/providers/google/cloud/example_dags/example_cloud_memorystore.py deleted file mode 100644 index 6f9dd09ff3d36..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_cloud_memorystore.py +++ /dev/null @@ -1,337 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example Airflow DAG for Google Cloud Memorystore service. -""" -import os -from datetime import datetime - -from google.cloud.redis_v1 import FailoverInstanceRequest, Instance -from google.protobuf.field_mask_pb2 import FieldMask - -from airflow import models -from airflow.operators.bash import BashOperator -from airflow.providers.google.cloud.operators.cloud_memorystore import ( - CloudMemorystoreCreateInstanceAndImportOperator, - CloudMemorystoreCreateInstanceOperator, - CloudMemorystoreDeleteInstanceOperator, - CloudMemorystoreExportAndDeleteInstanceOperator, - CloudMemorystoreExportInstanceOperator, - CloudMemorystoreFailoverInstanceOperator, - CloudMemorystoreGetInstanceOperator, - CloudMemorystoreImportOperator, - CloudMemorystoreListInstancesOperator, - CloudMemorystoreMemcachedApplyParametersOperator, - CloudMemorystoreMemcachedCreateInstanceOperator, - CloudMemorystoreMemcachedDeleteInstanceOperator, - CloudMemorystoreMemcachedGetInstanceOperator, - CloudMemorystoreMemcachedListInstancesOperator, - CloudMemorystoreMemcachedUpdateInstanceOperator, - CloudMemorystoreMemcachedUpdateParametersOperator, - CloudMemorystoreScaleInstanceOperator, - CloudMemorystoreUpdateInstanceOperator, -) -from airflow.providers.google.cloud.operators.gcs import GCSBucketCreateAclEntryOperator - -START_DATE = datetime(2021, 1, 1) - -GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") - -MEMORYSTORE_REDIS_INSTANCE_NAME = os.environ.get( - "GCP_MEMORYSTORE_REDIS_INSTANCE_NAME", "test-memorystore-redis" -) -MEMORYSTORE_REDIS_INSTANCE_NAME_2 = os.environ.get( - "GCP_MEMORYSTORE_REDIS_INSTANCE_NAME_2", "test-memorystore-redis-2" -) -MEMORYSTORE_REDIS_INSTANCE_NAME_3 = os.environ.get( - "GCP_MEMORYSTORE_REDIS_INSTANCE_NAME_3", "test-memorystore-redis-3" -) -MEMORYSTORE_MEMCACHED_INSTANCE_NAME = os.environ.get( - "GCP_MEMORYSTORE_MEMCACHED_INSTANCE_NAME", "test-memorystore-memcached-1" -) - -BUCKET_NAME = os.environ.get("GCP_MEMORYSTORE_BUCKET", "INVALID BUCKET NAME") -EXPORT_GCS_URL = f"gs://{BUCKET_NAME}/my-export.rdb" - -# [START howto_operator_instance] -FIRST_INSTANCE = {"tier": Instance.Tier.BASIC, "memory_size_gb": 1} -# [END howto_operator_instance] - -SECOND_INSTANCE = {"tier": Instance.Tier.STANDARD_HA, "memory_size_gb": 3} - -# [START howto_operator_memcached_instance] -MEMCACHED_INSTANCE = {"name": "", "node_count": 1, "node_config": {"cpu_count": 1, "memory_size_mb": 1024}} -# [END howto_operator_memcached_instance] - - -with models.DAG( - "gcp_cloud_memorystore_redis", - schedule_interval='@once', # Override to match your needs - start_date=START_DATE, - catchup=False, - tags=['example'], -) as dag: - # [START howto_operator_create_instance] - create_instance = CloudMemorystoreCreateInstanceOperator( - task_id="create-instance", - location="europe-north1", - instance_id=MEMORYSTORE_REDIS_INSTANCE_NAME, - instance=FIRST_INSTANCE, - project_id=GCP_PROJECT_ID, - ) - # [END howto_operator_create_instance] - - # [START howto_operator_create_instance_result] - create_instance_result = BashOperator( - task_id="create-instance-result", - bash_command=f"echo {create_instance.output}", - ) - # [END howto_operator_create_instance_result] - - create_instance_2 = CloudMemorystoreCreateInstanceOperator( - task_id="create-instance-2", - location="europe-north1", - instance_id=MEMORYSTORE_REDIS_INSTANCE_NAME_2, - instance=SECOND_INSTANCE, - project_id=GCP_PROJECT_ID, - ) - - # [START howto_operator_get_instance] - get_instance = CloudMemorystoreGetInstanceOperator( - task_id="get-instance", - location="europe-north1", - instance=MEMORYSTORE_REDIS_INSTANCE_NAME, - project_id=GCP_PROJECT_ID, - do_xcom_push=True, - ) - # [END howto_operator_get_instance] - - # [START howto_operator_get_instance_result] - get_instance_result = BashOperator( - task_id="get-instance-result", bash_command=f"echo {get_instance.output}" - ) - # [END howto_operator_get_instance_result] - - # [START howto_operator_failover_instance] - failover_instance = CloudMemorystoreFailoverInstanceOperator( - task_id="failover-instance", - location="europe-north1", - instance=MEMORYSTORE_REDIS_INSTANCE_NAME_2, - data_protection_mode=FailoverInstanceRequest.DataProtectionMode( - FailoverInstanceRequest.DataProtectionMode.LIMITED_DATA_LOSS - ), - project_id=GCP_PROJECT_ID, - ) - # [END howto_operator_failover_instance] - - # [START howto_operator_list_instances] - list_instances = CloudMemorystoreListInstancesOperator( - task_id="list-instances", location="-", page_size=100, project_id=GCP_PROJECT_ID - ) - # [END howto_operator_list_instances] - - # [START howto_operator_list_instances_result] - list_instances_result = BashOperator( - task_id="list-instances-result", bash_command=f"echo {get_instance.output}" - ) - # [END howto_operator_list_instances_result] - - # [START howto_operator_update_instance] - update_instance = CloudMemorystoreUpdateInstanceOperator( - task_id="update-instance", - location="europe-north1", - instance_id=MEMORYSTORE_REDIS_INSTANCE_NAME, - project_id=GCP_PROJECT_ID, - update_mask={"paths": ["memory_size_gb"]}, - instance={"memory_size_gb": 2}, - ) - # [END howto_operator_update_instance] - - # [START howto_operator_set_acl_permission] - set_acl_permission = GCSBucketCreateAclEntryOperator( - task_id="gcs-set-acl-permission", - bucket=BUCKET_NAME, - entity="user-{{ task_instance.xcom_pull('get-instance')['persistence_iam_identity']" - ".split(':', 2)[1] }}", - role="OWNER", - ) - # [END howto_operator_set_acl_permission] - - # [START howto_operator_export_instance] - export_instance = CloudMemorystoreExportInstanceOperator( - task_id="export-instance", - location="europe-north1", - instance=MEMORYSTORE_REDIS_INSTANCE_NAME, - output_config={"gcs_destination": {"uri": EXPORT_GCS_URL}}, - project_id=GCP_PROJECT_ID, - ) - # [END howto_operator_export_instance] - - # [START howto_operator_import_instance] - import_instance = CloudMemorystoreImportOperator( - task_id="import-instance", - location="europe-north1", - instance=MEMORYSTORE_REDIS_INSTANCE_NAME_2, - input_config={"gcs_source": {"uri": EXPORT_GCS_URL}}, - project_id=GCP_PROJECT_ID, - ) - # [END howto_operator_import_instance] - - # [START howto_operator_delete_instance] - delete_instance = CloudMemorystoreDeleteInstanceOperator( - task_id="delete-instance", - location="europe-north1", - instance=MEMORYSTORE_REDIS_INSTANCE_NAME, - project_id=GCP_PROJECT_ID, - ) - # [END howto_operator_delete_instance] - - delete_instance_2 = CloudMemorystoreDeleteInstanceOperator( - task_id="delete-instance-2", - location="europe-north1", - instance=MEMORYSTORE_REDIS_INSTANCE_NAME_2, - project_id=GCP_PROJECT_ID, - ) - - # [END howto_operator_create_instance_and_import] - create_instance_and_import = CloudMemorystoreCreateInstanceAndImportOperator( - task_id="create-instance-and-import", - location="europe-north1", - instance_id=MEMORYSTORE_REDIS_INSTANCE_NAME_3, - instance=FIRST_INSTANCE, - input_config={"gcs_source": {"uri": EXPORT_GCS_URL}}, - project_id=GCP_PROJECT_ID, - ) - # [START howto_operator_create_instance_and_import] - - # [START howto_operator_scale_instance] - scale_instance = CloudMemorystoreScaleInstanceOperator( - task_id="scale-instance", - location="europe-north1", - instance_id=MEMORYSTORE_REDIS_INSTANCE_NAME_3, - project_id=GCP_PROJECT_ID, - memory_size_gb=3, - ) - # [END howto_operator_scale_instance] - - # [END howto_operator_export_and_delete_instance] - export_and_delete_instance = CloudMemorystoreExportAndDeleteInstanceOperator( - task_id="export-and-delete-instance", - location="europe-north1", - instance=MEMORYSTORE_REDIS_INSTANCE_NAME_3, - output_config={"gcs_destination": {"uri": EXPORT_GCS_URL}}, - project_id=GCP_PROJECT_ID, - ) - # [START howto_operator_export_and_delete_instance] - - create_instance >> get_instance >> get_instance_result - create_instance >> update_instance - create_instance >> export_instance - create_instance_2 >> import_instance - create_instance >> list_instances >> list_instances_result - list_instances >> delete_instance - export_instance >> update_instance - update_instance >> delete_instance - create_instance >> create_instance_result - get_instance >> set_acl_permission >> export_instance - get_instance >> list_instances_result - export_instance >> import_instance - export_instance >> delete_instance - failover_instance >> delete_instance_2 - import_instance >> failover_instance - - export_instance >> create_instance_and_import >> scale_instance >> export_and_delete_instance - - -with models.DAG( - "gcp_cloud_memorystore_memcached", - schedule_interval='@once', # Override to match your needs - start_date=START_DATE, - catchup=False, - tags=['example'], -) as dag_memcache: - # [START howto_operator_create_instance_memcached] - create_memcached_instance = CloudMemorystoreMemcachedCreateInstanceOperator( - task_id="create-instance", - location="europe-north1", - instance_id=MEMORYSTORE_MEMCACHED_INSTANCE_NAME, - instance=MEMCACHED_INSTANCE, - project_id=GCP_PROJECT_ID, - ) - # [END howto_operator_create_instance_memcached] - - # [START howto_operator_delete_instance_memcached] - delete_memcached_instance = CloudMemorystoreMemcachedDeleteInstanceOperator( - task_id="delete-instance", - location="europe-north1", - instance=MEMORYSTORE_MEMCACHED_INSTANCE_NAME, - project_id=GCP_PROJECT_ID, - ) - # [END howto_operator_delete_instance_memcached] - - # [START howto_operator_get_instance_memcached] - get_memcached_instance = CloudMemorystoreMemcachedGetInstanceOperator( - task_id="get-instance", - location="europe-north1", - instance=MEMORYSTORE_MEMCACHED_INSTANCE_NAME, - project_id=GCP_PROJECT_ID, - ) - # [END howto_operator_get_instance_memcached] - - # [START howto_operator_list_instances_memcached] - list_memcached_instances = CloudMemorystoreMemcachedListInstancesOperator( - task_id="list-instances", location="-", project_id=GCP_PROJECT_ID - ) - # [END howto_operator_list_instances_memcached] - - # # [START howto_operator_update_instance_memcached] - update_memcached_instance = CloudMemorystoreMemcachedUpdateInstanceOperator( - task_id="update-instance", - location="europe-north1", - instance_id=MEMORYSTORE_MEMCACHED_INSTANCE_NAME, - project_id=GCP_PROJECT_ID, - update_mask=FieldMask(paths=["node_count"]), - instance={"node_count": 2}, - ) - # [END howto_operator_update_instance_memcached] - - # [START howto_operator_update_and_apply_parameters_memcached] - update_memcached_parameters = CloudMemorystoreMemcachedUpdateParametersOperator( - task_id="update-parameters", - location="europe-north1", - instance_id=MEMORYSTORE_MEMCACHED_INSTANCE_NAME, - project_id=GCP_PROJECT_ID, - update_mask={"paths": ["params"]}, - parameters={"params": {"protocol": "ascii", "hash_algorithm": "jenkins"}}, - ) - - apply_memcached_parameters = CloudMemorystoreMemcachedApplyParametersOperator( - task_id="apply-parameters", - location="europe-north1", - instance_id=MEMORYSTORE_MEMCACHED_INSTANCE_NAME, - project_id=GCP_PROJECT_ID, - node_ids=["node-a-1"], - apply_all=False, - ) - - # update_parameters >> apply_parameters - # [END howto_operator_update_and_apply_parameters_memcached] - - create_memcached_instance >> [list_memcached_instances, get_memcached_instance] - create_memcached_instance >> update_memcached_instance >> update_memcached_parameters - update_memcached_parameters >> apply_memcached_parameters >> delete_memcached_instance diff --git a/airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py b/airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py index 307de77ff3ed2..68b329f557415 100644 --- a/airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py +++ b/airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example Airflow DAG that performs query in a Cloud SQL instance. @@ -36,6 +35,8 @@ * GCSQL_MYSQL_PUBLIC_IP - Public IP of the mysql database * GCSQL_MYSQL_PUBLIC_PORT - Port of the mysql database """ +from __future__ import annotations + import os import subprocess from datetime import datetime @@ -45,42 +46,42 @@ from airflow import models from airflow.providers.google.cloud.operators.cloud_sql import CloudSQLExecuteQueryOperator -GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'example-project') -GCP_REGION = os.environ.get('GCP_REGION', 'europe-west1') +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +GCP_REGION = os.environ.get("GCP_REGION", "europe-west1") GCSQL_POSTGRES_INSTANCE_NAME_QUERY = os.environ.get( - 'GCSQL_POSTGRES_INSTANCE_NAME_QUERY', 'test-postgres-query' + "GCSQL_POSTGRES_INSTANCE_NAME_QUERY", "test-postgres-query" ) -GCSQL_POSTGRES_DATABASE_NAME = os.environ.get('GCSQL_POSTGRES_DATABASE_NAME', 'postgresdb') -GCSQL_POSTGRES_USER = os.environ.get('GCSQL_POSTGRES_USER', 'postgres_user') -GCSQL_POSTGRES_PASSWORD = os.environ.get('GCSQL_POSTGRES_PASSWORD', 'JoxHlwrPzwch0gz9') -GCSQL_POSTGRES_PUBLIC_IP = os.environ.get('GCSQL_POSTGRES_PUBLIC_IP', '0.0.0.0') -GCSQL_POSTGRES_PUBLIC_PORT = os.environ.get('GCSQL_POSTGRES_PUBLIC_PORT', 5432) +GCSQL_POSTGRES_DATABASE_NAME = os.environ.get("GCSQL_POSTGRES_DATABASE_NAME", "postgresdb") +GCSQL_POSTGRES_USER = os.environ.get("GCSQL_POSTGRES_USER", "postgres_user") +GCSQL_POSTGRES_PASSWORD = os.environ.get("GCSQL_POSTGRES_PASSWORD", "JoxHlwrPzwch0gz9") +GCSQL_POSTGRES_PUBLIC_IP = os.environ.get("GCSQL_POSTGRES_PUBLIC_IP", "0.0.0.0") +GCSQL_POSTGRES_PUBLIC_PORT = os.environ.get("GCSQL_POSTGRES_PUBLIC_PORT", 5432) GCSQL_POSTGRES_CLIENT_CERT_FILE = os.environ.get( - 'GCSQL_POSTGRES_CLIENT_CERT_FILE', ".key/postgres-client-cert.pem" + "GCSQL_POSTGRES_CLIENT_CERT_FILE", ".key/postgres-client-cert.pem" ) GCSQL_POSTGRES_CLIENT_KEY_FILE = os.environ.get( - 'GCSQL_POSTGRES_CLIENT_KEY_FILE', ".key/postgres-client-key.pem" + "GCSQL_POSTGRES_CLIENT_KEY_FILE", ".key/postgres-client-key.pem" ) -GCSQL_POSTGRES_SERVER_CA_FILE = os.environ.get('GCSQL_POSTGRES_SERVER_CA_FILE', ".key/postgres-server-ca.pem") - -GCSQL_MYSQL_INSTANCE_NAME_QUERY = os.environ.get('GCSQL_MYSQL_INSTANCE_NAME_QUERY', 'test-mysql-query') -GCSQL_MYSQL_DATABASE_NAME = os.environ.get('GCSQL_MYSQL_DATABASE_NAME', 'mysqldb') -GCSQL_MYSQL_USER = os.environ.get('GCSQL_MYSQL_USER', 'mysql_user') -GCSQL_MYSQL_PASSWORD = os.environ.get('GCSQL_MYSQL_PASSWORD', 'JoxHlwrPzwch0gz9') -GCSQL_MYSQL_PUBLIC_IP = os.environ.get('GCSQL_MYSQL_PUBLIC_IP', '0.0.0.0') -GCSQL_MYSQL_PUBLIC_PORT = os.environ.get('GCSQL_MYSQL_PUBLIC_PORT', 3306) -GCSQL_MYSQL_CLIENT_CERT_FILE = os.environ.get('GCSQL_MYSQL_CLIENT_CERT_FILE', ".key/mysql-client-cert.pem") -GCSQL_MYSQL_CLIENT_KEY_FILE = os.environ.get('GCSQL_MYSQL_CLIENT_KEY_FILE', ".key/mysql-client-key.pem") -GCSQL_MYSQL_SERVER_CA_FILE = os.environ.get('GCSQL_MYSQL_SERVER_CA_FILE', ".key/mysql-server-ca.pem") +GCSQL_POSTGRES_SERVER_CA_FILE = os.environ.get("GCSQL_POSTGRES_SERVER_CA_FILE", ".key/postgres-server-ca.pem") + +GCSQL_MYSQL_INSTANCE_NAME_QUERY = os.environ.get("GCSQL_MYSQL_INSTANCE_NAME_QUERY", "test-mysql-query") +GCSQL_MYSQL_DATABASE_NAME = os.environ.get("GCSQL_MYSQL_DATABASE_NAME", "mysqldb") +GCSQL_MYSQL_USER = os.environ.get("GCSQL_MYSQL_USER", "mysql_user") +GCSQL_MYSQL_PASSWORD = os.environ.get("GCSQL_MYSQL_PASSWORD", "JoxHlwrPzwch0gz9") +GCSQL_MYSQL_PUBLIC_IP = os.environ.get("GCSQL_MYSQL_PUBLIC_IP", "0.0.0.0") +GCSQL_MYSQL_PUBLIC_PORT = os.environ.get("GCSQL_MYSQL_PUBLIC_PORT", 3306) +GCSQL_MYSQL_CLIENT_CERT_FILE = os.environ.get("GCSQL_MYSQL_CLIENT_CERT_FILE", ".key/mysql-client-cert.pem") +GCSQL_MYSQL_CLIENT_KEY_FILE = os.environ.get("GCSQL_MYSQL_CLIENT_KEY_FILE", ".key/mysql-client-key.pem") +GCSQL_MYSQL_SERVER_CA_FILE = os.environ.get("GCSQL_MYSQL_SERVER_CA_FILE", ".key/mysql-server-ca.pem") SQL = [ - 'CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)', - 'CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)', # shows warnings logged - 'INSERT INTO TABLE_TEST VALUES (0)', - 'CREATE TABLE IF NOT EXISTS TABLE_TEST2 (I INTEGER)', - 'DROP TABLE TABLE_TEST', - 'DROP TABLE TABLE_TEST2', + "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)", + "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)", # shows warnings logged + "INSERT INTO TABLE_TEST VALUES (0)", + "CREATE TABLE IF NOT EXISTS TABLE_TEST2 (I INTEGER)", + "DROP TABLE TABLE_TEST", + "DROP TABLE TABLE_TEST2", ] @@ -118,7 +119,7 @@ def get_absolute_path(path): # of AIRFLOW (using command line or UI). # Postgres: connect via proxy over TCP -os.environ['AIRFLOW_CONN_PROXY_POSTGRES_TCP'] = ( +os.environ["AIRFLOW_CONN_PROXY_POSTGRES_TCP"] = ( "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" "database_type=postgres&" "project_id={project_id}&" @@ -129,7 +130,7 @@ def get_absolute_path(path): ) # Postgres: connect via proxy over UNIX socket (specific proxy version) -os.environ['AIRFLOW_CONN_PROXY_POSTGRES_SOCKET'] = ( +os.environ["AIRFLOW_CONN_PROXY_POSTGRES_SOCKET"] = ( "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" "database_type=postgres&" "project_id={project_id}&" @@ -141,7 +142,7 @@ def get_absolute_path(path): ) # Postgres: connect directly via TCP (non-SSL) -os.environ['AIRFLOW_CONN_PUBLIC_POSTGRES_TCP'] = ( +os.environ["AIRFLOW_CONN_PUBLIC_POSTGRES_TCP"] = ( "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" "database_type=postgres&" "project_id={project_id}&" @@ -152,7 +153,7 @@ def get_absolute_path(path): ) # Postgres: connect directly via TCP (SSL) -os.environ['AIRFLOW_CONN_PUBLIC_POSTGRES_TCP_SSL'] = ( +os.environ["AIRFLOW_CONN_PUBLIC_POSTGRES_TCP_SSL"] = ( "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" "database_type=postgres&" "project_id={project_id}&" @@ -180,7 +181,7 @@ def get_absolute_path(path): ) # MySQL: connect via proxy over TCP (specific proxy version) -os.environ['AIRFLOW_CONN_PROXY_MYSQL_TCP'] = ( +os.environ["AIRFLOW_CONN_PROXY_MYSQL_TCP"] = ( "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" "database_type=mysql&" "project_id={project_id}&" @@ -193,11 +194,11 @@ def get_absolute_path(path): # MySQL: connect via proxy over UNIX socket using pre-downloaded Cloud Sql Proxy binary try: - sql_proxy_binary_path = subprocess.check_output(['which', 'cloud_sql_proxy']).decode('utf-8').rstrip() + sql_proxy_binary_path = subprocess.check_output(["which", "cloud_sql_proxy"]).decode("utf-8").rstrip() except subprocess.CalledProcessError: sql_proxy_binary_path = "/tmp/anyhow_download_cloud_sql_proxy" -os.environ['AIRFLOW_CONN_PROXY_MYSQL_SOCKET'] = ( +os.environ["AIRFLOW_CONN_PROXY_MYSQL_SOCKET"] = ( "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" "database_type=mysql&" "project_id={project_id}&" @@ -209,7 +210,7 @@ def get_absolute_path(path): ) # MySQL: connect directly via TCP (non-SSL) -os.environ['AIRFLOW_CONN_PUBLIC_MYSQL_TCP'] = ( +os.environ["AIRFLOW_CONN_PUBLIC_MYSQL_TCP"] = ( "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" "database_type=mysql&" "project_id={project_id}&" @@ -220,7 +221,7 @@ def get_absolute_path(path): ) # MySQL: connect directly via TCP (SSL) and with fixed Cloud Sql Proxy binary path -os.environ['AIRFLOW_CONN_PUBLIC_MYSQL_TCP_SSL'] = ( +os.environ["AIRFLOW_CONN_PUBLIC_MYSQL_TCP_SSL"] = ( "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" "database_type=mysql&" "project_id={project_id}&" @@ -236,7 +237,7 @@ def get_absolute_path(path): # Special case: MySQL: connect directly via TCP (SSL) and with fixed Cloud Sql # Proxy binary path AND with missing project_id -os.environ['AIRFLOW_CONN_PUBLIC_MYSQL_TCP_SSL_NO_PROJECT_ID'] = ( +os.environ["AIRFLOW_CONN_PUBLIC_MYSQL_TCP_SSL_NO_PROJECT_ID"] = ( "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" "database_type=mysql&" "location={location}&" @@ -269,11 +270,10 @@ def get_absolute_path(path): with models.DAG( - dag_id='example_gcp_sql_query', - schedule_interval='@once', + dag_id="example_gcp_sql_query", start_date=datetime(2021, 1, 1), catchup=False, - tags=['example'], + tags=["example"], ) as dag: prev_task = None diff --git a/airflow/providers/google/cloud/example_dags/example_cloud_storage_transfer_service_aws.py b/airflow/providers/google/cloud/example_dags/example_cloud_storage_transfer_service_aws.py index be858c4018753..6a5f93bffcc84 100644 --- a/airflow/providers/google/cloud/example_dags/example_cloud_storage_transfer_service_aws.py +++ b/airflow/providers/google/cloud/example_dags/example_cloud_storage_transfer_service_aws.py @@ -15,19 +15,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ -Example Airflow DAG that demonstrates interactions with Google Cloud Transfer. - +Example Airflow DAG that demonstrates interactions with Google Cloud Transfer. This DAG relies on +the following OS environment variables -This DAG relies on the following OS environment variables +Note that you need to provide a large enough set of data so that operations do not execute too quickly. +Otherwise, DAG will fail. * GCP_PROJECT_ID - Google Cloud Project to use for the Google Cloud Transfer Service. * GCP_DESCRIPTION - Description of transfer job * GCP_TRANSFER_SOURCE_AWS_BUCKET - Amazon Web Services Storage bucket from which files are copied. - .. warning:: - You need to provide a large enough set of data so that operations do not execute too quickly. - Otherwise, DAG will fail. * GCP_TRANSFER_SECOND_TARGET_BUCKET - Google Cloud Storage bucket to which files are copied * WAIT_FOR_OPERATION_POKE_INTERVAL - interval of what to check the status of the operation A smaller value than the default value accelerates the system test and ensures its correct execution with @@ -35,6 +32,7 @@ Look at documentation of :class:`~airflow.operators.sensors.BaseSensorOperator` for more information """ +from __future__ import annotations import os from datetime import datetime, timedelta @@ -74,17 +72,17 @@ CloudDataTransferServiceJobStatusSensor, ) -GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'example-project') -GCP_DESCRIPTION = os.environ.get('GCP_DESCRIPTION', 'description') -GCP_TRANSFER_TARGET_BUCKET = os.environ.get('GCP_TRANSFER_TARGET_BUCKET') -WAIT_FOR_OPERATION_POKE_INTERVAL = int(os.environ.get('WAIT_FOR_OPERATION_POKE_INTERVAL', 5)) +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +GCP_DESCRIPTION = os.environ.get("GCP_DESCRIPTION", "description") +GCP_TRANSFER_TARGET_BUCKET = os.environ.get("GCP_TRANSFER_TARGET_BUCKET") +WAIT_FOR_OPERATION_POKE_INTERVAL = int(os.environ.get("WAIT_FOR_OPERATION_POKE_INTERVAL", 5)) -GCP_TRANSFER_SOURCE_AWS_BUCKET = os.environ.get('GCP_TRANSFER_SOURCE_AWS_BUCKET') +GCP_TRANSFER_SOURCE_AWS_BUCKET = os.environ.get("GCP_TRANSFER_SOURCE_AWS_BUCKET") GCP_TRANSFER_FIRST_TARGET_BUCKET = os.environ.get( - 'GCP_TRANSFER_FIRST_TARGET_BUCKET', 'gcp-transfer-first-target' + "GCP_TRANSFER_FIRST_TARGET_BUCKET", "gcp-transfer-first-target" ) -GCP_TRANSFER_JOB_NAME = os.environ.get('GCP_TRANSFER_JOB_NAME', 'transferJobs/sampleJob') +GCP_TRANSFER_JOB_NAME = os.environ.get("GCP_TRANSFER_JOB_NAME", "transferJobs/sampleJob") # [START howto_operator_gcp_transfer_create_job_body_aws] aws_to_gcs_transfer_body = { @@ -107,11 +105,10 @@ with models.DAG( - 'example_gcp_transfer_aws', - schedule_interval=None, # Override to match your needs + "example_gcp_transfer_aws", start_date=datetime(2021, 1, 1), catchup=False, - tags=['example'], + tags=["example"], ) as dag: # [START howto_operator_gcp_transfer_create_job] diff --git a/airflow/providers/google/cloud/example_dags/example_cloud_storage_transfer_service_gcp.py b/airflow/providers/google/cloud/example_dags/example_cloud_storage_transfer_service_gcp.py deleted file mode 100644 index cdf38e0ca1fae..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_cloud_storage_transfer_service_gcp.py +++ /dev/null @@ -1,155 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG that demonstrates interactions with Google Cloud Transfer. - - -This DAG relies on the following OS environment variables - -* GCP_PROJECT_ID - Google Cloud Project to use for the Google Cloud Transfer Service. -* GCP_TRANSFER_FIRST_TARGET_BUCKET - Google Cloud Storage bucket to which files are copied from AWS. - It is also a source bucket in next step -* GCP_TRANSFER_SECOND_TARGET_BUCKET - Google Cloud Storage bucket to which files are copied -""" - -import os -from datetime import datetime, timedelta - -from airflow import models -from airflow.models.baseoperator import chain -from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( - ALREADY_EXISTING_IN_SINK, - BUCKET_NAME, - DESCRIPTION, - FILTER_JOB_NAMES, - FILTER_PROJECT_ID, - GCS_DATA_SINK, - GCS_DATA_SOURCE, - PROJECT_ID, - SCHEDULE, - SCHEDULE_END_DATE, - SCHEDULE_START_DATE, - START_TIME_OF_DAY, - STATUS, - TRANSFER_JOB, - TRANSFER_JOB_FIELD_MASK, - TRANSFER_OPTIONS, - TRANSFER_SPEC, - GcpTransferJobsStatus, - GcpTransferOperationStatus, -) -from airflow.providers.google.cloud.operators.cloud_storage_transfer_service import ( - CloudDataTransferServiceCreateJobOperator, - CloudDataTransferServiceDeleteJobOperator, - CloudDataTransferServiceGetOperationOperator, - CloudDataTransferServiceListOperationsOperator, - CloudDataTransferServiceUpdateJobOperator, -) -from airflow.providers.google.cloud.sensors.cloud_storage_transfer_service import ( - CloudDataTransferServiceJobStatusSensor, -) - -GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") -GCP_TRANSFER_FIRST_TARGET_BUCKET = os.environ.get( - "GCP_TRANSFER_FIRST_TARGET_BUCKET", "gcp-transfer-first-target" -) -GCP_TRANSFER_SECOND_TARGET_BUCKET = os.environ.get( - "GCP_TRANSFER_SECOND_TARGET_BUCKET", "gcp-transfer-second-target" -) - -# [START howto_operator_gcp_transfer_create_job_body_gcp] -gcs_to_gcs_transfer_body = { - DESCRIPTION: "description", - STATUS: GcpTransferJobsStatus.ENABLED, - PROJECT_ID: GCP_PROJECT_ID, - SCHEDULE: { - SCHEDULE_START_DATE: datetime(2015, 1, 1).date(), - SCHEDULE_END_DATE: datetime(2030, 1, 1).date(), - START_TIME_OF_DAY: (datetime.utcnow() + timedelta(seconds=120)).time(), - }, - TRANSFER_SPEC: { - GCS_DATA_SOURCE: {BUCKET_NAME: GCP_TRANSFER_FIRST_TARGET_BUCKET}, - GCS_DATA_SINK: {BUCKET_NAME: GCP_TRANSFER_SECOND_TARGET_BUCKET}, - TRANSFER_OPTIONS: {ALREADY_EXISTING_IN_SINK: True}, - }, -} -# [END howto_operator_gcp_transfer_create_job_body_gcp] - -# [START howto_operator_gcp_transfer_update_job_body] -update_body = { - PROJECT_ID: GCP_PROJECT_ID, - TRANSFER_JOB: {DESCRIPTION: "description_updated"}, - TRANSFER_JOB_FIELD_MASK: "description", -} -# [END howto_operator_gcp_transfer_update_job_body] - -with models.DAG( - "example_gcp_transfer", - schedule_interval='@once', # Override to match your needs - start_date=datetime(2021, 1, 1), - catchup=False, - tags=["example"], -) as dag: - - create_transfer = CloudDataTransferServiceCreateJobOperator( - task_id="create_transfer", body=gcs_to_gcs_transfer_body - ) - - # [START howto_operator_gcp_transfer_update_job] - update_transfer = CloudDataTransferServiceUpdateJobOperator( - task_id="update_transfer", - job_name="{{task_instance.xcom_pull('create_transfer')['name']}}", - body=update_body, - ) - # [END howto_operator_gcp_transfer_update_job] - - wait_for_transfer = CloudDataTransferServiceJobStatusSensor( - task_id="wait_for_transfer", - job_name="{{task_instance.xcom_pull('create_transfer')['name']}}", - project_id=GCP_PROJECT_ID, - expected_statuses={GcpTransferOperationStatus.SUCCESS}, - ) - - list_operations = CloudDataTransferServiceListOperationsOperator( - task_id="list_operations", - request_filter={ - FILTER_PROJECT_ID: GCP_PROJECT_ID, - FILTER_JOB_NAMES: ["{{task_instance.xcom_pull('create_transfer')['name']}}"], - }, - ) - - get_operation = CloudDataTransferServiceGetOperationOperator( - task_id="get_operation", - operation_name="{{task_instance.xcom_pull('list_operations')[0]['name']}}", - ) - - delete_transfer = CloudDataTransferServiceDeleteJobOperator( - task_id="delete_transfer_from_gcp_job", - job_name="{{task_instance.xcom_pull('create_transfer')['name']}}", - project_id=GCP_PROJECT_ID, - ) - - chain( - create_transfer, - wait_for_transfer, - update_transfer, - list_operations, - get_operation, - delete_transfer, - ) diff --git a/airflow/providers/google/cloud/example_dags/example_cloud_task.py b/airflow/providers/google/cloud/example_dags/example_cloud_task.py new file mode 100644 index 0000000000000..877eac15f3453 --- /dev/null +++ b/airflow/providers/google/cloud/example_dags/example_cloud_task.py @@ -0,0 +1,53 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG that sense a cloud task queue being empty. + +This DAG relies on the following OS environment variables + +* GCP_PROJECT_ID - Google Cloud project where the Compute Engine instance exists. +* GCP_ZONE - Google Cloud zone where the cloud task queue exists. +* QUEUE_NAME - Name of the cloud task queue. +""" +from __future__ import annotations + +import os +from datetime import datetime + +from airflow import models +from airflow.providers.google.cloud.sensors.tasks import TaskQueueEmptySensor + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +GCP_ZONE = os.environ.get("GCE_ZONE", "europe-west1-b") +QUEUE_NAME = os.environ.get("GCP_QUEUE_NAME", "testqueue") + + +with models.DAG( + "example_gcp_cloud_tasks_sensor", + start_date=datetime(2022, 8, 8), + catchup=False, + tags=["example"], +) as dag: + # [START cloud_tasks_empty_sensor] + gcp_cloud_tasks_sensor = TaskQueueEmptySensor( + project_id=GCP_PROJECT_ID, + location=GCP_ZONE, + task_id="gcp_sense_cloud_tasks_empty", + queue_name=QUEUE_NAME, + ) + # [END cloud_tasks_empty_sensor] diff --git a/airflow/providers/google/cloud/example_dags/example_compute.py b/airflow/providers/google/cloud/example_dags/example_compute.py index 6d81e3a232a02..e42cb6d962ed2 100644 --- a/airflow/providers/google/cloud/example_dags/example_compute.py +++ b/airflow/providers/google/cloud/example_dags/example_compute.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example Airflow DAG that starts, stops and sets the machine type of a Google Compute Engine instance. @@ -28,6 +27,7 @@ * GCE_SHORT_MACHINE_TYPE_NAME - Machine type resource name to set, e.g. 'n1-standard-1'. See https://cloud.google.com/compute/docs/machine-types """ +from __future__ import annotations import os from datetime import datetime @@ -41,42 +41,41 @@ ) # [START howto_operator_gce_args_common] -GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'example-project') -GCE_ZONE = os.environ.get('GCE_ZONE', 'europe-west1-b') -GCE_INSTANCE = os.environ.get('GCE_INSTANCE', 'testinstance') +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +GCE_ZONE = os.environ.get("GCE_ZONE", "europe-west1-b") +GCE_INSTANCE = os.environ.get("GCE_INSTANCE", "testinstance") # [END howto_operator_gce_args_common] -GCE_SHORT_MACHINE_TYPE_NAME = os.environ.get('GCE_SHORT_MACHINE_TYPE_NAME', 'n1-standard-1') +GCE_SHORT_MACHINE_TYPE_NAME = os.environ.get("GCE_SHORT_MACHINE_TYPE_NAME", "n1-standard-1") with models.DAG( - 'example_gcp_compute', - schedule_interval='@once', # Override to match your needs + "example_gcp_compute", start_date=datetime(2021, 1, 1), catchup=False, - tags=['example'], + tags=["example"], ) as dag: # [START howto_operator_gce_start] gce_instance_start = ComputeEngineStartInstanceOperator( - project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=GCE_INSTANCE, task_id='gcp_compute_start_task' + project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=GCE_INSTANCE, task_id="gcp_compute_start_task" ) # [END howto_operator_gce_start] # Duplicate start for idempotence testing # [START howto_operator_gce_start_no_project_id] gce_instance_start2 = ComputeEngineStartInstanceOperator( - zone=GCE_ZONE, resource_id=GCE_INSTANCE, task_id='gcp_compute_start_task2' + zone=GCE_ZONE, resource_id=GCE_INSTANCE, task_id="gcp_compute_start_task2" ) # [END howto_operator_gce_start_no_project_id] # [START howto_operator_gce_stop] gce_instance_stop = ComputeEngineStopInstanceOperator( - project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=GCE_INSTANCE, task_id='gcp_compute_stop_task' + project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=GCE_INSTANCE, task_id="gcp_compute_stop_task" ) # [END howto_operator_gce_stop] # Duplicate stop for idempotence testing # [START howto_operator_gce_stop_no_project_id] gce_instance_stop2 = ComputeEngineStopInstanceOperator( - zone=GCE_ZONE, resource_id=GCE_INSTANCE, task_id='gcp_compute_stop_task2' + zone=GCE_ZONE, resource_id=GCE_INSTANCE, task_id="gcp_compute_stop_task2" ) # [END howto_operator_gce_stop_no_project_id] # [START howto_operator_gce_set_machine_type] @@ -84,8 +83,8 @@ project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=GCE_INSTANCE, - body={'machineType': f'zones/{GCE_ZONE}/machineTypes/{GCE_SHORT_MACHINE_TYPE_NAME}'}, - task_id='gcp_compute_set_machine_type', + body={"machineType": f"zones/{GCE_ZONE}/machineTypes/{GCE_SHORT_MACHINE_TYPE_NAME}"}, + task_id="gcp_compute_set_machine_type", ) # [END howto_operator_gce_set_machine_type] # Duplicate set machine type for idempotence testing @@ -93,8 +92,8 @@ gce_set_machine_type2 = ComputeEngineSetMachineTypeOperator( zone=GCE_ZONE, resource_id=GCE_INSTANCE, - body={'machineType': f'zones/{GCE_ZONE}/machineTypes/{GCE_SHORT_MACHINE_TYPE_NAME}'}, - task_id='gcp_compute_set_machine_type2', + body={"machineType": f"zones/{GCE_ZONE}/machineTypes/{GCE_SHORT_MACHINE_TYPE_NAME}"}, + task_id="gcp_compute_set_machine_type2", ) # [END howto_operator_gce_set_machine_type_no_project_id] diff --git a/airflow/providers/google/cloud/example_dags/example_compute_igm.py b/airflow/providers/google/cloud/example_dags/example_compute_igm.py deleted file mode 100644 index 7cad62ceef535..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_compute_igm.py +++ /dev/null @@ -1,143 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG that uses IGM-type compute operations: -* copy of Instance Template -* update template in Instance Group Manager - -This DAG relies on the following OS environment variables - -* GCP_PROJECT_ID - the Google Cloud project where the Compute Engine instance exists -* GCE_ZONE - the zone where the Compute Engine instance exists - -Variables for copy template operator: -* GCE_TEMPLATE_NAME - name of the template to copy -* GCE_NEW_TEMPLATE_NAME - name of the new template -* GCE_NEW_DESCRIPTION - description added to the template - -Variables for update template in Group Manager: - -* GCE_INSTANCE_GROUP_MANAGER_NAME - name of the Instance Group Manager -* SOURCE_TEMPLATE_URL - url of the template to replace in the Instance Group Manager -* DESTINATION_TEMPLATE_URL - url of the new template to set in the Instance Group Manager -""" - -import os -from datetime import datetime - -from airflow import models -from airflow.models.baseoperator import chain -from airflow.providers.google.cloud.operators.compute import ( - ComputeEngineCopyInstanceTemplateOperator, - ComputeEngineInstanceGroupUpdateManagerTemplateOperator, -) - -GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'example-project') -GCE_ZONE = os.environ.get('GCE_ZONE', 'europe-west1-b') - -# [START howto_operator_compute_template_copy_args] -GCE_TEMPLATE_NAME = os.environ.get('GCE_TEMPLATE_NAME', 'instance-template-test') -GCE_NEW_TEMPLATE_NAME = os.environ.get('GCE_NEW_TEMPLATE_NAME', 'instance-template-test-new') -GCE_NEW_DESCRIPTION = os.environ.get('GCE_NEW_DESCRIPTION', 'Test new description') -GCE_INSTANCE_TEMPLATE_BODY_UPDATE = { - "name": GCE_NEW_TEMPLATE_NAME, - "description": GCE_NEW_DESCRIPTION, - "properties": {"machineType": "n1-standard-2"}, -} -# [END howto_operator_compute_template_copy_args] - -# [START howto_operator_compute_igm_update_template_args] -GCE_INSTANCE_GROUP_MANAGER_NAME = os.environ.get('GCE_INSTANCE_GROUP_MANAGER_NAME', 'instance-group-test') - -SOURCE_TEMPLATE_URL = os.environ.get( - 'SOURCE_TEMPLATE_URL', - "https://www.googleapis.com/compute/beta/projects/" - + GCP_PROJECT_ID - + "/global/instanceTemplates/instance-template-test", -) - -DESTINATION_TEMPLATE_URL = os.environ.get( - 'DESTINATION_TEMPLATE_URL', - "https://www.googleapis.com/compute/beta/projects/" - + GCP_PROJECT_ID - + "/global/instanceTemplates/" - + GCE_NEW_TEMPLATE_NAME, -) - -UPDATE_POLICY = { - "type": "OPPORTUNISTIC", - "minimalAction": "RESTART", - "maxSurge": {"fixed": 1}, - "minReadySec": 1800, -} - -# [END howto_operator_compute_igm_update_template_args] - - -with models.DAG( - 'example_gcp_compute_igm', - schedule_interval='@once', # Override to match your needs - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['example'], -) as dag: - # [START howto_operator_gce_igm_copy_template] - gce_instance_template_copy = ComputeEngineCopyInstanceTemplateOperator( - project_id=GCP_PROJECT_ID, - resource_id=GCE_TEMPLATE_NAME, - body_patch=GCE_INSTANCE_TEMPLATE_BODY_UPDATE, - task_id='gcp_compute_igm_copy_template_task', - ) - # [END howto_operator_gce_igm_copy_template] - # Added to check for idempotence - # [START howto_operator_gce_igm_copy_template_no_project_id] - gce_instance_template_copy2 = ComputeEngineCopyInstanceTemplateOperator( - resource_id=GCE_TEMPLATE_NAME, - body_patch=GCE_INSTANCE_TEMPLATE_BODY_UPDATE, - task_id='gcp_compute_igm_copy_template_task_2', - ) - # [END howto_operator_gce_igm_copy_template_no_project_id] - # [START howto_operator_gce_igm_update_template] - gce_instance_group_manager_update_template = ComputeEngineInstanceGroupUpdateManagerTemplateOperator( - project_id=GCP_PROJECT_ID, - resource_id=GCE_INSTANCE_GROUP_MANAGER_NAME, - zone=GCE_ZONE, - source_template=SOURCE_TEMPLATE_URL, - destination_template=DESTINATION_TEMPLATE_URL, - update_policy=UPDATE_POLICY, - task_id='gcp_compute_igm_group_manager_update_template', - ) - # [END howto_operator_gce_igm_update_template] - # Added to check for idempotence (and without UPDATE_POLICY) - # [START howto_operator_gce_igm_update_template_no_project_id] - gce_instance_group_manager_update_template2 = ComputeEngineInstanceGroupUpdateManagerTemplateOperator( - resource_id=GCE_INSTANCE_GROUP_MANAGER_NAME, - zone=GCE_ZONE, - source_template=SOURCE_TEMPLATE_URL, - destination_template=DESTINATION_TEMPLATE_URL, - task_id='gcp_compute_igm_group_manager_update_template_2', - ) - # [END howto_operator_gce_igm_update_template_no_project_id] - - chain( - gce_instance_template_copy, - gce_instance_template_copy2, - gce_instance_group_manager_update_template, - gce_instance_group_manager_update_template2, - ) diff --git a/airflow/providers/google/cloud/example_dags/example_compute_ssh.py b/airflow/providers/google/cloud/example_dags/example_compute_ssh.py index c20743725342c..044789aec2a7e 100644 --- a/airflow/providers/google/cloud/example_dags/example_compute_ssh.py +++ b/airflow/providers/google/cloud/example_dags/example_compute_ssh.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import os from datetime import datetime @@ -23,17 +24,16 @@ from airflow.providers.ssh.operators.ssh import SSHOperator # [START howto_operator_gce_args_common] -GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'example-project') -GCE_ZONE = os.environ.get('GCE_ZONE', 'europe-west2-a') -GCE_INSTANCE = os.environ.get('GCE_INSTANCE', 'target-instance') +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +GCE_ZONE = os.environ.get("GCE_ZONE", "europe-west2-a") +GCE_INSTANCE = os.environ.get("GCE_INSTANCE", "target-instance") # [END howto_operator_gce_args_common] with models.DAG( - 'example_compute_ssh', - schedule_interval='@once', # Override to match your needs + "example_compute_ssh", start_date=datetime(2021, 1, 1), catchup=False, - tags=['example'], + tags=["example"], ) as dag: # # [START howto_execute_command_on_remote1] os_login_without_iap_tunnel = SSHOperator( diff --git a/airflow/providers/google/cloud/example_dags/example_datacatalog.py b/airflow/providers/google/cloud/example_dags/example_datacatalog.py deleted file mode 100644 index 848cd5ef903ce..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_datacatalog.py +++ /dev/null @@ -1,452 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG that interacts with Google Data Catalog service -""" -import os -from datetime import datetime - -from google.cloud.datacatalog_v1beta1 import FieldType, TagField, TagTemplateField -from google.protobuf.field_mask_pb2 import FieldMask - -from airflow import models -from airflow.models.baseoperator import chain -from airflow.operators.bash import BashOperator -from airflow.providers.google.cloud.operators.datacatalog import ( - CloudDataCatalogCreateEntryGroupOperator, - CloudDataCatalogCreateEntryOperator, - CloudDataCatalogCreateTagOperator, - CloudDataCatalogCreateTagTemplateFieldOperator, - CloudDataCatalogCreateTagTemplateOperator, - CloudDataCatalogDeleteEntryGroupOperator, - CloudDataCatalogDeleteEntryOperator, - CloudDataCatalogDeleteTagOperator, - CloudDataCatalogDeleteTagTemplateFieldOperator, - CloudDataCatalogDeleteTagTemplateOperator, - CloudDataCatalogGetEntryGroupOperator, - CloudDataCatalogGetEntryOperator, - CloudDataCatalogGetTagTemplateOperator, - CloudDataCatalogListTagsOperator, - CloudDataCatalogLookupEntryOperator, - CloudDataCatalogRenameTagTemplateFieldOperator, - CloudDataCatalogSearchCatalogOperator, - CloudDataCatalogUpdateEntryOperator, - CloudDataCatalogUpdateTagOperator, - CloudDataCatalogUpdateTagTemplateFieldOperator, - CloudDataCatalogUpdateTagTemplateOperator, -) - -PROJECT_ID = os.getenv("GCP_PROJECT_ID") -BUCKET_ID = os.getenv("GCP_TEST_DATA_BUCKET", "INVALID BUCKET NAME") -LOCATION = "us-central1" -ENTRY_GROUP_ID = "important_data_jan_2019" -ENTRY_ID = "python_files" -TEMPLATE_ID = "template_id" -FIELD_NAME_1 = "first" -FIELD_NAME_2 = "second" -FIELD_NAME_3 = "first-rename" - -with models.DAG( - "example_gcp_datacatalog", - schedule_interval='@once', - start_date=datetime(2021, 1, 1), - catchup=False, -) as dag: - # Create - # [START howto_operator_gcp_datacatalog_create_entry_group] - create_entry_group = CloudDataCatalogCreateEntryGroupOperator( - task_id="create_entry_group", - location=LOCATION, - entry_group_id=ENTRY_GROUP_ID, - entry_group={"display_name": "analytics data - jan 2011"}, - ) - # [END howto_operator_gcp_datacatalog_create_entry_group] - - # [START howto_operator_gcp_datacatalog_create_entry_group_result] - create_entry_group_result = BashOperator( - task_id="create_entry_group_result", - bash_command=f"echo {create_entry_group.output['entry_group_id']}", - ) - # [END howto_operator_gcp_datacatalog_create_entry_group_result] - - # [START howto_operator_gcp_datacatalog_create_entry_group_result2] - create_entry_group_result2 = BashOperator( - task_id="create_entry_group_result2", - bash_command=f"echo {create_entry_group.output}", - ) - # [END howto_operator_gcp_datacatalog_create_entry_group_result2] - - # [START howto_operator_gcp_datacatalog_create_entry_gcs] - create_entry_gcs = CloudDataCatalogCreateEntryOperator( - task_id="create_entry_gcs", - location=LOCATION, - entry_group=ENTRY_GROUP_ID, - entry_id=ENTRY_ID, - entry={ - "display_name": "Wizard", - "type_": "FILESET", - "gcs_fileset_spec": {"file_patterns": [f"gs://{BUCKET_ID}/**"]}, - }, - ) - # [END howto_operator_gcp_datacatalog_create_entry_gcs] - - # [START howto_operator_gcp_datacatalog_create_entry_gcs_result] - create_entry_gcs_result = BashOperator( - task_id="create_entry_gcs_result", - bash_command=f"echo {create_entry_gcs.output['entry_id']}", - ) - # [END howto_operator_gcp_datacatalog_create_entry_gcs_result] - - # [START howto_operator_gcp_datacatalog_create_entry_gcs_result2] - create_entry_gcs_result2 = BashOperator( - task_id="create_entry_gcs_result2", - bash_command=f"echo {create_entry_gcs.output}", - ) - # [END howto_operator_gcp_datacatalog_create_entry_gcs_result2] - - # [START howto_operator_gcp_datacatalog_create_tag] - create_tag = CloudDataCatalogCreateTagOperator( - task_id="create_tag", - location=LOCATION, - entry_group=ENTRY_GROUP_ID, - entry=ENTRY_ID, - template_id=TEMPLATE_ID, - tag={"fields": {FIELD_NAME_1: TagField(string_value="example-value-string")}}, - ) - # [END howto_operator_gcp_datacatalog_create_tag] - - # [START howto_operator_gcp_datacatalog_create_tag_result] - create_tag_result = BashOperator( - task_id="create_tag_result", - bash_command=f"echo {create_tag.output['tag_id']}", - ) - # [END howto_operator_gcp_datacatalog_create_tag_result] - - # [START howto_operator_gcp_datacatalog_create_tag_result2] - create_tag_result2 = BashOperator(task_id="create_tag_result2", bash_command=f"echo {create_tag.output}") - # [END howto_operator_gcp_datacatalog_create_tag_result2] - - # [START howto_operator_gcp_datacatalog_create_tag_template] - create_tag_template = CloudDataCatalogCreateTagTemplateOperator( - task_id="create_tag_template", - location=LOCATION, - tag_template_id=TEMPLATE_ID, - tag_template={ - "display_name": "Awesome Tag Template", - "fields": { - FIELD_NAME_1: TagTemplateField( - display_name="first-field", type_=dict(primitive_type="STRING") - ) - }, - }, - ) - # [END howto_operator_gcp_datacatalog_create_tag_template] - - # [START howto_operator_gcp_datacatalog_create_tag_template_result] - create_tag_template_result = BashOperator( - task_id="create_tag_template_result", - bash_command=f"echo {create_tag_template.output['tag_template_id']}", - ) - # [END howto_operator_gcp_datacatalog_create_tag_template_result] - - # [START howto_operator_gcp_datacatalog_create_tag_template_result2] - create_tag_template_result2 = BashOperator( - task_id="create_tag_template_result2", - bash_command=f"echo {create_tag_template.output}", - ) - # [END howto_operator_gcp_datacatalog_create_tag_template_result2] - - # [START howto_operator_gcp_datacatalog_create_tag_template_field] - create_tag_template_field = CloudDataCatalogCreateTagTemplateFieldOperator( - task_id="create_tag_template_field", - location=LOCATION, - tag_template=TEMPLATE_ID, - tag_template_field_id=FIELD_NAME_2, - tag_template_field=TagTemplateField( - display_name="second-field", type_=FieldType(primitive_type="STRING") - ), - ) - # [END howto_operator_gcp_datacatalog_create_tag_template_field] - - # [START howto_operator_gcp_datacatalog_create_tag_template_field_result] - create_tag_template_field_result = BashOperator( - task_id="create_tag_template_field_result", - bash_command=f"echo {create_tag_template_field.output['tag_template_field_id']}", - ) - # [END howto_operator_gcp_datacatalog_create_tag_template_field_result] - - # [START howto_operator_gcp_datacatalog_create_tag_template_field_result2] - create_tag_template_field_result2 = BashOperator( - task_id="create_tag_template_field_result2", - bash_command=f"echo {create_tag_template_field.output}", - ) - # [END howto_operator_gcp_datacatalog_create_tag_template_field_result2] - - # Delete - # [START howto_operator_gcp_datacatalog_delete_entry] - delete_entry = CloudDataCatalogDeleteEntryOperator( - task_id="delete_entry", location=LOCATION, entry_group=ENTRY_GROUP_ID, entry=ENTRY_ID - ) - # [END howto_operator_gcp_datacatalog_delete_entry] - - # [START howto_operator_gcp_datacatalog_delete_entry_group] - delete_entry_group = CloudDataCatalogDeleteEntryGroupOperator( - task_id="delete_entry_group", location=LOCATION, entry_group=ENTRY_GROUP_ID - ) - # [END howto_operator_gcp_datacatalog_delete_entry_group] - - # [START howto_operator_gcp_datacatalog_delete_tag] - delete_tag = CloudDataCatalogDeleteTagOperator( - task_id="delete_tag", - location=LOCATION, - entry_group=ENTRY_GROUP_ID, - entry=ENTRY_ID, - tag=create_tag.output["tag_id"], - ) - # [END howto_operator_gcp_datacatalog_delete_tag] - - # [START howto_operator_gcp_datacatalog_delete_tag_template_field] - delete_tag_template_field = CloudDataCatalogDeleteTagTemplateFieldOperator( - task_id="delete_tag_template_field", - location=LOCATION, - tag_template=TEMPLATE_ID, - field=FIELD_NAME_2, - force=True, - ) - # [END howto_operator_gcp_datacatalog_delete_tag_template_field] - - # [START howto_operator_gcp_datacatalog_delete_tag_template] - delete_tag_template = CloudDataCatalogDeleteTagTemplateOperator( - task_id="delete_tag_template", location=LOCATION, tag_template=TEMPLATE_ID, force=True - ) - # [END howto_operator_gcp_datacatalog_delete_tag_template] - - # Get - # [START howto_operator_gcp_datacatalog_get_entry_group] - get_entry_group = CloudDataCatalogGetEntryGroupOperator( - task_id="get_entry_group", - location=LOCATION, - entry_group=ENTRY_GROUP_ID, - read_mask=FieldMask(paths=["name", "display_name"]), - ) - # [END howto_operator_gcp_datacatalog_get_entry_group] - - # [START howto_operator_gcp_datacatalog_get_entry_group_result] - get_entry_group_result = BashOperator( - task_id="get_entry_group_result", - bash_command=f"echo {get_entry_group.output}", - ) - # [END howto_operator_gcp_datacatalog_get_entry_group_result] - - # [START howto_operator_gcp_datacatalog_get_entry] - get_entry = CloudDataCatalogGetEntryOperator( - task_id="get_entry", location=LOCATION, entry_group=ENTRY_GROUP_ID, entry=ENTRY_ID - ) - # [END howto_operator_gcp_datacatalog_get_entry] - - # [START howto_operator_gcp_datacatalog_get_entry_result] - get_entry_result = BashOperator(task_id="get_entry_result", bash_command=f"echo {get_entry.output}") - # [END howto_operator_gcp_datacatalog_get_entry_result] - - # [START howto_operator_gcp_datacatalog_get_tag_template] - get_tag_template = CloudDataCatalogGetTagTemplateOperator( - task_id="get_tag_template", location=LOCATION, tag_template=TEMPLATE_ID - ) - # [END howto_operator_gcp_datacatalog_get_tag_template] - - # [START howto_operator_gcp_datacatalog_get_tag_template_result] - get_tag_template_result = BashOperator( - task_id="get_tag_template_result", - bash_command=f"{get_tag_template.output}", - ) - # [END howto_operator_gcp_datacatalog_get_tag_template_result] - - # List - # [START howto_operator_gcp_datacatalog_list_tags] - list_tags = CloudDataCatalogListTagsOperator( - task_id="list_tags", location=LOCATION, entry_group=ENTRY_GROUP_ID, entry=ENTRY_ID - ) - # [END howto_operator_gcp_datacatalog_list_tags] - - # [START howto_operator_gcp_datacatalog_list_tags_result] - list_tags_result = BashOperator(task_id="list_tags_result", bash_command=f"echo {list_tags.output}") - # [END howto_operator_gcp_datacatalog_list_tags_result] - - # Lookup - # [START howto_operator_gcp_datacatalog_lookup_entry_linked_resource] - current_entry_template = ( - "//datacatalog.googleapis.com/projects/{project_id}/locations/{location}/" - "entryGroups/{entry_group}/entries/{entry}" - ) - lookup_entry_linked_resource = CloudDataCatalogLookupEntryOperator( - task_id="lookup_entry", - linked_resource=current_entry_template.format( - project_id=PROJECT_ID, location=LOCATION, entry_group=ENTRY_GROUP_ID, entry=ENTRY_ID - ), - ) - # [END howto_operator_gcp_datacatalog_lookup_entry_linked_resource] - - # [START howto_operator_gcp_datacatalog_lookup_entry_result] - lookup_entry_result = BashOperator( - task_id="lookup_entry_result", - bash_command="echo \"{{ task_instance.xcom_pull('lookup_entry')['display_name'] }}\"", - ) - # [END howto_operator_gcp_datacatalog_lookup_entry_result] - - # Rename - # [START howto_operator_gcp_datacatalog_rename_tag_template_field] - rename_tag_template_field = CloudDataCatalogRenameTagTemplateFieldOperator( - task_id="rename_tag_template_field", - location=LOCATION, - tag_template=TEMPLATE_ID, - field=FIELD_NAME_1, - new_tag_template_field_id=FIELD_NAME_3, - ) - # [END howto_operator_gcp_datacatalog_rename_tag_template_field] - - # Search - # [START howto_operator_gcp_datacatalog_search_catalog] - search_catalog = CloudDataCatalogSearchCatalogOperator( - task_id="search_catalog", scope={"include_project_ids": [PROJECT_ID]}, query=f"projectid:{PROJECT_ID}" - ) - # [END howto_operator_gcp_datacatalog_search_catalog] - - # [START howto_operator_gcp_datacatalog_search_catalog_result] - search_catalog_result = BashOperator( - task_id="search_catalog_result", - bash_command=f"echo {search_catalog.output}", - ) - # [END howto_operator_gcp_datacatalog_search_catalog_result] - - # Update - # [START howto_operator_gcp_datacatalog_update_entry] - update_entry = CloudDataCatalogUpdateEntryOperator( - task_id="update_entry", - entry={"display_name": "New Wizard"}, - update_mask={"paths": ["display_name"]}, - location=LOCATION, - entry_group=ENTRY_GROUP_ID, - entry_id=ENTRY_ID, - ) - # [END howto_operator_gcp_datacatalog_update_entry] - - # [START howto_operator_gcp_datacatalog_update_tag] - update_tag = CloudDataCatalogUpdateTagOperator( - task_id="update_tag", - tag={"fields": {FIELD_NAME_1: TagField(string_value="new-value-string")}}, - update_mask={"paths": ["fields"]}, - location=LOCATION, - entry_group=ENTRY_GROUP_ID, - entry=ENTRY_ID, - tag_id=f"{create_tag.output['tag_id']}", - ) - # [END howto_operator_gcp_datacatalog_update_tag] - - # [START howto_operator_gcp_datacatalog_update_tag_template] - update_tag_template = CloudDataCatalogUpdateTagTemplateOperator( - task_id="update_tag_template", - tag_template={"display_name": "Awesome Tag Template"}, - update_mask={"paths": ["display_name"]}, - location=LOCATION, - tag_template_id=TEMPLATE_ID, - ) - # [END howto_operator_gcp_datacatalog_update_tag_template] - - # [START howto_operator_gcp_datacatalog_update_tag_template_field] - update_tag_template_field = CloudDataCatalogUpdateTagTemplateFieldOperator( - task_id="update_tag_template_field", - tag_template_field={"display_name": "Updated template field"}, - update_mask={"paths": ["display_name"]}, - location=LOCATION, - tag_template=TEMPLATE_ID, - tag_template_field_id=FIELD_NAME_1, - ) - # [END howto_operator_gcp_datacatalog_update_tag_template_field] - - # Create - create_tasks = [ - create_entry_group, - create_entry_gcs, - create_tag_template, - create_tag_template_field, - create_tag, - ] - chain(*create_tasks) - - create_entry_group >> delete_entry_group - create_entry_group >> create_entry_group_result - create_entry_group >> create_entry_group_result2 - - create_entry_gcs >> delete_entry - create_entry_gcs >> create_entry_gcs_result - create_entry_gcs >> create_entry_gcs_result2 - - create_tag_template >> delete_tag_template_field - create_tag_template >> create_tag_template_result - create_tag_template >> create_tag_template_result2 - - create_tag_template_field >> delete_tag_template_field - create_tag_template_field >> create_tag_template_field_result - create_tag_template_field >> create_tag_template_field_result2 - - create_tag >> delete_tag - create_tag >> create_tag_result - create_tag >> create_tag_result2 - - # Delete - delete_tasks = [ - delete_tag, - delete_tag_template_field, - delete_tag_template, - delete_entry, - delete_entry_group, - ] - chain(*delete_tasks) - - # Get - create_tag_template >> get_tag_template >> delete_tag_template - get_tag_template >> get_tag_template_result - - create_entry_gcs >> get_entry >> delete_entry - get_entry >> get_entry_result - - create_entry_group >> get_entry_group >> delete_entry_group - get_entry_group >> get_entry_group_result - - # List - create_tag >> list_tags >> delete_tag - list_tags >> list_tags_result - - # Lookup - create_entry_gcs >> lookup_entry_linked_resource >> delete_entry - lookup_entry_linked_resource >> lookup_entry_result - - # Rename - update_tag >> rename_tag_template_field - create_tag_template_field >> rename_tag_template_field >> delete_tag_template_field - - # Search - chain(create_tasks, search_catalog, delete_tasks) - search_catalog >> search_catalog_result - - # Update - create_entry_gcs >> update_entry >> delete_entry - create_tag >> update_tag >> delete_tag - create_tag_template >> update_tag_template >> delete_tag_template - create_tag_template_field >> update_tag_template_field >> rename_tag_template_field diff --git a/airflow/providers/google/cloud/example_dags/example_dataflow.py b/airflow/providers/google/cloud/example_dags/example_dataflow.py index 8b1d01ff62735..f2a0860fc0d84 100644 --- a/airflow/providers/google/cloud/example_dags/example_dataflow.py +++ b/airflow/providers/google/cloud/example_dags/example_dataflow.py @@ -15,14 +15,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example Airflow DAG for Google Cloud Dataflow service """ +from __future__ import annotations + import os from datetime import datetime -from typing import Callable, Dict, List -from urllib.parse import urlparse +from typing import Callable +from urllib.parse import urlsplit from airflow import models from airflow.exceptions import AirflowException @@ -33,6 +34,7 @@ from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus from airflow.providers.google.cloud.operators.dataflow import ( CheckJobRunning, + DataflowStopJobOperator, DataflowTemplatedJobStartOperator, ) from airflow.providers.google.cloud.sensors.dataflow import ( @@ -45,29 +47,28 @@ START_DATE = datetime(2021, 1, 1) -GCS_TMP = os.environ.get('GCP_DATAFLOW_GCS_TMP', 'gs://INVALID BUCKET NAME/temp/') -GCS_STAGING = os.environ.get('GCP_DATAFLOW_GCS_STAGING', 'gs://INVALID BUCKET NAME/staging/') -GCS_OUTPUT = os.environ.get('GCP_DATAFLOW_GCS_OUTPUT', 'gs://INVALID BUCKET NAME/output') -GCS_JAR = os.environ.get('GCP_DATAFLOW_JAR', 'gs://INVALID BUCKET NAME/word-count-beam-bundled-0.1.jar') -GCS_PYTHON = os.environ.get('GCP_DATAFLOW_PYTHON', 'gs://INVALID BUCKET NAME/wordcount_debugging.py') +GCS_TMP = os.environ.get("GCP_DATAFLOW_GCS_TMP", "gs://INVALID BUCKET NAME/temp/") +GCS_STAGING = os.environ.get("GCP_DATAFLOW_GCS_STAGING", "gs://INVALID BUCKET NAME/staging/") +GCS_OUTPUT = os.environ.get("GCP_DATAFLOW_GCS_OUTPUT", "gs://INVALID BUCKET NAME/output") +GCS_JAR = os.environ.get("GCP_DATAFLOW_JAR", "gs://INVALID BUCKET NAME/word-count-beam-bundled-0.1.jar") +GCS_PYTHON = os.environ.get("GCP_DATAFLOW_PYTHON", "gs://INVALID BUCKET NAME/wordcount_debugging.py") -GCS_JAR_PARTS = urlparse(GCS_JAR) +GCS_JAR_PARTS = urlsplit(GCS_JAR) GCS_JAR_BUCKET_NAME = GCS_JAR_PARTS.netloc GCS_JAR_OBJECT_NAME = GCS_JAR_PARTS.path[1:] default_args = { - 'dataflow_default_options': { - 'tempLocation': GCS_TMP, - 'stagingLocation': GCS_STAGING, + "dataflow_default_options": { + "tempLocation": GCS_TMP, + "stagingLocation": GCS_STAGING, } } with models.DAG( "example_gcp_dataflow_native_java", - schedule_interval='@once', # Override to match your needs start_date=START_DATE, catchup=False, - tags=['example'], + tags=["example"], ) as dag_native_java: # [START howto_operator_start_java_job_jar_on_gcs] @@ -75,12 +76,12 @@ task_id="start-java-job", jar=GCS_JAR, pipeline_options={ - 'output': GCS_OUTPUT, + "output": GCS_OUTPUT, }, - job_class='org.apache.beam.examples.WordCount', + job_class="org.apache.beam.examples.WordCount", dataflow_config={ "check_if_running": CheckJobRunning.IgnoreJob, - "location": 'europe-west3', + "location": "europe-west3", "poll_sleep": 10, }, ) @@ -98,12 +99,12 @@ task_id="start-java-job-local", jar="/tmp/dataflow-{{ ds_nodash }}.jar", pipeline_options={ - 'output': GCS_OUTPUT, + "output": GCS_OUTPUT, }, - job_class='org.apache.beam.examples.WordCount', + job_class="org.apache.beam.examples.WordCount", dataflow_config={ "check_if_running": CheckJobRunning.WaitForRun, - "location": 'europe-west3', + "location": "europe-west3", "poll_sleep": 10, }, ) @@ -115,8 +116,7 @@ default_args=default_args, start_date=START_DATE, catchup=False, - schedule_interval='@once', # Override to match your needs - tags=['example'], + tags=["example"], ) as dag_native_python: # [START howto_operator_start_python_job] @@ -125,24 +125,24 @@ py_file=GCS_PYTHON, py_options=[], pipeline_options={ - 'output': GCS_OUTPUT, + "output": GCS_OUTPUT, }, - py_requirements=['apache-beam[gcp]==2.21.0'], - py_interpreter='python3', + py_requirements=["apache-beam[gcp]==2.21.0"], + py_interpreter="python3", py_system_site_packages=False, - dataflow_config={'location': 'europe-west3'}, + dataflow_config={"location": "europe-west3"}, ) # [END howto_operator_start_python_job] start_python_job_local = BeamRunPythonPipelineOperator( task_id="start-python-job-local", - py_file='apache_beam.examples.wordcount', - py_options=['-m'], + py_file="apache_beam.examples.wordcount", + py_options=["-m"], pipeline_options={ - 'output': GCS_OUTPUT, + "output": GCS_OUTPUT, }, - py_requirements=['apache-beam[gcp]==2.14.0'], - py_interpreter='python3', + py_requirements=["apache-beam[gcp]==2.14.0"], + py_interpreter="python3", py_system_site_packages=False, ) @@ -151,8 +151,7 @@ default_args=default_args, start_date=START_DATE, catchup=False, - schedule_interval='@once', # Override to match your needs - tags=['example'], + tags=["example"], ) as dag_native_python_async: # [START howto_operator_start_python_job_async] start_python_job_async = BeamRunPythonPipelineOperator( @@ -161,14 +160,14 @@ py_file=GCS_PYTHON, py_options=[], pipeline_options={ - 'output': GCS_OUTPUT, + "output": GCS_OUTPUT, }, - py_requirements=['apache-beam[gcp]==2.25.0'], - py_interpreter='python3', + py_requirements=["apache-beam[gcp]==2.25.0"], + py_interpreter="python3", py_system_site_packages=False, dataflow_config={ "job_name": "start-python-job-async", - "location": 'europe-west3', + "location": "europe-west3", "wait_until_finished": False, }, ) @@ -177,9 +176,9 @@ # [START howto_sensor_wait_for_job_status] wait_for_python_job_async_done = DataflowJobStatusSensor( task_id="wait-for-python-job-async-done", - job_id="{{task_instance.xcom_pull('start-python-job-async')['dataflow_job_id']}}", + job_id="{{task_instance.xcom_pull('start-python-job-async')['id']}}", expected_statuses={DataflowJobStatus.JOB_STATE_DONE}, - location='europe-west3', + location="europe-west3", ) # [END howto_sensor_wait_for_job_status] @@ -187,7 +186,7 @@ def check_metric_scalar_gte(metric_name: str, value: int) -> Callable: """Check is metric greater than equals to given value.""" - def callback(metrics: List[Dict]) -> bool: + def callback(metrics: list[dict]) -> bool: dag_native_python_async.log.info("Looking for '%s' >= %d", metric_name, value) for metric in metrics: context = metric.get("name", {}).get("context", {}) @@ -201,15 +200,15 @@ def callback(metrics: List[Dict]) -> bool: wait_for_python_job_async_metric = DataflowJobMetricsSensor( task_id="wait-for-python-job-async-metric", - job_id="{{task_instance.xcom_pull('start-python-job-async')['dataflow_job_id']}}", - location='europe-west3', + job_id="{{task_instance.xcom_pull('start-python-job-async')['id']}}", + location="europe-west3", callback=check_metric_scalar_gte(metric_name="Service-cpu_num_seconds", value=100), fail_on_terminal_state=False, ) # [END howto_sensor_wait_for_job_metric] # [START howto_sensor_wait_for_job_message] - def check_message(messages: List[dict]) -> bool: + def check_message(messages: list[dict]) -> bool: """Check message""" for message in messages: if "Adding workflow start and stop steps." in message.get("messageText", ""): @@ -218,15 +217,15 @@ def check_message(messages: List[dict]) -> bool: wait_for_python_job_async_message = DataflowJobMessagesSensor( task_id="wait-for-python-job-async-message", - job_id="{{task_instance.xcom_pull('start-python-job-async')['dataflow_job_id']}}", - location='europe-west3', + job_id="{{task_instance.xcom_pull('start-python-job-async')['id']}}", + location="europe-west3", callback=check_message, fail_on_terminal_state=False, ) # [END howto_sensor_wait_for_job_message] # [START howto_sensor_wait_for_job_autoscaling_event] - def check_autoscaling_event(autoscaling_events: List[dict]) -> bool: + def check_autoscaling_event(autoscaling_events: list[dict]) -> bool: """Check autoscaling event""" for autoscaling_event in autoscaling_events: if "Worker pool started." in autoscaling_event.get("description", {}).get("messageText", ""): @@ -235,8 +234,8 @@ def check_autoscaling_event(autoscaling_events: List[dict]) -> bool: wait_for_python_job_async_autoscaling_event = DataflowJobAutoScalingEventsSensor( task_id="wait-for-python-job-async-autoscaling-event", - job_id="{{task_instance.xcom_pull('start-python-job-async')['dataflow_job_id']}}", - location='europe-west3', + job_id="{{task_instance.xcom_pull('start-python-job-async')['id']}}", + location="europe-west3", callback=check_autoscaling_event, fail_on_terminal_state=False, ) @@ -253,14 +252,37 @@ def check_autoscaling_event(autoscaling_events: List[dict]) -> bool: default_args=default_args, start_date=START_DATE, catchup=False, - schedule_interval='@once', # Override to match your needs - tags=['example'], + tags=["example"], ) as dag_template: # [START howto_operator_start_template_job] start_template_job = DataflowTemplatedJobStartOperator( task_id="start-template-job", - template='gs://dataflow-templates/latest/Word_Count', - parameters={'inputFile': "gs://dataflow-samples/shakespeare/kinglear.txt", 'output': GCS_OUTPUT}, - location='europe-west3', + template="gs://dataflow-templates/latest/Word_Count", + parameters={"inputFile": "gs://dataflow-samples/shakespeare/kinglear.txt", "output": GCS_OUTPUT}, + location="europe-west3", ) # [END howto_operator_start_template_job] + +with models.DAG( + "example_gcp_stop_dataflow_job", + default_args=default_args, + start_date=START_DATE, + catchup=False, + tags=["example"], +) as dag_template: + # [START howto_operator_stop_dataflow_job] + stop_dataflow_job = DataflowStopJobOperator( + task_id="stop-dataflow-job", + location="europe-west3", + job_name_prefix="start-template-job", + ) + # [END howto_operator_stop_dataflow_job] + start_template_job = DataflowTemplatedJobStartOperator( + task_id="start-template-job", + template="gs://dataflow-templates/latest/Word_Count", + parameters={"inputFile": "gs://dataflow-samples/shakespeare/kinglear.txt", "output": GCS_OUTPUT}, + location="europe-west3", + append_job_name=False, + ) + + stop_dataflow_job >> start_template_job diff --git a/airflow/providers/google/cloud/example_dags/example_dataflow_flex_template.py b/airflow/providers/google/cloud/example_dags/example_dataflow_flex_template.py index 43d9914850973..86de7014c5293 100644 --- a/airflow/providers/google/cloud/example_dags/example_dataflow_flex_template.py +++ b/airflow/providers/google/cloud/example_dags/example_dataflow_flex_template.py @@ -15,10 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example Airflow DAG for Google Cloud Dataflow service """ +from __future__ import annotations + import os from datetime import datetime @@ -28,26 +29,25 @@ GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") DATAFLOW_FLEX_TEMPLATE_JOB_NAME = os.environ.get( - 'GCP_DATAFLOW_FLEX_TEMPLATE_JOB_NAME', "dataflow-flex-template" + "GCP_DATAFLOW_FLEX_TEMPLATE_JOB_NAME", "dataflow-flex-template" ) # For simplicity we use the same topic name as the subscription name. PUBSUB_FLEX_TEMPLATE_TOPIC = os.environ.get( - 'GCP_DATAFLOW_PUBSUB_FLEX_TEMPLATE_TOPIC', "dataflow-flex-template" + "GCP_DATAFLOW_PUBSUB_FLEX_TEMPLATE_TOPIC", "dataflow-flex-template" ) PUBSUB_FLEX_TEMPLATE_SUBSCRIPTION = PUBSUB_FLEX_TEMPLATE_TOPIC GCS_FLEX_TEMPLATE_TEMPLATE_PATH = os.environ.get( - 'GCP_DATAFLOW_GCS_FLEX_TEMPLATE_TEMPLATE_PATH', + "GCP_DATAFLOW_GCS_FLEX_TEMPLATE_TEMPLATE_PATH", "gs://INVALID BUCKET NAME/samples/dataflow/templates/streaming-beam-sql.json", ) -BQ_FLEX_TEMPLATE_DATASET = os.environ.get('GCP_DATAFLOW_BQ_FLEX_TEMPLATE_DATASET', 'airflow_dataflow_samples') -BQ_FLEX_TEMPLATE_LOCATION = os.environ.get('GCP_DATAFLOW_BQ_FLEX_TEMPLATE_LOCATION>', 'us-west1') +BQ_FLEX_TEMPLATE_DATASET = os.environ.get("GCP_DATAFLOW_BQ_FLEX_TEMPLATE_DATASET", "airflow_dataflow_samples") +BQ_FLEX_TEMPLATE_LOCATION = os.environ.get("GCP_DATAFLOW_BQ_FLEX_TEMPLATE_LOCATION>", "us-west1") with models.DAG( dag_id="example_gcp_dataflow_flex_template_java", start_date=datetime(2021, 1, 1), catchup=False, - schedule_interval='@once', # Override to match your needs ) as dag_flex_template: # [START howto_operator_start_template_job] start_flex_template = DataflowStartFlexTemplateOperator( diff --git a/airflow/providers/google/cloud/example_dags/example_dataflow_sql.py b/airflow/providers/google/cloud/example_dags/example_dataflow_sql.py index a74f5dedc1254..3ef0626f6db10 100644 --- a/airflow/providers/google/cloud/example_dags/example_dataflow_sql.py +++ b/airflow/providers/google/cloud/example_dags/example_dataflow_sql.py @@ -15,10 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example Airflow DAG for Google Cloud Dataflow service """ +from __future__ import annotations + import os from datetime import datetime @@ -38,8 +39,7 @@ dag_id="example_gcp_dataflow_sql", start_date=datetime(2021, 1, 1), catchup=False, - schedule_interval='@once', # Override to match your needs - tags=['example'], + tags=["example"], ) as dag_sql: # [START howto_operator_start_sql_job] start_sql = DataflowStartSqlJobOperator( diff --git a/airflow/providers/google/cloud/example_dags/example_datafusion.py b/airflow/providers/google/cloud/example_dags/example_datafusion.py index e442164fd686a..24b0a9b239d4f 100644 --- a/airflow/providers/google/cloud/example_dags/example_datafusion.py +++ b/airflow/providers/google/cloud/example_dags/example_datafusion.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example Airflow DAG that shows how to use DataFusion. """ +from __future__ import annotations + import os from datetime import datetime @@ -88,8 +89,8 @@ "filenameOnly": "false", "recursive": "false", "encrypted": "false", - "schema": "{\"type\":\"record\",\"name\":\"textfile\",\"fields\":[{\"name\"\ - :\"offset\",\"type\":\"long\"},{\"name\":\"body\",\"type\":\"string\"}]}", + "schema": '{"type":"record","name":"textfile","fields":[{"name"\ + :"offset","type":"long"},{"name":"body","type":"string"}]}', "path": BUCKET_1_URI, "referenceName": "foo_bucket", "useConnection": "false", @@ -98,8 +99,8 @@ "fileEncoding": "UTF-8", }, }, - "outputSchema": "{\"type\":\"record\",\"name\":\"textfile\",\"fields\"\ - :[{\"name\":\"offset\",\"type\":\"long\"},{\"name\":\"body\",\"type\":\"string\"}]}", + "outputSchema": '{"type":"record","name":"textfile","fields"\ + :[{"name":"offset","type":"long"},{"name":"body","type":"string"}]}', "id": "GCS", }, { @@ -115,21 +116,21 @@ "format": "json", "serviceFilePath": "auto-detect", "location": "us", - "schema": "{\"type\":\"record\",\"name\":\"textfile\",\"fields\":[{\"name\"\ - :\"offset\",\"type\":\"long\"},{\"name\":\"body\",\"type\":\"string\"}]}", + "schema": '{"type":"record","name":"textfile","fields":[{"name"\ + :"offset","type":"long"},{"name":"body","type":"string"}]}', "referenceName": "bar", "path": BUCKET_2_URI, "serviceAccountType": "filePath", "contentType": "application/octet-stream", }, }, - "outputSchema": "{\"type\":\"record\",\"name\":\"textfile\",\"fields\"\ - :[{\"name\":\"offset\",\"type\":\"long\"},{\"name\":\"body\",\"type\":\"string\"}]}", + "outputSchema": '{"type":"record","name":"textfile","fields"\ + :[{"name":"offset","type":"long"},{"name":"body","type":"string"}]}', "inputSchema": [ { "name": "GCS", - "schema": "{\"type\":\"record\",\"name\":\"textfile\",\"fields\":[{\"name\"\ - :\"offset\",\"type\":\"long\"},{\"name\":\"body\",\"type\":\"string\"}]}", + "schema": '{"type":"record","name":"textfile","fields":[{"name"\ + :"offset","type":"long"},{"name":"body","type":"string"}]}', } ], "id": "GCS2", @@ -147,7 +148,6 @@ with models.DAG( "example_data_fusion", - schedule_interval='@once', # Override to match your needs start_date=datetime(2021, 1, 1), catchup=False, ) as dag: diff --git a/airflow/providers/google/cloud/example_dags/example_dataplex.py b/airflow/providers/google/cloud/example_dags/example_dataplex.py deleted file mode 100644 index aabe17aed69a6..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_dataplex.py +++ /dev/null @@ -1,122 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG that shows how to use Dataplex. -""" - -import datetime -import os - -from airflow import models -from airflow.models.baseoperator import chain -from airflow.providers.google.cloud.operators.dataplex import ( - DataplexCreateTaskOperator, - DataplexDeleteTaskOperator, - DataplexGetTaskOperator, - DataplexListTasksOperator, -) -from airflow.providers.google.cloud.sensors.dataplex import DataplexTaskStateSensor - -PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "INVALID PROJECT ID") -REGION = os.environ.get("GCP_REGION", "INVALID REGION") -LAKE_ID = os.environ.get("GCP_LAKE_ID", "INVALID LAKE ID") -SERVICE_ACC = os.environ.get("GCP_DATAPLEX_SERVICE_ACC", "XYZ@developer.gserviceaccount.com") -BUCKET = os.environ.get("GCP_DATAPLEX_BUCKET", "INVALID BUCKET NAME") -SPARK_FILE_NAME = os.environ.get("SPARK_FILE_NAME", "INVALID FILE NAME") -SPARK_FILE_FULL_PATH = f"gs://{BUCKET}/{SPARK_FILE_NAME}" -DATAPLEX_TASK_ID = "task001" -TRIGGER_SPEC_TYPE = "ON_DEMAND" - -# [START howto_dataplex_configuration] -EXAMPLE_TASK_BODY = { - "trigger_spec": {"type_": TRIGGER_SPEC_TYPE}, - "execution_spec": {"service_account": SERVICE_ACC}, - "spark": {"python_script_file": SPARK_FILE_FULL_PATH}, -} -# [END howto_dataplex_configuration] - -with models.DAG( - "example_dataplex", start_date=datetime.datetime(2021, 1, 1), schedule_interval="@once", catchup=False -) as dag: - # [START howto_dataplex_create_task_operator] - create_dataplex_task = DataplexCreateTaskOperator( - project_id=PROJECT_ID, - region=REGION, - lake_id=LAKE_ID, - body=EXAMPLE_TASK_BODY, - dataplex_task_id=DATAPLEX_TASK_ID, - task_id="create_dataplex_task", - ) - # [END howto_dataplex_create_task_operator] - - # [START howto_dataplex_async_create_task_operator] - create_dataplex_task_async = DataplexCreateTaskOperator( - project_id=PROJECT_ID, - region=REGION, - lake_id=LAKE_ID, - body=EXAMPLE_TASK_BODY, - dataplex_task_id=DATAPLEX_TASK_ID, - asynchronous=True, - task_id="create_dataplex_task_async", - ) - # [END howto_dataplex_async_create_task_operator] - - # [START howto_dataplex_delete_task_operator] - delete_dataplex_task = DataplexDeleteTaskOperator( - project_id=PROJECT_ID, - region=REGION, - lake_id=LAKE_ID, - dataplex_task_id=DATAPLEX_TASK_ID, - task_id="delete_dataplex_task", - ) - # [END howto_dataplex_delete_task_operator] - - # [START howto_dataplex_list_tasks_operator] - list_dataplex_task = DataplexListTasksOperator( - project_id=PROJECT_ID, region=REGION, lake_id=LAKE_ID, task_id="list_dataplex_task" - ) - # [END howto_dataplex_list_tasks_operator] - - # [START howto_dataplex_get_task_operator] - get_dataplex_task = DataplexGetTaskOperator( - project_id=PROJECT_ID, - region=REGION, - lake_id=LAKE_ID, - dataplex_task_id=DATAPLEX_TASK_ID, - task_id="get_dataplex_task", - ) - # [END howto_dataplex_get_task_operator] - - # [START howto_dataplex_task_state_sensor] - dataplex_task_state = DataplexTaskStateSensor( - project_id=PROJECT_ID, - region=REGION, - lake_id=LAKE_ID, - dataplex_task_id=DATAPLEX_TASK_ID, - task_id="dataplex_task_state", - ) - # [END howto_dataplex_task_state_sensor] - - chain( - create_dataplex_task, - get_dataplex_task, - list_dataplex_task, - delete_dataplex_task, - create_dataplex_task_async, - dataplex_task_state, - ) diff --git a/airflow/providers/google/cloud/example_dags/example_dataprep.py b/airflow/providers/google/cloud/example_dags/example_dataprep.py deleted file mode 100644 index 1bd460a370476..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_dataprep.py +++ /dev/null @@ -1,78 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example Airflow DAG that shows how to use Google Dataprep. -""" -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.operators.dataprep import ( - DataprepGetJobGroupOperator, - DataprepGetJobsForJobGroupOperator, - DataprepRunJobGroupOperator, -) - -DATAPREP_JOB_ID = int(os.environ.get('DATAPREP_JOB_ID', 12345677)) -DATAPREP_JOB_RECIPE_ID = int(os.environ.get('DATAPREP_JOB_RECIPE_ID', 12345677)) -DATAPREP_BUCKET = os.environ.get("DATAPREP_BUCKET", "gs://INVALID BUCKET NAME/name@email.com") - -DATA = { - "wrangledDataset": {"id": DATAPREP_JOB_RECIPE_ID}, - "overrides": { - "execution": "dataflow", - "profiler": False, - "writesettings": [ - { - "path": DATAPREP_BUCKET, - "action": "create", - "format": "csv", - "compression": "none", - "header": False, - "asSingleFile": False, - } - ], - }, -} - - -with models.DAG( - "example_dataprep", - schedule_interval='@once', - start_date=datetime(2021, 1, 1), # Override to match your needs - catchup=False, -) as dag: - # [START how_to_dataprep_run_job_group_operator] - run_job_group = DataprepRunJobGroupOperator(task_id="run_job_group", body_request=DATA) - # [END how_to_dataprep_run_job_group_operator] - - # [START how_to_dataprep_get_jobs_for_job_group_operator] - get_jobs_for_job_group = DataprepGetJobsForJobGroupOperator( - task_id="get_jobs_for_job_group", job_id=DATAPREP_JOB_ID - ) - # [END how_to_dataprep_get_jobs_for_job_group_operator] - - # [START how_to_dataprep_get_job_group_operator] - get_job_group = DataprepGetJobGroupOperator( - task_id="get_job_group", - job_group_id=DATAPREP_JOB_ID, - embed="", - include_deleted=False, - ) - # [END how_to_dataprep_get_job_group_operator] - - run_job_group >> [get_jobs_for_job_group, get_job_group] diff --git a/airflow/providers/google/cloud/example_dags/example_dataproc_metastore.py b/airflow/providers/google/cloud/example_dags/example_dataproc_metastore.py deleted file mode 100644 index 91b067f8c7304..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_dataproc_metastore.py +++ /dev/null @@ -1,221 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example Airflow DAG that show how to use various Dataproc Metastore -operators to manage a service. -""" - -import datetime -import os - -from google.cloud.metastore_v1 import MetadataImport -from google.protobuf.field_mask_pb2 import FieldMask - -from airflow import models -from airflow.models.baseoperator import chain -from airflow.providers.google.cloud.operators.dataproc_metastore import ( - DataprocMetastoreCreateBackupOperator, - DataprocMetastoreCreateMetadataImportOperator, - DataprocMetastoreCreateServiceOperator, - DataprocMetastoreDeleteBackupOperator, - DataprocMetastoreDeleteServiceOperator, - DataprocMetastoreExportMetadataOperator, - DataprocMetastoreGetServiceOperator, - DataprocMetastoreListBackupsOperator, - DataprocMetastoreRestoreServiceOperator, - DataprocMetastoreUpdateServiceOperator, -) - -PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "") -SERVICE_ID = os.environ.get("GCP_DATAPROC_METASTORE_SERVICE_ID", "dataproc-metastore-system-tests-service-1") -BACKUP_ID = os.environ.get("GCP_DATAPROC_METASTORE_BACKUP_ID", "dataproc-metastore-system-tests-backup-1") -REGION = os.environ.get("GCP_REGION", "") -BUCKET = os.environ.get("GCP_DATAPROC_METASTORE_BUCKET", "INVALID BUCKET NAME") -METADATA_IMPORT_FILE = os.environ.get("GCS_METADATA_IMPORT_FILE", None) -GCS_URI = os.environ.get("GCS_URI", f"gs://{BUCKET}/data/hive.sql") -METADATA_IMPORT_ID = "dataproc-metastore-system-tests-metadata-import-1" -TIMEOUT = 1200 -DB_TYPE = "MYSQL" -DESTINATION_GCS_FOLDER = f"gs://{BUCKET}/>" - -# Service definition -# Docs: https://cloud.google.com/dataproc-metastore/docs/reference/rest/v1/projects.locations.services#Service -# [START how_to_cloud_dataproc_metastore_create_service] -SERVICE = { - "name": "test-service", -} -# [END how_to_cloud_dataproc_metastore_create_service] - -# Update service -# [START how_to_cloud_dataproc_metastore_update_service] -SERVICE_TO_UPDATE = { - "labels": { - "mylocalmachine": "mylocalmachine", - "systemtest": "systemtest", - } -} -UPDATE_MASK = FieldMask(paths=["labels"]) -# [END how_to_cloud_dataproc_metastore_update_service] - -# Backup definition -# [START how_to_cloud_dataproc_metastore_create_backup] -BACKUP = { - "name": "test-backup", -} -# [END how_to_cloud_dataproc_metastore_create_backup] - -# Metadata import definition -# [START how_to_cloud_dataproc_metastore_create_metadata_import] -METADATA_IMPORT = MetadataImport( - { - "name": "test-metadata-import", - "database_dump": { - "gcs_uri": GCS_URI, - "database_type": DB_TYPE, - }, - } -) -# [END how_to_cloud_dataproc_metastore_create_metadata_import] - - -with models.DAG( - "example_gcp_dataproc_metastore", start_date=datetime.datetime(2021, 1, 1), schedule_interval="@once" -) as dag: - # [START how_to_cloud_dataproc_metastore_create_service_operator] - create_service = DataprocMetastoreCreateServiceOperator( - task_id="create_service", - region=REGION, - project_id=PROJECT_ID, - service=SERVICE, - service_id=SERVICE_ID, - timeout=TIMEOUT, - ) - # [END how_to_cloud_dataproc_metastore_create_service_operator] - - # [START how_to_cloud_dataproc_metastore_get_service_operator] - get_service_details = DataprocMetastoreGetServiceOperator( - task_id="get_service", - region=REGION, - project_id=PROJECT_ID, - service_id=SERVICE_ID, - ) - # [END how_to_cloud_dataproc_metastore_get_service_operator] - - # [START how_to_cloud_dataproc_metastore_update_service_operator] - update_service = DataprocMetastoreUpdateServiceOperator( - task_id="update_service", - project_id=PROJECT_ID, - service_id=SERVICE_ID, - region=REGION, - service=SERVICE_TO_UPDATE, - update_mask=UPDATE_MASK, - timeout=TIMEOUT, - ) - # [END how_to_cloud_dataproc_metastore_update_service_operator] - - # [START how_to_cloud_dataproc_metastore_create_metadata_import_operator] - import_metadata = DataprocMetastoreCreateMetadataImportOperator( - task_id="create_metadata_import", - project_id=PROJECT_ID, - region=REGION, - service_id=SERVICE_ID, - metadata_import=METADATA_IMPORT, - metadata_import_id=METADATA_IMPORT_ID, - timeout=TIMEOUT, - ) - # [END how_to_cloud_dataproc_metastore_create_metadata_import_operator] - - # [START how_to_cloud_dataproc_metastore_export_metadata_operator] - export_metadata = DataprocMetastoreExportMetadataOperator( - task_id="export_metadata", - destination_gcs_folder=DESTINATION_GCS_FOLDER, - project_id=PROJECT_ID, - region=REGION, - service_id=SERVICE_ID, - timeout=TIMEOUT, - ) - # [END how_to_cloud_dataproc_metastore_export_metadata_operator] - - # [START how_to_cloud_dataproc_metastore_create_backup_operator] - backup_service = DataprocMetastoreCreateBackupOperator( - task_id="create_backup", - project_id=PROJECT_ID, - region=REGION, - service_id=SERVICE_ID, - backup=BACKUP, - backup_id=BACKUP_ID, - timeout=TIMEOUT, - ) - # [END how_to_cloud_dataproc_metastore_create_backup_operator] - - # [START how_to_cloud_dataproc_metastore_list_backups_operator] - list_backups = DataprocMetastoreListBackupsOperator( - task_id="list_backups", - project_id=PROJECT_ID, - region=REGION, - service_id=SERVICE_ID, - ) - # [END how_to_cloud_dataproc_metastore_list_backups_operator] - - # [START how_to_cloud_dataproc_metastore_delete_backup_operator] - delete_backup = DataprocMetastoreDeleteBackupOperator( - task_id="delete_backup", - project_id=PROJECT_ID, - region=REGION, - service_id=SERVICE_ID, - backup_id=BACKUP_ID, - timeout=TIMEOUT, - ) - # [END how_to_cloud_dataproc_metastore_delete_backup_operator] - - # [START how_to_cloud_dataproc_metastore_restore_service_operator] - restore_service = DataprocMetastoreRestoreServiceOperator( - task_id="restore_metastore", - region=REGION, - project_id=PROJECT_ID, - service_id=SERVICE_ID, - backup_id=BACKUP_ID, - backup_region=REGION, - backup_project_id=PROJECT_ID, - backup_service_id=SERVICE_ID, - timeout=TIMEOUT, - ) - # [END how_to_cloud_dataproc_metastore_restore_service_operator] - - # [START how_to_cloud_dataproc_metastore_delete_service_operator] - delete_service = DataprocMetastoreDeleteServiceOperator( - task_id="delete_service", - region=REGION, - project_id=PROJECT_ID, - service_id=SERVICE_ID, - timeout=TIMEOUT, - ) - # [END how_to_cloud_dataproc_metastore_delete_service_operator] - - chain( - create_service, - update_service, - get_service_details, - backup_service, - list_backups, - restore_service, - delete_backup, - export_metadata, - import_metadata, - delete_service, - ) diff --git a/airflow/providers/google/cloud/example_dags/example_dlp.py b/airflow/providers/google/cloud/example_dags/example_dlp.py deleted file mode 100644 index 480fda1f0d984..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_dlp.py +++ /dev/null @@ -1,219 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG that execute the following tasks using -Cloud DLP service in the Google Cloud: -1) Creating a content inspect template; -2) Using the created template to inspect content; -3) Deleting the template from Google Cloud . -""" - -import os -from datetime import datetime - -from google.cloud.dlp_v2.types import ContentItem, InspectConfig, InspectTemplate - -from airflow import models -from airflow.providers.google.cloud.operators.dlp import ( - CloudDLPCreateInspectTemplateOperator, - CloudDLPCreateJobTriggerOperator, - CloudDLPCreateStoredInfoTypeOperator, - CloudDLPDeidentifyContentOperator, - CloudDLPDeleteInspectTemplateOperator, - CloudDLPDeleteJobTriggerOperator, - CloudDLPDeleteStoredInfoTypeOperator, - CloudDLPInspectContentOperator, - CloudDLPUpdateJobTriggerOperator, - CloudDLPUpdateStoredInfoTypeOperator, -) - -START_DATE = datetime(2021, 1, 1) - -GCP_PROJECT = os.environ.get("GCP_PROJECT_ID", "example-project") -TEMPLATE_ID = "dlp-inspect-838746" -ITEM = ContentItem( - table={ - "headers": [{"name": "column1"}], - "rows": [{"values": [{"string_value": "My phone number is (206) 555-0123"}]}], - } -) -INSPECT_CONFIG = InspectConfig(info_types=[{"name": "PHONE_NUMBER"}, {"name": "US_TOLLFREE_PHONE_NUMBER"}]) -INSPECT_TEMPLATE = InspectTemplate(inspect_config=INSPECT_CONFIG) -OUTPUT_BUCKET = os.environ.get("DLP_OUTPUT_BUCKET", "gs://INVALID BUCKET NAME") -OUTPUT_FILENAME = "test.txt" - -OBJECT_GCS_URI = os.path.join(OUTPUT_BUCKET, "tmp") -OBJECT_GCS_OUTPUT_URI = os.path.join(OUTPUT_BUCKET, "tmp", OUTPUT_FILENAME) - -with models.DAG( - "example_gcp_dlp", - schedule_interval='@once', # Override to match your needs - start_date=START_DATE, - catchup=False, - tags=['example'], -) as dag1: - # [START howto_operator_dlp_create_inspect_template] - create_template = CloudDLPCreateInspectTemplateOperator( - project_id=GCP_PROJECT, - inspect_template=INSPECT_TEMPLATE, - template_id=TEMPLATE_ID, - task_id="create_template", - do_xcom_push=True, - ) - # [END howto_operator_dlp_create_inspect_template] - - # [START howto_operator_dlp_use_inspect_template] - inspect_content = CloudDLPInspectContentOperator( - task_id="inspect_content", - project_id=GCP_PROJECT, - item=ITEM, - inspect_template_name="{{ task_instance.xcom_pull('create_template', key='return_value')['name'] }}", - ) - # [END howto_operator_dlp_use_inspect_template] - - # [START howto_operator_dlp_delete_inspect_template] - delete_template = CloudDLPDeleteInspectTemplateOperator( - task_id="delete_template", - template_id=TEMPLATE_ID, - project_id=GCP_PROJECT, - ) - # [END howto_operator_dlp_delete_inspect_template] - - create_template >> inspect_content >> delete_template - -CUSTOM_INFO_TYPE_ID = "custom_info_type" -CUSTOM_INFO_TYPES = { - "large_custom_dictionary": { - "output_path": {"path": OBJECT_GCS_OUTPUT_URI}, - "cloud_storage_file_set": {"url": OBJECT_GCS_URI + "/"}, - } -} -UPDATE_CUSTOM_INFO_TYPE = { - "large_custom_dictionary": { - "output_path": {"path": OBJECT_GCS_OUTPUT_URI}, - "cloud_storage_file_set": {"url": OBJECT_GCS_URI + "/"}, - } -} - -with models.DAG( - "example_gcp_dlp_info_types", - schedule_interval='@once', - start_date=START_DATE, - catchup=False, - tags=["example", "dlp", "info-types"], -) as dag2: - # [START howto_operator_dlp_create_info_type] - create_info_type = CloudDLPCreateStoredInfoTypeOperator( - project_id=GCP_PROJECT, - config=CUSTOM_INFO_TYPES, - stored_info_type_id=CUSTOM_INFO_TYPE_ID, - task_id="create_info_type", - ) - # [END howto_operator_dlp_create_info_type] - # [START howto_operator_dlp_update_info_type] - update_info_type = CloudDLPUpdateStoredInfoTypeOperator( - project_id=GCP_PROJECT, - stored_info_type_id=CUSTOM_INFO_TYPE_ID, - config=UPDATE_CUSTOM_INFO_TYPE, - task_id="update_info_type", - ) - # [END howto_operator_dlp_update_info_type] - # [START howto_operator_dlp_delete_info_type] - delete_info_type = CloudDLPDeleteStoredInfoTypeOperator( - project_id=GCP_PROJECT, - stored_info_type_id=CUSTOM_INFO_TYPE_ID, - task_id="delete_info_type", - ) - # [END howto_operator_dlp_delete_info_type] - create_info_type >> update_info_type >> delete_info_type - -JOB_TRIGGER = { - "inspect_job": { - "storage_config": { - "datastore_options": {"partition_id": {"project_id": GCP_PROJECT}, "kind": {"name": "test"}} - } - }, - "triggers": [{"schedule": {"recurrence_period_duration": {"seconds": 60 * 60 * 24}}}], - "status": "HEALTHY", -} - -TRIGGER_ID = "example_trigger" - -with models.DAG( - "example_gcp_dlp_job", - schedule_interval='@once', - start_date=START_DATE, - catchup=False, - tags=["example", "dlp_job"], -) as dag3: # [START howto_operator_dlp_create_job_trigger] - create_trigger = CloudDLPCreateJobTriggerOperator( - project_id=GCP_PROJECT, - job_trigger=JOB_TRIGGER, - trigger_id=TRIGGER_ID, - task_id="create_trigger", - ) - # [END howto_operator_dlp_create_job_trigger] - - JOB_TRIGGER["triggers"] = [{"schedule": {"recurrence_period_duration": {"seconds": 2 * 60 * 60 * 24}}}] - - # [START howto_operator_dlp_update_job_trigger] - update_trigger = CloudDLPUpdateJobTriggerOperator( - project_id=GCP_PROJECT, - job_trigger_id=TRIGGER_ID, - job_trigger=JOB_TRIGGER, - task_id="update_info_type", - ) - # [END howto_operator_dlp_update_job_trigger] - # [START howto_operator_dlp_delete_job_trigger] - delete_trigger = CloudDLPDeleteJobTriggerOperator( - project_id=GCP_PROJECT, job_trigger_id=TRIGGER_ID, task_id="delete_info_type" - ) - # [END howto_operator_dlp_delete_job_trigger] - create_trigger >> update_trigger >> delete_trigger - -# [START dlp_deidentify_config_example] -DEIDENTIFY_CONFIG = { - "info_type_transformations": { - "transformations": [ - { - "primitive_transformation": { - "replace_config": {"new_value": {"string_value": "[deidentified_number]"}} - } - } - ] - } -} -# [END dlp_deidentify_config_example] - -with models.DAG( - "example_gcp_dlp_deidentify_content", - schedule_interval='@once', - start_date=START_DATE, - catchup=False, - tags=["example", "dlp", "deidentify"], -) as dag4: - # [START _howto_operator_dlp_deidentify_content] - deidentify_content = CloudDLPDeidentifyContentOperator( - project_id=GCP_PROJECT, - item=ITEM, - deidentify_config=DEIDENTIFY_CONFIG, - inspect_config=INSPECT_CONFIG, - task_id="deidentify_content", - ) - # [END _howto_operator_dlp_deidentify_content] diff --git a/airflow/providers/google/cloud/example_dags/example_facebook_ads_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_facebook_ads_to_gcs.py index bd80091e48dc5..96da2ad126b90 100644 --- a/airflow/providers/google/cloud/example_dags/example_facebook_ads_to_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_facebook_ads_to_gcs.py @@ -18,6 +18,8 @@ """ Example Airflow DAG that shows how to use FacebookAdsReportToGcsOperator. """ +from __future__ import annotations + import os from datetime import datetime @@ -52,12 +54,11 @@ AdsInsights.Field.clicks, AdsInsights.Field.impressions, ] -PARAMETERS = {'level': 'ad', 'date_preset': 'yesterday'} +PARAMETERS = {"level": "ad", "date_preset": "yesterday"} # [END howto_FB_ADS_variables] with models.DAG( "example_facebook_ads_to_gcs", - schedule_interval='@once', # Override to match your needs start_date=datetime(2021, 1, 1), catchup=False, ) as dag: @@ -78,18 +79,18 @@ dataset_id=DATASET_NAME, table_id=TABLE_NAME, schema_fields=[ - {'name': 'campaign_name', 'type': 'STRING', 'mode': 'NULLABLE'}, - {'name': 'campaign_id', 'type': 'STRING', 'mode': 'NULLABLE'}, - {'name': 'ad_id', 'type': 'STRING', 'mode': 'NULLABLE'}, - {'name': 'clicks', 'type': 'STRING', 'mode': 'NULLABLE'}, - {'name': 'impressions', 'type': 'STRING', 'mode': 'NULLABLE'}, + {"name": "campaign_name", "type": "STRING", "mode": "NULLABLE"}, + {"name": "campaign_id", "type": "STRING", "mode": "NULLABLE"}, + {"name": "ad_id", "type": "STRING", "mode": "NULLABLE"}, + {"name": "clicks", "type": "STRING", "mode": "NULLABLE"}, + {"name": "impressions", "type": "STRING", "mode": "NULLABLE"}, ], ) # [START howto_operator_facebook_ads_to_gcs] run_operator = FacebookAdsReportToGcsOperator( - task_id='run_fetch_data', - owner='airflow', + task_id="run_fetch_data", + owner="airflow", bucket_name=GCS_BUCKET, parameters=PARAMETERS, fields=FIELDS, @@ -99,11 +100,11 @@ # [END howto_operator_facebook_ads_to_gcs] load_csv = GCSToBigQueryOperator( - task_id='gcs_to_bq_example', + task_id="gcs_to_bq_example", bucket=GCS_BUCKET, source_objects=[GCS_OBJ_PATH], destination_project_dataset_table=f"{DATASET_NAME}.{TABLE_NAME}", - write_disposition='WRITE_TRUNCATE', + write_disposition="WRITE_TRUNCATE", ) read_data_from_gcs_many_chunks = BigQueryInsertJobOperator( diff --git a/airflow/providers/google/cloud/example_dags/example_functions.py b/airflow/providers/google/cloud/example_dags/example_functions.py deleted file mode 100644 index 5ce7f5e218206..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_functions.py +++ /dev/null @@ -1,128 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG that displays interactions with Google Cloud Functions. -It creates a function and then deletes it. - -This DAG relies on the following OS environment variables -https://airflow.apache.org/concepts.html#variables - -* GCP_PROJECT_ID - Google Cloud Project to use for the Cloud Function. -* GCP_LOCATION - Google Cloud Functions region where the function should be - created. -* GCF_ENTRYPOINT - Name of the executable function in the source code. -* and one of the below: - - * GCF_SOURCE_ARCHIVE_URL - Path to the zipped source in Google Cloud Storage - - * GCF_SOURCE_UPLOAD_URL - Generated upload URL for the zipped source and GCF_ZIP_PATH - Local path to - the zipped source archive - - * GCF_SOURCE_REPOSITORY - The URL pointing to the hosted repository where the function - is defined in a supported Cloud Source Repository URL format - https://cloud.google.com/functions/docs/reference/rest/v1/projects.locations.functions#SourceRepository - -""" - -import os -from datetime import datetime -from typing import Any, Dict - -from airflow import models -from airflow.providers.google.cloud.operators.functions import ( - CloudFunctionDeleteFunctionOperator, - CloudFunctionDeployFunctionOperator, - CloudFunctionInvokeFunctionOperator, -) - -GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'example-project') -GCP_LOCATION = os.environ.get('GCP_LOCATION', 'europe-west1') -# make sure there are no dashes in function name (!) -GCF_SHORT_FUNCTION_NAME = os.environ.get('GCF_SHORT_FUNCTION_NAME', 'hello').replace("-", "_") -FUNCTION_NAME = f'projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/functions/{GCF_SHORT_FUNCTION_NAME}' -GCF_SOURCE_ARCHIVE_URL = os.environ.get('GCF_SOURCE_ARCHIVE_URL', '') -GCF_SOURCE_UPLOAD_URL = os.environ.get('GCF_SOURCE_UPLOAD_URL', '') -GCF_SOURCE_REPOSITORY = os.environ.get( - 'GCF_SOURCE_REPOSITORY', - f'https://source.developers.google.com/projects/{GCP_PROJECT_ID}/' - f'repos/hello-world/moveable-aliases/master', -) -GCF_ZIP_PATH = os.environ.get('GCF_ZIP_PATH', '') -GCF_ENTRYPOINT = os.environ.get('GCF_ENTRYPOINT', 'helloWorld') -GCF_RUNTIME = 'nodejs14' -GCP_VALIDATE_BODY = os.environ.get('GCP_VALIDATE_BODY', "True") == "True" - -# [START howto_operator_gcf_deploy_body] -body = {"name": FUNCTION_NAME, "entryPoint": GCF_ENTRYPOINT, "runtime": GCF_RUNTIME, "httpsTrigger": {}} -# [END howto_operator_gcf_deploy_body] - -# [START howto_operator_gcf_default_args] -default_args: Dict[str, Any] = {'retries': 3} -# [END howto_operator_gcf_default_args] - -# [START howto_operator_gcf_deploy_variants] -if GCF_SOURCE_ARCHIVE_URL: - body['sourceArchiveUrl'] = GCF_SOURCE_ARCHIVE_URL -elif GCF_SOURCE_REPOSITORY: - body['sourceRepository'] = {'url': GCF_SOURCE_REPOSITORY} -elif GCF_ZIP_PATH: - body['sourceUploadUrl'] = '' - default_args['zip_path'] = GCF_ZIP_PATH -elif GCF_SOURCE_UPLOAD_URL: - body['sourceUploadUrl'] = GCF_SOURCE_UPLOAD_URL -else: - raise Exception("Please provide one of the source_code parameters") -# [END howto_operator_gcf_deploy_variants] - - -with models.DAG( - 'example_gcp_function', - default_args=default_args, - schedule_interval='@once', # Override to match your needs - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['example'], -) as dag: - # [START howto_operator_gcf_deploy] - deploy_task = CloudFunctionDeployFunctionOperator( - task_id="gcf_deploy_task", - project_id=GCP_PROJECT_ID, - location=GCP_LOCATION, - body=body, - validate_body=GCP_VALIDATE_BODY, - ) - # [END howto_operator_gcf_deploy] - # [START howto_operator_gcf_deploy_no_project_id] - deploy2_task = CloudFunctionDeployFunctionOperator( - task_id="gcf_deploy2_task", location=GCP_LOCATION, body=body, validate_body=GCP_VALIDATE_BODY - ) - # [END howto_operator_gcf_deploy_no_project_id] - # [START howto_operator_gcf_invoke_function] - invoke_task = CloudFunctionInvokeFunctionOperator( - task_id="invoke_task", - project_id=GCP_PROJECT_ID, - location=GCP_LOCATION, - input_data={}, - function_id=GCF_SHORT_FUNCTION_NAME, - ) - # [END howto_operator_gcf_invoke_function] - # [START howto_operator_gcf_delete] - delete_task = CloudFunctionDeleteFunctionOperator(task_id="gcf_delete_task", name=FUNCTION_NAME) - # [END howto_operator_gcf_delete] - deploy_task >> deploy2_task >> invoke_task >> delete_task diff --git a/airflow/providers/google/cloud/example_dags/example_gcs_to_sftp.py b/airflow/providers/google/cloud/example_dags/example_gcs_to_sftp.py deleted file mode 100644 index ff0431f0c3690..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_gcs_to_sftp.py +++ /dev/null @@ -1,121 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example Airflow DAG for Google Cloud Storage to SFTP transfer operators. -""" - -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.transfers.gcs_to_sftp import GCSToSFTPOperator -from airflow.providers.sftp.sensors.sftp import SFTPSensor - -SFTP_CONN_ID = "ssh_default" -BUCKET_SRC = os.environ.get("GCP_GCS_BUCKET_1_SRC", "test-gcs-sftp") -OBJECT_SRC_1 = "parent-1.bin" -OBJECT_SRC_2 = "dir-1/parent-2.bin" -OBJECT_SRC_3 = "dir-2/*" -DESTINATION_PATH_1 = "/tmp/single-file/" -DESTINATION_PATH_2 = "/tmp/dest-dir-1/" -DESTINATION_PATH_3 = "/tmp/dest-dir-2/" - - -with models.DAG( - "example_gcs_to_sftp", - schedule_interval='@once', - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['example'], -) as dag: - # [START howto_operator_gcs_to_sftp_copy_single_file] - copy_file_from_gcs_to_sftp = GCSToSFTPOperator( - task_id="file-copy-gsc-to-sftp", - sftp_conn_id=SFTP_CONN_ID, - source_bucket=BUCKET_SRC, - source_object=OBJECT_SRC_1, - destination_path=DESTINATION_PATH_1, - ) - # [END howto_operator_gcs_to_sftp_copy_single_file] - - check_copy_file_from_gcs_to_sftp = SFTPSensor( - task_id="check-file-copy-gsc-to-sftp", - sftp_conn_id=SFTP_CONN_ID, - timeout=60, - path=os.path.join(DESTINATION_PATH_1, OBJECT_SRC_1), - ) - - # [START howto_operator_gcs_to_sftp_move_single_file_destination] - move_file_from_gcs_to_sftp = GCSToSFTPOperator( - task_id="file-move-gsc-to-sftp", - sftp_conn_id=SFTP_CONN_ID, - source_bucket=BUCKET_SRC, - source_object=OBJECT_SRC_2, - destination_path=DESTINATION_PATH_1, - move_object=True, - ) - # [END howto_operator_gcs_to_sftp_move_single_file_destination] - - check_move_file_from_gcs_to_sftp = SFTPSensor( - task_id="check-file-move-gsc-to-sftp", - sftp_conn_id=SFTP_CONN_ID, - timeout=60, - path=os.path.join(DESTINATION_PATH_1, OBJECT_SRC_2), - ) - - # [START howto_operator_gcs_to_sftp_copy_directory] - copy_dir_from_gcs_to_sftp = GCSToSFTPOperator( - task_id="dir-copy-gsc-to-sftp", - sftp_conn_id=SFTP_CONN_ID, - source_bucket=BUCKET_SRC, - source_object=OBJECT_SRC_3, - destination_path=DESTINATION_PATH_2, - ) - # [END howto_operator_gcs_to_sftp_copy_directory] - - check_copy_dir_from_gcs_to_sftp = SFTPSensor( - task_id="check-dir-copy-gsc-to-sftp", - sftp_conn_id=SFTP_CONN_ID, - timeout=60, - path=os.path.join(DESTINATION_PATH_2, "dir-2", OBJECT_SRC_1), - ) - - # [START howto_operator_gcs_to_sftp_move_specific_files] - move_dir_from_gcs_to_sftp = GCSToSFTPOperator( - task_id="dir-move-gsc-to-sftp", - sftp_conn_id=SFTP_CONN_ID, - source_bucket=BUCKET_SRC, - source_object=OBJECT_SRC_3, - destination_path=DESTINATION_PATH_3, - keep_directory_structure=False, - ) - # [END howto_operator_gcs_to_sftp_move_specific_files] - - check_move_dir_from_gcs_to_sftp = SFTPSensor( - task_id="check-dir-move-gsc-to-sftp", - sftp_conn_id=SFTP_CONN_ID, - timeout=60, - path=os.path.join(DESTINATION_PATH_3, OBJECT_SRC_1), - ) - - move_file_from_gcs_to_sftp >> check_move_file_from_gcs_to_sftp - copy_dir_from_gcs_to_sftp >> check_copy_file_from_gcs_to_sftp - - copy_dir_from_gcs_to_sftp >> move_dir_from_gcs_to_sftp - copy_dir_from_gcs_to_sftp >> check_copy_dir_from_gcs_to_sftp - move_dir_from_gcs_to_sftp >> check_move_dir_from_gcs_to_sftp diff --git a/airflow/providers/google/cloud/example_dags/example_gdrive_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_gdrive_to_gcs.py deleted file mode 100644 index bb656340f1fba..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_gdrive_to_gcs.py +++ /dev/null @@ -1,52 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.transfers.gdrive_to_gcs import GoogleDriveToGCSOperator -from airflow.providers.google.suite.sensors.drive import GoogleDriveFileExistenceSensor - -BUCKET = os.environ.get("GCP_GCS_BUCKET", "test28397yeo") -OBJECT = os.environ.get("GCP_GCS_OBJECT", "abc123xyz") -FOLDER_ID = os.environ.get("FILE_ID", "1234567890qwerty") -FILE_NAME = os.environ.get("FILE_NAME", "file.pdf") - -with models.DAG( - "example_gdrive_to_gcs_with_gdrive_sensor", - start_date=datetime(2021, 1, 1), - catchup=False, - schedule_interval='@once', # Override to match your needs - tags=["example"], -) as dag: - # [START detect_file] - detect_file = GoogleDriveFileExistenceSensor( - task_id="detect_file", folder_id=FOLDER_ID, file_name=FILE_NAME - ) - # [END detect_file] - # [START upload_gdrive_to_gcs] - upload_gdrive_to_gcs = GoogleDriveToGCSOperator( - task_id="upload_gdrive_object_to_gcs", - folder_id=FOLDER_ID, - file_name=FILE_NAME, - bucket_name=BUCKET, - object_name=OBJECT, - ) - # [END upload_gdrive_to_gcs] - detect_file >> upload_gdrive_to_gcs diff --git a/airflow/providers/google/cloud/example_dags/example_gdrive_to_local.py b/airflow/providers/google/cloud/example_dags/example_gdrive_to_local.py deleted file mode 100644 index 2ba38ea81311a..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_gdrive_to_local.py +++ /dev/null @@ -1,50 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.transfers.gdrive_to_local import GoogleDriveToLocalOperator -from airflow.providers.google.suite.sensors.drive import GoogleDriveFileExistenceSensor - -FOLDER_ID = os.environ.get("FILE_ID", "1234567890qwerty") -FILE_NAME = os.environ.get("FILE_NAME", "file.pdf") -OUTPUT_FILE = os.environ.get("OUTPUT_FILE", "out_file.pdf") - -with models.DAG( - "example_gdrive_to_local_with_gdrive_sensor", - start_date=datetime(2021, 1, 1), - catchup=False, - schedule_interval=None, # Override to match your needs - tags=["example"], -) as dag: - # [START detect_file] - detect_file = GoogleDriveFileExistenceSensor( - task_id="detect_file", folder_id=FOLDER_ID, file_name=FILE_NAME - ) - # [END detect_file] - # [START download_from_gdrive_to_local] - download_from_gdrive_to_local = GoogleDriveToLocalOperator( - task_id="download_from_gdrive_to_local", - folder_id=FOLDER_ID, - file_name=FILE_NAME, - output_file=OUTPUT_FILE, - ) - # [END download_from_gdrive_to_local] - detect_file >> download_from_gdrive_to_local diff --git a/airflow/providers/google/cloud/example_dags/example_life_sciences.py b/airflow/providers/google/cloud/example_dags/example_life_sciences.py deleted file mode 100644 index 0503a8e1e340a..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_life_sciences.py +++ /dev/null @@ -1,97 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.operators.life_sciences import LifeSciencesRunPipelineOperator - -PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project-id") -BUCKET = os.environ.get("GCP_GCS_LIFE_SCIENCES_BUCKET", "INVALID BUCKET NAME") -FILENAME = os.environ.get("GCP_GCS_LIFE_SCIENCES_FILENAME", 'input.in') -LOCATION = os.environ.get("GCP_LIFE_SCIENCES_LOCATION", 'us-central1') - - -# [START howto_configure_simple_action_pipeline] -SIMPLE_ACTION_PIPELINE = { - "pipeline": { - "actions": [ - {"imageUri": "bash", "commands": ["-c", "echo Hello, world"]}, - ], - "resources": { - "regions": [f"{LOCATION}"], - "virtualMachine": { - "machineType": "n1-standard-1", - }, - }, - }, -} -# [END howto_configure_simple_action_pipeline] - -# [START howto_configure_multiple_action_pipeline] -MULTI_ACTION_PIPELINE = { - "pipeline": { - "actions": [ - { - "imageUri": "google/cloud-sdk", - "commands": ["gsutil", "cp", f"gs://{BUCKET}/{FILENAME}", "/tmp"], - }, - {"imageUri": "bash", "commands": ["-c", "echo Hello, world"]}, - { - "imageUri": "google/cloud-sdk", - "commands": [ - "gsutil", - "cp", - f"gs://{BUCKET}/{FILENAME}", - f"gs://{BUCKET}/output.in", - ], - }, - ], - "resources": { - "regions": [f"{LOCATION}"], - "virtualMachine": { - "machineType": "n1-standard-1", - }, - }, - } -} -# [END howto_configure_multiple_action_pipeline] - -with models.DAG( - "example_gcp_life_sciences", - schedule_interval='@once', - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['example'], -) as dag: - - # [START howto_run_pipeline] - simple_life_science_action_pipeline = LifeSciencesRunPipelineOperator( - task_id='simple-action-pipeline', - body=SIMPLE_ACTION_PIPELINE, - project_id=PROJECT_ID, - location=LOCATION, - ) - # [END howto_run_pipeline] - - multiple_life_science_action_pipeline = LifeSciencesRunPipelineOperator( - task_id='multi-action-pipeline', body=MULTI_ACTION_PIPELINE, project_id=PROJECT_ID, location=LOCATION - ) - - simple_life_science_action_pipeline >> multiple_life_science_action_pipeline diff --git a/airflow/providers/google/cloud/example_dags/example_looker.py b/airflow/providers/google/cloud/example_dags/example_looker.py index ece60f3823398..36cc0f56280bd 100644 --- a/airflow/providers/google/cloud/example_dags/example_looker.py +++ b/airflow/providers/google/cloud/example_dags/example_looker.py @@ -19,6 +19,7 @@ Example Airflow DAG that show how to use various Looker operators to submit PDT materialization job and manage it. """ +from __future__ import annotations from datetime import datetime @@ -27,23 +28,22 @@ from airflow.providers.google.cloud.sensors.looker import LookerCheckPdtBuildSensor with models.DAG( - dag_id='example_gcp_looker', - schedule_interval=None, + dag_id="example_gcp_looker", start_date=datetime(2021, 1, 1), catchup=False, ) as dag: # [START cloud_looker_async_start_pdt_sensor] start_pdt_task_async = LookerStartPdtBuildOperator( - task_id='start_pdt_task_async', - looker_conn_id='your_airflow_connection_for_looker', - model='your_lookml_model', - view='your_lookml_view', + task_id="start_pdt_task_async", + looker_conn_id="your_airflow_connection_for_looker", + model="your_lookml_model", + view="your_lookml_view", asynchronous=True, ) check_pdt_task_async_sensor = LookerCheckPdtBuildSensor( - task_id='check_pdt_task_async_sensor', - looker_conn_id='your_airflow_connection_for_looker', + task_id="check_pdt_task_async_sensor", + looker_conn_id="your_airflow_connection_for_looker", materialization_id=start_pdt_task_async.output, poke_interval=10, ) @@ -51,10 +51,10 @@ # [START how_to_cloud_looker_start_pdt_build_operator] build_pdt_task = LookerStartPdtBuildOperator( - task_id='build_pdt_task', - looker_conn_id='your_airflow_connection_for_looker', - model='your_lookml_model', - view='your_lookml_view', + task_id="build_pdt_task", + looker_conn_id="your_airflow_connection_for_looker", + model="your_lookml_model", + view="your_lookml_view", ) # [END how_to_cloud_looker_start_pdt_build_operator] diff --git a/airflow/providers/google/cloud/example_dags/example_mlengine.py b/airflow/providers/google/cloud/example_dags/example_mlengine.py deleted file mode 100644 index 7db060a08319a..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_mlengine.py +++ /dev/null @@ -1,294 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG for Google ML Engine service. -""" -import os -from datetime import datetime -from typing import Any, Dict - -from airflow import models -from airflow.operators.bash import BashOperator -from airflow.providers.google.cloud.operators.mlengine import ( - MLEngineCreateModelOperator, - MLEngineCreateVersionOperator, - MLEngineDeleteModelOperator, - MLEngineDeleteVersionOperator, - MLEngineGetModelOperator, - MLEngineListVersionsOperator, - MLEngineSetDefaultVersionOperator, - MLEngineStartBatchPredictionJobOperator, - MLEngineStartTrainingJobOperator, -) -from airflow.providers.google.cloud.utils import mlengine_operator_utils - -PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") - -MODEL_NAME = os.environ.get("GCP_MLENGINE_MODEL_NAME", "model_name") - -SAVED_MODEL_PATH = os.environ.get("GCP_MLENGINE_SAVED_MODEL_PATH", "gs://INVALID BUCKET NAME/saved-model/") -JOB_DIR = os.environ.get("GCP_MLENGINE_JOB_DIR", "gs://INVALID BUCKET NAME/keras-job-dir") -PREDICTION_INPUT = os.environ.get( - "GCP_MLENGINE_PREDICTION_INPUT", "gs://INVALID BUCKET NAME/prediction_input.json" -) -PREDICTION_OUTPUT = os.environ.get( - "GCP_MLENGINE_PREDICTION_OUTPUT", "gs://INVALID BUCKET NAME/prediction_output" -) -TRAINER_URI = os.environ.get("GCP_MLENGINE_TRAINER_URI", "gs://INVALID BUCKET NAME/trainer.tar.gz") -TRAINER_PY_MODULE = os.environ.get("GCP_MLENGINE_TRAINER_TRAINER_PY_MODULE", "trainer.task") - -SUMMARY_TMP = os.environ.get("GCP_MLENGINE_DATAFLOW_TMP", "gs://INVALID BUCKET NAME/tmp/") -SUMMARY_STAGING = os.environ.get("GCP_MLENGINE_DATAFLOW_STAGING", "gs://INVALID BUCKET NAME/staging/") - - -with models.DAG( - "example_gcp_mlengine", - schedule_interval='@once', # Override to match your needs - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['example'], - params={"model_name": MODEL_NAME}, -) as dag: - hyperparams: Dict[str, Any] = { - 'goal': 'MAXIMIZE', - 'hyperparameterMetricTag': 'metric1', - 'maxTrials': 30, - 'maxParallelTrials': 1, - 'enableTrialEarlyStopping': True, - 'params': [], - } - - hyperparams['params'].append( - { - 'parameterName': 'hidden1', - 'type': 'INTEGER', - 'minValue': 40, - 'maxValue': 400, - 'scaleType': 'UNIT_LINEAR_SCALE', - } - ) - - hyperparams['params'].append( - {'parameterName': 'numRnnCells', 'type': 'DISCRETE', 'discreteValues': [1, 2, 3, 4]} - ) - - hyperparams['params'].append( - { - 'parameterName': 'rnnCellType', - 'type': 'CATEGORICAL', - 'categoricalValues': [ - 'BasicLSTMCell', - 'BasicRNNCell', - 'GRUCell', - 'LSTMCell', - 'LayerNormBasicLSTMCell', - ], - } - ) - # [START howto_operator_gcp_mlengine_training] - training = MLEngineStartTrainingJobOperator( - task_id="training", - project_id=PROJECT_ID, - region="us-central1", - job_id="training-job-{{ ts_nodash }}-{{ params.model_name }}", - runtime_version="1.15", - python_version="3.7", - job_dir=JOB_DIR, - package_uris=[TRAINER_URI], - training_python_module=TRAINER_PY_MODULE, - training_args=[], - labels={"job_type": "training"}, - hyperparameters=hyperparams, - ) - # [END howto_operator_gcp_mlengine_training] - - # [START howto_operator_gcp_mlengine_create_model] - create_model = MLEngineCreateModelOperator( - task_id="create-model", - project_id=PROJECT_ID, - model={ - "name": MODEL_NAME, - }, - ) - # [END howto_operator_gcp_mlengine_create_model] - - # [START howto_operator_gcp_mlengine_get_model] - get_model = MLEngineGetModelOperator( - task_id="get-model", - project_id=PROJECT_ID, - model_name=MODEL_NAME, - ) - # [END howto_operator_gcp_mlengine_get_model] - - # [START howto_operator_gcp_mlengine_print_model] - get_model_result = BashOperator( - bash_command=f"echo {get_model.output}", - task_id="get-model-result", - ) - # [END howto_operator_gcp_mlengine_print_model] - - # [START howto_operator_gcp_mlengine_create_version1] - create_version = MLEngineCreateVersionOperator( - task_id="create-version", - project_id=PROJECT_ID, - model_name=MODEL_NAME, - version={ - "name": "v1", - "description": "First-version", - "deployment_uri": f'{JOB_DIR}/keras_export/', - "runtime_version": "1.15", - "machineType": "mls1-c1-m2", - "framework": "TENSORFLOW", - "pythonVersion": "3.7", - }, - ) - # [END howto_operator_gcp_mlengine_create_version1] - - # [START howto_operator_gcp_mlengine_create_version2] - create_version_2 = MLEngineCreateVersionOperator( - task_id="create-version-2", - project_id=PROJECT_ID, - model_name=MODEL_NAME, - version={ - "name": "v2", - "description": "Second version", - "deployment_uri": SAVED_MODEL_PATH, - "runtime_version": "1.15", - "machineType": "mls1-c1-m2", - "framework": "TENSORFLOW", - "pythonVersion": "3.7", - }, - ) - # [END howto_operator_gcp_mlengine_create_version2] - - # [START howto_operator_gcp_mlengine_default_version] - set_defaults_version = MLEngineSetDefaultVersionOperator( - task_id="set-default-version", - project_id=PROJECT_ID, - model_name=MODEL_NAME, - version_name="v2", - ) - # [END howto_operator_gcp_mlengine_default_version] - - # [START howto_operator_gcp_mlengine_list_versions] - list_version = MLEngineListVersionsOperator( - task_id="list-version", - project_id=PROJECT_ID, - model_name=MODEL_NAME, - ) - # [END howto_operator_gcp_mlengine_list_versions] - - # [START howto_operator_gcp_mlengine_print_versions] - list_version_result = BashOperator( - bash_command=f"echo {list_version.output}", - task_id="list-version-result", - ) - # [END howto_operator_gcp_mlengine_print_versions] - - # [START howto_operator_gcp_mlengine_get_prediction] - prediction = MLEngineStartBatchPredictionJobOperator( - task_id="prediction", - project_id=PROJECT_ID, - job_id="prediction-{{ ts_nodash }}-{{ params.model_name }}", - region="us-central1", - model_name=MODEL_NAME, - data_format="TEXT", - input_paths=[PREDICTION_INPUT], - output_path=PREDICTION_OUTPUT, - labels={"job_type": "prediction"}, - ) - # [END howto_operator_gcp_mlengine_get_prediction] - - # [START howto_operator_gcp_mlengine_delete_version] - delete_version = MLEngineDeleteVersionOperator( - task_id="delete-version", project_id=PROJECT_ID, model_name=MODEL_NAME, version_name="v1" - ) - # [END howto_operator_gcp_mlengine_delete_version] - - # [START howto_operator_gcp_mlengine_delete_model] - delete_model = MLEngineDeleteModelOperator( - task_id="delete-model", project_id=PROJECT_ID, model_name=MODEL_NAME, delete_contents=True - ) - # [END howto_operator_gcp_mlengine_delete_model] - - training >> create_version - training >> create_version_2 - create_model >> get_model >> [get_model_result, delete_model] - create_model >> get_model >> delete_model - create_model >> create_version >> create_version_2 >> set_defaults_version >> list_version - create_version >> prediction - create_version_2 >> prediction - prediction >> delete_version - list_version >> list_version_result - list_version >> delete_version - delete_version >> delete_model - - # [START howto_operator_gcp_mlengine_get_metric] - def get_metric_fn_and_keys(): - """ - Gets metric function and keys used to generate summary - """ - - def normalize_value(inst: Dict): - val = float(inst['dense_4'][0]) - return tuple([val]) # returns a tuple. - - return normalize_value, ['val'] # key order must match. - - # [END howto_operator_gcp_mlengine_get_metric] - - # [START howto_operator_gcp_mlengine_validate_error] - def validate_err_and_count(summary: Dict) -> Dict: - """ - Validate summary result - """ - if summary['val'] > 1: - raise ValueError(f'Too high val>1; summary={summary}') - if summary['val'] < 0: - raise ValueError(f'Too low val<0; summary={summary}') - if summary['count'] != 20: - raise ValueError(f'Invalid value val != 20; summary={summary}') - return summary - - # [END howto_operator_gcp_mlengine_validate_error] - - # [START howto_operator_gcp_mlengine_evaluate] - evaluate_prediction, evaluate_summary, evaluate_validation = mlengine_operator_utils.create_evaluate_ops( - task_prefix="evaluate-ops", - data_format="TEXT", - input_paths=[PREDICTION_INPUT], - prediction_path=PREDICTION_OUTPUT, - metric_fn_and_keys=get_metric_fn_and_keys(), - validate_fn=validate_err_and_count, - batch_prediction_job_id="evaluate-ops-{{ ts_nodash }}-{{ params.model_name }}", - project_id=PROJECT_ID, - region="us-central1", - dataflow_options={ - 'project': PROJECT_ID, - 'tempLocation': SUMMARY_TMP, - 'stagingLocation': SUMMARY_STAGING, - }, - model_name=MODEL_NAME, - version_name="v1", - py_interpreter="python3", - ) - # [END howto_operator_gcp_mlengine_evaluate] - - create_model >> create_version >> evaluate_prediction - evaluate_validation >> delete_version diff --git a/airflow/providers/google/cloud/example_dags/example_mssql_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_mssql_to_gcs.py deleted file mode 100644 index 46be2529d34a9..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_mssql_to_gcs.py +++ /dev/null @@ -1,44 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.transfers.mssql_to_gcs import MSSQLToGCSOperator - -GCS_BUCKET = os.environ.get("GCP_GCS_BUCKET", "example-airflow") -FILENAME = 'test_file' - -SQL_QUERY = "USE airflow SELECT * FROM Country;" - -with models.DAG( - 'example_mssql_to_gcs', - schedule_interval='@once', - start_date=datetime(2021, 12, 1), - catchup=False, - tags=['example'], -) as dag: - # [START howto_operator_mssql_to_gcs] - upload = MSSQLToGCSOperator( - task_id='mssql_to_gcs', - mssql_conn_id='airflow_mssql', - sql=SQL_QUERY, - bucket=GCS_BUCKET, - filename=FILENAME, - export_format='csv', - ) - # [END howto_operator_mssql_to_gcs] diff --git a/airflow/providers/google/cloud/example_dags/example_mysql_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_mysql_to_gcs.py deleted file mode 100644 index c8c798bc89101..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_mysql_to_gcs.py +++ /dev/null @@ -1,40 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.transfers.mysql_to_gcs import MySQLToGCSOperator - -GCS_BUCKET = os.environ.get("GCP_GCS_BUCKET", "example-airflow-mysql-gcs") -FILENAME = 'test_file' - -SQL_QUERY = "SELECT * from test_table" - -with models.DAG( - 'example_mysql_to_gcs', - schedule_interval='@once', - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['example'], -) as dag: - # [START howto_operator_mysql_to_gcs] - upload = MySQLToGCSOperator( - task_id='mysql_to_gcs', sql=SQL_QUERY, bucket=GCS_BUCKET, filename=FILENAME, export_format='csv' - ) - # [END howto_operator_mysql_to_gcs] diff --git a/airflow/providers/google/cloud/example_dags/example_oracle_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_oracle_to_gcs.py deleted file mode 100644 index 2d5d5c59bfd1e..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_oracle_to_gcs.py +++ /dev/null @@ -1,40 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.transfers.oracle_to_gcs import OracleToGCSOperator - -GCS_BUCKET = os.environ.get("GCP_GCS_BUCKET", "example-airflow-oracle-gcs") -FILENAME = 'test_file' - -SQL_QUERY = "SELECT * from test_table" - -with models.DAG( - 'example_oracle_to_gcs', - schedule_interval=None, - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['example'], -) as dag: - # [START howto_operator_oracle_to_gcs] - upload = OracleToGCSOperator( - task_id='oracle_to_gcs', sql=SQL_QUERY, bucket=GCS_BUCKET, filename=FILENAME, export_format='csv' - ) - # [END howto_operator_oracle_to_gcs] diff --git a/airflow/providers/google/cloud/example_dags/example_postgres_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_postgres_to_gcs.py index 96ac71193998a..eca37996eee01 100644 --- a/airflow/providers/google/cloud/example_dags/example_postgres_to_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_postgres_to_gcs.py @@ -18,6 +18,8 @@ """ Example DAG using PostgresToGoogleCloudStorageOperator. """ +from __future__ import annotations + import os from datetime import datetime @@ -30,11 +32,10 @@ SQL_QUERY = "select * from test_table;" with models.DAG( - dag_id='example_postgres_to_gcs', - schedule_interval='@once', # Override to match your needs + dag_id="example_postgres_to_gcs", start_date=datetime(2021, 1, 1), catchup=False, - tags=['example'], + tags=["example"], ) as dag: upload_data = PostgresToGCSOperator( task_id="get_data", sql=SQL_QUERY, bucket=GCS_BUCKET, filename=FILENAME, gzip=False diff --git a/airflow/providers/google/cloud/example_dags/example_presto_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_presto_to_gcs.py index 6ac82c5b97995..5c5ce985b5965 100644 --- a/airflow/providers/google/cloud/example_dags/example_presto_to_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_presto_to_gcs.py @@ -18,6 +18,8 @@ """ Example DAG using PrestoToGCSOperator. """ +from __future__ import annotations + import os import re from datetime import datetime @@ -31,7 +33,7 @@ ) from airflow.providers.google.cloud.transfers.presto_to_gcs import PrestoToGCSOperator -GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", 'example-project') +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") GCS_BUCKET = os.environ.get("GCP_PRESTO_TO_GCS_BUCKET_NAME", "INVALID BUCKET NAME") DATASET_NAME = os.environ.get("GCP_PRESTO_TO_GCS_DATASET_NAME", "test_presto_to_gcs_dataset") @@ -48,7 +50,6 @@ def safe_name(s: str) -> str: with models.DAG( dag_id="example_presto_to_gcs", - schedule_interval='@once', # Override to match your needs start_date=datetime(2021, 1, 1), catchup=False, tags=["example"], diff --git a/airflow/providers/google/cloud/example_dags/example_pubsub.py b/airflow/providers/google/cloud/example_dags/example_pubsub.py deleted file mode 100644 index 8e3dd1fe8f01e..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_pubsub.py +++ /dev/null @@ -1,184 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG that uses Google PubSub services. -""" -import os -from datetime import datetime - -from airflow import models -from airflow.operators.bash import BashOperator -from airflow.providers.google.cloud.operators.pubsub import ( - PubSubCreateSubscriptionOperator, - PubSubCreateTopicOperator, - PubSubDeleteSubscriptionOperator, - PubSubDeleteTopicOperator, - PubSubPublishMessageOperator, - PubSubPullOperator, -) -from airflow.providers.google.cloud.sensors.pubsub import PubSubPullSensor - -START_DATE = datetime(2021, 1, 1) - -GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") -TOPIC_FOR_SENSOR_DAG = os.environ.get("GCP_PUBSUB_SENSOR_TOPIC", "PubSubSensorTestTopic") -TOPIC_FOR_OPERATOR_DAG = os.environ.get("GCP_PUBSUB_OPERATOR_TOPIC", "PubSubOperatorTestTopic") -MESSAGE = {"data": b"Tool", "attributes": {"name": "wrench", "mass": "1.3kg", "count": "3"}} - -# [START howto_operator_gcp_pubsub_pull_messages_result_cmd] -echo_cmd = """ -{% for m in task_instance.xcom_pull('pull_messages') %} - echo "AckID: {{ m.get('ackId') }}, Base64-Encoded: {{ m.get('message') }}" -{% endfor %} -""" -# [END howto_operator_gcp_pubsub_pull_messages_result_cmd] - -with models.DAG( - "example_gcp_pubsub_sensor", - schedule_interval='@once', # Override to match your needs - start_date=START_DATE, - catchup=False, -) as example_sensor_dag: - # [START howto_operator_gcp_pubsub_create_topic] - create_topic = PubSubCreateTopicOperator( - task_id="create_topic", topic=TOPIC_FOR_SENSOR_DAG, project_id=GCP_PROJECT_ID, fail_if_exists=False - ) - # [END howto_operator_gcp_pubsub_create_topic] - - # [START howto_operator_gcp_pubsub_create_subscription] - subscribe_task = PubSubCreateSubscriptionOperator( - task_id="subscribe_task", project_id=GCP_PROJECT_ID, topic=TOPIC_FOR_SENSOR_DAG - ) - # [END howto_operator_gcp_pubsub_create_subscription] - - # [START howto_operator_gcp_pubsub_pull_message_with_sensor] - subscription = subscribe_task.output - - pull_messages = PubSubPullSensor( - task_id="pull_messages", - ack_messages=True, - project_id=GCP_PROJECT_ID, - subscription=subscription, - ) - # [END howto_operator_gcp_pubsub_pull_message_with_sensor] - - # [START howto_operator_gcp_pubsub_pull_messages_result] - pull_messages_result = BashOperator(task_id="pull_messages_result", bash_command=echo_cmd) - # [END howto_operator_gcp_pubsub_pull_messages_result] - - # [START howto_operator_gcp_pubsub_publish] - publish_task = PubSubPublishMessageOperator( - task_id="publish_task", - project_id=GCP_PROJECT_ID, - topic=TOPIC_FOR_SENSOR_DAG, - messages=[MESSAGE] * 10, - ) - # [END howto_operator_gcp_pubsub_publish] - - # [START howto_operator_gcp_pubsub_unsubscribe] - unsubscribe_task = PubSubDeleteSubscriptionOperator( - task_id="unsubscribe_task", - project_id=GCP_PROJECT_ID, - subscription=subscription, - ) - # [END howto_operator_gcp_pubsub_unsubscribe] - - # [START howto_operator_gcp_pubsub_delete_topic] - delete_topic = PubSubDeleteTopicOperator( - task_id="delete_topic", topic=TOPIC_FOR_SENSOR_DAG, project_id=GCP_PROJECT_ID - ) - # [END howto_operator_gcp_pubsub_delete_topic] - - create_topic >> subscribe_task >> publish_task - pull_messages >> pull_messages_result >> unsubscribe_task >> delete_topic - - # Task dependencies created via `XComArgs`: - # subscribe_task >> pull_messages - # subscribe_task >> unsubscribe_task - - -with models.DAG( - "example_gcp_pubsub_operator", - schedule_interval='@once', # Override to match your needs - start_date=START_DATE, - catchup=False, -) as example_operator_dag: - # [START howto_operator_gcp_pubsub_create_topic] - create_topic = PubSubCreateTopicOperator( - task_id="create_topic", topic=TOPIC_FOR_OPERATOR_DAG, project_id=GCP_PROJECT_ID - ) - # [END howto_operator_gcp_pubsub_create_topic] - - # [START howto_operator_gcp_pubsub_create_subscription] - subscribe_task = PubSubCreateSubscriptionOperator( - task_id="subscribe_task", project_id=GCP_PROJECT_ID, topic=TOPIC_FOR_OPERATOR_DAG - ) - # [END howto_operator_gcp_pubsub_create_subscription] - - # [START howto_operator_gcp_pubsub_pull_message_with_operator] - subscription = subscribe_task.output - - pull_messages_operator = PubSubPullOperator( - task_id="pull_messages", - ack_messages=True, - project_id=GCP_PROJECT_ID, - subscription=subscription, - ) - # [END howto_operator_gcp_pubsub_pull_message_with_operator] - - # [START howto_operator_gcp_pubsub_pull_messages_result] - pull_messages_result = BashOperator(task_id="pull_messages_result", bash_command=echo_cmd) - # [END howto_operator_gcp_pubsub_pull_messages_result] - - # [START howto_operator_gcp_pubsub_publish] - publish_task = PubSubPublishMessageOperator( - task_id="publish_task", - project_id=GCP_PROJECT_ID, - topic=TOPIC_FOR_OPERATOR_DAG, - messages=[MESSAGE, MESSAGE, MESSAGE], - ) - # [END howto_operator_gcp_pubsub_publish] - - # [START howto_operator_gcp_pubsub_unsubscribe] - unsubscribe_task = PubSubDeleteSubscriptionOperator( - task_id="unsubscribe_task", - project_id=GCP_PROJECT_ID, - subscription=subscription, - ) - # [END howto_operator_gcp_pubsub_unsubscribe] - - # [START howto_operator_gcp_pubsub_delete_topic] - delete_topic = PubSubDeleteTopicOperator( - task_id="delete_topic", topic=TOPIC_FOR_OPERATOR_DAG, project_id=GCP_PROJECT_ID - ) - # [END howto_operator_gcp_pubsub_delete_topic] - - ( - create_topic - >> subscribe_task - >> publish_task - >> pull_messages_operator - >> pull_messages_result - >> unsubscribe_task - >> delete_topic - ) - - # Task dependencies created via `XComArgs`: - # subscribe_task >> pull_messages_operator - # subscribe_task >> unsubscribe_task diff --git a/airflow/providers/google/cloud/example_dags/example_s3_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_s3_to_gcs.py deleted file mode 100644 index e3948f390dc45..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_s3_to_gcs.py +++ /dev/null @@ -1,78 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -from datetime import datetime - -from airflow import models -from airflow.decorators import task -from airflow.providers.amazon.aws.hooks.s3 import S3Hook -from airflow.providers.amazon.aws.operators.s3 import S3CreateBucketOperator, S3DeleteBucketOperator -from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator -from airflow.providers.google.cloud.transfers.s3_to_gcs import S3ToGCSOperator - -GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'gcp-project-id') -S3BUCKET_NAME = os.environ.get('S3BUCKET_NAME', 'example-s3bucket-name') -GCS_BUCKET = os.environ.get('GCP_GCS_BUCKET', 'example-gcsbucket-name') -GCS_BUCKET_URL = f"gs://{GCS_BUCKET}/" -UPLOAD_FILE = '/tmp/example-file.txt' -PREFIX = 'TESTS' - - -@task(task_id='upload_file_to_s3') -def upload_file(): - """A callable to upload file to AWS bucket""" - s3_hook = S3Hook() - s3_hook.load_file(filename=UPLOAD_FILE, key=PREFIX, bucket_name=S3BUCKET_NAME) - - -with models.DAG( - 'example_s3_to_gcs', - schedule_interval='@once', - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['example'], -) as dag: - create_s3_bucket = S3CreateBucketOperator( - task_id="create_s3_bucket", bucket_name=S3BUCKET_NAME, region_name='us-east-1' - ) - - create_gcs_bucket = GCSCreateBucketOperator( - task_id="create_bucket", - bucket_name=GCS_BUCKET, - project_id=GCP_PROJECT_ID, - ) - # [START howto_transfer_s3togcs_operator] - transfer_to_gcs = S3ToGCSOperator( - task_id='s3_to_gcs_task', bucket=S3BUCKET_NAME, prefix=PREFIX, dest_gcs=GCS_BUCKET_URL - ) - # [END howto_transfer_s3togcs_operator] - - delete_s3_bucket = S3DeleteBucketOperator( - task_id='delete_s3_bucket', bucket_name=S3BUCKET_NAME, force_delete=True - ) - - delete_gcs_bucket = GCSDeleteBucketOperator(task_id='delete_gcs_bucket', bucket_name=GCS_BUCKET) - - ( - create_s3_bucket - >> upload_file() - >> create_gcs_bucket - >> transfer_to_gcs - >> delete_s3_bucket - >> delete_gcs_bucket - ) diff --git a/airflow/providers/google/cloud/example_dags/example_salesforce_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_salesforce_to_gcs.py index 4cf449f538cf0..c41e515b6547c 100644 --- a/airflow/providers/google/cloud/example_dags/example_salesforce_to_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_salesforce_to_gcs.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ Example Airflow DAG that shows how to use SalesforceToGcsOperator. """ +from __future__ import annotations + import os from datetime import datetime @@ -44,7 +45,6 @@ with models.DAG( "example_salesforce_to_gcs", - schedule_interval='@once', # Override to match your needs start_date=datetime(2021, 1, 1), catchup=False, ) as dag: @@ -62,7 +62,7 @@ bucket_name=GCS_BUCKET, object_name=GCS_OBJ_PATH, salesforce_conn_id=SALESFORCE_CONN_ID, - export_format='csv', + export_format="csv", coerce_to_timestamp=False, record_time_added=False, gcp_conn_id=GCS_CONN_ID, @@ -80,23 +80,23 @@ dataset_id=DATASET_NAME, table_id=TABLE_NAME, schema_fields=[ - {'name': 'id', 'type': 'STRING', 'mode': 'NULLABLE'}, - {'name': 'name', 'type': 'STRING', 'mode': 'NULLABLE'}, - {'name': 'company', 'type': 'STRING', 'mode': 'NULLABLE'}, - {'name': 'phone', 'type': 'STRING', 'mode': 'NULLABLE'}, - {'name': 'email', 'type': 'STRING', 'mode': 'NULLABLE'}, - {'name': 'createddate', 'type': 'STRING', 'mode': 'NULLABLE'}, - {'name': 'lastmodifieddate', 'type': 'STRING', 'mode': 'NULLABLE'}, - {'name': 'isdeleted', 'type': 'BOOL', 'mode': 'NULLABLE'}, + {"name": "id", "type": "STRING", "mode": "NULLABLE"}, + {"name": "name", "type": "STRING", "mode": "NULLABLE"}, + {"name": "company", "type": "STRING", "mode": "NULLABLE"}, + {"name": "phone", "type": "STRING", "mode": "NULLABLE"}, + {"name": "email", "type": "STRING", "mode": "NULLABLE"}, + {"name": "createddate", "type": "STRING", "mode": "NULLABLE"}, + {"name": "lastmodifieddate", "type": "STRING", "mode": "NULLABLE"}, + {"name": "isdeleted", "type": "BOOL", "mode": "NULLABLE"}, ], ) load_csv = GCSToBigQueryOperator( - task_id='gcs_to_bq', + task_id="gcs_to_bq", bucket=GCS_BUCKET, source_objects=[GCS_OBJ_PATH], destination_project_dataset_table=f"{DATASET_NAME}.{TABLE_NAME}", - write_disposition='WRITE_TRUNCATE', + write_disposition="WRITE_TRUNCATE", ) read_data_from_gcs = BigQueryInsertJobOperator( diff --git a/airflow/providers/google/cloud/example_dags/example_sftp_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_sftp_to_gcs.py deleted file mode 100644 index 46870f48ff5f8..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_sftp_to_gcs.py +++ /dev/null @@ -1,79 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example Airflow DAG for Google Cloud Storage to SFTP transfer operators. -""" - -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.transfers.sftp_to_gcs import SFTPToGCSOperator - -BUCKET_SRC = os.environ.get("GCP_GCS_BUCKET_1_SRC", "test-sftp-gcs") - -TMP_PATH = "/tmp" -DIR = "tests_sftp_hook_dir" -SUBDIR = "subdir" - -OBJECT_SRC_1 = "parent-1.bin" -OBJECT_SRC_2 = "parent-2.bin" -OBJECT_SRC_3 = "parent-3.txt" - - -with models.DAG( - "example_sftp_to_gcs", - schedule_interval='@once', - start_date=datetime(2021, 1, 1), - catchup=False, -) as dag: - # [START howto_operator_sftp_to_gcs_copy_single_file] - copy_file_from_sftp_to_gcs = SFTPToGCSOperator( - task_id="file-copy-sftp-to-gcs", - source_path=os.path.join(TMP_PATH, DIR, OBJECT_SRC_1), - destination_bucket=BUCKET_SRC, - ) - # [END howto_operator_sftp_to_gcs_copy_single_file] - - # [START howto_operator_sftp_to_gcs_move_single_file_destination] - move_file_from_sftp_to_gcs_destination = SFTPToGCSOperator( - task_id="file-move-sftp-to-gcs-destination", - source_path=os.path.join(TMP_PATH, DIR, OBJECT_SRC_2), - destination_bucket=BUCKET_SRC, - destination_path="destination_dir/destination_filename.bin", - move_object=True, - ) - # [END howto_operator_sftp_to_gcs_move_single_file_destination] - - # [START howto_operator_sftp_to_gcs_copy_directory] - copy_directory_from_sftp_to_gcs = SFTPToGCSOperator( - task_id="dir-copy-sftp-to-gcs", - source_path=os.path.join(TMP_PATH, DIR, SUBDIR, "*"), - destination_bucket=BUCKET_SRC, - ) - # [END howto_operator_sftp_to_gcs_copy_directory] - - # [START howto_operator_sftp_to_gcs_move_specific_files] - move_specific_files_from_gcs_to_sftp = SFTPToGCSOperator( - task_id="dir-move-specific-files-sftp-to-gcs", - source_path=os.path.join(TMP_PATH, DIR, SUBDIR, "*.bin"), - destination_bucket=BUCKET_SRC, - destination_path="specific_files/", - move_object=True, - ) - # [END howto_operator_sftp_to_gcs_move_specific_files] diff --git a/airflow/providers/google/cloud/example_dags/example_sheets_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_sheets_to_gcs.py deleted file mode 100644 index 0741fa0e3d332..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_sheets_to_gcs.py +++ /dev/null @@ -1,41 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.transfers.sheets_to_gcs import GoogleSheetsToGCSOperator - -BUCKET = os.environ.get("GCP_GCS_BUCKET", "test28397yeo") -SPREADSHEET_ID = os.environ.get("SPREADSHEET_ID", "1234567890qwerty") - -with models.DAG( - "example_sheets_to_gcs", - schedule_interval='@once', # Override to match your needs - start_date=datetime(2021, 1, 1), - catchup=False, - tags=["example"], -) as dag: - # [START upload_sheet_to_gcs] - upload_sheet_to_gcs = GoogleSheetsToGCSOperator( - task_id="upload_sheet_to_gcs", - destination_bucket=BUCKET, - spreadsheet_id=SPREADSHEET_ID, - ) - # [END upload_sheet_to_gcs] diff --git a/airflow/providers/google/cloud/example_dags/example_translate.py b/airflow/providers/google/cloud/example_dags/example_translate.py deleted file mode 100644 index e74e5a160914e..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_translate.py +++ /dev/null @@ -1,52 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG that translates text in Google Cloud Translate -service in the Google Cloud. - -""" -from datetime import datetime - -from airflow import models -from airflow.operators.bash import BashOperator -from airflow.providers.google.cloud.operators.translate import CloudTranslateTextOperator - -with models.DAG( - 'example_gcp_translate', - schedule_interval='@once', # Override to match your needs - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['example'], -) as dag: - # [START howto_operator_translate_text] - product_set_create = CloudTranslateTextOperator( - task_id='translate', - values=['zażółć gęślą jaźń'], - target_language='en', - format_='text', - source_language=None, - model='base', - ) - # [END howto_operator_translate_text] - # [START howto_operator_translate_access] - translation_access = BashOperator( - task_id='access', bash_command="echo '{{ task_instance.xcom_pull(\"translate\")[0] }}'" - ) - product_set_create >> translation_access - # [END howto_operator_translate_access] diff --git a/airflow/providers/google/cloud/example_dags/example_translate_speech.py b/airflow/providers/google/cloud/example_dags/example_translate_speech.py deleted file mode 100644 index 35e129dd241c1..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_translate_speech.py +++ /dev/null @@ -1,86 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.operators.text_to_speech import CloudTextToSpeechSynthesizeOperator -from airflow.providers.google.cloud.operators.translate_speech import CloudTranslateSpeechOperator - -GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") -BUCKET_NAME = os.environ.get("GCP_TRANSLATE_SPEECH_TEST_BUCKET", "INVALID BUCKET NAME") - -# [START howto_operator_translate_speech_gcp_filename] -FILENAME = "gcp-speech-test-file" -# [END howto_operator_translate_speech_gcp_filename] - -# [START howto_operator_text_to_speech_api_arguments] -INPUT = {"text": "Sample text for demo purposes"} -VOICE = {"language_code": "en-US", "ssml_gender": "FEMALE"} -AUDIO_CONFIG = {"audio_encoding": "LINEAR16"} -# [END howto_operator_text_to_speech_api_arguments] - -# [START howto_operator_translate_speech_arguments] -CONFIG = {"encoding": "LINEAR16", "language_code": "en_US"} -AUDIO = {"uri": f"gs://{BUCKET_NAME}/{FILENAME}"} -TARGET_LANGUAGE = 'pl' -FORMAT = 'text' -MODEL = 'base' -SOURCE_LANGUAGE = None # type: None -# [END howto_operator_translate_speech_arguments] - - -with models.DAG( - "example_gcp_translate_speech", - schedule_interval='@once', # Override to match your needs - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['example'], -) as dag: - text_to_speech_synthesize_task = CloudTextToSpeechSynthesizeOperator( - project_id=GCP_PROJECT_ID, - input_data=INPUT, - voice=VOICE, - audio_config=AUDIO_CONFIG, - target_bucket_name=BUCKET_NAME, - target_filename=FILENAME, - task_id="text_to_speech_synthesize_task", - ) - # [START howto_operator_translate_speech] - translate_speech_task = CloudTranslateSpeechOperator( - project_id=GCP_PROJECT_ID, - audio=AUDIO, - config=CONFIG, - target_language=TARGET_LANGUAGE, - format_=FORMAT, - source_language=SOURCE_LANGUAGE, - model=MODEL, - task_id='translate_speech_task', - ) - translate_speech_task2 = CloudTranslateSpeechOperator( - audio=AUDIO, - config=CONFIG, - target_language=TARGET_LANGUAGE, - format_=FORMAT, - source_language=SOURCE_LANGUAGE, - model=MODEL, - task_id='translate_speech_task2', - ) - # [END howto_operator_translate_speech] - text_to_speech_synthesize_task >> translate_speech_task >> translate_speech_task2 diff --git a/airflow/providers/google/cloud/example_dags/example_vertex_ai.py b/airflow/providers/google/cloud/example_dags/example_vertex_ai.py index cded48ae9b4de..4265ac31ce7e2 100644 --- a/airflow/providers/google/cloud/example_dags/example_vertex_ai.py +++ b/airflow/providers/google/cloud/example_dags/example_vertex_ai.py @@ -15,10 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# mypy ignore arg types (for templated fields) -# type: ignore[arg-type] - """ Example Airflow DAG that demonstrates operators for the Google Vertex AI service in the Google Cloud Platform. @@ -26,21 +22,22 @@ This DAG relies on the following OS environment variables: * GCP_VERTEX_AI_BUCKET - Google Cloud Storage bucket where the model will be saved -after training process was finished. + after training process was finished. * CUSTOM_CONTAINER_URI - path to container with model. * PYTHON_PACKAGE_GSC_URI - path to test model in archive. * LOCAL_TRAINING_SCRIPT_PATH - path to local training script. * DATASET_ID - ID of dataset which will be used in training process. * MODEL_ID - ID of model which will be used in predict process. * MODEL_ARTIFACT_URI - The artifact_uri should be the path to a GCS directory containing saved model -artifacts. + artifacts. """ +from __future__ import annotations + import os from datetime import datetime from uuid import uuid4 from google.cloud import aiplatform -from google.protobuf import json_format from google.protobuf.struct_pb2 import Value from airflow import models @@ -94,6 +91,10 @@ UploadModelOperator, ) +# mypy ignore arg types (for templated fields) +# type: ignore[arg-type] + + PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "an-id") REGION = os.environ.get("GCP_LOCATION", "us-central1") BUCKET = os.environ.get("GCP_VERTEX_AI_BUCKET", "vertex-ai-system-tests") @@ -195,22 +196,22 @@ JOB_DISPLAY_NAME = f"temp_create_batch_prediction_job_test_{uuid4()}" BIGQUERY_SOURCE = f"bq://{PROJECT_ID}.test_iowa_liquor_sales_forecasting_us.2021_sales_predict" GCS_DESTINATION_PREFIX = "gs://test-vertex-ai-bucket-us/output" -MODEL_PARAMETERS = json_format.ParseDict({}, Value()) +MODEL_PARAMETERS: dict | None = {} ENDPOINT_CONF = { "display_name": f"endpoint_test_{uuid4()}", } DEPLOYED_MODEL = { # format: 'projects/{project}/locations/{location}/models/{model}' - 'model': f"projects/{PROJECT_ID}/locations/{REGION}/models/{MODEL_ID}", - 'display_name': f"temp_endpoint_test_{uuid4()}", + "model": f"projects/{PROJECT_ID}/locations/{REGION}/models/{MODEL_ID}", + "display_name": f"temp_endpoint_test_{uuid4()}", "dedicated_resources": { "machine_spec": { "machine_type": "n1-standard-2", "accelerator_type": aiplatform.gapic.AcceleratorType.NVIDIA_TESLA_K80, "accelerator_count": 1, }, - 'min_replica_count': 1, + "min_replica_count": 1, "max_replica_count": 1, }, } @@ -237,7 +238,6 @@ with models.DAG( "example_gcp_vertex_ai_custom_jobs", - schedule_interval="@once", start_date=datetime(2021, 1, 1), catchup=False, ) as custom_jobs_dag: @@ -327,7 +327,6 @@ with models.DAG( "example_gcp_vertex_ai_dataset", - schedule_interval="@once", start_date=datetime(2021, 1, 1), catchup=False, ) as dataset_dag: @@ -367,7 +366,7 @@ # [START how_to_cloud_vertex_ai_delete_dataset_operator] delete_dataset_job = DeleteDatasetOperator( task_id="delete_dataset", - dataset_id=create_text_dataset_job.output['dataset_id'], + dataset_id=create_text_dataset_job.output["dataset_id"], region=REGION, project_id=PROJECT_ID, ) @@ -378,14 +377,14 @@ task_id="get_dataset", project_id=PROJECT_ID, region=REGION, - dataset_id=create_tabular_dataset_job.output['dataset_id'], + dataset_id=create_tabular_dataset_job.output["dataset_id"], ) # [END how_to_cloud_vertex_ai_get_dataset_operator] # [START how_to_cloud_vertex_ai_export_data_operator] export_data_job = ExportDataOperator( task_id="export_data", - dataset_id=create_image_dataset_job.output['dataset_id'], + dataset_id=create_image_dataset_job.output["dataset_id"], region=REGION, project_id=PROJECT_ID, export_config=TEST_EXPORT_CONFIG, @@ -395,7 +394,7 @@ # [START how_to_cloud_vertex_ai_import_data_operator] import_data_job = ImportDataOperator( task_id="import_data", - dataset_id=create_image_dataset_job.output['dataset_id'], + dataset_id=create_image_dataset_job.output["dataset_id"], region=REGION, project_id=PROJECT_ID, import_configs=TEST_IMPORT_CONFIG, @@ -415,7 +414,7 @@ task_id="update_dataset", project_id=PROJECT_ID, region=REGION, - dataset_id=create_video_dataset_job.output['dataset_id'], + dataset_id=create_video_dataset_job.output["dataset_id"], dataset=DATASET_TO_UPDATE, update_mask=TEST_UPDATE_MASK, ) @@ -430,7 +429,6 @@ with models.DAG( "example_gcp_vertex_ai_auto_ml", - schedule_interval="@once", start_date=datetime(2021, 1, 1), catchup=False, ) as auto_ml_dag: @@ -547,7 +545,6 @@ with models.DAG( "example_gcp_vertex_ai_batch_prediction_job", - schedule_interval="@once", start_date=datetime(2021, 1, 1), catchup=False, ) as batch_prediction_job_dag: @@ -576,7 +573,7 @@ # [START how_to_cloud_vertex_ai_delete_batch_prediction_job_operator] delete_batch_prediction_job = DeleteBatchPredictionJobOperator( task_id="delete_batch_prediction_job", - batch_prediction_job_id=create_batch_prediction_job.output['batch_prediction_job_id'], + batch_prediction_job_id=create_batch_prediction_job.output["batch_prediction_job_id"], region=REGION, project_id=PROJECT_ID, ) @@ -587,7 +584,6 @@ with models.DAG( "example_gcp_vertex_ai_endpoint", - schedule_interval="@once", start_date=datetime(2021, 1, 1), catchup=False, ) as endpoint_dag: @@ -603,7 +599,7 @@ # [START how_to_cloud_vertex_ai_delete_endpoint_operator] delete_endpoint = DeleteEndpointOperator( task_id="delete_endpoint", - endpoint_id=create_endpoint.output['endpoint_id'], + endpoint_id=create_endpoint.output["endpoint_id"], region=REGION, project_id=PROJECT_ID, ) @@ -620,9 +616,9 @@ # [START how_to_cloud_vertex_ai_deploy_model_operator] deploy_model = DeployModelOperator( task_id="deploy_model", - endpoint_id=create_endpoint.output['endpoint_id'], + endpoint_id=create_endpoint.output["endpoint_id"], deployed_model=DEPLOYED_MODEL, - traffic_split={'0': 100}, + traffic_split={"0": 100}, region=REGION, project_id=PROJECT_ID, ) @@ -631,8 +627,8 @@ # [START how_to_cloud_vertex_ai_undeploy_model_operator] undeploy_model = UndeployModelOperator( task_id="undeploy_model", - endpoint_id=create_endpoint.output['endpoint_id'], - deployed_model_id=deploy_model.output['deployed_model_id'], + endpoint_id=create_endpoint.output["endpoint_id"], + deployed_model_id=deploy_model.output["deployed_model_id"], region=REGION, project_id=PROJECT_ID, ) @@ -643,7 +639,6 @@ with models.DAG( "example_gcp_vertex_ai_hyperparameter_tuning_job", - schedule_interval="@once", start_date=datetime(2021, 1, 1), catchup=False, ) as hyperparameter_tuning_job_dag: @@ -669,16 +664,16 @@ region=REGION, project_id=PROJECT_ID, parameter_spec={ - 'learning_rate': aiplatform.hyperparameter_tuning.DoubleParameterSpec( - min=0.01, max=1, scale='log' + "learning_rate": aiplatform.hyperparameter_tuning.DoubleParameterSpec( + min=0.01, max=1, scale="log" ), - 'momentum': aiplatform.hyperparameter_tuning.DoubleParameterSpec(min=0, max=1, scale='linear'), - 'num_neurons': aiplatform.hyperparameter_tuning.DiscreteParameterSpec( - values=[64, 128, 512], scale='linear' + "momentum": aiplatform.hyperparameter_tuning.DoubleParameterSpec(min=0, max=1, scale="linear"), + "num_neurons": aiplatform.hyperparameter_tuning.DiscreteParameterSpec( + values=[64, 128, 512], scale="linear" ), }, metric_spec={ - 'accuracy': 'maximize', + "accuracy": "maximize", }, max_trial_count=15, parallel_trial_count=3, @@ -716,7 +711,6 @@ with models.DAG( "example_gcp_vertex_ai_model_service", - schedule_interval="@once", start_date=datetime(2021, 1, 1), catchup=False, ) as model_service_dag: diff --git a/airflow/providers/google/cloud/example_dags/example_video_intelligence.py b/airflow/providers/google/cloud/example_dags/example_video_intelligence.py deleted file mode 100644 index 7280cf3e085c1..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_video_intelligence.py +++ /dev/null @@ -1,116 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG that demonstrates operators for the Google Cloud Video Intelligence service in the Google -Cloud Platform. - -This DAG relies on the following OS environment variables: - -* GCP_BUCKET_NAME - Google Cloud Storage bucket where the file exists. -""" -import os -from datetime import datetime - -from google.api_core.retry import Retry - -from airflow import models -from airflow.operators.bash import BashOperator -from airflow.providers.google.cloud.operators.video_intelligence import ( - CloudVideoIntelligenceDetectVideoExplicitContentOperator, - CloudVideoIntelligenceDetectVideoLabelsOperator, - CloudVideoIntelligenceDetectVideoShotsOperator, -) - -# [START howto_operator_video_intelligence_os_args] -GCP_BUCKET_NAME = os.environ.get("GCP_VIDEO_INTELLIGENCE_BUCKET_NAME", "INVALID BUCKET NAME") -# [END howto_operator_video_intelligence_os_args] - - -# [START howto_operator_video_intelligence_other_args] -INPUT_URI = f"gs://{GCP_BUCKET_NAME}/video.mp4" -# [END howto_operator_video_intelligence_other_args] - - -with models.DAG( - "example_gcp_video_intelligence", - schedule_interval='@once', # Override to match your needs - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['example'], -) as dag: - - # [START howto_operator_video_intelligence_detect_labels] - detect_video_label = CloudVideoIntelligenceDetectVideoLabelsOperator( - input_uri=INPUT_URI, - output_uri=None, - video_context=None, - timeout=5, - task_id="detect_video_label", - ) - # [END howto_operator_video_intelligence_detect_labels] - - # [START howto_operator_video_intelligence_detect_labels_result] - detect_video_label_result = BashOperator( - bash_command="echo {{ task_instance.xcom_pull('detect_video_label')" - "['annotationResults'][0]['shotLabelAnnotations'][0]['entity']}}", - task_id="detect_video_label_result", - ) - # [END howto_operator_video_intelligence_detect_labels_result] - - # [START howto_operator_video_intelligence_detect_explicit_content] - detect_video_explicit_content = CloudVideoIntelligenceDetectVideoExplicitContentOperator( - input_uri=INPUT_URI, - output_uri=None, - video_context=None, - retry=Retry(maximum=10.0), - timeout=5, - task_id="detect_video_explicit_content", - ) - # [END howto_operator_video_intelligence_detect_explicit_content] - - # [START howto_operator_video_intelligence_detect_explicit_content_result] - detect_video_explicit_content_result = BashOperator( - bash_command="echo {{ task_instance.xcom_pull('detect_video_explicit_content')" - "['annotationResults'][0]['explicitAnnotation']['frames'][0]}}", - task_id="detect_video_explicit_content_result", - ) - # [END howto_operator_video_intelligence_detect_explicit_content_result] - - # [START howto_operator_video_intelligence_detect_video_shots] - detect_video_shots = CloudVideoIntelligenceDetectVideoShotsOperator( - input_uri=INPUT_URI, - output_uri=None, - video_context=None, - retry=Retry(maximum=10.0), - timeout=5, - task_id="detect_video_shots", - ) - # [END howto_operator_video_intelligence_detect_video_shots] - - # [START howto_operator_video_intelligence_detect_video_shots_result] - detect_video_shots_result = BashOperator( - bash_command="echo {{ task_instance.xcom_pull('detect_video_shots')" - "['annotationResults'][0]['shotAnnotations'][0]}}", - task_id="detect_video_shots_result", - ) - # [END howto_operator_video_intelligence_detect_video_shots_result] - - detect_video_label >> detect_video_label_result - detect_video_explicit_content >> detect_video_explicit_content_result - detect_video_shots >> detect_video_shots_result diff --git a/airflow/providers/google/cloud/example_dags/example_vision.py b/airflow/providers/google/cloud/example_dags/example_vision.py deleted file mode 100644 index eb3703d3598ed..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_vision.py +++ /dev/null @@ -1,531 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG that creates, gets, updates and deletes Products and Product Sets in the Google Cloud -Vision service. - -This DAG relies on the following OS environment variables - -* GCP_VISION_LOCATION - Zone where the instance exists. -* GCP_VISION_PRODUCT_SET_ID - Product Set ID. -* GCP_VISION_PRODUCT_ID - Product ID. -* GCP_VISION_REFERENCE_IMAGE_ID - Reference Image ID. -* GCP_VISION_REFERENCE_IMAGE_URL - A link to the bucket that contains the reference image. -* GCP_VISION_ANNOTATE_IMAGE_URL - A link to the bucket that contains the file to be annotated. - -""" - -import os -from datetime import datetime - -from airflow import models -from airflow.operators.bash import BashOperator -from airflow.providers.google.cloud.operators.vision import ( - CloudVisionAddProductToProductSetOperator, - CloudVisionCreateProductOperator, - CloudVisionCreateProductSetOperator, - CloudVisionCreateReferenceImageOperator, - CloudVisionDeleteProductOperator, - CloudVisionDeleteProductSetOperator, - CloudVisionDeleteReferenceImageOperator, - CloudVisionDetectImageLabelsOperator, - CloudVisionDetectImageSafeSearchOperator, - CloudVisionDetectTextOperator, - CloudVisionGetProductOperator, - CloudVisionGetProductSetOperator, - CloudVisionImageAnnotateOperator, - CloudVisionRemoveProductFromProductSetOperator, - CloudVisionTextDetectOperator, - CloudVisionUpdateProductOperator, - CloudVisionUpdateProductSetOperator, -) - -# [START howto_operator_vision_retry_import] - - -from google.api_core.retry import Retry # isort:skip - -# [END howto_operator_vision_retry_import] -# [START howto_operator_vision_product_set_import] -from google.cloud.vision_v1.types import ProductSet # isort:skip - -# [END howto_operator_vision_product_set_import] -# [START howto_operator_vision_product_import] -from google.cloud.vision_v1.types import Product # isort:skip - -# [END howto_operator_vision_product_import] -# [START howto_operator_vision_reference_image_import] -from google.cloud.vision_v1.types import ReferenceImage # isort:skip - -# [END howto_operator_vision_reference_image_import] -# [START howto_operator_vision_enums_import] -from google.cloud.vision import enums # isort:skip - -# [END howto_operator_vision_enums_import] - -START_DATE = datetime(2021, 1, 1) - -GCP_VISION_LOCATION = os.environ.get('GCP_VISION_LOCATION', 'europe-west1') - -GCP_VISION_PRODUCT_SET_ID = os.environ.get('GCP_VISION_PRODUCT_SET_ID', 'product_set_explicit_id') -GCP_VISION_PRODUCT_ID = os.environ.get('GCP_VISION_PRODUCT_ID', 'product_explicit_id') -GCP_VISION_REFERENCE_IMAGE_ID = os.environ.get('GCP_VISION_REFERENCE_IMAGE_ID', 'reference_image_explicit_id') -GCP_VISION_REFERENCE_IMAGE_URL = os.environ.get( - 'GCP_VISION_REFERENCE_IMAGE_URL', 'gs://INVALID BUCKET NAME/image1.jpg' -) -GCP_VISION_ANNOTATE_IMAGE_URL = os.environ.get( - 'GCP_VISION_ANNOTATE_IMAGE_URL', 'gs://INVALID BUCKET NAME/image2.jpg' -) - -# [START howto_operator_vision_product_set] -product_set = ProductSet(display_name='My Product Set') -# [END howto_operator_vision_product_set] - -# [START howto_operator_vision_product] -product = Product(display_name='My Product 1', product_category='toys') -# [END howto_operator_vision_product] - -# [START howto_operator_vision_reference_image] -reference_image = ReferenceImage(uri=GCP_VISION_REFERENCE_IMAGE_URL) -# [END howto_operator_vision_reference_image] - -# [START howto_operator_vision_annotate_image_request] -annotate_image_request = { - 'image': {'source': {'image_uri': GCP_VISION_ANNOTATE_IMAGE_URL}}, - 'features': [{'type': enums.Feature.Type.LOGO_DETECTION}], -} -# [END howto_operator_vision_annotate_image_request] - -# [START howto_operator_vision_detect_image_param] -DETECT_IMAGE = {"source": {"image_uri": GCP_VISION_ANNOTATE_IMAGE_URL}} -# [END howto_operator_vision_detect_image_param] - -with models.DAG( - 'example_gcp_vision_autogenerated_id', - schedule_interval='@once', - start_date=START_DATE, - catchup=False, -) as dag_autogenerated_id: - # ################################## # - # ### Autogenerated IDs examples ### # - # ################################## # - - # [START howto_operator_vision_product_set_create] - product_set_create = CloudVisionCreateProductSetOperator( - location=GCP_VISION_LOCATION, - product_set=product_set, - retry=Retry(maximum=10.0), - timeout=5, - task_id='product_set_create', - ) - # [END howto_operator_vision_product_set_create] - - product_set_create_output = product_set_create.output - - # [START howto_operator_vision_product_set_get] - product_set_get = CloudVisionGetProductSetOperator( - location=GCP_VISION_LOCATION, - product_set_id=product_set_create_output, - task_id='product_set_get', - ) - # [END howto_operator_vision_product_set_get] - - # [START howto_operator_vision_product_set_update] - product_set_update = CloudVisionUpdateProductSetOperator( - location=GCP_VISION_LOCATION, - product_set_id=product_set_create_output, - product_set=ProductSet(display_name='My Product Set 2'), - task_id='product_set_update', - ) - # [END howto_operator_vision_product_set_update] - - # [START howto_operator_vision_product_set_delete] - product_set_delete = CloudVisionDeleteProductSetOperator( - location=GCP_VISION_LOCATION, - product_set_id=product_set_create_output, - task_id='product_set_delete', - ) - # [END howto_operator_vision_product_set_delete] - - # [START howto_operator_vision_product_create] - product_create = CloudVisionCreateProductOperator( - location=GCP_VISION_LOCATION, - product=product, - retry=Retry(maximum=10.0), - timeout=5, - task_id='product_create', - ) - # [END howto_operator_vision_product_create] - - product_create_output = product_create.output - - # [START howto_operator_vision_product_get] - product_get = CloudVisionGetProductOperator( - location=GCP_VISION_LOCATION, - product_id=product_create_output, - task_id='product_get', - ) - # [END howto_operator_vision_product_get] - - # [START howto_operator_vision_product_update] - product_update = CloudVisionUpdateProductOperator( - location=GCP_VISION_LOCATION, - product_id=product_create_output, - product=Product(display_name='My Product 2', description='My updated description'), - task_id='product_update', - ) - # [END howto_operator_vision_product_update] - - # [START howto_operator_vision_product_delete] - product_delete = CloudVisionDeleteProductOperator( - location=GCP_VISION_LOCATION, - product_id=product_create_output, - task_id='product_delete', - ) - # [END howto_operator_vision_product_delete] - - # [START howto_operator_vision_reference_image_create] - reference_image_create = CloudVisionCreateReferenceImageOperator( - location=GCP_VISION_LOCATION, - reference_image=reference_image, - product_id=product_create_output, - reference_image_id=GCP_VISION_REFERENCE_IMAGE_ID, - retry=Retry(maximum=10.0), - timeout=5, - task_id='reference_image_create', - ) - # [END howto_operator_vision_reference_image_create] - - # [START howto_operator_vision_reference_image_delete] - reference_image_delete = CloudVisionDeleteReferenceImageOperator( - location=GCP_VISION_LOCATION, - product_id=product_create_output, - reference_image_id=GCP_VISION_REFERENCE_IMAGE_ID, - retry=Retry(maximum=10.0), - timeout=5, - task_id='reference_image_delete', - ) - # [END howto_operator_vision_reference_image_delete] - - # [START howto_operator_vision_add_product_to_product_set] - add_product_to_product_set = CloudVisionAddProductToProductSetOperator( - location=GCP_VISION_LOCATION, - product_set_id=product_set_create_output, - product_id=product_create_output, - retry=Retry(maximum=10.0), - timeout=5, - task_id='add_product_to_product_set', - ) - # [END howto_operator_vision_add_product_to_product_set] - - # [START howto_operator_vision_remove_product_from_product_set] - remove_product_from_product_set = CloudVisionRemoveProductFromProductSetOperator( - location=GCP_VISION_LOCATION, - product_set_id=product_set_create_output, - product_id=product_create_output, - retry=Retry(maximum=10.0), - timeout=5, - task_id='remove_product_from_product_set', - ) - # [END howto_operator_vision_remove_product_from_product_set] - - # Product path - product_create >> product_get >> product_update >> product_delete - - # ProductSet path - product_set_get >> product_set_update >> product_set_delete - - # ReferenceImage path - reference_image_create >> reference_image_delete >> product_delete - - # Product/ProductSet path - product_create >> add_product_to_product_set - add_product_to_product_set >> remove_product_from_product_set - remove_product_from_product_set >> product_delete - remove_product_from_product_set >> product_set_delete - - # Task dependencies created via `XComArgs`: - # product_set_create >> product_set_get - # product_set_create >> product_set_update - # product_set_create >> product_set_delete - # product_create >> product_get - # product_create >> product_delete - # product_create >> reference_image_create - # product_create >> reference_image_delete - # product_set_create >> add_product_to_product_set - # product_create >> add_product_to_product_set - # product_set_create >> remove_product_from_product_set - # product_create >> remove_product_from_product_set - - -with models.DAG( - 'example_gcp_vision_explicit_id', - schedule_interval='@once', - start_date=START_DATE, - catchup=False, -) as dag_explicit_id: - # ############################# # - # ### Explicit IDs examples ### # - # ############################# # - - # [START howto_operator_vision_product_set_create_2] - product_set_create_2 = CloudVisionCreateProductSetOperator( - product_set_id=GCP_VISION_PRODUCT_SET_ID, - location=GCP_VISION_LOCATION, - product_set=product_set, - retry=Retry(maximum=10.0), - timeout=5, - task_id='product_set_create_2', - ) - # [END howto_operator_vision_product_set_create_2] - - # Second 'create' task with the same product_set_id to demonstrate idempotence - product_set_create_2_idempotence = CloudVisionCreateProductSetOperator( - product_set_id=GCP_VISION_PRODUCT_SET_ID, - location=GCP_VISION_LOCATION, - product_set=product_set, - retry=Retry(maximum=10.0), - timeout=5, - task_id='product_set_create_2_idempotence', - ) - - # [START howto_operator_vision_product_set_get_2] - product_set_get_2 = CloudVisionGetProductSetOperator( - location=GCP_VISION_LOCATION, product_set_id=GCP_VISION_PRODUCT_SET_ID, task_id='product_set_get_2' - ) - # [END howto_operator_vision_product_set_get_2] - - # [START howto_operator_vision_product_set_update_2] - product_set_update_2 = CloudVisionUpdateProductSetOperator( - location=GCP_VISION_LOCATION, - product_set_id=GCP_VISION_PRODUCT_SET_ID, - product_set=ProductSet(display_name='My Product Set 2'), - task_id='product_set_update_2', - ) - # [END howto_operator_vision_product_set_update_2] - - # [START howto_operator_vision_product_set_delete_2] - product_set_delete_2 = CloudVisionDeleteProductSetOperator( - location=GCP_VISION_LOCATION, product_set_id=GCP_VISION_PRODUCT_SET_ID, task_id='product_set_delete_2' - ) - # [END howto_operator_vision_product_set_delete_2] - - # [START howto_operator_vision_product_create_2] - product_create_2 = CloudVisionCreateProductOperator( - product_id=GCP_VISION_PRODUCT_ID, - location=GCP_VISION_LOCATION, - product=product, - retry=Retry(maximum=10.0), - timeout=5, - task_id='product_create_2', - ) - # [END howto_operator_vision_product_create_2] - - # Second 'create' task with the same product_id to demonstrate idempotence - product_create_2_idempotence = CloudVisionCreateProductOperator( - product_id=GCP_VISION_PRODUCT_ID, - location=GCP_VISION_LOCATION, - product=product, - retry=Retry(maximum=10.0), - timeout=5, - task_id='product_create_2_idempotence', - ) - - # [START howto_operator_vision_product_get_2] - product_get_2 = CloudVisionGetProductOperator( - location=GCP_VISION_LOCATION, product_id=GCP_VISION_PRODUCT_ID, task_id='product_get_2' - ) - # [END howto_operator_vision_product_get_2] - - # [START howto_operator_vision_product_update_2] - product_update_2 = CloudVisionUpdateProductOperator( - location=GCP_VISION_LOCATION, - product_id=GCP_VISION_PRODUCT_ID, - product=Product(display_name='My Product 2', description='My updated description'), - task_id='product_update_2', - ) - # [END howto_operator_vision_product_update_2] - - # [START howto_operator_vision_product_delete_2] - product_delete_2 = CloudVisionDeleteProductOperator( - location=GCP_VISION_LOCATION, product_id=GCP_VISION_PRODUCT_ID, task_id='product_delete_2' - ) - # [END howto_operator_vision_product_delete_2] - - # [START howto_operator_vision_reference_image_create_2] - reference_image_create_2 = CloudVisionCreateReferenceImageOperator( - location=GCP_VISION_LOCATION, - reference_image=reference_image, - product_id=GCP_VISION_PRODUCT_ID, - reference_image_id=GCP_VISION_REFERENCE_IMAGE_ID, - retry=Retry(maximum=10.0), - timeout=5, - task_id='reference_image_create_2', - ) - # [END howto_operator_vision_reference_image_create_2] - - # [START howto_operator_vision_reference_image_delete_2] - reference_image_delete_2 = CloudVisionDeleteReferenceImageOperator( - location=GCP_VISION_LOCATION, - reference_image_id=GCP_VISION_REFERENCE_IMAGE_ID, - product_id=GCP_VISION_PRODUCT_ID, - retry=Retry(maximum=10.0), - timeout=5, - task_id='reference_image_delete_2', - ) - # [END howto_operator_vision_reference_image_delete_2] - - # Second 'create' task with the same product_id to demonstrate idempotence - reference_image_create_2_idempotence = CloudVisionCreateReferenceImageOperator( - location=GCP_VISION_LOCATION, - reference_image=reference_image, - product_id=GCP_VISION_PRODUCT_ID, - reference_image_id=GCP_VISION_REFERENCE_IMAGE_ID, - retry=Retry(maximum=10.0), - timeout=5, - task_id='reference_image_create_2_idempotence', - ) - - # [START howto_operator_vision_add_product_to_product_set_2] - add_product_to_product_set_2 = CloudVisionAddProductToProductSetOperator( - location=GCP_VISION_LOCATION, - product_set_id=GCP_VISION_PRODUCT_SET_ID, - product_id=GCP_VISION_PRODUCT_ID, - retry=Retry(maximum=10.0), - timeout=5, - task_id='add_product_to_product_set_2', - ) - # [END howto_operator_vision_add_product_to_product_set_2] - - # [START howto_operator_vision_remove_product_from_product_set_2] - remove_product_from_product_set_2 = CloudVisionRemoveProductFromProductSetOperator( - location=GCP_VISION_LOCATION, - product_set_id=GCP_VISION_PRODUCT_SET_ID, - product_id=GCP_VISION_PRODUCT_ID, - retry=Retry(maximum=10.0), - timeout=5, - task_id='remove_product_from_product_set_2', - ) - # [END howto_operator_vision_remove_product_from_product_set_2] - - # Product path - product_create_2 >> product_create_2_idempotence >> product_get_2 >> product_update_2 >> product_delete_2 - - # ProductSet path - product_set_create_2 >> product_set_get_2 >> product_set_update_2 >> product_set_delete_2 - product_set_create_2 >> product_set_create_2_idempotence >> product_set_delete_2 - - # ReferenceImage path - product_create_2 >> reference_image_create_2 >> reference_image_create_2_idempotence - reference_image_create_2_idempotence >> reference_image_delete_2 >> product_delete_2 - - # Product/ProductSet path - add_product_to_product_set_2 >> remove_product_from_product_set_2 - product_set_create_2 >> add_product_to_product_set_2 - product_create_2 >> add_product_to_product_set_2 - remove_product_from_product_set_2 >> product_set_delete_2 - remove_product_from_product_set_2 >> product_delete_2 - -with models.DAG( - 'example_gcp_vision_annotate_image', - schedule_interval='@once', - start_date=START_DATE, - catchup=False, -) as dag_annotate_image: - # ############################## # - # ### Annotate image example ### # - # ############################## # - - # [START howto_operator_vision_annotate_image] - annotate_image = CloudVisionImageAnnotateOperator( - request=annotate_image_request, retry=Retry(maximum=10.0), timeout=5, task_id='annotate_image' - ) - # [END howto_operator_vision_annotate_image] - - # [START howto_operator_vision_annotate_image_result] - annotate_image_result = BashOperator( - bash_command="echo {{ task_instance.xcom_pull('annotate_image')" - "['logoAnnotations'][0]['description'] }}", - task_id='annotate_image_result', - ) - # [END howto_operator_vision_annotate_image_result] - - # [START howto_operator_vision_detect_text] - detect_text = CloudVisionDetectTextOperator( - image=DETECT_IMAGE, - retry=Retry(maximum=10.0), - timeout=5, - task_id="detect_text", - language_hints="en", - web_detection_params={'include_geo_results': True}, - ) - # [END howto_operator_vision_detect_text] - - # [START howto_operator_vision_detect_text_result] - detect_text_result = BashOperator( - bash_command="echo {{ task_instance.xcom_pull('detect_text')['textAnnotations'][0] }}", - task_id="detect_text_result", - ) - # [END howto_operator_vision_detect_text_result] - - # [START howto_operator_vision_document_detect_text] - document_detect_text = CloudVisionTextDetectOperator( - image=DETECT_IMAGE, retry=Retry(maximum=10.0), timeout=5, task_id="document_detect_text" - ) - # [END howto_operator_vision_document_detect_text] - - # [START howto_operator_vision_document_detect_text_result] - document_detect_text_result = BashOperator( - bash_command="echo {{ task_instance.xcom_pull('document_detect_text')['textAnnotations'][0] }}", - task_id="document_detect_text_result", - ) - # [END howto_operator_vision_document_detect_text_result] - - # [START howto_operator_vision_detect_labels] - detect_labels = CloudVisionDetectImageLabelsOperator( - image=DETECT_IMAGE, retry=Retry(maximum=10.0), timeout=5, task_id="detect_labels" - ) - # [END howto_operator_vision_detect_labels] - - # [START howto_operator_vision_detect_labels_result] - detect_labels_result = BashOperator( - bash_command="echo {{ task_instance.xcom_pull('detect_labels')['labelAnnotations'][0] }}", - task_id="detect_labels_result", - ) - # [END howto_operator_vision_detect_labels_result] - - # [START howto_operator_vision_detect_safe_search] - detect_safe_search = CloudVisionDetectImageSafeSearchOperator( - image=DETECT_IMAGE, retry=Retry(maximum=10.0), timeout=5, task_id="detect_safe_search" - ) - # [END howto_operator_vision_detect_safe_search] - - # [START howto_operator_vision_detect_safe_search_result] - detect_safe_search_result = BashOperator( - bash_command=f"echo {detect_safe_search.output}", - task_id="detect_safe_search_result", - ) - # [END howto_operator_vision_detect_safe_search_result] - - annotate_image >> annotate_image_result - - detect_text >> detect_text_result - document_detect_text >> document_detect_text_result - detect_labels >> detect_labels_result - detect_safe_search >> detect_safe_search_result diff --git a/airflow/providers/google/cloud/hooks/automl.py b/airflow/providers/google/cloud/hooks/automl.py index dbe375ac5a08e..6d3af91e20ac7 100644 --- a/airflow/providers/google/cloud/hooks/automl.py +++ b/airflow/providers/google/cloud/hooks/automl.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """ This module contains a Google AutoML hook. @@ -23,21 +22,9 @@ PredictResponse """ -import sys -from typing import Dict, Optional, Sequence, Tuple, Union - -from google.cloud.automl_v1beta1.services.auto_ml.pagers import ( - ListColumnSpecsPager, - ListDatasetsPager, - ListTableSpecsPager, -) +from __future__ import annotations -from airflow.providers.google.common.consts import CLIENT_INFO - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property +from typing import Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.operation import Operation @@ -54,8 +41,15 @@ PredictionServiceClient, PredictResponse, ) +from google.cloud.automl_v1beta1.services.auto_ml.pagers import ( + ListColumnSpecsPager, + ListDatasetsPager, + ListTableSpecsPager, +) from google.protobuf.field_mask_pb2 import FieldMask +from airflow.compat.functools import cached_property +from airflow.providers.google.common.consts import CLIENT_INFO from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook @@ -70,18 +64,18 @@ class CloudAutoMLHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) - self._client = None # type: Optional[AutoMlClient] + self._client: AutoMlClient | None = None @staticmethod - def extract_object_id(obj: Dict) -> str: + def extract_object_id(obj: dict) -> str: """Returns unique id of the object.""" return obj["name"].rpartition("/")[-1] @@ -90,10 +84,9 @@ def get_conn(self) -> AutoMlClient: Retrieves connection to AutoML. :return: Google Cloud AutoML client object. - :rtype: google.cloud.automl_v1beta1.AutoMlClient """ if self._client is None: - self._client = AutoMlClient(credentials=self._get_credentials(), client_info=CLIENT_INFO) + self._client = AutoMlClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) return self._client @cached_property @@ -102,19 +95,18 @@ def prediction_client(self) -> PredictionServiceClient: Creates PredictionServiceClient. :return: Google Cloud AutoML PredictionServiceClient client object. - :rtype: google.cloud.automl_v1beta1.PredictionServiceClient """ - return PredictionServiceClient(credentials=self._get_credentials(), client_info=CLIENT_INFO) + return PredictionServiceClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) @GoogleBaseHook.fallback_to_default_project_id def create_model( self, - model: Union[dict, Model], + model: dict | Model, location: str, project_id: str = PROVIDE_PROJECT_ID, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - retry: Union[Retry, _MethodDefault] = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, ) -> Operation: """ Creates a model_id. Returns a Model in the `response` field when it @@ -138,7 +130,7 @@ def create_model( client = self.get_conn() parent = f"projects/{project_id}/locations/{location}" return client.create_model( - request={'parent': parent, 'model': model}, + request={"parent": parent, "model": model}, retry=retry, timeout=timeout, metadata=metadata, @@ -148,14 +140,14 @@ def create_model( def batch_predict( self, model_id: str, - input_config: Union[dict, BatchPredictInputConfig], - output_config: Union[dict, BatchPredictOutputConfig], + input_config: dict | BatchPredictInputConfig, + output_config: dict | BatchPredictOutputConfig, location: str, project_id: str = PROVIDE_PROJECT_ID, - params: Optional[Dict[str, str]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + params: dict[str, str] | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Perform a batch prediction. Unlike the online `Predict`, batch @@ -186,10 +178,10 @@ def batch_predict( name = f"projects/{project_id}/locations/{location}/models/{model_id}" result = client.batch_predict( request={ - 'name': name, - 'input_config': input_config, - 'output_config': output_config, - 'params': params, + "name": name, + "input_config": input_config, + "output_config": output_config, + "params": params, }, retry=retry, timeout=timeout, @@ -201,13 +193,13 @@ def batch_predict( def predict( self, model_id: str, - payload: Union[dict, ExamplePayload], + payload: dict | ExamplePayload, location: str, project_id: str = PROVIDE_PROJECT_ID, - params: Optional[Dict[str, str]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + params: dict[str, str] | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> PredictResponse: """ Perform an online prediction. The prediction result will be directly @@ -232,7 +224,7 @@ def predict( client = self.prediction_client name = f"projects/{project_id}/locations/{location}/models/{model_id}" result = client.predict( - request={'name': name, 'payload': payload, 'params': params}, + request={"name": name, "payload": payload, "params": params}, retry=retry, timeout=timeout, metadata=metadata, @@ -242,12 +234,12 @@ def predict( @GoogleBaseHook.fallback_to_default_project_id def create_dataset( self, - dataset: Union[dict, Dataset], + dataset: dict | Dataset, location: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Dataset: """ Creates a dataset. @@ -268,7 +260,7 @@ def create_dataset( client = self.get_conn() parent = f"projects/{project_id}/locations/{location}" result = client.create_dataset( - request={'parent': parent, 'dataset': dataset}, + request={"parent": parent, "dataset": dataset}, retry=retry, timeout=timeout, metadata=metadata, @@ -280,11 +272,11 @@ def import_data( self, dataset_id: str, location: str, - input_config: Union[dict, InputConfig], + input_config: dict | InputConfig, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Imports data into a dataset. For Tables this method can only be called on an empty Dataset. @@ -306,7 +298,7 @@ def import_data( client = self.get_conn() name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" result = client.import_data( - request={'name': name, 'input_config': input_config}, + request={"name": name, "input_config": input_config}, retry=retry, timeout=timeout, metadata=metadata, @@ -320,12 +312,12 @@ def list_column_specs( table_spec_id: str, location: str, project_id: str = PROVIDE_PROJECT_ID, - field_mask: Optional[Union[dict, FieldMask]] = None, - filter_: Optional[str] = None, - page_size: Optional[int] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + field_mask: dict | FieldMask | None = None, + filter_: str | None = None, + page_size: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ListColumnSpecsPager: """ Lists column specs in a table spec. @@ -359,7 +351,7 @@ def list_column_specs( table_spec=table_spec_id, ) result = client.list_column_specs( - request={'parent': parent, 'field_mask': field_mask, 'filter': filter_, 'page_size': page_size}, + request={"parent": parent, "field_mask": field_mask, "filter": filter_, "page_size": page_size}, retry=retry, timeout=timeout, metadata=metadata, @@ -372,9 +364,9 @@ def get_model( model_id: str, location: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Model: """ Gets a AutoML model. @@ -394,7 +386,7 @@ def get_model( client = self.get_conn() name = f"projects/{project_id}/locations/{location}/models/{model_id}" result = client.get_model( - request={'name': name}, + request={"name": name}, retry=retry, timeout=timeout, metadata=metadata, @@ -407,9 +399,9 @@ def delete_model( model_id: str, location: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Deletes a AutoML model. @@ -429,7 +421,7 @@ def delete_model( client = self.get_conn() name = f"projects/{project_id}/locations/{location}/models/{model_id}" result = client.delete_model( - request={'name': name}, + request={"name": name}, retry=retry, timeout=timeout, metadata=metadata, @@ -438,11 +430,11 @@ def delete_model( def update_dataset( self, - dataset: Union[dict, Dataset], - update_mask: Optional[Union[dict, FieldMask]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + dataset: dict | Dataset, + update_mask: dict | FieldMask | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Dataset: """ Updates a dataset. @@ -461,7 +453,7 @@ def update_dataset( """ client = self.get_conn() result = client.update_dataset( - request={'dataset': dataset, 'update_mask': update_mask}, + request={"dataset": dataset, "update_mask": update_mask}, retry=retry, timeout=timeout, metadata=metadata, @@ -474,15 +466,15 @@ def deploy_model( model_id: str, location: str, project_id: str = PROVIDE_PROJECT_ID, - image_detection_metadata: Optional[Union[ImageObjectDetectionModelDeploymentMetadata, dict]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + image_detection_metadata: ImageObjectDetectionModelDeploymentMetadata | dict | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Deploys a model. If a model is already deployed, deploying it with the same parameters has no effect. Deploying with different parameters (as e.g. changing node_number) will - reset the deployment state without pausing the model_id’s availability. + reset the deployment state without pausing the model_id's availability. Only applicable for Text Classification, Image Object Detection and Tables; all other domains manage deployment automatically. @@ -506,8 +498,8 @@ def deploy_model( name = f"projects/{project_id}/locations/{location}/models/{model_id}" result = client.deploy_model( request={ - 'name': name, - 'image_object_detection_model_deployment_metadata': image_detection_metadata, + "name": name, + "image_object_detection_model_deployment_metadata": image_detection_metadata, }, retry=retry, timeout=timeout, @@ -519,12 +511,12 @@ def list_table_specs( self, dataset_id: str, location: str, - project_id: Optional[str] = None, - filter_: Optional[str] = None, - page_size: Optional[int] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + filter_: str | None = None, + page_size: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ListTableSpecsPager: """ Lists table specs in a dataset_id. @@ -553,7 +545,7 @@ def list_table_specs( client = self.get_conn() parent = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" result = client.list_table_specs( - request={'parent': parent, 'filter': filter_, 'page_size': page_size}, + request={"parent": parent, "filter": filter_, "page_size": page_size}, retry=retry, timeout=timeout, metadata=metadata, @@ -565,9 +557,9 @@ def list_datasets( self, location: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ListDatasetsPager: """ Lists datasets in a project. @@ -589,7 +581,7 @@ def list_datasets( client = self.get_conn() parent = f"projects/{project_id}/locations/{location}" result = client.list_datasets( - request={'parent': parent}, + request={"parent": parent}, retry=retry, timeout=timeout, metadata=metadata, @@ -602,9 +594,9 @@ def delete_dataset( dataset_id: str, location: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Deletes a dataset and all of its contents. @@ -624,7 +616,7 @@ def delete_dataset( client = self.get_conn() name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" result = client.delete_dataset( - request={'name': name}, + request={"name": name}, retry=retry, timeout=timeout, metadata=metadata, diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index d4f54f56cef09..48bf39c12688d 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -15,20 +15,25 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """ This module contains a BigQuery Hook, as well as a very basic PEP 249 implementation for BigQuery. """ +from __future__ import annotations + import hashlib import json import logging +import re import time +import uuid import warnings from copy import deepcopy from datetime import datetime, timedelta -from typing import Any, Dict, Iterable, List, Mapping, NoReturn, Optional, Sequence, Tuple, Type, Union +from typing import Any, Iterable, Mapping, NoReturn, Sequence, Union, cast +from aiohttp import ClientSession as ClientSession +from gcloud.aio.bigquery import Job, Table as Table_async from google.api_core.retry import Retry from google.cloud.bigquery import ( DEFAULT_RETRY, @@ -47,12 +52,14 @@ from pandas import DataFrame from pandas_gbq import read_gbq from pandas_gbq.gbq import GbqConnector # noqa +from requests import Session from sqlalchemy import create_engine from airflow.exceptions import AirflowException -from airflow.hooks.dbapi import DbApiHook +from airflow.providers.common.sql.hooks.sql import DbApiHook +from airflow.providers.google.cloud.utils.bigquery import bq_cast from airflow.providers.google.common.consts import CLIENT_INFO -from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook, get_field from airflow.utils.helpers import convert_camel_to_snake from airflow.utils.log.logging_mixin import LoggingMixin @@ -75,20 +82,20 @@ class BigQueryHook(GoogleBaseHook, DbApiHook): :param labels: The BigQuery resource label. """ - conn_name_attr = 'gcp_conn_id' - default_conn_name = 'google_cloud_bigquery_default' - conn_type = 'gcpbigquery' - hook_name = 'Google Bigquery' + conn_name_attr = "gcp_conn_id" + default_conn_name = "google_cloud_bigquery_default" + conn_type = "gcpbigquery" + hook_name = "Google Bigquery" def __init__( self, gcp_conn_id: str = GoogleBaseHook.default_conn_name, - delegate_to: Optional[str] = None, + delegate_to: str | None = None, use_legacy_sql: bool = True, - location: Optional[str] = None, - api_resource_configs: Optional[Dict] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - labels: Optional[Dict] = None, + location: str | None = None, + api_resource_configs: dict | None = None, + impersonation_chain: str | Sequence[str] | None = None, + labels: dict | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -97,12 +104,12 @@ def __init__( ) self.use_legacy_sql = use_legacy_sql self.location = location - self.running_job_id = None # type: Optional[str] - self.api_resource_configs = api_resource_configs if api_resource_configs else {} # type Dict + self.running_job_id: str | None = None + self.api_resource_configs: dict = api_resource_configs if api_resource_configs else {} self.labels = labels self.credentials_path = "bigquery_hook_credentials.json" - def get_conn(self) -> "BigQueryConnection": + def get_conn(self) -> BigQueryConnection: """Returns a BigQuery PEP 249 connection object.""" service = self.get_service() return BigQueryConnection( @@ -120,9 +127,9 @@ def get_service(self) -> Resource: "This method will be deprecated. Please use `BigQueryHook.get_client` method", DeprecationWarning ) http_authorized = self._authorize() - return build('bigquery', 'v2', http=http_authorized, cache_discovery=False) + return build("bigquery", "v2", http=http_authorized, cache_discovery=False) - def get_client(self, project_id: Optional[str] = None, location: Optional[str] = None) -> Client: + def get_client(self, project_id: str | None = None, location: str | None = None) -> Client: """ Returns authenticated BigQuery Client. @@ -134,7 +141,7 @@ def get_client(self, project_id: Optional[str] = None, location: Optional[str] = client_info=CLIENT_INFO, project=project_id, location=location, - credentials=self._get_credentials(), + credentials=self.get_credentials(), ) def get_uri(self) -> str: @@ -150,15 +157,14 @@ def get_sqlalchemy_engine(self, engine_kwargs=None): """ if engine_kwargs is None: engine_kwargs = {} - connection = self.get_connection(self.gcp_conn_id) - if connection.extra_dejson.get("extra__google_cloud_platform__key_path"): - credentials_path = connection.extra_dejson['extra__google_cloud_platform__key_path'] + extras = self.get_connection(self.gcp_conn_id).extra_dejson + credentials_path = get_field(extras, "key_path") + if credentials_path: return create_engine(self.get_uri(), credentials_path=credentials_path, **engine_kwargs) - elif connection.extra_dejson.get("extra__google_cloud_platform__keyfile_dict"): - credential_file_content = json.loads( - connection.extra_dejson["extra__google_cloud_platform__keyfile_dict"] - ) - return create_engine(self.get_uri(), credentials_info=credential_file_content, **engine_kwargs) + keyfile_dict = get_field(extras, "keyfile_dict") + if keyfile_dict: + keyfile_content = keyfile_dict if isinstance(keyfile_dict, dict) else json.loads(keyfile_dict) + return create_engine(self.get_uri(), credentials_info=keyfile_content, **engine_kwargs) try: # 1. If the environment variable GOOGLE_APPLICATION_CREDENTIALS is set # ADC uses the service account key or configuration file that the variable points to. @@ -169,9 +175,7 @@ def get_sqlalchemy_engine(self, engine_kwargs=None): self.log.error(e) raise AirflowException( "For now, we only support instantiating SQLAlchemy engine by" - " using ADC" - ", extra__google_cloud_platform__key_path" - "and extra__google_cloud_platform__keyfile_dict" + " using ADC or extra fields `key_path` and `keyfile_dict`." ) def get_records(self, sql, parameters=None): @@ -181,11 +185,11 @@ def get_records(self, sql, parameters=None): @staticmethod def _resolve_table_reference( - table_resource: Dict[str, Any], - project_id: Optional[str] = None, - dataset_id: Optional[str] = None, - table_id: Optional[str] = None, - ) -> Dict[str, Any]: + table_resource: dict[str, Any], + project_id: str | None = None, + dataset_id: str | None = None, + table_id: str | None = None, + ) -> dict[str, Any]: try: # Check if tableReference is present and is valid TableReference.from_api_repr(table_resource["tableReference"]) @@ -223,8 +227,8 @@ def insert_rows( def get_pandas_df( self, sql: str, - parameters: Optional[Union[Iterable, Mapping]] = None, - dialect: Optional[str] = None, + parameters: Iterable | Mapping | None = None, + dialect: str | None = None, **kwargs, ) -> DataFrame: """ @@ -243,9 +247,9 @@ def get_pandas_df( :param kwargs: (optional) passed into pandas_gbq.read_gbq method """ if dialect is None: - dialect = 'legacy' if self.use_legacy_sql else 'standard' + dialect = "legacy" if self.use_legacy_sql else "standard" - credentials, project_id = self._get_credentials_and_project_id() + credentials, project_id = self.get_credentials_and_project_id() return read_gbq( sql, project_id=project_id, dialect=dialect, verbose=False, credentials=credentials, **kwargs @@ -294,19 +298,19 @@ def table_partition_exists( @GoogleBaseHook.fallback_to_default_project_id def create_empty_table( self, - project_id: Optional[str] = None, - dataset_id: Optional[str] = None, - table_id: Optional[str] = None, - table_resource: Optional[Dict[str, Any]] = None, - schema_fields: Optional[List] = None, - time_partitioning: Optional[Dict] = None, - cluster_fields: Optional[List[str]] = None, - labels: Optional[Dict] = None, - view: Optional[Dict] = None, - materialized_view: Optional[Dict] = None, - encryption_configuration: Optional[Dict] = None, - retry: Optional[Retry] = DEFAULT_RETRY, - location: Optional[str] = None, + project_id: str | None = None, + dataset_id: str | None = None, + table_id: str | None = None, + table_resource: dict[str, Any] | None = None, + schema_fields: list | None = None, + time_partitioning: dict | None = None, + cluster_fields: list[str] | None = None, + labels: dict | None = None, + view: dict | None = None, + materialized_view: dict | None = None, + encryption_configuration: dict | None = None, + retry: Retry | None = DEFAULT_RETRY, + location: str | None = None, exists_ok: bool = True, ) -> Table: """ @@ -361,29 +365,28 @@ def create_empty_table( :param exists_ok: If ``True``, ignore "already exists" errors when creating the table. :return: Created table """ - - _table_resource: Dict[str, Any] = {} + _table_resource: dict[str, Any] = {} if self.location: - _table_resource['location'] = self.location + _table_resource["location"] = self.location if schema_fields: - _table_resource['schema'] = {'fields': schema_fields} + _table_resource["schema"] = {"fields": schema_fields} if time_partitioning: - _table_resource['timePartitioning'] = time_partitioning + _table_resource["timePartitioning"] = time_partitioning if cluster_fields: - _table_resource['clustering'] = {'fields': cluster_fields} + _table_resource["clustering"] = {"fields": cluster_fields} if labels: - _table_resource['labels'] = labels + _table_resource["labels"] = labels if view: - _table_resource['view'] = view + _table_resource["view"] = view if materialized_view: - _table_resource['materializedView'] = materialized_view + _table_resource["materializedView"] = materialized_view if encryption_configuration: _table_resource["encryptionConfiguration"] = encryption_configuration @@ -403,12 +406,12 @@ def create_empty_table( @GoogleBaseHook.fallback_to_default_project_id def create_empty_dataset( self, - dataset_id: Optional[str] = None, - project_id: Optional[str] = None, - location: Optional[str] = None, - dataset_reference: Optional[Dict[str, Any]] = None, + dataset_id: str | None = None, + project_id: str | None = None, + location: str | None = None, + dataset_reference: dict[str, Any] | None = None, exists_ok: bool = True, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create a new empty dataset: https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/insert @@ -447,25 +450,26 @@ def create_empty_dataset( dataset_reference["datasetReference"][param] = value location = location or self.location + project_id = project_id or self.project_id if location: dataset_reference["location"] = dataset_reference.get("location", location) dataset: Dataset = Dataset.from_api_repr(dataset_reference) - self.log.info('Creating dataset: %s in project: %s ', dataset.dataset_id, dataset.project) - dataset_object = self.get_client(location=location).create_dataset( + self.log.info("Creating dataset: %s in project: %s ", dataset.dataset_id, dataset.project) + dataset_object = self.get_client(project_id=project_id, location=location).create_dataset( dataset=dataset, exists_ok=exists_ok ) - self.log.info('Dataset created successfully.') + self.log.info("Dataset created successfully.") return dataset_object.to_api_repr() @GoogleBaseHook.fallback_to_default_project_id def get_dataset_tables( self, dataset_id: str, - project_id: Optional[str] = None, - max_results: Optional[int] = None, + project_id: str | None = None, + max_results: int | None = None, retry: Retry = DEFAULT_RETRY, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """ Get the list of tables for a given dataset. @@ -479,7 +483,7 @@ def get_dataset_tables( :param retry: How to retry the RPC. :return: List of tables associated with the dataset. """ - self.log.info('Start getting tables list from dataset: %s.%s', project_id, dataset_id) + self.log.info("Start getting tables list from dataset: %s.%s", project_id, dataset_id) tables = self.get_client().list_tables( dataset=DatasetReference(project=project_id, dataset_id=dataset_id), max_results=max_results, @@ -492,7 +496,7 @@ def get_dataset_tables( def delete_dataset( self, dataset_id: str, - project_id: Optional[str] = None, + project_id: str | None = None, delete_contents: bool = False, retry: Retry = DEFAULT_RETRY, ) -> None: @@ -505,7 +509,7 @@ def delete_dataset( If False and the dataset contains tables, the request will fail. :param retry: How to retry the RPC. """ - self.log.info('Deleting from project: %s Dataset:%s', project_id, dataset_id) + self.log.info("Deleting from project: %s Dataset:%s", project_id, dataset_id) self.get_client(project_id=project_id).delete_dataset( dataset=DatasetReference(project=project_id, dataset_id=dataset_id), delete_contents=delete_contents, @@ -517,25 +521,25 @@ def delete_dataset( def create_external_table( self, external_project_dataset_table: str, - schema_fields: List, - source_uris: List, - source_format: str = 'CSV', + schema_fields: list, + source_uris: list, + source_format: str = "CSV", autodetect: bool = False, - compression: str = 'NONE', + compression: str = "NONE", ignore_unknown_values: bool = False, max_bad_records: int = 0, skip_leading_rows: int = 0, - field_delimiter: str = ',', - quote_character: Optional[str] = None, + field_delimiter: str = ",", + quote_character: str | None = None, allow_quoted_newlines: bool = False, allow_jagged_rows: bool = False, encoding: str = "UTF-8", - src_fmt_configs: Optional[Dict] = None, - labels: Optional[Dict] = None, - description: Optional[str] = None, - encryption_configuration: Optional[Dict] = None, - location: Optional[str] = None, - project_id: Optional[str] = None, + src_fmt_configs: dict | None = None, + labels: dict | None = None, + description: str | None = None, + encryption_configuration: dict | None = None, + location: str | None = None, + project_id: str | None = None, ) -> Table: """ Creates a new external table in the dataset with the data from Google @@ -606,34 +610,34 @@ def create_external_table( compression = compression.upper() external_config_api_repr = { - 'autodetect': autodetect, - 'sourceFormat': source_format, - 'sourceUris': source_uris, - 'compression': compression, - 'ignoreUnknownValues': ignore_unknown_values, + "autodetect": autodetect, + "sourceFormat": source_format, + "sourceUris": source_uris, + "compression": compression, + "ignoreUnknownValues": ignore_unknown_values, } # if following fields are not specified in src_fmt_configs, # honor the top-level params for backward-compatibility backward_compatibility_configs = { - 'skipLeadingRows': skip_leading_rows, - 'fieldDelimiter': field_delimiter, - 'quote': quote_character, - 'allowQuotedNewlines': allow_quoted_newlines, - 'allowJaggedRows': allow_jagged_rows, - 'encoding': encoding, + "skipLeadingRows": skip_leading_rows, + "fieldDelimiter": field_delimiter, + "quote": quote_character, + "allowQuotedNewlines": allow_quoted_newlines, + "allowJaggedRows": allow_jagged_rows, + "encoding": encoding, } - src_fmt_to_param_mapping = {'CSV': 'csvOptions', 'GOOGLE_SHEETS': 'googleSheetsOptions'} + src_fmt_to_param_mapping = {"CSV": "csvOptions", "GOOGLE_SHEETS": "googleSheetsOptions"} src_fmt_to_configs_mapping = { - 'csvOptions': [ - 'allowJaggedRows', - 'allowQuotedNewlines', - 'fieldDelimiter', - 'skipLeadingRows', - 'quote', - 'encoding', + "csvOptions": [ + "allowJaggedRows", + "allowQuotedNewlines", + "fieldDelimiter", + "skipLeadingRows", + "quote", + "encoding", ], - 'googleSheetsOptions': ['skipLeadingRows'], + "googleSheetsOptions": ["skipLeadingRows"], } if source_format in src_fmt_to_param_mapping.keys(): valid_configs = src_fmt_to_configs_mapping[src_fmt_to_param_mapping[source_format]] @@ -661,22 +665,22 @@ def create_external_table( if encryption_configuration: table.encryption_configuration = EncryptionConfiguration.from_api_repr(encryption_configuration) - self.log.info('Creating external table: %s', external_project_dataset_table) + self.log.info("Creating external table: %s", external_project_dataset_table) table_object = self.create_empty_table( table_resource=table.to_api_repr(), project_id=project_id, location=location, exists_ok=True ) - self.log.info('External table created successfully: %s', external_project_dataset_table) + self.log.info("External table created successfully: %s", external_project_dataset_table) return table_object @GoogleBaseHook.fallback_to_default_project_id def update_table( self, - table_resource: Dict[str, Any], - fields: Optional[List[str]] = None, - dataset_id: Optional[str] = None, - table_id: Optional[str] = None, - project_id: Optional[str] = None, - ) -> Dict[str, Any]: + table_resource: dict[str, Any], + fields: list[str] | None = None, + dataset_id: str | None = None, + table_id: str | None = None, + project_id: str | None = None, + ) -> dict[str, Any]: """ Change some fields of a table. @@ -706,9 +710,9 @@ def update_table( ) table = Table.from_api_repr(table_resource) - self.log.info('Updating table: %s', table_resource["tableReference"]) + self.log.info("Updating table: %s", table_resource["tableReference"]) table_object = self.get_client(project_id=project_id).update_table(table=table, fields=fields) - self.log.info('Table %s.%s.%s updated successfully', project_id, dataset_id, table_id) + self.log.info("Table %s.%s.%s updated successfully", project_id, dataset_id, table_id) return table_object.to_api_repr() @GoogleBaseHook.fallback_to_default_project_id @@ -716,17 +720,17 @@ def patch_table( self, dataset_id: str, table_id: str, - project_id: Optional[str] = None, - description: Optional[str] = None, - expiration_time: Optional[int] = None, - external_data_configuration: Optional[Dict] = None, - friendly_name: Optional[str] = None, - labels: Optional[Dict] = None, - schema: Optional[List] = None, - time_partitioning: Optional[Dict] = None, - view: Optional[Dict] = None, - require_partition_filter: Optional[bool] = None, - encryption_configuration: Optional[Dict] = None, + project_id: str | None = None, + description: str | None = None, + expiration_time: int | None = None, + external_data_configuration: dict | None = None, + friendly_name: str | None = None, + labels: dict | None = None, + schema: list | None = None, + time_partitioning: dict | None = None, + view: dict | None = None, + require_partition_filter: bool | None = None, + encryption_configuration: dict | None = None, ) -> None: """ Patch information in an existing table. @@ -779,26 +783,26 @@ def patch_table( "This method is deprecated, please use ``BigQueryHook.update_table`` method.", DeprecationWarning, ) - table_resource: Dict[str, Any] = {} + table_resource: dict[str, Any] = {} if description is not None: - table_resource['description'] = description + table_resource["description"] = description if expiration_time is not None: - table_resource['expirationTime'] = expiration_time + table_resource["expirationTime"] = expiration_time if external_data_configuration: - table_resource['externalDataConfiguration'] = external_data_configuration + table_resource["externalDataConfiguration"] = external_data_configuration if friendly_name is not None: - table_resource['friendlyName'] = friendly_name + table_resource["friendlyName"] = friendly_name if labels: - table_resource['labels'] = labels + table_resource["labels"] = labels if schema: - table_resource['schema'] = {'fields': schema} + table_resource["schema"] = {"fields": schema} if time_partitioning: - table_resource['timePartitioning'] = time_partitioning + table_resource["timePartitioning"] = time_partitioning if view: - table_resource['view'] = view + table_resource["view"] = view if require_partition_filter is not None: - table_resource['requirePartitionFilter'] = require_partition_filter + table_resource["requirePartitionFilter"] = require_partition_filter if encryption_configuration: table_resource["encryptionConfiguration"] = encryption_configuration @@ -816,7 +820,7 @@ def insert_all( project_id: str, dataset_id: str, table_id: str, - rows: List, + rows: list, ignore_unknown_values: bool = False, skip_invalid_rows: bool = False, fail_on_error: bool = False, @@ -847,7 +851,7 @@ def insert_all( The default value is false, which indicates the task should not fail even if any insertion errors occur. """ - self.log.info('Inserting %s row(s) into table %s:%s.%s', len(rows), project_id, dataset_id, table_id) + self.log.info("Inserting %s row(s) into table %s:%s.%s", len(rows), project_id, dataset_id, table_id) table_ref = TableReference(dataset_ref=DatasetReference(project_id, dataset_id), table_id=table_id) bq_client = self.get_client(project_id=project_id) @@ -862,17 +866,17 @@ def insert_all( error_msg = f"{len(errors)} insert error(s) occurred. Details: {errors}" self.log.error(error_msg) if fail_on_error: - raise AirflowException(f'BigQuery job failed. Error was: {error_msg}') + raise AirflowException(f"BigQuery job failed. Error was: {error_msg}") else: - self.log.info('All row(s) inserted successfully: %s:%s.%s', project_id, dataset_id, table_id) + self.log.info("All row(s) inserted successfully: %s:%s.%s", project_id, dataset_id, table_id) @GoogleBaseHook.fallback_to_default_project_id def update_dataset( self, fields: Sequence[str], - dataset_resource: Dict[str, Any], - dataset_id: Optional[str] = None, - project_id: Optional[str] = None, + dataset_resource: dict[str, Any], + dataset_id: str | None = None, + project_id: str | None = None, retry: Retry = DEFAULT_RETRY, ) -> Dataset: """ @@ -904,7 +908,7 @@ def update_dataset( if value and not spec_value: dataset_resource["datasetReference"][key] = value - self.log.info('Start updating dataset') + self.log.info("Start updating dataset") dataset = self.get_client(project_id=project_id).update_dataset( dataset=Dataset.from_api_repr(dataset_resource), fields=fields, @@ -913,9 +917,7 @@ def update_dataset( self.log.info("Dataset successfully updated: %s", dataset) return dataset - def patch_dataset( - self, dataset_id: str, dataset_resource: Dict, project_id: Optional[str] = None - ) -> Dict: + def patch_dataset(self, dataset_id: str, dataset_resource: dict, project_id: str | None = None) -> dict: """ Patches information in an existing dataset. It only replaces fields that are provided in the submitted dataset resource. @@ -927,8 +929,6 @@ def patch_dataset( in request body. https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource :param project_id: The Google Cloud Project ID - :rtype: dataset - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource """ warnings.warn("This method is deprecated. Please use ``update_dataset``.", DeprecationWarning) project_id = project_id or self.project_id @@ -940,7 +940,7 @@ def patch_dataset( service = self.get_service() dataset_project_id = project_id or self.project_id - self.log.info('Start patching dataset: %s:%s', dataset_project_id, dataset_id) + self.log.info("Start patching dataset: %s:%s", dataset_project_id, dataset_id) dataset = ( service.datasets() .patch( @@ -957,10 +957,10 @@ def patch_dataset( def get_dataset_tables_list( self, dataset_id: str, - project_id: Optional[str] = None, - table_prefix: Optional[str] = None, - max_results: Optional[int] = None, - ) -> List[Dict[str, Any]]: + project_id: str | None = None, + table_prefix: str | None = None, + max_results: int | None = None, + ) -> list[dict[str, Any]]: """ Method returns tables list of a BigQuery tables. If table prefix is specified, only tables beginning by it are returned. @@ -993,13 +993,13 @@ def get_dataset_tables_list( @GoogleBaseHook.fallback_to_default_project_id def get_datasets_list( self, - project_id: Optional[str] = None, + project_id: str | None = None, include_all: bool = False, - filter_: Optional[str] = None, - max_results: Optional[int] = None, - page_token: Optional[str] = None, + filter_: str | None = None, + max_results: int | None = None, + page_token: str | None = None, retry: Retry = DEFAULT_RETRY, - ) -> List[DatasetListItem]: + ) -> list[DatasetListItem]: """ Method returns full list of BigQuery datasets in the current project @@ -1034,7 +1034,7 @@ def get_datasets_list( return datasets_list @GoogleBaseHook.fallback_to_default_project_id - def get_dataset(self, dataset_id: str, project_id: Optional[str] = None) -> Dataset: + def get_dataset(self, dataset_id: str, project_id: str | None = None) -> Dataset: """ Fetch the dataset referenced by dataset_id. @@ -1058,9 +1058,9 @@ def run_grant_dataset_view_access( source_dataset: str, view_dataset: str, view_table: str, - view_project: Optional[str] = None, - project_id: Optional[str] = None, - ) -> Dict[str, Any]: + view_project: str | None = None, + project_id: str | None = None, + ) -> dict[str, Any]: """ Grant authorized view access of a dataset to a view table. If this view has already been granted access to the dataset, do nothing. @@ -1079,7 +1079,7 @@ def run_grant_dataset_view_access( view_access = AccessEntry( role=None, entity_type="view", - entity_id={'projectId': view_project, 'datasetId': view_dataset, 'tableId': view_table}, + entity_id={"projectId": view_project, "datasetId": view_dataset, "tableId": view_table}, ) dataset = self.get_dataset(project_id=project_id, dataset_id=source_dataset) @@ -1087,7 +1087,7 @@ def run_grant_dataset_view_access( # Check to see if the view we want to add already exists. if view_access not in dataset.access_entries: self.log.info( - 'Granting table %s:%s.%s authorized view access to %s:%s dataset.', + "Granting table %s:%s.%s authorized view access to %s:%s dataset.", view_project, view_dataset, view_table, @@ -1100,7 +1100,7 @@ def run_grant_dataset_view_access( ) else: self.log.info( - 'Table %s:%s.%s already has authorized view access to %s:%s dataset.', + "Table %s:%s.%s already has authorized view access to %s:%s dataset.", view_project, view_dataset, view_table, @@ -1111,8 +1111,8 @@ def run_grant_dataset_view_access( @GoogleBaseHook.fallback_to_default_project_id def run_table_upsert( - self, dataset_id: str, table_resource: Dict[str, Any], project_id: Optional[str] = None - ) -> Dict[str, Any]: + self, dataset_id: str, table_resource: dict[str, Any], project_id: str | None = None + ) -> dict[str, Any]: """ If the table already exists, update the existing table if not create new. Since BigQuery does not natively allow table upserts, this is not an @@ -1125,17 +1125,17 @@ def run_table_upsert( project will be self.project_id. :return: """ - table_id = table_resource['tableReference']['tableId'] + table_id = table_resource["tableReference"]["tableId"] table_resource = self._resolve_table_reference( table_resource=table_resource, project_id=project_id, dataset_id=dataset_id, table_id=table_id ) tables_list_resp = self.get_dataset_tables(dataset_id=dataset_id, project_id=project_id) - if any(table['tableId'] == table_id for table in tables_list_resp): - self.log.info('Table %s:%s.%s exists, updating.', project_id, dataset_id, table_id) + if any(table["tableId"] == table_id for table in tables_list_resp): + self.log.info("Table %s:%s.%s exists, updating.", project_id, dataset_id, table_id) table = self.update_table(table_resource=table_resource) else: - self.log.info('Table %s:%s.%s does not exist. creating.', project_id, dataset_id, table_id) + self.log.info("Table %s:%s.%s does not exist. creating.", project_id, dataset_id, table_id) table = self.create_empty_table( table_resource=table_resource, project_id=project_id ).to_api_repr() @@ -1162,7 +1162,7 @@ def delete_table( self, table_id: str, not_found_ok: bool = True, - project_id: Optional[str] = None, + project_id: str | None = None, ) -> None: """ Delete an existing table from the dataset. If the table does not exist, return an error @@ -1175,20 +1175,20 @@ def delete_table( :param project_id: the project used to perform the request """ self.get_client(project_id=project_id).delete_table( - table=Table.from_string(table_id), + table=table_id, not_found_ok=not_found_ok, ) - self.log.info('Deleted table %s', table_id) + self.log.info("Deleted table %s", table_id) def get_tabledata( self, dataset_id: str, table_id: str, - max_results: Optional[int] = None, - selected_fields: Optional[str] = None, - page_token: Optional[str] = None, - start_index: Optional[int] = None, - ) -> List[Dict]: + max_results: int | None = None, + selected_fields: str | None = None, + page_token: str | None = None, + start_index: int | None = None, + ) -> list[dict]: """ Get the data of a given dataset.table and optionally with selected columns. see https://cloud.google.com/bigquery/docs/reference/v2/tabledata/list @@ -1219,13 +1219,13 @@ def list_rows( self, dataset_id: str, table_id: str, - max_results: Optional[int] = None, - selected_fields: Optional[Union[List[str], str]] = None, - page_token: Optional[str] = None, - start_index: Optional[int] = None, - project_id: Optional[str] = None, - location: Optional[str] = None, - ) -> List[Row]: + max_results: int | None = None, + selected_fields: list[str] | str | None = None, + page_token: str | None = None, + start_index: int | None = None, + project_id: str | None = None, + location: str | None = None, + ) -> list[Row]: """ List the rows of the table. See https://cloud.google.com/bigquery/docs/reference/rest/v2/tabledata/list @@ -1268,7 +1268,7 @@ def list_rows( return list(result) @GoogleBaseHook.fallback_to_default_project_id - def get_schema(self, dataset_id: str, table_id: str, project_id: Optional[str] = None) -> dict: + def get_schema(self, dataset_id: str, table_id: str, project_id: str | None = None) -> dict: """ Get the schema for a given dataset and table. see https://cloud.google.com/bigquery/docs/reference/v2/tables#resource @@ -1286,12 +1286,12 @@ def get_schema(self, dataset_id: str, table_id: str, project_id: Optional[str] = @GoogleBaseHook.fallback_to_default_project_id def update_table_schema( self, - schema_fields_updates: List[Dict[str, Any]], + schema_fields_updates: list[dict[str, Any]], include_policy_tags: bool, dataset_id: str, table_id: str, - project_id: Optional[str] = None, - ) -> Dict[str, Any]: + project_id: str | None = None, + ) -> dict[str, Any]: """ Update fields within a schema for a given dataset and table. Note that some fields in schemas are immutable and trying to change them will cause @@ -1322,8 +1322,8 @@ def update_table_schema( """ def _build_new_schema( - current_schema: List[Dict[str, Any]], schema_fields_updates: List[Dict[str, Any]] - ) -> List[Dict[str, Any]]: + current_schema: list[dict[str, Any]], schema_fields_updates: list[dict[str, Any]] + ) -> list[dict[str, Any]]: # Turn schema_field_updates into a dict keyed on field names schema_fields_updates_dict = {field["name"]: field for field in deepcopy(schema_fields_updates)} @@ -1350,7 +1350,7 @@ def _build_new_schema( return list(new_schema.values()) - def _remove_policy_tags(schema: List[Dict[str, Any]]): + def _remove_policy_tags(schema: list[dict[str, Any]]): for field in schema: if "policyTags" in field: del field["policyTags"] @@ -1378,8 +1378,8 @@ def _remove_policy_tags(schema: List[Dict[str, Any]]): def poll_job_complete( self, job_id: str, - project_id: Optional[str] = None, - location: Optional[str] = None, + project_id: str | None = None, + location: str | None = None, retry: Retry = DEFAULT_RETRY, ) -> bool: """ @@ -1389,7 +1389,6 @@ def poll_job_complete( :param project_id: Google Cloud Project where the job is running :param location: location the job is running :param retry: How to retry the RPC. - :rtype: bool """ location = location or self.location job = self.get_client(project_id=project_id, location=location).get_job(job_id=job_id) @@ -1404,29 +1403,30 @@ def cancel_query(self) -> None: if self.running_job_id: self.cancel_job(job_id=self.running_job_id) else: - self.log.info('No running BigQuery jobs to cancel.') + self.log.info("No running BigQuery jobs to cancel.") @GoogleBaseHook.fallback_to_default_project_id def cancel_job( self, job_id: str, - project_id: Optional[str] = None, - location: Optional[str] = None, + project_id: str | None = None, + location: str | None = None, ) -> None: """ - Cancels a job an wait for cancellation to complete + Cancel a job and wait for cancellation to complete :param job_id: id of the job. :param project_id: Google Cloud Project where the job is running :param location: location the job is running """ + project_id = project_id or self.project_id location = location or self.location - if self.poll_job_complete(job_id=job_id): - self.log.info('No running BigQuery jobs to cancel.') + if self.poll_job_complete(job_id=job_id, project_id=project_id, location=location): + self.log.info("No running BigQuery jobs to cancel.") return - self.log.info('Attempting to cancel job : %s, %s', project_id, job_id) + self.log.info("Attempting to cancel job : %s, %s", project_id, job_id) self.get_client(location=location, project_id=project_id).cancel_job(job_id=job_id) # Wait for all the calls to cancel to finish @@ -1436,26 +1436,27 @@ def cancel_job( job_complete = False while polling_attempts < max_polling_attempts and not job_complete: polling_attempts += 1 - job_complete = self.poll_job_complete(job_id) + job_complete = self.poll_job_complete(job_id=job_id, project_id=project_id, location=location) if job_complete: - self.log.info('Job successfully canceled: %s, %s', project_id, job_id) + self.log.info("Job successfully canceled: %s, %s", project_id, job_id) elif polling_attempts == max_polling_attempts: self.log.info( - "Stopping polling due to timeout. Job with id %s " + "Stopping polling due to timeout. Job %s, %s " "has not completed cancel and may or may not finish.", + project_id, job_id, ) else: - self.log.info('Waiting for canceled job with id %s to finish.', job_id) + self.log.info("Waiting for canceled job %s, %s to finish.", project_id, job_id) time.sleep(5) @GoogleBaseHook.fallback_to_default_project_id def get_job( self, - job_id: Optional[str] = None, - project_id: Optional[str] = None, - location: Optional[str] = None, - ) -> Union[CopyJob, QueryJob, LoadJob, ExtractJob]: + job_id: str | None = None, + project_id: str | None = None, + location: str | None = None, + ) -> CopyJob | QueryJob | LoadJob | ExtractJob: """ Retrieves a BigQuery job. For more information see: https://cloud.google.com/bigquery/docs/reference/v2/jobs @@ -1471,7 +1472,7 @@ def get_job( return job @staticmethod - def _custom_job_id(configuration: Dict[str, Any]) -> str: + def _custom_job_id(configuration: dict[str, Any]) -> str: hash_base = json.dumps(configuration, sort_keys=True) uniqueness_suffix = hashlib.md5(hash_base.encode()).hexdigest() microseconds_from_epoch = int( @@ -1482,13 +1483,13 @@ def _custom_job_id(configuration: Dict[str, Any]) -> str: @GoogleBaseHook.fallback_to_default_project_id def insert_job( self, - configuration: Dict, - job_id: Optional[str] = None, - project_id: Optional[str] = None, - location: Optional[str] = None, + configuration: dict, + job_id: str | None = None, + project_id: str | None = None, + location: str | None = None, nowait: bool = False, retry: Retry = DEFAULT_RETRY, - timeout: Optional[float] = None, + timeout: float | None = None, ) -> BigQueryJob: """ Executes a BigQuery job. Waits for the job to complete and returns job id. @@ -1565,27 +1566,27 @@ def run_with_configuration(self, configuration: dict) -> str: def run_load( self, destination_project_dataset_table: str, - source_uris: List, - schema_fields: Optional[List] = None, - source_format: str = 'CSV', - create_disposition: str = 'CREATE_IF_NEEDED', + source_uris: list, + schema_fields: list | None = None, + source_format: str = "CSV", + create_disposition: str = "CREATE_IF_NEEDED", skip_leading_rows: int = 0, - write_disposition: str = 'WRITE_EMPTY', - field_delimiter: str = ',', + write_disposition: str = "WRITE_EMPTY", + field_delimiter: str = ",", max_bad_records: int = 0, - quote_character: Optional[str] = None, + quote_character: str | None = None, ignore_unknown_values: bool = False, allow_quoted_newlines: bool = False, allow_jagged_rows: bool = False, encoding: str = "UTF-8", - schema_update_options: Optional[Iterable] = None, - src_fmt_configs: Optional[Dict] = None, - time_partitioning: Optional[Dict] = None, - cluster_fields: Optional[List] = None, + schema_update_options: Iterable | None = None, + src_fmt_configs: dict | None = None, + time_partitioning: dict | None = None, + cluster_fields: list | None = None, autodetect: bool = False, - encryption_configuration: Optional[Dict] = None, - labels: Optional[Dict] = None, - description: Optional[str] = None, + encryption_configuration: dict | None = None, + labels: dict | None = None, + description: str | None = None, ) -> str: """ Executes a BigQuery load command to load data from Google Cloud Storage @@ -1668,7 +1669,7 @@ def run_load( # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.tableDefinitions.(key).sourceFormat # noqa if schema_fields is None and not autodetect: - raise ValueError('You must either pass a schema or autodetect=True.') + raise ValueError("You must either pass a schema or autodetect=True.") if src_fmt_configs is None: src_fmt_configs = {} @@ -1692,44 +1693,44 @@ def run_load( # as a side effect of a load # for more details: # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schemaUpdateOptions - allowed_schema_update_options = ['ALLOW_FIELD_ADDITION', "ALLOW_FIELD_RELAXATION"] + allowed_schema_update_options = ["ALLOW_FIELD_ADDITION", "ALLOW_FIELD_RELAXATION"] if not set(allowed_schema_update_options).issuperset(set(schema_update_options)): raise ValueError( f"{schema_update_options} contains invalid schema update options. " f"Please only use one or more of the following options: {allowed_schema_update_options}" ) - destination_project, destination_dataset, destination_table = _split_tablename( + destination_project, destination_dataset, destination_table = self.split_tablename( table_input=destination_project_dataset_table, default_project_id=self.project_id, - var_name='destination_project_dataset_table', + var_name="destination_project_dataset_table", ) - configuration: Dict[str, Any] = { - 'load': { - 'autodetect': autodetect, - 'createDisposition': create_disposition, - 'destinationTable': { - 'projectId': destination_project, - 'datasetId': destination_dataset, - 'tableId': destination_table, + configuration: dict[str, Any] = { + "load": { + "autodetect": autodetect, + "createDisposition": create_disposition, + "destinationTable": { + "projectId": destination_project, + "datasetId": destination_dataset, + "tableId": destination_table, }, - 'sourceFormat': source_format, - 'sourceUris': source_uris, - 'writeDisposition': write_disposition, - 'ignoreUnknownValues': ignore_unknown_values, + "sourceFormat": source_format, + "sourceUris": source_uris, + "writeDisposition": write_disposition, + "ignoreUnknownValues": ignore_unknown_values, } } time_partitioning = _cleanse_time_partitioning(destination_project_dataset_table, time_partitioning) if time_partitioning: - configuration['load'].update({'timePartitioning': time_partitioning}) + configuration["load"].update({"timePartitioning": time_partitioning}) if cluster_fields: - configuration['load'].update({'clustering': {'fields': cluster_fields}}) + configuration["load"].update({"clustering": {"fields": cluster_fields}}) if schema_fields: - configuration['load']['schema'] = {'fields': schema_fields} + configuration["load"]["schema"] = {"fields": schema_fields} if schema_update_options: if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]: @@ -1740,39 +1741,39 @@ def run_load( ) else: self.log.info("Adding experimental 'schemaUpdateOptions': %s", schema_update_options) - configuration['load']['schemaUpdateOptions'] = schema_update_options + configuration["load"]["schemaUpdateOptions"] = schema_update_options if max_bad_records: - configuration['load']['maxBadRecords'] = max_bad_records + configuration["load"]["maxBadRecords"] = max_bad_records if encryption_configuration: configuration["load"]["destinationEncryptionConfiguration"] = encryption_configuration if labels or description: - configuration['load'].update({'destinationTableProperties': {}}) + configuration["load"].update({"destinationTableProperties": {}}) if labels: - configuration['load']['destinationTableProperties']['labels'] = labels + configuration["load"]["destinationTableProperties"]["labels"] = labels if description: - configuration['load']['destinationTableProperties']['description'] = description + configuration["load"]["destinationTableProperties"]["description"] = description src_fmt_to_configs_mapping = { - 'CSV': [ - 'allowJaggedRows', - 'allowQuotedNewlines', - 'autodetect', - 'fieldDelimiter', - 'skipLeadingRows', - 'ignoreUnknownValues', - 'nullMarker', - 'quote', - 'encoding', + "CSV": [ + "allowJaggedRows", + "allowQuotedNewlines", + "autodetect", + "fieldDelimiter", + "skipLeadingRows", + "ignoreUnknownValues", + "nullMarker", + "quote", + "encoding", ], - 'DATASTORE_BACKUP': ['projectionFields'], - 'NEWLINE_DELIMITED_JSON': ['autodetect', 'ignoreUnknownValues'], - 'PARQUET': ['autodetect', 'ignoreUnknownValues'], - 'AVRO': ['useAvroLogicalTypes'], + "DATASTORE_BACKUP": ["projectionFields"], + "NEWLINE_DELIMITED_JSON": ["autodetect", "ignoreUnknownValues"], + "PARQUET": ["autodetect", "ignoreUnknownValues"], + "AVRO": ["useAvroLogicalTypes"], } valid_configs = src_fmt_to_configs_mapping[source_format] @@ -1780,22 +1781,22 @@ def run_load( # if following fields are not specified in src_fmt_configs, # honor the top-level params for backward-compatibility backward_compatibility_configs = { - 'skipLeadingRows': skip_leading_rows, - 'fieldDelimiter': field_delimiter, - 'ignoreUnknownValues': ignore_unknown_values, - 'quote': quote_character, - 'allowQuotedNewlines': allow_quoted_newlines, - 'encoding': encoding, + "skipLeadingRows": skip_leading_rows, + "fieldDelimiter": field_delimiter, + "ignoreUnknownValues": ignore_unknown_values, + "quote": quote_character, + "allowQuotedNewlines": allow_quoted_newlines, + "encoding": encoding, } src_fmt_configs = _validate_src_fmt_configs( source_format, src_fmt_configs, valid_configs, backward_compatibility_configs ) - configuration['load'].update(src_fmt_configs) + configuration["load"].update(src_fmt_configs) if allow_jagged_rows: - configuration['load']['allowJaggedRows'] = allow_jagged_rows + configuration["load"]["allowJaggedRows"] = allow_jagged_rows job = self.insert_job(configuration=configuration, project_id=self.project_id) self.running_job_id = job.job_id @@ -1803,12 +1804,12 @@ def run_load( def run_copy( self, - source_project_dataset_tables: Union[List, str], + source_project_dataset_tables: list | str, destination_project_dataset_table: str, - write_disposition: str = 'WRITE_EMPTY', - create_disposition: str = 'CREATE_IF_NEEDED', - labels: Optional[Dict] = None, - encryption_configuration: Optional[Dict] = None, + write_disposition: str = "WRITE_EMPTY", + create_disposition: str = "CREATE_IF_NEEDED", + labels: dict | None = None, + encryption_configuration: dict | None = None, ) -> str: """ Executes a BigQuery copy command to copy data from one BigQuery table @@ -1851,33 +1852,33 @@ def run_copy( source_project_dataset_tables_fixup = [] for source_project_dataset_table in source_project_dataset_tables: - source_project, source_dataset, source_table = _split_tablename( + source_project, source_dataset, source_table = self.split_tablename( table_input=source_project_dataset_table, default_project_id=self.project_id, - var_name='source_project_dataset_table', + var_name="source_project_dataset_table", ) source_project_dataset_tables_fixup.append( - {'projectId': source_project, 'datasetId': source_dataset, 'tableId': source_table} + {"projectId": source_project, "datasetId": source_dataset, "tableId": source_table} ) - destination_project, destination_dataset, destination_table = _split_tablename( + destination_project, destination_dataset, destination_table = self.split_tablename( table_input=destination_project_dataset_table, default_project_id=self.project_id ) configuration = { - 'copy': { - 'createDisposition': create_disposition, - 'writeDisposition': write_disposition, - 'sourceTables': source_project_dataset_tables_fixup, - 'destinationTable': { - 'projectId': destination_project, - 'datasetId': destination_dataset, - 'tableId': destination_table, + "copy": { + "createDisposition": create_disposition, + "writeDisposition": write_disposition, + "sourceTables": source_project_dataset_tables_fixup, + "destinationTable": { + "projectId": destination_project, + "datasetId": destination_dataset, + "tableId": destination_table, }, } } if labels: - configuration['labels'] = labels + configuration["labels"] = labels if encryption_configuration: configuration["copy"]["destinationEncryptionConfiguration"] = encryption_configuration @@ -1889,13 +1890,14 @@ def run_copy( def run_extract( self, source_project_dataset_table: str, - destination_cloud_storage_uris: List[str], - compression: str = 'NONE', - export_format: str = 'CSV', - field_delimiter: str = ',', + destination_cloud_storage_uris: list[str], + compression: str = "NONE", + export_format: str = "CSV", + field_delimiter: str = ",", print_header: bool = True, - labels: Optional[Dict] = None, - ) -> str: + labels: dict | None = None, + return_full_job: bool = False, + ) -> str | BigQueryJob: """ Executes a BigQuery extract command to copy data from BigQuery to Google Cloud Storage. See here: @@ -1916,6 +1918,7 @@ def run_extract( :param print_header: Whether to print a header for a CSV file extract. :param labels: a dictionary containing labels for the job/query, passed to BigQuery + :param return_full_job: return full job instead of job id only """ warnings.warn( "This method is deprecated. Please use `BigQueryHook.insert_job` method.", DeprecationWarning @@ -1923,60 +1926,62 @@ def run_extract( if not self.project_id: raise ValueError("The project_id should be set") - source_project, source_dataset, source_table = _split_tablename( + source_project, source_dataset, source_table = self.split_tablename( table_input=source_project_dataset_table, default_project_id=self.project_id, - var_name='source_project_dataset_table', + var_name="source_project_dataset_table", ) - configuration: Dict[str, Any] = { - 'extract': { - 'sourceTable': { - 'projectId': source_project, - 'datasetId': source_dataset, - 'tableId': source_table, + configuration: dict[str, Any] = { + "extract": { + "sourceTable": { + "projectId": source_project, + "datasetId": source_dataset, + "tableId": source_table, }, - 'compression': compression, - 'destinationUris': destination_cloud_storage_uris, - 'destinationFormat': export_format, + "compression": compression, + "destinationUris": destination_cloud_storage_uris, + "destinationFormat": export_format, } } if labels: - configuration['labels'] = labels + configuration["labels"] = labels - if export_format == 'CSV': + if export_format == "CSV": # Only set fieldDelimiter and printHeader fields if using CSV. # Google does not like it if you set these fields for other export # formats. - configuration['extract']['fieldDelimiter'] = field_delimiter - configuration['extract']['printHeader'] = print_header + configuration["extract"]["fieldDelimiter"] = field_delimiter + configuration["extract"]["printHeader"] = print_header job = self.insert_job(configuration=configuration, project_id=self.project_id) self.running_job_id = job.job_id + if return_full_job: + return job return job.job_id def run_query( self, sql: str, - destination_dataset_table: Optional[str] = None, - write_disposition: str = 'WRITE_EMPTY', + destination_dataset_table: str | None = None, + write_disposition: str = "WRITE_EMPTY", allow_large_results: bool = False, - flatten_results: Optional[bool] = None, - udf_config: Optional[List] = None, - use_legacy_sql: Optional[bool] = None, - maximum_billing_tier: Optional[int] = None, - maximum_bytes_billed: Optional[float] = None, - create_disposition: str = 'CREATE_IF_NEEDED', - query_params: Optional[List] = None, - labels: Optional[Dict] = None, - schema_update_options: Optional[Iterable] = None, - priority: str = 'INTERACTIVE', - time_partitioning: Optional[Dict] = None, - api_resource_configs: Optional[Dict] = None, - cluster_fields: Optional[List[str]] = None, - location: Optional[str] = None, - encryption_configuration: Optional[Dict] = None, + flatten_results: bool | None = None, + udf_config: list | None = None, + use_legacy_sql: bool | None = None, + maximum_billing_tier: int | None = None, + maximum_bytes_billed: float | None = None, + create_disposition: str = "CREATE_IF_NEEDED", + query_params: list | None = None, + labels: dict | None = None, + schema_update_options: Iterable | None = None, + priority: str = "INTERACTIVE", + time_partitioning: dict | None = None, + api_resource_configs: dict | None = None, + cluster_fields: list[str] | None = None, + location: str | None = None, + encryption_configuration: dict | None = None, ) -> str: """ Executes a BigQuery SQL query. Optionally persists results in a BigQuery @@ -2056,23 +2061,23 @@ def run_query( if not api_resource_configs: api_resource_configs = self.api_resource_configs else: - _validate_value('api_resource_configs', api_resource_configs, dict) + _validate_value("api_resource_configs", api_resource_configs, dict) configuration = deepcopy(api_resource_configs) - if 'query' not in configuration: - configuration['query'] = {} + if "query" not in configuration: + configuration["query"] = {} else: - _validate_value("api_resource_configs['query']", configuration['query'], dict) + _validate_value("api_resource_configs['query']", configuration["query"], dict) - if sql is None and not configuration['query'].get('query', None): - raise TypeError('`BigQueryBaseCursor.run_query` missing 1 required positional argument: `sql`') + if sql is None and not configuration["query"].get("query", None): + raise TypeError("`BigQueryBaseCursor.run_query` missing 1 required positional argument: `sql`") # BigQuery also allows you to define how you want a table's schema to change # as a side effect of a query job # for more details: # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.schemaUpdateOptions # noqa - allowed_schema_update_options = ['ALLOW_FIELD_ADDITION', "ALLOW_FIELD_RELAXATION"] + allowed_schema_update_options = ["ALLOW_FIELD_ADDITION", "ALLOW_FIELD_RELAXATION"] if not set(allowed_schema_update_options).issuperset(set(schema_update_options)): raise ValueError( @@ -2089,60 +2094,60 @@ def run_query( ) if destination_dataset_table: - destination_project, destination_dataset, destination_table = _split_tablename( + destination_project, destination_dataset, destination_table = self.split_tablename( table_input=destination_dataset_table, default_project_id=self.project_id ) destination_dataset_table = { # type: ignore - 'projectId': destination_project, - 'datasetId': destination_dataset, - 'tableId': destination_table, + "projectId": destination_project, + "datasetId": destination_dataset, + "tableId": destination_table, } if cluster_fields: - cluster_fields = {'fields': cluster_fields} # type: ignore - - query_param_list = [ - (sql, 'query', None, (str,)), - (priority, 'priority', 'INTERACTIVE', (str,)), - (use_legacy_sql, 'useLegacySql', self.use_legacy_sql, bool), - (query_params, 'queryParameters', None, list), - (udf_config, 'userDefinedFunctionResources', None, list), - (maximum_billing_tier, 'maximumBillingTier', None, int), - (maximum_bytes_billed, 'maximumBytesBilled', None, float), - (time_partitioning, 'timePartitioning', {}, dict), - (schema_update_options, 'schemaUpdateOptions', None, list), - (destination_dataset_table, 'destinationTable', None, dict), - (cluster_fields, 'clustering', None, dict), - ] # type: List[Tuple] + cluster_fields = {"fields": cluster_fields} # type: ignore + + query_param_list: list[tuple[Any, str, str | bool | None | dict, type | tuple[type]]] = [ + (sql, "query", None, (str,)), + (priority, "priority", "INTERACTIVE", (str,)), + (use_legacy_sql, "useLegacySql", self.use_legacy_sql, bool), + (query_params, "queryParameters", None, list), + (udf_config, "userDefinedFunctionResources", None, list), + (maximum_billing_tier, "maximumBillingTier", None, int), + (maximum_bytes_billed, "maximumBytesBilled", None, float), + (time_partitioning, "timePartitioning", {}, dict), + (schema_update_options, "schemaUpdateOptions", None, list), + (destination_dataset_table, "destinationTable", None, dict), + (cluster_fields, "clustering", None, dict), + ] for param, param_name, param_default, param_type in query_param_list: - if param_name not in configuration['query'] and param in [None, {}, ()]: - if param_name == 'timePartitioning': + if param_name not in configuration["query"] and param in [None, {}, ()]: + if param_name == "timePartitioning": param_default = _cleanse_time_partitioning(destination_dataset_table, time_partitioning) param = param_default if param in [None, {}, ()]: continue - _api_resource_configs_duplication_check(param_name, param, configuration['query']) + _api_resource_configs_duplication_check(param_name, param, configuration["query"]) - configuration['query'][param_name] = param + configuration["query"][param_name] = param # check valid type of provided param, # it last step because we can get param from 2 sources, # and first of all need to find it - _validate_value(param_name, configuration['query'][param_name], param_type) + _validate_value(param_name, configuration["query"][param_name], param_type) - if param_name == 'schemaUpdateOptions' and param: + if param_name == "schemaUpdateOptions" and param: self.log.info("Adding experimental 'schemaUpdateOptions': %s", schema_update_options) - if param_name != 'destinationTable': + if param_name != "destinationTable": continue - for key in ['projectId', 'datasetId', 'tableId']: - if key not in configuration['query']['destinationTable']: + for key in ["projectId", "datasetId", "tableId"]: + if key not in configuration["query"]["destinationTable"]: raise ValueError( "Not correct 'destinationTable' in " "api_resource_configs. 'destinationTable' " @@ -2150,25 +2155,25 @@ def run_query( "'datasetId':'', 'tableId':''}" ) - configuration['query'].update( + configuration["query"].update( { - 'allowLargeResults': allow_large_results, - 'flattenResults': flatten_results, - 'writeDisposition': write_disposition, - 'createDisposition': create_disposition, + "allowLargeResults": allow_large_results, + "flattenResults": flatten_results, + "writeDisposition": write_disposition, + "createDisposition": create_disposition, } ) if ( - 'useLegacySql' in configuration['query'] - and configuration['query']['useLegacySql'] - and 'queryParameters' in configuration['query'] + "useLegacySql" in configuration["query"] + and configuration["query"]["useLegacySql"] + and "queryParameters" in configuration["query"] ): raise ValueError("Query parameters are not allowed when using legacy SQL") if labels: - _api_resource_configs_duplication_check('labels', labels, configuration) - configuration['labels'] = labels + _api_resource_configs_duplication_check("labels", labels, configuration) + configuration["labels"] = labels if encryption_configuration: configuration["query"]["destinationEncryptionConfiguration"] = encryption_configuration @@ -2177,6 +2182,83 @@ def run_query( self.running_job_id = job.job_id return job.job_id + def generate_job_id(self, job_id, dag_id, task_id, logical_date, configuration, force_rerun=False): + if force_rerun: + hash_base = str(uuid.uuid4()) + else: + hash_base = json.dumps(configuration, sort_keys=True) + + uniqueness_suffix = hashlib.md5(hash_base.encode()).hexdigest() + + if job_id: + return f"{job_id}_{uniqueness_suffix}" + + exec_date = logical_date.isoformat() + job_id = f"airflow_{dag_id}_{task_id}_{exec_date}_{uniqueness_suffix}" + return re.sub(r"[:\-+.]", "_", job_id) + + def split_tablename( + self, table_input: str, default_project_id: str, var_name: str | None = None + ) -> tuple[str, str, str]: + + if "." not in table_input: + raise ValueError(f"Expected table name in the format of .. Got: {table_input}") + + if not default_project_id: + raise ValueError("INTERNAL: No default project is specified") + + def var_print(var_name): + if var_name is None: + return "" + else: + return f"Format exception for {var_name}: " + + if table_input.count(".") + table_input.count(":") > 3: + raise Exception(f"{var_print(var_name)}Use either : or . to specify project got {table_input}") + cmpt = table_input.rsplit(":", 1) + project_id = None + rest = table_input + if len(cmpt) == 1: + project_id = None + rest = cmpt[0] + elif len(cmpt) == 2 and cmpt[0].count(":") <= 1: + if cmpt[-1].count(".") != 2: + project_id = cmpt[0] + rest = cmpt[1] + else: + raise Exception( + f"{var_print(var_name)}Expect format of (.
, got {table_input}" + ) + + cmpt = rest.split(".") + if len(cmpt) == 3: + if project_id: + raise ValueError(f"{var_print(var_name)}Use either : or . to specify project") + project_id = cmpt[0] + dataset_id = cmpt[1] + table_id = cmpt[2] + + elif len(cmpt) == 2: + dataset_id = cmpt[0] + table_id = cmpt[1] + else: + raise Exception( + f"{var_print(var_name)} Expect format of (.
, " + f"got {table_input}" + ) + + if project_id is None: + if var_name is not None: + self.log.info( + 'Project is not included in %s: %s; using project "%s"', + var_name, + table_input, + default_project_id, + ) + project_id = default_project_id + + return project_id, dataset_id, table_id + class BigQueryConnection: """ @@ -2195,7 +2277,7 @@ def close(self) -> None: def commit(self) -> None: """The BigQueryConnection does not support transactions""" - def cursor(self) -> "BigQueryCursor": + def cursor(self) -> BigQueryCursor: """Return a new :py:class:`Cursor` object using the connection""" return BigQueryCursor(*self._args, **self._kwargs) @@ -2217,26 +2299,25 @@ def __init__( project_id: str, hook: BigQueryHook, use_legacy_sql: bool = True, - api_resource_configs: Optional[Dict] = None, - location: Optional[str] = None, + api_resource_configs: dict | None = None, + location: str | None = None, num_retries: int = 5, - labels: Optional[Dict] = None, + labels: dict | None = None, ) -> None: - super().__init__() self.service = service self.project_id = project_id self.use_legacy_sql = use_legacy_sql if api_resource_configs: _validate_value("api_resource_configs", api_resource_configs, dict) - self.api_resource_configs = api_resource_configs if api_resource_configs else {} # type Dict - self.running_job_id = None # type: Optional[str] + self.api_resource_configs: dict = api_resource_configs if api_resource_configs else {} + self.running_job_id: str | None = None self.location = location self.num_retries = num_retries self.labels = labels self.hook = hook - def create_empty_table(self, *args, **kwargs) -> None: + def create_empty_table(self, *args, **kwargs): """ This method is deprecated. Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_empty_table` @@ -2249,7 +2330,7 @@ def create_empty_table(self, *args, **kwargs) -> None: ) return self.hook.create_empty_table(*args, **kwargs) - def create_empty_dataset(self, *args, **kwargs) -> Dict[str, Any]: + def create_empty_dataset(self, *args, **kwargs) -> dict[str, Any]: """ This method is deprecated. Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_empty_dataset` @@ -2262,7 +2343,7 @@ def create_empty_dataset(self, *args, **kwargs) -> Dict[str, Any]: ) return self.hook.create_empty_dataset(*args, **kwargs) - def get_dataset_tables(self, *args, **kwargs) -> List[Dict[str, Any]]: + def get_dataset_tables(self, *args, **kwargs) -> list[dict[str, Any]]: """ This method is deprecated. Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_dataset_tables` @@ -2288,7 +2369,7 @@ def delete_dataset(self, *args, **kwargs) -> None: ) return self.hook.delete_dataset(*args, **kwargs) - def create_external_table(self, *args, **kwargs) -> None: + def create_external_table(self, *args, **kwargs): """ This method is deprecated. Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_external_table` @@ -2327,7 +2408,7 @@ def insert_all(self, *args, **kwargs) -> None: ) return self.hook.insert_all(*args, **kwargs) - def update_dataset(self, *args, **kwargs) -> Dict: + def update_dataset(self, *args, **kwargs) -> dict: """ This method is deprecated. Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.update_dataset` @@ -2340,7 +2421,7 @@ def update_dataset(self, *args, **kwargs) -> Dict: ) return Dataset.to_api_repr(self.hook.update_dataset(*args, **kwargs)) - def patch_dataset(self, *args, **kwargs) -> Dict: + def patch_dataset(self, *args, **kwargs) -> dict: """ This method is deprecated. Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.patch_dataset` @@ -2353,7 +2434,7 @@ def patch_dataset(self, *args, **kwargs) -> Dict: ) return self.hook.patch_dataset(*args, **kwargs) - def get_dataset_tables_list(self, *args, **kwargs) -> List[Dict[str, Any]]: + def get_dataset_tables_list(self, *args, **kwargs) -> list[dict[str, Any]]: """ This method is deprecated. Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_dataset_tables_list` @@ -2432,7 +2513,7 @@ def run_table_delete(self, *args, **kwargs) -> None: ) return self.hook.run_table_delete(*args, **kwargs) - def get_tabledata(self, *args, **kwargs) -> List[dict]: + def get_tabledata(self, *args, **kwargs) -> list[dict]: """ This method is deprecated. Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_tabledata` @@ -2565,7 +2646,7 @@ def __init__( project_id: str, hook: BigQueryHook, use_legacy_sql: bool = True, - location: Optional[str] = None, + location: str | None = None, num_retries: int = 5, ) -> None: super().__init__( @@ -2576,16 +2657,21 @@ def __init__( location=location, num_retries=num_retries, ) - self.buffersize = None # type: Optional[int] - self.page_token = None # type: Optional[str] - self.job_id = None # type: Optional[str] - self.buffer = [] # type: list - self.all_pages_loaded = False # type: bool + self.buffersize: int | None = None + self.page_token: str | None = None + self.job_id: str | None = None + self.buffer: list = [] + self.all_pages_loaded: bool = False + self._description: list = [] @property - def description(self) -> None: - """The schema description method is not currently implemented""" - raise NotImplementedError + def description(self) -> list: + """Return the cursor description""" + return self._description + + @description.setter + def description(self, value): + self._description = value def close(self) -> None: """By default, do nothing""" @@ -2595,7 +2681,7 @@ def rowcount(self) -> int: """By default, return -1 to indicate that this is not supported""" return -1 - def execute(self, operation: str, parameters: Optional[dict] = None) -> None: + def execute(self, operation: str, parameters: dict | None = None) -> None: """ Executes a BigQuery query, and returns the job ID. @@ -2606,6 +2692,12 @@ def execute(self, operation: str, parameters: Optional[dict] = None) -> None: self.flush_results() self.job_id = self.hook.run_query(sql) + query_results = self._get_query_result() + if "schema" in query_results: + self.description = _format_schema_for_description(query_results["schema"]) + else: + self.description = [] + def executemany(self, operation: str, seq_of_parameters: list) -> None: """ Execute a BigQuery query multiple times with different parameters. @@ -2624,11 +2716,11 @@ def flush_results(self) -> None: self.all_pages_loaded = False self.buffer = [] - def fetchone(self) -> Union[List, None]: + def fetchone(self) -> list | None: """Fetch the next row of a query result set""" return self.next() - def next(self) -> Union[List, None]: + def next(self) -> list | None: """ Helper method for fetchone, which returns the next row from a buffer. If the buffer is empty, attempts to paginate through the result set for @@ -2641,25 +2733,15 @@ def next(self) -> Union[List, None]: if self.all_pages_loaded: return None - query_results = ( - self.service.jobs() - .getQueryResults( - projectId=self.project_id, - jobId=self.job_id, - location=self.location, - pageToken=self.page_token, - ) - .execute(num_retries=self.num_retries) - ) - - if 'rows' in query_results and query_results['rows']: - self.page_token = query_results.get('pageToken') - fields = query_results['schema']['fields'] - col_types = [field['type'] for field in fields] - rows = query_results['rows'] + query_results = self._get_query_result() + if "rows" in query_results and query_results["rows"]: + self.page_token = query_results.get("pageToken") + fields = query_results["schema"]["fields"] + col_types = [field["type"] for field in fields] + rows = query_results["rows"] for dict_row in rows: - typed_row = [_bq_cast(vs['v'], col_types[idx]) for idx, vs in enumerate(dict_row['f'])] + typed_row = [bq_cast(vs["v"], col_types[idx]) for idx, vs in enumerate(dict_row["f"])] self.buffer.append(typed_row) if not self.page_token: @@ -2672,7 +2754,7 @@ def next(self) -> Union[List, None]: return self.buffer.pop(0) - def fetchmany(self, size: Optional[int] = None) -> list: + def fetchmany(self, size: int | None = None) -> list: """ Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a list of tuples). An empty sequence is returned when no more rows are @@ -2694,7 +2776,7 @@ def fetchmany(self, size: Optional[int] = None) -> list: result.append(one) return result - def fetchall(self) -> List[list]: + def fetchall(self) -> list[list]: """ Fetch all (remaining) rows of a query result, returning them as a sequence of sequences (e.g. a list of tuples). @@ -2723,14 +2805,29 @@ def setinputsizes(self, sizes: Any) -> None: def setoutputsize(self, size: Any, column: Any = None) -> None: """Does nothing by default""" + def _get_query_result(self) -> dict: + """Get job query results like data, schema, job type...""" + query_results = ( + self.service.jobs() + .getQueryResults( + projectId=self.project_id, + jobId=self.job_id, + location=self.location, + pageToken=self.page_token, + ) + .execute(num_retries=self.num_retries) + ) + + return query_results + def _bind_parameters(operation: str, parameters: dict) -> str: """Helper method that binds parameters to a SQL query""" # inspired by MySQL Python Connector (conversion.py) - string_parameters = {} # type Dict[str, str] + string_parameters = {} # type dict[str, str] for (name, value) in parameters.items(): if value is None: - string_parameters[name] = 'NULL' + string_parameters[name] = "NULL" elif isinstance(value, str): string_parameters[name] = "'" + _escape(value) + "'" else: @@ -2741,39 +2838,19 @@ def _bind_parameters(operation: str, parameters: dict) -> str: def _escape(s: str) -> str: """Helper method that escapes parameters to a SQL query""" e = s - e = e.replace('\\', '\\\\') - e = e.replace('\n', '\\n') - e = e.replace('\r', '\\r') + e = e.replace("\\", "\\\\") + e = e.replace("\n", "\\n") + e = e.replace("\r", "\\r") e = e.replace("'", "\\'") e = e.replace('"', '\\"') return e -def _bq_cast(string_field: str, bq_type: str) -> Union[None, int, float, bool, str]: - """ - Helper method that casts a BigQuery row to the appropriate data types. - This is useful because BigQuery returns all fields as strings. - """ - if string_field is None: - return None - elif bq_type == 'INTEGER': - return int(string_field) - elif bq_type in ('FLOAT', 'TIMESTAMP'): - return float(string_field) - elif bq_type == 'BOOLEAN': - if string_field not in ['true', 'false']: - raise ValueError(f"{string_field} must have value 'true' or 'false'") - return string_field == 'true' - else: - return string_field - - -def _split_tablename( - table_input: str, default_project_id: str, var_name: Optional[str] = None -) -> Tuple[str, str, str]: - - if '.' not in table_input: - raise ValueError(f'Expected table name in the format of .
. Got: {table_input}') +def split_tablename( + table_input: str, default_project_id: str, var_name: str | None = None +) -> tuple[str, str, str]: + if "." not in table_input: + raise ValueError(f"Expected table name in the format of .
. Got: {table_input}") if not default_project_id: raise ValueError("INTERNAL: No default project is specified") @@ -2784,24 +2861,24 @@ def var_print(var_name): else: return f"Format exception for {var_name}: " - if table_input.count('.') + table_input.count(':') > 3: - raise Exception(f'{var_print(var_name)}Use either : or . to specify project got {table_input}') - cmpt = table_input.rsplit(':', 1) + if table_input.count(".") + table_input.count(":") > 3: + raise Exception(f"{var_print(var_name)}Use either : or . to specify project got {table_input}") + cmpt = table_input.rsplit(":", 1) project_id = None rest = table_input if len(cmpt) == 1: project_id = None rest = cmpt[0] - elif len(cmpt) == 2 and cmpt[0].count(':') <= 1: - if cmpt[-1].count('.') != 2: + elif len(cmpt) == 2 and cmpt[0].count(":") <= 1: + if cmpt[-1].count(".") != 2: project_id = cmpt[0] rest = cmpt[1] else: raise Exception( - f'{var_print(var_name)}Expect format of (.
, got {table_input}' + f"{var_print(var_name)}Expect format of (.
, got {table_input}" ) - cmpt = rest.split('.') + cmpt = rest.split(".") if len(cmpt) == 3: if project_id: raise ValueError(f"{var_print(var_name)}Use either : or . to specify project") @@ -2814,13 +2891,13 @@ def var_print(var_name): table_id = cmpt[1] else: raise Exception( - f'{var_print(var_name)}Expect format of (.
, got {table_input}' + f"{var_print(var_name)}Expect format of (.
, got {table_input}" ) if project_id is None: if var_name is not None: log.info( - 'Project not included in %s: %s; using project "%s"', + 'Project is not included in %s: %s; using project "%s"', var_name, table_input, default_project_id, @@ -2831,27 +2908,27 @@ def var_print(var_name): def _cleanse_time_partitioning( - destination_dataset_table: Optional[str], time_partitioning_in: Optional[Dict] -) -> Dict: # if it is a partitioned table ($ is in the table name) add partition load option + destination_dataset_table: str | None, time_partitioning_in: dict | None +) -> dict: # if it is a partitioned table ($ is in the table name) add partition load option if time_partitioning_in is None: time_partitioning_in = {} time_partitioning_out = {} - if destination_dataset_table and '$' in destination_dataset_table: - time_partitioning_out['type'] = 'DAY' + if destination_dataset_table and "$" in destination_dataset_table: + time_partitioning_out["type"] = "DAY" time_partitioning_out.update(time_partitioning_in) return time_partitioning_out -def _validate_value(key: Any, value: Any, expected_type: Type) -> None: +def _validate_value(key: Any, value: Any, expected_type: type | tuple[type]) -> None: """Function to check expected type and raise error if type is not correct""" if not isinstance(value, expected_type): raise TypeError(f"{key} argument must have a type {expected_type} not {type(value)}") def _api_resource_configs_duplication_check( - key: Any, value: Any, config_dict: dict, config_dict_name='api_resource_configs' + key: Any, value: Any, config_dict: dict, config_dict_name="api_resource_configs" ) -> None: if key in config_dict and value != config_dict[key]: raise ValueError( @@ -2867,9 +2944,9 @@ def _api_resource_configs_duplication_check( def _validate_src_fmt_configs( source_format: str, src_fmt_configs: dict, - valid_configs: List[str], - backward_compatibility_configs: Optional[Dict] = None, -) -> Dict: + valid_configs: list[str], + backward_compatibility_configs: dict | None = None, +) -> dict: """ Validates the given src_fmt_configs against a valid configuration for the source format. Adds the backward compatibility config to the src_fmt_configs. @@ -2891,3 +2968,276 @@ def _validate_src_fmt_configs( raise ValueError(f"{k} is not a valid src_fmt_configs for type {source_format}.") return src_fmt_configs + + +def _format_schema_for_description(schema: dict) -> list: + """ + Reformat the schema to match cursor description standard which is a tuple + of 7 elemenbts (name, type, display_size, internal_size, precision, scale, null_ok) + """ + description = [] + for field in schema["fields"]: + mode = field.get("mode", "NULLABLE") + field_description = ( + field["name"], + field["type"], + None, + None, + None, + None, + mode == "NULLABLE", + ) + description.append(field_description) + return description + + +class BigQueryAsyncHook(GoogleBaseAsyncHook): + """Uses gcloud-aio library to retrieve Job details""" + + sync_hook_class = BigQueryHook + + async def get_job_instance( + self, project_id: str | None, job_id: str | None, session: ClientSession + ) -> Job: + """Get the specified job resource by job ID and project ID.""" + with await self.service_file_as_context() as f: + return Job(job_id=job_id, project=project_id, service_file=f, session=cast(Session, session)) + + async def get_job_status( + self, + job_id: str | None, + project_id: str | None = None, + ) -> str | None: + """ + Polls for job status asynchronously using gcloud-aio. + + Note that an OSError is raised when Job results are still pending. + Exception means that Job finished with errors + """ + async with ClientSession() as s: + try: + self.log.info("Executing get_job_status...") + job_client = await self.get_job_instance(project_id, job_id, s) + job_status_response = await job_client.result(cast(Session, s)) + if job_status_response: + job_status = "success" + except OSError: + job_status = "pending" + except Exception as e: + self.log.info("Query execution finished with errors...") + job_status = str(e) + return job_status + + async def get_job_output( + self, + job_id: str | None, + project_id: str | None = None, + ) -> dict[str, Any]: + """Get the big query job output for the given job id asynchronously using gcloud-aio.""" + async with ClientSession() as session: + self.log.info("Executing get_job_output..") + job_client = await self.get_job_instance(project_id, job_id, session) + job_query_response = await job_client.get_query_results(cast(Session, session)) + return job_query_response + + def get_records(self, query_results: dict[str, Any]) -> list[Any]: + """ + Given the output query response from gcloud-aio bigquery, convert the response to records. + + :param query_results: the results from a SQL query + """ + buffer = [] + if "rows" in query_results and query_results["rows"]: + rows = query_results["rows"] + fields = query_results["schema"]["fields"] + col_types = [field["type"] for field in fields] + for dict_row in rows: + typed_row = [bq_cast(vs["v"], col_types[idx]) for idx, vs in enumerate(dict_row["f"])] + buffer.append(typed_row) + return buffer + + def value_check( + self, + sql: str, + pass_value: Any, + records: list[Any], + tolerance: float | None = None, + ) -> None: + """ + Match a single query resulting row and tolerance with pass_value + + :return: If Match fail, we throw an AirflowException. + """ + if not records: + raise AirflowException("The query returned None") + pass_value_conv = self._convert_to_float_if_possible(pass_value) + is_numeric_value_check = isinstance(pass_value_conv, float) + tolerance_pct_str = str(tolerance * 100) + "%" if tolerance else None + + error_msg = ( + "Test failed.\nPass value:{pass_value_conv}\n" + "Tolerance:{tolerance_pct_str}\n" + "Query:\n{sql}\nResults:\n{records!s}" + ).format( + pass_value_conv=pass_value_conv, + tolerance_pct_str=tolerance_pct_str, + sql=sql, + records=records, + ) + + if not is_numeric_value_check: + tests = [str(record) == pass_value_conv for record in records] + else: + try: + numeric_records = [float(record) for record in records] + except (ValueError, TypeError): + raise AirflowException(f"Converting a result to float failed.\n{error_msg}") + tests = self._get_numeric_matches(numeric_records, pass_value_conv, tolerance) + + if not all(tests): + raise AirflowException(error_msg) + + @staticmethod + def _get_numeric_matches( + records: list[float], pass_value: Any, tolerance: float | None = None + ) -> list[bool]: + """ + A helper function to match numeric pass_value, tolerance with records value + + :param records: List of value to match against + :param pass_value: Expected value + :param tolerance: Allowed tolerance for match to succeed + """ + if tolerance: + return [ + pass_value * (1 - tolerance) <= record <= pass_value * (1 + tolerance) for record in records + ] + + return [record == pass_value for record in records] + + @staticmethod + def _convert_to_float_if_possible(s: Any) -> Any: + """ + A small helper function to convert a string to a numeric value if appropriate + + :param s: the string to be converted + """ + try: + return float(s) + except (ValueError, TypeError): + return s + + def interval_check( + self, + row1: str | None, + row2: str | None, + metrics_thresholds: dict[str, Any], + ignore_zero: bool, + ratio_formula: str, + ) -> None: + """ + Checks that the values of metrics given as SQL expressions are within a certain tolerance + + :param row1: first resulting row of a query execution job for first SQL query + :param row2: first resulting row of a query execution job for second SQL query + :param metrics_thresholds: a dictionary of ratios indexed by metrics, for + example 'COUNT(*)': 1.5 would require a 50 percent or less difference + between the current day, and the prior days_back. + :param ignore_zero: whether we should ignore zero metrics + :param ratio_formula: which formula to use to compute the ratio between + the two metrics. Assuming cur is the metric of today and ref is + the metric to today - days_back. + max_over_min: computes max(cur, ref) / min(cur, ref) + relative_diff: computes abs(cur-ref) / ref + """ + if not row2: + raise AirflowException("The second SQL query returned None") + if not row1: + raise AirflowException("The first SQL query returned None") + + ratio_formulas = { + "max_over_min": lambda cur, ref: float(max(cur, ref)) / min(cur, ref), + "relative_diff": lambda cur, ref: float(abs(cur - ref)) / ref, + } + + metrics_sorted = sorted(metrics_thresholds.keys()) + + current = dict(zip(metrics_sorted, row1)) + reference = dict(zip(metrics_sorted, row2)) + ratios: dict[str, Any] = {} + test_results: dict[str, Any] = {} + + for metric in metrics_sorted: + cur = float(current[metric]) + ref = float(reference[metric]) + threshold = float(metrics_thresholds[metric]) + if cur == 0 or ref == 0: + ratios[metric] = None + test_results[metric] = ignore_zero + else: + ratios[metric] = ratio_formulas[ratio_formula]( + float(current[metric]), float(reference[metric]) + ) + test_results[metric] = float(ratios[metric]) < threshold + + self.log.info( + ( + "Current metric for %s: %s\n" + "Past metric for %s: %s\n" + "Ratio for %s: %s\n" + "Threshold: %s\n" + ), + metric, + cur, + metric, + ref, + metric, + ratios[metric], + threshold, + ) + + if not all(test_results.values()): + failed_tests = [metric for metric, value in test_results.items() if not value] + self.log.warning( + "The following %s tests out of %s failed:", + len(failed_tests), + len(metrics_sorted), + ) + for k in failed_tests: + self.log.warning( + "'%s' check failed. %s is above %s", + k, + ratios[k], + metrics_thresholds[k], + ) + raise AirflowException(f"The following tests have failed:\n {', '.join(sorted(failed_tests))}") + + self.log.info("All tests have passed") + + +class BigQueryTableAsyncHook(GoogleBaseAsyncHook): + """Class to get async hook for Bigquery Table Async""" + + sync_hook_class = BigQueryHook + + async def get_table_client( + self, dataset: str, table_id: str, project_id: str, session: ClientSession + ) -> Table_async: + """ + Returns a Google Big Query Table object. + + :param dataset: The name of the dataset in which to look for the table storage bucket. + :param table_id: The name of the table to check the existence of. + :param project_id: The Google cloud project in which to look for the table. + The connection supplied to the hook must provide + access to the specified project. + :param session: aiohttp ClientSession + """ + with await self.service_file_as_context() as file: + return Table_async( + dataset_name=dataset, + table_name=table_id, + project=project_id, + service_file=file, + session=cast(Session, session), + ) diff --git a/airflow/providers/google/cloud/hooks/bigquery_dts.py b/airflow/providers/google/cloud/hooks/bigquery_dts.py index 5a842189cb1d7..aee1e3b6b6fe2 100644 --- a/airflow/providers/google/cloud/hooks/bigquery_dts.py +++ b/airflow/providers/google/cloud/hooks/bigquery_dts.py @@ -15,10 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains a BigQuery Hook.""" +from __future__ import annotations + from copy import copy -from typing import Optional, Sequence, Tuple, Union +from typing import Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -47,14 +48,14 @@ class BiqQueryDataTransferServiceHook(GoogleBaseHook): keyword arguments rather than positional. """ - _conn = None # type: Optional[Resource] + _conn: Resource | None = None def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - location: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + location: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -64,7 +65,7 @@ def __init__( self.location = location @staticmethod - def _disable_auto_scheduling(config: Union[dict, TransferConfig]) -> TransferConfig: + def _disable_auto_scheduling(config: dict | TransferConfig) -> TransferConfig: """ In the case of Airflow, the customer needs to create a transfer config with the automatic scheduling disabled (UI, CLI or an Airflow operator) and @@ -90,23 +91,22 @@ def get_conn(self) -> DataTransferServiceClient: Retrieves connection to Google Bigquery. :return: Google Bigquery API client - :rtype: google.cloud.bigquery_datatransfer_v1.DataTransferServiceClient """ if not self._conn: self._conn = DataTransferServiceClient( - credentials=self._get_credentials(), client_info=CLIENT_INFO + credentials=self.get_credentials(), client_info=CLIENT_INFO ) return self._conn @GoogleBaseHook.fallback_to_default_project_id def create_transfer_config( self, - transfer_config: Union[dict, TransferConfig], + transfer_config: dict | TransferConfig, project_id: str = PROVIDE_PROJECT_ID, - authorization_code: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + authorization_code: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> TransferConfig: """ Creates a new data transfer configuration. @@ -132,9 +132,9 @@ def create_transfer_config( return client.create_transfer_config( request={ - 'parent': parent, - 'transfer_config': self._disable_auto_scheduling(transfer_config), - 'authorization_code': authorization_code, + "parent": parent, + "transfer_config": self._disable_auto_scheduling(transfer_config), + "authorization_code": authorization_code, }, retry=retry, timeout=timeout, @@ -146,9 +146,9 @@ def delete_transfer_config( self, transfer_config_id: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Deletes transfer configuration. @@ -168,11 +168,11 @@ def delete_transfer_config( client = self.get_conn() project = f"projects/{project_id}" if self.location: - project = f"/{project}/locations/{self.location}" + project = f"{project}/locations/{self.location}" name = f"{project}/transferConfigs/{transfer_config_id}" return client.delete_transfer_config( - request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or () + request={"name": name}, retry=retry, timeout=timeout, metadata=metadata or () ) @GoogleBaseHook.fallback_to_default_project_id @@ -180,11 +180,11 @@ def start_manual_transfer_runs( self, transfer_config_id: str, project_id: str = PROVIDE_PROJECT_ID, - requested_time_range: Optional[dict] = None, - requested_run_time: Optional[dict] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + requested_time_range: dict | None = None, + requested_run_time: dict | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> StartManualTransferRunsResponse: """ Start manual transfer runs to be executed now with schedule_time equal @@ -219,9 +219,9 @@ def start_manual_transfer_runs( parent = f"{project}/transferConfigs/{transfer_config_id}" return client.start_manual_transfer_runs( request={ - 'parent': parent, - 'requested_time_range': requested_time_range, - 'requested_run_time': requested_run_time, + "parent": parent, + "requested_time_range": requested_time_range, + "requested_run_time": requested_run_time, }, retry=retry, timeout=timeout, @@ -234,9 +234,9 @@ def get_transfer_run( run_id: str, transfer_config_id: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> TransferRun: """ Returns information about the particular transfer run. @@ -261,5 +261,5 @@ def get_transfer_run( name = f"{project}/transferConfigs/{transfer_config_id}/runs/{run_id}" return client.get_transfer_run( - request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or () + request={"name": name}, retry=retry, timeout=timeout, metadata=metadata or () ) diff --git a/airflow/providers/google/cloud/hooks/bigtable.py b/airflow/providers/google/cloud/hooks/bigtable.py index e43e03f6fcf5e..ee9281c23e63b 100644 --- a/airflow/providers/google/cloud/hooks/bigtable.py +++ b/airflow/providers/google/cloud/hooks/bigtable.py @@ -16,8 +16,10 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Cloud Bigtable Hook.""" +from __future__ import annotations + import enum -from typing import Dict, List, Optional, Sequence, Union +from typing import Sequence from google.cloud.bigtable import Client from google.cloud.bigtable.cluster import Cluster @@ -41,8 +43,8 @@ class BigtableHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -55,7 +57,7 @@ def _get_client(self, project_id: str): if not self._client: self._client = Client( project=project_id, - credentials=self._get_credentials(), + credentials=self.get_credentials(), client_info=CLIENT_INFO, admin=True, ) @@ -104,13 +106,13 @@ def create_instance( main_cluster_id: str, main_cluster_zone: str, project_id: str, - replica_clusters: Optional[List[Dict[str, str]]] = None, - instance_display_name: Optional[str] = None, + replica_clusters: list[dict[str, str]] | None = None, + instance_display_name: str | None = None, instance_type: enums.Instance.Type = enums.Instance.Type.TYPE_UNSPECIFIED, - instance_labels: Optional[Dict] = None, - cluster_nodes: Optional[int] = None, + instance_labels: dict | None = None, + cluster_nodes: int | None = None, cluster_storage_type: enums.StorageType = enums.StorageType.STORAGE_TYPE_UNSPECIFIED, - timeout: Optional[float] = None, + timeout: float | None = None, ) -> Instance: """ Creates new instance. @@ -174,10 +176,10 @@ def update_instance( self, instance_id: str, project_id: str, - instance_display_name: Optional[str] = None, - instance_type: Optional[Union[enums.Instance.Type, enum.IntEnum]] = None, - instance_labels: Optional[Dict] = None, - timeout: Optional[float] = None, + instance_display_name: str | None = None, + instance_type: enums.Instance.Type | enum.IntEnum | None = None, + instance_labels: dict | None = None, + timeout: float | None = None, ) -> Instance: """ Update an existing instance. @@ -212,8 +214,8 @@ def update_instance( def create_table( instance: Instance, table_id: str, - initial_split_keys: Optional[List] = None, - column_families: Optional[Dict[str, GarbageCollectionRule]] = None, + initial_split_keys: list | None = None, + column_families: dict[str, GarbageCollectionRule] | None = None, ) -> None: """ Creates the specified Cloud Bigtable table. @@ -264,7 +266,7 @@ def update_cluster(instance: Instance, cluster_id: str, nodes: int) -> None: cluster.update() @staticmethod - def get_column_families_for_table(instance: Instance, table_id: str) -> Dict[str, ColumnFamily]: + def get_column_families_for_table(instance: Instance, table_id: str) -> dict[str, ColumnFamily]: """ Fetches Column Families for the specified table in Cloud Bigtable. @@ -276,7 +278,7 @@ def get_column_families_for_table(instance: Instance, table_id: str) -> Dict[str return table.list_column_families() @staticmethod - def get_cluster_states_for_table(instance: Instance, table_id: str) -> Dict[str, ClusterState]: + def get_cluster_states_for_table(instance: Instance, table_id: str) -> dict[str, ClusterState]: """ Fetches Cluster States for the specified table in Cloud Bigtable. Raises google.api_core.exceptions.NotFound if the table does not exist. diff --git a/airflow/providers/google/cloud/hooks/cloud_build.py b/airflow/providers/google/cloud/hooks/cloud_build.py index f9cc9282a6a72..6ba6fd06e9fe6 100644 --- a/airflow/providers/google/cloud/hooks/cloud_build.py +++ b/airflow/providers/google/cloud/hooks/cloud_build.py @@ -15,10 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Hook for Google Cloud Build service.""" +from __future__ import annotations -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.operation import Operation @@ -55,13 +55,13 @@ class CloudBuildHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain ) - self._client: Optional[CloudBuildClient] = None + self._client: CloudBuildClient | None = None def _get_build_id_from_operation(self, operation: Operation) -> str: """ @@ -71,7 +71,6 @@ def _get_build_id_from_operation(self, operation: Operation) -> str: version to :return: Cloud Build ID - :rtype: str """ try: return operation.metadata.build.id @@ -83,10 +82,9 @@ def get_conn(self) -> CloudBuildClient: Retrieves the connection to Google Cloud Build. :return: Google Cloud Build client object. - :rtype: `google.cloud.devtools.cloudbuild_v1.CloudBuildClient` """ if not self._client: - self._client = CloudBuildClient(credentials=self._get_credentials(), client_info=CLIENT_INFO) + self._client = CloudBuildClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) return self._client @GoogleBaseHook.fallback_to_default_project_id @@ -94,9 +92,9 @@ def cancel_build( self, id_: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Build: """ Cancels a build in progress. @@ -110,14 +108,13 @@ def cancel_build( Note that if `retry` is specified, the timeout applies to each individual attempt. :param metadata: Optional, additional metadata that is provided to the method. - :rtype: `google.cloud.devtools.cloudbuild_v1.types.Build` """ client = self.get_conn() self.log.info("Start cancelling build: %s.", id_) build = client.cancel_build( - request={'project_id': project_id, 'id': id_}, + request={"project_id": project_id, "id": id_}, retry=retry, timeout=timeout, metadata=metadata, @@ -129,12 +126,12 @@ def cancel_build( @GoogleBaseHook.fallback_to_default_project_id def create_build( self, - build: Union[Dict, Build], + build: dict | Build, project_id: str = PROVIDE_PROJECT_ID, wait: bool = True, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Build: """ Starts a build with the specified configuration. @@ -150,14 +147,13 @@ def create_build( Note that if `retry` is specified, the timeout applies to each individual attempt. :param metadata: Optional, additional metadata that is provided to the method. - :rtype: `google.cloud.devtools.cloudbuild_v1.types.Build` """ client = self.get_conn() self.log.info("Start creating build.") operation = client.create_build( - request={'project_id': project_id, 'build': build}, + request={"project_id": project_id, "build": build}, retry=retry, timeout=timeout, metadata=metadata, @@ -177,11 +173,11 @@ def create_build( @GoogleBaseHook.fallback_to_default_project_id def create_build_trigger( self, - trigger: Union[dict, BuildTrigger], + trigger: dict | BuildTrigger, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> BuildTrigger: """ Creates a new BuildTrigger. @@ -196,14 +192,13 @@ def create_build_trigger( Note that if `retry` is specified, the timeout applies to each individual attempt. :param metadata: Optional, additional metadata that is provided to the method. - :rtype: `google.cloud.devtools.cloudbuild_v1.types.BuildTrigger` """ client = self.get_conn() self.log.info("Start creating build trigger.") trigger = client.create_build_trigger( - request={'project_id': project_id, 'trigger': trigger}, + request={"project_id": project_id, "trigger": trigger}, retry=retry, timeout=timeout, metadata=metadata, @@ -218,9 +213,9 @@ def delete_build_trigger( self, trigger_id: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Deletes a BuildTrigger by its project ID and trigger ID. @@ -239,7 +234,7 @@ def delete_build_trigger( self.log.info("Start deleting build trigger: %s.", trigger_id) client.delete_build_trigger( - request={'project_id': project_id, 'trigger_id': trigger_id}, + request={"project_id": project_id, "trigger_id": trigger_id}, retry=retry, timeout=timeout, metadata=metadata, @@ -252,9 +247,9 @@ def get_build( self, id_: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Build: """ Returns information about a previously requested build. @@ -268,14 +263,13 @@ def get_build( Note that if `retry` is specified, the timeout applies to each individual attempt. :param metadata: Optional, additional metadata that is provided to the method. - :rtype: `google.cloud.devtools.cloudbuild_v1.types.Build` """ client = self.get_conn() self.log.info("Start retrieving build: %s.", id_) build = client.get_build( - request={'project_id': project_id, 'id': id_}, + request={"project_id": project_id, "id": id_}, retry=retry, timeout=timeout, metadata=metadata, @@ -290,9 +284,9 @@ def get_build_trigger( self, trigger_id: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> BuildTrigger: """ Returns information about a BuildTrigger. @@ -306,14 +300,13 @@ def get_build_trigger( Note that if `retry` is specified, the timeout applies to each individual attempt. :param metadata: Optional, additional metadata that is provided to the method. - :rtype: `google.cloud.devtools.cloudbuild_v1.types.BuildTrigger` """ client = self.get_conn() self.log.info("Start retrieving build trigger: %s.", trigger_id) trigger = client.get_build_trigger( - request={'project_id': project_id, 'trigger_id': trigger_id}, + request={"project_id": project_id, "trigger_id": trigger_id}, retry=retry, timeout=timeout, metadata=metadata, @@ -328,12 +321,12 @@ def list_build_triggers( self, location: str, project_id: str = PROVIDE_PROJECT_ID, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> List[BuildTrigger]: + page_size: int | None = None, + page_token: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> list[BuildTrigger]: """ Lists existing BuildTriggers. @@ -348,7 +341,6 @@ def list_build_triggers( Note that if `retry` is specified, the timeout applies to each individual attempt. :param metadata: Optional, additional metadata that is provided to the method. - :rtype: `google.cloud.devtools.cloudbuild_v1.types.BuildTrigger` """ client = self.get_conn() @@ -358,10 +350,10 @@ def list_build_triggers( response = client.list_build_triggers( request={ - 'parent': parent, - 'project_id': project_id, - 'page_size': page_size, - 'page_token': page_token, + "parent": parent, + "project_id": project_id, + "page_size": page_size, + "page_token": page_token, }, retry=retry, timeout=timeout, @@ -377,13 +369,13 @@ def list_builds( self, location: str, project_id: str = PROVIDE_PROJECT_ID, - page_size: Optional[int] = None, - page_token: Optional[int] = None, - filter_: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> List[Build]: + page_size: int | None = None, + page_token: int | None = None, + filter_: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> list[Build]: """ Lists previously requested builds. @@ -399,7 +391,6 @@ def list_builds( Note that if `retry` is specified, the timeout applies to each individual attempt. :param metadata: Optional, additional metadata that is provided to the method. - :rtype: List[`google.cloud.devtools.cloudbuild_v1.types.Build`] """ client = self.get_conn() @@ -409,11 +400,11 @@ def list_builds( response = client.list_builds( request={ - 'parent': parent, - 'project_id': project_id, - 'page_size': page_size, - 'page_token': page_token, - 'filter': filter_, + "parent": parent, + "project_id": project_id, + "page_size": page_size, + "page_token": page_token, + "filter": filter_, }, retry=retry, timeout=timeout, @@ -430,9 +421,9 @@ def retry_build( id_: str, project_id: str = PROVIDE_PROJECT_ID, wait: bool = True, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Build: """ Creates a new build based on the specified build. This method creates a new build @@ -448,14 +439,13 @@ def retry_build( Note that if `retry` is specified, the timeout applies to each individual attempt. :param metadata: Optional, additional metadata that is provided to the method. - :rtype: `google.cloud.devtools.cloudbuild_v1.types.Build` """ client = self.get_conn() self.log.info("Start retrying build: %s.", id_) operation = client.retry_build( - request={'project_id': project_id, 'id': id_}, + request={"project_id": project_id, "id": id_}, retry=retry, timeout=timeout, metadata=metadata, @@ -476,12 +466,12 @@ def retry_build( def run_build_trigger( self, trigger_id: str, - source: Union[dict, RepoSource], + source: dict | RepoSource, project_id: str = PROVIDE_PROJECT_ID, wait: bool = True, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Build: """ Runs a BuildTrigger at a particular source revision. @@ -498,14 +488,13 @@ def run_build_trigger( Note that if `retry` is specified, the timeout applies to each individual attempt. :param metadata: Optional, additional metadata that is provided to the method. - :rtype: `google.cloud.devtools.cloudbuild_v1.types.Build` """ client = self.get_conn() self.log.info("Start running build trigger: %s.", trigger_id) operation = client.run_build_trigger( - request={'project_id': project_id, 'trigger_id': trigger_id, 'source': source}, + request={"project_id": project_id, "trigger_id": trigger_id, "source": source}, retry=retry, timeout=timeout, metadata=metadata, @@ -526,11 +515,11 @@ def run_build_trigger( def update_build_trigger( self, trigger_id: str, - trigger: Union[dict, BuildTrigger], + trigger: dict | BuildTrigger, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> BuildTrigger: """ Updates a BuildTrigger by its project ID and trigger ID. @@ -546,14 +535,13 @@ def update_build_trigger( Note that if `retry` is specified, the timeout applies to each individual attempt. :param metadata: Optional, additional metadata that is provided to the method. - :rtype: `google.cloud.devtools.cloudbuild_v1.types.BuildTrigger` """ client = self.get_conn() self.log.info("Start updating build trigger: %s.", trigger_id) trigger = client.update_build_trigger( - request={'project_id': project_id, 'trigger_id': trigger_id, 'trigger': trigger}, + request={"project_id": project_id, "trigger_id": trigger_id, "trigger": trigger}, retry=retry, timeout=timeout, metadata=metadata, diff --git a/airflow/providers/google/cloud/hooks/cloud_composer.py b/airflow/providers/google/cloud/hooks/cloud_composer.py index 33f987c1c1040..793aec9f20d27 100644 --- a/airflow/providers/google/cloud/hooks/cloud_composer.py +++ b/airflow/providers/google/cloud/hooks/cloud_composer.py @@ -15,14 +15,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import Dict, Optional, Sequence, Tuple, Union +from typing import Sequence from google.api_core.client_options import ClientOptions from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.operation import Operation +from google.api_core.operation_async import AsyncOperation from google.api_core.retry import Retry -from google.cloud.orchestration.airflow.service_v1 import EnvironmentsClient, ImageVersionsClient +from google.cloud.orchestration.airflow.service_v1 import ( + EnvironmentsAsyncClient, + EnvironmentsClient, + ImageVersionsClient, +) from google.cloud.orchestration.airflow.service_v1.services.environments.pagers import ListEnvironmentsPager from google.cloud.orchestration.airflow.service_v1.services.image_versions.pagers import ( ListImageVersionsPager, @@ -38,27 +44,25 @@ class CloudComposerHook(GoogleBaseHook): """Hook for Google Cloud Composer APIs.""" - client_options = ClientOptions(api_endpoint='composer.googleapis.com:443') + client_options = ClientOptions(api_endpoint="composer.googleapis.com:443") def get_environment_client(self) -> EnvironmentsClient: """Retrieves client library object that allow access Environments service.""" return EnvironmentsClient( - credentials=self._get_credentials(), + credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=self.client_options, ) - def get_image_versions_client( - self, - ) -> ImageVersionsClient: + def get_image_versions_client(self) -> ImageVersionsClient: """Retrieves client library object that allow access Image Versions service.""" return ImageVersionsClient( - credentials=self._get_credentials(), + credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=self.client_options, ) - def wait_for_operation(self, operation: Operation, timeout: Optional[float] = None): + def wait_for_operation(self, operation: Operation, timeout: float | None = None): """Waits for long-lasting operation to complete.""" try: return operation.result(timeout=timeout) @@ -70,20 +74,20 @@ def get_operation(self, operation_name): return self.get_environment_client().transport.operations_client.get_operation(name=operation_name) def get_environment_name(self, project_id, region, environment_id): - return f'projects/{project_id}/locations/{region}/environments/{environment_id}' + return f"projects/{project_id}/locations/{region}/environments/{environment_id}" def get_parent(self, project_id, region): - return f'projects/{project_id}/locations/{region}' + return f"projects/{project_id}/locations/{region}" @GoogleBaseHook.fallback_to_default_project_id def create_environment( self, project_id: str, region: str, - environment: Union[Environment, Dict], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + environment: Environment | dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Create a new environment. @@ -98,7 +102,7 @@ def create_environment( """ client = self.get_environment_client() result = client.create_environment( - request={'parent': self.get_parent(project_id, region), 'environment': environment}, + request={"parent": self.get_parent(project_id, region), "environment": environment}, retry=retry, timeout=timeout, metadata=metadata, @@ -111,9 +115,9 @@ def delete_environment( project_id: str, region: str, environment_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Delete an environment. @@ -138,9 +142,9 @@ def get_environment( project_id: str, region: str, environment_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Environment: """ Get an existing environment. @@ -153,7 +157,7 @@ def get_environment( """ client = self.get_environment_client() result = client.get_environment( - request={'name': self.get_environment_name(project_id, region, environment_id)}, + request={"name": self.get_environment_name(project_id, region, environment_id)}, retry=retry, timeout=timeout, metadata=metadata, @@ -165,11 +169,11 @@ def list_environments( self, project_id: str, region: str, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + page_size: int | None = None, + page_token: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ListEnvironmentsPager: """ List environments. @@ -202,11 +206,11 @@ def update_environment( project_id: str, region: str, environment_id: str, - environment: Union[Environment, Dict], - update_mask: Union[Dict, FieldMask], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + environment: Environment | dict, + update_mask: dict | FieldMask, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: r""" Update an environment. @@ -242,12 +246,12 @@ def list_image_versions( self, project_id: str, region: str, - page_size: Optional[int] = None, - page_token: Optional[str] = None, + page_size: int | None = None, + page_token: str | None = None, include_past_releases: bool = False, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ListImageVersionsPager: """ List ImageVersions for provided location. @@ -265,7 +269,7 @@ def list_image_versions( client = self.get_image_versions_client() result = client.list_image_versions( request={ - 'parent': self.get_parent(project_id, region), + "parent": self.get_parent(project_id, region), "page_size": page_size, "page_token": page_token, "include_past_releases": include_past_releases, @@ -275,3 +279,123 @@ def list_image_versions( metadata=metadata, ) return result + + +class CloudComposerAsyncHook(GoogleBaseHook): + """Hook for Google Cloud Composer async APIs.""" + + client_options = ClientOptions(api_endpoint="composer.googleapis.com:443") + + def get_environment_client(self) -> EnvironmentsAsyncClient: + """Retrieves client library object that allow access Environments service.""" + return EnvironmentsAsyncClient( + credentials=self.get_credentials(), + client_info=CLIENT_INFO, + client_options=self.client_options, + ) + + def get_environment_name(self, project_id, region, environment_id): + return f"projects/{project_id}/locations/{region}/environments/{environment_id}" + + def get_parent(self, project_id, region): + return f"projects/{project_id}/locations/{region}" + + async def get_operation(self, operation_name): + return await self.get_environment_client().transport.operations_client.get_operation( + name=operation_name + ) + + @GoogleBaseHook.fallback_to_default_project_id + async def create_environment( + self, + project_id: str, + region: str, + environment: Environment | dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> AsyncOperation: + """ + Create a new environment. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param environment: The environment to create. This corresponds to the ``environment`` field on the + ``request`` instance; if ``request`` is provided, this should not be set. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_environment_client() + return await client.create_environment( + request={"parent": self.get_parent(project_id, region), "environment": environment}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + async def delete_environment( + self, + project_id: str, + region: str, + environment_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> AsyncOperation: + """ + Delete an environment. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param environment_id: Required. The ID of the Google Cloud environment that the service belongs to. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_environment_client() + name = self.get_environment_name(project_id, region, environment_id) + return await client.delete_environment( + request={"name": name}, retry=retry, timeout=timeout, metadata=metadata + ) + + @GoogleBaseHook.fallback_to_default_project_id + async def update_environment( + self, + project_id: str, + region: str, + environment_id: str, + environment: Environment | dict, + update_mask: dict | FieldMask, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> AsyncOperation: + r""" + Update an environment. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param environment_id: Required. The ID of the Google Cloud environment that the service belongs to. + :param environment: A patch environment. Fields specified by the ``updateMask`` will be copied from + the patch environment into the environment under update. + + This corresponds to the ``environment`` field on the ``request`` instance; if ``request`` is + provided, this should not be set. + :param update_mask: Required. A comma-separated list of paths, relative to ``Environment``, of fields + to update. If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.protobuf.field_mask_pb2.FieldMask` + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_environment_client() + name = self.get_environment_name(project_id, region, environment_id) + + return await client.update_environment( + request={"name": name, "environment": environment, "update_mask": update_mask}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) diff --git a/airflow/providers/google/cloud/hooks/cloud_memorystore.py b/airflow/providers/google/cloud/hooks/cloud_memorystore.py index 0abb308dcf500..20ac955c5edb0 100644 --- a/airflow/providers/google/cloud/hooks/cloud_memorystore.py +++ b/airflow/providers/google/cloud/hooks/cloud_memorystore.py @@ -25,7 +25,9 @@ pb memcache """ -from typing import Dict, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from typing import Sequence from google.api_core import path_template from google.api_core.exceptions import NotFound @@ -71,26 +73,26 @@ class CloudMemorystoreHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) - self._client: Optional[CloudRedisClient] = None + self._client: CloudRedisClient | None = None def get_conn(self) -> CloudRedisClient: """Retrieves client library object that allow access to Cloud Memorystore service.""" if not self._client: - self._client = CloudRedisClient(credentials=self._get_credentials()) + self._client = CloudRedisClient(credentials=self.get_credentials()) return self._client @staticmethod def _append_label(instance: Instance, key: str, val: str) -> Instance: """ - Append labels to provided Instance type + Append labels to provided Instance type. Labels must fit the regex ``[a-z]([-a-z0-9]*[a-z0-9])?`` (current airflow version string follows semantic versioning spec: x.y.z). @@ -110,11 +112,11 @@ def create_instance( self, location: str, instance_id: str, - instance: Union[Dict, Instance], + instance: dict | Instance, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Creates a Redis instance based on the specified tier and memory size. @@ -154,7 +156,7 @@ def create_instance( try: self.log.info("Fetching instance: %s", instance_name) instance = client.get_instance( - request={'name': instance_name}, retry=retry, timeout=timeout, metadata=metadata or () + request={"name": instance_name}, retry=retry, timeout=timeout, metadata=metadata or () ) self.log.info("Instance exists. Skipping creation.") return instance @@ -164,7 +166,7 @@ def create_instance( self._append_label(instance, "airflow-version", "v" + version.version) result = client.create_instance( - request={'parent': parent, 'instance_id': instance_id, 'instance': instance}, + request={"parent": parent, "instance_id": instance_id, "instance": instance}, retry=retry, timeout=timeout, metadata=metadata, @@ -172,7 +174,7 @@ def create_instance( result.result() self.log.info("Instance created.") return client.get_instance( - request={'name': instance_name}, retry=retry, timeout=timeout, metadata=metadata or () + request={"name": instance_name}, retry=retry, timeout=timeout, metadata=metadata or () ) @GoogleBaseHook.fallback_to_default_project_id @@ -181,9 +183,9 @@ def delete_instance( location: str, instance: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Deletes a specific Redis instance. Instance stops serving and data is deleted. @@ -202,7 +204,7 @@ def delete_instance( name = f"projects/{project_id}/locations/{location}/instances/{instance}" self.log.info("Fetching Instance: %s", name) instance = client.get_instance( - request={'name': name}, + request={"name": name}, retry=retry, timeout=timeout, metadata=metadata, @@ -213,7 +215,7 @@ def delete_instance( self.log.info("Deleting Instance: %s", name) result = client.delete_instance( - request={'name': name}, + request={"name": name}, retry=retry, timeout=timeout, metadata=metadata, @@ -226,11 +228,11 @@ def export_instance( self, location: str, instance: str, - output_config: Union[Dict, OutputConfig], + output_config: dict | OutputConfig, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Export Redis instance data into a Redis RDB format file in Cloud Storage. @@ -255,7 +257,7 @@ def export_instance( name = f"projects/{project_id}/locations/{location}/instances/{instance}" self.log.info("Exporting Instance: %s", name) result = client.export_instance( - request={'name': name, 'output_config': output_config}, + request={"name": name, "output_config": output_config}, retry=retry, timeout=timeout, metadata=metadata, @@ -270,11 +272,13 @@ def failover_instance( instance: str, data_protection_mode: FailoverInstanceRequest.DataProtectionMode, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ + Failover of the primary node to current replica node. + Initiates a failover of the primary node to current replica node for a specific STANDARD tier Cloud Memorystore for Redis instance. @@ -296,7 +300,7 @@ def failover_instance( self.log.info("Failovering Instance: %s", name) result = client.failover_instance( - request={'name': name, 'data_protection_mode': data_protection_mode}, + request={"name": name, "data_protection_mode": data_protection_mode}, retry=retry, timeout=timeout, metadata=metadata, @@ -310,9 +314,9 @@ def get_instance( location: str, instance: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Gets the details of a specific Redis instance. @@ -330,7 +334,7 @@ def get_instance( client = self.get_conn() name = f"projects/{project_id}/locations/{location}/instances/{instance}" result = client.get_instance( - request={'name': name}, + request={"name": name}, retry=retry, timeout=timeout, metadata=metadata, @@ -343,11 +347,11 @@ def import_instance( self, location: str, instance: str, - input_config: Union[Dict, InputConfig], + input_config: dict | InputConfig, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Import a Redis RDB snapshot file from Cloud Storage into a Redis instance. @@ -373,7 +377,7 @@ def import_instance( name = f"projects/{project_id}/locations/{location}/instances/{instance}" self.log.info("Importing Instance: %s", name) result = client.import_instance( - request={'name': name, 'input_config': input_config}, + request={"name": name, "input_config": input_config}, retry=retry, timeout=timeout, metadata=metadata, @@ -387,13 +391,12 @@ def list_instances( location: str, page_size: int, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ - Lists all Redis instances owned by a project in either the specified location (region) or all - locations. + List Redis instances owned by a project at the specified location (region) or all locations. :param location: The location of the Cloud Memorystore instance (for example europe-west1) @@ -413,7 +416,7 @@ def list_instances( client = self.get_conn() parent = f"projects/{project_id}/locations/{location}" result = client.list_instances( - request={'parent': parent, 'page_size': page_size}, + request={"parent": parent, "page_size": page_size}, retry=retry, timeout=timeout, metadata=metadata, @@ -424,14 +427,14 @@ def list_instances( @GoogleBaseHook.fallback_to_default_project_id def update_instance( self, - update_mask: Union[Dict, FieldMask], - instance: Union[Dict, Instance], + update_mask: dict | FieldMask, + instance: dict | Instance, project_id: str = PROVIDE_PROJECT_ID, - location: Optional[str] = None, - instance_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + location: str | None = None, + instance_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Updates the metadata and configuration of a specific Redis instance. @@ -473,13 +476,14 @@ def update_instance( self.log.info("Updating instances: %s", instance.name) result = client.update_instance( - request={'update_mask': update_mask, 'instance': instance}, + request={"update_mask": update_mask, "instance": instance}, retry=retry, timeout=timeout, metadata=metadata, ) - result.result() + updated_instance = result.result() self.log.info("Instance updated: %s", instance.name) + return updated_instance class CloudMemorystoreMemcachedHook(GoogleBaseHook): @@ -506,28 +510,26 @@ class CloudMemorystoreMemcachedHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) - self._client: Optional[CloudMemcacheClient] = None + self._client: CloudMemcacheClient | None = None - def get_conn( - self, - ): + def get_conn(self): """Retrieves client library object that allow access to Cloud Memorystore Memcached service.""" if not self._client: - self._client = CloudMemcacheClient(credentials=self._get_credentials()) + self._client = CloudMemcacheClient(credentials=self.get_credentials()) return self._client @staticmethod def _append_label(instance: cloud_memcache.Instance, key: str, val: str) -> cloud_memcache.Instance: """ - Append labels to provided Instance type + Append labels to provided Instance type. Labels must fit the regex ``[a-z]([-a-z0-9]*[a-z0-9])?`` (current airflow version string follows semantic versioning spec: x.y.z). @@ -550,9 +552,9 @@ def apply_parameters( project_id: str, location: str, instance_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Will update current set of Parameters to the set of specified nodes of the Memcached Instance. @@ -592,11 +594,11 @@ def create_instance( self, location: str, instance_id: str, - instance: Union[Dict, cloud_memcache.Instance], + instance: dict | cloud_memcache.Instance, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Creates a Memcached instance based on the specified tier and memory size. @@ -670,9 +672,9 @@ def delete_instance( location: str, instance: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Deletes a specific Memcached instance. Instance stops serving and data is deleted. @@ -717,9 +719,9 @@ def get_instance( location: str, instance: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Gets the details of a specific Memcached instance. @@ -746,13 +748,12 @@ def list_instances( self, location: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ - Lists all Memcached instances owned by a project in either the specified location (region) or all - locations. + List Memcached instances owned by a project at the specified location (region) or all locations. :param location: The location of the Cloud Memorystore instance (for example europe-west1) @@ -783,14 +784,14 @@ def list_instances( @GoogleBaseHook.fallback_to_default_project_id def update_instance( self, - update_mask: Union[Dict, FieldMask], - instance: Union[Dict, cloud_memcache.Instance], + update_mask: dict | FieldMask, + instance: dict | cloud_memcache.Instance, project_id: str, - location: Optional[str] = None, - instance_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + location: str | None = None, + instance_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Updates the metadata and configuration of a specific Memcached instance. @@ -802,7 +803,7 @@ def update_instance( If a dict is provided, it must be of the same form as the protobuf message :class:`~google.protobuf.field_mask_pb2.FieldMask`) - Union[Dict, google.protobuf.field_mask_pb2.FieldMask] + Union[dict, google.protobuf.field_mask_pb2.FieldMask] :param instance: Required. Update description. Only fields specified in ``update_mask`` are updated. If a dict is provided, it must be of the same form as the protobuf message @@ -833,30 +834,31 @@ def update_instance( result = client.update_instance( update_mask=update_mask, resource=instance, retry=retry, timeout=timeout, metadata=metadata or () ) - result.result() + updated_instance = result.result() self.log.info("Instance updated: %s", instance.name) + return updated_instance @GoogleBaseHook.fallback_to_default_project_id def update_parameters( self, - update_mask: Union[Dict, FieldMask], - parameters: Union[Dict, cloud_memcache.MemcacheParameters], + update_mask: dict | FieldMask, + parameters: dict | cloud_memcache.MemcacheParameters, project_id: str, location: str, instance_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ - Updates the defined Memcached Parameters for an existing Instance. This method only stages the - parameters, it must be followed by apply_parameters to apply the parameters to nodes of - the Memcached Instance. + Update the defined Memcached Parameters for an existing Instance. + + This method only stages the parameters, it must be followed by apply_parameters + to apply the parameters to nodes of the Memcached Instance. :param update_mask: Required. Mask of fields to update. If a dict is provided, it must be of the same form as the protobuf message :class:`~google.protobuf.field_mask_pb2.FieldMask` - Union[Dict, google.protobuf.field_mask_pb2.FieldMask] :param parameters: The parameters to apply to the instance. If a dict is provided, it must be of the same form as the protobuf message :class:`~google.cloud.memcache_v1beta2.types.cloud_memcache.MemcacheParameters` diff --git a/airflow/providers/google/cloud/hooks/cloud_sql.py b/airflow/providers/google/cloud/hooks/cloud_sql.py index 95f16fe9c19f2..330d9f3fcd840 100644 --- a/airflow/providers/google/cloud/hooks/cloud_sql.py +++ b/airflow/providers/google/cloud/hooks/cloud_sql.py @@ -15,8 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains a Google Cloud SQL Hook.""" +from __future__ import annotations import errno import json @@ -35,7 +35,7 @@ from pathlib import Path from subprocess import PIPE, Popen from tempfile import gettempdir -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Sequence from urllib.parse import quote_plus import httpx @@ -48,7 +48,7 @@ # For requests that are "retriable" from airflow.hooks.base import BaseHook from airflow.models import Connection -from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook, get_field from airflow.providers.mysql.hooks.mysql import MySqlHook from airflow.providers.postgres.hooks.postgres import PostgresHook from airflow.utils.log.logging_mixin import LoggingMixin @@ -82,17 +82,17 @@ class CloudSQLHook(GoogleBaseHook): credentials. """ - conn_name_attr = 'gcp_conn_id' - default_conn_name = 'google_cloud_sql_default' - conn_type = 'gcpcloudsql' - hook_name = 'Google Cloud SQL' + conn_name_attr = "gcp_conn_id" + default_conn_name = "google_cloud_sql_default" + conn_type = "gcpcloudsql" + hook_name = "Google Cloud SQL" def __init__( self, api_version: str, gcp_conn_id: str = default_conn_name, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -107,11 +107,10 @@ def get_conn(self) -> Resource: Retrieves connection to Cloud SQL. :return: Google Cloud SQL services object. - :rtype: dict """ if not self._conn: http_authorized = self._authorize() - self._conn = build('sqladmin', self.api_version, http=http_authorized, cache_discovery=False) + self._conn = build("sqladmin", self.api_version, http=http_authorized, cache_discovery=False) return self._conn @GoogleBaseHook.fallback_to_default_project_id @@ -123,7 +122,6 @@ def get_instance(self, instance: str, project_id: str) -> dict: :param project_id: Project ID of the project that contains the instance. If set to None or missing, the default project_id from the Google Cloud connection is used. :return: A Cloud SQL instance resource. - :rtype: dict """ return ( self.get_conn() @@ -134,7 +132,7 @@ def get_instance(self, instance: str, project_id: str) -> dict: @GoogleBaseHook.fallback_to_default_project_id @GoogleBaseHook.operation_in_progress_retry() - def create_instance(self, body: Dict, project_id: str) -> None: + def create_instance(self, body: dict, project_id: str) -> None: """ Creates a new Cloud SQL instance. @@ -209,7 +207,6 @@ def get_database(self, instance: str, database: str, project_id: str) -> dict: to None or missing, the default project_id from the Google Cloud connection is used. :return: A Cloud SQL database resource, as described in https://cloud.google.com/sql/docs/mysql/admin-api/v1beta4/databases#resource. - :rtype: dict """ return ( self.get_conn() @@ -220,7 +217,7 @@ def get_database(self, instance: str, database: str, project_id: str) -> dict: @GoogleBaseHook.fallback_to_default_project_id @GoogleBaseHook.operation_in_progress_retry() - def create_database(self, instance: str, body: Dict, project_id: str) -> None: + def create_database(self, instance: str, body: dict, project_id: str) -> None: """ Creates a new database inside a Cloud SQL instance. @@ -246,7 +243,7 @@ def patch_database( self, instance: str, database: str, - body: Dict, + body: dict, project_id: str, ) -> None: """ @@ -295,7 +292,7 @@ def delete_database(self, instance: str, database: str, project_id: str) -> None @GoogleBaseHook.fallback_to_default_project_id @GoogleBaseHook.operation_in_progress_retry() - def export_instance(self, instance: str, body: Dict, project_id: str) -> None: + def export_instance(self, instance: str, body: dict, project_id: str) -> None: """ Exports data from a Cloud SQL instance to a Cloud Storage bucket as a SQL dump or CSV file. @@ -318,7 +315,7 @@ def export_instance(self, instance: str, body: Dict, project_id: str) -> None: self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name) @GoogleBaseHook.fallback_to_default_project_id - def import_instance(self, instance: str, body: Dict, project_id: str) -> None: + def import_instance(self, instance: str, body: dict, project_id: str) -> None: """ Imports data into a Cloud SQL instance from a SQL dump or CSV file in Cloud Storage. @@ -341,7 +338,7 @@ def import_instance(self, instance: str, body: Dict, project_id: str) -> None: operation_name = response["name"] self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name) except HttpError as ex: - raise AirflowException(f'Importing instance {instance} failed: {ex.content}') + raise AirflowException(f"Importing instance {instance} failed: {ex.content}") def _wait_for_operation_to_complete(self, project_id: str, operation_name: str) -> None: """ @@ -375,9 +372,6 @@ def _wait_for_operation_to_complete(self, project_id: str, operation_name: str) "https://storage.googleapis.com/cloudsql-proxy/{}/cloud_sql_proxy.{}.{}" ) -GCP_CREDENTIALS_KEY_PATH = "extra__google_cloud_platform__key_path" -GCP_CREDENTIALS_KEYFILE_DICT = "extra__google_cloud_platform__keyfile_dict" - class CloudSqlProxyRunner(LoggingMixin): """ @@ -414,10 +408,10 @@ def __init__( self, path_prefix: str, instance_specification: str, - gcp_conn_id: str = 'google_cloud_default', - project_id: Optional[str] = None, - sql_proxy_version: Optional[str] = None, - sql_proxy_binary_path: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + project_id: str | None = None, + sql_proxy_version: str | None = None, + sql_proxy_binary_path: str | None = None, ) -> None: super().__init__() self.path_prefix = path_prefix @@ -426,11 +420,11 @@ def __init__( self.sql_proxy_was_downloaded = False self.sql_proxy_version = sql_proxy_version self.download_sql_proxy_dir = None - self.sql_proxy_process = None # type: Optional[Popen] + self.sql_proxy_process: Popen | None = None self.instance_specification = instance_specification self.project_id = project_id self.gcp_conn_id = gcp_conn_id - self.command_line_parameters = [] # type: List[str] + self.command_line_parameters: list[str] = [] self.cloud_sql_proxy_socket_directory = self.path_prefix self.sql_proxy_path = ( sql_proxy_binary_path if sql_proxy_binary_path else self.path_prefix + "_cloud_sql_proxy" @@ -439,12 +433,12 @@ def __init__( self._build_command_line_parameters() def _build_command_line_parameters(self) -> None: - self.command_line_parameters.extend(['-dir', self.cloud_sql_proxy_socket_directory]) - self.command_line_parameters.extend(['-instances', self.instance_specification]) + self.command_line_parameters.extend(["-dir", self.cloud_sql_proxy_socket_directory]) + self.command_line_parameters.extend(["-instances", self.instance_specification]) @staticmethod def _is_os_64bit() -> bool: - return platform.machine().endswith('64') + return platform.machine().endswith("64") def _download_sql_proxy_if_needed(self) -> None: if os.path.isfile(self.sql_proxy_path): @@ -470,7 +464,7 @@ def _download_sql_proxy_if_needed(self) -> None: response = httpx.get(download_url, allow_redirects=True) # type: ignore[call-arg] # Downloading to .tmp file first to avoid case where partially downloaded # binary is used by parallel operator which uses the same fixed binary path - with open(proxy_path_tmp, 'wb') as file: + with open(proxy_path_tmp, "wb") as file: file.write(response.content) if response.status_code != 200: raise AirflowException( @@ -483,17 +477,18 @@ def _download_sql_proxy_if_needed(self) -> None: os.chmod(self.sql_proxy_path, 0o744) # Set executable bit self.sql_proxy_was_downloaded = True - def _get_credential_parameters(self) -> List[str]: - connection = GoogleBaseHook.get_connection(conn_id=self.gcp_conn_id) - - if connection.extra_dejson.get(GCP_CREDENTIALS_KEY_PATH): - credential_params = ['-credential_file', connection.extra_dejson[GCP_CREDENTIALS_KEY_PATH]] - elif connection.extra_dejson.get(GCP_CREDENTIALS_KEYFILE_DICT): - credential_file_content = json.loads(connection.extra_dejson[GCP_CREDENTIALS_KEYFILE_DICT]) + def _get_credential_parameters(self) -> list[str]: + extras = GoogleBaseHook.get_connection(conn_id=self.gcp_conn_id).extra_dejson + key_path = get_field(extras, "key_path") + keyfile_dict = get_field(extras, "keyfile_dict") + if key_path: + credential_params = ["-credential_file", key_path] + elif keyfile_dict: + keyfile_content = keyfile_dict if isinstance(keyfile_dict, dict) else json.loads(keyfile_dict) self.log.info("Saving credentials to %s", self.credentials_path) with open(self.credentials_path, "w") as file: - json.dump(credential_file_content, file) - credential_params = ['-credential_file', self.credentials_path] + json.dump(keyfile_content, file) + credential_params = ["-credential_file", self.credentials_path] else: self.log.info( "The credentials are not supplied by neither key_path nor " @@ -504,7 +499,7 @@ def _get_credential_parameters(self) -> List[str]: credential_params = [] if not self.instance_specification: - project_id = connection.extra_dejson.get('extra__google_cloud_platform__project') + project_id = get_field(extras, "project") if self.project_id: project_id = self.project_id if not project_id: @@ -514,7 +509,7 @@ def _get_credential_parameters(self) -> List[str]: "by project_id extra in the Google Cloud connection or by " "project_id provided in the operator." ) - credential_params.extend(['-projects', project_id]) + credential_params.extend(["-projects", project_id]) return credential_params def start_proxy(self) -> None: @@ -538,17 +533,17 @@ def start_proxy(self) -> None: self.log.info("The pid of cloud_sql_proxy: %s", self.sql_proxy_process.pid) while True: line = ( - self.sql_proxy_process.stderr.readline().decode('utf-8') + self.sql_proxy_process.stderr.readline().decode("utf-8") if self.sql_proxy_process.stderr else "" ) return_code = self.sql_proxy_process.poll() - if line == '' and return_code is not None: + if line == "" and return_code is not None: self.sql_proxy_process = None raise AirflowException( f"The cloud_sql_proxy finished early with return code {return_code}!" ) - if line != '': + if line != "": self.log.info(line) if "googleapi: Error" in line or "invalid instance name:" in line: self.stop_proxy() @@ -586,13 +581,13 @@ def stop_proxy(self) -> None: # Here file cannot be delete by concurrent task (each task has its own copy) os.remove(self.credentials_path) - def get_proxy_version(self) -> Optional[str]: + def get_proxy_version(self) -> str | None: """Returns version of the Cloud SQL Proxy.""" self._download_sql_proxy_if_needed() command_to_run = [self.sql_proxy_path] - command_to_run.extend(['--version']) + command_to_run.extend(["--version"]) command_to_run.extend(self._get_credential_parameters()) - result = subprocess.check_output(command_to_run).decode('utf-8') + result = subprocess.check_output(command_to_run).decode("utf-8") pattern = re.compile("^.*[V|v]ersion ([^;]*);.*$") matched = pattern.match(result) if matched: @@ -605,12 +600,11 @@ def get_socket_path(self) -> str: Retrieves UNIX socket path used by Cloud SQL Proxy. :return: The dynamically generated path for the socket created by the proxy. - :rtype: str """ return self.cloud_sql_proxy_socket_directory + "/" + self.instance_specification -CONNECTION_URIS = { +CONNECTION_URIS: dict[str, dict[str, dict[str, str]]] = { "postgres": { "proxy": { "tcp": "postgresql://{user}:{password}@127.0.0.1:{proxy_port}/{database}", @@ -635,9 +629,9 @@ def get_socket_path(self) -> str: "non-ssl": "mysql://{user}:{password}@{public_ip}:{public_port}/{database}", }, }, -} # type: Dict[str, Dict[str, Dict[str, str]]] +} -CLOUD_SQL_VALID_DATABASE_TYPES = ['postgres', 'mysql'] +CLOUD_SQL_VALID_DATABASE_TYPES = ["postgres", "mysql"] class CloudSQLDatabaseHook(BaseHook): @@ -687,46 +681,46 @@ class CloudSQLDatabaseHook(BaseHook): in the connection URL """ - conn_name_attr = 'gcp_cloudsql_conn_id' - default_conn_name = 'google_cloud_sqldb_default' - conn_type = 'gcpcloudsqldb' - hook_name = 'Google Cloud SQL Database' + conn_name_attr = "gcp_cloudsql_conn_id" + default_conn_name = "google_cloud_sqldb_default" + conn_type = "gcpcloudsqldb" + hook_name = "Google Cloud SQL Database" - _conn = None # type: Optional[Any] + _conn = None def __init__( self, - gcp_cloudsql_conn_id: str = 'google_cloud_sql_default', - gcp_conn_id: str = 'google_cloud_default', - default_gcp_project_id: Optional[str] = None, + gcp_cloudsql_conn_id: str = "google_cloud_sql_default", + gcp_conn_id: str = "google_cloud_default", + default_gcp_project_id: str | None = None, ) -> None: super().__init__() self.gcp_conn_id = gcp_conn_id self.gcp_cloudsql_conn_id = gcp_cloudsql_conn_id self.cloudsql_connection = self.get_connection(self.gcp_cloudsql_conn_id) self.extras = self.cloudsql_connection.extra_dejson - self.project_id = self.extras.get('project_id', default_gcp_project_id) # type: Optional[str] - self.instance = self.extras.get('instance') # type: Optional[str] - self.database = self.cloudsql_connection.schema # type: Optional[str] - self.location = self.extras.get('location') # type: Optional[str] - self.database_type = self.extras.get('database_type') # type: Optional[str] - self.use_proxy = self._get_bool(self.extras.get('use_proxy', 'False')) # type: bool - self.use_ssl = self._get_bool(self.extras.get('use_ssl', 'False')) # type: bool - self.sql_proxy_use_tcp = self._get_bool(self.extras.get('sql_proxy_use_tcp', 'False')) # type: bool - self.sql_proxy_version = self.extras.get('sql_proxy_version') # type: Optional[str] - self.sql_proxy_binary_path = self.extras.get('sql_proxy_binary_path') # type: Optional[str] - self.user = self.cloudsql_connection.login # type: Optional[str] - self.password = self.cloudsql_connection.password # type: Optional[str] - self.public_ip = self.cloudsql_connection.host # type: Optional[str] - self.public_port = self.cloudsql_connection.port # type: Optional[int] - self.sslcert = self.extras.get('sslcert') # type: Optional[str] - self.sslkey = self.extras.get('sslkey') # type: Optional[str] - self.sslrootcert = self.extras.get('sslrootcert') # type: Optional[str] + self.project_id = self.extras.get("project_id", default_gcp_project_id) + self.instance = self.extras.get("instance") + self.database = self.cloudsql_connection.schema + self.location = self.extras.get("location") + self.database_type = self.extras.get("database_type") + self.use_proxy = self._get_bool(self.extras.get("use_proxy", "False")) + self.use_ssl = self._get_bool(self.extras.get("use_ssl", "False")) + self.sql_proxy_use_tcp = self._get_bool(self.extras.get("sql_proxy_use_tcp", "False")) + self.sql_proxy_version = self.extras.get("sql_proxy_version") + self.sql_proxy_binary_path = self.extras.get("sql_proxy_binary_path") + self.user = self.cloudsql_connection.login + self.password = self.cloudsql_connection.password + self.public_ip = self.cloudsql_connection.host + self.public_port = self.cloudsql_connection.port + self.sslcert = self.extras.get("sslcert") + self.sslkey = self.extras.get("sslkey") + self.sslrootcert = self.extras.get("sslrootcert") # Port and socket path and db_hook are automatically generated self.sql_proxy_tcp_port = None - self.sql_proxy_unique_path = None # type: Optional[str] - self.db_hook = None # type: Optional[Union[PostgresHook, MySqlHook]] - self.reserved_tcp_socket = None # type: Optional[socket.socket] + self.sql_proxy_unique_path: str | None = None + self.db_hook: PostgresHook | MySqlHook | None = None + self.reserved_tcp_socket: socket.socket | None = None # Generated based on clock + clock sequence. Unique per host (!). # This is important as different hosts share the database self.db_conn_id = str(uuid.uuid1()) @@ -734,7 +728,7 @@ def __init__( @staticmethod def _get_bool(val: Any) -> bool: - if val == 'False' or val is False: + if val == "False" or val is False: return False return True @@ -746,7 +740,7 @@ def _check_ssl_file(file_to_check, name) -> None: raise AirflowException(f"The {file_to_check} must be a readable file") def _validate_inputs(self) -> None: - if self.project_id == '': + if self.project_id == "": raise AirflowException("The required extra 'project_id' is empty") if not self.location: raise AirflowException("The required extra 'location' is empty or None") @@ -782,7 +776,7 @@ def validate_socket_path_length(self) -> None: :return: None or rises AirflowException """ if self.use_proxy and not self.sql_proxy_use_tcp: - if self.database_type == 'postgres': + if self.database_type == "postgres": suffix = "/.s.PGSQL.5432" else: suffix = "" @@ -809,13 +803,13 @@ def _generate_unique_path() -> str: random.seed() while True: candidate = os.path.join( - gettempdir(), ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(8)) + gettempdir(), "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(8)) ) if not os.path.exists(candidate): return candidate @staticmethod - def _quote(value) -> Optional[str]: + def _quote(value) -> str | None: return quote_plus(value) if value else None def _generate_connection_uri(self) -> str: @@ -828,23 +822,23 @@ def _generate_connection_uri(self) -> str: if not self.database_type: raise ValueError("The database_type should be set") - database_uris = CONNECTION_URIS[self.database_type] # type: Dict[str, Dict[str, str]] + database_uris = CONNECTION_URIS[self.database_type] ssl_spec = None socket_path = None if self.use_proxy: - proxy_uris = database_uris['proxy'] # type: Dict[str, str] + proxy_uris = database_uris["proxy"] if self.sql_proxy_use_tcp: - format_string = proxy_uris['tcp'] + format_string = proxy_uris["tcp"] else: - format_string = proxy_uris['socket'] + format_string = proxy_uris["socket"] socket_path = f"{self.sql_proxy_unique_path}/{self._get_instance_socket_name()}" else: - public_uris = database_uris['public'] # type: Dict[str, str] + public_uris = database_uris["public"] if self.use_ssl: - format_string = public_uris['ssl'] - ssl_spec = {'cert': self.sslcert, 'key': self.sslkey, 'ca': self.sslrootcert} + format_string = public_uris["ssl"] + ssl_spec = {"cert": self.sslcert, "key": self.sslkey, "ca": self.sslrootcert} else: - format_string = public_uris['non-ssl'] + format_string = public_uris["non-ssl"] if not self.user: raise AirflowException("The login parameter needs to be set in connection") if not self.public_ip: @@ -855,28 +849,28 @@ def _generate_connection_uri(self) -> str: raise AirflowException("The database parameter needs to be set in connection") connection_uri = format_string.format( - user=quote_plus(self.user) if self.user else '', - password=quote_plus(self.password) if self.password else '', - database=quote_plus(self.database) if self.database else '', + user=quote_plus(self.user) if self.user else "", + password=quote_plus(self.password) if self.password else "", + database=quote_plus(self.database) if self.database else "", public_ip=self.public_ip, public_port=self.public_port, proxy_port=self.sql_proxy_tcp_port, socket_path=self._quote(socket_path), - ssl_spec=self._quote(json.dumps(ssl_spec)) if ssl_spec else '', - client_cert_file=self._quote(self.sslcert) if self.sslcert else '', - client_key_file=self._quote(self.sslkey) if self.sslcert else '', - server_ca_file=self._quote(self.sslrootcert if self.sslcert else ''), + ssl_spec=self._quote(json.dumps(ssl_spec)) if ssl_spec else "", + client_cert_file=self._quote(self.sslcert) if self.sslcert else "", + client_key_file=self._quote(self.sslkey) if self.sslcert else "", + server_ca_file=self._quote(self.sslrootcert if self.sslcert else ""), ) self.log.info( "DB connection URI %s", connection_uri.replace( - quote_plus(self.password) if self.password else 'PASSWORD', 'XXXXXXXXXXXX' + quote_plus(self.password) if self.password else "PASSWORD", "XXXXXXXXXXXX" ), ) return connection_uri def _get_instance_socket_name(self) -> str: - return self.project_id + ":" + self.location + ":" + self.instance # type: ignore + return self.project_id + ":" + self.location + ":" + self.instance def _get_sqlproxy_instance_specification(self) -> str: instance_specification = self._get_instance_socket_name() @@ -900,7 +894,6 @@ def get_sqlproxy_runner(self) -> CloudSqlProxyRunner: lifecycle per task. :return: The Cloud SQL Proxy runner. - :rtype: CloudSqlProxyRunner """ if not self.use_proxy: raise ValueError("Proxy runner can only be retrieved in case of use_proxy = True") @@ -915,25 +908,26 @@ def get_sqlproxy_runner(self) -> CloudSqlProxyRunner: gcp_conn_id=self.gcp_conn_id, ) - def get_database_hook(self, connection: Connection) -> Union[PostgresHook, MySqlHook]: + def get_database_hook(self, connection: Connection) -> PostgresHook | MySqlHook: """ Retrieve database hook. This is the actual Postgres or MySQL database hook that uses proxy or connects directly to the Google Cloud SQL database. """ - if self.database_type == 'postgres': - self.db_hook = PostgresHook(connection=connection, schema=self.database) + if self.database_type == "postgres": + db_hook: PostgresHook | MySqlHook = PostgresHook(connection=connection, schema=self.database) else: - self.db_hook = MySqlHook(connection=connection, schema=self.database) - return self.db_hook + db_hook = MySqlHook(connection=connection, schema=self.database) + self.db_hook = db_hook + return db_hook def cleanup_database_hook(self) -> None: """Clean up database hook after it was used.""" - if self.database_type == 'postgres': + if self.database_type == "postgres": if not self.db_hook: raise ValueError("The db_hook should be set") if not isinstance(self.db_hook, PostgresHook): raise ValueError(f"The db_hook should be PostgresHook and is {type(self.db_hook)}") - conn = getattr(self.db_hook, 'conn') + conn = getattr(self.db_hook, "conn") if conn and conn.notices: for output in self.db_hook.conn.notices: self.log.info(output) @@ -941,7 +935,7 @@ def cleanup_database_hook(self) -> None: def reserve_free_tcp_port(self) -> None: """Reserve free TCP port to be used by Cloud SQL Proxy""" self.reserved_tcp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.reserved_tcp_socket.bind(('127.0.0.1', 0)) + self.reserved_tcp_socket.bind(("127.0.0.1", 0)) self.sql_proxy_tcp_port = self.reserved_tcp_socket.getsockname()[1] def free_reserved_port(self) -> None: diff --git a/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py b/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py index 04b57db0b468d..2942f09c21366 100644 --- a/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +++ b/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Storage Transfer Service Hook.""" +from __future__ import annotations import json import logging @@ -23,7 +24,7 @@ import warnings from copy import deepcopy from datetime import timedelta -from typing import List, Optional, Sequence, Set, Union +from typing import Sequence from googleapiclient.discovery import Resource, build from googleapiclient.errors import HttpError @@ -59,44 +60,44 @@ class GcpTransferOperationStatus: ACCESS_KEY_ID = "accessKeyId" ALREADY_EXISTING_IN_SINK = "overwriteObjectsAlreadyExistingInSink" AWS_ACCESS_KEY = "awsAccessKey" -AWS_S3_DATA_SOURCE = 'awsS3DataSource' -BODY = 'body' -BUCKET_NAME = 'bucketName' -COUNTERS = 'counters' -DAY = 'day' +AWS_S3_DATA_SOURCE = "awsS3DataSource" +BODY = "body" +BUCKET_NAME = "bucketName" +COUNTERS = "counters" +DAY = "day" DESCRIPTION = "description" -FILTER = 'filter' -FILTER_JOB_NAMES = 'job_names' -FILTER_PROJECT_ID = 'project_id' -GCS_DATA_SINK = 'gcsDataSink' -GCS_DATA_SOURCE = 'gcsDataSource' +FILTER = "filter" +FILTER_JOB_NAMES = "job_names" +FILTER_PROJECT_ID = "project_id" +GCS_DATA_SINK = "gcsDataSink" +GCS_DATA_SOURCE = "gcsDataSource" HOURS = "hours" -HTTP_DATA_SOURCE = 'httpDataSource' -JOB_NAME = 'name' -LIST_URL = 'list_url' -METADATA = 'metadata' +HTTP_DATA_SOURCE = "httpDataSource" +JOB_NAME = "name" +LIST_URL = "list_url" +METADATA = "metadata" MINUTES = "minutes" -MONTH = 'month' -NAME = 'name' -OBJECT_CONDITIONS = 'object_conditions' -OPERATIONS = 'operations' -PATH = 'path' -PROJECT_ID = 'projectId' -SCHEDULE = 'schedule' -SCHEDULE_END_DATE = 'scheduleEndDate' -SCHEDULE_START_DATE = 'scheduleStartDate' +MONTH = "month" +NAME = "name" +OBJECT_CONDITIONS = "object_conditions" +OPERATIONS = "operations" +PATH = "path" +PROJECT_ID = "projectId" +SCHEDULE = "schedule" +SCHEDULE_END_DATE = "scheduleEndDate" +SCHEDULE_START_DATE = "scheduleStartDate" SECONDS = "seconds" SECRET_ACCESS_KEY = "secretAccessKey" -START_TIME_OF_DAY = 'startTimeOfDay' +START_TIME_OF_DAY = "startTimeOfDay" STATUS = "status" -STATUS1 = 'status' -TRANSFER_JOB = 'transfer_job' -TRANSFER_JOBS = 'transferJobs' -TRANSFER_JOB_FIELD_MASK = 'update_transfer_job_field_mask' -TRANSFER_OPERATIONS = 'transferOperations' -TRANSFER_OPTIONS = 'transfer_options' -TRANSFER_SPEC = 'transferSpec' -YEAR = 'year' +STATUS1 = "status" +TRANSFER_JOB = "transfer_job" +TRANSFER_JOBS = "transferJobs" +TRANSFER_JOB_FIELD_MASK = "update_transfer_job_field_mask" +TRANSFER_OPERATIONS = "transferOperations" +TRANSFER_OPTIONS = "transfer_options" +TRANSFER_SPEC = "transferSpec" +YEAR = "year" ALREADY_EXIST_CODE = 409 NEGATIVE_STATUSES = {GcpTransferOperationStatus.FAILED, GcpTransferOperationStatus.ABORTED} @@ -124,10 +125,10 @@ class CloudDataTransferServiceHook(GoogleBaseHook): def __init__( self, - api_version: str = 'v1', + api_version: str = "v1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -142,12 +143,11 @@ def get_conn(self) -> Resource: Retrieves connection to Google Storage Transfer service. :return: Google Storage Transfer service object - :rtype: dict """ if not self._conn: http_authorized = self._authorize() self._conn = build( - 'storagetransfer', self.api_version, http=http_authorized, cache_discovery=False + "storagetransfer", self.api_version, http=http_authorized, cache_discovery=False ) return self._conn @@ -160,11 +160,10 @@ def create_transfer_job(self, body: dict) -> dict: :return: transfer job. See: https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs#TransferJob - :rtype: dict """ body = self._inject_project_id(body, BODY, PROJECT_ID) - try: + try: transfer_job = ( self.get_conn().transferJobs().create(body=body).execute(num_retries=self.num_retries) ) @@ -206,7 +205,6 @@ def get_transfer_job(self, job_name: str, project_id: str) -> dict: Job. If set to None or missing, the default project_id from the Google Cloud connection is used. :return: Transfer Job - :rtype: dict """ return ( self.get_conn() @@ -215,7 +213,7 @@ def get_transfer_job(self, job_name: str, project_id: str) -> dict: .execute(num_retries=self.num_retries) ) - def list_transfer_job(self, request_filter: Optional[dict] = None, **kwargs) -> List[dict]: + def list_transfer_job(self, request_filter: dict | None = None, **kwargs) -> list[dict]: """ Lists long-running operations in Google Storage Transfer Service that match the specified filter. @@ -223,13 +221,12 @@ def list_transfer_job(self, request_filter: Optional[dict] = None, **kwargs) -> :param request_filter: (Required) A request filter, as described in https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs/list#body.QUERY_PARAMETERS.filter :return: List of Transfer Jobs - :rtype: list[dict] """ # To preserve backward compatibility # TODO: remove one day if request_filter is None: - if 'filter' in kwargs: - request_filter = kwargs['filter'] + if "filter" in kwargs: + request_filter = kwargs["filter"] if not isinstance(request_filter, dict): raise ValueError(f"The request_filter should be dict and is {type(request_filter)}") warnings.warn("Use 'request_filter' instead of 'filter'", DeprecationWarning) @@ -239,7 +236,7 @@ def list_transfer_job(self, request_filter: Optional[dict] = None, **kwargs) -> conn = self.get_conn() request_filter = self._inject_project_id(request_filter, FILTER, FILTER_PROJECT_ID) request = conn.transferJobs().list(filter=json.dumps(request_filter)) - jobs: List[dict] = [] + jobs: list[dict] = [] while request is not None: response = request.execute(num_retries=self.num_retries) @@ -259,7 +256,6 @@ def enable_transfer_job(self, job_name: str, project_id: str) -> dict: Job. If set to None or missing, the default project_id from the Google Cloud connection is used. :return: If successful, TransferJob. - :rtype: dict """ return ( self.get_conn() @@ -283,7 +279,6 @@ def update_transfer_job(self, job_name: str, body: dict) -> dict: :param body: A request body, as described in https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs/patch#request-body :return: If successful, TransferJob. - :rtype: dict """ body = self._inject_project_id(body, BODY, PROJECT_ID) return ( @@ -305,7 +300,6 @@ def delete_transfer_job(self, job_name: str, project_id: str) -> None: :param project_id: (Optional) the ID of the project that owns the Transfer Job. If set to None or missing, the default project_id from the Google Cloud connection is used. - :rtype: None """ ( self.get_conn() @@ -326,7 +320,6 @@ def cancel_transfer_operation(self, operation_name: str) -> None: Cancels an transfer operation in Google Storage Transfer Service. :param operation_name: Name of the transfer operation. - :rtype: None """ self.get_conn().transferOperations().cancel(name=operation_name).execute(num_retries=self.num_retries) @@ -338,7 +331,6 @@ def get_transfer_operation(self, operation_name: str) -> dict: :return: transfer operation See: https://cloud.google.com/storage-transfer/docs/reference/rest/v1/Operation - :rtype: dict """ return ( self.get_conn() @@ -347,7 +339,7 @@ def get_transfer_operation(self, operation_name: str) -> dict: .execute(num_retries=self.num_retries) ) - def list_transfer_operations(self, request_filter: Optional[dict] = None, **kwargs) -> List[dict]: + def list_transfer_operations(self, request_filter: dict | None = None, **kwargs) -> list[dict]: """ Gets an transfer operation in Google Storage Transfer Service. @@ -360,13 +352,12 @@ def list_transfer_operations(self, request_filter: Optional[dict] = None, **kwar See: :doc:`/connections/gcp` :return: transfer operation - :rtype: list[dict] """ # To preserve backward compatibility # TODO: remove one day if request_filter is None: - if 'filter' in kwargs: - request_filter = kwargs['filter'] + if "filter" in kwargs: + request_filter = kwargs["filter"] if not isinstance(request_filter, dict): raise ValueError(f"The request_filter should be dict and is {type(request_filter)}") warnings.warn("Use 'request_filter' instead of 'filter'", DeprecationWarning) @@ -379,7 +370,7 @@ def list_transfer_operations(self, request_filter: Optional[dict] = None, **kwar request_filter = self._inject_project_id(request_filter, FILTER, FILTER_PROJECT_ID) - operations: List[dict] = [] + operations: list[dict] = [] request = conn.transferOperations().list(name=TRANSFER_OPERATIONS, filter=json.dumps(request_filter)) @@ -399,7 +390,6 @@ def pause_transfer_operation(self, operation_name: str) -> None: Pauses an transfer operation in Google Storage Transfer Service. :param operation_name: (Required) Name of the transfer operation. - :rtype: None """ self.get_conn().transferOperations().pause(name=operation_name).execute(num_retries=self.num_retries) @@ -408,15 +398,14 @@ def resume_transfer_operation(self, operation_name: str) -> None: Resumes an transfer operation in Google Storage Transfer Service. :param operation_name: (Required) Name of the transfer operation. - :rtype: None """ self.get_conn().transferOperations().resume(name=operation_name).execute(num_retries=self.num_retries) def wait_for_transfer_job( self, job: dict, - expected_statuses: Optional[Set[str]] = None, - timeout: Optional[Union[float, timedelta]] = None, + expected_statuses: set[str] | None = None, + timeout: float | timedelta | None = None, ) -> None: """ Waits until the job reaches the expected state. @@ -429,7 +418,6 @@ def wait_for_transfer_job( https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferOperations#Status :param timeout: Time in which the operation must end in seconds. If not specified, defaults to 60 seconds. - :rtype: None """ expected_statuses = ( {GcpTransferOperationStatus.SUCCESS} if not expected_statuses else expected_statuses @@ -464,7 +452,7 @@ def _inject_project_id(self, body: dict, param_name: str, target_key: str) -> di @staticmethod def operations_contain_expected_statuses( - operations: List[dict], expected_statuses: Union[Set[str], str] + operations: list[dict], expected_statuses: set[str] | str ) -> bool: """ Checks whether the operation list has an operation with the @@ -480,7 +468,6 @@ def operations_contain_expected_statuses( in the operation list, returns true, :raises: airflow.exceptions.AirflowException If it encounters operations with a state in the list, - :rtype: bool """ expected_statuses_set = ( {expected_statuses} if isinstance(expected_statuses, str) else set(expected_statuses) diff --git a/airflow/providers/google/cloud/hooks/compute.py b/airflow/providers/google/cloud/hooks/compute.py index 86d5e808e41bc..7bc9d33e7c30e 100644 --- a/airflow/providers/google/cloud/hooks/compute.py +++ b/airflow/providers/google/cloud/hooks/compute.py @@ -16,10 +16,16 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Compute Engine Hook.""" +from __future__ import annotations import time -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Sequence +from google.api_core.retry import Retry +from google.cloud.compute_v1.services.instance_group_managers import InstanceGroupManagersClient +from google.cloud.compute_v1.services.instance_templates import InstanceTemplatesClient +from google.cloud.compute_v1.services.instances import InstancesClient +from google.cloud.compute_v1.types import Instance, InstanceGroupManager, InstanceTemplate from googleapiclient.discovery import build from airflow.exceptions import AirflowException @@ -47,10 +53,10 @@ class ComputeEngineHook(GoogleBaseHook): def __init__( self, - api_version: str = 'v1', - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + api_version: str = "v1", + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -59,7 +65,7 @@ def __init__( ) self.api_version = api_version - _conn: Optional[Any] = None + _conn: Any | None = None def get_conn(self): """ @@ -69,9 +75,352 @@ def get_conn(self): """ if not self._conn: http_authorized = self._authorize() - self._conn = build('compute', self.api_version, http=http_authorized, cache_discovery=False) + self._conn = build("compute", self.api_version, http=http_authorized, cache_discovery=False) return self._conn + def get_compute_instance_template_client(self): + """Returns Compute Engine Instance Template Client.""" + return InstanceTemplatesClient(credentials=self._get_credentials(), client_info=self.client_info) + + def get_compute_instance_client(self): + """Returns Compute Engine Instance Client.""" + return InstancesClient(credentials=self._get_credentials(), client_info=self.client_info) + + def get_compute_instance_group_managers_client(self): + """Returns Compute Engine Instance Group Managers Client.""" + return InstanceGroupManagersClient(credentials=self._get_credentials(), client_info=self.client_info) + + @GoogleBaseHook.fallback_to_default_project_id + def insert_instance_template( + self, + body: dict, + request_id: str | None = None, + project_id: str = PROVIDE_PROJECT_ID, + retry: Retry | None = None, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> None: + """ + Creates Instance Template using body specified. + Must be called with keyword arguments rather than positional. + + :param body: Instance Template representation as an object. + :param request_id: Unique request_id that you might add to achieve + full idempotence (for example when client call times out repeating the request + with the same request id will not create a new instance template again) + It should be in UUID format as defined in RFC 4122 + :param project_id: Google Cloud project ID where the Compute Engine Instance Template exists. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + client = self.get_compute_instance_template_client() + client.insert( + # Calling method insert() on client to create Instance Template. + # This method accepts request object as an argument and should be of type + # Union[google.cloud.compute_v1.types.InsertInstanceTemplateRequest, dict] to construct a request + # message. + # The request object should be represented using arguments: + # instance_template_resource (google.cloud.compute_v1.types.InstanceTemplate): + # The body resource for this request. + # request_id (str): + # An optional request ID to identify requests. + # project (str): + # Project ID for this request. + request={ + "instance_template_resource": body, + "request_id": request_id, + "project": project_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_instance_template( + self, + resource_id: str, + request_id: str | None = None, + project_id: str = PROVIDE_PROJECT_ID, + retry: Retry | None = None, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> None: + """ + Deletes Instance Template. + Deleting an Instance Template is permanent and cannot be undone. It + is not possible to delete templates that are already in use by a managed instance group. + Must be called with keyword arguments rather than positional. + + :param resource_id: Name of the Compute Engine Instance Template resource. + :param request_id: Unique request_id that you might add to achieve + full idempotence (for example when client call times out repeating the request + with the same request id will not create a new instance template again) + It should be in UUID format as defined in RFC 4122 + :param project_id: Google Cloud project ID where the Compute Engine Instance Template exists. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + client = self.get_compute_instance_template_client() + client.delete( + # Calling method delete() on client to delete Instance Template. + # This method accepts request object as an argument and should be of type + # Union[google.cloud.compute_v1.types.DeleteInstanceTemplateRequest, dict] to + # construct a request message. + # The request object should be represented using arguments: + # instance_template (str): + # The name of the Instance Template to delete. + # project (str): + # Project ID for this request. + # request_id (str): + # An optional request ID to identify requests. + request={ + "instance_template": resource_id, + "project": project_id, + "request_id": request_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def get_instance_template( + self, + resource_id: str, + project_id: str = PROVIDE_PROJECT_ID, + retry: Retry | None = None, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> InstanceTemplate: + """ + Retrieves Instance Template by project_id and resource_id. + Must be called with keyword arguments rather than positional. + + :param resource_id: Name of the Instance Template. + :param project_id: Google Cloud project ID where the Compute Engine Instance Template exists. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :return: Instance Template representation as object according to + https://cloud.google.com/compute/docs/reference/rest/v1/instanceTemplates + :rtype: object + """ + client = self.get_compute_instance_template_client() + instance_template_obj = client.get( + # Calling method get() on client to get the specified Instance Template. + # This method accepts request object as an argument and should be of type + # Union[google.cloud.compute_v1.types.GetInstanceTemplateRequest, dict] to construct a request + # message. + # The request object should be represented using arguments: + # instance_template (str): + # The name of the Instance Template. + # project (str): + # Project ID for this request. + request={ + "instance_template": resource_id, + "project": project_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return instance_template_obj + + @GoogleBaseHook.fallback_to_default_project_id + def insert_instance( + self, + body: dict, + zone: str, + project_id: str = PROVIDE_PROJECT_ID, + request_id: str | None = None, + source_instance_template: str | None = None, + retry: Retry | None = None, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> None: + """ + Creates Instance using body specified. + Must be called with keyword arguments rather than positional. + + :param body: Instance representation as an object. Should at least include 'name', 'machine_type', + 'disks' and 'network_interfaces' fields but doesn't include 'zone' field, as it will be specified + in 'zone' parameter. + Full or partial URL and can be represented as examples below: + 1. "machine_type": "projects/your-project-name/zones/your-zone/machineTypes/your-machine-type" + 2. "source_image": "projects/your-project-name/zones/your-zone/diskTypes/your-disk-type" + 3. "subnetwork": "projects/your-project-name/regions/your-region/subnetworks/your-subnetwork" + :param zone: Google Cloud zone where the Instance exists + :param project_id: Google Cloud project ID where the Compute Engine Instance Template exists. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param source_instance_template: Existing Instance Template that will be used as a base while + creating new Instance. + When specified, only name of new Instance should be provided as input arguments in 'body' + parameter when creating new Instance. All other parameters, such as machine_type, disks + and network_interfaces and etc will be passed to Instance as they are specified + in the Instance Template. + Full or partial URL and can be represented as examples below: + 1. "https://www.googleapis.com/compute/v1/projects/your-project/global/instanceTemplates/temp" + 2. "projects/your-project/global/instanceTemplates/temp" + 3. "global/instanceTemplates/temp" + :param request_id: Unique request_id that you might add to achieve + full idempotence (for example when client call times out repeating the request + with the same request id will not create a new instance template again) + It should be in UUID format as defined in RFC 4122 + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + client = self.get_compute_instance_client() + client.insert( + # Calling method insert() on client to create Instance. + # This method accepts request object as an argument and should be of type + # Union[google.cloud.compute_v1.types.InsertInstanceRequest, dict] to construct a request + # message. + # The request object should be represented using arguments: + # instance_resource (google.cloud.compute_v1.types.Instance): + # The body resource for this request. + # request_id (str): + # Optional, request ID to identify requests. + # project (str): + # Project ID for this request. + # zone (str): + # The name of the zone for this request. + # source_instance_template (str): + # Optional, link to Instance Template, that can be used to create new Instance. + request={ + "instance_resource": body, + "request_id": request_id, + "project": project_id, + "zone": zone, + "source_instance_template": source_instance_template, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def get_instance( + self, + resource_id: str, + zone: str, + project_id: str = PROVIDE_PROJECT_ID, + retry: Retry | None = None, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Instance: + """ + Retrieves Instance by project_id and resource_id. + Must be called with keyword arguments rather than positional. + + :param resource_id: Name of the Instance + :param zone: Google Cloud zone where the Instance exists + :param project_id: Google Cloud project ID where the Compute Engine Instance exists. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :return: Instance representation as object according to + https://cloud.google.com/compute/docs/reference/rest/v1/instances + :rtype: object + """ + client = self.get_compute_instance_client() + instance_obj = client.get( + # Calling method get() on client to get the specified Instance. + # This method accepts request object as an argument and should be of type + # Union[google.cloud.compute_v1.types.GetInstanceRequest, dict] to construct a request + # message. + # The request object should be represented using arguments: + # instance (str): + # The name of the Instance. + # project (str): + # Project ID for this request. + # zone (str): + # The name of the zone for this request. + request={ + "instance": resource_id, + "project": project_id, + "zone": zone, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return instance_obj + + @GoogleBaseHook.fallback_to_default_project_id + def delete_instance( + self, + resource_id: str, + zone: str, + project_id: str = PROVIDE_PROJECT_ID, + request_id: str | None = None, + retry: Retry | None = None, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> None: + """ + Deletes Instance. + Deleting an Instance is permanent and cannot be undone. + It is not possible to delete Instances that are already in use by a managed instance group. + Must be called with keyword arguments rather than positional. + + :param resource_id: Name of the Compute Engine Instance Template resource. + :param request_id: Unique request_id that you might add to achieve + full idempotence (for example when client call times out repeating the request + with the same request id will not create a new instance template again) + It should be in UUID format as defined in RFC 4122 + :param project_id: Google Cloud project ID where the Compute Engine Instance Template exists. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param zone: Google Cloud zone where the Instance exists + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + client = self.get_compute_instance_client() + client.delete( + # Calling method delete() on client to delete Instance. + # This method accepts request object as an argument and should be of type + # Union[google.cloud.compute_v1.types.DeleteInstanceRequest, dict] to construct a request + # message. + # The request object should be represented using arguments: + # instance (str): + # Name of the Instance resource to delete. + # project (str): + # Project ID for this request. + # request_id (str): + # An optional request ID to identify requests. + # zone (str): + # The name of the zone for this request. + request={ + "instance": resource_id, + "project": project_id, + "request_id": request_id, + "zone": zone, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + @GoogleBaseHook.fallback_to_default_project_id def start_instance(self, zone: str, resource_id: str, project_id: str) -> None: """ @@ -155,88 +504,167 @@ def _execute_set_machine_type(self, zone: str, resource_id: str, body: dict, pro ) @GoogleBaseHook.fallback_to_default_project_id - def get_instance_template(self, resource_id: str, project_id: str) -> dict: + def insert_instance_group_manager( + self, + body: dict, + zone: str, + project_id: str = PROVIDE_PROJECT_ID, + request_id: str | None = None, + retry: Retry | None = None, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> None: """ - Retrieves instance template by project_id and resource_id. + Creates an Instance Group Managers using the body specified. + After the group is created, instances in the group are created using the specified Instance Template. Must be called with keyword arguments rather than positional. - :param resource_id: Name of the instance template - :param project_id: Optional, Google Cloud project ID where the - Compute Engine Instance exists. If set to None or missing, - the default project_id from the Google Cloud connection is used. - :return: Instance template representation as object according to - https://cloud.google.com/compute/docs/reference/rest/v1/instanceTemplates - :rtype: dict + :param body: Instance Group Manager representation as an object. + :param request_id: Unique request_id that you might add to achieve + full idempotence (for example when client call times out repeating the request + with the same request id will not create a new Instance Group Managers again) + It should be in UUID format as defined in RFC 4122 + :param project_id: Google Cloud project ID where the Compute Engine Instance Group Managers exists. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param zone: Google Cloud zone where the Instance exists + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. """ - response = ( - self.get_conn() - .instanceTemplates() - .get(project=project_id, instanceTemplate=resource_id) - .execute(num_retries=self.num_retries) + client = self.get_compute_instance_group_managers_client() + client.insert( + # Calling method insert() on client to create the specified Instance Group Managers. + # This method accepts request object as an argument and should be of type + # Union[google.cloud.compute_v1.types.InsertInstanceGroupManagerRequest, dict] to construct + # a request message. + # The request object should be represented using arguments: + # instance_group_manager_resource (google.cloud.compute_v1.types.InstanceGroupManager): + # The body resource for this request. + # project (str): + # Project ID for this request. + # zone (str): + # The name of the zone where you want to create the managed instance group. + # request_id (str): + # An optional request ID to identify requests. + request={ + "instance_group_manager_resource": body, + "project": project_id, + "zone": zone, + "request_id": request_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, ) - return response @GoogleBaseHook.fallback_to_default_project_id - def insert_instance_template( + def get_instance_group_manager( self, - body: dict, + resource_id: str, + zone: str, project_id: str = PROVIDE_PROJECT_ID, - request_id: Optional[str] = None, - ) -> None: + retry: Retry | None = None, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> InstanceGroupManager: """ - Inserts instance template using body specified + Retrieves Instance Group Manager by project_id, zone and resource_id. Must be called with keyword arguments rather than positional. - :param body: Instance template representation as object according to - https://cloud.google.com/compute/docs/reference/rest/v1/instanceTemplates - :param request_id: Optional, unique request_id that you might add to achieve - full idempotence (for example when client call times out repeating the request - with the same request id will not create a new instance template again) - It should be in UUID format as defined in RFC 4122 - :param project_id: Optional, Google Cloud project ID where the - Compute Engine Instance exists. If set to None or missing, - the default project_id from the Google Cloud connection is used. - :return: None + :param resource_id: The name of the Managed Instance Group + :param zone: Google Cloud zone where the Instance Group Managers exists + :param project_id: Google Cloud project ID where the Compute Engine Instance Group Managers exists. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :return: Instance Group Managers representation as object according to + https://cloud.google.com/compute/docs/reference/rest/v1/instanceGroupManagers + :rtype: object """ - response = ( - self.get_conn() - .instanceTemplates() - .insert(project=project_id, body=body, requestId=request_id) - .execute(num_retries=self.num_retries) + client = self.get_compute_instance_group_managers_client() + instance_group_manager_obj = client.get( + # Calling method get() on client to get the specified Instance Group Manager. + # This method accepts request object as an argument and should be of type + # Union[google.cloud.compute_v1.types.GetInstanceGroupManagerRequest, dict] to construct a + # request message. + # The request object should be represented using arguments: + # instance_group_manager (str): + # The name of the Managed Instance Group. + # project (str): + # Project ID for this request. + # zone (str): + # The name of the zone for this request. + request={ + "instance_group_manager": resource_id, + "project": project_id, + "zone": zone, + }, + retry=retry, + timeout=timeout, + metadata=metadata, ) - try: - operation_name = response["name"] - except KeyError: - raise AirflowException(f"Wrong response '{response}' returned - it should contain 'name' field") - self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name) + return instance_group_manager_obj @GoogleBaseHook.fallback_to_default_project_id - def get_instance_group_manager( + def delete_instance_group_manager( self, - zone: str, resource_id: str, + zone: str, project_id: str = PROVIDE_PROJECT_ID, - ) -> dict: + request_id: str | None = None, + retry: Retry | None = None, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> None: """ - Retrieves Instance Group Manager by project_id, zone and resource_id. + Deletes Instance Group Managers. + Deleting an Instance Group Manager is permanent and cannot be undone. Must be called with keyword arguments rather than positional. - :param zone: Google Cloud zone where the Instance Group Manager exists - :param resource_id: Name of the Instance Group Manager - :param project_id: Optional, Google Cloud project ID where the - Compute Engine Instance exists. If set to None or missing, - the default project_id from the Google Cloud connection is used. - :return: Instance group manager representation as object according to - https://cloud.google.com/compute/docs/reference/rest/beta/instanceGroupManagers - :rtype: dict + :param resource_id: Name of the Compute Engine Instance Group Managers resource. + :param request_id: Unique request_id that you might add to achieve + full idempotence (for example when client call times out repeating the request + with the same request id will not create a new instance template again) + It should be in UUID format as defined in RFC 4122 + :param project_id: Google Cloud project ID where the Compute Engine Instance Group Managers exists. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param zone: Google Cloud zone where the Instance Group Managers exists + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. """ - response = ( - self.get_conn() - .instanceGroupManagers() - .get(project=project_id, zone=zone, instanceGroupManager=resource_id) - .execute(num_retries=self.num_retries) + client = self.get_compute_instance_group_managers_client() + client.delete( + # Calling method delete() on client to delete Instance Group Managers. + # This method accepts request object as an argument and should be of type + # Union[google.cloud.compute_v1.types.DeleteInstanceGroupManagerRequest, dict] to construct a + # request message. + # The request object should be represented using arguments: + # instance_group_manager (str): + # Name of the Instance resource to delete. + # project (str): + # Project ID for this request. + # request_id (str): + # An optional request ID to identify requests. + # zone (str): + # The name of the zone for this request. + request={ + "instance_group_manager": resource_id, + "project": project_id, + "request_id": request_id, + "zone": zone, + }, + retry=retry, + timeout=timeout, + metadata=metadata, ) - return response @GoogleBaseHook.fallback_to_default_project_id def patch_instance_group_manager( @@ -245,7 +673,7 @@ def patch_instance_group_manager( resource_id: str, body: dict, project_id: str, - request_id: Optional[str] = None, + request_id: str | None = None, ) -> None: """ Patches Instance Group Manager with the specified body. @@ -284,7 +712,7 @@ def patch_instance_group_manager( self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name, zone=zone) def _wait_for_operation_to_complete( - self, project_id: str, operation_name: str, zone: Optional[str] = None + self, project_id: str, operation_name: str, zone: str | None = None ) -> None: """ Waits for the named operation to complete - checks status of the async call. @@ -338,7 +766,7 @@ def _check_global_operation_status( ) @GoogleBaseHook.fallback_to_default_project_id - def get_instance_info(self, zone: str, resource_id: str, project_id: str) -> Dict[str, Any]: + def get_instance_info(self, zone: str, resource_id: str, project_id: str) -> dict[str, Any]: """ Gets instance information. @@ -381,7 +809,7 @@ def get_instance_address( @GoogleBaseHook.fallback_to_default_project_id def set_instance_metadata( - self, zone: str, resource_id: str, metadata: Dict[str, str], project_id: str + self, zone: str, resource_id: str, metadata: dict[str, str], project_id: str ) -> None: """ Set instance metadata. diff --git a/airflow/providers/google/cloud/hooks/compute_ssh.py b/airflow/providers/google/cloud/hooks/compute_ssh.py index 3b65ce4a72097..5e8d343196ba4 100644 --- a/airflow/providers/google/cloud/hooks/compute_ssh.py +++ b/airflow/providers/google/cloud/hooks/compute_ssh.py @@ -14,20 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import shlex -import sys import time from io import StringIO -from typing import Any, Dict, Optional - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property +from typing import Any from google.api_core.retry import exponential_sleep_generator from airflow import AirflowException +from airflow.compat.functools import cached_property from airflow.providers.google.cloud.hooks.compute import ComputeEngineHook from airflow.providers.google.cloud.hooks.os_login import OSLoginHook from airflow.providers.ssh.hooks.ssh import SSHHook @@ -87,31 +84,31 @@ class ComputeEngineSSHHook(SSHHook): domain-wide delegation enabled. """ - conn_name_attr = 'gcp_conn_id' - default_conn_name = 'google_cloud_ssh_default' - conn_type = 'gcpssh' - hook_name = 'Google Cloud SSH' + conn_name_attr = "gcp_conn_id" + default_conn_name = "google_cloud_ssh_default" + conn_type = "gcpssh" + hook_name = "Google Cloud SSH" @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + def get_ui_field_behaviour() -> dict[str, Any]: return { - "hidden_fields": ['host', 'schema', 'login', 'password', 'port', 'extra'], + "hidden_fields": ["host", "schema", "login", "password", "port", "extra"], "relabeling": {}, } def __init__( self, - gcp_conn_id: str = 'google_cloud_default', - instance_name: Optional[str] = None, - zone: Optional[str] = None, - user: Optional[str] = 'root', - project_id: Optional[str] = None, - hostname: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + instance_name: str | None = None, + zone: str | None = None, + user: str | None = "root", + project_id: str | None = None, + hostname: str | None = None, use_internal_ip: bool = False, use_iap_tunnel: bool = False, use_oslogin: bool = True, expire_time: int = 300, - delegate_to: Optional[str] = None, + delegate_to: str | None = None, ) -> None: # Ignore original constructor # super().__init__() @@ -126,7 +123,7 @@ def __init__( self.expire_time = expire_time self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to - self._conn: Optional[Any] = None + self._conn: Any | None = None @cached_property def _oslogin_hook(self) -> OSLoginHook: @@ -141,23 +138,23 @@ def _boolify(value): if isinstance(value, bool): return value if isinstance(value, str): - if value.lower() == 'false': + if value.lower() == "false": return False - elif value.lower() == 'true': + elif value.lower() == "true": return True return False def intify(key, value, default): if value is None: return default - if isinstance(value, str) and value.strip() == '': + if isinstance(value, str) and value.strip() == "": return default try: return int(value) except ValueError: raise AirflowException( f"The {key} field should be a integer. " - f"Current value: \"{value}\" (type: {type(value)}). " + f'Current value: "{value}" (type: {type(value)}). ' f"Please check the connection configuration." ) @@ -220,15 +217,15 @@ def get_conn(self) -> paramiko.SSHClient: proxy_command = None if self.use_iap_tunnel: proxy_command_args = [ - 'gcloud', - 'compute', - 'start-iap-tunnel', + "gcloud", + "compute", + "start-iap-tunnel", str(self.instance_name), - '22', - '--listen-on-stdin', - f'--project={self.project_id}', - f'--zone={self.zone}', - '--verbosity=warning', + "22", + "--listen-on-stdin", + f"--project={self.project_id}", + f"--zone={self.zone}", + "--verbosity=warning", ] proxy_command = " ".join(shlex.quote(arg) for arg in proxy_command_args) @@ -260,7 +257,7 @@ def _connect_to_instance(self, user, hostname, pkey, proxy_command) -> paramiko. raise self.log.info("Failed to connect. Waiting %ds to retry", time_to_wait) time.sleep(time_to_wait) - raise AirflowException("Caa not connect to instance") + raise AirflowException("Can not connect to instance") def _authorize_compute_engine_instance_metadata(self, pubkey): self.log.info("Appending SSH public key to instance metadata") @@ -269,16 +266,16 @@ def _authorize_compute_engine_instance_metadata(self, pubkey): ) keys = self.user + ":" + pubkey + "\n" - metadata = instance_info['metadata'] + metadata = instance_info["metadata"] items = metadata.get("items", []) for item in items: if item.get("key") == "ssh-keys": keys += item["value"] - item['value'] = keys + item["value"] = keys break else: - new_dict = dict(key='ssh-keys', value=keys) - metadata['items'] = [new_dict] + new_dict = dict(key="ssh-keys", value=keys) + metadata["items"] = [new_dict] self._compute_hook.set_instance_metadata( zone=self.zone, resource_id=self.instance_name, metadata=metadata, project_id=self.project_id diff --git a/airflow/providers/google/cloud/hooks/datacatalog.py b/airflow/providers/google/cloud/hooks/datacatalog.py index 8aa5aed934905..876987e07fb04 100644 --- a/airflow/providers/google/cloud/hooks/datacatalog.py +++ b/airflow/providers/google/cloud/hooks/datacatalog.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import Dict, Optional, Sequence, Tuple, Union +from typing import Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -58,20 +59,20 @@ class CloudDataCatalogHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) - self._client: Optional[DataCatalogClient] = None + self._client: DataCatalogClient | None = None def get_conn(self) -> DataCatalogClient: """Retrieves client library object that allow access to Cloud Data Catalog service.""" if not self._client: - self._client = DataCatalogClient(credentials=self._get_credentials(), client_info=CLIENT_INFO) + self._client = DataCatalogClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) return self._client @GoogleBaseHook.fallback_to_default_project_id @@ -80,11 +81,11 @@ def create_entry( location: str, entry_group: str, entry_id: str, - entry: Union[dict, Entry], + entry: dict | Entry, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Entry: """ Creates an entry. @@ -107,14 +108,14 @@ def create_entry( """ client = self.get_conn() parent = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}" - self.log.info('Creating a new entry: parent=%s', parent) + self.log.info("Creating a new entry: parent=%s", parent) result = client.create_entry( - request={'parent': parent, 'entry_id': entry_id, 'entry': entry}, + request={"parent": parent, "entry_id": entry_id, "entry": entry}, retry=retry, timeout=timeout, metadata=metadata, ) - self.log.info('Created a entry: name=%s', result.name) + self.log.info("Created a entry: name=%s", result.name) return result @GoogleBaseHook.fallback_to_default_project_id @@ -122,11 +123,11 @@ def create_entry_group( self, location: str, entry_group_id: str, - entry_group: Union[Dict, EntryGroup], + entry_group: dict | EntryGroup, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> EntryGroup: """ Creates an EntryGroup. @@ -149,15 +150,15 @@ def create_entry_group( """ client = self.get_conn() parent = f"projects/{project_id}/locations/{location}" - self.log.info('Creating a new entry group: parent=%s', parent) + self.log.info("Creating a new entry group: parent=%s", parent) result = client.create_entry_group( - request={'parent': parent, 'entry_group_id': entry_group_id, 'entry_group': entry_group}, + request={"parent": parent, "entry_group_id": entry_group_id, "entry_group": entry_group}, retry=retry, timeout=timeout, metadata=metadata, ) - self.log.info('Created a entry group: name=%s', result.name) + self.log.info("Created a entry group: name=%s", result.name) return result @@ -167,12 +168,12 @@ def create_tag( location: str, entry_group: str, entry: str, - tag: Union[dict, Tag], + tag: dict | Tag, project_id: str = PROVIDE_PROJECT_ID, - template_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + template_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Tag: """ Creates a tag on an entry. @@ -201,16 +202,16 @@ def create_tag( tag["template"] = template_path parent = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}" - self.log.info('Creating a new tag: parent=%s', parent) + self.log.info("Creating a new tag: parent=%s", parent) # HACK: google-cloud-datacatalog has problems with mapping messages where the value is not a # primitive type, so we need to convert it manually. # See: https://github.com/googleapis/python-datacatalog/issues/84 if isinstance(tag, dict): tag = Tag( - name=tag.get('name'), - template=tag.get('template'), - template_display_name=tag.get('template_display_name'), - column=tag.get('column'), + name=tag.get("name"), + template=tag.get("template"), + template_display_name=tag.get("template_display_name"), + column=tag.get("column"), fields={ k: datacatalog.TagField(**v) if isinstance(v, dict) else v for k, v in tag.get("fields", {}).items() @@ -222,7 +223,7 @@ def create_tag( ) result = client.create_tag(request=request, retry=retry, timeout=timeout, metadata=metadata or ()) - self.log.info('Created a tag: name=%s', result.name) + self.log.info("Created a tag: name=%s", result.name) return result @@ -231,11 +232,11 @@ def create_tag_template( self, location, tag_template_id: str, - tag_template: Union[dict, TagTemplate], + tag_template: dict | TagTemplate, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> TagTemplate: """ Creates a tag template. @@ -257,7 +258,7 @@ def create_tag_template( client = self.get_conn() parent = f"projects/{project_id}/locations/{location}" - self.log.info('Creating a new tag template: parent=%s', parent) + self.log.info("Creating a new tag template: parent=%s", parent) # HACK: google-cloud-datacatalog has problems with mapping messages where the value is not a # primitive type, so we need to convert it manually. # See: https://github.com/googleapis/python-datacatalog/issues/84 @@ -280,7 +281,7 @@ def create_tag_template( timeout=timeout, metadata=metadata, ) - self.log.info('Created a tag template: name=%s', result.name) + self.log.info("Created a tag template: name=%s", result.name) return result @@ -290,11 +291,11 @@ def create_tag_template_field( location: str, tag_template: str, tag_template_field_id: str, - tag_template_field: Union[dict, TagTemplateField], + tag_template_field: dict | TagTemplateField, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> TagTemplateField: r""" Creates a field in a tag template. @@ -320,20 +321,20 @@ def create_tag_template_field( client = self.get_conn() parent = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}" - self.log.info('Creating a new tag template field: parent=%s', parent) + self.log.info("Creating a new tag template field: parent=%s", parent) result = client.create_tag_template_field( request={ - 'parent': parent, - 'tag_template_field_id': tag_template_field_id, - 'tag_template_field': tag_template_field, + "parent": parent, + "tag_template_field_id": tag_template_field_id, + "tag_template_field": tag_template_field, }, retry=retry, timeout=timeout, metadata=metadata, ) - self.log.info('Created a tag template field: name=%s', result.name) + self.log.info("Created a tag template field: name=%s", result.name) return result @@ -344,9 +345,9 @@ def delete_entry( entry_group: str, entry: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Deletes an existing entry. @@ -364,9 +365,9 @@ def delete_entry( """ client = self.get_conn() name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}" - self.log.info('Deleting a entry: name=%s', name) - client.delete_entry(request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()) - self.log.info('Deleted a entry: name=%s', name) + self.log.info("Deleting a entry: name=%s", name) + client.delete_entry(request={"name": name}, retry=retry, timeout=timeout, metadata=metadata or ()) + self.log.info("Deleted a entry: name=%s", name) @GoogleBaseHook.fallback_to_default_project_id def delete_entry_group( @@ -374,9 +375,9 @@ def delete_entry_group( location, entry_group, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Deletes an EntryGroup. @@ -396,11 +397,11 @@ def delete_entry_group( client = self.get_conn() name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}" - self.log.info('Deleting a entry group: name=%s', name) + self.log.info("Deleting a entry group: name=%s", name) client.delete_entry_group( - request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or () + request={"name": name}, retry=retry, timeout=timeout, metadata=metadata or () ) - self.log.info('Deleted a entry group: name=%s', name) + self.log.info("Deleted a entry group: name=%s", name) @GoogleBaseHook.fallback_to_default_project_id def delete_tag( @@ -410,9 +411,9 @@ def delete_tag( entry: str, tag: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Deletes a tag. @@ -434,9 +435,9 @@ def delete_tag( f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}/tags/{tag}" ) - self.log.info('Deleting a tag: name=%s', name) - client.delete_tag(request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()) - self.log.info('Deleted a tag: name=%s', name) + self.log.info("Deleting a tag: name=%s", name) + client.delete_tag(request={"name": name}, retry=retry, timeout=timeout, metadata=metadata or ()) + self.log.info("Deleted a tag: name=%s", name) @GoogleBaseHook.fallback_to_default_project_id def delete_tag_template( @@ -445,9 +446,9 @@ def delete_tag_template( tag_template, force: bool, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Deletes a tag template and all tags using the template. @@ -468,11 +469,11 @@ def delete_tag_template( client = self.get_conn() name = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}" - self.log.info('Deleting a tag template: name=%s', name) + self.log.info("Deleting a tag template: name=%s", name) client.delete_tag_template( - request={'name': name, 'force': force}, retry=retry, timeout=timeout, metadata=metadata or () + request={"name": name, "force": force}, retry=retry, timeout=timeout, metadata=metadata or () ) - self.log.info('Deleted a tag template: name=%s', name) + self.log.info("Deleted a tag template: name=%s", name) @GoogleBaseHook.fallback_to_default_project_id def delete_tag_template_field( @@ -482,9 +483,9 @@ def delete_tag_template_field( field: str, force: bool, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Deletes a field in a tag template and all uses of that field. @@ -504,11 +505,11 @@ def delete_tag_template_field( client = self.get_conn() name = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}/fields/{field}" - self.log.info('Deleting a tag template field: name=%s', name) + self.log.info("Deleting a tag template field: name=%s", name) client.delete_tag_template_field( - request={'name': name, 'force': force}, retry=retry, timeout=timeout, metadata=metadata or () + request={"name": name, "force": force}, retry=retry, timeout=timeout, metadata=metadata or () ) - self.log.info('Deleted a tag template field: name=%s', name) + self.log.info("Deleted a tag template field: name=%s", name) @GoogleBaseHook.fallback_to_default_project_id def get_entry( @@ -517,9 +518,9 @@ def get_entry( entry_group: str, entry: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Entry: """ Gets an entry. @@ -538,11 +539,11 @@ def get_entry( client = self.get_conn() name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}" - self.log.info('Getting a entry: name=%s', name) + self.log.info("Getting a entry: name=%s", name) result = client.get_entry( - request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or () + request={"name": name}, retry=retry, timeout=timeout, metadata=metadata or () ) - self.log.info('Received a entry: name=%s', result.name) + self.log.info("Received a entry: name=%s", result.name) return result @@ -552,10 +553,10 @@ def get_entry_group( location: str, entry_group: str, project_id: str, - read_mask: Optional[FieldMask] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + read_mask: FieldMask | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> EntryGroup: """ Gets an entry group. @@ -577,16 +578,16 @@ def get_entry_group( client = self.get_conn() name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}" - self.log.info('Getting a entry group: name=%s', name) + self.log.info("Getting a entry group: name=%s", name) result = client.get_entry_group( - request={'name': name, 'read_mask': read_mask}, + request={"name": name, "read_mask": read_mask}, retry=retry, timeout=timeout, metadata=metadata, ) - self.log.info('Received a entry group: name=%s', result.name) + self.log.info("Received a entry group: name=%s", result.name) return result @@ -596,9 +597,9 @@ def get_tag_template( location: str, tag_template: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> TagTemplate: """ Gets a tag template. @@ -616,13 +617,13 @@ def get_tag_template( client = self.get_conn() name = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}" - self.log.info('Getting a tag template: name=%s', name) + self.log.info("Getting a tag template: name=%s", name) result = client.get_tag_template( - request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or () + request={"name": name}, retry=retry, timeout=timeout, metadata=metadata or () ) - self.log.info('Received a tag template: name=%s', result.name) + self.log.info("Received a tag template: name=%s", result.name) return result @@ -634,9 +635,9 @@ def list_tags( entry: str, project_id: str, page_size: int = 100, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Lists the tags on an Entry. @@ -658,16 +659,16 @@ def list_tags( client = self.get_conn() parent = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}" - self.log.info('Listing tag on entry: entry_name=%s', parent) + self.log.info("Listing tag on entry: entry_name=%s", parent) result = client.list_tags( - request={'parent': parent, 'page_size': page_size}, + request={"parent": parent, "page_size": page_size}, retry=retry, timeout=timeout, metadata=metadata, ) - self.log.info('Received tags.') + self.log.info("Received tags.") return result @@ -679,9 +680,9 @@ def get_tag_for_template_name( entry: str, template_name: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Tag: """ Gets for a tag with a specific template for a specific entry. @@ -712,11 +713,11 @@ def get_tag_for_template_name( def lookup_entry( self, - linked_resource: Optional[str] = None, - sql_resource: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + linked_resource: str | None = None, + sql_resource: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Entry: r""" Get an entry by target resource name. @@ -743,22 +744,22 @@ def lookup_entry( raise AirflowException("At least one of linked_resource, sql_resource should be set.") if linked_resource: - self.log.info('Getting entry: linked_resource=%s', linked_resource) + self.log.info("Getting entry: linked_resource=%s", linked_resource) result = client.lookup_entry( - request={'linked_resource': linked_resource}, + request={"linked_resource": linked_resource}, retry=retry, timeout=timeout, metadata=metadata, ) else: - self.log.info('Getting entry: sql_resource=%s', sql_resource) + self.log.info("Getting entry: sql_resource=%s", sql_resource) result = client.lookup_entry( - request={'sql_resource': sql_resource}, + request={"sql_resource": sql_resource}, retry=retry, timeout=timeout, metadata=metadata, ) - self.log.info('Received entry. name=%s', result.name) + self.log.info("Received entry. name=%s", result.name) return result @@ -770,9 +771,9 @@ def rename_tag_template_field( field: str, new_tag_template_field_id: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> TagTemplateField: """ Renames a field in a tag template. @@ -795,29 +796,29 @@ def rename_tag_template_field( name = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}/fields/{field}" self.log.info( - 'Renaming field: old_name=%s, new_tag_template_field_id=%s', name, new_tag_template_field_id + "Renaming field: old_name=%s, new_tag_template_field_id=%s", name, new_tag_template_field_id ) result = client.rename_tag_template_field( - request={'name': name, 'new_tag_template_field_id': new_tag_template_field_id}, + request={"name": name, "new_tag_template_field_id": new_tag_template_field_id}, retry=retry, timeout=timeout, metadata=metadata, ) - self.log.info('Renamed tag template field.') + self.log.info("Renamed tag template field.") return result def search_catalog( self, - scope: Union[Dict, SearchCatalogRequest.Scope], + scope: dict | SearchCatalogRequest.Scope, query: str, page_size: int = 100, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): r""" Searches Data Catalog for multiple resources like entries, tags that match a query. @@ -870,28 +871,28 @@ def search_catalog( order_by, ) result = client.search_catalog( - request={'scope': scope, 'query': query, 'page_size': page_size, 'order_by': order_by}, + request={"scope": scope, "query": query, "page_size": page_size, "order_by": order_by}, retry=retry, timeout=timeout, metadata=metadata, ) - self.log.info('Received items.') + self.log.info("Received items.") return result @GoogleBaseHook.fallback_to_default_project_id def update_entry( self, - entry: Union[Dict, Entry], - update_mask: Union[dict, FieldMask], + entry: dict | Entry, + update_mask: dict | FieldMask, project_id: str, - location: Optional[str] = None, - entry_group: Optional[str] = None, - entry_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + location: str | None = None, + entry_group: str | None = None, + entry_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Entry: """ Updates an existing entry. @@ -939,29 +940,29 @@ def update_entry( if isinstance(entry, dict): entry = Entry(**entry) result = client.update_entry( - request={'entry': entry, 'update_mask': update_mask}, + request={"entry": entry, "update_mask": update_mask}, retry=retry, timeout=timeout, metadata=metadata, ) - self.log.info('Updated entry.') + self.log.info("Updated entry.") return result @GoogleBaseHook.fallback_to_default_project_id def update_tag( self, - tag: Union[Dict, Tag], - update_mask: Union[Dict, FieldMask], + tag: dict | Tag, + update_mask: dict | FieldMask, project_id: str, - location: Optional[str] = None, - entry_group: Optional[str] = None, - entry: Optional[str] = None, - tag_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + location: str | None = None, + entry_group: str | None = None, + entry: str | None = None, + tag_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Tag: """ Updates an existing tag. @@ -1012,26 +1013,26 @@ def update_tag( if isinstance(tag, dict): tag = Tag(**tag) result = client.update_tag( - request={'tag': tag, 'update_mask': update_mask}, + request={"tag": tag, "update_mask": update_mask}, retry=retry, timeout=timeout, metadata=metadata, ) - self.log.info('Updated tag.') + self.log.info("Updated tag.") return result @GoogleBaseHook.fallback_to_default_project_id def update_tag_template( self, - tag_template: Union[dict, TagTemplate], - update_mask: Union[dict, FieldMask], + tag_template: dict | TagTemplate, + update_mask: dict | FieldMask, project_id: str, - location: Optional[str] = None, - tag_template_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + location: str | None = None, + tag_template_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> TagTemplate: """ Updates a tag template. @@ -1084,28 +1085,28 @@ def update_tag_template( if isinstance(tag_template, dict): tag_template = TagTemplate(**tag_template) result = client.update_tag_template( - request={'tag_template': tag_template, 'update_mask': update_mask}, + request={"tag_template": tag_template, "update_mask": update_mask}, retry=retry, timeout=timeout, metadata=metadata, ) - self.log.info('Updated tag template.') + self.log.info("Updated tag template.") return result @GoogleBaseHook.fallback_to_default_project_id def update_tag_template_field( self, - tag_template_field: Union[dict, TagTemplateField], - update_mask: Union[dict, FieldMask], + tag_template_field: dict | TagTemplateField, + update_mask: dict | FieldMask, project_id: str, - tag_template_field_name: Optional[str] = None, - location: Optional[str] = None, - tag_template: Optional[str] = None, - tag_template_field_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + tag_template_field_name: str | None = None, + location: str | None = None, + tag_template: str | None = None, + tag_template_field_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Updates a field in a tag template. This method cannot be used to update the field type. @@ -1149,14 +1150,14 @@ def update_tag_template_field( result = client.update_tag_template_field( request={ - 'name': tag_template_field_name, - 'tag_template_field': tag_template_field, - 'update_mask': update_mask, + "name": tag_template_field_name, + "tag_template_field": tag_template_field, + "update_mask": update_mask, }, retry=retry, timeout=timeout, metadata=metadata, ) - self.log.info('Updated tag template field.') + self.log.info("Updated tag template field.") return result diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py index 66b259838e522..b9dbf9478cfc4 100644 --- a/airflow/providers/google/cloud/hooks/dataflow.py +++ b/airflow/providers/google/cloud/hooks/dataflow.py @@ -16,6 +16,8 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Dataflow Hook.""" +from __future__ import annotations + import functools import json import re @@ -25,7 +27,7 @@ import uuid import warnings from copy import deepcopy -from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Set, TypeVar, Union, cast +from typing import Any, Callable, Generator, Sequence, TypeVar, cast from googleapiclient.discovery import build @@ -48,7 +50,7 @@ def process_line_and_extract_dataflow_job_id_callback( - on_new_job_id_callback: Optional[Callable[[str], None]] + on_new_job_id_callback: Callable[[str], None] | None ) -> Callable[[str], None]: """ Returns callback which triggers function passed as `on_new_job_id_callback` when Dataflow job_id is found. @@ -60,7 +62,7 @@ def process_line_and_extract_dataflow_job_id_callback( def _process_line_and_extract_job_id( line: str, - # on_new_job_id_callback: Optional[Callable[[str], None]] + # on_new_job_id_callback: Callable[[str], None] | None ) -> None: # Job id info: https://goo.gl/SE29y9. matched_job = JOB_ID_PATTERN.search(line) @@ -85,7 +87,7 @@ def _wrapper(func: T) -> T: """ @functools.wraps(func) - def inner_wrapper(self: "DataflowHook", *args, **kwargs): + def inner_wrapper(self: DataflowHook, *args, **kwargs): if args: raise AirflowException( "You must use keyword arguments in this methods rather than positional" @@ -189,13 +191,13 @@ def __init__( project_number: str, location: str, poll_sleep: int = 10, - name: Optional[str] = None, - job_id: Optional[str] = None, + name: str | None = None, + job_id: str | None = None, num_retries: int = 0, multiple_jobs: bool = False, drain_pipeline: bool = False, - cancel_timeout: Optional[int] = 5 * 60, - wait_until_finished: Optional[bool] = None, + cancel_timeout: int | None = 5 * 60, + wait_until_finished: bool | None = None, ) -> None: super().__init__() @@ -208,7 +210,7 @@ def __init__( self._num_retries = num_retries self._poll_sleep = poll_sleep self._cancel_timeout = cancel_timeout - self._jobs: Optional[List[dict]] = None + self._jobs: list[dict] | None = None self.drain_pipeline = drain_pipeline self._wait_until_finished = wait_until_finished @@ -217,7 +219,6 @@ def is_job_running(self) -> bool: Helper method to check if jos is still running in dataflow :return: True if job is running. - :rtype: bool """ self._refresh_jobs() if not self._jobs: @@ -228,15 +229,16 @@ def is_job_running(self) -> bool: return True return False - def _get_current_jobs(self) -> List[dict]: + def _get_current_jobs(self) -> list[dict]: """ Helper method to get list of jobs that start with job name or id :return: list of jobs including id's - :rtype: list """ if not self._multiple_jobs and self._job_id: return [self.fetch_job_by_id(self._job_id)] + elif self._jobs: + return [self.fetch_job_by_id(job["id"]) for job in self._jobs] elif self._job_name: jobs = self._fetch_jobs_by_prefix_name(self._job_name.lower()) if len(jobs) == 1: @@ -251,7 +253,6 @@ def fetch_job_by_id(self, job_id: str) -> dict: :param job_id: Job ID to get. :return: the Job - :rtype: dict """ return ( self._dataflow.projects() @@ -272,7 +273,6 @@ def fetch_job_metrics_by_id(self, job_id: str) -> dict: :param job_id: Job ID to get. :return: the JobMetrics. See: https://cloud.google.com/dataflow/docs/reference/rest/v1b3/JobMetrics - :rtype: dict """ result = ( self._dataflow.projects() @@ -292,7 +292,6 @@ def _fetch_list_job_messages_responses(self, job_id: str) -> Generator[dict, Non :param job_id: Job ID to get. :return: yields the ListJobMessagesResponse. See: https://cloud.google.com/dataflow/docs/reference/rest/v1b3/ListJobMessagesResponse - :rtype: Generator[dict, None, None] """ request = ( self._dataflow.projects() @@ -314,42 +313,40 @@ def _fetch_list_job_messages_responses(self, job_id: str) -> Generator[dict, Non .list_next(previous_request=request, previous_response=response) ) - def fetch_job_messages_by_id(self, job_id: str) -> List[dict]: + def fetch_job_messages_by_id(self, job_id: str) -> list[dict]: """ Helper method to fetch the job messages with the specified Job ID. :param job_id: Job ID to get. :return: the list of JobMessages. See: https://cloud.google.com/dataflow/docs/reference/rest/v1b3/ListJobMessagesResponse#JobMessage - :rtype: List[dict] """ - messages: List[dict] = [] + messages: list[dict] = [] for response in self._fetch_list_job_messages_responses(job_id=job_id): messages.extend(response.get("jobMessages", [])) return messages - def fetch_job_autoscaling_events_by_id(self, job_id: str) -> List[dict]: + def fetch_job_autoscaling_events_by_id(self, job_id: str) -> list[dict]: """ Helper method to fetch the job autoscaling events with the specified Job ID. :param job_id: Job ID to get. :return: the list of AutoscalingEvents. See: https://cloud.google.com/dataflow/docs/reference/rest/v1b3/ListJobMessagesResponse#autoscalingevent - :rtype: List[dict] """ - autoscaling_events: List[dict] = [] + autoscaling_events: list[dict] = [] for response in self._fetch_list_job_messages_responses(job_id=job_id): autoscaling_events.extend(response.get("autoscalingEvents", [])) return autoscaling_events - def _fetch_all_jobs(self) -> List[dict]: + def _fetch_all_jobs(self) -> list[dict]: request = ( self._dataflow.projects() .locations() .jobs() .list(projectId=self._project_number, location=self._job_location) ) - all_jobs: List[dict] = [] + all_jobs: list[dict] = [] while request is not None: response = request.execute(num_retries=self._num_retries) jobs = response.get("jobs") @@ -365,7 +362,7 @@ def _fetch_all_jobs(self) -> List[dict]: ) return all_jobs - def _fetch_jobs_by_prefix_name(self, prefix_name: str) -> List[dict]: + def _fetch_jobs_by_prefix_name(self, prefix_name: str) -> list[dict]: jobs = self._fetch_all_jobs() jobs = [job for job in jobs if job["name"].startswith(prefix_name)] return jobs @@ -375,7 +372,6 @@ def _refresh_jobs(self) -> None: Helper method to get all jobs by name :return: jobs - :rtype: list """ self._jobs = self._get_current_jobs() @@ -395,27 +391,26 @@ def _check_dataflow_job_state(self, job) -> bool: if job failed raise exception :return: True if job is done. - :rtype: bool :raise: Exception """ if self._wait_until_finished is None: - wait_for_running = job.get('type') == DataflowJobType.JOB_TYPE_STREAMING + wait_for_running = job.get("type") == DataflowJobType.JOB_TYPE_STREAMING else: wait_for_running = not self._wait_until_finished - if job['currentState'] == DataflowJobStatus.JOB_STATE_DONE: + if job["currentState"] == DataflowJobStatus.JOB_STATE_DONE: return True - elif job['currentState'] == DataflowJobStatus.JOB_STATE_FAILED: + elif job["currentState"] == DataflowJobStatus.JOB_STATE_FAILED: raise Exception(f"Google Cloud Dataflow job {job['name']} has failed.") - elif job['currentState'] == DataflowJobStatus.JOB_STATE_CANCELLED: + elif job["currentState"] == DataflowJobStatus.JOB_STATE_CANCELLED: raise Exception(f"Google Cloud Dataflow job {job['name']} was cancelled.") - elif job['currentState'] == DataflowJobStatus.JOB_STATE_DRAINED: + elif job["currentState"] == DataflowJobStatus.JOB_STATE_DRAINED: raise Exception(f"Google Cloud Dataflow job {job['name']} was drained.") - elif job['currentState'] == DataflowJobStatus.JOB_STATE_UPDATED: + elif job["currentState"] == DataflowJobStatus.JOB_STATE_UPDATED: raise Exception(f"Google Cloud Dataflow job {job['name']} was updated.") - elif job['currentState'] == DataflowJobStatus.JOB_STATE_RUNNING and wait_for_running: + elif job["currentState"] == DataflowJobStatus.JOB_STATE_RUNNING and wait_for_running: return True - elif job['currentState'] in DataflowJobStatus.AWAITING_STATES: + elif job["currentState"] in DataflowJobStatus.AWAITING_STATES: return self._wait_until_finished is False self.log.debug("Current job: %s", str(job)) raise Exception(f"Google Cloud Dataflow job {job['name']} was unknown state: {job['currentState']}") @@ -429,13 +424,12 @@ def wait_for_done(self) -> None: time.sleep(self._poll_sleep) self._refresh_jobs() - def get_jobs(self, refresh: bool = False) -> List[dict]: + def get_jobs(self, refresh: bool = False) -> list[dict]: """ Returns Dataflow jobs. :param refresh: Forces the latest data to be fetched. :return: list of jobs - :rtype: list """ if not self._jobs or refresh: self._refresh_jobs() @@ -444,20 +438,20 @@ def get_jobs(self, refresh: bool = False) -> List[dict]: return self._jobs - def _wait_for_states(self, expected_states: Set[str]): + def _wait_for_states(self, expected_states: set[str]): """Waiting for the jobs to reach a certain state.""" if not self._jobs: raise ValueError("The _jobs should be set") while True: self._refresh_jobs() - job_states = {job['currentState'] for job in self._jobs} + job_states = {job["currentState"] for job in self._jobs} if not job_states.difference(expected_states): return - unexpected_failed_end_states = expected_states - DataflowJobStatus.FAILED_END_STATES + unexpected_failed_end_states = DataflowJobStatus.FAILED_END_STATES - expected_states if unexpected_failed_end_states.intersection(job_states): - unexpected_failed_jobs = { - job for job in self._jobs if job['currentState'] in unexpected_failed_end_states - } + unexpected_failed_jobs = [ + job for job in self._jobs if job["currentState"] in unexpected_failed_end_states + ] raise AirflowException( "Jobs failed: " + ", ".join( @@ -469,18 +463,19 @@ def _wait_for_states(self, expected_states: Set[str]): def cancel(self) -> None: """Cancels or drains current job""" - jobs = self.get_jobs() - job_ids = [job["id"] for job in jobs if job["currentState"] not in DataflowJobStatus.TERMINAL_STATES] + self._jobs = [ + job for job in self.get_jobs() if job["currentState"] not in DataflowJobStatus.TERMINAL_STATES + ] + job_ids = [job["id"] for job in self._jobs] if job_ids: - batch = self._dataflow.new_batch_http_request() self.log.info("Canceling jobs: %s", ", ".join(job_ids)) - for job in jobs: + for job in self._jobs: requested_state = ( DataflowJobStatus.JOB_STATE_DRAINED if self.drain_pipeline and job["type"] == DataflowJobType.JOB_TYPE_STREAMING else DataflowJobStatus.JOB_STATE_CANCELLED ) - batch.add( + request = ( self._dataflow.projects() .locations() .jobs() @@ -491,14 +486,16 @@ def cancel(self) -> None: body={"requestedState": requested_state}, ) ) - batch.execute() + request.execute(num_retries=self._num_retries) if self._cancel_timeout and isinstance(self._cancel_timeout, int): timeout_error_message = ( f"Canceling jobs failed due to timeout ({self._cancel_timeout}s): {', '.join(job_ids)}" ) tm = timeout(seconds=self._cancel_timeout, error_message=timeout_error_message) with tm: - self._wait_for_states({DataflowJobStatus.JOB_STATE_CANCELLED}) + self._wait_for_states( + {DataflowJobStatus.JOB_STATE_CANCELLED, DataflowJobStatus.JOB_STATE_DRAINED} + ) else: self.log.info("No jobs to cancel") @@ -514,18 +511,18 @@ class DataflowHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, + delegate_to: str | None = None, poll_sleep: int = 10, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, drain_pipeline: bool = False, - cancel_timeout: Optional[int] = 5 * 60, - wait_until_finished: Optional[bool] = None, + cancel_timeout: int | None = 5 * 60, + wait_until_finished: bool | None = None, ) -> None: self.poll_sleep = poll_sleep self.drain_pipeline = drain_pipeline self.cancel_timeout = cancel_timeout self.wait_until_finished = wait_until_finished - self.job_id: Optional[str] = None + self.job_id: str | None = None self.beam_hook = BeamHook(BeamRunnerType.DataflowRunner) super().__init__( gcp_conn_id=gcp_conn_id, @@ -547,10 +544,10 @@ def start_java_dataflow( variables: dict, jar: str, project_id: str, - job_class: Optional[str] = None, + job_class: str | None = None, append_job_name: bool = True, multiple_jobs: bool = False, - on_new_job_id_callback: Optional[Callable[[str], None]] = None, + on_new_job_id_callback: Callable[[str], None] | None = None, location: str = DEFAULT_DATAFLOW_LOCATION, ) -> None: """ @@ -610,10 +607,10 @@ def start_template_dataflow( dataflow_template: str, project_id: str, append_job_name: bool = True, - on_new_job_id_callback: Optional[Callable[[str], None]] = None, - on_new_job_callback: Optional[Callable[[dict], None]] = None, + on_new_job_id_callback: Callable[[str], None] | None = None, + on_new_job_callback: Callable[[dict], None] | None = None, location: str = DEFAULT_DATAFLOW_LOCATION, - environment: Optional[dict] = None, + environment: dict | None = None, ) -> dict: """ Starts Dataflow template job. @@ -728,8 +725,8 @@ def start_flex_template( body: dict, location: str, project_id: str, - on_new_job_id_callback: Optional[Callable[[str], None]] = None, - on_new_job_callback: Optional[Callable[[dict], None]] = None, + on_new_job_id_callback: Callable[[str], None] | None = None, + on_new_job_callback: Callable[[dict], None] | None = None, ): """ Starts flex templates with the Dataflow pipeline. @@ -786,13 +783,13 @@ def start_python_dataflow( job_name: str, variables: dict, dataflow: str, - py_options: List[str], + py_options: list[str], project_id: str, py_interpreter: str = "python3", - py_requirements: Optional[List[str]] = None, + py_requirements: list[str] | None = None, py_system_site_packages: bool = False, append_job_name: bool = True, - on_new_job_id_callback: Optional[Callable[[str], None]] = None, + on_new_job_id_callback: Callable[[str], None] | None = None, location: str = DEFAULT_DATAFLOW_LOCATION, ): """ @@ -881,7 +878,7 @@ def is_job_dataflow_running( name: str, project_id: str, location: str = DEFAULT_DATAFLOW_LOCATION, - variables: Optional[dict] = None, + variables: dict | None = None, ) -> bool: """ Helper method to check if jos is still running in dataflow @@ -891,7 +888,6 @@ def is_job_dataflow_running( If set to None or missing, the default project_id from the Google Cloud connection is used. :param location: Job location. :return: True if job is running. - :rtype: bool """ if variables: warnings.warn( @@ -916,8 +912,8 @@ def is_job_dataflow_running( def cancel_job( self, project_id: str, - job_name: Optional[str] = None, - job_id: Optional[str] = None, + job_name: str | None = None, + job_id: str | None = None, location: str = DEFAULT_DATAFLOW_LOCATION, ) -> None: """ @@ -949,11 +945,11 @@ def start_sql_job( self, job_name: str, query: str, - options: Dict[str, Any], + options: dict[str, Any], project_id: str, location: str = DEFAULT_DATAFLOW_LOCATION, - on_new_job_id_callback: Optional[Callable[[str], None]] = None, - on_new_job_callback: Optional[Callable[[dict], None]] = None, + on_new_job_id_callback: Callable[[str], None] | None = None, + on_new_job_callback: Callable[[dict], None] | None = None, ): """ Starts Dataflow SQL query. @@ -972,16 +968,31 @@ def start_sql_job( :param on_new_job_callback: Callback called when the job is known. :return: the new job object """ + gcp_options = [ + f"--project={project_id}", + "--format=value(job.id)", + f"--job-name={job_name}", + f"--region={location}", + ] + + if self.impersonation_chain: + if isinstance(self.impersonation_chain, str): + impersonation_account = self.impersonation_chain + elif len(self.impersonation_chain) == 1: + impersonation_account = self.impersonation_chain[0] + else: + raise AirflowException( + "Chained list of accounts is not supported, please specify only one service account" + ) + gcp_options.append(f"--impersonate-service-account={impersonation_account}") + cmd = [ "gcloud", "dataflow", "sql", "query", query, - f"--project={project_id}", - "--format=value(job.id)", - f"--job-name={job_name}", - f"--region={location}", + *gcp_options, *(beam_options_to_args(options)), ] self.log.info("Executing command: %s", " ".join(shlex.quote(c) for c in cmd)) @@ -1038,7 +1049,6 @@ def get_job( :param location: The location of the Dataflow job (for example europe-west1). See: https://cloud.google.com/dataflow/docs/concepts/regional-endpoints :return: the Job - :rtype: dict """ jobs_controller = _DataflowJobsController( dataflow=self.get_conn(), @@ -1064,7 +1074,6 @@ def fetch_job_metrics_by_id( https://cloud.google.com/dataflow/docs/concepts/regional-endpoints :return: the JobMetrics. See: https://cloud.google.com/dataflow/docs/reference/rest/v1b3/JobMetrics - :rtype: dict """ jobs_controller = _DataflowJobsController( dataflow=self.get_conn(), @@ -1079,7 +1088,7 @@ def fetch_job_messages_by_id( job_id: str, project_id: str, location: str = DEFAULT_DATAFLOW_LOCATION, - ) -> List[dict]: + ) -> list[dict]: """ Gets the job messages with the specified Job ID. @@ -1089,7 +1098,6 @@ def fetch_job_messages_by_id( :param location: Job location. :return: the list of JobMessages. See: https://cloud.google.com/dataflow/docs/reference/rest/v1b3/ListJobMessagesResponse#JobMessage - :rtype: List[dict] """ jobs_controller = _DataflowJobsController( dataflow=self.get_conn(), @@ -1104,7 +1112,7 @@ def fetch_job_autoscaling_events_by_id( job_id: str, project_id: str, location: str = DEFAULT_DATAFLOW_LOCATION, - ) -> List[dict]: + ) -> list[dict]: """ Gets the job autoscaling events with the specified Job ID. @@ -1114,7 +1122,6 @@ def fetch_job_autoscaling_events_by_id( :param location: Job location. :return: the list of AutoscalingEvents. See: https://cloud.google.com/dataflow/docs/reference/rest/v1b3/ListJobMessagesResponse#autoscalingevent - :rtype: List[dict] """ jobs_controller = _DataflowJobsController( dataflow=self.get_conn(), @@ -1129,7 +1136,7 @@ def wait_for_done( job_name: str, location: str, project_id: str, - job_id: Optional[str] = None, + job_id: str | None = None, multiple_jobs: bool = False, ) -> None: """ diff --git a/airflow/providers/google/cloud/hooks/dataform.py b/airflow/providers/google/cloud/hooks/dataform.py new file mode 100644 index 0000000000000..aa8506c16ed7a --- /dev/null +++ b/airflow/providers/google/cloud/hooks/dataform.py @@ -0,0 +1,627 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import time +from typing import Sequence + +from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault +from google.api_core.retry import Retry +from google.cloud.dataform_v1beta1 import DataformClient +from google.cloud.dataform_v1beta1.types import ( + CompilationResult, + InstallNpmPackagesResponse, + Repository, + WorkflowInvocation, + Workspace, + WriteFileResponse, +) + +from airflow import AirflowException +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook + + +class DataformHook(GoogleBaseHook): + """Hook for Google Cloud DataForm APIs.""" + + def get_dataform_client(self) -> DataformClient: + """Retrieves client library object that allow access to Cloud Dataform service.""" + return DataformClient(credentials=self.get_credentials()) + + @GoogleBaseHook.fallback_to_default_project_id + def wait_for_workflow_invocation( + self, + workflow_invocation_id: str, + repository_id: str, + project_id: str, + region: str, + wait_time: int = 10, + timeout: int | None = None, + ) -> None: + """ + Helper method which polls a job to check if it finishes. + + :param workflow_invocation_id: Id of the Workflow Invocation + :param repository_id: Id of the Dataform repository + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param wait_time: Number of seconds between checks + :param timeout: How many seconds wait for job to be ready. Used only if ``asynchronous`` is False + """ + if region is None: + raise TypeError("missing 1 required keyword argument: 'region'") + state = None + start = time.monotonic() + while state not in ( + WorkflowInvocation.State.FAILED, + WorkflowInvocation.State.SUCCEEDED, + WorkflowInvocation.State.CANCELLED, + ): + if timeout and start + timeout < time.monotonic(): + raise AirflowException( + f"Timeout: workflow invocation {workflow_invocation_id} is not ready after {timeout}s" + ) + time.sleep(wait_time) + try: + workflow_invocation = self.get_workflow_invocation( + project_id=project_id, + region=region, + repository_id=repository_id, + workflow_invocation_id=workflow_invocation_id, + ) + state = workflow_invocation.state + except Exception as err: + self.log.info( + "Retrying. Dataform API returned error when waiting for workflow invocation: %s", err + ) + + if state == WorkflowInvocation.State.FAILED: + raise AirflowException(f"Workflow Invocation failed:\n{workflow_invocation}") + if state == WorkflowInvocation.State.CANCELLED: + raise AirflowException(f"Workflow Invocation was cancelled:\n{workflow_invocation}") + + @GoogleBaseHook.fallback_to_default_project_id + def create_compilation_result( + self, + project_id: str, + region: str, + repository_id: str, + compilation_result: CompilationResult | dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> CompilationResult: + """ + Creates a new CompilationResult in a given project and location. + + :param project_id: Required. The ID of the Google Cloud project that the task belongs to. + :param region: Required. The ID of the Google Cloud region that the task belongs to. + :param repository_id: Required. The ID of the Dataform repository that the task belongs to. + :param compilation_result: Required. The compilation result to create. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataform_client() + parent = f"projects/{project_id}/locations/{region}/repositories/{repository_id}" + return client.create_compilation_result( + request={ + "parent": parent, + "compilation_result": compilation_result, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def get_compilation_result( + self, + project_id: str, + region: str, + repository_id: str, + compilation_result_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> CompilationResult: + """ + Fetches a single CompilationResult. + + :param project_id: Required. The ID of the Google Cloud project that the task belongs to. + :param region: Required. The ID of the Google Cloud region that the task belongs to. + :param repository_id: Required. The ID of the Dataform repository that the task belongs to. + :param compilation_result_id: The Id of the Dataform Compilation Result + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataform_client() + name = ( + f"projects/{project_id}/locations/{region}/repositories/" + f"{repository_id}/compilationResults/{compilation_result_id}" + ) + return client.get_compilation_result( + request={"name": name}, retry=retry, timeout=timeout, metadata=metadata + ) + + @GoogleBaseHook.fallback_to_default_project_id + def create_workflow_invocation( + self, + project_id: str, + region: str, + repository_id: str, + workflow_invocation: WorkflowInvocation | dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> WorkflowInvocation: + """ + Creates a new WorkflowInvocation in a given Repository. + + :param project_id: Required. The ID of the Google Cloud project that the task belongs to. + :param region: Required. The ID of the Google Cloud region that the task belongs to. + :param repository_id: Required. The ID of the Dataform repository that the task belongs to. + :param workflow_invocation: Required. The workflow invocation resource to create. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataform_client() + parent = f"projects/{project_id}/locations/{region}/repositories/{repository_id}" + return client.create_workflow_invocation( + request={"parent": parent, "workflow_invocation": workflow_invocation}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def get_workflow_invocation( + self, + project_id: str, + region: str, + repository_id: str, + workflow_invocation_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> WorkflowInvocation: + """ + Fetches a single WorkflowInvocation. + + :param project_id: Required. The ID of the Google Cloud project that the task belongs to. + :param region: Required. The ID of the Google Cloud region that the task belongs to. + :param repository_id: Required. The ID of the Dataform repository that the task belongs to. + :param workflow_invocation_id: Required. The workflow invocation resource's id. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataform_client() + name = ( + f"projects/{project_id}/locations/{region}/repositories/" + f"{repository_id}/workflowInvocations/{workflow_invocation_id}" + ) + return client.get_workflow_invocation( + request={ + "name": name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def cancel_workflow_invocation( + self, + project_id: str, + region: str, + repository_id: str, + workflow_invocation_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ): + """ + Requests cancellation of a running WorkflowInvocation. + + :param project_id: Required. The ID of the Google Cloud project that the task belongs to. + :param region: Required. The ID of the Google Cloud region that the task belongs to. + :param repository_id: Required. The ID of the Dataform repository that the task belongs to. + :param workflow_invocation_id: Required. The workflow invocation resource's id. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataform_client() + name = ( + f"projects/{project_id}/locations/{region}/repositories/" + f"{repository_id}/workflowInvocations/{workflow_invocation_id}" + ) + client.cancel_workflow_invocation( + request={"name": name}, retry=retry, timeout=timeout, metadata=metadata + ) + + @GoogleBaseHook.fallback_to_default_project_id + def create_repository( + self, + *, + project_id: str, + region: str, + repository_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Repository: + """ + Creates repository + + :param project_id: Required. The ID of the Google Cloud project where repository should be. + :param region: Required. The ID of the Google Cloud region where repository should be. + :param repository_id: Required. The ID of the new Dataform repository. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataform_client() + parent = f"projects/{project_id}/locations/{region}" + request = { + "parent": parent, + "repository_id": repository_id, + } + + repository = client.create_repository( + request=request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + return repository + + @GoogleBaseHook.fallback_to_default_project_id + def delete_repository( + self, + *, + project_id: str, + region: str, + repository_id: str, + force: bool = True, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> None: + """ + Deletes repository. + + :param project_id: Required. The ID of the Google Cloud project where repository located. + :param region: Required. The ID of the Google Cloud region where repository located. + :param repository_id: Required. The ID of the Dataform repository that should be deleted. + :param force: If set to true, any child resources of this repository will also be deleted. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataform_client() + name = f"projects/{project_id}/locations/{region}/repositories/{repository_id}" + request = { + "name": name, + "force": force, + } + + client.delete_repository( + request=request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def create_workspace( + self, + *, + project_id: str, + region: str, + repository_id: str, + workspace_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Workspace: + """ + Creates workspace. + + :param project_id: Required. The ID of the Google Cloud project where workspace should be. + :param region: Required. The ID of the Google Cloud region where workspace should be. + :param repository_id: Required. The ID of the Dataform repository where workspace should be. + :param workspace_id: Required. The ID of the new Dataform workspace. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataform_client() + parent = f"projects/{project_id}/locations/{region}/repositories/{repository_id}" + + request = {"parent": parent, "workspace_id": workspace_id} + + workspace = client.create_workspace( + request=request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + return workspace + + @GoogleBaseHook.fallback_to_default_project_id + def delete_workspace( + self, + *, + project_id: str, + region: str, + repository_id: str, + workspace_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ): + """ + Deletes workspace. + + :param project_id: Required. The ID of the Google Cloud project where workspace located. + :param region: Required. The ID of the Google Cloud region where workspace located. + :param repository_id: Required. The ID of the Dataform repository where workspace located. + :param workspace_id: Required. The ID of the Dataform workspace that should be deleted. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataform_client() + workspace_path = ( + f"projects/{project_id}/locations/{region}/" + f"repositories/{repository_id}/workspaces/{workspace_id}" + ) + request = { + "name": workspace_path, + } + + client.delete_workspace( + request=request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def write_file( + self, + *, + project_id: str, + region: str, + repository_id: str, + workspace_id: str, + filepath: str, + contents: bytes, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> WriteFileResponse: + """ + Writes a new file to the specified workspace. + + :param project_id: Required. The ID of the Google Cloud project where workspace located. + :param region: Required. The ID of the Google Cloud region where workspace located. + :param repository_id: Required. The ID of the Dataform repository where workspace located. + :param workspace_id: Required. The ID of the Dataform workspace where files should be created. + :param filepath: Required. Path to file including name of the file relative to workspace root. + :param contents: Required. Content of the file to be written. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataform_client() + workspace_path = ( + f"projects/{project_id}/locations/{region}/" + f"repositories/{repository_id}/workspaces/{workspace_id}" + ) + request = { + "workspace": workspace_path, + "path": filepath, + "contents": contents, + } + + response = client.write_file( + request=request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + return response + + @GoogleBaseHook.fallback_to_default_project_id + def make_directory( + self, + *, + project_id: str, + region: str, + repository_id: str, + workspace_id: str, + path: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> dict: + """ + Makes new directory in specified workspace. + + :param project_id: Required. The ID of the Google Cloud project where workspace located. + :param region: Required. The ID of the Google Cloud region where workspace located. + :param repository_id: Required. The ID of the Dataform repository where workspace located. + :param workspace_id: Required. The ID of the Dataform workspace where directory should be created. + :param path: Required. The directory's full path including new directory name, + relative to the workspace root. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataform_client() + workspace_path = ( + f"projects/{project_id}/locations/{region}/" + f"repositories/{repository_id}/workspaces/{workspace_id}" + ) + request = { + "workspace": workspace_path, + "path": path, + } + + response = client.make_directory( + request=request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + return response + + @GoogleBaseHook.fallback_to_default_project_id + def remove_directory( + self, + *, + project_id: str, + region: str, + repository_id: str, + workspace_id: str, + path: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ): + """ + Removes directory in specified workspace. + + :param project_id: Required. The ID of the Google Cloud project where workspace located. + :param region: Required. The ID of the Google Cloud region where workspace located. + :param repository_id: Required. The ID of the Dataform repository where workspace located. + :param workspace_id: Required. The ID of the Dataform workspace where directory located. + :param path: Required. The directory's full path including directory name, + relative to the workspace root. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataform_client() + workspace_path = ( + f"projects/{project_id}/locations/{region}/" + f"repositories/{repository_id}/workspaces/{workspace_id}" + ) + request = { + "workspace": workspace_path, + "path": path, + } + + client.remove_directory( + request=request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def remove_file( + self, + *, + project_id: str, + region: str, + repository_id: str, + workspace_id: str, + filepath: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ): + """ + Removes file in specified workspace. + + :param project_id: Required. The ID of the Google Cloud project where workspace located. + :param region: Required. The ID of the Google Cloud region where workspace located. + :param repository_id: Required. The ID of the Dataform repository where workspace located. + :param workspace_id: Required. The ID of the Dataform workspace where directory located. + :param filepath: Required. The full path including name of the file, relative to the workspace root. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataform_client() + workspace_path = ( + f"projects/{project_id}/locations/{region}/" + f"repositories/{repository_id}/workspaces/{workspace_id}" + ) + request = { + "workspace": workspace_path, + "path": filepath, + } + + client.remove_file( + request=request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def install_npm_packages( + self, + *, + project_id: str, + region: str, + repository_id: str, + workspace_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> InstallNpmPackagesResponse: + """ + Installs npm dependencies in the provided workspace. Requires "package.json" + to be created in workspace + + :param project_id: Required. The ID of the Google Cloud project where workspace located. + :param region: Required. The ID of the Google Cloud region where workspace located. + :param repository_id: Required. The ID of the Dataform repository where workspace located. + :param workspace_id: Required. The ID of the Dataform workspace. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_dataform_client() + workspace_path = ( + f"projects/{project_id}/locations/{region}/" + f"repositories/{repository_id}/workspaces/{workspace_id}" + ) + request = { + "workspace": workspace_path, + } + + response = client.install_npm_packages( + request=request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + return response diff --git a/airflow/providers/google/cloud/hooks/datafusion.py b/airflow/providers/google/cloud/hooks/datafusion.py index 8068c3ece2431..89861a0c4f740 100644 --- a/airflow/providers/google/cloud/hooks/datafusion.py +++ b/airflow/providers/google/cloud/hooks/datafusion.py @@ -14,12 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains Google DataFusion hook.""" +from __future__ import annotations + import json import os from time import monotonic, sleep -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, Sequence from urllib.parse import quote, urlencode import google.auth @@ -53,14 +54,14 @@ class PipelineStates: class DataFusionHook(GoogleBaseHook): """Hook for Google DataFusion.""" - _conn = None # type: Optional[Resource] + _conn: Resource | None = None def __init__( self, api_version: str = "v1beta1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -69,7 +70,7 @@ def __init__( ) self.api_version = api_version - def wait_for_operation(self, operation: Dict[str, Any]) -> Dict[str, Any]: + def wait_for_operation(self, operation: dict[str, Any]) -> dict[str, Any]: """Waits for long-lasting operation to complete.""" for time_to_wait in exponential_sleep_generator(initial=10, maximum=120): sleep(time_to_wait) @@ -88,8 +89,8 @@ def wait_for_pipeline_state( pipeline_id: str, instance_url: str, namespace: str = "default", - success_states: Optional[List[str]] = None, - failure_states: Optional[List[str]] = None, + success_states: list[str] | None = None, + failure_states: list[str] | None = None, timeout: int = 5 * 60, ) -> None: """ @@ -138,12 +139,12 @@ def _base_url(instance_url: str, namespace: str) -> str: return os.path.join(instance_url, "v3", "namespaces", quote(namespace), "apps") def _cdap_request( - self, url: str, method: str, body: Optional[Union[List, Dict]] = None + self, url: str, method: str, body: list | dict | None = None ) -> google.auth.transport.Response: - headers: Dict[str, str] = {"Content-Type": "application/json"} + headers: dict[str, str] = {"Content-Type": "application/json"} request = google.auth.transport.requests.Request() - credentials = self._get_credentials() + credentials = self.get_credentials() credentials.before_request(request=request, method=method, url=url, headers=headers) payload = json.dumps(body) if body else None @@ -151,6 +152,16 @@ def _cdap_request( response = request(method=method, url=url, headers=headers, body=payload) return response + @staticmethod + def _check_response_status_and_data(response, message: str) -> None: + if response.status != 200: + raise AirflowException(message) + if response.data is None: + raise AirflowException( + "Empty response received. Please, check for possible root " + "causes of this behavior either in DAG code or on Cloud Datafusion side" + ) + def get_conn(self) -> Resource: """Retrieves connection to DataFusion.""" if not self._conn: @@ -206,7 +217,7 @@ def delete_instance(self, instance_name: str, location: str, project_id: str) -> def create_instance( self, instance_name: str, - instance: Dict[str, Any], + instance: dict[str, Any], location: str, project_id: str = PROVIDE_PROJECT_ID, ) -> Operation: @@ -234,7 +245,7 @@ def create_instance( return operation @GoogleBaseHook.fallback_to_default_project_id - def get_instance(self, instance_name: str, location: str, project_id: str) -> Dict[str, Any]: + def get_instance(self, instance_name: str, location: str, project_id: str) -> dict[str, Any]: """ Gets details of a single Data Fusion instance. @@ -256,7 +267,7 @@ def get_instance(self, instance_name: str, location: str, project_id: str) -> Di def patch_instance( self, instance_name: str, - instance: Dict[str, Any], + instance: dict[str, Any], update_mask: str, location: str, project_id: str = PROVIDE_PROJECT_ID, @@ -293,7 +304,7 @@ def patch_instance( def create_pipeline( self, pipeline_name: str, - pipeline: Dict[str, Any], + pipeline: dict[str, Any], instance_url: str, namespace: str = "default", ) -> None: @@ -310,16 +321,15 @@ def create_pipeline( """ url = os.path.join(self._base_url(instance_url, namespace), quote(pipeline_name)) response = self._cdap_request(url=url, method="PUT", body=pipeline) - if response.status != 200: - raise AirflowException( - f"Creating a pipeline failed with code {response.status} while calling {url}" - ) + self._check_response_status_and_data( + response, f"Creating a pipeline failed with code {response.status} while calling {url}" + ) def delete_pipeline( self, pipeline_name: str, instance_url: str, - version_id: Optional[str] = None, + version_id: str | None = None, namespace: str = "default", ) -> None: """ @@ -337,14 +347,15 @@ def delete_pipeline( url = os.path.join(url, "versions", version_id) response = self._cdap_request(url=url, method="DELETE", body=None) - if response.status != 200: - raise AirflowException(f"Deleting a pipeline failed with code {response.status}") + self._check_response_status_and_data( + response, f"Deleting a pipeline failed with code {response.status}" + ) def list_pipelines( self, instance_url: str, - artifact_name: Optional[str] = None, - artifact_version: Optional[str] = None, + artifact_name: str | None = None, + artifact_version: str | None = None, namespace: str = "default", ) -> dict: """ @@ -358,7 +369,7 @@ def list_pipelines( can create a namespace. """ url = self._base_url(instance_url, namespace) - query: Dict[str, str] = {} + query: dict[str, str] = {} if artifact_name: query = {"artifactName": artifact_name} if artifact_version: @@ -367,8 +378,9 @@ def list_pipelines( url = os.path.join(url, urlencode(query)) response = self._cdap_request(url=url, method="GET", body=None) - if response.status != 200: - raise AirflowException(f"Listing pipelines failed with code {response.status}") + self._check_response_status_and_data( + response, f"Listing pipelines failed with code {response.status}" + ) return json.loads(response.data) def get_pipeline_workflow( @@ -387,8 +399,9 @@ def get_pipeline_workflow( quote(pipeline_id), ) response = self._cdap_request(url=url, method="GET") - if response.status != 200: - raise AirflowException(f"Retrieving a pipeline state failed with code {response.status}") + self._check_response_status_and_data( + response, f"Retrieving a pipeline state failed with code {response.status}" + ) workflow = json.loads(response.data) return workflow @@ -397,7 +410,7 @@ def start_pipeline( pipeline_name: str, instance_url: str, namespace: str = "default", - runtime_args: Optional[Dict[str, Any]] = None, + runtime_args: dict[str, Any] | None = None, ) -> str: """ Starts a Cloud Data Fusion pipeline. Works for both batch and stream pipelines. @@ -429,9 +442,9 @@ def start_pipeline( } ] response = self._cdap_request(url=url, method="POST", body=body) - if response.status != 200: - raise AirflowException(f"Starting a pipeline failed with code {response.status}") - + self._check_response_status_and_data( + response, f"Starting a pipeline failed with code {response.status}" + ) response_json = json.loads(response.data) return response_json[0]["runId"] @@ -453,5 +466,6 @@ def stop_pipeline(self, pipeline_name: str, instance_url: str, namespace: str = "stop", ) response = self._cdap_request(url=url, method="POST") - if response.status != 200: - raise AirflowException(f"Stopping a pipeline failed with code {response.status}") + self._check_response_status_and_data( + response, f"Stopping a pipeline failed with code {response.status}" + ) diff --git a/airflow/providers/google/cloud/hooks/dataplex.py b/airflow/providers/google/cloud/hooks/dataplex.py index c51158d4ec62d..e8121b582e08b 100644 --- a/airflow/providers/google/cloud/hooks/dataplex.py +++ b/airflow/providers/google/cloud/hooks/dataplex.py @@ -14,19 +14,21 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains Google Dataplex hook.""" -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from typing import Any, Sequence from google.api_core.client_options import ClientOptions from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.operation import Operation from google.api_core.retry import Retry from google.cloud.dataplex_v1 import DataplexServiceClient -from google.cloud.dataplex_v1.types import Task +from google.cloud.dataplex_v1.types import Lake, Task from googleapiclient.discovery import Resource from airflow.exceptions import AirflowException +from airflow.providers.google.common.consts import CLIENT_INFO from airflow.providers.google.common.hooks.base_google import GoogleBaseHook @@ -48,14 +50,14 @@ class DataplexHook(GoogleBaseHook): account from the list granting this role to the originating account (templated). """ - _conn = None # type: Optional[Resource] + _conn: Resource | None = None def __init__( self, api_version: str = "v1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -66,13 +68,13 @@ def __init__( def get_dataplex_client(self) -> DataplexServiceClient: """Returns DataplexServiceClient.""" - client_options = ClientOptions(api_endpoint='dataplex.googleapis.com:443') + client_options = ClientOptions(api_endpoint="dataplex.googleapis.com:443") return DataplexServiceClient( - credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options + credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options ) - def wait_for_operation(self, timeout: Optional[float], operation: Operation): + def wait_for_operation(self, timeout: float | None, operation: Operation): """Waits for long-lasting operation to complete.""" try: return operation.result(timeout=timeout) @@ -86,12 +88,12 @@ def create_task( project_id: str, region: str, lake_id: str, - body: Union[Dict[str, Any], Task], + body: dict[str, Any] | Task, dataplex_task_id: str, - validate_only: Optional[bool] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + validate_only: bool | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Any: """ Creates a task resource within a lake. @@ -109,14 +111,14 @@ def create_task( Note that if `retry` is specified, the timeout applies to each individual attempt. :param metadata: Additional metadata that is provided to the method. """ - parent = f'projects/{project_id}/locations/{region}/lakes/{lake_id}' + parent = f"projects/{project_id}/locations/{region}/lakes/{lake_id}" client = self.get_dataplex_client() result = client.create_task( request={ - 'parent': parent, - 'task_id': dataplex_task_id, - 'task': body, + "parent": parent, + "task_id": dataplex_task_id, + "task": body, }, retry=retry, timeout=timeout, @@ -131,9 +133,9 @@ def delete_task( region: str, lake_id: str, dataplex_task_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Any: """ Delete the task resource. @@ -148,12 +150,12 @@ def delete_task( Note that if `retry` is specified, the timeout applies to each individual attempt. :param metadata: Additional metadata that is provided to the method. """ - name = f'projects/{project_id}/locations/{region}/lakes/{lake_id}/tasks/{dataplex_task_id}' + name = f"projects/{project_id}/locations/{region}/lakes/{lake_id}/tasks/{dataplex_task_id}" client = self.get_dataplex_client() result = client.delete_task( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -167,13 +169,13 @@ def list_tasks( project_id: str, region: str, lake_id: str, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - filter: Optional[str] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + page_size: int | None = None, + page_token: str | None = None, + filter: str | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Any: """ Lists tasks under the given lake. @@ -195,16 +197,16 @@ def list_tasks( Note that if `retry` is specified, the timeout applies to each individual attempt. :param metadata: Additional metadata that is provided to the method. """ - parent = f'projects/{project_id}/locations/{region}/lakes/{lake_id}' + parent = f"projects/{project_id}/locations/{region}/lakes/{lake_id}" client = self.get_dataplex_client() result = client.list_tasks( request={ - 'parent': parent, - 'page_size': page_size, - 'page_token': page_token, - 'filter': filter, - 'order_by': order_by, + "parent": parent, + "page_size": page_size, + "page_token": page_token, + "filter": filter, + "order_by": order_by, }, retry=retry, timeout=timeout, @@ -219,9 +221,9 @@ def get_task( region: str, lake_id: str, dataplex_task_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Any: """ Get task resource. @@ -236,11 +238,121 @@ def get_task( Note that if `retry` is specified, the timeout applies to each individual attempt. :param metadata: Additional metadata that is provided to the method. """ - name = f'projects/{project_id}/locations/{region}/lakes/{lake_id}/tasks/{dataplex_task_id}' + name = f"projects/{project_id}/locations/{region}/lakes/{lake_id}/tasks/{dataplex_task_id}" client = self.get_dataplex_client() result = client.get_task( request={ - 'name': name, + "name": name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def delete_lake( + self, + project_id: str, + region: str, + lake_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Any: + """ + Delete the lake resource. + + :param project_id: Required. The ID of the Google Cloud project that the lake belongs to. + :param region: Required. The ID of the Google Cloud region that the lake belongs to. + :param lake_id: Required. The ID of the Google Cloud lake to be deleted. + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + name = f"projects/{project_id}/locations/{region}/lakes/{lake_id}" + + client = self.get_dataplex_client() + result = client.delete_lake( + request={ + "name": name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def create_lake( + self, + project_id: str, + region: str, + lake_id: str, + body: dict[str, Any] | Lake, + validate_only: bool | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Any: + """ + Creates a lake resource. + + :param project_id: Required. The ID of the Google Cloud project that the lake belongs to. + :param region: Required. The ID of the Google Cloud region that the lake belongs to. + :param lake_id: Required. Lake identifier. + :param body: Required. The Request body contains an instance of Lake. + :param validate_only: Optional. Only validate the request, but do not perform mutations. + The default is false. + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + parent = f"projects/{project_id}/locations/{region}" + client = self.get_dataplex_client() + result = client.create_lake( + request={ + "parent": parent, + "lake_id": lake_id, + "lake": body, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_lake( + self, + project_id: str, + region: str, + lake_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Any: + """ + Get lake resource. + + :param project_id: Required. The ID of the Google Cloud project that the lake belongs to. + :param region: Required. The ID of the Google Cloud region that the lake belongs to. + :param lake_id: Required. The ID of the Google Cloud lake to be retrieved. + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + name = f"projects/{project_id}/locations/{region}/lakes/{lake_id}/" + client = self.get_dataplex_client() + result = client.get_lake( + request={ + "name": name, }, retry=retry, timeout=timeout, diff --git a/airflow/providers/google/cloud/hooks/dataprep.py b/airflow/providers/google/cloud/hooks/dataprep.py index 945aefe42ece2..c7cbc3b55157b 100644 --- a/airflow/providers/google/cloud/hooks/dataprep.py +++ b/airflow/providers/google/cloud/hooks/dataprep.py @@ -16,9 +16,12 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Dataprep hook.""" +from __future__ import annotations + import json -import os -from typing import Any, Dict +from enum import Enum +from typing import Any +from urllib.parse import urljoin import requests from requests import HTTPError @@ -27,6 +30,31 @@ from airflow.hooks.base import BaseHook +def _get_field(extras: dict, field_name: str): + """Get field from extra, first checking short name, then for backcompat we check for prefixed name.""" + backcompat_prefix = "extra__dataprep__" + if field_name.startswith("extra__"): + raise ValueError( + f"Got prefixed name {field_name}; please remove the '{backcompat_prefix}' prefix " + "when using this method." + ) + if field_name in extras: + return extras[field_name] or None + prefixed_name = f"{backcompat_prefix}{field_name}" + return extras.get(prefixed_name) or None + + +class JobGroupStatuses(str, Enum): + """Types of job group run statuses.""" + + CREATED = "Created" + UNDEFINED = "undefined" + IN_PROGRESS = "InProgress" + COMPLETE = "Complete" + FAILED = "Failed" + CANCELED = "Canceled" + + class GoogleDataprepHook(BaseHook): """ Hook for connection with Dataprep API. @@ -37,21 +65,21 @@ class GoogleDataprepHook(BaseHook): """ - conn_name_attr = 'dataprep_conn_id' - default_conn_name = 'google_cloud_dataprep_default' - conn_type = 'dataprep' - hook_name = 'Google Dataprep' + conn_name_attr = "dataprep_conn_id" + default_conn_name = "google_cloud_dataprep_default" + conn_type = "dataprep" + hook_name = "Google Dataprep" def __init__(self, dataprep_conn_id: str = default_conn_name) -> None: super().__init__() self.dataprep_conn_id = dataprep_conn_id conn = self.get_connection(self.dataprep_conn_id) - extra_dejson = conn.extra_dejson - self._token = extra_dejson.get("extra__dataprep__token") - self._base_url = extra_dejson.get("extra__dataprep__base_url", "https://api.clouddataprep.com") + extras = conn.extra_dejson + self._token = _get_field(extras, "token") + self._base_url = _get_field(extras, "base_url") or "https://api.clouddataprep.com" @property - def _headers(self) -> Dict[str, str]: + def _headers(self) -> dict[str, str]: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self._token}", @@ -59,20 +87,20 @@ def _headers(self) -> Dict[str, str]: return headers @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10)) - def get_jobs_for_job_group(self, job_id: int) -> Dict[str, Any]: + def get_jobs_for_job_group(self, job_id: int) -> dict[str, Any]: """ Get information about the batch jobs within a Cloud Dataprep job. :param job_id: The ID of the job that will be fetched """ endpoint_path = f"v4/jobGroups/{job_id}/jobs" - url: str = os.path.join(self._base_url, endpoint_path) + url: str = urljoin(self._base_url, endpoint_path) response = requests.get(url, headers=self._headers) self._raise_for_status(response) return response.json() @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10)) - def get_job_group(self, job_group_id: int, embed: str, include_deleted: bool) -> Dict[str, Any]: + def get_job_group(self, job_group_id: int, embed: str, include_deleted: bool) -> dict[str, Any]: """ Get the specified job group. A job group is a job that is executed from a specific node in a flow. @@ -81,15 +109,15 @@ def get_job_group(self, job_group_id: int, embed: str, include_deleted: bool) -> :param embed: Comma-separated list of objects to pull in as part of the response :param include_deleted: if set to "true", will include deleted objects """ - params: Dict[str, Any] = {"embed": embed, "includeDeleted": include_deleted} + params: dict[str, Any] = {"embed": embed, "includeDeleted": include_deleted} endpoint_path = f"v4/jobGroups/{job_group_id}" - url: str = os.path.join(self._base_url, endpoint_path) + url: str = urljoin(self._base_url, endpoint_path) response = requests.get(url, headers=self._headers, params=params) self._raise_for_status(response) return response.json() @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10)) - def run_job_group(self, body_request: dict) -> Dict[str, Any]: + def run_job_group(self, body_request: dict) -> dict[str, Any]: """ Creates a ``jobGroup``, which launches the specified job as the authenticated user. This performs the same action as clicking on the Run Job button in the application. @@ -99,14 +127,76 @@ def run_job_group(self, body_request: dict) -> Dict[str, Any]: :param body_request: The identifier for the recipe you would like to run. """ endpoint_path = "v4/jobGroups" - url: str = os.path.join(self._base_url, endpoint_path) + url: str = urljoin(self._base_url, endpoint_path) response = requests.post(url, headers=self._headers, data=json.dumps(body_request)) self._raise_for_status(response) return response.json() + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10)) + def copy_flow( + self, *, flow_id: int, name: str = "", description: str = "", copy_datasources: bool = False + ) -> dict: + """ + Create a copy of the provided flow id, as well as all contained recipes. + + :param flow_id: ID of the flow to be copied + :param name: Name for the copy of the flow + :param description: Description of the copy of the flow + :param copy_datasources: Bool value to define should copies of data inputs be made or not. + """ + endpoint_path = f"v4/flows/{flow_id}/copy" + url: str = urljoin(self._base_url, endpoint_path) + body_request = { + "name": name, + "description": description, + "copyDatasources": copy_datasources, + } + response = requests.post(url, headers=self._headers, data=json.dumps(body_request)) + self._raise_for_status(response) + return response.json() + + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10)) + def delete_flow(self, *, flow_id: int) -> None: + """ + Delete the flow with the provided id. + + :param flow_id: ID of the flow to be copied + """ + endpoint_path = f"v4/flows/{flow_id}" + url: str = urljoin(self._base_url, endpoint_path) + response = requests.delete(url, headers=self._headers) + self._raise_for_status(response) + + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10)) + def run_flow(self, *, flow_id: int, body_request: dict) -> dict: + """ + Runs the flow with the provided id copy of the provided flow id. + + :param flow_id: ID of the flow to be copied + :param body_request: Body of the POST request to be sent. + """ + endpoint = f"v4/flows/{flow_id}/run" + url: str = urljoin(self._base_url, endpoint) + response = requests.post(url, headers=self._headers, data=json.dumps(body_request)) + self._raise_for_status(response) + return response.json() + + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10)) + def get_job_group_status(self, *, job_group_id: int) -> JobGroupStatuses: + """ + Check the status of the Dataprep task to be finished. + + :param job_group_id: ID of the job group to check + """ + endpoint = f"/v4/jobGroups/{job_group_id}/status" + url: str = urljoin(self._base_url, endpoint) + response = requests.get(url, headers=self._headers) + self._raise_for_status(response) + return response.json() + def _raise_for_status(self, response: requests.models.Response) -> None: try: response.raise_for_status() except HTTPError: - self.log.error(response.json().get('exception')) + self.log.error(response.json().get("exception")) raise diff --git a/airflow/providers/google/cloud/hooks/dataproc.py b/airflow/providers/google/cloud/hooks/dataproc.py index d870d80726669..a3e32831bb06d 100644 --- a/airflow/providers/google/cloud/hooks/dataproc.py +++ b/airflow/providers/google/cloud/hooks/dataproc.py @@ -15,27 +15,32 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains a Google Cloud Dataproc hook.""" +from __future__ import annotations import time import uuid -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Sequence from google.api_core.client_options import ClientOptions from google.api_core.exceptions import ServerError from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.operation import Operation +from google.api_core.operation_async import AsyncOperation from google.api_core.retry import Retry from google.cloud.dataproc_v1 import ( Batch, + BatchControllerAsyncClient, BatchControllerClient, Cluster, + ClusterControllerAsyncClient, ClusterControllerClient, Job, + JobControllerAsyncClient, JobControllerClient, JobStatus, WorkflowTemplate, + WorkflowTemplateServiceAsyncClient, WorkflowTemplateServiceClient, ) from google.protobuf.duration_pb2 import Duration @@ -56,22 +61,22 @@ def __init__( task_id: str, cluster_name: str, job_type: str, - properties: Optional[Dict[str, str]] = None, + properties: dict[str, str] | None = None, ) -> None: name = f"{task_id.replace('.', '_')}_{uuid.uuid4()!s:.8}" self.job_type = job_type - self.job = { + self.job: dict[str, Any] = { "job": { "reference": {"project_id": project_id, "job_id": name}, "placement": {"cluster_name": cluster_name}, - "labels": {'airflow-version': 'v' + airflow_version.replace('.', '-').replace('+', '-')}, + "labels": {"airflow-version": "v" + airflow_version.replace(".", "-").replace("+", "-")}, job_type: {}, } - } # type: Dict[str, Any] + } if properties is not None: self.job["job"][job_type]["properties"] = properties - def add_labels(self, labels: Optional[dict] = None) -> None: + def add_labels(self, labels: dict | None = None) -> None: """ Set labels for Dataproc job. @@ -80,7 +85,7 @@ def add_labels(self, labels: Optional[dict] = None) -> None: if labels: self.job["job"]["labels"].update(labels) - def add_variables(self, variables: Optional[Dict] = None) -> None: + def add_variables(self, variables: dict | None = None) -> None: """ Set variables for Dataproc job. @@ -89,7 +94,7 @@ def add_variables(self, variables: Optional[Dict] = None) -> None: if variables is not None: self.job["job"][self.job_type]["script_variables"] = variables - def add_args(self, args: Optional[List[str]] = None) -> None: + def add_args(self, args: list[str] | None = None) -> None: """ Set args for Dataproc job. @@ -104,7 +109,7 @@ def add_query(self, query: str) -> None: :param query: query for the job. """ - self.job["job"][self.job_type]["query_list"] = {'queries': [query]} + self.job["job"][self.job_type]["query_list"] = {"queries": [query]} def add_query_uri(self, query_uri: str) -> None: """ @@ -114,7 +119,7 @@ def add_query_uri(self, query_uri: str) -> None: """ self.job["job"][self.job_type]["query_file_uri"] = query_uri - def add_jar_file_uris(self, jars: Optional[List[str]] = None) -> None: + def add_jar_file_uris(self, jars: list[str] | None = None) -> None: """ Set jars uris for Dataproc job. @@ -123,7 +128,7 @@ def add_jar_file_uris(self, jars: Optional[List[str]] = None) -> None: if jars is not None: self.job["job"][self.job_type]["jar_file_uris"] = jars - def add_archive_uris(self, archives: Optional[List[str]] = None) -> None: + def add_archive_uris(self, archives: list[str] | None = None) -> None: """ Set archives uris for Dataproc job. @@ -132,7 +137,7 @@ def add_archive_uris(self, archives: Optional[List[str]] = None) -> None: if archives is not None: self.job["job"][self.job_type]["archive_uris"] = archives - def add_file_uris(self, files: Optional[List[str]] = None) -> None: + def add_file_uris(self, files: list[str] | None = None) -> None: """ Set file uris for Dataproc job. @@ -141,7 +146,7 @@ def add_file_uris(self, files: Optional[List[str]] = None) -> None: if files is not None: self.job["job"][self.job_type]["file_uris"] = files - def add_python_file_uris(self, pyfiles: Optional[List[str]] = None) -> None: + def add_python_file_uris(self, pyfiles: list[str] | None = None) -> None: """ Set python file uris for Dataproc job. @@ -150,7 +155,7 @@ def add_python_file_uris(self, pyfiles: Optional[List[str]] = None) -> None: if pyfiles is not None: self.job["job"][self.job_type]["python_file_uris"] = pyfiles - def set_main(self, main_jar: Optional[str] = None, main_class: Optional[str] = None) -> None: + def set_main(self, main_jar: str | None = None, main_class: str | None = None) -> None: """ Set Dataproc main class. @@ -182,12 +187,11 @@ def set_job_name(self, name: str) -> None: sanitized_name = f"{name.replace('.', '_')}_{uuid.uuid4()!s:.8}" self.job["job"]["reference"]["job_id"] = sanitized_name - def build(self) -> Dict: + def build(self) -> dict: """ Returns Dataproc job. :return: Dataproc job - :rtype: dict """ return self.job @@ -200,50 +204,63 @@ class DataprocHook(GoogleBaseHook): keyword arguments rather than positional. """ - def get_cluster_client(self, region: Optional[str] = None) -> ClusterControllerClient: + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + ) -> None: + super().__init__(gcp_conn_id, delegate_to, impersonation_chain) + + def get_cluster_client(self, region: str | None = None) -> ClusterControllerClient: """Returns ClusterControllerClient.""" client_options = None - if region and region != 'global': - client_options = ClientOptions(api_endpoint=f'{region}-dataproc.googleapis.com:443') + if region and region != "global": + client_options = ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443") return ClusterControllerClient( - credentials=self._get_credentials(), client_info=CLIENT_INFO, client_options=client_options + credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options ) - def get_template_client(self, region: Optional[str] = None) -> WorkflowTemplateServiceClient: + def get_template_client(self, region: str | None = None) -> WorkflowTemplateServiceClient: """Returns WorkflowTemplateServiceClient.""" client_options = None - if region and region != 'global': - client_options = ClientOptions(api_endpoint=f'{region}-dataproc.googleapis.com:443') + if region and region != "global": + client_options = ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443") return WorkflowTemplateServiceClient( - credentials=self._get_credentials(), client_info=CLIENT_INFO, client_options=client_options + credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options ) - def get_job_client(self, region: Optional[str] = None) -> JobControllerClient: + def get_job_client(self, region: str | None = None) -> JobControllerClient: """Returns JobControllerClient.""" client_options = None - if region and region != 'global': - client_options = ClientOptions(api_endpoint=f'{region}-dataproc.googleapis.com:443') + if region and region != "global": + client_options = ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443") return JobControllerClient( - credentials=self._get_credentials(), client_info=CLIENT_INFO, client_options=client_options + credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options ) - def get_batch_client(self, region: Optional[str] = None) -> BatchControllerClient: + def get_batch_client(self, region: str | None = None) -> BatchControllerClient: """Returns BatchControllerClient""" client_options = None - if region and region != 'global': - client_options = ClientOptions(api_endpoint=f'{region}-dataproc.googleapis.com:443') + if region and region != "global": + client_options = ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443") return BatchControllerClient( - credentials=self._get_credentials(), client_info=CLIENT_INFO, client_options=client_options + credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options ) - def wait_for_operation(self, operation: Operation, timeout: Optional[float] = None): + def wait_for_operation( + self, + operation: Operation, + timeout: float | None = None, + result_retry: Retry | _MethodDefault = DEFAULT, + ): """Waits for long-lasting operation to complete.""" try: - return operation.result(timeout=timeout) + return operation.result(timeout=timeout, retry=result_retry) except Exception: error = operation.exception(timeout=timeout) raise AirflowException(error) @@ -254,13 +271,13 @@ def create_cluster( region: str, project_id: str, cluster_name: str, - cluster_config: Union[Dict, Cluster, None] = None, - virtual_cluster_config: Optional[Dict] = None, - labels: Optional[Dict[str, str]] = None, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + cluster_config: dict | Cluster | None = None, + virtual_cluster_config: dict | None = None, + labels: dict[str, str] | None = None, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Creates a cluster in a project. @@ -289,25 +306,25 @@ def create_cluster( # [a-z]([-a-z0-9]*[a-z0-9])? (current airflow version string follows # semantic versioning spec: x.y.z). labels = labels or {} - labels.update({'airflow-version': 'v' + airflow_version.replace('.', '-').replace('+', '-')}) + labels.update({"airflow-version": "v" + airflow_version.replace(".", "-").replace("+", "-")}) cluster = { "project_id": project_id, "cluster_name": cluster_name, } if virtual_cluster_config is not None: - cluster['virtual_cluster_config'] = virtual_cluster_config # type: ignore + cluster["virtual_cluster_config"] = virtual_cluster_config # type: ignore if cluster_config is not None: - cluster['config'] = cluster_config # type: ignore - cluster['labels'] = labels # type: ignore + cluster["config"] = cluster_config # type: ignore + cluster["labels"] = labels # type: ignore client = self.get_cluster_client(region=region) result = client.create_cluster( request={ - 'project_id': project_id, - 'region': region, - 'cluster': cluster, - 'request_id': request_id, + "project_id": project_id, + "region": region, + "cluster": cluster, + "request_id": request_id, }, retry=retry, timeout=timeout, @@ -321,11 +338,11 @@ def delete_cluster( region: str, cluster_name: str, project_id: str, - cluster_uuid: Optional[str] = None, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + cluster_uuid: str | None = None, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Deletes a cluster in a project. @@ -347,11 +364,11 @@ def delete_cluster( client = self.get_cluster_client(region=region) result = client.delete_cluster( request={ - 'project_id': project_id, - 'region': region, - 'cluster_name': cluster_name, - 'cluster_uuid': cluster_uuid, - 'request_id': request_id, + "project_id": project_id, + "region": region, + "cluster_name": cluster_name, + "cluster_uuid": cluster_uuid, + "request_id": request_id, }, retry=retry, timeout=timeout, @@ -365,9 +382,9 @@ def diagnose_cluster( region: str, cluster_name: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Gets cluster diagnostic information. After the operation completes GCS uri to @@ -384,7 +401,7 @@ def diagnose_cluster( """ client = self.get_cluster_client(region=region) operation = client.diagnose_cluster( - request={'project_id': project_id, 'region': region, 'cluster_name': cluster_name}, + request={"project_id": project_id, "region": region, "cluster_name": cluster_name}, retry=retry, timeout=timeout, metadata=metadata, @@ -399,9 +416,9 @@ def get_cluster( region: str, cluster_name: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Gets the resource representation for a cluster in a project. @@ -417,7 +434,7 @@ def get_cluster( """ client = self.get_cluster_client(region=region) result = client.get_cluster( - request={'project_id': project_id, 'region': region, 'cluster_name': cluster_name}, + request={"project_id": project_id, "region": region, "cluster_name": cluster_name}, retry=retry, timeout=timeout, metadata=metadata, @@ -430,10 +447,10 @@ def list_clusters( region: str, filter_: str, project_id: str, - page_size: Optional[int] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + page_size: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Lists all regions/{region}/clusters in a project. @@ -452,7 +469,7 @@ def list_clusters( """ client = self.get_cluster_client(region=region) result = client.list_clusters( - request={'project_id': project_id, 'region': region, 'filter': filter_, 'page_size': page_size}, + request={"project_id": project_id, "region": region, "filter": filter_, "page_size": page_size}, retry=retry, timeout=timeout, metadata=metadata, @@ -463,15 +480,15 @@ def list_clusters( def update_cluster( self, cluster_name: str, - cluster: Union[Dict, Cluster], - update_mask: Union[Dict, FieldMask], + cluster: dict | Cluster, + update_mask: dict | FieldMask, project_id: str, region: str, - graceful_decommission_timeout: Optional[Union[Dict, Duration]] = None, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + graceful_decommission_timeout: dict | Duration | None = None, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Updates a cluster in a project. @@ -526,13 +543,13 @@ def update_cluster( client = self.get_cluster_client(region=region) operation = client.update_cluster( request={ - 'project_id': project_id, - 'region': region, - 'cluster_name': cluster_name, - 'cluster': cluster, - 'update_mask': update_mask, - 'graceful_decommission_timeout': graceful_decommission_timeout, - 'request_id': request_id, + "project_id": project_id, + "region": region, + "cluster_name": cluster_name, + "cluster": cluster, + "update_mask": update_mask, + "graceful_decommission_timeout": graceful_decommission_timeout, + "request_id": request_id, }, retry=retry, timeout=timeout, @@ -543,12 +560,12 @@ def update_cluster( @GoogleBaseHook.fallback_to_default_project_id def create_workflow_template( self, - template: Union[Dict, WorkflowTemplate], + template: dict | WorkflowTemplate, project_id: str, region: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> WorkflowTemplate: """ Creates new workflow template. @@ -567,9 +584,9 @@ def create_workflow_template( raise TypeError("missing 1 required keyword argument: 'region'") metadata = metadata or () client = self.get_template_client(region) - parent = f'projects/{project_id}/regions/{region}' + parent = f"projects/{project_id}/regions/{region}" return client.create_workflow_template( - request={'parent': parent, 'template': template}, retry=retry, timeout=timeout, metadata=metadata + request={"parent": parent, "template": template}, retry=retry, timeout=timeout, metadata=metadata ) @GoogleBaseHook.fallback_to_default_project_id @@ -578,12 +595,12 @@ def instantiate_workflow_template( template_name: str, project_id: str, region: str, - version: Optional[int] = None, - request_id: Optional[str] = None, - parameters: Optional[Dict[str, str]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + version: int | None = None, + request_id: str | None = None, + parameters: dict[str, str] | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Instantiates a template and begins execution. @@ -611,9 +628,9 @@ def instantiate_workflow_template( raise TypeError("missing 1 required keyword argument: 'region'") metadata = metadata or () client = self.get_template_client(region) - name = f'projects/{project_id}/regions/{region}/workflowTemplates/{template_name}' + name = f"projects/{project_id}/regions/{region}/workflowTemplates/{template_name}" operation = client.instantiate_workflow_template( - request={'name': name, 'version': version, 'request_id': request_id, 'parameters': parameters}, + request={"name": name, "version": version, "request_id": request_id, "parameters": parameters}, retry=retry, timeout=timeout, metadata=metadata, @@ -623,13 +640,13 @@ def instantiate_workflow_template( @GoogleBaseHook.fallback_to_default_project_id def instantiate_inline_workflow_template( self, - template: Union[Dict, WorkflowTemplate], + template: dict | WorkflowTemplate, project_id: str, region: str, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Instantiates a template and begins execution. @@ -651,9 +668,9 @@ def instantiate_inline_workflow_template( raise TypeError("missing 1 required keyword argument: 'region'") metadata = metadata or () client = self.get_template_client(region) - parent = f'projects/{project_id}/regions/{region}' + parent = f"projects/{project_id}/regions/{region}" operation = client.instantiate_inline_workflow_template( - request={'parent': parent, 'template': template, 'request_id': request_id}, + request={"parent": parent, "template": template, "request_id": request_id}, retry=retry, timeout=timeout, metadata=metadata, @@ -667,7 +684,7 @@ def wait_for_job( project_id: str, region: str, wait_time: int = 10, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> None: """ Helper method which polls a job to check if it finishes. @@ -693,9 +710,9 @@ def wait_for_job( self.log.info("Retrying. Dataproc API returned server error when waiting for job: %s", err) if state == JobStatus.State.ERROR: - raise AirflowException(f'Job failed:\n{job}') + raise AirflowException(f"Job failed:\n{job}") if state == JobStatus.State.CANCELLED: - raise AirflowException(f'Job was cancelled:\n{job}') + raise AirflowException(f"Job was cancelled:\n{job}") @GoogleBaseHook.fallback_to_default_project_id def get_job( @@ -703,9 +720,9 @@ def get_job( job_id: str, project_id: str, region: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Job: """ Gets the resource representation for a job in a project. @@ -723,7 +740,7 @@ def get_job( raise TypeError("missing 1 required keyword argument: 'region'") client = self.get_job_client(region=region) job = client.get_job( - request={'project_id': project_id, 'region': region, 'job_id': job_id}, + request={"project_id": project_id, "region": region, "job_id": job_id}, retry=retry, timeout=timeout, metadata=metadata, @@ -733,13 +750,13 @@ def get_job( @GoogleBaseHook.fallback_to_default_project_id def submit_job( self, - job: Union[dict, Job], + job: dict | Job, project_id: str, region: str, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Job: """ Submits a job to a cluster. @@ -761,7 +778,7 @@ def submit_job( raise TypeError("missing 1 required keyword argument: 'region'") client = self.get_job_client(region=region) return client.submit_job( - request={'project_id': project_id, 'region': region, 'job': job, 'request_id': request_id}, + request={"project_id": project_id, "region": region, "job": job, "request_id": request_id}, retry=retry, timeout=timeout, metadata=metadata, @@ -772,10 +789,10 @@ def cancel_job( self, job_id: str, project_id: str, - region: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + region: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Job: """ Starts a job cancellation request. @@ -792,7 +809,7 @@ def cancel_job( client = self.get_job_client(region=region) job = client.cancel_job( - request={'project_id': project_id, 'region': region, 'job_id': job_id}, + request={"project_id": project_id, "region": region, "job_id": job_id}, retry=retry, timeout=timeout, metadata=metadata, @@ -804,12 +821,12 @@ def create_batch( self, region: str, project_id: str, - batch: Union[Dict, Batch], - batch_id: Optional[str] = None, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + batch: dict | Batch, + batch_id: str | None = None, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Creates a batch workload. @@ -830,14 +847,14 @@ def create_batch( :param metadata: Additional metadata that is provided to the method. """ client = self.get_batch_client(region) - parent = f'projects/{project_id}/regions/{region}' + parent = f"projects/{project_id}/regions/{region}" result = client.create_batch( request={ - 'parent': parent, - 'batch': batch, - 'batch_id': batch_id, - 'request_id': request_id, + "parent": parent, + "batch": batch, + "batch_id": batch_id, + "request_id": request_id, }, retry=retry, timeout=timeout, @@ -851,9 +868,9 @@ def delete_batch( batch_id: str, region: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Deletes the batch workload resource. @@ -874,7 +891,7 @@ def delete_batch( client.delete_batch( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -887,9 +904,9 @@ def get_batch( batch_id: str, region: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Batch: """ Gets the batch workload resource representation. @@ -910,7 +927,7 @@ def get_batch( result = client.get_batch( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -923,11 +940,11 @@ def list_batches( self, region: str, project_id: str, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + page_size: int | None = None, + page_token: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Lists batch workloads. @@ -945,13 +962,746 @@ def list_batches( :param metadata: Additional metadata that is provided to the method. """ client = self.get_batch_client(region) - parent = f'projects/{project_id}/regions/{region}' + parent = f"projects/{project_id}/regions/{region}" result = client.list_batches( request={ - 'parent': parent, - 'page_size': page_size, - 'page_token': page_token, + "parent": parent, + "page_size": page_size, + "page_token": page_token, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + +class DataprocAsyncHook(GoogleBaseHook): + """ + Asynchronous Hook for Google Cloud Dataproc APIs. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + ) -> None: + super().__init__(gcp_conn_id, delegate_to, impersonation_chain) + + def get_cluster_client(self, region: str | None = None) -> ClusterControllerAsyncClient: + """Returns ClusterControllerAsyncClient.""" + client_options = None + if region and region != "global": + client_options = ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443") + + return ClusterControllerAsyncClient( + credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options + ) + + def get_template_client(self, region: str | None = None) -> WorkflowTemplateServiceAsyncClient: + """Returns WorkflowTemplateServiceAsyncClient.""" + client_options = None + if region and region != "global": + client_options = ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443") + + return WorkflowTemplateServiceAsyncClient( + credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options + ) + + def get_job_client(self, region: str | None = None) -> JobControllerAsyncClient: + """Returns JobControllerAsyncClient.""" + client_options = None + if region and region != "global": + client_options = ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443") + + return JobControllerAsyncClient( + credentials=self.get_credentials(), + client_info=CLIENT_INFO, + client_options=client_options, + ) + + def get_batch_client(self, region: str | None = None) -> BatchControllerAsyncClient: + """Returns BatchControllerAsyncClient""" + client_options = None + if region and region != "global": + client_options = ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443") + + return BatchControllerAsyncClient( + credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options + ) + + @GoogleBaseHook.fallback_to_default_project_id + async def create_cluster( + self, + region: str, + project_id: str, + cluster_name: str, + cluster_config: dict | Cluster | None = None, + virtual_cluster_config: dict | None = None, + labels: dict[str, str] | None = None, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ): + """ + Creates a cluster in a project. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param cluster_name: Name of the cluster to create + :param labels: Labels that will be assigned to created cluster + :param cluster_config: Required. The cluster config to create. + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.dataproc_v1.types.ClusterConfig` + :param virtual_cluster_config: Optional. The virtual cluster config, used when creating a Dataproc + cluster that does not directly control the underlying compute resources, for example, when + creating a `Dataproc-on-GKE cluster` + :class:`~google.cloud.dataproc_v1.types.VirtualClusterConfig` + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``CreateClusterRequest`` requests with the same id, then the second request will be ignored and + the first ``google.longrunning.Operation`` created and stored in the backend is returned. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + # Dataproc labels must conform to the following regex: + # [a-z]([-a-z0-9]*[a-z0-9])? (current airflow version string follows + # semantic versioning spec: x.y.z). + labels = labels or {} + labels.update({"airflow-version": "v" + airflow_version.replace(".", "-").replace("+", "-")}) + + cluster = { + "project_id": project_id, + "cluster_name": cluster_name, + } + if virtual_cluster_config is not None: + cluster["virtual_cluster_config"] = virtual_cluster_config # type: ignore + if cluster_config is not None: + cluster["config"] = cluster_config # type: ignore + cluster["labels"] = labels # type: ignore + + client = self.get_cluster_client(region=region) + result = await client.create_cluster( + request={ + "project_id": project_id, + "region": region, + "cluster": cluster, + "request_id": request_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + async def delete_cluster( + self, + region: str, + cluster_name: str, + project_id: str, + cluster_uuid: str | None = None, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ): + """ + Deletes a cluster in a project. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param cluster_name: Required. The cluster name. + :param cluster_uuid: Optional. Specifying the ``cluster_uuid`` means the RPC should fail + if cluster with specified UUID does not exist. + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``DeleteClusterRequest`` requests with the same id, then the second request will be ignored and + the first ``google.longrunning.Operation`` created and stored in the backend is returned. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + client = self.get_cluster_client(region=region) + result = client.delete_cluster( + request={ + "project_id": project_id, + "region": region, + "cluster_name": cluster_name, + "cluster_uuid": cluster_uuid, + "request_id": request_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + async def diagnose_cluster( + self, + region: str, + cluster_name: str, + project_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ): + """ + Gets cluster diagnostic information. After the operation completes GCS uri to + diagnose is returned + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param cluster_name: Required. The cluster name. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + client = self.get_cluster_client(region=region) + operation = await client.diagnose_cluster( + request={"project_id": project_id, "region": region, "cluster_name": cluster_name}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + operation.result() + gcs_uri = str(operation.operation.response.value) + return gcs_uri + + @GoogleBaseHook.fallback_to_default_project_id + async def get_cluster( + self, + region: str, + cluster_name: str, + project_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ): + """ + Gets the resource representation for a cluster in a project. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param cluster_name: Required. The cluster name. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + client = self.get_cluster_client(region=region) + result = await client.get_cluster( + request={"project_id": project_id, "region": region, "cluster_name": cluster_name}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + async def list_clusters( + self, + region: str, + filter_: str, + project_id: str, + page_size: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ): + """ + Lists all regions/{region}/clusters in a project. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param filter_: Optional. A filter constraining the clusters to list. Filters are case-sensitive. + :param page_size: The maximum number of resources contained in the underlying API response. If page + streaming is performed per- resource, this parameter does not affect the return value. If page + streaming is performed per-page, this determines the maximum number of resources in a page. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + client = self.get_cluster_client(region=region) + result = await client.list_clusters( + request={"project_id": project_id, "region": region, "filter": filter_, "page_size": page_size}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + async def update_cluster( + self, + cluster_name: str, + cluster: dict | Cluster, + update_mask: dict | FieldMask, + project_id: str, + region: str, + graceful_decommission_timeout: dict | Duration | None = None, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ): + """ + Updates a cluster in a project. + + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param cluster_name: Required. The cluster name. + :param cluster: Required. The changes to the cluster. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.dataproc_v1.types.Cluster` + :param update_mask: Required. Specifies the path, relative to ``Cluster``, of the field to update. For + example, to change the number of workers in a cluster to 5, the ``update_mask`` parameter would be + specified as ``config.worker_config.num_instances``, and the ``PATCH`` request body would specify + the new value, as follows: + + :: + + { "config":{ "workerConfig":{ "numInstances":"5" } } } + + Similarly, to change the number of preemptible workers in a cluster to 5, the ``update_mask`` + parameter would be ``config.secondary_worker_config.num_instances``, and the ``PATCH`` request + body would be set as follows: + + :: + + { "config":{ "secondaryWorkerConfig":{ "numInstances":"5" } } } + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.dataproc_v1.types.FieldMask` + :param graceful_decommission_timeout: Optional. Timeout for graceful YARN decommissioning. Graceful + decommissioning allows removing nodes from the cluster without interrupting jobs in progress. + Timeout specifies how long to wait for jobs in progress to finish before forcefully removing nodes + (and potentially interrupting jobs). Default timeout is 0 (for forceful decommission), and the + maximum allowed timeout is 1 day. + + Only supported on Dataproc image versions 1.2 and higher. + + If a dict is provided, it must be of the same form as the protobuf message + :class:`~google.cloud.dataproc_v1.types.Duration` + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``UpdateClusterRequest`` requests with the same id, then the second request will be ignored and + the first ``google.longrunning.Operation`` created and stored in the backend is returned. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + if region is None: + raise TypeError("missing 1 required keyword argument: 'region'") + client = self.get_cluster_client(region=region) + operation = await client.update_cluster( + request={ + "project_id": project_id, + "region": region, + "cluster_name": cluster_name, + "cluster": cluster, + "update_mask": update_mask, + "graceful_decommission_timeout": graceful_decommission_timeout, + "request_id": request_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return operation + + @GoogleBaseHook.fallback_to_default_project_id + async def create_workflow_template( + self, + template: dict | WorkflowTemplate, + project_id: str, + region: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> WorkflowTemplate: + """ + Creates new workflow template. + + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param template: The Dataproc workflow template to create. If a dict is provided, + it must be of the same form as the protobuf message WorkflowTemplate. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + if region is None: + raise TypeError("missing 1 required keyword argument: 'region'") + metadata = metadata or () + client = self.get_template_client(region) + parent = f"projects/{project_id}/regions/{region}" + return await client.create_workflow_template( + request={"parent": parent, "template": template}, retry=retry, timeout=timeout, metadata=metadata + ) + + @GoogleBaseHook.fallback_to_default_project_id + async def instantiate_workflow_template( + self, + template_name: str, + project_id: str, + region: str, + version: int | None = None, + request_id: str | None = None, + parameters: dict[str, str] | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ): + """ + Instantiates a template and begins execution. + + :param template_name: Name of template to instantiate. + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param version: Optional. The version of workflow template to instantiate. If specified, + the workflow will be instantiated only if the current version of + the workflow template has the supplied version. + This option cannot be used to instantiate a previous version of + workflow template. + :param request_id: Optional. A tag that prevents multiple concurrent workflow instances + with the same tag from running. This mitigates risk of concurrent + instances started due to retries. + :param parameters: Optional. Map from parameter names to values that should be used for those + parameters. Values may not exceed 100 characters. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + if region is None: + raise TypeError("missing 1 required keyword argument: 'region'") + metadata = metadata or () + client = self.get_template_client(region) + name = f"projects/{project_id}/regions/{region}/workflowTemplates/{template_name}" + operation = await client.instantiate_workflow_template( + request={"name": name, "version": version, "request_id": request_id, "parameters": parameters}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return operation + + @GoogleBaseHook.fallback_to_default_project_id + async def instantiate_inline_workflow_template( + self, + template: dict | WorkflowTemplate, + project_id: str, + region: str, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ): + """ + Instantiates a template and begins execution. + + :param template: The workflow template to instantiate. If a dict is provided, + it must be of the same form as the protobuf message WorkflowTemplate + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param request_id: Optional. A tag that prevents multiple concurrent workflow instances + with the same tag from running. This mitigates risk of concurrent + instances started due to retries. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + if region is None: + raise TypeError("missing 1 required keyword argument: 'region'") + metadata = metadata or () + client = self.get_template_client(region) + parent = f"projects/{project_id}/regions/{region}" + operation = await client.instantiate_inline_workflow_template( + request={"parent": parent, "template": template, "request_id": request_id}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return operation + + @GoogleBaseHook.fallback_to_default_project_id + async def get_job( + self, + job_id: str, + project_id: str, + region: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Job: + """ + Gets the resource representation for a job in a project. + + :param job_id: Id of the Dataproc job + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + if region is None: + raise TypeError("missing 1 required keyword argument: 'region'") + client = self.get_job_client(region=region) + job = await client.get_job( + request={"project_id": project_id, "region": region, "job_id": job_id}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return job + + @GoogleBaseHook.fallback_to_default_project_id + async def submit_job( + self, + job: dict | Job, + project_id: str, + region: str, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Job: + """ + Submits a job to a cluster. + + :param job: The job resource. If a dict is provided, + it must be of the same form as the protobuf message Job + :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param request_id: Optional. A tag that prevents multiple concurrent workflow instances + with the same tag from running. This mitigates risk of concurrent + instances started due to retries. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + if region is None: + raise TypeError("missing 1 required keyword argument: 'region'") + client = self.get_job_client(region=region) + return await client.submit_job( + request={"project_id": project_id, "region": region, "job": job, "request_id": request_id}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + async def cancel_job( + self, + job_id: str, + project_id: str, + region: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Job: + """ + Starts a job cancellation request. + + :param project_id: Required. The ID of the Google Cloud project that the job belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param job_id: Required. The job ID. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + client = self.get_job_client(region=region) + + job = await client.cancel_job( + request={"project_id": project_id, "region": region, "job_id": job_id}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return job + + @GoogleBaseHook.fallback_to_default_project_id + async def create_batch( + self, + region: str, + project_id: str, + batch: dict | Batch, + batch_id: str | None = None, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> AsyncOperation: + """ + Creates a batch workload. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param batch: Required. The batch to create. + :param batch_id: Optional. The ID to use for the batch, which will become the final component + of the batch's resource name. + This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/. + :param request_id: Optional. A unique id used to identify the request. If the server receives two + ``CreateBatchRequest`` requests with the same id, then the second request will be ignored and + the first ``google.longrunning.Operation`` created and stored in the backend is returned. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + client = self.get_batch_client(region) + parent = f"projects/{project_id}/regions/{region}" + + result = await client.create_batch( + request={ + "parent": parent, + "batch": batch, + "batch_id": batch_id, + "request_id": request_id, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + async def delete_batch( + self, + batch_id: str, + region: str, + project_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> None: + """ + Deletes the batch workload resource. + + :param batch_id: Required. The ID to use for the batch, which will become the final component + of the batch's resource name. + This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/. + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + client = self.get_batch_client(region) + name = f"projects/{project_id}/regions/{region}/batches/{batch_id}" + + await client.delete_batch( + request={ + "name": name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + async def get_batch( + self, + batch_id: str, + region: str, + project_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Batch: + """ + Gets the batch workload resource representation. + + :param batch_id: Required. The ID to use for the batch, which will become the final component + of the batch's resource name. + This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/. + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + client = self.get_batch_client(region) + name = f"projects/{project_id}/regions/{region}/batches/{batch_id}" + + result = await client.get_batch( + request={ + "name": name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + async def list_batches( + self, + region: str, + project_id: str, + page_size: int | None = None, + page_token: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ): + """ + Lists batch workloads. + + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. + :param region: Required. The Cloud Dataproc region in which to handle the request. + :param page_size: Optional. The maximum number of batches to return in each response. The service may + return fewer than this value. The default page size is 20; the maximum page size is 1000. + :param page_token: Optional. A page token received from a previous ``ListBatches`` call. + Provide this token to retrieve the subsequent page. + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + ``retry`` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + client = self.get_batch_client(region) + parent = f"projects/{project_id}/regions/{region}" + + result = await client.list_batches( + request={ + "parent": parent, + "page_size": page_size, + "page_token": page_token, }, retry=retry, timeout=timeout, diff --git a/airflow/providers/google/cloud/hooks/dataproc_metastore.py b/airflow/providers/google/cloud/hooks/dataproc_metastore.py index 53d031bb050a3..c7dcebee9fcab 100644 --- a/airflow/providers/google/cloud/hooks/dataproc_metastore.py +++ b/airflow/providers/google/cloud/hooks/dataproc_metastore.py @@ -15,10 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains a Google Cloud Dataproc Metastore hook.""" +from __future__ import annotations -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Sequence from google.api_core.client_options import ClientOptions from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -39,13 +39,13 @@ class DataprocMetastoreHook(GoogleBaseHook): def get_dataproc_metastore_client(self) -> DataprocMetastoreClient: """Returns DataprocMetastoreClient.""" - client_options = ClientOptions(api_endpoint='metastore.googleapis.com:443') + client_options = ClientOptions(api_endpoint="metastore.googleapis.com:443") return DataprocMetastoreClient( - credentials=self._get_credentials(), client_info=CLIENT_INFO, client_options=client_options + credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options ) - def wait_for_operation(self, timeout: Optional[float], operation: Operation): + def wait_for_operation(self, timeout: float | None, operation: Operation): """Waits for long-lasting operation to complete.""" try: return operation.result(timeout=timeout) @@ -59,12 +59,12 @@ def create_backup( project_id: str, region: str, service_id: str, - backup: Union[Dict[Any, Any], Backup], + backup: dict[Any, Any] | Backup, backup_id: str, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Creates a new backup in a given project and location. @@ -94,15 +94,15 @@ def create_backup( :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ - parent = f'projects/{project_id}/locations/{region}/services/{service_id}' + parent = f"projects/{project_id}/locations/{region}/services/{service_id}" client = self.get_dataproc_metastore_client() result = client.create_backup( request={ - 'parent': parent, - 'backup': backup, - 'backup_id': backup_id, - 'request_id': request_id, + "parent": parent, + "backup": backup, + "backup_id": backup_id, + "request_id": request_id, }, retry=retry, timeout=timeout, @@ -116,12 +116,12 @@ def create_metadata_import( project_id: str, region: str, service_id: str, - metadata_import: MetadataImport, + metadata_import: dict | MetadataImport, metadata_import_id: str, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Creates a new MetadataImport in a given project and location. @@ -152,15 +152,15 @@ def create_metadata_import( :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ - parent = f'projects/{project_id}/locations/{region}/services/{service_id}' + parent = f"projects/{project_id}/locations/{region}/services/{service_id}" client = self.get_dataproc_metastore_client() result = client.create_metadata_import( request={ - 'parent': parent, - 'metadata_import': metadata_import, - 'metadata_import_id': metadata_import_id, - 'request_id': request_id, + "parent": parent, + "metadata_import": metadata_import, + "metadata_import_id": metadata_import_id, + "request_id": request_id, }, retry=retry, timeout=timeout, @@ -173,12 +173,12 @@ def create_service( self, region: str, project_id: str, - service: Union[Dict, Service], + service: dict | Service, service_id: str, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Creates a metastore service in a project and location. @@ -202,15 +202,15 @@ def create_service( :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ - parent = f'projects/{project_id}/locations/{region}' + parent = f"projects/{project_id}/locations/{region}" client = self.get_dataproc_metastore_client() result = client.create_service( request={ - 'parent': parent, - 'service_id': service_id, - 'service': service if service else {}, - 'request_id': request_id, + "parent": parent, + "service_id": service_id, + "service": service if service else {}, + "request_id": request_id, }, retry=retry, timeout=timeout, @@ -225,10 +225,10 @@ def delete_backup( region: str, service_id: str, backup_id: str, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Deletes a single backup. @@ -253,13 +253,13 @@ def delete_backup( :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ - name = f'projects/{project_id}/locations/{region}/services/{service_id}/backups/{backup_id}' + name = f"projects/{project_id}/locations/{region}/services/{service_id}/backups/{backup_id}" client = self.get_dataproc_metastore_client() result = client.delete_backup( request={ - 'name': name, - 'request_id': request_id, + "name": name, + "request_id": request_id, }, retry=retry, timeout=timeout, @@ -273,10 +273,10 @@ def delete_service( project_id: str, region: str, service_id: str, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Deletes a single service. @@ -295,13 +295,13 @@ def delete_service( :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ - name = f'projects/{project_id}/locations/{region}/services/{service_id}' + name = f"projects/{project_id}/locations/{region}/services/{service_id}" client = self.get_dataproc_metastore_client() result = client.delete_service( request={ - 'name': name, - 'request_id': request_id, + "name": name, + "request_id": request_id, }, retry=retry, timeout=timeout, @@ -316,11 +316,11 @@ def export_metadata( project_id: str, region: str, service_id: str, - request_id: Optional[str] = None, - database_dump_type: Optional[DatabaseDumpSpec] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + request_id: str | None = None, + database_dump_type: DatabaseDumpSpec | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Exports metadata from a service. @@ -345,15 +345,15 @@ def export_metadata( :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ - service = f'projects/{project_id}/locations/{region}/services/{service_id}' + service = f"projects/{project_id}/locations/{region}/services/{service_id}" client = self.get_dataproc_metastore_client() result = client.export_metadata( request={ - 'destination_gcs_folder': destination_gcs_folder, - 'service': service, - 'request_id': request_id, - 'database_dump_type': database_dump_type, + "destination_gcs_folder": destination_gcs_folder, + "service": service, + "request_id": request_id, + "database_dump_type": database_dump_type, }, retry=retry, timeout=timeout, @@ -367,9 +367,9 @@ def get_service( project_id: str, region: str, service_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Gets the details of a single service. @@ -387,12 +387,12 @@ def get_service( :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ - name = f'projects/{project_id}/locations/{region}/services/{service_id}' + name = f"projects/{project_id}/locations/{region}/services/{service_id}" client = self.get_dataproc_metastore_client() result = client.get_service( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -407,9 +407,9 @@ def get_backup( region: str, service_id: str, backup_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Backup: """ Get backup from a service. @@ -428,11 +428,11 @@ def get_backup( :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ - backup = f'projects/{project_id}/locations/{region}/services/{service_id}/backups/{backup_id}' + backup = f"projects/{project_id}/locations/{region}/services/{service_id}/backups/{backup_id}" client = self.get_dataproc_metastore_client() result = client.get_backup( request={ - 'name': backup, + "name": backup, }, retry=retry, timeout=timeout, @@ -446,13 +446,13 @@ def list_backups( project_id: str, region: str, service_id: str, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - filter: Optional[str] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + page_size: int | None = None, + page_token: str | None = None, + filter: str | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Lists backups in a service. @@ -489,16 +489,16 @@ def list_backups( :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ - parent = f'projects/{project_id}/locations/{region}/services/{service_id}/backups' + parent = f"projects/{project_id}/locations/{region}/services/{service_id}/backups" client = self.get_dataproc_metastore_client() result = client.list_backups( request={ - 'parent': parent, - 'page_size': page_size, - 'page_token': page_token, - 'filter': filter, - 'order_by': order_by, + "parent": parent, + "page_size": page_size, + "page_token": page_token, + "filter": filter, + "order_by": order_by, }, retry=retry, timeout=timeout, @@ -516,11 +516,11 @@ def restore_service( backup_region: str, backup_service_id: str, backup_id: str, - restore_type: Optional[Restore] = None, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + restore_type: Restore | None = None, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Restores a service from a backup. @@ -550,19 +550,19 @@ def restore_service( :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ - service = f'projects/{project_id}/locations/{region}/services/{service_id}' + service = f"projects/{project_id}/locations/{region}/services/{service_id}" backup = ( - f'projects/{backup_project_id}/locations/{backup_region}/services/' - f'{backup_service_id}/backups/{backup_id}' + f"projects/{backup_project_id}/locations/{backup_region}/services/" + f"{backup_service_id}/backups/{backup_id}" ) client = self.get_dataproc_metastore_client() result = client.restore_service( request={ - 'service': service, - 'backup': backup, - 'restore_type': restore_type, - 'request_id': request_id, + "service": service, + "backup": backup, + "restore_type": restore_type, + "request_id": request_id, }, retry=retry, timeout=timeout, @@ -576,12 +576,12 @@ def update_service( project_id: str, region: str, service_id: str, - service: Union[Dict, Service], + service: dict | Service, update_mask: FieldMask, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ Updates the parameters of a single service. @@ -615,15 +615,15 @@ def update_service( """ client = self.get_dataproc_metastore_client() - service_name = f'projects/{project_id}/locations/{region}/services/{service_id}' + service_name = f"projects/{project_id}/locations/{region}/services/{service_id}" service["name"] = service_name result = client.update_service( request={ - 'service': service, - 'update_mask': update_mask, - 'request_id': request_id, + "service": service, + "update_mask": update_mask, + "request_id": request_id, }, retry=retry, timeout=timeout, diff --git a/airflow/providers/google/cloud/hooks/datastore.py b/airflow/providers/google/cloud/hooks/datastore.py index f62b3ccfd5b64..eae35da5c0340 100644 --- a/airflow/providers/google/cloud/hooks/datastore.py +++ b/airflow/providers/google/cloud/hooks/datastore.py @@ -15,12 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains Google Datastore hook.""" - +from __future__ import annotations import time -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Sequence from googleapiclient.discovery import Resource, build @@ -40,9 +39,9 @@ class DatastoreHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - api_version: str = 'v1', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + api_version: str = "v1", + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -57,12 +56,11 @@ def get_conn(self) -> Resource: Establishes a connection to the Google API. :return: a Google Cloud Datastore service object. - :rtype: Resource """ if not self.connection: http_authorized = self._authorize() self.connection = build( - 'datastore', self.api_version, http=http_authorized, cache_discovery=False + "datastore", self.api_version, http=http_authorized, cache_discovery=False ) return self.connection @@ -78,20 +76,19 @@ def allocate_ids(self, partial_keys: list, project_id: str) -> list: :param partial_keys: a list of partial keys. :param project_id: Google Cloud project ID against which to make the request. :return: a list of full keys. - :rtype: list """ - conn = self.get_conn() # type: Any + conn = self.get_conn() resp = ( conn.projects() - .allocateIds(projectId=project_id, body={'keys': partial_keys}) + .allocateIds(projectId=project_id, body={"keys": partial_keys}) .execute(num_retries=self.num_retries) ) - return resp['keys'] + return resp["keys"] @GoogleBaseHook.fallback_to_default_project_id - def begin_transaction(self, project_id: str, transaction_options: Dict[str, Any]) -> str: + def begin_transaction(self, project_id: str, transaction_options: dict[str, Any]) -> str: """ Begins a new transaction. @@ -101,9 +98,8 @@ def begin_transaction(self, project_id: str, transaction_options: Dict[str, Any] :param project_id: Google Cloud project ID against which to make the request. :param transaction_options: Options for a new transaction. :return: a transaction handle. - :rtype: str """ - conn = self.get_conn() # type: Any + conn = self.get_conn() resp = ( conn.projects() @@ -111,7 +107,7 @@ def begin_transaction(self, project_id: str, transaction_options: Dict[str, Any] .execute(num_retries=self.num_retries) ) - return resp['transaction'] + return resp["transaction"] @GoogleBaseHook.fallback_to_default_project_id def commit(self, body: dict, project_id: str) -> dict: @@ -124,9 +120,8 @@ def commit(self, body: dict, project_id: str) -> dict: :param body: the body of the commit request. :param project_id: Google Cloud project ID against which to make the request. :return: the response body of the commit request. - :rtype: dict """ - conn = self.get_conn() # type: Any + conn = self.get_conn() resp = conn.projects().commit(projectId=project_id, body=body).execute(num_retries=self.num_retries) @@ -137,8 +132,8 @@ def lookup( self, keys: list, project_id: str, - read_consistency: Optional[str] = None, - transaction: Optional[str] = None, + read_consistency: str | None = None, + transaction: str | None = None, ) -> dict: """ Lookup some entities by key. @@ -152,15 +147,14 @@ def lookup( :param transaction: the transaction to use, if any. :param project_id: Google Cloud project ID against which to make the request. :return: the response body of the lookup request. - :rtype: dict """ - conn = self.get_conn() # type: Any + conn = self.get_conn() - body = {'keys': keys} # type: Dict[str, Any] + body: dict[str, Any] = {"keys": keys} if read_consistency: - body['readConsistency'] = read_consistency + body["readConsistency"] = read_consistency if transaction: - body['transaction'] = transaction + body["transaction"] = transaction resp = conn.projects().lookup(projectId=project_id, body=body).execute(num_retries=self.num_retries) return resp @@ -178,7 +172,7 @@ def rollback(self, transaction: str, project_id: str) -> None: """ conn: Any = self.get_conn() - conn.projects().rollback(projectId=project_id, body={'transaction': transaction}).execute( + conn.projects().rollback(projectId=project_id, body={"transaction": transaction}).execute( num_retries=self.num_retries ) @@ -193,13 +187,12 @@ def run_query(self, body: dict, project_id: str) -> dict: :param body: the body of the query request. :param project_id: Google Cloud project ID against which to make the request. :return: the batch of query results. - :rtype: dict """ - conn = self.get_conn() # type: Any + conn = self.get_conn() resp = conn.projects().runQuery(projectId=project_id, body=body).execute(num_retries=self.num_retries) - return resp['batch'] + return resp["batch"] def get_operation(self, name: str) -> dict: """ @@ -210,7 +203,6 @@ def get_operation(self, name: str) -> dict: :param name: the name of the operation resource. :return: a resource operation instance. - :rtype: dict """ conn: Any = self.get_conn() @@ -227,30 +219,28 @@ def delete_operation(self, name: str) -> dict: :param name: the name of the operation resource. :return: none if successful. - :rtype: dict """ - conn = self.get_conn() # type: Any + conn = self.get_conn() resp = conn.projects().operations().delete(name=name).execute(num_retries=self.num_retries) return resp - def poll_operation_until_done(self, name: str, polling_interval_in_seconds: float) -> Dict: + def poll_operation_until_done(self, name: str, polling_interval_in_seconds: float) -> dict: """ Poll backup operation state until it's completed. :param name: the name of the operation resource :param polling_interval_in_seconds: The number of seconds to wait before calling another request. :return: a resource operation instance. - :rtype: dict """ while True: - result: Dict = self.get_operation(name) + result: dict = self.get_operation(name) - state: str = result['metadata']['common']['state'] - if state == 'PROCESSING': + state: str = result["metadata"]["common"]["state"] + if state == "PROCESSING": self.log.info( - 'Operation is processing. Re-polling state in %s seconds', polling_interval_in_seconds + "Operation is processing. Re-polling state in %s seconds", polling_interval_in_seconds ) time.sleep(polling_interval_in_seconds) else: @@ -261,9 +251,9 @@ def export_to_storage_bucket( self, bucket: str, project_id: str, - namespace: Optional[str] = None, - entity_filter: Optional[dict] = None, - labels: Optional[Dict[str, str]] = None, + namespace: str | None = None, + entity_filter: dict | None = None, + labels: dict[str, str] | None = None, ) -> dict: """ Export entities from Cloud Datastore to Cloud Storage for backup. @@ -280,20 +270,19 @@ def export_to_storage_bucket( :param labels: Client-assigned labels. :param project_id: Google Cloud project ID against which to make the request. :return: a resource operation instance. - :rtype: dict """ - admin_conn = self.get_conn() # type: Any + admin_conn = self.get_conn() - output_uri_prefix = 'gs://' + '/'.join(filter(None, [bucket, namespace])) # type: str + output_url_prefix = f"gs://{'/'.join(filter(None, [bucket, namespace]))}" if not entity_filter: entity_filter = {} if not labels: labels = {} body = { - 'outputUrlPrefix': output_uri_prefix, - 'entityFilter': entity_filter, - 'labels': labels, - } # type: Dict + "outputUrlPrefix": output_url_prefix, + "entityFilter": entity_filter, + "labels": labels, + } resp = ( admin_conn.projects() .export(projectId=project_id, body=body) @@ -308,9 +297,9 @@ def import_from_storage_bucket( bucket: str, file: str, project_id: str, - namespace: Optional[str] = None, - entity_filter: Optional[dict] = None, - labels: Optional[Union[dict, str]] = None, + namespace: str | None = None, + entity_filter: dict | None = None, + labels: dict | str | None = None, ) -> dict: """ Import a backup from Cloud Storage to Cloud Datastore. @@ -328,20 +317,19 @@ def import_from_storage_bucket( :param labels: Client-assigned labels. :param project_id: Google Cloud project ID against which to make the request. :return: a resource operation instance. - :rtype: dict """ - admin_conn = self.get_conn() # type: Any + admin_conn = self.get_conn() - input_url = 'gs://' + '/'.join(filter(None, [bucket, namespace, file])) # type: str + input_url = f"gs://{'/'.join(filter(None, [bucket, namespace, file]))}" if not entity_filter: entity_filter = {} if not labels: labels = {} body = { - 'inputUrl': input_url, - 'entityFilter': entity_filter, - 'labels': labels, - } # type: Dict + "inputUrl": input_url, + "entityFilter": entity_filter, + "labels": labels, + } resp = ( admin_conn.projects() .import_(projectId=project_id, body=body) diff --git a/airflow/providers/google/cloud/hooks/dlp.py b/airflow/providers/google/cloud/hooks/dlp.py index d0a14909ce333..41d595a984e08 100644 --- a/airflow/providers/google/cloud/hooks/dlp.py +++ b/airflow/providers/google/cloud/hooks/dlp.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ This module contains a CloudDLPHook which allows you to connect to Google Cloud DLP service. @@ -25,10 +24,11 @@ ImageRedactionConfig RedactImageRequest """ +from __future__ import annotations import re import time -from typing import List, Optional, Sequence, Tuple, Union +from typing import Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -88,8 +88,8 @@ class CloudDLPHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -103,10 +103,9 @@ def get_conn(self) -> DlpServiceClient: Provides a client for interacting with the Cloud DLP API. :return: Google Cloud DLP API Client - :rtype: google.cloud.dlp_v2.DlpServiceClient """ if not self._client: - self._client = DlpServiceClient(credentials=self._get_credentials(), client_info=CLIENT_INFO) + self._client = DlpServiceClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) return self._client @GoogleBaseHook.fallback_to_default_project_id @@ -114,9 +113,9 @@ def cancel_dlp_job( self, dlp_job_id: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Starts asynchronous cancellation on a long-running DLP job. @@ -142,13 +141,13 @@ def cancel_dlp_job( def create_deidentify_template( self, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - deidentify_template: Optional[Union[dict, DeidentifyTemplate]] = None, - template_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + deidentify_template: dict | DeidentifyTemplate | None = None, + template_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> DeidentifyTemplate: """ Creates a deidentify template for re-using frequently used configuration for @@ -167,7 +166,6 @@ def create_deidentify_template( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.dlp_v2.types.DeidentifyTemplate """ client = self.get_conn() # Handle project_id from connection configuration @@ -193,12 +191,12 @@ def create_deidentify_template( def create_dlp_job( self, project_id: str = PROVIDE_PROJECT_ID, - inspect_job: Optional[Union[dict, InspectJobConfig]] = None, - risk_job: Optional[Union[dict, RiskAnalysisJobConfig]] = None, - job_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + inspect_job: dict | InspectJobConfig | None = None, + risk_job: dict | RiskAnalysisJobConfig | None = None, + job_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), wait_until_finished: bool = True, time_to_sleep_in_seconds: int = 60, ) -> DlpJob: @@ -219,7 +217,6 @@ def create_dlp_job( :param metadata: (Optional) Additional metadata that is provided to the method. :param wait_until_finished: (Optional) If true, it will keep polling the job state until it is set to DONE. - :rtype: google.cloud.dlp_v2.types.DlpJob :param time_to_sleep_in_seconds: (Optional) Time to sleep, in seconds, between active checks of the operation results. Defaults to 60. """ @@ -266,13 +263,13 @@ def create_dlp_job( def create_inspect_template( self, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - inspect_template: Optional[InspectTemplate] = None, - template_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + inspect_template: InspectTemplate | None = None, + template_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> InspectTemplate: """ Creates an inspect template for re-using frequently used configuration for @@ -291,7 +288,6 @@ def create_inspect_template( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.dlp_v2.types.InspectTemplate """ client = self.get_conn() @@ -318,11 +314,11 @@ def create_inspect_template( def create_job_trigger( self, project_id: str = PROVIDE_PROJECT_ID, - job_trigger: Optional[Union[dict, JobTrigger]] = None, - trigger_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + job_trigger: dict | JobTrigger | None = None, + trigger_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> JobTrigger: """ Creates a job trigger to run DLP actions such as scanning storage for sensitive @@ -339,7 +335,6 @@ def create_job_trigger( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.dlp_v2.types.JobTrigger """ client = self.get_conn() @@ -355,13 +350,13 @@ def create_job_trigger( def create_stored_info_type( self, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - config: Optional[Union[dict, StoredInfoTypeConfig]] = None, - stored_info_type_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + config: dict | StoredInfoTypeConfig | None = None, + stored_info_type_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> StoredInfoType: """ Creates a pre-built stored info type to be used for inspection. @@ -379,7 +374,6 @@ def create_stored_info_type( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.dlp_v2.types.StoredInfoType """ client = self.get_conn() @@ -406,14 +400,14 @@ def create_stored_info_type( def deidentify_content( self, project_id: str = PROVIDE_PROJECT_ID, - deidentify_config: Optional[Union[dict, DeidentifyConfig]] = None, - inspect_config: Optional[Union[dict, InspectConfig]] = None, - item: Optional[Union[dict, ContentItem]] = None, - inspect_template_name: Optional[str] = None, - deidentify_template_name: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + deidentify_config: dict | DeidentifyConfig | None = None, + inspect_config: dict | InspectConfig | None = None, + item: dict | ContentItem | None = None, + inspect_template_name: str | None = None, + deidentify_template_name: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> DeidentifyContentResponse: """ De-identifies potentially sensitive info from a content item. This method has limits @@ -439,7 +433,6 @@ def deidentify_content( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.dlp_v2.types.DeidentifyContentResponse """ client = self.get_conn() @@ -497,9 +490,9 @@ def delete_dlp_job( self, dlp_job_id: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Deletes a long-running DLP job. This method indicates that the client is no longer @@ -527,11 +520,11 @@ def delete_dlp_job( def delete_inspect_template( self, template_id: str, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Deletes an inspect template. @@ -571,9 +564,9 @@ def delete_job_trigger( self, job_trigger_id: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Deletes a job trigger. @@ -600,11 +593,11 @@ def delete_job_trigger( def delete_stored_info_type( self, stored_info_type_id: str, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Deletes a stored info type. @@ -642,11 +635,11 @@ def delete_stored_info_type( def get_deidentify_template( self, template_id: str, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> DeidentifyTemplate: """ Gets a deidentify template. @@ -663,7 +656,6 @@ def get_deidentify_template( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.dlp_v2.types.DeidentifyTemplate """ client = self.get_conn() @@ -687,9 +679,9 @@ def get_dlp_job( self, dlp_job_id: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> DlpJob: """ Gets the latest state of a long-running Dlp Job. @@ -704,7 +696,6 @@ def get_dlp_job( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.dlp_v2.types.DlpJob """ client = self.get_conn() @@ -717,11 +708,11 @@ def get_dlp_job( def get_inspect_template( self, template_id: str, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> InspectTemplate: """ Gets an inspect template. @@ -738,7 +729,6 @@ def get_inspect_template( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.dlp_v2.types.InspectTemplate """ client = self.get_conn() @@ -762,9 +752,9 @@ def get_job_trigger( self, job_trigger_id: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> JobTrigger: """ Gets a DLP job trigger. @@ -779,7 +769,6 @@ def get_job_trigger( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.dlp_v2.types.JobTrigger """ client = self.get_conn() @@ -792,11 +781,11 @@ def get_job_trigger( def get_stored_info_type( self, stored_info_type_id: str, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> StoredInfoType: """ Gets a stored info type. @@ -813,7 +802,6 @@ def get_stored_info_type( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.dlp_v2.types.StoredInfoType """ client = self.get_conn() @@ -836,12 +824,12 @@ def get_stored_info_type( def inspect_content( self, project_id: str, - inspect_config: Optional[Union[dict, InspectConfig]] = None, - item: Optional[Union[dict, ContentItem]] = None, - inspect_template_name: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + inspect_config: dict | InspectConfig | None = None, + item: dict | ContentItem | None = None, + inspect_template_name: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> InspectContentResponse: """ Finds potentially sensitive info in content. This method has limits on input size, @@ -861,7 +849,6 @@ def inspect_content( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.dlp_v2.types.InspectContentResponse """ client = self.get_conn() @@ -878,14 +865,14 @@ def inspect_content( def list_deidentify_templates( self, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - page_size: Optional[int] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> List[DeidentifyTemplate]: + organization_id: str | None = None, + project_id: str | None = None, + page_size: int | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> list[DeidentifyTemplate]: """ Lists deidentify templates. @@ -904,7 +891,6 @@ def list_deidentify_templates( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: List[google.cloud.dlp_v2.types.DeidentifyTemplate] """ client = self.get_conn() @@ -933,14 +919,14 @@ def list_deidentify_templates( def list_dlp_jobs( self, project_id: str, - results_filter: Optional[str] = None, - page_size: Optional[int] = None, - job_type: Optional[str] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> List[DlpJob]: + results_filter: str | None = None, + page_size: int | None = None, + job_type: str | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> list[DlpJob]: """ Lists DLP jobs that match the specified filter in the request. @@ -959,7 +945,6 @@ def list_dlp_jobs( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: List[google.cloud.dlp_v2.types.DlpJob] """ client = self.get_conn() @@ -978,11 +963,11 @@ def list_dlp_jobs( def list_info_types( self, - language_code: Optional[str] = None, - results_filter: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + language_code: str | None = None, + results_filter: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ListInfoTypesResponse: """ Returns a list of the sensitive information types that the DLP API supports. @@ -997,7 +982,6 @@ def list_info_types( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.dlp_v2.types.ListInfoTypesResponse """ client = self.get_conn() @@ -1011,14 +995,14 @@ def list_info_types( def list_inspect_templates( self, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - page_size: Optional[int] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> List[InspectTemplate]: + organization_id: str | None = None, + project_id: str | None = None, + page_size: int | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> list[InspectTemplate]: """ Lists inspect templates. @@ -1037,7 +1021,6 @@ def list_inspect_templates( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: List[google.cloud.dlp_v2.types.InspectTemplate] """ client = self.get_conn() @@ -1065,13 +1048,13 @@ def list_inspect_templates( def list_job_triggers( self, project_id: str, - page_size: Optional[int] = None, - order_by: Optional[str] = None, - results_filter: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> List[JobTrigger]: + page_size: int | None = None, + order_by: str | None = None, + results_filter: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> list[JobTrigger]: """ Lists job triggers. @@ -1089,7 +1072,6 @@ def list_job_triggers( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: List[google.cloud.dlp_v2.types.JobTrigger] """ client = self.get_conn() @@ -1107,14 +1089,14 @@ def list_job_triggers( def list_stored_info_types( self, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - page_size: Optional[int] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> List[StoredInfoType]: + organization_id: str | None = None, + project_id: str | None = None, + page_size: int | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> list[StoredInfoType]: """ Lists stored info types. @@ -1133,7 +1115,6 @@ def list_stored_info_types( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: List[google.cloud.dlp_v2.types.StoredInfoType] """ client = self.get_conn() @@ -1161,15 +1142,13 @@ def list_stored_info_types( def redact_image( self, project_id: str, - inspect_config: Optional[Union[dict, InspectConfig]] = None, - image_redaction_configs: Optional[ - Union[List[dict], List[RedactImageRequest.ImageRedactionConfig]] - ] = None, - include_findings: Optional[bool] = None, - byte_item: Optional[Union[dict, ByteContentItem]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + inspect_config: dict | InspectConfig | None = None, + image_redaction_configs: None | (list[dict] | list[RedactImageRequest.ImageRedactionConfig]) = None, + include_findings: bool | None = None, + byte_item: dict | ByteContentItem | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> RedactImageResponse: """ Redacts potentially sensitive info from an image. This method has limits on @@ -1182,7 +1161,7 @@ def redact_image( here will override the template referenced by the inspect_template_name argument. :param image_redaction_configs: (Optional) The configuration for specifying what content to redact from images. - List[google.cloud.dlp_v2.types.RedactImageRequest.ImageRedactionConfig] + list[google.cloud.dlp_v2.types.RedactImageRequest.ImageRedactionConfig] :param include_findings: (Optional) Whether the response should include findings along with the redacted image. :param byte_item: (Optional) The content must be PNG, JPEG, SVG or BMP. @@ -1192,7 +1171,6 @@ def redact_image( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.dlp_v2.types.RedactImageResponse """ client = self.get_conn() @@ -1212,14 +1190,14 @@ def redact_image( def reidentify_content( self, project_id: str, - reidentify_config: Optional[Union[dict, DeidentifyConfig]] = None, - inspect_config: Optional[Union[dict, InspectConfig]] = None, - item: Optional[Union[dict, ContentItem]] = None, - inspect_template_name: Optional[str] = None, - reidentify_template_name: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + reidentify_config: dict | DeidentifyConfig | None = None, + inspect_config: dict | InspectConfig | None = None, + item: dict | ContentItem | None = None, + inspect_template_name: str | None = None, + reidentify_template_name: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ReidentifyContentResponse: """ Re-identifies content that has been de-identified. @@ -1242,7 +1220,6 @@ def reidentify_content( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.dlp_v2.types.ReidentifyContentResponse """ client = self.get_conn() @@ -1262,13 +1239,13 @@ def reidentify_content( def update_deidentify_template( self, template_id: str, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - deidentify_template: Optional[Union[dict, DeidentifyTemplate]] = None, - update_mask: Optional[Union[dict, FieldMask]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + deidentify_template: dict | DeidentifyTemplate | None = None, + update_mask: dict | FieldMask | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> DeidentifyTemplate: """ Updates the deidentify template. @@ -1287,7 +1264,6 @@ def update_deidentify_template( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.dlp_v2.types.DeidentifyTemplate """ client = self.get_conn() @@ -1316,13 +1292,13 @@ def update_deidentify_template( def update_inspect_template( self, template_id: str, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - inspect_template: Optional[Union[dict, InspectTemplate]] = None, - update_mask: Optional[Union[dict, FieldMask]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + inspect_template: dict | InspectTemplate | None = None, + update_mask: dict | FieldMask | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> InspectTemplate: """ Updates the inspect template. @@ -1341,7 +1317,6 @@ def update_inspect_template( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.dlp_v2.types.InspectTemplate """ client = self.get_conn() @@ -1371,11 +1346,11 @@ def update_job_trigger( self, job_trigger_id: str, project_id: str, - job_trigger: Optional[Union[dict, JobTrigger]] = None, - update_mask: Optional[Union[dict, FieldMask]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + job_trigger: dict | JobTrigger | None = None, + update_mask: dict | FieldMask | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> JobTrigger: """ Updates a job trigger. @@ -1392,7 +1367,6 @@ def update_job_trigger( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.dlp_v2.types.JobTrigger """ client = self.get_conn() @@ -1412,13 +1386,13 @@ def update_job_trigger( def update_stored_info_type( self, stored_info_type_id: str, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - config: Optional[Union[dict, StoredInfoTypeConfig]] = None, - update_mask: Optional[Union[dict, FieldMask]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + config: dict | StoredInfoTypeConfig | None = None, + update_mask: dict | FieldMask | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> StoredInfoType: """ Updates the stored info type by creating a new version. @@ -1438,7 +1412,6 @@ def update_stored_info_type( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.dlp_v2.types.StoredInfoType """ client = self.get_conn() diff --git a/airflow/providers/google/cloud/hooks/functions.py b/airflow/providers/google/cloud/hooks/functions.py index bf66e314e8e77..3bf10e6688e38 100644 --- a/airflow/providers/google/cloud/hooks/functions.py +++ b/airflow/providers/google/cloud/hooks/functions.py @@ -16,8 +16,10 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Cloud Functions Hook.""" +from __future__ import annotations + import time -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Sequence import requests from googleapiclient.discovery import build @@ -37,14 +39,14 @@ class CloudFunctionsHook(GoogleBaseHook): keyword arguments rather than positional. """ - _conn = None # type: Optional[Any] + _conn = None def __init__( self, api_version: str, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -63,19 +65,18 @@ def _full_location(project_id: str, location: str) -> str: :param location: The location where the function is created. :return: """ - return f'projects/{project_id}/locations/{location}' + return f"projects/{project_id}/locations/{location}" def get_conn(self) -> build: """ Retrieves the connection to Cloud Functions. :return: Google Cloud Function services object. - :rtype: dict """ if not self._conn: http_authorized = self._authorize() self._conn = build( - 'cloudfunctions', self.api_version, http=http_authorized, cache_discovery=False + "cloudfunctions", self.api_version, http=http_authorized, cache_discovery=False ) return self._conn @@ -85,7 +86,6 @@ def get_function(self, name: str) -> dict: :param name: Name of the function. :return: A Cloud Functions object representing the function. - :rtype: dict """ # fmt: off return self.get_conn().projects().locations().functions().get( @@ -112,7 +112,7 @@ def create_new_function(self, location: str, body: dict, project_id: str) -> Non operation_name = response["name"] self._wait_for_operation_to_complete(operation_name=operation_name) - def update_function(self, name: str, body: dict, update_mask: List[str]) -> None: + def update_function(self, name: str, body: dict, update_mask: list[str]) -> None: """ Updates Cloud Functions according to the specified update mask. @@ -141,7 +141,6 @@ def upload_function_zip(self, location: str, zip_path: str, project_id: str) -> :param project_id: Optional, Google Cloud Project project_id where the function belongs. If set to None or missing, the default project_id from the Google Cloud connection is used. :return: The upload URL that was returned by generateUploadUrl method. - :rtype: str """ # fmt: off @@ -151,8 +150,8 @@ def upload_function_zip(self, location: str, zip_path: str, project_id: str) -> ).execute(num_retries=self.num_retries) # fmt: on - upload_url = response.get('uploadUrl') - with open(zip_path, 'rb') as file: + upload_url = response.get("uploadUrl") + with open(zip_path, "rb") as file: requests.put( url=upload_url, data=file, @@ -160,8 +159,8 @@ def upload_function_zip(self, location: str, zip_path: str, project_id: str) -> # https://cloud.google.com/functions/docs/reference/rest/v1/projects.locations.functions/generateUploadUrl # nopep8 headers={ - 'Content-type': 'application/zip', - 'x-goog-content-length-range': '0,104857600', + "Content-type": "application/zip", + "x-goog-content-length-range": "0,104857600", }, ) return upload_url @@ -184,7 +183,7 @@ def delete_function(self, name: str) -> None: def call_function( self, function_id: str, - input_data: Dict, + input_data: dict, location: str, project_id: str = PROVIDE_PROJECT_ID, ) -> dict: @@ -206,8 +205,8 @@ def call_function( body=input_data ).execute(num_retries=self.num_retries) # fmt: on - if 'error' in response: - raise AirflowException(response['error']) + if "error" in response: + raise AirflowException(response["error"]) return response def _wait_for_operation_to_complete(self, operation_name: str) -> dict: @@ -217,7 +216,6 @@ def _wait_for_operation_to_complete(self, operation_name: str) -> dict: :param operation_name: The name of the operation. :return: The response returned by the operation. - :rtype: dict :exception: AirflowException in case error is returned. """ service = self.get_conn() diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py index 36e693b608a33..7e2e081a464eb 100644 --- a/airflow/providers/google/cloud/hooks/gcs.py +++ b/airflow/providers/google/cloud/hooks/gcs.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains a Google Cloud Storage hook.""" +from __future__ import annotations + import functools import gzip as gz import os @@ -28,21 +29,8 @@ from io import BytesIO from os import path from tempfile import NamedTemporaryFile -from typing import ( - IO, - Callable, - Generator, - List, - Optional, - Sequence, - Set, - Tuple, - TypeVar, - Union, - cast, - overload, -) -from urllib.parse import urlparse +from typing import IO, Callable, Generator, Sequence, TypeVar, cast, overload +from urllib.parse import urlsplit from google.api_core.exceptions import NotFound @@ -57,17 +45,21 @@ from airflow.utils import timezone from airflow.version import version -RT = TypeVar('RT') +RT = TypeVar("RT") T = TypeVar("T", bound=Callable) +# GCSHook has a method named 'list' (to junior devs: please don't do this), so +# we need to create an alias to prevent Mypy being confused. +List = list + # Use default timeout from google-cloud-storage DEFAULT_TIMEOUT = 60 def _fallback_object_url_to_object_name_and_bucket_name( - object_url_keyword_arg_name='object_url', - bucket_name_keyword_arg_name='bucket_name', - object_name_keyword_arg_name='object_name', + object_url_keyword_arg_name="object_url", + bucket_name_keyword_arg_name="bucket_name", + object_name_keyword_arg_name="object_name", ) -> Callable[[T], T]: """ Decorator factory that convert object URL parameter to object name and bucket name parameter. @@ -80,7 +72,7 @@ def _fallback_object_url_to_object_name_and_bucket_name( def _wrapper(func: T): @functools.wraps(func) - def _inner_wrapper(self: "GCSHook", *args, **kwargs) -> RT: + def _inner_wrapper(self: GCSHook, *args, **kwargs) -> RT: if args: raise AirflowException( "You must use keyword arguments in this methods rather than positional" @@ -127,7 +119,7 @@ def _inner_wrapper(self: "GCSHook", *args, **kwargs) -> RT: # A fake bucket to use in functions decorated by _fallback_object_url_to_object_name_and_bucket_name. -# This allows the 'bucket' argument to be of type str instead of Optional[str], +# This allows the 'bucket' argument to be of type str instead of str | None, # making it easier to type hint the function body without dealing with the None # case that can never happen at runtime. PROVIDE_BUCKET: str = cast(str, None) @@ -139,13 +131,13 @@ class GCSHook(GoogleBaseHook): connection. """ - _conn = None # type: Optional[storage.Client] + _conn: storage.Client | None = None def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -157,7 +149,7 @@ def get_conn(self) -> storage.Client: """Returns a Google Cloud Storage service object.""" if not self._conn: self._conn = storage.Client( - credentials=self._get_credentials(), client_info=CLIENT_INFO, project=self.project_id + credentials=self.get_credentials(), client_info=CLIENT_INFO, project=self.project_id ) return self._conn @@ -166,8 +158,8 @@ def copy( self, source_bucket: str, source_object: str, - destination_bucket: Optional[str] = None, - destination_object: Optional[str] = None, + destination_bucket: str | None = None, + destination_object: str | None = None, ) -> None: """ Copies an object from a bucket to another, with renaming if requested. @@ -188,11 +180,11 @@ def copy( if source_bucket == destination_bucket and source_object == destination_object: raise ValueError( - f'Either source/destination bucket or source/destination object must be different, ' - f'not both the same: bucket={source_bucket}, object={source_object}' + f"Either source/destination bucket or source/destination object must be different, " + f"not both the same: bucket={source_bucket}, object={source_object}" ) if not source_bucket or not source_object: - raise ValueError('source_bucket and source_object cannot be empty.') + raise ValueError("source_bucket and source_object cannot be empty.") client = self.get_conn() source_bucket = client.bucket(source_bucket) @@ -203,7 +195,7 @@ def copy( ) self.log.info( - 'Object %s in bucket %s copied to object %s in bucket %s', + "Object %s in bucket %s copied to object %s in bucket %s", source_object.name, # type: ignore[attr-defined] source_bucket.name, # type: ignore[attr-defined] destination_object.name, # type: ignore[union-attr] @@ -215,7 +207,7 @@ def rewrite( source_bucket: str, source_object: str, destination_bucket: str, - destination_object: Optional[str] = None, + destination_object: str | None = None, ) -> None: """ Has the same functionality as copy, except that will work on files @@ -233,11 +225,11 @@ def rewrite( destination_object = destination_object or source_object if source_bucket == destination_bucket and source_object == destination_object: raise ValueError( - f'Either source/destination bucket or source/destination object must be different, ' - f'not both the same: bucket={source_bucket}, object={source_object}' + f"Either source/destination bucket or source/destination object must be different, " + f"not both the same: bucket={source_bucket}, object={source_object}" ) if not source_bucket or not source_object: - raise ValueError('source_bucket and source_object cannot be empty.') + raise ValueError("source_bucket and source_object cannot be empty.") client = self.get_conn() source_bucket = client.bucket(source_bucket) @@ -248,16 +240,16 @@ def rewrite( blob_name=destination_object ).rewrite(source=source_object) - self.log.info('Total Bytes: %s | Bytes Written: %s', total_bytes, bytes_rewritten) + self.log.info("Total Bytes: %s | Bytes Written: %s", total_bytes, bytes_rewritten) while token is not None: token, bytes_rewritten, total_bytes = destination_bucket.blob( # type: ignore[attr-defined] blob_name=destination_object ).rewrite(source=source_object, token=token) - self.log.info('Total Bytes: %s | Bytes Written: %s', total_bytes, bytes_rewritten) + self.log.info("Total Bytes: %s | Bytes Written: %s", total_bytes, bytes_rewritten) self.log.info( - 'Object %s in bucket %s rewritten to object %s in bucket %s', + "Object %s in bucket %s rewritten to object %s in bucket %s", source_object.name, # type: ignore[attr-defined] source_bucket.name, # type: ignore[attr-defined] destination_object, @@ -270,9 +262,9 @@ def download( bucket_name: str, object_name: str, filename: None = None, - chunk_size: Optional[int] = None, - timeout: Optional[int] = DEFAULT_TIMEOUT, - num_max_attempts: Optional[int] = 1, + chunk_size: int | None = None, + timeout: int | None = DEFAULT_TIMEOUT, + num_max_attempts: int | None = 1, ) -> bytes: ... @@ -282,9 +274,9 @@ def download( bucket_name: str, object_name: str, filename: str, - chunk_size: Optional[int] = None, - timeout: Optional[int] = DEFAULT_TIMEOUT, - num_max_attempts: Optional[int] = 1, + chunk_size: int | None = None, + timeout: int | None = DEFAULT_TIMEOUT, + num_max_attempts: int | None = 1, ) -> str: ... @@ -292,11 +284,11 @@ def download( self, bucket_name: str, object_name: str, - filename: Optional[str] = None, - chunk_size: Optional[int] = None, - timeout: Optional[int] = DEFAULT_TIMEOUT, - num_max_attempts: Optional[int] = 1, - ) -> Union[str, bytes]: + filename: str | None = None, + chunk_size: int | None = None, + timeout: int | None = DEFAULT_TIMEOUT, + num_max_attempts: int | None = 1, + ) -> str | bytes: """ Downloads a file from Google Cloud Storage. @@ -326,7 +318,7 @@ def download( if filename: blob.download_to_filename(filename, timeout=timeout) - self.log.info('File downloaded to %s', filename) + self.log.info("File downloaded to %s", filename) return filename else: return blob.download_as_bytes() @@ -334,9 +326,9 @@ def download( except GoogleCloudError: if num_file_attempts == num_max_attempts: self.log.error( - 'Download attempt of object: %s from %s has failed. Attempt: %s, max %s.', - object_name, + "Download attempt of object: %s from %s has failed. Attempt: %s, max %s.", object_name, + bucket_name, num_file_attempts, num_max_attempts, ) @@ -351,9 +343,9 @@ def download_as_byte_array( self, bucket_name: str, object_name: str, - chunk_size: Optional[int] = None, - timeout: Optional[int] = DEFAULT_TIMEOUT, - num_max_attempts: Optional[int] = 1, + chunk_size: int | None = None, + timeout: int | None = DEFAULT_TIMEOUT, + num_max_attempts: int | None = 1, ) -> bytes: """ Downloads a file from Google Cloud Storage. @@ -383,9 +375,9 @@ def download_as_byte_array( def provide_file( self, bucket_name: str = PROVIDE_BUCKET, - object_name: Optional[str] = None, - object_url: Optional[str] = None, - dir: Optional[str] = None, + object_name: str | None = None, + object_url: str | None = None, + dir: str | None = None, ) -> Generator[IO[bytes], None, None]: """ Downloads the file to a temporary directory and returns a file handle @@ -412,8 +404,8 @@ def provide_file( def provide_file_and_upload( self, bucket_name: str = PROVIDE_BUCKET, - object_name: Optional[str] = None, - object_url: Optional[str] = None, + object_name: str | None = None, + object_url: str | None = None, ) -> Generator[IO[bytes], None, None]: """ Creates temporary file, returns a file handle and uploads the files content @@ -440,15 +432,15 @@ def upload( self, bucket_name: str, object_name: str, - filename: Optional[str] = None, - data: Optional[Union[str, bytes]] = None, - mime_type: Optional[str] = None, + filename: str | None = None, + data: str | bytes | None = None, + mime_type: str | None = None, gzip: bool = False, - encoding: str = 'utf-8', - chunk_size: Optional[int] = None, - timeout: Optional[int] = DEFAULT_TIMEOUT, + encoding: str = "utf-8", + chunk_size: int | None = None, + timeout: int | None = DEFAULT_TIMEOUT, num_max_attempts: int = 1, - metadata: Optional[dict] = None, + metadata: dict | None = None, ) -> None: """ Uploads a local file or file data as string or bytes to Google Cloud Storage. @@ -480,7 +472,7 @@ def _call_with_retry(f: Callable[[], None]) -> None: except GoogleCloudError as e: if num_file_attempts == num_max_attempts: self.log.error( - 'Upload attempt of object: %s from %s has failed. Attempt: %s, max %s.', + "Upload attempt of object: %s from %s has failed. Attempt: %s, max %s.", object_name, object_name, num_file_attempts, @@ -508,12 +500,12 @@ def _call_with_retry(f: Callable[[], None]) -> None: ) elif filename: if not mime_type: - mime_type = 'application/octet-stream' + mime_type = "application/octet-stream" if gzip: - filename_gz = filename + '.gz' + filename_gz = filename + ".gz" - with open(filename, 'rb') as f_in: - with gz.open(filename_gz, 'wb') as f_out: + with open(filename, "rb") as f_in: + with gz.open(filename_gz, "wb") as f_out: shutil.copyfileobj(f_in, f_out) filename = filename_gz @@ -523,10 +515,10 @@ def _call_with_retry(f: Callable[[], None]) -> None: if gzip: os.remove(filename) - self.log.info('File %s uploaded to %s in %s bucket', filename, object_name, bucket_name) + self.log.info("File %s uploaded to %s in %s bucket", filename, object_name, bucket_name) elif data: if not mime_type: - mime_type = 'text/plain' + mime_type = "text/plain" if gzip: if isinstance(data, str): data = bytes(data, encoding) @@ -537,7 +529,7 @@ def _call_with_retry(f: Callable[[], None]) -> None: _call_with_retry(partial(blob.upload_from_string, data, content_type=mime_type, timeout=timeout)) - self.log.info('Data stream uploaded to %s in %s bucket', object_name, bucket_name) + self.log.info("Data stream uploaded to %s in %s bucket", object_name, bucket_name) else: raise ValueError("'filename' and 'data' parameter missing. One is required to upload to gcs.") @@ -663,7 +655,7 @@ def delete(self, bucket_name: str, object_name: str) -> None: blob = bucket.blob(blob_name=object_name) blob.delete() - self.log.info('Blob %s deleted.', object_name) + self.log.info("Blob %s deleted.", object_name) def delete_bucket(self, bucket_name: str, force: bool = False) -> None: """ @@ -683,7 +675,7 @@ def delete_bucket(self, bucket_name: str, force: bool = False) -> None: except NotFound: self.log.info("Bucket %s not exists", bucket_name) - def list(self, bucket_name, versions=None, max_results=None, prefix=None, delimiter=None) -> list: + def list(self, bucket_name, versions=None, max_results=None, prefix=None, delimiter=None) -> List: """ List all objects from the bucket with the give string prefix in name @@ -730,10 +722,10 @@ def list_by_timespan( bucket_name: str, timespan_start: datetime, timespan_end: datetime, - versions: Optional[bool] = None, - max_results: Optional[int] = None, - prefix: Optional[str] = None, - delimiter: Optional[str] = None, + versions: bool | None = None, + max_results: int | None = None, + prefix: str | None = None, + delimiter: str | None = None, ) -> List[str]: """ List all objects from the bucket with the give string prefix in name that were @@ -790,12 +782,12 @@ def get_size(self, bucket_name: str, object_name: str) -> int: cloud storage bucket_name. """ - self.log.info('Checking the file size of object: %s in bucket_name: %s', object_name, bucket_name) + self.log.info("Checking the file size of object: %s in bucket_name: %s", object_name, bucket_name) client = self.get_conn() bucket = client.bucket(bucket_name) blob = bucket.get_blob(blob_name=object_name) blob_size = blob.size - self.log.info('The file size of %s is %s bytes.', object_name, blob_size) + self.log.info("The file size of %s is %s bytes.", object_name, blob_size) return blob_size def get_crc32c(self, bucket_name: str, object_name: str): @@ -807,7 +799,7 @@ def get_crc32c(self, bucket_name: str, object_name: str): storage bucket_name. """ self.log.info( - 'Retrieving the crc32c checksum of object_name: %s in bucket_name: %s', + "Retrieving the crc32c checksum of object_name: %s in bucket_name: %s", object_name, bucket_name, ) @@ -815,7 +807,7 @@ def get_crc32c(self, bucket_name: str, object_name: str): bucket = client.bucket(bucket_name) blob = bucket.get_blob(blob_name=object_name) blob_crc32c = blob.crc32c - self.log.info('The crc32c checksum of %s is %s', object_name, blob_crc32c) + self.log.info("The crc32c checksum of %s is %s", object_name, blob_crc32c) return blob_crc32c def get_md5hash(self, bucket_name: str, object_name: str) -> str: @@ -826,23 +818,23 @@ def get_md5hash(self, bucket_name: str, object_name: str) -> str: :param object_name: The name of the object to check in the Google cloud storage bucket_name. """ - self.log.info('Retrieving the MD5 hash of object: %s in bucket: %s', object_name, bucket_name) + self.log.info("Retrieving the MD5 hash of object: %s in bucket: %s", object_name, bucket_name) client = self.get_conn() bucket = client.bucket(bucket_name) blob = bucket.get_blob(blob_name=object_name) blob_md5hash = blob.md5_hash - self.log.info('The md5Hash of %s is %s', object_name, blob_md5hash) + self.log.info("The md5Hash of %s is %s", object_name, blob_md5hash) return blob_md5hash @GoogleBaseHook.fallback_to_default_project_id def create_bucket( self, bucket_name: str, - resource: Optional[dict] = None, - storage_class: str = 'MULTI_REGIONAL', - location: str = 'US', - project_id: Optional[str] = None, - labels: Optional[dict] = None, + resource: dict | None = None, + storage_class: str = "MULTI_REGIONAL", + location: str = "US", + project_id: str | None = None, + labels: dict | None = None, ) -> str: """ Creates a new bucket. Google Cloud Storage uses a flat namespace, so @@ -879,12 +871,12 @@ def create_bucket( :return: If successful, it returns the ``id`` of the bucket. """ self.log.info( - 'Creating Bucket: %s; Location: %s; Storage Class: %s', bucket_name, location, storage_class + "Creating Bucket: %s; Location: %s; Storage Class: %s", bucket_name, location, storage_class ) # Add airflow-version label to the bucket labels = labels or {} - labels['airflow-version'] = 'v' + version.replace('.', '-').replace('+', '-') + labels["airflow-version"] = "v" + version.replace(".", "-").replace("+", "-") client = self.get_conn() bucket = client.bucket(bucket_name=bucket_name) @@ -900,7 +892,7 @@ def create_bucket( return bucket.id def insert_bucket_acl( - self, bucket_name: str, entity: str, role: str, user_project: Optional[str] = None + self, bucket_name: str, entity: str, role: str, user_project: str | None = None ) -> None: """ Creates a new ACL entry on the specified bucket_name. @@ -916,7 +908,7 @@ def insert_bucket_acl( :param user_project: (Optional) The project to be billed for this request. Required for Requester Pays buckets. """ - self.log.info('Creating a new ACL entry in bucket: %s', bucket_name) + self.log.info("Creating a new ACL entry in bucket: %s", bucket_name) client = self.get_conn() bucket = client.bucket(bucket_name=bucket_name) bucket.acl.reload() @@ -925,7 +917,7 @@ def insert_bucket_acl( bucket.acl.user_project = user_project bucket.acl.save() - self.log.info('A new ACL entry created in bucket: %s', bucket_name) + self.log.info("A new ACL entry created in bucket: %s", bucket_name) def insert_object_acl( self, @@ -933,8 +925,8 @@ def insert_object_acl( object_name: str, entity: str, role: str, - generation: Optional[int] = None, - user_project: Optional[str] = None, + generation: int | None = None, + user_project: str | None = None, ) -> None: """ Creates a new ACL entry on the specified object. @@ -954,7 +946,7 @@ def insert_object_acl( :param user_project: (Optional) The project to be billed for this request. Required for Requester Pays buckets. """ - self.log.info('Creating a new ACL entry for object: %s in bucket: %s', object_name, bucket_name) + self.log.info("Creating a new ACL entry for object: %s in bucket: %s", object_name, bucket_name) client = self.get_conn() bucket = client.bucket(bucket_name=bucket_name) blob = bucket.blob(blob_name=object_name, generation=generation) @@ -965,9 +957,9 @@ def insert_object_acl( blob.acl.user_project = user_project blob.acl.save() - self.log.info('A new ACL entry created for object: %s in bucket: %s', object_name, bucket_name) + self.log.info("A new ACL entry created for object: %s in bucket: %s", object_name, bucket_name) - def compose(self, bucket_name: str, source_objects: List, destination_object: str) -> None: + def compose(self, bucket_name: str, source_objects: List[str], destination_object: str) -> None: """ Composes a list of existing object into a new object in the same storage bucket_name @@ -983,10 +975,10 @@ def compose(self, bucket_name: str, source_objects: List, destination_object: st :param destination_object: The path of the object if given. """ if not source_objects: - raise ValueError('source_objects cannot be empty.') + raise ValueError("source_objects cannot be empty.") if not bucket_name or not destination_object: - raise ValueError('bucket_name and destination_object cannot be empty.') + raise ValueError("bucket_name and destination_object cannot be empty.") self.log.info("Composing %s to %s in the bucket %s", source_objects, destination_object, bucket_name) client = self.get_conn() @@ -1002,8 +994,8 @@ def sync( self, source_bucket: str, destination_bucket: str, - source_object: Optional[str] = None, - destination_object: Optional[str] = None, + source_object: str | None = None, + destination_object: str | None = None, recursive: bool = True, allow_overwrite: bool = False, delete_extra_files: bool = False, @@ -1104,7 +1096,7 @@ def sync( self.log.info("Synchronization finished.") def _calculate_sync_destination_path( - self, blob: storage.Blob, destination_object: Optional[str], source_object_prefix_len: int + self, blob: storage.Blob, destination_object: str | None, source_object_prefix_len: int ) -> str: return ( path.join(destination_object, blob.name[source_object_prefix_len:]) @@ -1116,10 +1108,10 @@ def _calculate_sync_destination_path( def _prepare_sync_plan( source_bucket: storage.Bucket, destination_bucket: storage.Bucket, - source_object: Optional[str], - destination_object: Optional[str], + source_object: str | None, + destination_object: str | None, recursive: bool, - ) -> Tuple[Set[storage.Blob], Set[storage.Blob], Set[storage.Blob]]: + ) -> tuple[set[storage.Blob], set[storage.Blob], set[storage.Blob]]: # Calculate the number of characters that remove from the name, because they contain information # about the parent's path source_object_prefix_len = len(source_object) if source_object else 0 @@ -1139,11 +1131,11 @@ def _prepare_sync_plan( # Determine objects to copy and delete to_copy = source_names - destination_names to_delete = destination_names - source_names - to_copy_blobs = {source_names_index[a] for a in to_copy} # type: Set[storage.Blob] - to_delete_blobs = {destination_names_index[a] for a in to_delete} # type: Set[storage.Blob] + to_copy_blobs: set[storage.Blob] = {source_names_index[a] for a in to_copy} + to_delete_blobs: set[storage.Blob] = {destination_names_index[a] for a in to_delete} # Find names that are in both buckets names_to_check = source_names.intersection(destination_names) - to_rewrite_blobs = set() # type: Set[storage.Blob] + to_rewrite_blobs: set[storage.Blob] = set() # Compare objects based on crc32 for current_name in names_to_check: source_blob = source_names_index[current_name] @@ -1161,21 +1153,21 @@ def gcs_object_is_directory(bucket: str) -> bool: """ _, blob = _parse_gcs_url(bucket) - return len(blob) == 0 or blob.endswith('/') + return len(blob) == 0 or blob.endswith("/") -def _parse_gcs_url(gsurl: str) -> Tuple[str, str]: +def _parse_gcs_url(gsurl: str) -> tuple[str, str]: """ Given a Google Cloud Storage URL (gs:///), returns a tuple containing the corresponding bucket and blob. """ - parsed_url = urlparse(gsurl) + parsed_url = urlsplit(gsurl) if not parsed_url.netloc: - raise AirflowException('Please provide a bucket name') + raise AirflowException("Please provide a bucket name") if parsed_url.scheme.lower() != "gs": raise AirflowException(f"Schema must be to 'gs://': Current schema: '{parsed_url.scheme}://'") bucket = parsed_url.netloc # Remove leading '/' but NOT trailing one - blob = parsed_url.path.lstrip('/') + blob = parsed_url.path.lstrip("/") return bucket, blob diff --git a/airflow/providers/google/cloud/hooks/gdm.py b/airflow/providers/google/cloud/hooks/gdm.py index ae0b55c1e3873..6b9229543b463 100644 --- a/airflow/providers/google/cloud/hooks/gdm.py +++ b/airflow/providers/google/cloud/hooks/gdm.py @@ -15,9 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Sequence from googleapiclient.discovery import Resource, build @@ -34,8 +34,8 @@ class GoogleDeploymentManagerHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -44,30 +44,25 @@ def __init__( ) def get_conn(self) -> Resource: - """ - Returns a Google Deployment Manager service object. - - :rtype: googleapiclient.discovery.Resource - """ + """Returns a Google Deployment Manager service object.""" http_authorized = self._authorize() - return build('deploymentmanager', 'v2', http=http_authorized, cache_discovery=False) + return build("deploymentmanager", "v2", http=http_authorized, cache_discovery=False) @GoogleBaseHook.fallback_to_default_project_id def list_deployments( self, - project_id: Optional[str] = None, - deployment_filter: Optional[str] = None, - order_by: Optional[str] = None, - ) -> List[Dict[str, Any]]: + project_id: str | None = None, + deployment_filter: str | None = None, + order_by: str | None = None, + ) -> list[dict[str, Any]]: """ Lists deployments in a google cloud project. :param project_id: The project ID for this request. :param deployment_filter: A filter expression which limits resources returned in the response. :param order_by: A field name to order by, ex: "creationTimestamp desc" - :rtype: list """ - deployments = [] # type: List[Dict] + deployments: list[dict] = [] conn = self.get_conn() request = conn.deployments().list(project=project_id, filter=deployment_filter, orderBy=order_by) @@ -81,7 +76,7 @@ def list_deployments( @GoogleBaseHook.fallback_to_default_project_id def delete_deployment( - self, project_id: Optional[str], deployment: Optional[str] = None, delete_policy: Optional[str] = None + self, project_id: str | None, deployment: str | None = None, delete_policy: str | None = None ) -> None: """ Deletes a deployment and all associated resources in a google cloud project. @@ -89,8 +84,6 @@ def delete_deployment( :param project_id: The project ID for this request. :param deployment: The name of the deployment for this request. :param delete_policy: Sets the policy to use for deleting resources. (ABANDON | DELETE) - - :rtype: None """ conn = self.get_conn() @@ -98,7 +91,7 @@ def delete_deployment( project=project_id, deployment=deployment, deletePolicy=delete_policy ) resp = request.execute() - if 'error' in resp.keys(): + if "error" in resp.keys(): raise AirflowException( - 'Errors deleting deployment: ', ', '.join(err['message'] for err in resp['error']['errors']) + "Errors deleting deployment: ", ", ".join(err["message"] for err in resp["error"]["errors"]) ) diff --git a/airflow/providers/google/cloud/hooks/kms.py b/airflow/providers/google/cloud/hooks/kms.py index 35169f132d778..9ecc6efb5c662 100644 --- a/airflow/providers/google/cloud/hooks/kms.py +++ b/airflow/providers/google/cloud/hooks/kms.py @@ -15,12 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains a Google Cloud KMS hook""" - +from __future__ import annotations import base64 -from typing import Optional, Sequence, Tuple, Union +from typing import Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -61,26 +60,25 @@ class CloudKMSHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) - self._conn = None # type: Optional[KeyManagementServiceClient] + self._conn: KeyManagementServiceClient | None = None def get_conn(self) -> KeyManagementServiceClient: """ Retrieves connection to Cloud Key Management service. :return: Cloud Key Management service object - :rtype: google.cloud.kms_v1.KeyManagementServiceClient """ if not self._conn: self._conn = KeyManagementServiceClient( - credentials=self._get_credentials(), client_info=CLIENT_INFO + credentials=self.get_credentials(), client_info=CLIENT_INFO ) return self._conn @@ -88,10 +86,10 @@ def encrypt( self, key_name: str, plaintext: bytes, - authenticated_data: Optional[bytes] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + authenticated_data: bytes | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> str: """ Encrypts a plaintext message using Google Cloud KMS. @@ -108,13 +106,12 @@ def encrypt( retry is specified, the timeout applies to each individual attempt. :param metadata: Additional metadata that is provided to the method. :return: The base 64 encoded ciphertext of the original message. - :rtype: str """ response = self.get_conn().encrypt( request={ - 'name': key_name, - 'plaintext': plaintext, - 'additional_authenticated_data': authenticated_data, + "name": key_name, + "plaintext": plaintext, + "additional_authenticated_data": authenticated_data, }, retry=retry, timeout=timeout, @@ -128,10 +125,10 @@ def decrypt( self, key_name: str, ciphertext: str, - authenticated_data: Optional[bytes] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + authenticated_data: bytes | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> bytes: """ Decrypts a ciphertext message using Google Cloud KMS. @@ -147,13 +144,12 @@ def decrypt( retry is specified, the timeout applies to each individual attempt. :param metadata: Additional metadata that is provided to the method. :return: The original message. - :rtype: bytes """ response = self.get_conn().decrypt( request={ - 'name': key_name, - 'ciphertext': _b64decode(ciphertext), - 'additional_authenticated_data': authenticated_data, + "name": key_name, + "ciphertext": _b64decode(ciphertext), + "additional_authenticated_data": authenticated_data, }, retry=retry, timeout=timeout, diff --git a/airflow/providers/google/cloud/hooks/kubernetes_engine.py b/airflow/providers/google/cloud/hooks/kubernetes_engine.py index 31c6c6c1fd142..64e81a3623543 100644 --- a/airflow/providers/google/cloud/hooks/kubernetes_engine.py +++ b/airflow/providers/google/cloud/hooks/kubernetes_engine.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """ This module contains a Google Kubernetes Engine Hook. @@ -24,11 +23,12 @@ gapic enums """ +from __future__ import annotations import json import time import warnings -from typing import Dict, Optional, Sequence, Union +from typing import Sequence from google.api_core.exceptions import AlreadyExists, NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -58,22 +58,22 @@ class GKEHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - location: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + location: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) - self._client = None # type: Optional[ClusterManagerClient] + self._client: ClusterManagerClient | None = None self.location = location def get_cluster_manager_client(self) -> ClusterManagerClient: """Returns ClusterManagerClient.""" if self._client is None: - self._client = ClusterManagerClient(credentials=self._get_credentials(), client_info=CLIENT_INFO) + self._client = ClusterManagerClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) return self._client # To preserve backward compatibility @@ -94,7 +94,7 @@ def get_client(self) -> ClusterManagerClient: ) return self.get_conn() - def wait_for_operation(self, operation: Operation, project_id: Optional[str] = None) -> Operation: + def wait_for_operation(self, operation: Operation, project_id: str | None = None) -> Operation: """ Given an operation, continuously fetches the status from Google Cloud until either completion or an error occurring @@ -114,7 +114,7 @@ def wait_for_operation(self, operation: Operation, project_id: Optional[str] = N operation = self.get_operation(operation.name, project_id=project_id or self.project_id) return operation - def get_operation(self, operation_name: str, project_id: Optional[str] = None) -> Operation: + def get_operation(self, operation_name: str, project_id: str | None = None) -> Operation: """ Fetches the operation from Google Cloud @@ -124,8 +124,8 @@ def get_operation(self, operation_name: str, project_id: Optional[str] = None) - """ return self.get_cluster_manager_client().get_operation( name=( - f'projects/{project_id or self.project_id}' - f'/locations/{self.location}/operations/{operation_name}' + f"projects/{project_id or self.project_id}" + f"/locations/{self.location}/operations/{operation_name}" ) ) @@ -143,7 +143,7 @@ def _append_label(cluster_proto: Cluster, key: str, val: str) -> Cluster: :param val: :return: The cluster proto updated with new label """ - val = val.replace('.', '-').replace('+', '-') + val = val.replace(".", "-").replace("+", "-") cluster_proto.resource_labels.update({key: val}) return cluster_proto @@ -152,9 +152,9 @@ def delete_cluster( self, name: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - ) -> Optional[str]: + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + ) -> str | None: """ Deletes the cluster, including the Kubernetes endpoint and all worker nodes. Firewalls and routes that were configured during @@ -176,24 +176,24 @@ def delete_cluster( try: resource = self.get_cluster_manager_client().delete_cluster( - name=f'projects/{project_id}/locations/{self.location}/clusters/{name}', + name=f"projects/{project_id}/locations/{self.location}/clusters/{name}", retry=retry, timeout=timeout, ) - resource = self.wait_for_operation(resource) + resource = self.wait_for_operation(resource, project_id) # Returns server-defined url for the resource return resource.self_link except NotFound as error: - self.log.info('Assuming Success: %s', error.message) + self.log.info("Assuming Success: %s", error.message) return None @GoogleBaseHook.fallback_to_default_project_id def create_cluster( self, - cluster: Union[Dict, Cluster, None], + cluster: dict | Cluster | None, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, ) -> str: """ Creates a cluster, consisting of the specified number and type of Google Compute @@ -219,7 +219,7 @@ def create_cluster( elif not isinstance(cluster, Cluster): raise AirflowException("cluster is not instance of Cluster proto or python dict") - self._append_label(cluster, 'airflow-version', 'v' + version.version) # type: ignore + self._append_label(cluster, "airflow-version", "v" + version.version) # type: ignore self.log.info( "Creating (project_id=%s, location=%s, cluster_name=%s)", @@ -229,16 +229,16 @@ def create_cluster( ) try: resource = self.get_cluster_manager_client().create_cluster( - parent=f'projects/{project_id}/locations/{self.location}', + parent=f"projects/{project_id}/locations/{self.location}", cluster=cluster, # type: ignore retry=retry, timeout=timeout, ) - resource = self.wait_for_operation(resource) + resource = self.wait_for_operation(resource, project_id) return resource.target_link except AlreadyExists as error: - self.log.info('Assuming Success: %s', error.message) + self.log.info("Assuming Success: %s", error.message) return self.get_cluster(name=cluster.name, project_id=project_id) # type: ignore @GoogleBaseHook.fallback_to_default_project_id @@ -246,8 +246,8 @@ def get_cluster( self, name: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, ) -> Cluster: """ Gets details of specified cluster @@ -271,7 +271,7 @@ def get_cluster( return ( self.get_cluster_manager_client() .get_cluster( - name=f'projects/{project_id}/locations/{self.location}/clusters/{name}', + name=f"projects/{project_id}/locations/{self.location}/clusters/{name}", retry=retry, timeout=timeout, ) diff --git a/airflow/providers/google/cloud/hooks/life_sciences.py b/airflow/providers/google/cloud/hooks/life_sciences.py index 551d21980c396..983a4d303a2c3 100644 --- a/airflow/providers/google/cloud/hooks/life_sciences.py +++ b/airflow/providers/google/cloud/hooks/life_sciences.py @@ -16,9 +16,10 @@ # specific language governing permissions and limitations # under the License. """Hook for Google Cloud Life Sciences service""" +from __future__ import annotations import time -from typing import Any, Optional, Sequence, Union +from typing import Sequence import google.api_core.path_template from googleapiclient.discovery import build @@ -52,14 +53,14 @@ class LifeSciencesHook(GoogleBaseHook): account from the list granting this role to the originating account. """ - _conn = None # type: Optional[Any] + _conn = None def __init__( self, api_version: str = "v2beta", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -88,7 +89,6 @@ def run_pipeline(self, body: dict, location: str, project_id: str) -> dict: :param location: The location of the project. For example: "us-east1". :param project_id: Optional, Google Cloud Project project_id where the function belongs. If set to None or missing, the default project_id from the Google Cloud connection is used. - :rtype: dict """ parent = self._location_path(project_id=project_id, location=location) service = self.get_conn() @@ -98,7 +98,7 @@ def run_pipeline(self, body: dict, location: str, project_id: str) -> dict: response = request.execute(num_retries=self.num_retries) # wait - operation_name = response['name'] + operation_name = response["name"] self._wait_for_operation_to_complete(operation_name) return response @@ -114,7 +114,7 @@ def _location_path(self, project_id: str, location: str) -> str: :param location: The location of the project. For example: "us-east1". """ return google.api_core.path_template.expand( - 'projects/{project}/locations/{location}', + "projects/{project}/locations/{location}", project=project_id, location=location, ) @@ -126,7 +126,6 @@ def _wait_for_operation_to_complete(self, operation_name: str) -> None: :param operation_name: The name of the operation. :return: The response returned by the operation. - :rtype: dict :exception: AirflowException in case error is returned. """ service = self.get_conn() @@ -138,7 +137,7 @@ def _wait_for_operation_to_complete(self, operation_name: str) -> None: .get(name=operation_name) .execute(num_retries=self.num_retries) ) - self.log.info('Waiting for pipeline operation to complete') + self.log.info("Waiting for pipeline operation to complete") if operation_response.get("done"): response = operation_response.get("response") error = operation_response.get("error") diff --git a/airflow/providers/google/cloud/hooks/looker.py b/airflow/providers/google/cloud/hooks/looker.py index 845211260bbbb..42477d2fbbba2 100644 --- a/airflow/providers/google/cloud/hooks/looker.py +++ b/airflow/providers/google/cloud/hooks/looker.py @@ -15,13 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains a Google Cloud Looker hook.""" +from __future__ import annotations import json import time from enum import Enum -from typing import Dict, Optional from looker_sdk.rtl import api_settings, auth_session, requests_transport, serialize from looker_sdk.sdk.api40 import methods as methods40 @@ -43,13 +42,13 @@ def __init__( super().__init__() self.looker_conn_id = looker_conn_id # source is used to track origin of the requests - self.source = f'airflow:{version}' + self.source = f"airflow:{version}" def start_pdt_build( self, model: str, view: str, - query_params: Optional[Dict] = None, + query_params: dict | None = None, ): """ Submits a PDT materialization job to Looker. @@ -64,7 +63,7 @@ def start_pdt_build( sdk = self.get_looker_sdk() looker_ver = sdk.versions().looker_release_version if parse_version(looker_ver) < parse_version("22.2.0"): - raise AirflowException(f'This API requires Looker version 22.2+. Found: {looker_ver}.') + raise AirflowException(f"This API requires Looker version 22.2+. Found: {looker_ver}.") # unpack query_params dict into kwargs (if not None) if query_params: @@ -96,7 +95,7 @@ def check_pdt_build( def pdt_build_status( self, materialization_id: str, - ) -> Dict: + ) -> dict: """ Gets the PDT materialization job status. @@ -104,11 +103,11 @@ def pdt_build_status( """ resp = self.check_pdt_build(materialization_id=materialization_id) - status_json = resp['resp_text'] + status_json = resp["resp_text"] status_dict = json.loads(status_json) self.log.info( - "PDT materialization job id: %s. Status: '%s'.", materialization_id, status_dict['status'] + "PDT materialization job id: %s. Status: '%s'.", materialization_id, status_dict["status"] ) return status_dict @@ -135,7 +134,7 @@ def wait_for_job( self, materialization_id: str, wait_time: int = 10, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> None: """ Helper method which polls a PDT materialization job to check if it finishes. @@ -145,7 +144,7 @@ def wait_for_job( :param timeout: Optional. How many seconds wait for job to be ready. Used only if ``asynchronous`` is False. """ - self.log.info('Waiting for PDT materialization job to complete. Job id: %s.', materialization_id) + self.log.info("Waiting for PDT materialization job to complete. Job id: %s.", materialization_id) status = None start = time.monotonic() @@ -167,25 +166,24 @@ def wait_for_job( time.sleep(wait_time) status_dict = self.pdt_build_status(materialization_id=materialization_id) - status = status_dict['status'] + status = status_dict["status"] if status == JobStatus.ERROR.value: - msg = status_dict['message'] + msg = status_dict["message"] raise AirflowException( f'PDT materialization job failed. Job id: {materialization_id}. Message:\n"{msg}"' ) if status == JobStatus.CANCELLED.value: - raise AirflowException(f'PDT materialization job was cancelled. Job id: {materialization_id}.') + raise AirflowException(f"PDT materialization job was cancelled. Job id: {materialization_id}.") if status == JobStatus.UNKNOWN.value: raise AirflowException( - f'PDT materialization job has unknown status. Job id: {materialization_id}.' + f"PDT materialization job has unknown status. Job id: {materialization_id}." ) - self.log.info('PDT materialization job completed successfully. Job id: %s.', materialization_id) + self.log.info("PDT materialization job completed successfully. Job id: %s.", materialization_id) def get_looker_sdk(self): """Returns Looker SDK client for Looker API 4.0.""" - conn = self.get_connection(self.looker_conn_id) settings = LookerApiSettings(conn) @@ -214,11 +212,10 @@ def read_config(self): Overrides the default logic of getting connection settings. Fetches the connection settings from Airflow's connection object. """ - config = {} if self.conn.host is None: - raise AirflowException(f'No `host` was supplied in connection: {self.conn.id}.') + raise AirflowException(f"No `host` was supplied in connection: {self.conn.id}.") if self.conn.port: config["base_url"] = f"{self.conn.host}:{self.conn.port}" # port is optional @@ -228,19 +225,19 @@ def read_config(self): if self.conn.login: config["client_id"] = self.conn.login else: - raise AirflowException(f'No `login` was supplied in connection: {self.conn.id}.') + raise AirflowException(f"No `login` was supplied in connection: {self.conn.id}.") if self.conn.password: config["client_secret"] = self.conn.password else: - raise AirflowException(f'No `password` was supplied in connection: {self.conn.id}.') + raise AirflowException(f"No `password` was supplied in connection: {self.conn.id}.") - extras = self.conn.extra_dejson # type: Dict + extras: dict = self.conn.extra_dejson - if 'verify_ssl' in extras: + if "verify_ssl" in extras: config["verify_ssl"] = extras["verify_ssl"] # optional - if 'timeout' in extras: + if "timeout" in extras: config["timeout"] = extras["timeout"] # optional return config @@ -250,9 +247,9 @@ class JobStatus(Enum): """The job status string.""" QUEUED = "added" - PENDING = 'pending' - RUNNING = 'running' - CANCELLED = 'killed' - DONE = 'complete' - ERROR = 'error' - UNKNOWN = 'unknown' + PENDING = "pending" + RUNNING = "running" + CANCELLED = "killed" + DONE = "complete" + ERROR = "error" + UNKNOWN = "unknown" diff --git a/airflow/providers/google/cloud/hooks/mlengine.py b/airflow/providers/google/cloud/hooks/mlengine.py index b1b6e83918aa4..0af6e15524702 100644 --- a/airflow/providers/google/cloud/hooks/mlengine.py +++ b/airflow/providers/google/cloud/hooks/mlengine.py @@ -16,23 +16,28 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google ML Engine Hook.""" +from __future__ import annotations + import logging import random import time -from typing import Callable, Dict, List, Optional +from typing import Callable from googleapiclient.discovery import Resource, build from googleapiclient.errors import HttpError +from httplib2 import Response from airflow.providers.google.common.hooks.base_google import GoogleBaseHook from airflow.version import version as airflow_version log = logging.getLogger(__name__) -_AIRFLOW_VERSION = 'v' + airflow_version.replace('.', '-').replace('+', '-') +_AIRFLOW_VERSION = "v" + airflow_version.replace(".", "-").replace("+", "-") -def _poll_with_exponential_delay(request, execute_num_retries, max_n, is_done_func, is_error_func): +def _poll_with_exponential_delay( + request, execute_num_retries, max_n, is_done_func, is_error_func +) -> Response: """ Execute request with exponential delay. @@ -46,26 +51,25 @@ def _poll_with_exponential_delay(request, execute_num_retries, max_n, is_done_fu :param is_done_func: callable to determine if operation is done. :param is_error_func: callable to determine if operation is failed. :return: response - :rtype: httplib2.Response """ for i in range(0, max_n): try: response = request.execute(num_retries=execute_num_retries) if is_error_func(response): - raise ValueError(f'The response contained an error: {response}') + raise ValueError(f"The response contained an error: {response}") if is_done_func(response): - log.info('Operation is done: %s', response) + log.info("Operation is done: %s", response) return response time.sleep((2**i) + (random.randint(0, 1000) / 1000)) except HttpError as e: if e.resp.status != 429: - log.info('Something went wrong. Not retrying: %s', format(e)) + log.info("Something went wrong. Not retrying: %s", format(e)) raise else: time.sleep((2**i) + (random.randint(0, 1000) / 1000)) - raise ValueError(f'Connection could not be established after {max_n} retries.') + raise ValueError(f"Connection could not be established after {max_n} retries.") class MLEngineHook(GoogleBaseHook): @@ -83,10 +87,10 @@ def get_conn(self) -> Resource: :return: Google MLEngine services object. """ authed_http = self._authorize() - return build('ml', 'v1', http=authed_http, cache_discovery=False) + return build("ml", "v1", http=authed_http, cache_discovery=False) @GoogleBaseHook.fallback_to_default_project_id - def create_job(self, job: dict, project_id: str, use_existing_job_fn: Optional[Callable] = None) -> dict: + def create_job(self, job: dict, project_id: str, use_existing_job_fn: Callable | None = None) -> dict: """ Launches a MLEngine job and wait for it to reach a terminal state. @@ -113,15 +117,14 @@ def create_job(self, job: dict, project_id: str, use_existing_job_fn: Optional[C we by default reuse the existing MLEngine job. :return: The MLEngine job object if the job successfully reach a terminal state (which might be FAILED or CANCELLED state). - :rtype: dict """ hook = self.get_conn() self._append_label(job) self.log.info("Creating job.") - request = hook.projects().jobs().create(parent=f'projects/{project_id}', body=job) - job_id = job['jobId'] + request = hook.projects().jobs().create(parent=f"projects/{project_id}", body=job) + job_id = job["jobId"] try: request.execute(num_retries=self.num_retries) @@ -132,14 +135,14 @@ def create_job(self, job: dict, project_id: str, use_existing_job_fn: Optional[C existing_job = self._get_job(project_id, job_id) if not use_existing_job_fn(existing_job): self.log.error( - 'Job with job_id %s already exist, but it does not match our expectation: %s', + "Job with job_id %s already exist, but it does not match our expectation: %s", job_id, existing_job, ) raise - self.log.info('Job with job_id %s already exist. Will waiting for it to finish', job_id) + self.log.info("Job with job_id %s already exist. Will waiting for it to finish", job_id) else: - self.log.error('Failed to create MLEngine job: %s', e) + self.log.error("Failed to create MLEngine job: %s", e) raise return self._wait_for_job_done(project_id, job_id) @@ -159,24 +162,23 @@ def cancel_job( :param job_id: A unique id for the want-to-be cancelled Google MLEngine training job. :return: Empty dict if cancelled successfully - :rtype: dict :raises: googleapiclient.errors.HttpError """ hook = self.get_conn() - request = hook.projects().jobs().cancel(name=f'projects/{project_id}/jobs/{job_id}') + request = hook.projects().jobs().cancel(name=f"projects/{project_id}/jobs/{job_id}") try: return request.execute(num_retries=self.num_retries) except HttpError as e: if e.resp.status == 404: - self.log.error('Job with job_id %s does not exist. ', job_id) + self.log.error("Job with job_id %s does not exist. ", job_id) raise elif e.resp.status == 400: - self.log.info('Job with job_id %s is already complete, cancellation aborted.', job_id) + self.log.info("Job with job_id %s is already complete, cancellation aborted.", job_id) return {} else: - self.log.error('Failed to cancel MLEngine job: %s', e) + self.log.error("Failed to cancel MLEngine job: %s", e) raise def _get_job(self, project_id: str, job_id: str) -> dict: @@ -187,11 +189,10 @@ def _get_job(self, project_id: str, job_id: str) -> dict: project_id from the Google Cloud connection is used. (templated) :param job_id: A unique id for the Google MLEngine job. (templated) :return: MLEngine job object if succeed. - :rtype: dict :raises: googleapiclient.errors.HttpError """ hook = self.get_conn() - job_name = f'projects/{project_id}/jobs/{job_id}' + job_name = f"projects/{project_id}/jobs/{job_id}" request = hook.projects().jobs().get(name=job_name) while True: try: @@ -201,7 +202,7 @@ def _get_job(self, project_id: str, job_id: str) -> dict: # polling after 30 seconds when quota failure occurs time.sleep(30) else: - self.log.error('Failed to get MLEngine job: %s', e) + self.log.error("Failed to get MLEngine job: %s", e) raise def _wait_for_job_done(self, project_id: str, job_id: str, interval: int = 30): @@ -223,7 +224,7 @@ def _wait_for_job_done(self, project_id: str, job_id: str, interval: int = 30): raise ValueError("Interval must be > 0") while True: job = self._get_job(project_id, job_id) - if job['state'] in ['SUCCEEDED', 'FAILED', 'CANCELLED']: + if job["state"] in ["SUCCEEDED", "FAILED", "CANCELLED"]: return job time.sleep(interval) @@ -231,7 +232,7 @@ def _wait_for_job_done(self, project_id: str, job_id: str, interval: int = 30): def create_version( self, model_name: str, - version_spec: Dict, + version_spec: dict, project_id: str, ) -> dict: """ @@ -245,23 +246,22 @@ def create_version( (templated) :return: If the version was created successfully, returns the operation. Otherwise raises an error . - :rtype: dict """ hook = self.get_conn() - parent_name = f'projects/{project_id}/models/{model_name}' + parent_name = f"projects/{project_id}/models/{model_name}" self._append_label(version_spec) create_request = hook.projects().models().versions().create(parent=parent_name, body=version_spec) response = create_request.execute(num_retries=self.num_retries) - get_request = hook.projects().operations().get(name=response['name']) + get_request = hook.projects().operations().get(name=response["name"]) return _poll_with_exponential_delay( request=get_request, execute_num_retries=self.num_retries, max_n=9, - is_done_func=lambda resp: resp.get('done', False), - is_error_func=lambda resp: resp.get('error', None) is not None, + is_done_func=lambda resp: resp.get("done", False), + is_error_func=lambda resp: resp.get("error", None) is not None, ) @GoogleBaseHook.fallback_to_default_project_id @@ -281,20 +281,19 @@ def set_default_version( or missing, the default project_id from the Google Cloud connection is used. (templated) :return: If successful, return an instance of Version. Otherwise raises an error. - :rtype: dict :raises: googleapiclient.errors.HttpError """ hook = self.get_conn() - full_version_name = f'projects/{project_id}/models/{model_name}/versions/{version_name}' + full_version_name = f"projects/{project_id}/models/{model_name}/versions/{version_name}" request = hook.projects().models().versions().setDefault(name=full_version_name, body={}) try: response = request.execute(num_retries=self.num_retries) - self.log.info('Successfully set version: %s to default', response) + self.log.info("Successfully set version: %s to default", response) return response except HttpError as e: - self.log.error('Something went wrong: %s', e) + self.log.error("Something went wrong: %s", e) raise @GoogleBaseHook.fallback_to_default_project_id @@ -302,7 +301,7 @@ def list_versions( self, model_name: str, project_id: str, - ) -> List[dict]: + ) -> list[dict]: """ Lists all available versions of a model. Blocks until finished. @@ -311,18 +310,17 @@ def list_versions( :param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None or missing, the default project_id from the Google Cloud connection is used. (templated) :return: return an list of instance of Version. - :rtype: List[Dict] :raises: googleapiclient.errors.HttpError """ hook = self.get_conn() - result = [] # type: List[Dict] - full_parent_name = f'projects/{project_id}/models/{model_name}' + result: list[dict] = [] + full_parent_name = f"projects/{project_id}/models/{model_name}" request = hook.projects().models().versions().list(parent=full_parent_name, pageSize=100) while request is not None: response = request.execute(num_retries=self.num_retries) - result.extend(response.get('versions', [])) + result.extend(response.get("versions", [])) request = ( hook.projects() @@ -349,20 +347,19 @@ def delete_version( model belongs. :return: If the version was deleted successfully, returns the operation. Otherwise raises an error. - :rtype: Dict """ hook = self.get_conn() - full_name = f'projects/{project_id}/models/{model_name}/versions/{version_name}' + full_name = f"projects/{project_id}/models/{model_name}/versions/{version_name}" delete_request = hook.projects().models().versions().delete(name=full_name) response = delete_request.execute(num_retries=self.num_retries) - get_request = hook.projects().operations().get(name=response['name']) + get_request = hook.projects().operations().get(name=response["name"]) return _poll_with_exponential_delay( request=get_request, execute_num_retries=self.num_retries, max_n=9, - is_done_func=lambda resp: resp.get('done', False), - is_error_func=lambda resp: resp.get('error', None) is not None, + is_done_func=lambda resp: resp.get("done", False), + is_error_func=lambda resp: resp.get("error", None) is not None, ) @GoogleBaseHook.fallback_to_default_project_id @@ -379,13 +376,12 @@ def create_model( missing, the default project_id from the Google Cloud connection is used. (templated) :return: If the version was created successfully, returns the instance of Model. Otherwise raises an error. - :rtype: Dict :raises: googleapiclient.errors.HttpError """ hook = self.get_conn() - if 'name' not in model or not model['name']: + if "name" not in model or not model["name"]: raise ValueError("Model name must be provided and could not be an empty string") - project = f'projects/{project_id}' + project = f"projects/{project_id}" self._append_label(model) try: @@ -399,19 +395,19 @@ def create_model( raise e error_detail = e.error_details[0] - if error_detail["@type"] != 'type.googleapis.com/google.rpc.BadRequest': + if error_detail["@type"] != "type.googleapis.com/google.rpc.BadRequest": raise e - if "fieldViolations" not in error_detail or len(error_detail['fieldViolations']) != 1: + if "fieldViolations" not in error_detail or len(error_detail["fieldViolations"]) != 1: raise e - field_violation = error_detail['fieldViolations'][0] + field_violation = error_detail["fieldViolations"][0] if ( field_violation["field"] != "model.name" or field_violation["description"] != "A model with the same name already exists." ): raise e - response = self.get_model(model_name=model['name'], project_id=project_id) + response = self.get_model(model_name=model["name"], project_id=project_id) return response @@ -420,7 +416,7 @@ def get_model( self, model_name: str, project_id: str, - ) -> Optional[dict]: + ) -> dict | None: """ Gets a Model. Blocks until finished. @@ -429,19 +425,18 @@ def get_model( or missing, the default project_id from the Google Cloud connection is used. (templated) :return: If the model exists, returns the instance of Model. Otherwise return None. - :rtype: Dict :raises: googleapiclient.errors.HttpError """ hook = self.get_conn() if not model_name: raise ValueError("Model name must be provided and it could not be an empty string") - full_model_name = f'projects/{project_id}/models/{model_name}' + full_model_name = f"projects/{project_id}/models/{model_name}" request = hook.projects().models().get(name=full_model_name) try: return request.execute(num_retries=self.num_retries) except HttpError as e: if e.resp.status == 404: - self.log.error('Model was not found: %s', e) + self.log.error("Model was not found: %s", e) return None raise @@ -467,7 +462,7 @@ def delete_model( if not model_name: raise ValueError("Model name must be provided and it could not be an empty string") - model_path = f'projects/{project_id}/models/{model_name}' + model_path = f"projects/{project_id}/models/{model_name}" if delete_contents: self._delete_all_versions(model_name, project_id) request = hook.projects().models().delete(name=model_path) @@ -475,22 +470,22 @@ def delete_model( request.execute(num_retries=self.num_retries) except HttpError as e: if e.resp.status == 404: - self.log.error('Model was not found: %s', e) + self.log.error("Model was not found: %s", e) return raise def _delete_all_versions(self, model_name: str, project_id: str): versions = self.list_versions(project_id=project_id, model_name=model_name) # The default version can only be deleted when it is the last one in the model - non_default_versions = (version for version in versions if not version.get('isDefault', False)) + non_default_versions = (version for version in versions if not version.get("isDefault", False)) for version in non_default_versions: - _, _, version_name = version['name'].rpartition('/') + _, _, version_name = version["name"].rpartition("/") self.delete_version(project_id=project_id, model_name=model_name, version_name=version_name) - default_versions = (version for version in versions if version.get('isDefault', False)) + default_versions = (version for version in versions if version.get("isDefault", False)) for version in default_versions: - _, _, version_name = version['name'].rpartition('/') + _, _, version_name = version["name"].rpartition("/") self.delete_version(project_id=project_id, model_name=model_name, version_name=version_name) def _append_label(self, model: dict) -> None: - model['labels'] = model.get('labels', {}) - model['labels']['airflow-version'] = _AIRFLOW_VERSION + model["labels"] = model.get("labels", {}) + model["labels"]["airflow-version"] = _AIRFLOW_VERSION diff --git a/airflow/providers/google/cloud/hooks/natural_language.py b/airflow/providers/google/cloud/hooks/natural_language.py index b297700be75a9..010db53ffc0d2 100644 --- a/airflow/providers/google/cloud/hooks/natural_language.py +++ b/airflow/providers/google/cloud/hooks/natural_language.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Cloud Natural Language Hook.""" -from typing import Optional, Sequence, Tuple, Union +from __future__ import annotations + +from typing import Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -57,8 +59,8 @@ class CloudNaturalLanguageHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -72,20 +74,19 @@ def get_conn(self) -> LanguageServiceClient: Retrieves connection to Cloud Natural Language service. :return: Cloud Natural Language service object - :rtype: google.cloud.language_v1.LanguageServiceClient """ if not self._conn: - self._conn = LanguageServiceClient(credentials=self._get_credentials(), client_info=CLIENT_INFO) + self._conn = LanguageServiceClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) return self._conn @GoogleBaseHook.quota_retry() def analyze_entities( self, - document: Union[dict, Document], - encoding_type: Optional[enums.EncodingType] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + document: dict | Document, + encoding_type: enums.EncodingType | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> AnalyzeEntitiesResponse: """ Finds named entities in the text along with entity types, @@ -99,7 +100,6 @@ def analyze_entities( :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: Additional metadata that is provided to the method. - :rtype: google.cloud.language_v1.types.AnalyzeEntitiesResponse """ client = self.get_conn() @@ -110,11 +110,11 @@ def analyze_entities( @GoogleBaseHook.quota_retry() def analyze_entity_sentiment( self, - document: Union[dict, Document], - encoding_type: Optional[enums.EncodingType] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + document: dict | Document, + encoding_type: enums.EncodingType | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> AnalyzeEntitySentimentResponse: """ Finds entities, similar to AnalyzeEntities in the text and analyzes sentiment associated with each @@ -128,7 +128,6 @@ def analyze_entity_sentiment( :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: Additional metadata that is provided to the method. - :rtype: google.cloud.language_v1.types.AnalyzeEntitiesResponse """ client = self.get_conn() @@ -139,11 +138,11 @@ def analyze_entity_sentiment( @GoogleBaseHook.quota_retry() def analyze_sentiment( self, - document: Union[dict, Document], - encoding_type: Optional[enums.EncodingType] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + document: dict | Document, + encoding_type: enums.EncodingType | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> AnalyzeSentimentResponse: """ Analyzes the sentiment of the provided text. @@ -156,7 +155,6 @@ def analyze_sentiment( :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: Additional metadata that is provided to the method. - :rtype: google.cloud.language_v1.types.AnalyzeSentimentResponse """ client = self.get_conn() @@ -167,11 +165,11 @@ def analyze_sentiment( @GoogleBaseHook.quota_retry() def analyze_syntax( self, - document: Union[dict, Document], - encoding_type: Optional[enums.EncodingType] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + document: dict | Document, + encoding_type: enums.EncodingType | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> AnalyzeSyntaxResponse: """ Analyzes the syntax of the text and provides sentence boundaries and tokenization along with part @@ -185,7 +183,6 @@ def analyze_syntax( :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: Additional metadata that is provided to the method. - :rtype: google.cloud.language_v1.types.AnalyzeSyntaxResponse """ client = self.get_conn() @@ -196,12 +193,12 @@ def analyze_syntax( @GoogleBaseHook.quota_retry() def annotate_text( self, - document: Union[dict, Document], - features: Union[dict, AnnotateTextRequest.Features], + document: dict | Document, + features: dict | AnnotateTextRequest.Features, encoding_type: enums.EncodingType = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> AnnotateTextResponse: """ A convenience method that provides all the features that analyzeSentiment, @@ -217,7 +214,6 @@ def annotate_text( :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: Additional metadata that is provided to the method. - :rtype: google.cloud.language_v1.types.AnnotateTextResponse """ client = self.get_conn() @@ -233,10 +229,10 @@ def annotate_text( @GoogleBaseHook.quota_retry() def classify_text( self, - document: Union[dict, Document], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + document: dict | Document, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ClassifyTextResponse: """ Classifies a document into categories. @@ -248,7 +244,6 @@ def classify_text( :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: Additional metadata that is provided to the method. - :rtype: google.cloud.language_v1.types.ClassifyTextResponse """ client = self.get_conn() diff --git a/airflow/providers/google/cloud/hooks/os_login.py b/airflow/providers/google/cloud/hooks/os_login.py index 4d9666cab7b30..570410f73636c 100644 --- a/airflow/providers/google/cloud/hooks/os_login.py +++ b/airflow/providers/google/cloud/hooks/os_login.py @@ -20,9 +20,9 @@ ImportSshPublicKeyResponse oslogin """ +from __future__ import annotations - -from typing import Dict, Optional, Sequence, Tuple, Union +from typing import Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -42,34 +42,34 @@ class OSLoginHook(GoogleBaseHook): def __init__( self, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) - self._conn = None # type: Optional[OsLoginServiceClient] + self._conn: OsLoginServiceClient | None = None def get_conn(self) -> OsLoginServiceClient: """Return OS Login service client""" if self._conn: return self._conn - self._conn = OsLoginServiceClient(credentials=self._get_credentials(), client_info=CLIENT_INFO) + self._conn = OsLoginServiceClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) return self._conn @GoogleBaseHook.fallback_to_default_project_id def import_ssh_public_key( self, user: str, - ssh_public_key: Dict, + ssh_public_key: dict, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ImportSshPublicKeyResponse: """ Adds an SSH public key and returns the profile information. Default POSIX diff --git a/airflow/providers/google/cloud/hooks/pubsub.py b/airflow/providers/google/cloud/hooks/pubsub.py index 535f912f8b777..fe717ec1c5cd6 100644 --- a/airflow/providers/google/cloud/hooks/pubsub.py +++ b/airflow/providers/google/cloud/hooks/pubsub.py @@ -23,19 +23,13 @@ MessageStoragePolicy ReceivedMessage """ -import sys +from __future__ import annotations + import warnings from base64 import b64decode -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Sequence from uuid import uuid4 -from airflow.providers.google.common.consts import CLIENT_INFO - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - from google.api_core.exceptions import AlreadyExists, GoogleAPICallError from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -52,6 +46,8 @@ ) from googleapiclient.errors import HttpError +from airflow.compat.functools import cached_property +from airflow.providers.google.common.consts import CLIENT_INFO from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook from airflow.version import version @@ -71,8 +67,8 @@ class PubSubHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -86,10 +82,9 @@ def get_conn(self) -> PublisherClient: Retrieves connection to Google Cloud Pub/Sub. :return: Google Cloud Pub/Sub client object. - :rtype: google.cloud.pubsub_v1.PublisherClient """ if not self._client: - self._client = PublisherClient(credentials=self._get_credentials(), client_info=CLIENT_INFO) + self._client = PublisherClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) return self._client @cached_property @@ -98,15 +93,14 @@ def subscriber_client(self) -> SubscriberClient: Creates SubscriberClient. :return: Google Cloud Pub/Sub client object. - :rtype: google.cloud.pubsub_v1.SubscriberClient """ - return SubscriberClient(credentials=self._get_credentials(), client_info=CLIENT_INFO) + return SubscriberClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) @GoogleBaseHook.fallback_to_default_project_id def publish( self, topic: str, - messages: List[dict], + messages: list[dict], project_id: str = PROVIDE_PROJECT_ID, ) -> None: """ @@ -129,11 +123,11 @@ def publish( try: for message in messages: future = publisher.publish( - topic=topic_path, data=message.get("data", b''), **message.get('attributes', {}) + topic=topic_path, data=message.get("data", b""), **message.get("attributes", {}) ) future.result() except GoogleAPICallError as e: - raise PubSubException(f'Error publishing to topic {topic_path}', e) + raise PubSubException(f"Error publishing to topic {topic_path}", e) self.log.info("Published %d messages to topic (path) %s", len(messages), topic_path) @@ -173,12 +167,12 @@ def create_topic( topic: str, project_id: str = PROVIDE_PROJECT_ID, fail_if_exists: bool = False, - labels: Optional[Dict[str, str]] = None, - message_storage_policy: Union[Dict, MessageStoragePolicy] = None, - kms_key_name: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + labels: dict[str, str] | None = None, + message_storage_policy: dict | MessageStoragePolicy = None, + kms_key_name: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Creates a Pub/Sub topic, if it does not already exist. @@ -195,7 +189,7 @@ def create_topic( of Google Cloud regions where messages published to the topic may be stored. If not present, then no constraints are in effect. - Union[Dict, google.cloud.pubsub_v1.types.MessageStoragePolicy] + Union[dict, google.cloud.pubsub_v1.types.MessageStoragePolicy] :param kms_key_name: The resource name of the Cloud KMS CryptoKey to be used to protect access to messages published on this topic. The expected format is @@ -212,7 +206,7 @@ def create_topic( # Add airflow-version label to the topic labels = labels or {} - labels['airflow-version'] = 'v' + version.replace('.', '-').replace('+', '-') + labels["airflow-version"] = "v" + version.replace(".", "-").replace("+", "-") self.log.info("Creating topic (path) %s", topic_path) try: @@ -229,11 +223,11 @@ def create_topic( metadata=metadata, ) except AlreadyExists: - self.log.warning('Topic already exists: %s', topic) + self.log.warning("Topic already exists: %s", topic) if fail_if_exists: - raise PubSubException(f'Topic already exists: {topic}') + raise PubSubException(f"Topic already exists: {topic}") except GoogleAPICallError as e: - raise PubSubException(f'Error creating topic {topic}', e) + raise PubSubException(f"Error creating topic {topic}", e) self.log.info("Created topic (path) %s", topic_path) @@ -243,9 +237,9 @@ def delete_topic( topic: str, project_id: str = PROVIDE_PROJECT_ID, fail_if_not_exists: bool = False, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Deletes a Pub/Sub topic if it exists. @@ -273,11 +267,11 @@ def delete_topic( request={"topic": topic_path}, retry=retry, timeout=timeout, metadata=metadata or () ) except NotFound: - self.log.warning('Topic does not exist: %s', topic_path) + self.log.warning("Topic does not exist: %s", topic_path) if fail_if_not_exists: - raise PubSubException(f'Topic does not exist: {topic_path}') + raise PubSubException(f"Topic does not exist: {topic_path}") except GoogleAPICallError as e: - raise PubSubException(f'Error deleting topic {topic}', e) + raise PubSubException(f"Error deleting topic {topic}", e) self.log.info("Deleted topic (path) %s", topic_path) @GoogleBaseHook.fallback_to_default_project_id @@ -285,22 +279,22 @@ def create_subscription( self, topic: str, project_id: str = PROVIDE_PROJECT_ID, - subscription: Optional[str] = None, - subscription_project_id: Optional[str] = None, + subscription: str | None = None, + subscription_project_id: str | None = None, ack_deadline_secs: int = 10, fail_if_exists: bool = False, - push_config: Optional[Union[dict, PushConfig]] = None, - retain_acked_messages: Optional[bool] = None, - message_retention_duration: Optional[Union[dict, Duration]] = None, - labels: Optional[Dict[str, str]] = None, + push_config: dict | PushConfig | None = None, + retain_acked_messages: bool | None = None, + message_retention_duration: dict | Duration | None = None, + labels: dict[str, str] | None = None, enable_message_ordering: bool = False, - expiration_policy: Optional[Union[dict, ExpirationPolicy]] = None, - filter_: Optional[str] = None, - dead_letter_policy: Optional[Union[dict, DeadLetterPolicy]] = None, - retry_policy: Optional[Union[dict, RetryPolicy]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + expiration_policy: dict | ExpirationPolicy | None = None, + filter_: str | None = None, + dead_letter_policy: dict | DeadLetterPolicy | None = None, + retry_policy: dict | RetryPolicy | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> str: """ Creates a Pub/Sub subscription, if it does not already exist. @@ -339,7 +333,7 @@ def create_subscription( in which they are received by the Pub/Sub system. Otherwise, they may be delivered in any order. :param expiration_policy: A policy that specifies the conditions for this - subscription’s expiration. A subscription is considered active as long as any + subscription's expiration. A subscription is considered active as long as any connected subscriber is successfully consuming messages from the subscription or is issuing operations on the subscription. If expiration_policy is not set, a default policy with ttl of 31 days will be used. The minimum allowed value for @@ -363,18 +357,17 @@ def create_subscription( :param metadata: (Optional) Additional metadata that is provided to the method. :return: subscription name which will be the system-generated value if the ``subscription`` parameter is not supplied - :rtype: str """ subscriber = self.subscriber_client if not subscription: - subscription = f'sub-{uuid4()}' + subscription = f"sub-{uuid4()}" if not subscription_project_id: subscription_project_id = project_id # Add airflow-version label to the subscription labels = labels or {} - labels['airflow-version'] = 'v' + version.replace('.', '-').replace('+', '-') + labels["airflow-version"] = "v" + version.replace(".", "-").replace("+", "-") subscription_path = f"projects/{subscription_project_id}/subscriptions/{subscription}" topic_path = f"projects/{project_id}/topics/{topic}" @@ -401,11 +394,11 @@ def create_subscription( metadata=metadata, ) except AlreadyExists: - self.log.warning('Subscription already exists: %s', subscription_path) + self.log.warning("Subscription already exists: %s", subscription_path) if fail_if_exists: - raise PubSubException(f'Subscription already exists: {subscription_path}') + raise PubSubException(f"Subscription already exists: {subscription_path}") except GoogleAPICallError as e: - raise PubSubException(f'Error creating subscription {subscription_path}', e) + raise PubSubException(f"Error creating subscription {subscription_path}", e) self.log.info("Created subscription (path) %s for topic (path) %s", subscription_path, topic_path) return subscription @@ -416,9 +409,9 @@ def delete_subscription( subscription: str, project_id: str = PROVIDE_PROJECT_ID, fail_if_not_exists: bool = False, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Deletes a Pub/Sub subscription, if it exists. @@ -450,11 +443,11 @@ def delete_subscription( ) except NotFound: - self.log.warning('Subscription does not exist: %s', subscription_path) + self.log.warning("Subscription does not exist: %s", subscription_path) if fail_if_not_exists: - raise PubSubException(f'Subscription does not exist: {subscription_path}') + raise PubSubException(f"Subscription does not exist: {subscription_path}") except GoogleAPICallError as e: - raise PubSubException(f'Error deleting subscription {subscription_path}', e) + raise PubSubException(f"Error deleting subscription {subscription_path}", e) self.log.info("Deleted subscription (path) %s", subscription_path) @@ -465,10 +458,10 @@ def pull( max_messages: int, project_id: str = PROVIDE_PROJECT_ID, return_immediately: bool = False, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> List[ReceivedMessage]: + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> list[ReceivedMessage]: """ Pulls up to ``max_messages`` messages from Pub/Sub subscription. @@ -509,22 +502,22 @@ def pull( timeout=timeout, metadata=metadata, ) - result = getattr(response, 'received_messages', []) + result = getattr(response, "received_messages", []) self.log.info("Pulled %d messages from subscription (path) %s", len(result), subscription_path) return result except (HttpError, GoogleAPICallError) as e: - raise PubSubException(f'Error pulling messages from subscription {subscription_path}', e) + raise PubSubException(f"Error pulling messages from subscription {subscription_path}", e) @GoogleBaseHook.fallback_to_default_project_id def acknowledge( self, subscription: str, project_id: str, - ack_ids: Optional[List[str]] = None, - messages: Optional[List[ReceivedMessage]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + ack_ids: list[str] | None = None, + messages: list[ReceivedMessage] | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Acknowledges the messages associated with the ``ack_ids`` from Pub/Sub subscription. @@ -566,7 +559,7 @@ def acknowledge( ) except (HttpError, GoogleAPICallError) as e: raise PubSubException( - f'Error acknowledging {len(ack_ids)} messages pulled from subscription {subscription_path}', + f"Error acknowledging {len(ack_ids)} messages pulled from subscription {subscription_path}", e, ) diff --git a/airflow/providers/google/cloud/hooks/secret_manager.py b/airflow/providers/google/cloud/hooks/secret_manager.py index a584da267e05c..d005a6f657eb4 100644 --- a/airflow/providers/google/cloud/hooks/secret_manager.py +++ b/airflow/providers/google/cloud/hooks/secret_manager.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """Hook for Secrets Manager service""" -from typing import Optional, Sequence, Union +from __future__ import annotations + +from typing import Sequence from airflow.providers.google.cloud._internal_client.secret_manager_client import _SecretManagerClient from airflow.providers.google.common.hooks.base_google import GoogleBaseHook @@ -48,29 +50,28 @@ class SecretsManagerHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) - self.client = _SecretManagerClient(credentials=self._get_credentials()) + self.client = _SecretManagerClient(credentials=self.get_credentials()) def get_conn(self) -> _SecretManagerClient: """ Retrieves the connection to Secret Manager. :return: Secret Manager client. - :rtype: airflow.providers.google.cloud._internal_client.secret_manager_client._SecretManagerClient """ return self.client @GoogleBaseHook.fallback_to_default_project_id def get_secret( - self, secret_id: str, secret_version: str = 'latest', project_id: Optional[str] = None - ) -> Optional[str]: + self, secret_id: str, secret_version: str = "latest", project_id: str | None = None + ) -> str | None: """ Get secret value from the Secret Manager. diff --git a/airflow/providers/google/cloud/hooks/spanner.py b/airflow/providers/google/cloud/hooks/spanner.py index 3c76b9c9157cf..31aa3bba1feed 100644 --- a/airflow/providers/google/cloud/hooks/spanner.py +++ b/airflow/providers/google/cloud/hooks/spanner.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Cloud Spanner Hook.""" -from typing import Callable, List, Optional, Sequence, Union +from __future__ import annotations + +from typing import Callable, Sequence from google.api_core.exceptions import AlreadyExists, GoogleAPICallError from google.cloud.spanner_v1.client import Client @@ -41,8 +43,8 @@ class SpannerHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -57,11 +59,10 @@ def _get_client(self, project_id: str) -> Client: :param project_id: The ID of the Google Cloud project. :return: Client - :rtype: google.cloud.spanner_v1.client.Client """ if not self._client: self._client = Client( - project=project_id, credentials=self._get_credentials(), client_info=CLIENT_INFO + project=project_id, credentials=self.get_credentials(), client_info=CLIENT_INFO ) return self._client @@ -79,7 +80,6 @@ def get_instance( is used. :param instance_id: The ID of the Cloud Spanner instance. :return: Spanner instance - :rtype: google.cloud.spanner_v1.instance.Instance """ instance = self._get_client(project_id=project_id).instance(instance_id=instance_id) if not instance.exists(): @@ -115,9 +115,9 @@ def _apply_to_instance( display_name=display_name, ) try: - operation = func(instance) # type: Operation + operation: Operation = func(instance) except GoogleAPICallError as e: - self.log.error('An error occurred: %s. Exiting.', e.message) + self.log.error("An error occurred: %s. Exiting.", e.message) raise e if operation: @@ -200,7 +200,7 @@ def delete_instance(self, instance_id: str, project_id: str) -> None: instance.delete() return except GoogleAPICallError as e: - self.log.error('An error occurred: %s. Exiting.', e.message) + self.log.error("An error occurred: %s. Exiting.", e.message) raise e @GoogleBaseHook.fallback_to_default_project_id @@ -209,7 +209,7 @@ def get_database( instance_id: str, database_id: str, project_id: str, - ) -> Optional[Database]: + ) -> Database | None: """ Retrieves a database in Cloud Spanner. If the database does not exist in the specified instance, it returns None. @@ -220,7 +220,6 @@ def get_database( database. If set to None or missing, the default project_id from the Google Cloud connection is used. :return: Database object or None if database does not exist - :rtype: google.cloud.spanner_v1.database.Database or None """ instance = self._get_client(project_id=project_id).instance(instance_id=instance_id) if not instance.exists(): @@ -236,7 +235,7 @@ def create_database( self, instance_id: str, database_id: str, - ddl_statements: List[str], + ddl_statements: list[str], project_id: str, ) -> None: """ @@ -255,9 +254,9 @@ def create_database( raise AirflowException(f"The instance {instance_id} does not exist in project {project_id} !") database = instance.database(database_id=database_id, ddl_statements=ddl_statements) try: - operation = database.create() # type: Operation + operation: Operation = database.create() except GoogleAPICallError as e: - self.log.error('An error occurred: %s. Exiting.', e.message) + self.log.error("An error occurred: %s. Exiting.", e.message) raise e if operation: @@ -269,9 +268,9 @@ def update_database( self, instance_id: str, database_id: str, - ddl_statements: List[str], + ddl_statements: list[str], project_id: str, - operation_id: Optional[str] = None, + operation_id: str | None = None, ) -> None: """ Updates DDL of a database in Cloud Spanner. @@ -304,7 +303,7 @@ def update_database( ) return except GoogleAPICallError as e: - self.log.error('An error occurred: %s. Exiting.', e.message) + self.log.error("An error occurred: %s. Exiting.", e.message) raise e @GoogleBaseHook.fallback_to_default_project_id @@ -318,7 +317,6 @@ def delete_database(self, instance_id: str, database_id, project_id: str) -> boo database. If set to None or missing, the default project_id from the Google Cloud connection is used. :return: True if everything succeeded - :rtype: bool """ instance = self._get_client(project_id=project_id).instance(instance_id=instance_id) if not instance.exists(): @@ -332,7 +330,7 @@ def delete_database(self, instance_id: str, database_id, project_id: str) -> boo try: database.drop() except GoogleAPICallError as e: - self.log.error('An error occurred: %s. Exiting.', e.message) + self.log.error("An error occurred: %s. Exiting.", e.message) raise e return True @@ -342,7 +340,7 @@ def execute_dml( self, instance_id: str, database_id: str, - queries: List[str], + queries: list[str], project_id: str, ) -> None: """ @@ -360,6 +358,6 @@ def execute_dml( ).run_in_transaction(lambda transaction: self._execute_sql_in_transaction(transaction, queries)) @staticmethod - def _execute_sql_in_transaction(transaction: Transaction, queries: List[str]): + def _execute_sql_in_transaction(transaction: Transaction, queries: list[str]): for sql in queries: transaction.execute_update(sql) diff --git a/airflow/providers/google/cloud/hooks/speech_to_text.py b/airflow/providers/google/cloud/hooks/speech_to_text.py index c9d5b4c373812..912d15b43131d 100644 --- a/airflow/providers/google/cloud/hooks/speech_to_text.py +++ b/airflow/providers/google/cloud/hooks/speech_to_text.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Cloud Speech Hook.""" -from typing import Dict, Optional, Sequence, Union +from __future__ import annotations + +from typing import Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -48,8 +50,8 @@ class CloudSpeechToTextHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -63,19 +65,18 @@ def get_conn(self) -> SpeechClient: Retrieves connection to Cloud Speech. :return: Google Cloud Speech client object. - :rtype: google.cloud.speech_v1.SpeechClient """ if not self._client: - self._client = SpeechClient(credentials=self._get_credentials(), client_info=CLIENT_INFO) + self._client = SpeechClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) return self._client @GoogleBaseHook.quota_retry() def recognize_speech( self, - config: Union[Dict, RecognitionConfig], - audio: Union[Dict, RecognitionAudio], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + config: dict | RecognitionConfig, + audio: dict | RecognitionAudio, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, ): """ Recognizes audio input diff --git a/airflow/providers/google/cloud/hooks/stackdriver.py b/airflow/providers/google/cloud/hooks/stackdriver.py index 851823a7e609d..7b6e3fe94b603 100644 --- a/airflow/providers/google/cloud/hooks/stackdriver.py +++ b/airflow/providers/google/cloud/hooks/stackdriver.py @@ -15,11 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains Google Cloud Stackdriver operators.""" +from __future__ import annotations import json -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Any, Sequence from google.api_core.exceptions import InvalidArgument from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -39,8 +39,8 @@ class StackdriverHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -64,13 +64,13 @@ def _get_channel_client(self): def list_alert_policies( self, project_id: str = PROVIDE_PROJECT_ID, - format_: Optional[str] = None, - filter_: Optional[str] = None, - order_by: Optional[str] = None, - page_size: Optional[int] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + format_: str | None = None, + filter_: str | None = None, + order_by: str | None = None, + page_size: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Any: """ Fetches all the Alert Policies identified by the filter passed as @@ -105,10 +105,10 @@ def list_alert_policies( client = self._get_policy_client() policies_ = client.list_alert_policies( request={ - 'name': f'projects/{project_id}', - 'filter': filter_, - 'order_by': order_by, - 'page_size': page_size, + "name": f"projects/{project_id}", + "filter": filter_, + "order_by": order_by, + "page_size": page_size, }, retry=retry, timeout=timeout, @@ -126,19 +126,19 @@ def _toggle_policy_status( self, new_state: bool, project_id: str = PROVIDE_PROJECT_ID, - filter_: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + filter_: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): client = self._get_policy_client() policies_ = self.list_alert_policies(project_id=project_id, filter_=filter_) for policy in policies_: if policy.enabled != bool(new_state): policy.enabled = bool(new_state) - mask = FieldMask(paths=['enabled']) + mask = FieldMask(paths=["enabled"]) client.update_alert_policy( - request={'alert_policy': policy, 'update_mask': mask}, + request={"alert_policy": policy, "update_mask": mask}, retry=retry, timeout=timeout, metadata=metadata, @@ -148,10 +148,10 @@ def _toggle_policy_status( def enable_alert_policies( self, project_id: str = PROVIDE_PROJECT_ID, - filter_: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + filter_: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Enables one or more disabled alerting policies identified by filter @@ -181,10 +181,10 @@ def enable_alert_policies( def disable_alert_policies( self, project_id: str = PROVIDE_PROJECT_ID, - filter_: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + filter_: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Disables one or more enabled alerting policies identified by filter @@ -215,9 +215,9 @@ def upsert_alert( self, alerts: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Creates a new alert or updates an existing policy identified @@ -240,11 +240,11 @@ def upsert_alert( record = json.loads(alerts) existing_policies = [ - policy['name'] for policy in self.list_alert_policies(project_id=project_id, format_='dict') + policy["name"] for policy in self.list_alert_policies(project_id=project_id, format_="dict") ] existing_channels = [ - channel['name'] - for channel in self.list_notification_channels(project_id=project_id, format_='dict') + channel["name"] + for channel in self.list_notification_channels(project_id=project_id, format_="dict") ] policies_ = [] channels = [] @@ -262,7 +262,7 @@ def upsert_alert( if channel.name in existing_channels: channel_client.update_notification_channel( - request={'notification_channel': channel}, + request={"notification_channel": channel}, retry=retry, timeout=timeout, metadata=metadata, @@ -271,7 +271,7 @@ def upsert_alert( old_name = channel.name channel.name = None new_channel = channel_client.create_notification_channel( - request={'name': f'projects/{project_id}', 'notification_channel': channel}, + request={"name": f"projects/{project_id}", "notification_channel": channel}, retry=retry, timeout=timeout, metadata=metadata, @@ -290,7 +290,7 @@ def upsert_alert( if policy.name in existing_policies: try: policy_client.update_alert_policy( - request={'alert_policy': policy}, + request={"alert_policy": policy}, retry=retry, timeout=timeout, metadata=metadata, @@ -302,7 +302,7 @@ def upsert_alert( for condition in policy.conditions: condition.name = None policy_client.create_alert_policy( - request={'name': f'projects/{project_id}', 'alert_policy': policy}, + request={"name": f"projects/{project_id}", "alert_policy": policy}, retry=retry, timeout=timeout, metadata=metadata, @@ -311,9 +311,9 @@ def upsert_alert( def delete_alert_policy( self, name: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Deletes an alerting policy. @@ -330,22 +330,22 @@ def delete_alert_policy( policy_client = self._get_policy_client() try: policy_client.delete_alert_policy( - request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or () + request={"name": name}, retry=retry, timeout=timeout, metadata=metadata or () ) except HttpError as err: - raise AirflowException(f'Delete alerting policy failed. Error was {err.content}') + raise AirflowException(f"Delete alerting policy failed. Error was {err.content}") @GoogleBaseHook.fallback_to_default_project_id def list_notification_channels( self, project_id: str = PROVIDE_PROJECT_ID, - format_: Optional[str] = None, - filter_: Optional[str] = None, - order_by: Optional[str] = None, - page_size: Optional[int] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + format_: str | None = None, + filter_: str | None = None, + order_by: str | None = None, + page_size: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Any: """ Fetches all the Notification Channels identified by the filter passed as @@ -380,10 +380,10 @@ def list_notification_channels( client = self._get_channel_client() channels = client.list_notification_channels( request={ - 'name': f'projects/{project_id}', - 'filter': filter_, - 'order_by': order_by, - 'page_size': page_size, + "name": f"projects/{project_id}", + "filter": filter_, + "order_by": order_by, + "page_size": page_size, }, retry=retry, timeout=timeout, @@ -401,21 +401,21 @@ def _toggle_channel_status( self, new_state: bool, project_id: str = PROVIDE_PROJECT_ID, - filter_: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + filter_: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: client = self._get_channel_client() channels = client.list_notification_channels( - request={'name': f'projects/{project_id}', 'filter': filter_} + request={"name": f"projects/{project_id}", "filter": filter_} ) for channel in channels: if channel.enabled != bool(new_state): channel.enabled = bool(new_state) - mask = FieldMask(paths=['enabled']) + mask = FieldMask(paths=["enabled"]) client.update_notification_channel( - request={'notification_channel': channel, 'update_mask': mask}, + request={"notification_channel": channel, "update_mask": mask}, retry=retry, timeout=timeout, metadata=metadata, @@ -425,10 +425,10 @@ def _toggle_channel_status( def enable_notification_channels( self, project_id: str = PROVIDE_PROJECT_ID, - filter_: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + filter_: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Enables one or more disabled alerting policies identified by filter @@ -458,10 +458,10 @@ def enable_notification_channels( def disable_notification_channels( self, project_id: str, - filter_: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + filter_: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Disables one or more enabled notification channels identified by filter @@ -492,9 +492,9 @@ def upsert_channel( self, channels: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> dict: """ Creates a new notification or updates an existing notification channel @@ -532,7 +532,7 @@ def upsert_channel( if channel.name in existing_channels: channel_client.update_notification_channel( - request={'notification_channel': channel}, + request={"notification_channel": channel}, retry=retry, timeout=timeout, metadata=metadata, @@ -541,7 +541,7 @@ def upsert_channel( old_name = channel.name channel.name = None new_channel = channel_client.create_notification_channel( - request={'name': f'projects/{project_id}', 'notification_channel': channel}, + request={"name": f"projects/{project_id}", "notification_channel": channel}, retry=retry, timeout=timeout, metadata=metadata, @@ -553,9 +553,9 @@ def upsert_channel( def delete_notification_channel( self, name: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Deletes a notification channel. @@ -572,7 +572,7 @@ def delete_notification_channel( channel_client = self._get_channel_client() try: channel_client.delete_notification_channel( - request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or () + request={"name": name}, retry=retry, timeout=timeout, metadata=metadata or () ) except HttpError as err: - raise AirflowException(f'Delete notification channel failed. Error was {err.content}') + raise AirflowException(f"Delete notification channel failed. Error was {err.content}") diff --git a/airflow/providers/google/cloud/hooks/tasks.py b/airflow/providers/google/cloud/hooks/tasks.py index 2265e36a79a69..6d685c0506aaf 100644 --- a/airflow/providers/google/cloud/hooks/tasks.py +++ b/airflow/providers/google/cloud/hooks/tasks.py @@ -15,14 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ This module contains a CloudTasksHook which allows you to connect to Google Cloud Tasks service, performing actions to queues or tasks. """ +from __future__ import annotations -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -60,37 +60,36 @@ class CloudTasksHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) - self._client: Optional[CloudTasksClient] = None + self._client: CloudTasksClient | None = None def get_conn(self) -> CloudTasksClient: """ Provides a client for interacting with the Google Cloud Tasks API. :return: Google Cloud Tasks API Client - :rtype: google.cloud.tasks_v2.CloudTasksClient """ if self._client is None: - self._client = CloudTasksClient(credentials=self._get_credentials(), client_info=CLIENT_INFO) + self._client = CloudTasksClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) return self._client @GoogleBaseHook.fallback_to_default_project_id def create_queue( self, location: str, - task_queue: Union[dict, Queue], + task_queue: dict | Queue, project_id: str = PROVIDE_PROJECT_ID, - queue_name: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + queue_name: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Queue: """ Creates a queue in Cloud Tasks. @@ -109,7 +108,6 @@ def create_queue( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.tasks_v2.types.Queue """ client = self.get_conn() @@ -118,12 +116,12 @@ def create_queue( if isinstance(task_queue, Queue): task_queue.name = full_queue_name elif isinstance(task_queue, dict): - task_queue['name'] = full_queue_name + task_queue["name"] = full_queue_name else: - raise AirflowException('Unable to set queue_name.') + raise AirflowException("Unable to set queue_name.") full_location_path = f"projects/{project_id}/locations/{location}" return client.create_queue( - request={'parent': full_location_path, 'queue': task_queue}, + request={"parent": full_location_path, "queue": task_queue}, retry=retry, timeout=timeout, metadata=metadata, @@ -134,12 +132,12 @@ def update_queue( self, task_queue: Queue, project_id: str = PROVIDE_PROJECT_ID, - location: Optional[str] = None, - queue_name: Optional[str] = None, - update_mask: Optional[FieldMask] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + location: str | None = None, + queue_name: str | None = None, + update_mask: FieldMask | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Queue: """ Updates a queue in Cloud Tasks. @@ -162,7 +160,6 @@ def update_queue( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.tasks_v2.types.Queue """ client = self.get_conn() @@ -171,11 +168,11 @@ def update_queue( if isinstance(task_queue, Queue): task_queue.name = full_queue_name elif isinstance(task_queue, dict): - task_queue['name'] = full_queue_name + task_queue["name"] = full_queue_name else: - raise AirflowException('Unable to set queue_name.') + raise AirflowException("Unable to set queue_name.") return client.update_queue( - request={'queue': task_queue, 'update_mask': update_mask}, + request={"queue": task_queue, "update_mask": update_mask}, retry=retry, timeout=timeout, metadata=metadata, @@ -187,9 +184,9 @@ def get_queue( location: str, queue_name: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Queue: """ Gets a queue from Cloud Tasks. @@ -204,13 +201,12 @@ def get_queue( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.tasks_v2.types.Queue """ client = self.get_conn() full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}" return client.get_queue( - request={'name': full_queue_name}, + request={"name": full_queue_name}, retry=retry, timeout=timeout, metadata=metadata, @@ -221,12 +217,12 @@ def list_queues( self, location: str, project_id: str = PROVIDE_PROJECT_ID, - results_filter: Optional[str] = None, - page_size: Optional[int] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> List[Queue]: + results_filter: str | None = None, + page_size: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> list[Queue]: """ Lists queues from Cloud Tasks. @@ -242,13 +238,12 @@ def list_queues( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: list[google.cloud.tasks_v2.types.Queue] """ client = self.get_conn() full_location_path = f"projects/{project_id}/locations/{location}" queues = client.list_queues( - request={'parent': full_location_path, 'filter': results_filter, 'page_size': page_size}, + request={"parent": full_location_path, "filter": results_filter, "page_size": page_size}, retry=retry, timeout=timeout, metadata=metadata, @@ -261,9 +256,9 @@ def delete_queue( location: str, queue_name: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Deletes a queue from Cloud Tasks, even if it has tasks in it. @@ -283,7 +278,7 @@ def delete_queue( full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}" client.delete_queue( - request={'name': full_queue_name}, + request={"name": full_queue_name}, retry=retry, timeout=timeout, metadata=metadata, @@ -295,9 +290,9 @@ def purge_queue( location: str, queue_name: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Queue: """ Purges a queue by deleting all of its tasks from Cloud Tasks. @@ -312,13 +307,12 @@ def purge_queue( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: list[google.cloud.tasks_v2.types.Queue] """ client = self.get_conn() full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}" return client.purge_queue( - request={'name': full_queue_name}, + request={"name": full_queue_name}, retry=retry, timeout=timeout, metadata=metadata, @@ -330,9 +324,9 @@ def pause_queue( location: str, queue_name: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Queue: """ Pauses a queue in Cloud Tasks. @@ -347,13 +341,12 @@ def pause_queue( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: list[google.cloud.tasks_v2.types.Queue] """ client = self.get_conn() full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}" return client.pause_queue( - request={'name': full_queue_name}, + request={"name": full_queue_name}, retry=retry, timeout=timeout, metadata=metadata, @@ -365,9 +358,9 @@ def resume_queue( location: str, queue_name: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Queue: """ Resumes a queue in Cloud Tasks. @@ -382,13 +375,12 @@ def resume_queue( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: list[google.cloud.tasks_v2.types.Queue] """ client = self.get_conn() full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}" return client.resume_queue( - request={'name': full_queue_name}, + request={"name": full_queue_name}, retry=retry, timeout=timeout, metadata=metadata, @@ -399,13 +391,13 @@ def create_task( self, location: str, queue_name: str, - task: Union[Dict, Task], + task: dict | Task, project_id: str = PROVIDE_PROJECT_ID, - task_name: Optional[str] = None, - response_view: Optional[Task.View] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + task_name: str | None = None, + response_view: Task.View | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Task: """ Creates a task in Cloud Tasks. @@ -426,7 +418,6 @@ def create_task( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.tasks_v2.types.Task """ client = self.get_conn() @@ -437,12 +428,12 @@ def create_task( if isinstance(task, Task): task.name = full_task_name elif isinstance(task, dict): - task['name'] = full_task_name + task["name"] = full_task_name else: - raise AirflowException('Unable to set task_name.') + raise AirflowException("Unable to set task_name.") full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}" return client.create_task( - request={'parent': full_queue_name, 'task': task, 'response_view': response_view}, + request={"parent": full_queue_name, "task": task, "response_view": response_view}, retry=retry, timeout=timeout, metadata=metadata, @@ -455,10 +446,10 @@ def get_task( queue_name: str, task_name: str, project_id: str = PROVIDE_PROJECT_ID, - response_view: Optional[Task.View] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + response_view: Task.View | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Task: """ Gets a task from Cloud Tasks. @@ -476,13 +467,12 @@ def get_task( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.tasks_v2.types.Task """ client = self.get_conn() full_task_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}" return client.get_task( - request={'name': full_task_name, 'response_view': response_view}, + request={"name": full_task_name, "response_view": response_view}, retry=retry, timeout=timeout, metadata=metadata, @@ -494,12 +484,12 @@ def list_tasks( location: str, queue_name: str, project_id: str, - response_view: Optional[Task.View] = None, - page_size: Optional[int] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> List[Task]: + response_view: Task.View | None = None, + page_size: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> list[Task]: """ Lists the tasks in Cloud Tasks. @@ -517,12 +507,11 @@ def list_tasks( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: list[google.cloud.tasks_v2.types.Task] """ client = self.get_conn() full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}" tasks = client.list_tasks( - request={'parent': full_queue_name, 'response_view': response_view, 'page_size': page_size}, + request={"parent": full_queue_name, "response_view": response_view, "page_size": page_size}, retry=retry, timeout=timeout, metadata=metadata, @@ -536,9 +525,9 @@ def delete_task( queue_name: str, task_name: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Deletes a task from Cloud Tasks. @@ -559,7 +548,7 @@ def delete_task( full_task_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}" client.delete_task( - request={'name': full_task_name}, + request={"name": full_task_name}, retry=retry, timeout=timeout, metadata=metadata, @@ -572,10 +561,10 @@ def run_task( queue_name: str, task_name: str, project_id: str, - response_view: Optional[Task.View] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + response_view: Task.View | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Task: """ Forces to run a task in Cloud Tasks. @@ -593,13 +582,12 @@ def run_task( to complete. Note that if retry is specified, the timeout applies to each individual attempt. :param metadata: (Optional) Additional metadata that is provided to the method. - :rtype: google.cloud.tasks_v2.types.Task """ client = self.get_conn() full_task_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}" return client.run_task( - request={'name': full_task_name, 'response_view': response_view}, + request={"name": full_task_name, "response_view": response_view}, retry=retry, timeout=timeout, metadata=metadata, diff --git a/airflow/providers/google/cloud/hooks/text_to_speech.py b/airflow/providers/google/cloud/hooks/text_to_speech.py index 72e97ab81668f..9b838d8547dfb 100644 --- a/airflow/providers/google/cloud/hooks/text_to_speech.py +++ b/airflow/providers/google/cloud/hooks/text_to_speech.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Cloud Text to Speech Hook.""" -from typing import Dict, Optional, Sequence, Union +from __future__ import annotations + +from typing import Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -56,37 +58,36 @@ class CloudTextToSpeechHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) - self._client = None # type: Optional[TextToSpeechClient] + self._client: TextToSpeechClient | None = None def get_conn(self) -> TextToSpeechClient: """ Retrieves connection to Cloud Text to Speech. :return: Google Cloud Text to Speech client object. - :rtype: google.cloud.texttospeech_v1.TextToSpeechClient """ if not self._client: - self._client = TextToSpeechClient(credentials=self._get_credentials(), client_info=CLIENT_INFO) + self._client = TextToSpeechClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) return self._client @GoogleBaseHook.quota_retry() def synthesize_speech( self, - input_data: Union[Dict, SynthesisInput], - voice: Union[Dict, VoiceSelectionParams], - audio_config: Union[Dict, AudioConfig], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + input_data: dict | SynthesisInput, + voice: dict | VoiceSelectionParams, + audio_config: dict | AudioConfig, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, ) -> SynthesizeSpeechResponse: """ Synthesizes text input @@ -103,7 +104,6 @@ def synthesize_speech( Note that if retry is specified, the timeout applies to each individual attempt. :return: SynthesizeSpeechResponse See more: https://googleapis.github.io/google-cloud-python/latest/texttospeech/gapic/v1/types.html#google.cloud.texttospeech_v1.types.SynthesizeSpeechResponse - :rtype: object """ client = self.get_conn() self.log.info("Synthesizing input: %s", input_data) diff --git a/airflow/providers/google/cloud/hooks/translate.py b/airflow/providers/google/cloud/hooks/translate.py index 037c230641ef7..d0a51fdb6060d 100644 --- a/airflow/providers/google/cloud/hooks/translate.py +++ b/airflow/providers/google/cloud/hooks/translate.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Cloud Translate Hook.""" -from typing import List, Optional, Sequence, Union +from __future__ import annotations + +from typing import Sequence from google.cloud.translate_v2 import Client @@ -35,35 +37,34 @@ class CloudTranslateHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) - self._client = None # type: Optional[Client] + self._client: Client | None = None def get_conn(self) -> Client: """ Retrieves connection to Cloud Translate :return: Google Cloud Translate client object. - :rtype: google.cloud.translate_v2.Client """ if not self._client: - self._client = Client(credentials=self._get_credentials(), client_info=CLIENT_INFO) + self._client = Client(credentials=self.get_credentials(), client_info=CLIENT_INFO) return self._client @GoogleBaseHook.quota_retry() def translate( self, - values: Union[str, List[str]], + values: str | list[str], target_language: str, - format_: Optional[str] = None, - source_language: Optional[str] = None, - model: Optional[Union[str, List[str]]] = None, + format_: str | None = None, + source_language: str | None = None, + model: str | list[str] | None = None, ) -> dict: """Translate a string or list of strings. @@ -79,7 +80,6 @@ def translate( be translated. :param model: (Optional) The model used to translate the text, such as ``'base'`` or ``'nmt'``. - :rtype: str or list :returns: A list of dictionaries for each queried value. Each dictionary typically contains three keys (though not all will be present in all cases) diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py b/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py index 26d3425368288..1a31e68cc8353 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """ This module contains a Google Cloud Vertex AI hook. @@ -43,8 +42,10 @@ targetColumn optimizationObjective """ +from __future__ import annotations -from typing import Dict, List, Optional, Sequence, Tuple, Union +import warnings +from typing import Sequence from google.api_core.client_options import ClientOptions from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -73,66 +74,64 @@ class AutoMLHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) - self._job: Optional[ - Union[ - AutoMLForecastingTrainingJob, - AutoMLImageTrainingJob, - AutoMLTabularTrainingJob, - AutoMLTextTrainingJob, - AutoMLVideoTrainingJob, - ] - ] = None + self._job: None | ( + AutoMLForecastingTrainingJob + | AutoMLImageTrainingJob + | AutoMLTabularTrainingJob + | AutoMLTextTrainingJob + | AutoMLVideoTrainingJob + ) = None def get_pipeline_service_client( self, - region: Optional[str] = None, + region: str | None = None, ) -> PipelineServiceClient: """Returns PipelineServiceClient.""" - if region and region != 'global': - client_options = ClientOptions(api_endpoint=f'{region}-aiplatform.googleapis.com:443') + if region and region != "global": + client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443") else: client_options = ClientOptions() return PipelineServiceClient( - credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options + credentials=self.get_credentials(), client_info=self.client_info, client_options=client_options ) def get_job_service_client( self, - region: Optional[str] = None, + region: str | None = None, ) -> JobServiceClient: """Returns JobServiceClient""" - if region and region != 'global': - client_options = ClientOptions(api_endpoint=f'{region}-aiplatform.googleapis.com:443') + if region and region != "global": + client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443") else: client_options = ClientOptions() return JobServiceClient( - credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options + credentials=self.get_credentials(), client_info=self.client_info, client_options=client_options ) def get_auto_ml_tabular_training_job( self, display_name: str, optimization_prediction_type: str, - optimization_objective: Optional[str] = None, - column_specs: Optional[Dict[str, str]] = None, - column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None, - optimization_objective_recall_value: Optional[float] = None, - optimization_objective_precision_value: Optional[float] = None, - project: Optional[str] = None, - location: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - training_encryption_spec_key_name: Optional[str] = None, - model_encryption_spec_key_name: Optional[str] = None, + optimization_objective: str | None = None, + column_specs: dict[str, str] | None = None, + column_transformations: list[dict[str, dict[str, str]]] | None = None, + optimization_objective_recall_value: float | None = None, + optimization_objective_precision_value: float | None = None, + project: str | None = None, + location: str | None = None, + labels: dict[str, str] | None = None, + training_encryption_spec_key_name: str | None = None, + model_encryption_spec_key_name: str | None = None, ) -> AutoMLTabularTrainingJob: """Returns AutoMLTabularTrainingJob object""" return AutoMLTabularTrainingJob( @@ -145,7 +144,7 @@ def get_auto_ml_tabular_training_job( optimization_objective_precision_value=optimization_objective_precision_value, project=project, location=location, - credentials=self._get_credentials(), + credentials=self.get_credentials(), labels=labels, training_encryption_spec_key_name=training_encryption_spec_key_name, model_encryption_spec_key_name=model_encryption_spec_key_name, @@ -154,14 +153,14 @@ def get_auto_ml_tabular_training_job( def get_auto_ml_forecasting_training_job( self, display_name: str, - optimization_objective: Optional[str] = None, - column_specs: Optional[Dict[str, str]] = None, - column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None, - project: Optional[str] = None, - location: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - training_encryption_spec_key_name: Optional[str] = None, - model_encryption_spec_key_name: Optional[str] = None, + optimization_objective: str | None = None, + column_specs: dict[str, str] | None = None, + column_transformations: list[dict[str, dict[str, str]]] | None = None, + project: str | None = None, + location: str | None = None, + labels: dict[str, str] | None = None, + training_encryption_spec_key_name: str | None = None, + model_encryption_spec_key_name: str | None = None, ) -> AutoMLForecastingTrainingJob: """Returns AutoMLForecastingTrainingJob object""" return AutoMLForecastingTrainingJob( @@ -171,7 +170,7 @@ def get_auto_ml_forecasting_training_job( column_transformations=column_transformations, project=project, location=location, - credentials=self._get_credentials(), + credentials=self.get_credentials(), labels=labels, training_encryption_spec_key_name=training_encryption_spec_key_name, model_encryption_spec_key_name=model_encryption_spec_key_name, @@ -183,12 +182,12 @@ def get_auto_ml_image_training_job( prediction_type: str = "classification", multi_label: bool = False, model_type: str = "CLOUD", - base_model: Optional[models.Model] = None, - project: Optional[str] = None, - location: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - training_encryption_spec_key_name: Optional[str] = None, - model_encryption_spec_key_name: Optional[str] = None, + base_model: models.Model | None = None, + project: str | None = None, + location: str | None = None, + labels: dict[str, str] | None = None, + training_encryption_spec_key_name: str | None = None, + model_encryption_spec_key_name: str | None = None, ) -> AutoMLImageTrainingJob: """Returns AutoMLImageTrainingJob object""" return AutoMLImageTrainingJob( @@ -199,7 +198,7 @@ def get_auto_ml_image_training_job( base_model=base_model, project=project, location=location, - credentials=self._get_credentials(), + credentials=self.get_credentials(), labels=labels, training_encryption_spec_key_name=training_encryption_spec_key_name, model_encryption_spec_key_name=model_encryption_spec_key_name, @@ -211,11 +210,11 @@ def get_auto_ml_text_training_job( prediction_type: str, multi_label: bool = False, sentiment_max: int = 10, - project: Optional[str] = None, - location: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - training_encryption_spec_key_name: Optional[str] = None, - model_encryption_spec_key_name: Optional[str] = None, + project: str | None = None, + location: str | None = None, + labels: dict[str, str] | None = None, + training_encryption_spec_key_name: str | None = None, + model_encryption_spec_key_name: str | None = None, ) -> AutoMLTextTrainingJob: """Returns AutoMLTextTrainingJob object""" return AutoMLTextTrainingJob( @@ -225,7 +224,7 @@ def get_auto_ml_text_training_job( sentiment_max=sentiment_max, project=project, location=location, - credentials=self._get_credentials(), + credentials=self.get_credentials(), labels=labels, training_encryption_spec_key_name=training_encryption_spec_key_name, model_encryption_spec_key_name=model_encryption_spec_key_name, @@ -236,11 +235,11 @@ def get_auto_ml_video_training_job( display_name: str, prediction_type: str = "classification", model_type: str = "CLOUD", - project: Optional[str] = None, - location: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - training_encryption_spec_key_name: Optional[str] = None, - model_encryption_spec_key_name: Optional[str] = None, + project: str | None = None, + location: str | None = None, + labels: dict[str, str] | None = None, + training_encryption_spec_key_name: str | None = None, + model_encryption_spec_key_name: str | None = None, ) -> AutoMLVideoTrainingJob: """Returns AutoMLVideoTrainingJob object""" return AutoMLVideoTrainingJob( @@ -249,18 +248,23 @@ def get_auto_ml_video_training_job( model_type=model_type, project=project, location=location, - credentials=self._get_credentials(), + credentials=self.get_credentials(), labels=labels, training_encryption_spec_key_name=training_encryption_spec_key_name, model_encryption_spec_key_name=model_encryption_spec_key_name, ) @staticmethod - def extract_model_id(obj: Dict) -> str: + def extract_model_id(obj: dict) -> str: """Returns unique id of the Model.""" return obj["name"].rpartition("/")[-1] - def wait_for_operation(self, operation: Operation, timeout: Optional[float] = None): + @staticmethod + def extract_training_id(resource_name: str) -> str: + """Returns unique id of the Training pipeline.""" + return resource_name.rpartition("/")[-1] + + def wait_for_operation(self, operation: Operation, timeout: float | None = None): """Waits for long-lasting operation to complete.""" try: return operation.result(timeout=timeout) @@ -282,29 +286,29 @@ def create_auto_ml_tabular_training_job( dataset: datasets.TabularDataset, target_column: str, optimization_prediction_type: str, - optimization_objective: Optional[str] = None, - column_specs: Optional[Dict[str, str]] = None, - column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None, - optimization_objective_recall_value: Optional[float] = None, - optimization_objective_precision_value: Optional[float] = None, - labels: Optional[Dict[str, str]] = None, - training_encryption_spec_key_name: Optional[str] = None, - model_encryption_spec_key_name: Optional[str] = None, - training_fraction_split: Optional[float] = None, - validation_fraction_split: Optional[float] = None, - test_fraction_split: Optional[float] = None, - predefined_split_column_name: Optional[str] = None, - timestamp_split_column_name: Optional[str] = None, - weight_column: Optional[str] = None, + optimization_objective: str | None = None, + column_specs: dict[str, str] | None = None, + column_transformations: list[dict[str, dict[str, str]]] | None = None, + optimization_objective_recall_value: float | None = None, + optimization_objective_precision_value: float | None = None, + labels: dict[str, str] | None = None, + training_encryption_spec_key_name: str | None = None, + model_encryption_spec_key_name: str | None = None, + training_fraction_split: float | None = None, + validation_fraction_split: float | None = None, + test_fraction_split: float | None = None, + predefined_split_column_name: str | None = None, + timestamp_split_column_name: str | None = None, + weight_column: str | None = None, budget_milli_node_hours: int = 1000, - model_display_name: Optional[str] = None, - model_labels: Optional[Dict[str, str]] = None, + model_display_name: str | None = None, + model_labels: dict[str, str] | None = None, disable_early_stopping: bool = False, export_evaluated_data_items: bool = False, - export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None, + export_evaluated_data_items_bigquery_destination_uri: str | None = None, export_evaluated_data_items_override_destination: bool = False, sync: bool = True, - ) -> models.Model: + ) -> tuple[models.Model | None, str]: """ Create an AutoML Tabular Training Job. @@ -444,6 +448,13 @@ def create_auto_ml_tabular_training_job( concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. """ + if column_transformations: + warnings.warn( + "Consider using column_specs as column_transformations will be deprecated eventually.", + DeprecationWarning, + stacklevel=2, + ) + self._job = self.get_auto_ml_tabular_training_job( project=project_id, location=region, @@ -482,9 +493,15 @@ def create_auto_ml_tabular_training_job( export_evaluated_data_items_override_destination=export_evaluated_data_items_override_destination, sync=sync, ) - model.wait() - - return model + training_id = self.extract_training_id(self._job.resource_name) + if model: + model.wait() + else: + self.log.warning( + "Training did not produce a Managed Model returning None. Training Pipeline is not " + "configured to upload a Model." + ) + return model, training_id @GoogleBaseHook.fallback_to_default_project_id def create_auto_ml_forecasting_training_job( @@ -496,34 +513,34 @@ def create_auto_ml_forecasting_training_job( target_column: str, time_column: str, time_series_identifier_column: str, - unavailable_at_forecast_columns: List[str], - available_at_forecast_columns: List[str], + unavailable_at_forecast_columns: list[str], + available_at_forecast_columns: list[str], forecast_horizon: int, data_granularity_unit: str, data_granularity_count: int, - optimization_objective: Optional[str] = None, - column_specs: Optional[Dict[str, str]] = None, - column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None, - labels: Optional[Dict[str, str]] = None, - training_encryption_spec_key_name: Optional[str] = None, - model_encryption_spec_key_name: Optional[str] = None, - training_fraction_split: Optional[float] = None, - validation_fraction_split: Optional[float] = None, - test_fraction_split: Optional[float] = None, - predefined_split_column_name: Optional[str] = None, - weight_column: Optional[str] = None, - time_series_attribute_columns: Optional[List[str]] = None, - context_window: Optional[int] = None, + optimization_objective: str | None = None, + column_specs: dict[str, str] | None = None, + column_transformations: list[dict[str, dict[str, str]]] | None = None, + labels: dict[str, str] | None = None, + training_encryption_spec_key_name: str | None = None, + model_encryption_spec_key_name: str | None = None, + training_fraction_split: float | None = None, + validation_fraction_split: float | None = None, + test_fraction_split: float | None = None, + predefined_split_column_name: str | None = None, + weight_column: str | None = None, + time_series_attribute_columns: list[str] | None = None, + context_window: int | None = None, export_evaluated_data_items: bool = False, - export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None, + export_evaluated_data_items_bigquery_destination_uri: str | None = None, export_evaluated_data_items_override_destination: bool = False, - quantiles: Optional[List[float]] = None, - validation_options: Optional[str] = None, + quantiles: list[float] | None = None, + validation_options: str | None = None, budget_milli_node_hours: int = 1000, - model_display_name: Optional[str] = None, - model_labels: Optional[Dict[str, str]] = None, + model_display_name: str | None = None, + model_labels: dict[str, str] | None = None, sync: bool = True, - ) -> models.Model: + ) -> tuple[models.Model | None, str]: """ Create an AutoML Forecasting Training Job. @@ -658,6 +675,13 @@ def create_auto_ml_forecasting_training_job( concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. """ + if column_transformations: + warnings.warn( + "Consider using column_specs as column_transformations will be deprecated eventually.", + DeprecationWarning, + stacklevel=2, + ) + self._job = self.get_auto_ml_forecasting_training_job( project=project_id, location=region, @@ -702,9 +726,15 @@ def create_auto_ml_forecasting_training_job( model_labels=model_labels, sync=sync, ) - model.wait() - - return model + training_id = self.extract_training_id(self._job.resource_name) + if model: + model.wait() + else: + self.log.warning( + "Training did not produce a Managed Model returning None. Training Pipeline is not " + "configured to upload a Model." + ) + return model, training_id @GoogleBaseHook.fallback_to_default_project_id def create_auto_ml_image_training_job( @@ -716,22 +746,22 @@ def create_auto_ml_image_training_job( prediction_type: str = "classification", multi_label: bool = False, model_type: str = "CLOUD", - base_model: Optional[models.Model] = None, - labels: Optional[Dict[str, str]] = None, - training_encryption_spec_key_name: Optional[str] = None, - model_encryption_spec_key_name: Optional[str] = None, - training_fraction_split: Optional[float] = None, - validation_fraction_split: Optional[float] = None, - test_fraction_split: Optional[float] = None, - training_filter_split: Optional[str] = None, - validation_filter_split: Optional[str] = None, - test_filter_split: Optional[str] = None, - budget_milli_node_hours: Optional[int] = None, - model_display_name: Optional[str] = None, - model_labels: Optional[Dict[str, str]] = None, + base_model: models.Model | None = None, + labels: dict[str, str] | None = None, + training_encryption_spec_key_name: str | None = None, + model_encryption_spec_key_name: str | None = None, + training_fraction_split: float | None = None, + validation_fraction_split: float | None = None, + test_fraction_split: float | None = None, + training_filter_split: str | None = None, + validation_filter_split: str | None = None, + test_filter_split: str | None = None, + budget_milli_node_hours: int | None = None, + model_display_name: str | None = None, + model_labels: dict[str, str] | None = None, disable_early_stopping: bool = False, sync: bool = True, - ) -> models.Model: + ) -> tuple[models.Model | None, str]: """ Create an AutoML Image Training Job. @@ -872,9 +902,15 @@ def create_auto_ml_image_training_job( disable_early_stopping=disable_early_stopping, sync=sync, ) - model.wait() - - return model + training_id = self.extract_training_id(self._job.resource_name) + if model: + model.wait() + else: + self.log.warning( + "Training did not produce a Managed Model returning None. AutoML Image Training " + "Pipeline is not configured to upload a Model." + ) + return model, training_id @GoogleBaseHook.fallback_to_default_project_id def create_auto_ml_text_training_job( @@ -886,19 +922,19 @@ def create_auto_ml_text_training_job( prediction_type: str, multi_label: bool = False, sentiment_max: int = 10, - labels: Optional[Dict[str, str]] = None, - training_encryption_spec_key_name: Optional[str] = None, - model_encryption_spec_key_name: Optional[str] = None, - training_fraction_split: Optional[float] = None, - validation_fraction_split: Optional[float] = None, - test_fraction_split: Optional[float] = None, - training_filter_split: Optional[str] = None, - validation_filter_split: Optional[str] = None, - test_filter_split: Optional[str] = None, - model_display_name: Optional[str] = None, - model_labels: Optional[Dict[str, str]] = None, + labels: dict[str, str] | None = None, + training_encryption_spec_key_name: str | None = None, + model_encryption_spec_key_name: str | None = None, + training_fraction_split: float | None = None, + validation_fraction_split: float | None = None, + test_fraction_split: float | None = None, + training_filter_split: str | None = None, + validation_filter_split: str | None = None, + test_filter_split: str | None = None, + model_display_name: str | None = None, + model_labels: dict[str, str] | None = None, sync: bool = True, - ) -> models.Model: + ) -> tuple[models.Model | None, str]: """ Create an AutoML Text Training Job. @@ -1003,9 +1039,15 @@ def create_auto_ml_text_training_job( model_labels=model_labels, sync=sync, ) - model.wait() - - return model + training_id = self.extract_training_id(self._job.resource_name) + if model: + model.wait() + else: + self.log.warning( + "Training did not produce a Managed Model returning None. AutoML Text Training " + "Pipeline is not configured to upload a Model." + ) + return model, training_id @GoogleBaseHook.fallback_to_default_project_id def create_auto_ml_video_training_job( @@ -1016,17 +1058,17 @@ def create_auto_ml_video_training_job( dataset: datasets.VideoDataset, prediction_type: str = "classification", model_type: str = "CLOUD", - labels: Optional[Dict[str, str]] = None, - training_encryption_spec_key_name: Optional[str] = None, - model_encryption_spec_key_name: Optional[str] = None, - training_fraction_split: Optional[float] = None, - test_fraction_split: Optional[float] = None, - training_filter_split: Optional[str] = None, - test_filter_split: Optional[str] = None, - model_display_name: Optional[str] = None, - model_labels: Optional[Dict[str, str]] = None, + labels: dict[str, str] | None = None, + training_encryption_spec_key_name: str | None = None, + model_encryption_spec_key_name: str | None = None, + training_fraction_split: float | None = None, + test_fraction_split: float | None = None, + training_filter_split: str | None = None, + test_filter_split: str | None = None, + model_display_name: str | None = None, + model_labels: dict[str, str] | None = None, sync: bool = True, - ) -> models.Model: + ) -> tuple[models.Model | None, str]: """ Create an AutoML Video Training Job. @@ -1128,9 +1170,15 @@ def create_auto_ml_video_training_job( model_labels=model_labels, sync=sync, ) - model.wait() - - return model + training_id = self.extract_training_id(self._job.resource_name) + if model: + model.wait() + else: + self.log.warning( + "Training did not produce a Managed Model returning None. AutoML Video Training " + "Pipeline is not configured to upload a Model." + ) + return model, training_id @GoogleBaseHook.fallback_to_default_project_id def delete_training_pipeline( @@ -1138,9 +1186,9 @@ def delete_training_pipeline( project_id: str, region: str, training_pipeline: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Deletes a TrainingPipeline. @@ -1157,7 +1205,7 @@ def delete_training_pipeline( result = client.delete_training_pipeline( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -1171,9 +1219,9 @@ def get_training_pipeline( project_id: str, region: str, training_pipeline: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> TrainingPipeline: """ Gets a TrainingPipeline. @@ -1190,7 +1238,7 @@ def get_training_pipeline( result = client.get_training_pipeline( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -1203,13 +1251,13 @@ def list_training_pipelines( self, project_id: str, region: str, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - filter: Optional[str] = None, - read_mask: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + page_size: int | None = None, + page_token: str | None = None, + filter: str | None = None, + read_mask: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ListTrainingPipelinesPager: """ Lists TrainingPipelines in a Location. @@ -1247,11 +1295,11 @@ def list_training_pipelines( result = client.list_training_pipelines( request={ - 'parent': parent, - 'page_size': page_size, - 'page_token': page_token, - 'filter': filter, - 'read_mask': read_mask, + "parent": parent, + "page_size": page_size, + "page_token": page_token, + "filter": filter, + "read_mask": read_mask, }, retry=retry, timeout=timeout, diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py b/airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py index c6e74ca15bd93..8671fa31a265d 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains a Google Cloud Vertex AI hook. .. spelling:: @@ -25,8 +24,9 @@ aiplatform gapic """ +from __future__ import annotations -from typing import Dict, Optional, Sequence, Tuple, Union +from typing import Sequence from google.api_core.client_options import ClientOptions from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -46,28 +46,28 @@ class BatchPredictionJobHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) - self._batch_prediction_job: Optional[BatchPredictionJob] = None + self._batch_prediction_job: BatchPredictionJob | None = None - def get_job_service_client(self, region: Optional[str] = None) -> JobServiceClient: + def get_job_service_client(self, region: str | None = None) -> JobServiceClient: """Returns JobServiceClient.""" - if region and region != 'global': - client_options = ClientOptions(api_endpoint=f'{region}-aiplatform.googleapis.com:443') + if region and region != "global": + client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443") else: client_options = ClientOptions() return JobServiceClient( - credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options + credentials=self.get_credentials(), client_info=self.client_info, client_options=client_options ) - def wait_for_operation(self, operation: Operation, timeout: Optional[float] = None): + def wait_for_operation(self, operation: Operation, timeout: float | None = None): """Waits for long-lasting operation to complete.""" try: return operation.result(timeout=timeout) @@ -76,7 +76,7 @@ def wait_for_operation(self, operation: Operation, timeout: Optional[float] = No raise AirflowException(error) @staticmethod - def extract_batch_prediction_job_id(obj: Dict) -> str: + def extract_batch_prediction_job_id(obj: dict) -> str: """Returns unique id of the batch_prediction_job.""" return obj["name"].rpartition("/")[-1] @@ -91,24 +91,24 @@ def create_batch_prediction_job( project_id: str, region: str, job_display_name: str, - model_name: Union[str, "Model"], + model_name: str | Model, instances_format: str = "jsonl", predictions_format: str = "jsonl", - gcs_source: Optional[Union[str, Sequence[str]]] = None, - bigquery_source: Optional[str] = None, - gcs_destination_prefix: Optional[str] = None, - bigquery_destination_prefix: Optional[str] = None, - model_parameters: Optional[Dict] = None, - machine_type: Optional[str] = None, - accelerator_type: Optional[str] = None, - accelerator_count: Optional[int] = None, - starting_replica_count: Optional[int] = None, - max_replica_count: Optional[int] = None, - generate_explanation: Optional[bool] = False, - explanation_metadata: Optional["explain.ExplanationMetadata"] = None, - explanation_parameters: Optional["explain.ExplanationParameters"] = None, - labels: Optional[Dict[str, str]] = None, - encryption_spec_key_name: Optional[str] = None, + gcs_source: str | Sequence[str] | None = None, + bigquery_source: str | None = None, + gcs_destination_prefix: str | None = None, + bigquery_destination_prefix: str | None = None, + model_parameters: dict | None = None, + machine_type: str | None = None, + accelerator_type: str | None = None, + accelerator_count: int | None = None, + starting_replica_count: int | None = None, + max_replica_count: int | None = None, + generate_explanation: bool | None = False, + explanation_metadata: explain.ExplanationMetadata | None = None, + explanation_parameters: explain.ExplanationParameters | None = None, + labels: dict[str, str] | None = None, + encryption_spec_key_name: str | None = None, sync: bool = True, ) -> BatchPredictionJob: """ @@ -225,7 +225,7 @@ def create_batch_prediction_job( labels=labels, project=project_id, location=region, - credentials=self._get_credentials(), + credentials=self.get_credentials(), encryption_spec_key_name=encryption_spec_key_name, sync=sync, ) @@ -237,9 +237,9 @@ def delete_batch_prediction_job( project_id: str, region: str, batch_prediction_job: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Deletes a BatchPredictionJob. Can only be called on jobs that already finished. @@ -256,7 +256,7 @@ def delete_batch_prediction_job( result = client.delete_batch_prediction_job( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -270,9 +270,9 @@ def get_batch_prediction_job( project_id: str, region: str, batch_prediction_job: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> BatchPredictionJob: """ Gets a BatchPredictionJob @@ -289,7 +289,7 @@ def get_batch_prediction_job( result = client.get_batch_prediction_job( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -302,13 +302,13 @@ def list_batch_prediction_jobs( self, project_id: str, region: str, - filter: Optional[str] = None, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - read_mask: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + filter: str | None = None, + page_size: int | None = None, + page_token: str | None = None, + read_mask: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ListBatchPredictionJobsPager: """ Lists BatchPredictionJobs in a Location. @@ -337,11 +337,11 @@ def list_batch_prediction_jobs( result = client.list_batch_prediction_jobs( request={ - 'parent': parent, - 'filter': filter, - 'page_size': page_size, - 'page_token': page_token, - 'read_mask': read_mask, + "parent": parent, + "filter": filter, + "page_size": page_size, + "page_token": page_token, + "read_mask": read_mask, }, retry=retry, timeout=timeout, diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py index bd6987868512c..77fb7d2cc6f04 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py @@ -15,10 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains a Google Cloud Vertex AI hook.""" +from __future__ import annotations -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Sequence from google.api_core.client_options import ClientOptions from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -50,47 +50,43 @@ class CustomJobHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) - self._job: Optional[ - Union[ - CustomContainerTrainingJob, - CustomPythonPackageTrainingJob, - CustomTrainingJob, - ] - ] = None + self._job: None | ( + CustomContainerTrainingJob | CustomPythonPackageTrainingJob | CustomTrainingJob + ) = None def get_pipeline_service_client( self, - region: Optional[str] = None, + region: str | None = None, ) -> PipelineServiceClient: """Returns PipelineServiceClient.""" - if region and region != 'global': - client_options = ClientOptions(api_endpoint=f'{region}-aiplatform.googleapis.com:443') + if region and region != "global": + client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443") else: client_options = ClientOptions() return PipelineServiceClient( - credentials=self._get_credentials(), client_info=CLIENT_INFO, client_options=client_options + credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options ) def get_job_service_client( self, - region: Optional[str] = None, + region: str | None = None, ) -> JobServiceClient: """Returns JobServiceClient""" - if region and region != 'global': - client_options = ClientOptions(api_endpoint=f'{region}-aiplatform.googleapis.com:443') + if region and region != "global": + client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443") else: client_options = ClientOptions() return JobServiceClient( - credentials=self._get_credentials(), client_info=CLIENT_INFO, client_options=client_options + credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options ) def get_custom_container_training_job( @@ -98,23 +94,23 @@ def get_custom_container_training_job( display_name: str, container_uri: str, command: Sequence[str] = [], - model_serving_container_image_uri: Optional[str] = None, - model_serving_container_predict_route: Optional[str] = None, - model_serving_container_health_route: Optional[str] = None, - model_serving_container_command: Optional[Sequence[str]] = None, - model_serving_container_args: Optional[Sequence[str]] = None, - model_serving_container_environment_variables: Optional[Dict[str, str]] = None, - model_serving_container_ports: Optional[Sequence[int]] = None, - model_description: Optional[str] = None, - model_instance_schema_uri: Optional[str] = None, - model_parameters_schema_uri: Optional[str] = None, - model_prediction_schema_uri: Optional[str] = None, - project: Optional[str] = None, - location: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - training_encryption_spec_key_name: Optional[str] = None, - model_encryption_spec_key_name: Optional[str] = None, - staging_bucket: Optional[str] = None, + model_serving_container_image_uri: str | None = None, + model_serving_container_predict_route: str | None = None, + model_serving_container_health_route: str | None = None, + model_serving_container_command: Sequence[str] | None = None, + model_serving_container_args: Sequence[str] | None = None, + model_serving_container_environment_variables: dict[str, str] | None = None, + model_serving_container_ports: Sequence[int] | None = None, + model_description: str | None = None, + model_instance_schema_uri: str | None = None, + model_parameters_schema_uri: str | None = None, + model_prediction_schema_uri: str | None = None, + project: str | None = None, + location: str | None = None, + labels: dict[str, str] | None = None, + training_encryption_spec_key_name: str | None = None, + model_encryption_spec_key_name: str | None = None, + staging_bucket: str | None = None, ) -> CustomContainerTrainingJob: """Returns CustomContainerTrainingJob object""" return CustomContainerTrainingJob( @@ -134,7 +130,7 @@ def get_custom_container_training_job( model_prediction_schema_uri=model_prediction_schema_uri, project=project, location=location, - credentials=self._get_credentials(), + credentials=self.get_credentials(), labels=labels, training_encryption_spec_key_name=training_encryption_spec_key_name, model_encryption_spec_key_name=model_encryption_spec_key_name, @@ -147,23 +143,23 @@ def get_custom_python_package_training_job( python_package_gcs_uri: str, python_module_name: str, container_uri: str, - model_serving_container_image_uri: Optional[str] = None, - model_serving_container_predict_route: Optional[str] = None, - model_serving_container_health_route: Optional[str] = None, - model_serving_container_command: Optional[Sequence[str]] = None, - model_serving_container_args: Optional[Sequence[str]] = None, - model_serving_container_environment_variables: Optional[Dict[str, str]] = None, - model_serving_container_ports: Optional[Sequence[int]] = None, - model_description: Optional[str] = None, - model_instance_schema_uri: Optional[str] = None, - model_parameters_schema_uri: Optional[str] = None, - model_prediction_schema_uri: Optional[str] = None, - project: Optional[str] = None, - location: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - training_encryption_spec_key_name: Optional[str] = None, - model_encryption_spec_key_name: Optional[str] = None, - staging_bucket: Optional[str] = None, + model_serving_container_image_uri: str | None = None, + model_serving_container_predict_route: str | None = None, + model_serving_container_health_route: str | None = None, + model_serving_container_command: Sequence[str] | None = None, + model_serving_container_args: Sequence[str] | None = None, + model_serving_container_environment_variables: dict[str, str] | None = None, + model_serving_container_ports: Sequence[int] | None = None, + model_description: str | None = None, + model_instance_schema_uri: str | None = None, + model_parameters_schema_uri: str | None = None, + model_prediction_schema_uri: str | None = None, + project: str | None = None, + location: str | None = None, + labels: dict[str, str] | None = None, + training_encryption_spec_key_name: str | None = None, + model_encryption_spec_key_name: str | None = None, + staging_bucket: str | None = None, ): """Returns CustomPythonPackageTrainingJob object""" return CustomPythonPackageTrainingJob( @@ -184,7 +180,7 @@ def get_custom_python_package_training_job( model_prediction_schema_uri=model_prediction_schema_uri, project=project, location=location, - credentials=self._get_credentials(), + credentials=self.get_credentials(), labels=labels, training_encryption_spec_key_name=training_encryption_spec_key_name, model_encryption_spec_key_name=model_encryption_spec_key_name, @@ -196,24 +192,24 @@ def get_custom_training_job( display_name: str, script_path: str, container_uri: str, - requirements: Optional[Sequence[str]] = None, - model_serving_container_image_uri: Optional[str] = None, - model_serving_container_predict_route: Optional[str] = None, - model_serving_container_health_route: Optional[str] = None, - model_serving_container_command: Optional[Sequence[str]] = None, - model_serving_container_args: Optional[Sequence[str]] = None, - model_serving_container_environment_variables: Optional[Dict[str, str]] = None, - model_serving_container_ports: Optional[Sequence[int]] = None, - model_description: Optional[str] = None, - model_instance_schema_uri: Optional[str] = None, - model_parameters_schema_uri: Optional[str] = None, - model_prediction_schema_uri: Optional[str] = None, - project: Optional[str] = None, - location: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - training_encryption_spec_key_name: Optional[str] = None, - model_encryption_spec_key_name: Optional[str] = None, - staging_bucket: Optional[str] = None, + requirements: Sequence[str] | None = None, + model_serving_container_image_uri: str | None = None, + model_serving_container_predict_route: str | None = None, + model_serving_container_health_route: str | None = None, + model_serving_container_command: Sequence[str] | None = None, + model_serving_container_args: Sequence[str] | None = None, + model_serving_container_environment_variables: dict[str, str] | None = None, + model_serving_container_ports: Sequence[int] | None = None, + model_description: str | None = None, + model_instance_schema_uri: str | None = None, + model_parameters_schema_uri: str | None = None, + model_prediction_schema_uri: str | None = None, + project: str | None = None, + location: str | None = None, + labels: dict[str, str] | None = None, + training_encryption_spec_key_name: str | None = None, + model_encryption_spec_key_name: str | None = None, + staging_bucket: str | None = None, ): """Returns CustomTrainingJob object""" return CustomTrainingJob( @@ -234,7 +230,7 @@ def get_custom_training_job( model_prediction_schema_uri=model_prediction_schema_uri, project=project, location=location, - credentials=self._get_credentials(), + credentials=self.get_credentials(), labels=labels, training_encryption_spec_key_name=training_encryption_spec_key_name, model_encryption_spec_key_name=model_encryption_spec_key_name, @@ -242,11 +238,21 @@ def get_custom_training_job( ) @staticmethod - def extract_model_id(obj: Dict) -> str: + def extract_model_id(obj: dict) -> str: """Returns unique id of the Model.""" return obj["name"].rpartition("/")[-1] - def wait_for_operation(self, operation: Operation, timeout: Optional[float] = None): + @staticmethod + def extract_training_id(resource_name: str) -> str: + """Returns unique id of the Training pipeline.""" + return resource_name.rpartition("/")[-1] + + @staticmethod + def extract_custom_job_id(custom_job_name: str) -> str: + """Returns unique id of the Custom Job pipeline.""" + return custom_job_name.rpartition("/")[-1] + + def wait_for_operation(self, operation: Operation, timeout: float | None = None): """Waits for long-lasting operation to complete.""" try: return operation.result(timeout=timeout) @@ -261,45 +267,37 @@ def cancel_job(self) -> None: def _run_job( self, - job: Union[ - CustomTrainingJob, - CustomContainerTrainingJob, - CustomPythonPackageTrainingJob, - ], - dataset: Optional[ - Union[ - datasets.ImageDataset, - datasets.TabularDataset, - datasets.TextDataset, - datasets.VideoDataset, - ] - ] = None, - annotation_schema_uri: Optional[str] = None, - model_display_name: Optional[str] = None, - model_labels: Optional[Dict[str, str]] = None, - base_output_dir: Optional[str] = None, - service_account: Optional[str] = None, - network: Optional[str] = None, - bigquery_destination: Optional[str] = None, - args: Optional[List[Union[str, float, int]]] = None, - environment_variables: Optional[Dict[str, str]] = None, + job: (CustomTrainingJob | CustomContainerTrainingJob | CustomPythonPackageTrainingJob), + dataset: None + | ( + datasets.ImageDataset | datasets.TabularDataset | datasets.TextDataset | datasets.VideoDataset + ) = None, + annotation_schema_uri: str | None = None, + model_display_name: str | None = None, + model_labels: dict[str, str] | None = None, + base_output_dir: str | None = None, + service_account: str | None = None, + network: str | None = None, + bigquery_destination: str | None = None, + args: list[str | float | int] | None = None, + environment_variables: dict[str, str] | None = None, replica_count: int = 1, machine_type: str = "n1-standard-4", accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", accelerator_count: int = 0, boot_disk_type: str = "pd-ssd", boot_disk_size_gb: int = 100, - training_fraction_split: Optional[float] = None, - validation_fraction_split: Optional[float] = None, - test_fraction_split: Optional[float] = None, - training_filter_split: Optional[str] = None, - validation_filter_split: Optional[str] = None, - test_filter_split: Optional[str] = None, - predefined_split_column_name: Optional[str] = None, - timestamp_split_column_name: Optional[str] = None, - tensorboard: Optional[str] = None, + training_fraction_split: float | None = None, + validation_fraction_split: float | None = None, + test_fraction_split: float | None = None, + training_filter_split: str | None = None, + validation_filter_split: str | None = None, + test_filter_split: str | None = None, + predefined_split_column_name: str | None = None, + timestamp_split_column_name: str | None = None, + tensorboard: str | None = None, sync=True, - ) -> models.Model: + ) -> tuple[models.Model | None, str, str]: """Run Job for training pipeline""" model = job.run( dataset=dataset, @@ -329,11 +327,20 @@ def _run_job( tensorboard=tensorboard, sync=sync, ) + training_id = self.extract_training_id(job.resource_name) + custom_job_id = self.extract_custom_job_id( + job.gca_resource.training_task_metadata.get("backingCustomJob") + ) if model: model.wait() - return model else: - raise AirflowException("Training did not produce a Managed Model returning None.") + self.log.warning( + "Training did not produce a Managed Model returning None. Training Pipeline is not " + "configured to upload a Model. Create the Training Pipeline with " + "model_serving_container_image_uri and model_display_name passed in. " + "Ensure that your training script saves to model to os.environ['AIP_MODEL_DIR']." + ) + return model, training_id, custom_job_id @GoogleBaseHook.fallback_to_default_project_id def cancel_pipeline_job( @@ -341,9 +348,9 @@ def cancel_pipeline_job( project_id: str, region: str, pipeline_job: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Cancels a PipelineJob. Starts asynchronous cancellation on the PipelineJob. The server makes a best @@ -367,7 +374,7 @@ def cancel_pipeline_job( client.cancel_pipeline_job( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -380,9 +387,9 @@ def cancel_training_pipeline( project_id: str, region: str, training_pipeline: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Cancels a TrainingPipeline. Starts asynchronous cancellation on the TrainingPipeline. The server makes @@ -406,7 +413,7 @@ def cancel_training_pipeline( client.cancel_training_pipeline( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -419,9 +426,9 @@ def cancel_custom_job( project_id: str, region: str, custom_job: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ Cancels a CustomJob. Starts asynchronous cancellation on the CustomJob. The server makes a best effort @@ -445,7 +452,7 @@ def cancel_custom_job( client.cancel_custom_job( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -459,9 +466,9 @@ def create_pipeline_job( region: str, pipeline_job: PipelineJob, pipeline_job_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> PipelineJob: """ Creates a PipelineJob. A PipelineJob will run immediately when created. @@ -482,9 +489,9 @@ def create_pipeline_job( result = client.create_pipeline_job( request={ - 'parent': parent, - 'pipeline_job': pipeline_job, - 'pipeline_job_id': pipeline_job_id, + "parent": parent, + "pipeline_job": pipeline_job, + "pipeline_job_id": pipeline_job_id, }, retry=retry, timeout=timeout, @@ -498,9 +505,9 @@ def create_training_pipeline( project_id: str, region: str, training_pipeline: TrainingPipeline, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> TrainingPipeline: """ Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run. @@ -517,8 +524,8 @@ def create_training_pipeline( result = client.create_training_pipeline( request={ - 'parent': parent, - 'training_pipeline': training_pipeline, + "parent": parent, + "training_pipeline": training_pipeline, }, retry=retry, timeout=timeout, @@ -532,9 +539,9 @@ def create_custom_job( project_id: str, region: str, custom_job: CustomJob, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> CustomJob: """ Creates a CustomJob. A created CustomJob right away will be attempted to be run. @@ -552,8 +559,8 @@ def create_custom_job( result = client.create_custom_job( request={ - 'parent': parent, - 'custom_job': custom_job, + "parent": parent, + "custom_job": custom_job, }, retry=retry, timeout=timeout, @@ -569,56 +576,52 @@ def create_custom_container_training_job( display_name: str, container_uri: str, command: Sequence[str] = [], - model_serving_container_image_uri: Optional[str] = None, - model_serving_container_predict_route: Optional[str] = None, - model_serving_container_health_route: Optional[str] = None, - model_serving_container_command: Optional[Sequence[str]] = None, - model_serving_container_args: Optional[Sequence[str]] = None, - model_serving_container_environment_variables: Optional[Dict[str, str]] = None, - model_serving_container_ports: Optional[Sequence[int]] = None, - model_description: Optional[str] = None, - model_instance_schema_uri: Optional[str] = None, - model_parameters_schema_uri: Optional[str] = None, - model_prediction_schema_uri: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - training_encryption_spec_key_name: Optional[str] = None, - model_encryption_spec_key_name: Optional[str] = None, - staging_bucket: Optional[str] = None, + model_serving_container_image_uri: str | None = None, + model_serving_container_predict_route: str | None = None, + model_serving_container_health_route: str | None = None, + model_serving_container_command: Sequence[str] | None = None, + model_serving_container_args: Sequence[str] | None = None, + model_serving_container_environment_variables: dict[str, str] | None = None, + model_serving_container_ports: Sequence[int] | None = None, + model_description: str | None = None, + model_instance_schema_uri: str | None = None, + model_parameters_schema_uri: str | None = None, + model_prediction_schema_uri: str | None = None, + labels: dict[str, str] | None = None, + training_encryption_spec_key_name: str | None = None, + model_encryption_spec_key_name: str | None = None, + staging_bucket: str | None = None, # RUN - dataset: Optional[ - Union[ - datasets.ImageDataset, - datasets.TabularDataset, - datasets.TextDataset, - datasets.VideoDataset, - ] - ] = None, - annotation_schema_uri: Optional[str] = None, - model_display_name: Optional[str] = None, - model_labels: Optional[Dict[str, str]] = None, - base_output_dir: Optional[str] = None, - service_account: Optional[str] = None, - network: Optional[str] = None, - bigquery_destination: Optional[str] = None, - args: Optional[List[Union[str, float, int]]] = None, - environment_variables: Optional[Dict[str, str]] = None, + dataset: None + | ( + datasets.ImageDataset | datasets.TabularDataset | datasets.TextDataset | datasets.VideoDataset + ) = None, + annotation_schema_uri: str | None = None, + model_display_name: str | None = None, + model_labels: dict[str, str] | None = None, + base_output_dir: str | None = None, + service_account: str | None = None, + network: str | None = None, + bigquery_destination: str | None = None, + args: list[str | float | int] | None = None, + environment_variables: dict[str, str] | None = None, replica_count: int = 1, machine_type: str = "n1-standard-4", accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", accelerator_count: int = 0, boot_disk_type: str = "pd-ssd", boot_disk_size_gb: int = 100, - training_fraction_split: Optional[float] = None, - validation_fraction_split: Optional[float] = None, - test_fraction_split: Optional[float] = None, - training_filter_split: Optional[str] = None, - validation_filter_split: Optional[str] = None, - test_filter_split: Optional[str] = None, - predefined_split_column_name: Optional[str] = None, - timestamp_split_column_name: Optional[str] = None, - tensorboard: Optional[str] = None, + training_fraction_split: float | None = None, + validation_fraction_split: float | None = None, + test_fraction_split: float | None = None, + training_filter_split: str | None = None, + validation_filter_split: str | None = None, + test_filter_split: str | None = None, + predefined_split_column_name: str | None = None, + timestamp_split_column_name: str | None = None, + tensorboard: str | None = None, sync=True, - ) -> models.Model: + ) -> tuple[models.Model | None, str, str]: """ Create Custom Container Training Job @@ -890,7 +893,7 @@ def create_custom_container_training_job( if not self._job: raise AirflowException("CustomJob was not created") - model = self._run_job( + model, training_id, custom_job_id = self._run_job( job=self._job, dataset=dataset, annotation_schema_uri=annotation_schema_uri, @@ -920,7 +923,7 @@ def create_custom_container_training_job( sync=sync, ) - return model + return model, training_id, custom_job_id @GoogleBaseHook.fallback_to_default_project_id def create_custom_python_package_training_job( @@ -931,56 +934,52 @@ def create_custom_python_package_training_job( python_package_gcs_uri: str, python_module_name: str, container_uri: str, - model_serving_container_image_uri: Optional[str] = None, - model_serving_container_predict_route: Optional[str] = None, - model_serving_container_health_route: Optional[str] = None, - model_serving_container_command: Optional[Sequence[str]] = None, - model_serving_container_args: Optional[Sequence[str]] = None, - model_serving_container_environment_variables: Optional[Dict[str, str]] = None, - model_serving_container_ports: Optional[Sequence[int]] = None, - model_description: Optional[str] = None, - model_instance_schema_uri: Optional[str] = None, - model_parameters_schema_uri: Optional[str] = None, - model_prediction_schema_uri: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - training_encryption_spec_key_name: Optional[str] = None, - model_encryption_spec_key_name: Optional[str] = None, - staging_bucket: Optional[str] = None, + model_serving_container_image_uri: str | None = None, + model_serving_container_predict_route: str | None = None, + model_serving_container_health_route: str | None = None, + model_serving_container_command: Sequence[str] | None = None, + model_serving_container_args: Sequence[str] | None = None, + model_serving_container_environment_variables: dict[str, str] | None = None, + model_serving_container_ports: Sequence[int] | None = None, + model_description: str | None = None, + model_instance_schema_uri: str | None = None, + model_parameters_schema_uri: str | None = None, + model_prediction_schema_uri: str | None = None, + labels: dict[str, str] | None = None, + training_encryption_spec_key_name: str | None = None, + model_encryption_spec_key_name: str | None = None, + staging_bucket: str | None = None, # RUN - dataset: Optional[ - Union[ - datasets.ImageDataset, - datasets.TabularDataset, - datasets.TextDataset, - datasets.VideoDataset, - ] - ] = None, - annotation_schema_uri: Optional[str] = None, - model_display_name: Optional[str] = None, - model_labels: Optional[Dict[str, str]] = None, - base_output_dir: Optional[str] = None, - service_account: Optional[str] = None, - network: Optional[str] = None, - bigquery_destination: Optional[str] = None, - args: Optional[List[Union[str, float, int]]] = None, - environment_variables: Optional[Dict[str, str]] = None, + dataset: None + | ( + datasets.ImageDataset | datasets.TabularDataset | datasets.TextDataset | datasets.VideoDataset + ) = None, + annotation_schema_uri: str | None = None, + model_display_name: str | None = None, + model_labels: dict[str, str] | None = None, + base_output_dir: str | None = None, + service_account: str | None = None, + network: str | None = None, + bigquery_destination: str | None = None, + args: list[str | float | int] | None = None, + environment_variables: dict[str, str] | None = None, replica_count: int = 1, machine_type: str = "n1-standard-4", accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", accelerator_count: int = 0, boot_disk_type: str = "pd-ssd", boot_disk_size_gb: int = 100, - training_fraction_split: Optional[float] = None, - validation_fraction_split: Optional[float] = None, - test_fraction_split: Optional[float] = None, - training_filter_split: Optional[str] = None, - validation_filter_split: Optional[str] = None, - test_filter_split: Optional[str] = None, - predefined_split_column_name: Optional[str] = None, - timestamp_split_column_name: Optional[str] = None, - tensorboard: Optional[str] = None, + training_fraction_split: float | None = None, + validation_fraction_split: float | None = None, + test_fraction_split: float | None = None, + training_filter_split: str | None = None, + validation_filter_split: str | None = None, + test_filter_split: str | None = None, + predefined_split_column_name: str | None = None, + timestamp_split_column_name: str | None = None, + tensorboard: str | None = None, sync=True, - ) -> models.Model: + ) -> tuple[models.Model | None, str, str]: """ Create Custom Python Package Training Job @@ -1252,7 +1251,7 @@ def create_custom_python_package_training_job( if not self._job: raise AirflowException("CustomJob was not created") - model = self._run_job( + model, training_id, custom_job_id = self._run_job( job=self._job, dataset=dataset, annotation_schema_uri=annotation_schema_uri, @@ -1282,7 +1281,7 @@ def create_custom_python_package_training_job( sync=sync, ) - return model + return model, training_id, custom_job_id @GoogleBaseHook.fallback_to_default_project_id def create_custom_training_job( @@ -1292,57 +1291,53 @@ def create_custom_training_job( display_name: str, script_path: str, container_uri: str, - requirements: Optional[Sequence[str]] = None, - model_serving_container_image_uri: Optional[str] = None, - model_serving_container_predict_route: Optional[str] = None, - model_serving_container_health_route: Optional[str] = None, - model_serving_container_command: Optional[Sequence[str]] = None, - model_serving_container_args: Optional[Sequence[str]] = None, - model_serving_container_environment_variables: Optional[Dict[str, str]] = None, - model_serving_container_ports: Optional[Sequence[int]] = None, - model_description: Optional[str] = None, - model_instance_schema_uri: Optional[str] = None, - model_parameters_schema_uri: Optional[str] = None, - model_prediction_schema_uri: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - training_encryption_spec_key_name: Optional[str] = None, - model_encryption_spec_key_name: Optional[str] = None, - staging_bucket: Optional[str] = None, + requirements: Sequence[str] | None = None, + model_serving_container_image_uri: str | None = None, + model_serving_container_predict_route: str | None = None, + model_serving_container_health_route: str | None = None, + model_serving_container_command: Sequence[str] | None = None, + model_serving_container_args: Sequence[str] | None = None, + model_serving_container_environment_variables: dict[str, str] | None = None, + model_serving_container_ports: Sequence[int] | None = None, + model_description: str | None = None, + model_instance_schema_uri: str | None = None, + model_parameters_schema_uri: str | None = None, + model_prediction_schema_uri: str | None = None, + labels: dict[str, str] | None = None, + training_encryption_spec_key_name: str | None = None, + model_encryption_spec_key_name: str | None = None, + staging_bucket: str | None = None, # RUN - dataset: Optional[ - Union[ - datasets.ImageDataset, - datasets.TabularDataset, - datasets.TextDataset, - datasets.VideoDataset, - ] - ] = None, - annotation_schema_uri: Optional[str] = None, - model_display_name: Optional[str] = None, - model_labels: Optional[Dict[str, str]] = None, - base_output_dir: Optional[str] = None, - service_account: Optional[str] = None, - network: Optional[str] = None, - bigquery_destination: Optional[str] = None, - args: Optional[List[Union[str, float, int]]] = None, - environment_variables: Optional[Dict[str, str]] = None, + dataset: None + | ( + datasets.ImageDataset | datasets.TabularDataset | datasets.TextDataset | datasets.VideoDataset + ) = None, + annotation_schema_uri: str | None = None, + model_display_name: str | None = None, + model_labels: dict[str, str] | None = None, + base_output_dir: str | None = None, + service_account: str | None = None, + network: str | None = None, + bigquery_destination: str | None = None, + args: list[str | float | int] | None = None, + environment_variables: dict[str, str] | None = None, replica_count: int = 1, machine_type: str = "n1-standard-4", accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", accelerator_count: int = 0, boot_disk_type: str = "pd-ssd", boot_disk_size_gb: int = 100, - training_fraction_split: Optional[float] = None, - validation_fraction_split: Optional[float] = None, - test_fraction_split: Optional[float] = None, - training_filter_split: Optional[str] = None, - validation_filter_split: Optional[str] = None, - test_filter_split: Optional[str] = None, - predefined_split_column_name: Optional[str] = None, - timestamp_split_column_name: Optional[str] = None, - tensorboard: Optional[str] = None, + training_fraction_split: float | None = None, + validation_fraction_split: float | None = None, + test_fraction_split: float | None = None, + training_filter_split: str | None = None, + validation_filter_split: str | None = None, + test_filter_split: str | None = None, + predefined_split_column_name: str | None = None, + timestamp_split_column_name: str | None = None, + tensorboard: str | None = None, sync=True, - ) -> models.Model: + ) -> tuple[models.Model | None, str, str]: """ Create Custom Training Job @@ -1614,7 +1609,7 @@ def create_custom_training_job( if not self._job: raise AirflowException("CustomJob was not created") - model = self._run_job( + model, training_id, custom_job_id = self._run_job( job=self._job, dataset=dataset, annotation_schema_uri=annotation_schema_uri, @@ -1644,7 +1639,7 @@ def create_custom_training_job( sync=sync, ) - return model + return model, training_id, custom_job_id @GoogleBaseHook.fallback_to_default_project_id def delete_pipeline_job( @@ -1652,9 +1647,9 @@ def delete_pipeline_job( project_id: str, region: str, pipeline_job: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Deletes a PipelineJob. @@ -1671,7 +1666,7 @@ def delete_pipeline_job( result = client.delete_pipeline_job( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -1685,9 +1680,9 @@ def delete_training_pipeline( project_id: str, region: str, training_pipeline: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Deletes a TrainingPipeline. @@ -1704,7 +1699,7 @@ def delete_training_pipeline( result = client.delete_training_pipeline( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -1718,9 +1713,9 @@ def delete_custom_job( project_id: str, region: str, custom_job: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Deletes a CustomJob. @@ -1737,7 +1732,7 @@ def delete_custom_job( result = client.delete_custom_job( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -1751,9 +1746,9 @@ def get_pipeline_job( project_id: str, region: str, pipeline_job: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> PipelineJob: """ Gets a PipelineJob. @@ -1770,7 +1765,7 @@ def get_pipeline_job( result = client.get_pipeline_job( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -1784,9 +1779,9 @@ def get_training_pipeline( project_id: str, region: str, training_pipeline: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> TrainingPipeline: """ Gets a TrainingPipeline. @@ -1803,7 +1798,7 @@ def get_training_pipeline( result = client.get_training_pipeline( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -1817,9 +1812,9 @@ def get_custom_job( project_id: str, region: str, custom_job: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> CustomJob: """ Gets a CustomJob. @@ -1836,7 +1831,7 @@ def get_custom_job( result = client.get_custom_job( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -1849,13 +1844,13 @@ def list_pipeline_jobs( self, project_id: str, region: str, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - filter: Optional[str] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + page_size: int | None = None, + page_token: str | None = None, + filter: str | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ListPipelineJobsPager: """ Lists PipelineJobs in a Location. @@ -1920,11 +1915,11 @@ def list_pipeline_jobs( result = client.list_pipeline_jobs( request={ - 'parent': parent, - 'page_size': page_size, - 'page_token': page_token, - 'filter': filter, - 'order_by': order_by, + "parent": parent, + "page_size": page_size, + "page_token": page_token, + "filter": filter, + "order_by": order_by, }, retry=retry, timeout=timeout, @@ -1937,13 +1932,13 @@ def list_training_pipelines( self, project_id: str, region: str, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - filter: Optional[str] = None, - read_mask: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + page_size: int | None = None, + page_token: str | None = None, + filter: str | None = None, + read_mask: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ListTrainingPipelinesPager: """ Lists TrainingPipelines in a Location. @@ -1981,11 +1976,11 @@ def list_training_pipelines( result = client.list_training_pipelines( request={ - 'parent': parent, - 'page_size': page_size, - 'page_token': page_token, - 'filter': filter, - 'read_mask': read_mask, + "parent": parent, + "page_size": page_size, + "page_token": page_token, + "filter": filter, + "read_mask": read_mask, }, retry=retry, timeout=timeout, @@ -1998,13 +1993,13 @@ def list_custom_jobs( self, project_id: str, region: str, - page_size: Optional[int], - page_token: Optional[str], - filter: Optional[str], - read_mask: Optional[str], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + page_size: int | None, + page_token: str | None, + filter: str | None, + read_mask: str | None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ListCustomJobsPager: """ Lists CustomJobs in a Location. @@ -2042,11 +2037,11 @@ def list_custom_jobs( result = client.list_custom_jobs( request={ - 'parent': parent, - 'page_size': page_size, - 'page_token': page_token, - 'filter': filter, - 'read_mask': read_mask, + "parent": parent, + "page_size": page_size, + "page_token": page_token, + "filter": filter, + "read_mask": read_mask, }, retry=retry, timeout=timeout, diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py b/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py index e693424c08b29..aadc0733bb804 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py @@ -15,10 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains a Google Cloud Vertex AI hook.""" +from __future__ import annotations -from typing import Dict, Optional, Sequence, Tuple, Union +from typing import Sequence from google.api_core.client_options import ClientOptions from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -41,18 +41,18 @@ class DatasetHook(GoogleBaseHook): """Hook for Google Cloud Vertex AI Dataset APIs.""" - def get_dataset_service_client(self, region: Optional[str] = None) -> DatasetServiceClient: + def get_dataset_service_client(self, region: str | None = None) -> DatasetServiceClient: """Returns DatasetServiceClient.""" - if region and region != 'global': - client_options = ClientOptions(api_endpoint=f'{region}-aiplatform.googleapis.com:443') + if region and region != "global": + client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443") else: client_options = ClientOptions() return DatasetServiceClient( - credentials=self._get_credentials(), client_info=CLIENT_INFO, client_options=client_options + credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options ) - def wait_for_operation(self, operation: Operation, timeout: Optional[float] = None): + def wait_for_operation(self, operation: Operation, timeout: float | None = None): """Waits for long-lasting operation to complete.""" try: return operation.result(timeout=timeout) @@ -61,7 +61,7 @@ def wait_for_operation(self, operation: Operation, timeout: Optional[float] = No raise AirflowException(error) @staticmethod - def extract_dataset_id(obj: Dict) -> str: + def extract_dataset_id(obj: dict) -> str: """Returns unique id of the dataset.""" return obj["name"].rpartition("/")[-1] @@ -70,10 +70,10 @@ def create_dataset( self, project_id: str, region: str, - dataset: Union[Dataset, Dict], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + dataset: Dataset | dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Creates a Dataset. @@ -90,8 +90,8 @@ def create_dataset( result = client.create_dataset( request={ - 'parent': parent, - 'dataset': dataset, + "parent": parent, + "dataset": dataset, }, retry=retry, timeout=timeout, @@ -105,9 +105,9 @@ def delete_dataset( project_id: str, region: str, dataset: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Deletes a Dataset. @@ -124,7 +124,7 @@ def delete_dataset( result = client.delete_dataset( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -138,10 +138,10 @@ def export_data( project_id: str, region: str, dataset: str, - export_config: Union[ExportDataConfig, Dict], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + export_config: ExportDataConfig | dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Exports data from a Dataset. @@ -159,8 +159,8 @@ def export_data( result = client.export_data( request={ - 'name': name, - 'export_config': export_config, + "name": name, + "export_config": export_config, }, retry=retry, timeout=timeout, @@ -175,10 +175,10 @@ def get_annotation_spec( region: str, dataset: str, annotation_spec: str, - read_mask: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + read_mask: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> AnnotationSpec: """ Gets an AnnotationSpec. @@ -197,8 +197,8 @@ def get_annotation_spec( result = client.get_annotation_spec( request={ - 'name': name, - 'read_mask': read_mask, + "name": name, + "read_mask": read_mask, }, retry=retry, timeout=timeout, @@ -212,10 +212,10 @@ def get_dataset( project_id: str, region: str, dataset: str, - read_mask: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + read_mask: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Dataset: """ Gets a Dataset. @@ -233,8 +233,8 @@ def get_dataset( result = client.get_dataset( request={ - 'name': name, - 'read_mask': read_mask, + "name": name, + "read_mask": read_mask, }, retry=retry, timeout=timeout, @@ -249,9 +249,9 @@ def import_data( region: str, dataset: str, import_configs: Sequence[ImportDataConfig], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Imports data into a Dataset. @@ -270,8 +270,8 @@ def import_data( result = client.import_data( request={ - 'name': name, - 'import_configs': import_configs, + "name": name, + "import_configs": import_configs, }, retry=retry, timeout=timeout, @@ -286,14 +286,14 @@ def list_annotations( region: str, dataset: str, data_item: str, - filter: Optional[str] = None, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - read_mask: Optional[str] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + filter: str | None = None, + page_size: int | None = None, + page_token: str | None = None, + read_mask: str | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ListAnnotationsPager: """ Lists Annotations belongs to a data item @@ -317,12 +317,12 @@ def list_annotations( result = client.list_annotations( request={ - 'parent': parent, - 'filter': filter, - 'page_size': page_size, - 'page_token': page_token, - 'read_mask': read_mask, - 'order_by': order_by, + "parent": parent, + "filter": filter, + "page_size": page_size, + "page_token": page_token, + "read_mask": read_mask, + "order_by": order_by, }, retry=retry, timeout=timeout, @@ -336,14 +336,14 @@ def list_data_items( project_id: str, region: str, dataset: str, - filter: Optional[str] = None, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - read_mask: Optional[str] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + filter: str | None = None, + page_size: int | None = None, + page_token: str | None = None, + read_mask: str | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ListDataItemsPager: """ Lists DataItems in a Dataset. @@ -366,12 +366,12 @@ def list_data_items( result = client.list_data_items( request={ - 'parent': parent, - 'filter': filter, - 'page_size': page_size, - 'page_token': page_token, - 'read_mask': read_mask, - 'order_by': order_by, + "parent": parent, + "filter": filter, + "page_size": page_size, + "page_token": page_token, + "read_mask": read_mask, + "order_by": order_by, }, retry=retry, timeout=timeout, @@ -384,14 +384,14 @@ def list_datasets( self, project_id: str, region: str, - filter: Optional[str] = None, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - read_mask: Optional[str] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + filter: str | None = None, + page_size: int | None = None, + page_token: str | None = None, + read_mask: str | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ListDatasetsPager: """ Lists Datasets in a Location. @@ -413,12 +413,12 @@ def list_datasets( result = client.list_datasets( request={ - 'parent': parent, - 'filter': filter, - 'page_size': page_size, - 'page_token': page_token, - 'read_mask': read_mask, - 'order_by': order_by, + "parent": parent, + "filter": filter, + "page_size": page_size, + "page_token": page_token, + "read_mask": read_mask, + "order_by": order_by, }, retry=retry, timeout=timeout, @@ -431,11 +431,11 @@ def update_dataset( project_id: str, region: str, dataset_id: str, - dataset: Union[Dataset, Dict], - update_mask: Union[FieldMask, Dict], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + dataset: Dataset | dict, + update_mask: FieldMask | dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Dataset: """ Updates a Dataset. @@ -454,8 +454,8 @@ def update_dataset( result = client.update_dataset( request={ - 'dataset': dataset, - 'update_mask': update_mask, + "dataset": dataset, + "update_mask": update_mask, }, retry=retry, timeout=timeout, diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py b/airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py index 63307fed15715..6bee752463d16 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains a Google Cloud Vertex AI hook. .. spelling:: @@ -27,8 +26,9 @@ FieldMask unassigns """ +from __future__ import annotations -from typing import Dict, Optional, Sequence, Tuple, Union +from typing import Sequence from google.api_core.client_options import ClientOptions from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -46,18 +46,18 @@ class EndpointServiceHook(GoogleBaseHook): """Hook for Google Cloud Vertex AI Endpoint Service APIs.""" - def get_endpoint_service_client(self, region: Optional[str] = None) -> EndpointServiceClient: + def get_endpoint_service_client(self, region: str | None = None) -> EndpointServiceClient: """Returns EndpointServiceClient.""" - if region and region != 'global': - client_options = ClientOptions(api_endpoint=f'{region}-aiplatform.googleapis.com:443') + if region and region != "global": + client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443") else: client_options = ClientOptions() return EndpointServiceClient( - credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options + credentials=self.get_credentials(), client_info=self.client_info, client_options=client_options ) - def wait_for_operation(self, operation: Operation, timeout: Optional[float] = None): + def wait_for_operation(self, operation: Operation, timeout: float | None = None): """Waits for long-lasting operation to complete.""" try: return operation.result(timeout=timeout) @@ -66,12 +66,12 @@ def wait_for_operation(self, operation: Operation, timeout: Optional[float] = No raise AirflowException(error) @staticmethod - def extract_endpoint_id(obj: Dict) -> str: + def extract_endpoint_id(obj: dict) -> str: """Returns unique id of the endpoint.""" return obj["name"].rpartition("/")[-1] @staticmethod - def extract_deployed_model_id(obj: Dict) -> str: + def extract_deployed_model_id(obj: dict) -> str: """Returns unique id of the deploy model.""" return obj["deployed_model"]["id"] @@ -80,11 +80,11 @@ def create_endpoint( self, project_id: str, region: str, - endpoint: Union[Endpoint, Dict], - endpoint_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + endpoint: Endpoint | dict, + endpoint_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Creates an Endpoint. @@ -103,9 +103,9 @@ def create_endpoint( result = client.create_endpoint( request={ - 'parent': parent, - 'endpoint': endpoint, - 'endpoint_id': endpoint_id, + "parent": parent, + "endpoint": endpoint, + "endpoint_id": endpoint_id, }, retry=retry, timeout=timeout, @@ -119,9 +119,9 @@ def delete_endpoint( project_id: str, region: str, endpoint: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Deletes an Endpoint. @@ -138,7 +138,7 @@ def delete_endpoint( result = client.delete_endpoint( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -152,11 +152,11 @@ def deploy_model( project_id: str, region: str, endpoint: str, - deployed_model: Union[DeployedModel, Dict], - traffic_split: Optional[Union[Sequence, Dict]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + deployed_model: DeployedModel | dict, + traffic_split: Sequence | dict | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Deploys a Model into this Endpoint, creating a DeployedModel within it. @@ -189,9 +189,9 @@ def deploy_model( result = client.deploy_model( request={ - 'endpoint': endpoint_path, - 'deployed_model': deployed_model, - 'traffic_split': traffic_split, + "endpoint": endpoint_path, + "deployed_model": deployed_model, + "traffic_split": traffic_split, }, retry=retry, timeout=timeout, @@ -205,9 +205,9 @@ def get_endpoint( project_id: str, region: str, endpoint: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Endpoint: """ Gets an Endpoint. @@ -224,7 +224,7 @@ def get_endpoint( result = client.get_endpoint( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -237,14 +237,14 @@ def list_endpoints( self, project_id: str, region: str, - filter: Optional[str] = None, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - read_mask: Optional[str] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + filter: str | None = None, + page_size: int | None = None, + page_token: str | None = None, + read_mask: str | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ListEndpointsPager: """ Lists Endpoints in a Location. @@ -282,12 +282,12 @@ def list_endpoints( result = client.list_endpoints( request={ - 'parent': parent, - 'filter': filter, - 'page_size': page_size, - 'page_token': page_token, - 'read_mask': read_mask, - 'order_by': order_by, + "parent": parent, + "filter": filter, + "page_size": page_size, + "page_token": page_token, + "read_mask": read_mask, + "order_by": order_by, }, retry=retry, timeout=timeout, @@ -302,10 +302,10 @@ def undeploy_model( region: str, endpoint: str, deployed_model_id: str, - traffic_split: Optional[Union[Sequence, Dict]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + traffic_split: Sequence | dict | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Undeploys a Model from an Endpoint, removing a DeployedModel from it, and freeing all resources it's @@ -330,9 +330,9 @@ def undeploy_model( result = client.undeploy_model( request={ - 'endpoint': endpoint_path, - 'deployed_model_id': deployed_model_id, - 'traffic_split': traffic_split, + "endpoint": endpoint_path, + "deployed_model_id": deployed_model_id, + "traffic_split": traffic_split, }, retry=retry, timeout=timeout, @@ -346,11 +346,11 @@ def update_endpoint( project_id: str, region: str, endpoint_id: str, - endpoint: Union[Endpoint, Dict], - update_mask: Union[FieldMask, Dict], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + endpoint: Endpoint | dict, + update_mask: FieldMask | dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Endpoint: """ Updates an Endpoint. @@ -369,8 +369,8 @@ def update_endpoint( result = client.update_endpoint( request={ - 'endpoint': endpoint, - 'update_mask': update_mask, + "endpoint": endpoint, + "update_mask": update_mask, }, retry=retry, timeout=timeout, diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py b/airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py index 2659636c76bdf..23aaaa1f81100 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains a Google Cloud Vertex AI hook. .. spelling:: @@ -26,8 +25,9 @@ aiplatform myVPC """ +from __future__ import annotations -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Sequence from google.api_core.client_options import ClientOptions from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -47,42 +47,42 @@ class HyperparameterTuningJobHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) - self._hyperparameter_tuning_job: Optional[HyperparameterTuningJob] = None + self._hyperparameter_tuning_job: HyperparameterTuningJob | None = None - def get_job_service_client(self, region: Optional[str] = None) -> JobServiceClient: + def get_job_service_client(self, region: str | None = None) -> JobServiceClient: """Returns JobServiceClient.""" - if region and region != 'global': - client_options = ClientOptions(api_endpoint=f'{region}-aiplatform.googleapis.com:443') + if region and region != "global": + client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443") else: client_options = ClientOptions() return JobServiceClient( - credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options + credentials=self.get_credentials(), client_info=self.client_info, client_options=client_options ) def get_hyperparameter_tuning_job_object( self, display_name: str, custom_job: CustomJob, - metric_spec: Dict[str, str], - parameter_spec: Dict[str, hyperparameter_tuning._ParameterSpec], + metric_spec: dict[str, str], + parameter_spec: dict[str, hyperparameter_tuning._ParameterSpec], max_trial_count: int, parallel_trial_count: int, max_failed_trial_count: int = 0, - search_algorithm: Optional[str] = None, - measurement_selection: Optional[str] = "best", - project: Optional[str] = None, - location: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - encryption_spec_key_name: Optional[str] = None, + search_algorithm: str | None = None, + measurement_selection: str | None = "best", + project: str | None = None, + location: str | None = None, + labels: dict[str, str] | None = None, + encryption_spec_key_name: str | None = None, ) -> HyperparameterTuningJob: """Returns HyperparameterTuningJob object""" return HyperparameterTuningJob( @@ -97,7 +97,7 @@ def get_hyperparameter_tuning_job_object( measurement_selection=measurement_selection, project=project, location=location, - credentials=self._get_credentials(), + credentials=self.get_credentials(), labels=labels, encryption_spec_key_name=encryption_spec_key_name, ) @@ -105,13 +105,13 @@ def get_hyperparameter_tuning_job_object( def get_custom_job_object( self, display_name: str, - worker_pool_specs: Union[List[Dict], List[gapic.WorkerPoolSpec]], - base_output_dir: Optional[str] = None, - project: Optional[str] = None, - location: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - encryption_spec_key_name: Optional[str] = None, - staging_bucket: Optional[str] = None, + worker_pool_specs: list[dict] | list[gapic.WorkerPoolSpec], + base_output_dir: str | None = None, + project: str | None = None, + location: str | None = None, + labels: dict[str, str] | None = None, + encryption_spec_key_name: str | None = None, + staging_bucket: str | None = None, ) -> CustomJob: """Returns CustomJob object""" return CustomJob( @@ -120,18 +120,18 @@ def get_custom_job_object( base_output_dir=base_output_dir, project=project, location=location, - credentials=self._get_credentials, + credentials=self.get_credentials, labels=labels, encryption_spec_key_name=encryption_spec_key_name, staging_bucket=staging_bucket, ) @staticmethod - def extract_hyperparameter_tuning_job_id(obj: Dict) -> str: + def extract_hyperparameter_tuning_job_id(obj: dict) -> str: """Returns unique id of the hyperparameter_tuning_job.""" return obj["name"].rpartition("/")[-1] - def wait_for_operation(self, operation: Operation, timeout: Optional[float] = None): + def wait_for_operation(self, operation: Operation, timeout: float | None = None): """Waits for long-lasting operation to complete.""" try: return operation.result(timeout=timeout) @@ -150,29 +150,29 @@ def create_hyperparameter_tuning_job( project_id: str, region: str, display_name: str, - metric_spec: Dict[str, str], - parameter_spec: Dict[str, hyperparameter_tuning._ParameterSpec], + metric_spec: dict[str, str], + parameter_spec: dict[str, hyperparameter_tuning._ParameterSpec], max_trial_count: int, parallel_trial_count: int, # START: CustomJob param - worker_pool_specs: Union[List[Dict], List[gapic.WorkerPoolSpec]], - base_output_dir: Optional[str] = None, - custom_job_labels: Optional[Dict[str, str]] = None, - custom_job_encryption_spec_key_name: Optional[str] = None, - staging_bucket: Optional[str] = None, + worker_pool_specs: list[dict] | list[gapic.WorkerPoolSpec], + base_output_dir: str | None = None, + custom_job_labels: dict[str, str] | None = None, + custom_job_encryption_spec_key_name: str | None = None, + staging_bucket: str | None = None, # END: CustomJob param max_failed_trial_count: int = 0, - search_algorithm: Optional[str] = None, - measurement_selection: Optional[str] = "best", - hyperparameter_tuning_job_labels: Optional[Dict[str, str]] = None, - hyperparameter_tuning_job_encryption_spec_key_name: Optional[str] = None, + search_algorithm: str | None = None, + measurement_selection: str | None = "best", + hyperparameter_tuning_job_labels: dict[str, str] | None = None, + hyperparameter_tuning_job_encryption_spec_key_name: str | None = None, # START: run param - service_account: Optional[str] = None, - network: Optional[str] = None, - timeout: Optional[int] = None, # seconds + service_account: str | None = None, + network: str | None = None, + timeout: int | None = None, # seconds restart_job_on_worker_restart: bool = False, enable_web_access: bool = False, - tensorboard: Optional[str] = None, + tensorboard: str | None = None, sync: bool = True, # END: run param ) -> HyperparameterTuningJob: @@ -304,9 +304,9 @@ def get_hyperparameter_tuning_job( project_id: str, region: str, hyperparameter_tuning_job: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> types.HyperparameterTuningJob: """ Gets a HyperparameterTuningJob @@ -323,7 +323,7 @@ def get_hyperparameter_tuning_job( result = client.get_hyperparameter_tuning_job( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -336,13 +336,13 @@ def list_hyperparameter_tuning_jobs( self, project_id: str, region: str, - filter: Optional[str] = None, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - read_mask: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + filter: str | None = None, + page_size: int | None = None, + page_token: str | None = None, + read_mask: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ListHyperparameterTuningJobsPager: """ Lists HyperparameterTuningJobs in a Location. @@ -371,11 +371,11 @@ def list_hyperparameter_tuning_jobs( result = client.list_hyperparameter_tuning_jobs( request={ - 'parent': parent, - 'filter': filter, - 'page_size': page_size, - 'page_token': page_token, - 'read_mask': read_mask, + "parent": parent, + "filter": filter, + "page_size": page_size, + "page_token": page_token, + "read_mask": read_mask, }, retry=retry, timeout=timeout, @@ -389,9 +389,9 @@ def delete_hyperparameter_tuning_job( project_id: str, region: str, hyperparameter_tuning_job: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Deletes a HyperparameterTuningJob. @@ -409,7 +409,7 @@ def delete_hyperparameter_tuning_job( result = client.delete_hyperparameter_tuning_job( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/model_service.py b/airflow/providers/google/cloud/hooks/vertex_ai/model_service.py index c8f58eb394ea9..ca7f6aca85a89 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/model_service.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains a Google Cloud Vertex AI hook. .. spelling:: @@ -23,8 +22,9 @@ aiplatform camelCase """ +from __future__ import annotations -from typing import Dict, Optional, Sequence, Tuple, Union +from typing import Sequence from google.api_core.client_options import ClientOptions from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -41,23 +41,23 @@ class ModelServiceHook(GoogleBaseHook): """Hook for Google Cloud Vertex AI Endpoint Service APIs.""" - def get_model_service_client(self, region: Optional[str] = None) -> ModelServiceClient: + def get_model_service_client(self, region: str | None = None) -> ModelServiceClient: """Returns ModelServiceClient.""" - if region and region != 'global': - client_options = ClientOptions(api_endpoint=f'{region}-aiplatform.googleapis.com:443') + if region and region != "global": + client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443") else: client_options = ClientOptions() return ModelServiceClient( - credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options + credentials=self.get_credentials(), client_info=self.client_info, client_options=client_options ) @staticmethod - def extract_model_id(obj: Dict) -> str: + def extract_model_id(obj: dict) -> str: """Returns unique id of the model.""" return obj["model"].rpartition("/")[-1] - def wait_for_operation(self, operation: Operation, timeout: Optional[float] = None): + def wait_for_operation(self, operation: Operation, timeout: float | None = None): """Waits for long-lasting operation to complete.""" try: return operation.result(timeout=timeout) @@ -71,9 +71,9 @@ def delete_model( project_id: str, region: str, model: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Deletes a Model. @@ -90,7 +90,7 @@ def delete_model( result = client.delete_model( request={ - 'name': name, + "name": name, }, retry=retry, timeout=timeout, @@ -104,10 +104,10 @@ def export_model( project_id: str, region: str, model: str, - output_config: Union[model_service.ExportModelRequest.OutputConfig, Dict], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + output_config: model_service.ExportModelRequest.OutputConfig | dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Exports a trained, exportable Model to a location specified by the user. @@ -139,14 +139,14 @@ def list_models( self, project_id: str, region: str, - filter: Optional[str] = None, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - read_mask: Optional[str] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + filter: str | None = None, + page_size: int | None = None, + page_token: str | None = None, + read_mask: str | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ListModelsPager: r""" Lists Models in a Location. @@ -180,12 +180,12 @@ def list_models( result = client.list_models( request={ - 'parent': parent, - 'filter': filter, - 'page_size': page_size, - 'page_token': page_token, - 'read_mask': read_mask, - 'order_by': order_by, + "parent": parent, + "filter": filter, + "page_size": page_size, + "page_token": page_token, + "read_mask": read_mask, + "order_by": order_by, }, retry=retry, timeout=timeout, @@ -198,10 +198,10 @@ def upload_model( self, project_id: str, region: str, - model: Union[Model, Dict], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + model: Model | dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Uploads a Model artifact into Vertex AI. diff --git a/airflow/providers/google/cloud/hooks/video_intelligence.py b/airflow/providers/google/cloud/hooks/video_intelligence.py index cc3f00e94eed8..dce9aa2934305 100644 --- a/airflow/providers/google/cloud/hooks/video_intelligence.py +++ b/airflow/providers/google/cloud/hooks/video_intelligence.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Cloud Video Intelligence Hook.""" -from typing import Dict, List, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from typing import Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.operation import Operation @@ -52,8 +54,8 @@ class CloudVideoIntelligenceHook(GoogleBaseHook): def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -63,29 +65,25 @@ def __init__( self._conn = None def get_conn(self) -> VideoIntelligenceServiceClient: - """ - Returns Gcp Video Intelligence Service client - - :rtype: google.cloud.videointelligence_v1.VideoIntelligenceServiceClient - """ + """Returns Gcp Video Intelligence Service client""" if not self._conn: self._conn = VideoIntelligenceServiceClient( - credentials=self._get_credentials(), client_info=CLIENT_INFO + credentials=self.get_credentials(), client_info=CLIENT_INFO ) return self._conn @GoogleBaseHook.quota_retry() def annotate_video( self, - input_uri: Optional[str] = None, - input_content: Optional[bytes] = None, - features: Optional[List[VideoIntelligenceServiceClient.enums.Feature]] = None, - video_context: Union[Dict, VideoContext] = None, - output_uri: Optional[str] = None, - location: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + input_uri: str | None = None, + input_content: bytes | None = None, + features: list[VideoIntelligenceServiceClient.enums.Feature] | None = None, + video_context: dict | VideoContext = None, + output_uri: str | None = None, + location: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Performs video annotation. diff --git a/airflow/providers/google/cloud/hooks/vision.py b/airflow/providers/google/cloud/hooks/vision.py index 1d4d232a1eb4b..5ebdb0e9146dc 100644 --- a/airflow/providers/google/cloud/hooks/vision.py +++ b/airflow/providers/google/cloud/hooks/vision.py @@ -16,16 +16,10 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Cloud Vision Hook.""" -import sys -from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union - -from airflow.providers.google.common.consts import CLIENT_INFO +from __future__ import annotations -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property +from copy import deepcopy +from typing import Any, Callable, Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -40,7 +34,9 @@ ) from google.protobuf.json_format import MessageToDict +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException +from airflow.providers.google.common.consts import CLIENT_INFO from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook ERR_DIFF_NAMES = """The {label} name provided in the object ({explicit_name}) is different @@ -64,7 +60,7 @@ def __init__(self, label: str, id_label: str, get_path: Callable[[str, str, str] self.get_path = get_path def get_entity_with_name( - self, entity: Any, entity_id: Optional[str], location: Optional[str], project_id: str + self, entity: Any, entity_id: str | None, location: str | None, project_id: str ) -> Any: """ Check if entity has the `name` attribute set: @@ -86,11 +82,10 @@ def get_entity_with_name( :param location: Location :param project_id: The id of Google Cloud Vision project. :return: The same entity or entity with new name - :rtype: str :raises: AirflowException """ entity = deepcopy(entity) - explicit_name = getattr(entity, 'name') + explicit_name = getattr(entity, "name") if location and entity_id: # Necessary parameters to construct the name are present. Checking for conflict with explicit name constructed_name = self.get_path(project_id, location, entity_id) @@ -123,16 +118,16 @@ class CloudVisionHook(GoogleBaseHook): keyword arguments rather than positional. """ - product_name_determiner = NameDeterminer('Product', 'product_id', ProductSearchClient.product_path) + product_name_determiner = NameDeterminer("Product", "product_id", ProductSearchClient.product_path) product_set_name_determiner = NameDeterminer( - 'ProductSet', 'productset_id', ProductSearchClient.product_set_path + "ProductSet", "productset_id", ProductSearchClient.product_set_path ) def __init__( self, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -146,10 +141,9 @@ def get_conn(self) -> ProductSearchClient: Retrieves connection to Cloud Vision. :return: Google Cloud Vision client object. - :rtype: google.cloud.vision_v1.ProductSearchClient """ if not self._client: - self._client = ProductSearchClient(credentials=self._get_credentials(), client_info=CLIENT_INFO) + self._client = ProductSearchClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) return self._client @cached_property @@ -158,12 +152,11 @@ def annotator_client(self) -> ImageAnnotatorClient: Creates ImageAnnotatorClient. :return: Google Image Annotator client object. - :rtype: google.cloud.vision_v1.ImageAnnotatorClient """ - return ImageAnnotatorClient(credentials=self._get_credentials()) + return ImageAnnotatorClient(credentials=self.get_credentials()) @staticmethod - def _check_for_error(response: Dict) -> None: + def _check_for_error(response: dict) -> None: if "error" in response: raise AirflowException(response) @@ -171,12 +164,12 @@ def _check_for_error(response: Dict) -> None: def create_product_set( self, location: str, - product_set: Union[dict, ProductSet], + product_set: dict | ProductSet, project_id: str = PROVIDE_PROJECT_ID, - product_set_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + product_set_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> str: """ For the documentation see: @@ -184,7 +177,7 @@ def create_product_set( """ client = self.get_conn() parent = ProductSearchClient.location_path(project_id, location) - self.log.info('Creating a new ProductSet under the parent: %s', parent) + self.log.info("Creating a new ProductSet under the parent: %s", parent) response = client.create_product_set( parent=parent, product_set=product_set, @@ -193,13 +186,13 @@ def create_product_set( timeout=timeout, metadata=metadata, ) - self.log.info('ProductSet created: %s', response.name if response else '') - self.log.debug('ProductSet created:\n%s', response) + self.log.info("ProductSet created: %s", response.name if response else "") + self.log.debug("ProductSet created:\n%s", response) if not product_set_id: # Product set id was generated by the API product_set_id = self._get_autogenerated_id(response) - self.log.info('Extracted autogenerated ProductSet ID from the response: %s', product_set_id) + self.log.info("Extracted autogenerated ProductSet ID from the response: %s", product_set_id) return product_set_id @@ -209,9 +202,9 @@ def get_product_set( location: str, product_set_id: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> dict: """ For the documentation see: @@ -219,23 +212,23 @@ def get_product_set( """ client = self.get_conn() name = ProductSearchClient.product_set_path(project_id, location, product_set_id) - self.log.info('Retrieving ProductSet: %s', name) + self.log.info("Retrieving ProductSet: %s", name) response = client.get_product_set(name=name, retry=retry, timeout=timeout, metadata=metadata) - self.log.info('ProductSet retrieved.') - self.log.debug('ProductSet retrieved:\n%s', response) + self.log.info("ProductSet retrieved.") + self.log.debug("ProductSet retrieved:\n%s", response) return MessageToDict(response) @GoogleBaseHook.fallback_to_default_project_id def update_product_set( self, - product_set: Union[dict, ProductSet], + product_set: dict | ProductSet, project_id: str = PROVIDE_PROJECT_ID, - location: Optional[str] = None, - product_set_id: Optional[str] = None, - update_mask: Union[dict, FieldMask] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + location: str | None = None, + product_set_id: str | None = None, + update_mask: dict | FieldMask = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> dict: """ For the documentation see: @@ -245,12 +238,12 @@ def update_product_set( product_set = self.product_set_name_determiner.get_entity_with_name( product_set, product_set_id, location, project_id ) - self.log.info('Updating ProductSet: %s', product_set.name) + self.log.info("Updating ProductSet: %s", product_set.name) response = client.update_product_set( product_set=product_set, update_mask=update_mask, retry=retry, timeout=timeout, metadata=metadata ) - self.log.info('ProductSet updated: %s', response.name if response else '') - self.log.debug('ProductSet updated:\n%s', response) + self.log.info("ProductSet updated: %s", response.name if response else "") + self.log.debug("ProductSet updated:\n%s", response) return MessageToDict(response) @GoogleBaseHook.fallback_to_default_project_id @@ -259,9 +252,9 @@ def delete_product_set( location: str, product_set_id: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ For the documentation see: @@ -269,20 +262,20 @@ def delete_product_set( """ client = self.get_conn() name = ProductSearchClient.product_set_path(project_id, location, product_set_id) - self.log.info('Deleting ProductSet: %s', name) + self.log.info("Deleting ProductSet: %s", name) client.delete_product_set(name=name, retry=retry, timeout=timeout, metadata=metadata) - self.log.info('ProductSet with the name [%s] deleted.', name) + self.log.info("ProductSet with the name [%s] deleted.", name) @GoogleBaseHook.fallback_to_default_project_id def create_product( self, location: str, - product: Union[dict, Product], + product: dict | Product, project_id: str = PROVIDE_PROJECT_ID, - product_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + product_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ For the documentation see: @@ -290,7 +283,7 @@ def create_product( """ client = self.get_conn() parent = ProductSearchClient.location_path(project_id, location) - self.log.info('Creating a new Product under the parent: %s', parent) + self.log.info("Creating a new Product under the parent: %s", parent) response = client.create_product( parent=parent, product=product, @@ -299,13 +292,13 @@ def create_product( timeout=timeout, metadata=metadata, ) - self.log.info('Product created: %s', response.name if response else '') - self.log.debug('Product created:\n%s', response) + self.log.info("Product created: %s", response.name if response else "") + self.log.debug("Product created:\n%s", response) if not product_id: # Product id was generated by the API product_id = self._get_autogenerated_id(response) - self.log.info('Extracted autogenerated Product ID from the response: %s', product_id) + self.log.info("Extracted autogenerated Product ID from the response: %s", product_id) return product_id @@ -315,9 +308,9 @@ def get_product( location: str, product_id: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ For the documentation see: @@ -325,23 +318,23 @@ def get_product( """ client = self.get_conn() name = ProductSearchClient.product_path(project_id, location, product_id) - self.log.info('Retrieving Product: %s', name) + self.log.info("Retrieving Product: %s", name) response = client.get_product(name=name, retry=retry, timeout=timeout, metadata=metadata) - self.log.info('Product retrieved.') - self.log.debug('Product retrieved:\n%s', response) + self.log.info("Product retrieved.") + self.log.debug("Product retrieved:\n%s", response) return MessageToDict(response) @GoogleBaseHook.fallback_to_default_project_id def update_product( self, - product: Union[dict, Product], + product: dict | Product, project_id: str = PROVIDE_PROJECT_ID, - location: Optional[str] = None, - product_id: Optional[str] = None, - update_mask: Optional[Dict[str, FieldMask]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + location: str | None = None, + product_id: str | None = None, + update_mask: dict[str, FieldMask] | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ): """ For the documentation see: @@ -349,12 +342,12 @@ def update_product( """ client = self.get_conn() product = self.product_name_determiner.get_entity_with_name(product, product_id, location, project_id) - self.log.info('Updating ProductSet: %s', product.name) + self.log.info("Updating ProductSet: %s", product.name) response = client.update_product( product=product, update_mask=update_mask, retry=retry, timeout=timeout, metadata=metadata ) - self.log.info('Product updated: %s', response.name if response else '') - self.log.debug('Product updated:\n%s', response) + self.log.info("Product updated: %s", response.name if response else "") + self.log.debug("Product updated:\n%s", response) return MessageToDict(response) @GoogleBaseHook.fallback_to_default_project_id @@ -363,9 +356,9 @@ def delete_product( location: str, product_id: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ For the documentation see: @@ -373,28 +366,28 @@ def delete_product( """ client = self.get_conn() name = ProductSearchClient.product_path(project_id, location, product_id) - self.log.info('Deleting ProductSet: %s', name) + self.log.info("Deleting ProductSet: %s", name) client.delete_product(name=name, retry=retry, timeout=timeout, metadata=metadata) - self.log.info('Product with the name [%s] deleted:', name) + self.log.info("Product with the name [%s] deleted:", name) @GoogleBaseHook.fallback_to_default_project_id def create_reference_image( self, location: str, product_id: str, - reference_image: Union[dict, ReferenceImage], + reference_image: dict | ReferenceImage, project_id: str, - reference_image_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + reference_image_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> str: """ For the documentation see: :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionCreateReferenceImageOperator` """ client = self.get_conn() - self.log.info('Creating ReferenceImage') + self.log.info("Creating ReferenceImage") parent = ProductSearchClient.product_path(project=project_id, location=location, product=product_id) response = client.create_reference_image( @@ -406,14 +399,14 @@ def create_reference_image( metadata=metadata, ) - self.log.info('ReferenceImage created: %s', response.name if response else '') - self.log.debug('ReferenceImage created:\n%s', response) + self.log.info("ReferenceImage created: %s", response.name if response else "") + self.log.debug("ReferenceImage created:\n%s", response) if not reference_image_id: # Reference image id was generated by the API reference_image_id = self._get_autogenerated_id(response) self.log.info( - 'Extracted autogenerated ReferenceImage ID from the response: %s', reference_image_id + "Extracted autogenerated ReferenceImage ID from the response: %s", reference_image_id ) return reference_image_id @@ -425,29 +418,28 @@ def delete_reference_image( product_id: str, reference_image_id: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> dict: + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> None: """ For the documentation see: :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionDeleteReferenceImageOperator` """ client = self.get_conn() - self.log.info('Deleting ReferenceImage') + self.log.info("Deleting ReferenceImage") name = ProductSearchClient.reference_image_path( project=project_id, location=location, product=product_id, reference_image=reference_image_id ) - response = client.delete_reference_image( + client.delete_reference_image( name=name, retry=retry, timeout=timeout, metadata=metadata, ) - self.log.info('ReferenceImage with the name [%s] deleted.', name) - return MessageToDict(response) + self.log.info("ReferenceImage with the name [%s] deleted.", name) @GoogleBaseHook.fallback_to_default_project_id def add_product_to_product_set( @@ -455,10 +447,10 @@ def add_product_to_product_set( product_set_id: str, product_id: str, project_id: str, - location: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + location: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ For the documentation see: @@ -469,13 +461,13 @@ def add_product_to_product_set( product_name = ProductSearchClient.product_path(project_id, location, product_id) product_set_name = ProductSearchClient.product_set_path(project_id, location, product_set_id) - self.log.info('Add Product[name=%s] to Product Set[name=%s]', product_name, product_set_name) + self.log.info("Add Product[name=%s] to Product Set[name=%s]", product_name, product_set_name) client.add_product_to_product_set( name=product_set_name, product=product_name, retry=retry, timeout=timeout, metadata=metadata ) - self.log.info('Product added to Product Set') + self.log.info("Product added to Product Set") @GoogleBaseHook.fallback_to_default_project_id def remove_product_from_product_set( @@ -483,10 +475,10 @@ def remove_product_from_product_set( product_set_id: str, product_id: str, project_id: str, - location: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + location: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> None: """ For the documentation see: @@ -497,40 +489,40 @@ def remove_product_from_product_set( product_name = ProductSearchClient.product_path(project_id, location, product_id) product_set_name = ProductSearchClient.product_set_path(project_id, location, product_set_id) - self.log.info('Remove Product[name=%s] from Product Set[name=%s]', product_name, product_set_name) + self.log.info("Remove Product[name=%s] from Product Set[name=%s]", product_name, product_set_name) client.remove_product_from_product_set( name=product_set_name, product=product_name, retry=retry, timeout=timeout, metadata=metadata ) - self.log.info('Product removed from Product Set') + self.log.info("Product removed from Product Set") def annotate_image( self, - request: Union[dict, AnnotateImageRequest], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - ) -> Dict: + request: dict | AnnotateImageRequest, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + ) -> dict: """ For the documentation see: :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionImageAnnotateOperator` """ client = self.annotator_client - self.log.info('Annotating image') + self.log.info("Annotating image") response = client.annotate_image(request=request, retry=retry, timeout=timeout) - self.log.info('Image annotated') + self.log.info("Image annotated") return MessageToDict(response) @GoogleBaseHook.quota_retry() def batch_annotate_images( self, - requests: Union[List[dict], List[AnnotateImageRequest]], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + requests: list[dict] | list[AnnotateImageRequest], + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, ) -> dict: """ For the documentation see: @@ -538,22 +530,22 @@ def batch_annotate_images( """ client = self.annotator_client - self.log.info('Annotating images') + self.log.info("Annotating images") response = client.batch_annotate_images(requests=requests, retry=retry, timeout=timeout) - self.log.info('Images annotated') + self.log.info("Images annotated") return MessageToDict(response) @GoogleBaseHook.quota_retry() def text_detection( self, - image: Union[dict, Image], - max_results: Optional[int] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - additional_properties: Optional[Dict] = None, + image: dict | Image, + max_results: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + additional_properties: dict | None = None, ) -> dict: """ For the documentation see: @@ -579,11 +571,11 @@ def text_detection( @GoogleBaseHook.quota_retry() def document_text_detection( self, - image: Union[dict, Image], - max_results: Optional[int] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - additional_properties: Optional[dict] = None, + image: dict | Image, + max_results: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + additional_properties: dict | None = None, ) -> dict: """ For the documentation see: @@ -609,11 +601,11 @@ def document_text_detection( @GoogleBaseHook.quota_retry() def label_detection( self, - image: Union[dict, Image], - max_results: Optional[int] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - additional_properties: Optional[dict] = None, + image: dict | Image, + max_results: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + additional_properties: dict | None = None, ) -> dict: """ For the documentation see: @@ -639,11 +631,11 @@ def label_detection( @GoogleBaseHook.quota_retry() def safe_search_detection( self, - image: Union[dict, Image], - max_results: Optional[int] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - additional_properties: Optional[dict] = None, + image: dict | Image, + max_results: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + additional_properties: dict | None = None, ) -> dict: """ For the documentation see: @@ -670,7 +662,7 @@ def _get_autogenerated_id(response) -> str: try: name = response.name except AttributeError as e: - raise AirflowException(f'Unable to get name from response... [{response}]\n{e}') - if '/' not in name: - raise AirflowException(f'Unable to get id from name... [{name}]') - return name.rsplit('/', 1)[1] + raise AirflowException(f"Unable to get name from response... [{response}]\n{e}") + if "/" not in name: + raise AirflowException(f"Unable to get id from name... [{name}]") + return name.rsplit("/", 1)[1] diff --git a/airflow/providers/google/cloud/hooks/workflows.py b/airflow/providers/google/cloud/hooks/workflows.py index fb9963160f074..a8a99252c77c5 100644 --- a/airflow/providers/google/cloud/hooks/workflows.py +++ b/airflow/providers/google/cloud/hooks/workflows.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import Dict, Optional, Sequence, Tuple, Union +from typing import Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.operation import Operation @@ -40,22 +41,22 @@ class WorkflowsHook(GoogleBaseHook): def get_workflows_client(self) -> WorkflowsClient: """Returns WorkflowsClient.""" - return WorkflowsClient(credentials=self._get_credentials(), client_info=CLIENT_INFO) + return WorkflowsClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) def get_executions_client(self) -> ExecutionsClient: """Returns ExecutionsClient.""" - return ExecutionsClient(credentials=self._get_credentials(), client_info=CLIENT_INFO) + return ExecutionsClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) @GoogleBaseHook.fallback_to_default_project_id def create_workflow( self, - workflow: Dict, + workflow: dict, workflow_id: str, location: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Creates a new workflow. If a workflow with the specified name @@ -89,9 +90,9 @@ def get_workflow( workflow_id: str, location: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Workflow: """ Gets details of a single Workflow. @@ -112,11 +113,11 @@ def get_workflow( def update_workflow( self, - workflow: Union[Dict, Workflow], - update_mask: Optional[FieldMask] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + workflow: dict | Workflow, + update_mask: FieldMask | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Updates an existing workflow. @@ -150,9 +151,9 @@ def delete_workflow( workflow_id: str, location: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Operation: """ Deletes a workflow with the specified name. @@ -178,11 +179,11 @@ def list_workflows( self, location: str, project_id: str = PROVIDE_PROJECT_ID, - filter_: Optional[str] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + filter_: str | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ListWorkflowsPager: """ Lists Workflows in a given project and location. @@ -217,11 +218,11 @@ def create_execution( self, workflow_id: str, location: str, - execution: Dict, + execution: dict, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Execution: """ Creates a new execution using the latest revision of @@ -240,6 +241,7 @@ def create_execution( metadata = metadata or () client = self.get_executions_client() parent = f"projects/{project_id}/locations/{location}/workflows/{workflow_id}" + execution = {k: str(v) if isinstance(v, dict) else v for k, v in execution.items()} return client.create_execution( request={"parent": parent, "execution": execution}, retry=retry, @@ -254,9 +256,9 @@ def get_execution( execution_id: str, location: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Execution: """ Returns an execution for the given ``workflow_id`` and ``execution_id``. @@ -283,9 +285,9 @@ def cancel_execution( execution_id: str, location: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> Execution: """ Cancels an execution using the given ``workflow_id`` and ``execution_id``. @@ -313,9 +315,9 @@ def list_executions( workflow_id: str, location: str, project_id: str = PROVIDE_PROJECT_ID, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), ) -> ListExecutionsPager: """ Returns a list of executions which belong to the diff --git a/airflow/providers/google/cloud/links/base.py b/airflow/providers/google/cloud/links/base.py index fab7c7d00fc33..755266758e8e8 100644 --- a/airflow/providers/google/cloud/links/base.py +++ b/airflow/providers/google/cloud/links/base.py @@ -15,15 +15,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from datetime import datetime -from typing import TYPE_CHECKING, ClassVar, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar from airflow.models import BaseOperatorLink, XCom if TYPE_CHECKING: + from airflow.models import BaseOperator from airflow.models.taskinstance import TaskInstanceKey +BASE_LINK = "https://console.cloud.google.com" + + class BaseGoogleLink(BaseOperatorLink): """:meta private:""" @@ -33,19 +38,13 @@ class BaseGoogleLink(BaseOperatorLink): def get_link( self, - operator, - dttm: Optional[datetime] = None, - ti_key: Optional["TaskInstanceKey"] = None, + operator: BaseOperator, + *, + ti_key: TaskInstanceKey, ) -> str: - if ti_key is not None: - conf = XCom.get_value(key=self.key, ti_key=ti_key) - else: - assert dttm - conf = XCom.get_one( - key=self.key, - dag_id=operator.dag.dag_id, - task_id=operator.task_id, - execution_date=dttm, - ) - - return self.format_str.format(**conf) if conf else "" + conf = XCom.get_value(key=self.key, ti_key=ti_key) + if not conf: + return "" + if self.format_str.startswith("http"): + return self.format_str.format(**conf) + return BASE_LINK + self.format_str.format(**conf) diff --git a/airflow/providers/google/cloud/links/bigquery.py b/airflow/providers/google/cloud/links/bigquery.py index a80818e2034ed..8c8795c2f7de7 100644 --- a/airflow/providers/google/cloud/links/bigquery.py +++ b/airflow/providers/google/cloud/links/bigquery.py @@ -16,6 +16,8 @@ # specific language governing permissions and limitations # under the License. """This module contains Google BigQuery links.""" +from __future__ import annotations + from typing import TYPE_CHECKING from airflow.models import BaseOperator @@ -24,7 +26,7 @@ if TYPE_CHECKING: from airflow.utils.context import Context -BIGQUERY_BASE_LINK = "https://console.cloud.google.com/bigquery" +BIGQUERY_BASE_LINK = "/bigquery" BIGQUERY_DATASET_LINK = ( BIGQUERY_BASE_LINK + "?referrer=search&project={project_id}&d={dataset_id}&p={project_id}&page=dataset" ) @@ -43,7 +45,7 @@ class BigQueryDatasetLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance: BaseOperator, dataset_id: str, project_id: str, @@ -64,11 +66,11 @@ class BigQueryTableLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance: BaseOperator, - dataset_id: str, project_id: str, table_id: str, + dataset_id: str | None = None, ): task_instance.xcom_push( context, diff --git a/airflow/providers/google/cloud/links/bigquery_dts.py b/airflow/providers/google/cloud/links/bigquery_dts.py index 4a73be51d896a..a7ebde5dd5b97 100644 --- a/airflow/providers/google/cloud/links/bigquery_dts.py +++ b/airflow/providers/google/cloud/links/bigquery_dts.py @@ -16,6 +16,8 @@ # specific language governing permissions and limitations # under the License. """This module contains Google BigQuery Data Transfer links.""" +from __future__ import annotations + from typing import TYPE_CHECKING from airflow.models import BaseOperator @@ -24,7 +26,7 @@ if TYPE_CHECKING: from airflow.utils.context import Context -BIGQUERY_BASE_LINK = "https://console.cloud.google.com/bigquery/transfers" +BIGQUERY_BASE_LINK = "/bigquery/transfers" BIGQUERY_DTS_LINK = BIGQUERY_BASE_LINK + "/locations/{region}/configs/{config_id}/runs?project={project_id}" @@ -37,7 +39,7 @@ class BigQueryDataTransferConfigLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance: BaseOperator, region: str, config_id: str, diff --git a/airflow/providers/google/cloud/links/bigtable.py b/airflow/providers/google/cloud/links/bigtable.py index cc06129021798..962339a3256c1 100644 --- a/airflow/providers/google/cloud/links/bigtable.py +++ b/airflow/providers/google/cloud/links/bigtable.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from typing import TYPE_CHECKING @@ -22,8 +23,7 @@ if TYPE_CHECKING: from airflow.utils.context import Context -BASE_LINK = "https://console.cloud.google.com" -BIGTABLE_BASE_LINK = BASE_LINK + "/bigtable" +BIGTABLE_BASE_LINK = "/bigtable" BIGTABLE_INSTANCE_LINK = BIGTABLE_BASE_LINK + "/instances/{instance_id}/overview?project={project_id}" BIGTABLE_CLUSTER_LINK = ( BIGTABLE_BASE_LINK + "/instances/{instance_id}/clusters/{cluster_id}?project={project_id}" @@ -40,7 +40,7 @@ class BigtableInstanceLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance, ): task_instance.xcom_push( @@ -62,7 +62,7 @@ class BigtableClusterLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance, ): task_instance.xcom_push( @@ -85,7 +85,7 @@ class BigtableTablesLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance, ): task_instance.xcom_push( diff --git a/airflow/providers/google/cloud/links/cloud_build.py b/airflow/providers/google/cloud/links/cloud_build.py new file mode 100644 index 0000000000000..f42216acba628 --- /dev/null +++ b/airflow/providers/google/cloud/links/cloud_build.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.providers.google.cloud.links.base import BaseGoogleLink + +if TYPE_CHECKING: + from airflow.utils.context import Context + +BUILD_BASE_LINK = "/cloud-build" + +BUILD_LINK = BUILD_BASE_LINK + "/builds/{build_id}?project={project_id}" + +BUILD_LIST_LINK = BUILD_BASE_LINK + "/builds?project={project_id}" + +BUILD_TRIGGERS_LIST_LINK = BUILD_BASE_LINK + "/triggers?project={project_id}" + +BUILD_TRIGGER_DETAILS_LINK = BUILD_BASE_LINK + "/triggers/edit/{trigger_id}?project={project_id}" + + +class CloudBuildLink(BaseGoogleLink): + """Helper class for constructing Cloud Build link""" + + name = "Cloud Build Details" + key = "cloud_build_key" + format_str = BUILD_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + build_id: str, + project_id: str, + ): + task_instance.xcom_push( + context=context, + key=CloudBuildLink.key, + value={ + "project_id": project_id, + "build_id": build_id, + }, + ) + + +class CloudBuildListLink(BaseGoogleLink): + """Helper class for constructing Cloud Build List link""" + + name = "Cloud Builds List" + key = "cloud_build_list_key" + format_str = BUILD_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context=context, + key=CloudBuildListLink.key, + value={ + "project_id": project_id, + }, + ) + + +class CloudBuildTriggersListLink(BaseGoogleLink): + """Helper class for constructing Cloud Build Triggers List link""" + + name = "Cloud Build Triggers List" + key = "cloud_build_triggers_list_key" + format_str = BUILD_TRIGGERS_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context=context, + key=CloudBuildTriggersListLink.key, + value={ + "project_id": project_id, + }, + ) + + +class CloudBuildTriggerDetailsLink(BaseGoogleLink): + """Helper class for constructing Cloud Build Trigger Details link""" + + name = "Cloud Build Triggers Details" + key = "cloud_build_triggers_details_key" + format_str = BUILD_TRIGGER_DETAILS_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + trigger_id: str, + ): + task_instance.xcom_push( + context=context, + key=CloudBuildTriggerDetailsLink.key, + value={ + "project_id": project_id, + "trigger_id": trigger_id, + }, + ) diff --git a/airflow/providers/google/cloud/links/cloud_functions.py b/airflow/providers/google/cloud/links/cloud_functions.py new file mode 100644 index 0000000000000..1cb8349607c93 --- /dev/null +++ b/airflow/providers/google/cloud/links/cloud_functions.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Cloud Functions links.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.links.base import BaseGoogleLink + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +CLOUD_FUNCTIONS_BASE_LINK = "https://console.cloud.google.com/functions" + +CLOUD_FUNCTIONS_DETAILS_LINK = ( + CLOUD_FUNCTIONS_BASE_LINK + "/details/{location}/{function_name}?project={project_id}" +) + +CLOUD_FUNCTIONS_LIST_LINK = CLOUD_FUNCTIONS_BASE_LINK + "/list?project={project_id}" + + +class CloudFunctionsDetailsLink(BaseGoogleLink): + """Helper class for constructing Cloud Functions Details Link""" + + name = "Cloud Functions Details" + key = "cloud_functions_details" + format_str = CLOUD_FUNCTIONS_DETAILS_LINK + + @staticmethod + def persist( + context: Context, + task_instance: BaseOperator, + function_name: str, + location: str, + project_id: str, + ): + + task_instance.xcom_push( + context, + key=CloudFunctionsDetailsLink.key, + value={"function_name": function_name, "location": location, "project_id": project_id}, + ) + + +class CloudFunctionsListLink(BaseGoogleLink): + """Helper class for constructing Cloud Functions Details Link""" + + name = "Cloud Functions List" + key = "cloud_functions_list" + format_str = CLOUD_FUNCTIONS_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance: BaseOperator, + project_id: str, + ): + task_instance.xcom_push( + context, + key=CloudFunctionsDetailsLink.key, + value={"project_id": project_id}, + ) diff --git a/airflow/providers/google/cloud/links/cloud_memorystore.py b/airflow/providers/google/cloud/links/cloud_memorystore.py new file mode 100644 index 0000000000000..d91c4bdac2e56 --- /dev/null +++ b/airflow/providers/google/cloud/links/cloud_memorystore.py @@ -0,0 +1,121 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Cloud Memorystore links.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.links.base import BaseGoogleLink + +if TYPE_CHECKING: + from airflow.utils.context import Context + +BASE_LINK = "/memorystore" +MEMCACHED_LINK = ( + BASE_LINK + "/memcached/locations/{location_id}/instances/{instance_id}/details?project={project_id}" +) +MEMCACHED_LIST_LINK = BASE_LINK + "/memcached/instances?project={project_id}" +REDIS_LINK = ( + BASE_LINK + "/redis/locations/{location_id}/instances/{instance_id}/details/overview?project={project_id}" +) +REDIS_LIST_LINK = BASE_LINK + "/redis/instances?project={project_id}" + + +class MemcachedInstanceDetailsLink(BaseGoogleLink): + """Helper class for constructing Memorystore Memcached Instance Link""" + + name = "Memorystore Memcached Instance" + key = "memcached_instance" + format_str = MEMCACHED_LINK + + @staticmethod + def persist( + context: Context, + task_instance: BaseOperator, + instance_id: str, + location_id: str, + project_id: str | None, + ): + task_instance.xcom_push( + context, + key=MemcachedInstanceDetailsLink.key, + value={"instance_id": instance_id, "location_id": location_id, "project_id": project_id}, + ) + + +class MemcachedInstanceListLink(BaseGoogleLink): + """Helper class for constructing Memorystore Memcached List of Instances Link""" + + name = "Memorystore Memcached List of Instances" + key = "memcached_instances" + format_str = MEMCACHED_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance: BaseOperator, + project_id: str | None, + ): + task_instance.xcom_push( + context, + key=MemcachedInstanceListLink.key, + value={"project_id": project_id}, + ) + + +class RedisInstanceDetailsLink(BaseGoogleLink): + """Helper class for constructing Memorystore Redis Instance Link""" + + name = "Memorystore Redis Instance" + key = "redis_instance" + format_str = REDIS_LINK + + @staticmethod + def persist( + context: Context, + task_instance: BaseOperator, + instance_id: str, + location_id: str, + project_id: str | None, + ): + task_instance.xcom_push( + context, + key=RedisInstanceDetailsLink.key, + value={"instance_id": instance_id, "location_id": location_id, "project_id": project_id}, + ) + + +class RedisInstanceListLink(BaseGoogleLink): + """Helper class for constructing Memorystore Redis List of Instances Link""" + + name = "Memorystore Redis List of Instances" + key = "redis_instances" + format_str = REDIS_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance: BaseOperator, + project_id: str | None, + ): + task_instance.xcom_push( + context, + key=RedisInstanceListLink.key, + value={"project_id": project_id}, + ) diff --git a/airflow/providers/google/cloud/links/cloud_sql.py b/airflow/providers/google/cloud/links/cloud_sql.py index 58a0b3c3a5c6d..1b8f8028d08f4 100644 --- a/airflow/providers/google/cloud/links/cloud_sql.py +++ b/airflow/providers/google/cloud/links/cloud_sql.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Cloud SQL links.""" -from typing import TYPE_CHECKING, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING from airflow.models import BaseOperator from airflow.providers.google.cloud.links.base import BaseGoogleLink @@ -25,7 +27,7 @@ from airflow.utils.context import Context -CLOUD_SQL_BASE_LINK = "https://console.cloud.google.com/sql" +CLOUD_SQL_BASE_LINK = "/sql" CLOUD_SQL_INSTANCE_LINK = CLOUD_SQL_BASE_LINK + "/instances/{instance}/overview?project={project_id}" CLOUD_SQL_INSTANCE_DATABASE_LINK = ( CLOUD_SQL_BASE_LINK + "/instances/{instance}/databases?project={project_id}" @@ -41,10 +43,10 @@ class CloudSQLInstanceLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance: BaseOperator, cloud_sql_instance: str, - project_id: Optional[str], + project_id: str | None, ): task_instance.xcom_push( context, @@ -62,10 +64,10 @@ class CloudSQLInstanceDatabaseLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance: BaseOperator, cloud_sql_instance: str, - project_id: Optional[str], + project_id: str | None, ): task_instance.xcom_push( context, diff --git a/airflow/providers/google/cloud/links/cloud_storage_transfer.py b/airflow/providers/google/cloud/links/cloud_storage_transfer.py new file mode 100644 index 0000000000000..4a7db25b64bf5 --- /dev/null +++ b/airflow/providers/google/cloud/links/cloud_storage_transfer.py @@ -0,0 +1,127 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Storage Transfer Service links.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.providers.google.cloud.links.base import BaseGoogleLink + +if TYPE_CHECKING: + from airflow.utils.context import Context + +CLOUD_STORAGE_TRANSFER_BASE_LINK = "https://console.cloud.google.com/transfer" + +CLOUD_STORAGE_TRANSFER_LIST_LINK = CLOUD_STORAGE_TRANSFER_BASE_LINK + "/jobs?project={project_id}" + +CLOUD_STORAGE_TRANSFER_JOB_LINK = ( + CLOUD_STORAGE_TRANSFER_BASE_LINK + "/jobs/transferJobs%2F{transfer_job}/runs?project={project_id}" +) + +CLOUD_STORAGE_TRANSFER_OPERATION_LINK = ( + CLOUD_STORAGE_TRANSFER_BASE_LINK + + "/jobs/transferJobs%2F{transfer_job}/runs/transferOperations%2F{transfer_operation}" + + "?project={project_id}" +) + + +class CloudStorageTransferLinkHelper: + """Helper class for Storage Transfer links""" + + @staticmethod + def extract_parts(operation_name: str | None): + if not operation_name: + return "", "" + transfer_operation = operation_name.split("/")[1] + transfer_job = operation_name.split("-")[1] + return transfer_operation, transfer_job + + +class CloudStorageTransferListLink(BaseGoogleLink): + """Helper class for constructing Cloud Storage Transfer Link""" + + name = "Cloud Storage Transfer" + key = "cloud_storage_transfer" + format_str = CLOUD_STORAGE_TRANSFER_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context, + key=CloudStorageTransferListLink.key, + value={"project_id": project_id}, + ) + + +class CloudStorageTransferJobLink(BaseGoogleLink): + """Helper class for constructing Storage Transfer Job Link""" + + name = "Cloud Storage Transfer Job" + key = "cloud_storage_transfer_job" + format_str = CLOUD_STORAGE_TRANSFER_JOB_LINK + + @staticmethod + def persist( + task_instance, + context: Context, + project_id: str, + job_name: str, + ): + + job_name = job_name.split("/")[1] if job_name else "" + + task_instance.xcom_push( + context, + key=CloudStorageTransferJobLink.key, + value={ + "project_id": project_id, + "transfer_job": job_name, + }, + ) + + +class CloudStorageTransferDetailsLink(BaseGoogleLink): + """Helper class for constructing Cloud Storage Transfer Operation Link""" + + name = "Cloud Storage Transfer Details" + key = "cloud_storage_transfer_details" + format_str = CLOUD_STORAGE_TRANSFER_OPERATION_LINK + + @staticmethod + def persist( + task_instance, + context: Context, + project_id: str, + operation_name: str, + ): + transfer_operation, transfer_job = CloudStorageTransferLinkHelper.extract_parts(operation_name) + + task_instance.xcom_push( + context, + key=CloudStorageTransferDetailsLink.key, + value={ + "project_id": project_id, + "transfer_job": transfer_job, + "transfer_operation": transfer_operation, + }, + ) diff --git a/airflow/providers/google/cloud/links/cloud_tasks.py b/airflow/providers/google/cloud/links/cloud_tasks.py index 16a28ee4b62eb..738b119c4ebc2 100644 --- a/airflow/providers/google/cloud/links/cloud_tasks.py +++ b/airflow/providers/google/cloud/links/cloud_tasks.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Cloud Tasks links.""" -from typing import TYPE_CHECKING, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING from airflow.models import BaseOperator from airflow.providers.google.cloud.links.base import BaseGoogleLink @@ -24,7 +26,7 @@ if TYPE_CHECKING: from airflow.utils.context import Context -CLOUD_TASKS_BASE_LINK = "https://pantheon.corp.google.com/cloudtasks" +CLOUD_TASKS_BASE_LINK = "/cloudtasks" CLOUD_TASKS_QUEUE_LINK = CLOUD_TASKS_BASE_LINK + "/queue/{location}/{queue_id}/tasks?project={project_id}" CLOUD_TASKS_LINK = CLOUD_TASKS_BASE_LINK + "?project={project_id}" @@ -37,7 +39,7 @@ class CloudTasksQueueLink(BaseGoogleLink): format_str = CLOUD_TASKS_QUEUE_LINK @staticmethod - def extract_parts(queue_name: Optional[str]): + def extract_parts(queue_name: str | None): """ Extract project_id, location and queue id from queue name: projects/PROJECT_ID/locations/LOCATION_ID/queues/QUEUE_ID @@ -50,8 +52,8 @@ def extract_parts(queue_name: Optional[str]): @staticmethod def persist( operator_instance: BaseOperator, - context: "Context", - queue_name: Optional[str], + context: Context, + queue_name: str | None, ): project_id, location, queue_id = CloudTasksQueueLink.extract_parts(queue_name) operator_instance.xcom_push( @@ -71,8 +73,8 @@ class CloudTasksLink(BaseGoogleLink): @staticmethod def persist( operator_instance: BaseOperator, - context: "Context", - project_id: Optional[str], + context: Context, + project_id: str | None, ): operator_instance.xcom_push( context, diff --git a/airflow/providers/google/cloud/links/compute.py b/airflow/providers/google/cloud/links/compute.py new file mode 100644 index 0000000000000..c2f15b273004b --- /dev/null +++ b/airflow/providers/google/cloud/links/compute.py @@ -0,0 +1,112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Compute Engine links.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.links.base import BaseGoogleLink + +if TYPE_CHECKING: + from airflow.utils.context import Context + +COMPUTE_BASE_LINK = "https://console.cloud.google.com/compute" +COMPUTE_LINK = ( + COMPUTE_BASE_LINK + "/instancesDetail/zones/{location_id}/instances/{resource_id}?project={project_id}" +) +COMPUTE_TEMPLATE_LINK = COMPUTE_BASE_LINK + "/instanceTemplates/details/{resource_id}?project={project_id}" +COMPUTE_GROUP_MANAGER_LINK = ( + COMPUTE_BASE_LINK + "/instanceGroups/details/{location_id}/{resource_id}?project={project_id}" +) + + +class ComputeInstanceDetailsLink(BaseGoogleLink): + """Helper class for constructing Compute Instance details Link""" + + name = "Compute Instance details" + key = "compute_instance_details" + format_str = COMPUTE_LINK + + @staticmethod + def persist( + context: Context, + task_instance: BaseOperator, + location_id: str, + resource_id: str, + project_id: str | None, + ): + task_instance.xcom_push( + context, + key=ComputeInstanceDetailsLink.key, + value={ + "location_id": location_id, + "resource_id": resource_id, + "project_id": project_id, + }, + ) + + +class ComputeInstanceTemplateDetailsLink(BaseGoogleLink): + """Helper class for constructing Compute Instance Template details Link""" + + name = "Compute Instance Template details" + key = "compute_instance_template_details" + format_str = COMPUTE_TEMPLATE_LINK + + @staticmethod + def persist( + context: Context, + task_instance: BaseOperator, + resource_id: str, + project_id: str | None, + ): + task_instance.xcom_push( + context, + key=ComputeInstanceTemplateDetailsLink.key, + value={ + "resource_id": resource_id, + "project_id": project_id, + }, + ) + + +class ComputeInstanceGroupManagerDetailsLink(BaseGoogleLink): + """Helper class for constructing Compute Instance Group Manager details Link""" + + name = "Compute Instance Group Manager" + key = "compute_instance_group_manager_details" + format_str = COMPUTE_GROUP_MANAGER_LINK + + @staticmethod + def persist( + context: Context, + task_instance: BaseOperator, + location_id: str, + resource_id: str, + project_id: str | None, + ): + task_instance.xcom_push( + context, + key=ComputeInstanceGroupManagerDetailsLink.key, + value={ + "location_id": location_id, + "resource_id": resource_id, + "project_id": project_id, + }, + ) diff --git a/airflow/providers/google/cloud/links/data_loss_prevention.py b/airflow/providers/google/cloud/links/data_loss_prevention.py new file mode 100644 index 0000000000000..46c824f9074ab --- /dev/null +++ b/airflow/providers/google/cloud/links/data_loss_prevention.py @@ -0,0 +1,318 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.providers.google.cloud.links.base import BaseGoogleLink + +if TYPE_CHECKING: + from airflow.utils.context import Context + +BASE_LINK = "https://console.cloud.google.com" + +DLP_BASE_LINK = BASE_LINK + "/security/dlp" + +DLP_DEIDENTIFY_TEMPLATES_LIST_LINK = ( + DLP_BASE_LINK + "/landing/configuration/templates/deidentify?project={project_id}" +) +DLP_DEIDENTIFY_TEMPLATE_DETAILS_LINK = ( + DLP_BASE_LINK + + "/projects/{project_id}/locations/global/deidentifyTemplates/{template_name}?project={project_id}" +) + +DLP_JOB_TRIGGER_LIST_LINK = DLP_BASE_LINK + "/landing/inspection/triggers?project={project_id}" +DLP_JOB_TRIGGER_DETAILS_LINK = ( + DLP_BASE_LINK + "/projects/{project_id}/locations/global/jobTriggers/{trigger_name}?project={project_id}" +) + +DLP_JOBS_LIST_LINK = DLP_BASE_LINK + "/landing/inspection/jobs?project={project_id}" +DLP_JOB_DETAILS_LINK = ( + DLP_BASE_LINK + "/projects/{project_id}/locations/global/dlpJobs/{job_name}?project={project_id}" +) + +DLP_INSPECT_TEMPLATES_LIST_LINK = ( + DLP_BASE_LINK + "/landing/configuration/templates/inspect?project={project_id}" +) +DLP_INSPECT_TEMPLATE_DETAILS_LINK = ( + DLP_BASE_LINK + + "/projects/{project_id}/locations/global/inspectTemplates/{template_name}?project={project_id}" +) + +DLP_INFO_TYPES_LIST_LINK = ( + DLP_BASE_LINK + "/landing/configuration/infoTypes/stored?cloudshell=false&project={project_id}" +) +DLP_INFO_TYPE_DETAILS_LINK = ( + DLP_BASE_LINK + + "/projects/{project_id}/locations/global/storedInfoTypes/{info_type_name}?project={project_id}" +) +DLP_POSSIBLE_INFO_TYPES_LIST_LINK = ( + DLP_BASE_LINK + "/landing/configuration/infoTypes/built-in?project={project_id}" +) + + +class CloudDLPDeidentifyTemplatesListLink(BaseGoogleLink): + """Helper class for constructing Cloud Data Loss Prevention link""" + + name = "Cloud DLP Deidentify Templates List" + key = "cloud_dlp_deidentify_templates_list_key" + format_str = DLP_DEIDENTIFY_TEMPLATES_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context=context, + key=CloudDLPDeidentifyTemplatesListLink.key, + value={ + "project_id": project_id, + }, + ) + + +class CloudDLPDeidentifyTemplateDetailsLink(BaseGoogleLink): + """Helper class for constructing Cloud Data Loss Prevention link""" + + name = "Cloud DLP Deidentify Template Details" + key = "cloud_dlp_deidentify_template_details_key" + format_str = DLP_DEIDENTIFY_TEMPLATE_DETAILS_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + template_name: str, + ): + task_instance.xcom_push( + context=context, + key=CloudDLPDeidentifyTemplateDetailsLink.key, + value={ + "project_id": project_id, + "template_name": template_name, + }, + ) + + +class CloudDLPJobTriggersListLink(BaseGoogleLink): + """Helper class for constructing Cloud Data Loss Prevention link""" + + name = "Cloud DLP Job Triggers List" + key = "cloud_dlp_job_triggers_list_key" + format_str = DLP_JOB_TRIGGER_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context=context, + key=CloudDLPJobTriggersListLink.key, + value={ + "project_id": project_id, + }, + ) + + +class CloudDLPJobTriggerDetailsLink(BaseGoogleLink): + """Helper class for constructing Cloud Data Loss Prevention link""" + + name = "Cloud DLP Job Triggers Details" + key = "cloud_dlp_job_trigger_details_key" + format_str = DLP_JOB_TRIGGER_DETAILS_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + trigger_name: str, + ): + task_instance.xcom_push( + context=context, + key=CloudDLPJobTriggerDetailsLink.key, + value={ + "project_id": project_id, + "trigger_name": trigger_name, + }, + ) + + +class CloudDLPJobsListLink(BaseGoogleLink): + """Helper class for constructing Cloud Data Loss Prevention link""" + + name = "Cloud DLP Jobs List" + key = "cloud_dlp_jobs_list_key" + format_str = DLP_JOBS_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context=context, + key=CloudDLPJobsListLink.key, + value={ + "project_id": project_id, + }, + ) + + +class CloudDLPJobDetailsLink(BaseGoogleLink): + """Helper class for constructing Cloud Data Loss Prevention link""" + + name = "Cloud DLP Job Details" + key = "cloud_dlp_job_details_key" + format_str = DLP_JOB_DETAILS_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + job_name: str, + ): + task_instance.xcom_push( + context=context, + key=CloudDLPJobDetailsLink.key, + value={ + "project_id": project_id, + "job_name": job_name, + }, + ) + + +class CloudDLPInspectTemplatesListLink(BaseGoogleLink): + """Helper class for constructing Cloud Data Loss Prevention link""" + + name = "Cloud DLP Inspect Templates List" + key = "cloud_dlp_inspect_templates_list_key" + format_str = DLP_INSPECT_TEMPLATES_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context=context, + key=CloudDLPInspectTemplatesListLink.key, + value={ + "project_id": project_id, + }, + ) + + +class CloudDLPInspectTemplateDetailsLink(BaseGoogleLink): + """Helper class for constructing Cloud Data Loss Prevention link""" + + name = "Cloud DLP Inspect Template Details" + key = "cloud_dlp_inspect_template_details_key" + format_str = DLP_INSPECT_TEMPLATE_DETAILS_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + template_name: str, + ): + task_instance.xcom_push( + context=context, + key=CloudDLPInspectTemplateDetailsLink.key, + value={ + "project_id": project_id, + "template_name": template_name, + }, + ) + + +class CloudDLPInfoTypesListLink(BaseGoogleLink): + """Helper class for constructing Cloud Data Loss Prevention link""" + + name = "Cloud DLP Info Types List" + key = "cloud_dlp_info_types_list_key" + format_str = DLP_INFO_TYPES_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context=context, + key=CloudDLPInfoTypesListLink.key, + value={ + "project_id": project_id, + }, + ) + + +class CloudDLPInfoTypeDetailsLink(BaseGoogleLink): + """Helper class for constructing Cloud Data Loss Prevention link""" + + name = "Cloud DLP Info Type Details" + key = "cloud_dlp_info_type_details_key" + format_str = DLP_INFO_TYPE_DETAILS_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + info_type_name: str, + ): + task_instance.xcom_push( + context=context, + key=CloudDLPInfoTypeDetailsLink.key, + value={ + "project_id": project_id, + "info_type_name": info_type_name, + }, + ) + + +class CloudDLPPossibleInfoTypesListLink(BaseGoogleLink): + """Helper class for constructing Cloud Data Loss Prevention link""" + + name = "Cloud DLP Possible Info Types List" + key = "cloud_dlp_possible_info_types_list_key" + format_str = DLP_POSSIBLE_INFO_TYPES_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context=context, + key=CloudDLPPossibleInfoTypesListLink.key, + value={ + "project_id": project_id, + }, + ) diff --git a/airflow/providers/google/cloud/links/datacatalog.py b/airflow/providers/google/cloud/links/datacatalog.py new file mode 100644 index 0000000000000..a5e4fcef776ad --- /dev/null +++ b/airflow/providers/google/cloud/links/datacatalog.py @@ -0,0 +1,114 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Data Catalog links.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.links.base import BaseGoogleLink + +if TYPE_CHECKING: + from airflow.utils.context import Context + +DATACATALOG_BASE_LINK = "/datacatalog" +ENTRY_GROUP_LINK = ( + DATACATALOG_BASE_LINK + + "/groups/{entry_group_id};container={project_id};location={location_id}?project={project_id}" +) +ENTRY_LINK = ( + DATACATALOG_BASE_LINK + + "/projects/{project_id}/locations/{location_id}/entryGroups/{entry_group_id}/entries/{entry_id}\ + ?project={project_id}" +) +TAG_TEMPLATE_LINK = ( + DATACATALOG_BASE_LINK + + "/projects/{project_id}/locations/{location_id}/tagTemplates/{tag_template_id}?project={project_id}" +) + + +class DataCatalogEntryGroupLink(BaseGoogleLink): + """Helper class for constructing Data Catalog Entry Group Link""" + + name = "Data Catalog Entry Group" + key = "data_catalog_entry_group" + format_str = ENTRY_GROUP_LINK + + @staticmethod + def persist( + context: Context, + task_instance: BaseOperator, + entry_group_id: str, + location_id: str, + project_id: str | None, + ): + task_instance.xcom_push( + context, + key=DataCatalogEntryGroupLink.key, + value={"entry_group_id": entry_group_id, "location_id": location_id, "project_id": project_id}, + ) + + +class DataCatalogEntryLink(BaseGoogleLink): + """Helper class for constructing Data Catalog Entry Link""" + + name = "Data Catalog Entry" + key = "data_catalog_entry" + format_str = ENTRY_LINK + + @staticmethod + def persist( + context: Context, + task_instance: BaseOperator, + entry_id: str, + entry_group_id: str, + location_id: str, + project_id: str | None, + ): + task_instance.xcom_push( + context, + key=DataCatalogEntryLink.key, + value={ + "entry_id": entry_id, + "entry_group_id": entry_group_id, + "location_id": location_id, + "project_id": project_id, + }, + ) + + +class DataCatalogTagTemplateLink(BaseGoogleLink): + """Helper class for constructing Data Catalog Tag Template Link""" + + name = "Data Catalog Tag Template" + key = "data_catalog_tag_template" + format_str = TAG_TEMPLATE_LINK + + @staticmethod + def persist( + context: Context, + task_instance: BaseOperator, + tag_template_id: str, + location_id: str, + project_id: str | None, + ): + task_instance.xcom_push( + context, + key=DataCatalogTagTemplateLink.key, + value={"tag_template_id": tag_template_id, "location_id": location_id, "project_id": project_id}, + ) diff --git a/airflow/providers/google/cloud/links/dataflow.py b/airflow/providers/google/cloud/links/dataflow.py index 1bccacabbd8d2..1f0f6f87e8e99 100644 --- a/airflow/providers/google/cloud/links/dataflow.py +++ b/airflow/providers/google/cloud/links/dataflow.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Dataflow links.""" -from typing import TYPE_CHECKING, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING from airflow.models import BaseOperator from airflow.providers.google.cloud.links.base import BaseGoogleLink @@ -24,7 +26,7 @@ if TYPE_CHECKING: from airflow.utils.context import Context -DATAFLOW_BASE_LINK = "https://pantheon.corp.google.com/dataflow/jobs" +DATAFLOW_BASE_LINK = "/dataflow/jobs" DATAFLOW_JOB_LINK = DATAFLOW_BASE_LINK + "/{region}/{job_id}?project={project_id}" @@ -38,10 +40,10 @@ class DataflowJobLink(BaseGoogleLink): @staticmethod def persist( operator_instance: BaseOperator, - context: "Context", - project_id: Optional[str], - region: Optional[str], - job_id: Optional[str], + context: Context, + project_id: str | None, + region: str | None, + job_id: str | None, ): operator_instance.xcom_push( context, diff --git a/airflow/providers/google/cloud/links/dataform.py b/airflow/providers/google/cloud/links/dataform.py new file mode 100644 index 0000000000000..5e8e8bd658f9b --- /dev/null +++ b/airflow/providers/google/cloud/links/dataform.py @@ -0,0 +1,127 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Dataflow links.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.links.base import BaseGoogleLink + +if TYPE_CHECKING: + from airflow.utils.context import Context + +DATAFORM_BASE_LINK = "/bigquery/dataform" +DATAFORM_WORKFLOW_INVOCATION_LINK = ( + DATAFORM_BASE_LINK + + "/locations/{region}/repositories/{repository_id}/workflows/" + + "{workflow_invocation_id}?project={project_id}" +) +DATAFORM_REPOSITORY_LINK = ( + DATAFORM_BASE_LINK + + "/locations/{region}/repositories/{repository_id}/" + + "details/workspaces?project={project_id}" +) +DATAFORM_WORKSPACE_LINK = ( + DATAFORM_BASE_LINK + + "/locations/{region}/repositories/{repository_id}/" + + "workspaces/{workspace_id}/" + + "files/?project={project_id}" +) + + +class DataformWorkflowInvocationLink(BaseGoogleLink): + """Helper class for constructing Dataflow Job Link""" + + name = "Dataform Workflow Invocation" + key = "dataform_workflow_invocation_config" + format_str = DATAFORM_WORKFLOW_INVOCATION_LINK + + @staticmethod + def persist( + operator_instance: BaseOperator, + context: Context, + project_id: str, + region: str, + repository_id: str, + workflow_invocation_id: str, + ): + operator_instance.xcom_push( + context, + key=DataformWorkflowInvocationLink.key, + value={ + "project_id": project_id, + "region": region, + "repository_id": repository_id, + "workflow_invocation_id": workflow_invocation_id, + }, + ) + + +class DataformRepositoryLink(BaseGoogleLink): + """Helper class for constructing Dataflow repository link.""" + + name = "Dataform Repository" + key = "dataform_repository" + format_str = DATAFORM_REPOSITORY_LINK + + @staticmethod + def persist( + operator_instance: BaseOperator, + context: Context, + project_id: str, + region: str, + repository_id: str, + ) -> None: + operator_instance.xcom_push( + context=context, + key=DataformRepositoryLink.key, + value={ + "project_id": project_id, + "region": region, + "repository_id": repository_id, + }, + ) + + +class DataformWorkspaceLink(BaseGoogleLink): + """Helper class for constructing Dataform workspace link.""" + + name = "Dataform Workspace" + key = "dataform_workspace" + format_str = DATAFORM_WORKSPACE_LINK + + @staticmethod + def persist( + operator_instance: BaseOperator, + context: Context, + project_id: str, + region: str, + repository_id: str, + workspace_id: str, + ) -> None: + operator_instance.xcom_push( + context=context, + key=DataformWorkspaceLink.key, + value={ + "project_id": project_id, + "region": region, + "repository_id": repository_id, + "workspace_id": workspace_id, + }, + ) diff --git a/airflow/providers/google/cloud/links/dataplex.py b/airflow/providers/google/cloud/links/dataplex.py index 8c08d83e1e195..dcf3c8755848a 100644 --- a/airflow/providers/google/cloud/links/dataplex.py +++ b/airflow/providers/google/cloud/links/dataplex.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Dataplex links.""" +from __future__ import annotations from typing import TYPE_CHECKING @@ -24,10 +25,14 @@ if TYPE_CHECKING: from airflow.utils.context import Context -DATAPLEX_BASE_LINK = "https://console.cloud.google.com/dataplex/process/tasks" +DATAPLEX_BASE_LINK = "/dataplex/process/tasks" DATAPLEX_TASK_LINK = DATAPLEX_BASE_LINK + "/{lake_id}.{task_id};location={region}/jobs?project={project_id}" DATAPLEX_TASKS_LINK = DATAPLEX_BASE_LINK + "?project={project_id}&qLake={lake_id}.{region}" +DATAPLEX_LAKE_LINK = ( + "https://console.cloud.google.com/dataplex/lakes/{lake_id};location={region}?project={project_id}" +) + class DataplexTaskLink(BaseGoogleLink): """Helper class for constructing Dataplex Task link""" @@ -38,7 +43,7 @@ class DataplexTaskLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance, ): task_instance.xcom_push( @@ -62,7 +67,7 @@ class DataplexTasksLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance, ): task_instance.xcom_push( @@ -74,3 +79,26 @@ def persist( "region": task_instance.region, }, ) + + +class DataplexLakeLink(BaseGoogleLink): + """Helper class for constructing Dataplex Lake link""" + + name = "Dataplex Lake" + key = "dataplex_lake_key" + format_str = DATAPLEX_LAKE_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + ): + task_instance.xcom_push( + context=context, + key=DataplexLakeLink.key, + value={ + "lake_id": task_instance.lake_id, + "region": task_instance.region, + "project_id": task_instance.project_id, + }, + ) diff --git a/airflow/providers/google/cloud/links/dataprep.py b/airflow/providers/google/cloud/links/dataprep.py new file mode 100644 index 0000000000000..66caf1cfe8933 --- /dev/null +++ b/airflow/providers/google/cloud/links/dataprep.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.providers.google.cloud.links.base import BaseGoogleLink + +if TYPE_CHECKING: + from airflow.utils.context import Context + +BASE_LINK = "https://clouddataprep.com" +DATAPREP_FLOW_LINK = BASE_LINK + "/flows/{flow_id}?projectId={project_id}" +DATAPREP_JOB_GROUP_LINK = BASE_LINK + "/jobs/{job_group_id}?projectId={project_id}" + + +class DataprepFlowLink(BaseGoogleLink): + """Helper class for constructing Dataprep flow link.""" + + name = "Flow details page" + key = "dataprep_flow_page" + format_str = DATAPREP_FLOW_LINK + + @staticmethod + def persist(context: Context, task_instance, project_id: str, flow_id: int): + task_instance.xcom_push( + context=context, + key=DataprepFlowLink.key, + value={"project_id": project_id, "flow_id": flow_id}, + ) + + +class DataprepJobGroupLink(BaseGoogleLink): + """Helper class for constructing Dataprep job group link.""" + + name = "Job group details page" + key = "dataprep_job_group_page" + format_str = DATAPREP_JOB_GROUP_LINK + + @staticmethod + def persist(context: Context, task_instance, project_id: str, job_group_id: int): + task_instance.xcom_push( + context=context, + key=DataprepJobGroupLink.key, + value={ + "project_id": project_id, + "job_group_id": job_group_id, + }, + ) diff --git a/airflow/providers/google/cloud/links/dataproc.py b/airflow/providers/google/cloud/links/dataproc.py index e45105d4be43e..573621aa1491f 100644 --- a/airflow/providers/google/cloud/links/dataproc.py +++ b/airflow/providers/google/cloud/links/dataproc.py @@ -16,17 +16,19 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Dataproc links.""" +from __future__ import annotations -from datetime import datetime -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from airflow.models import BaseOperatorLink, XCom +from airflow.providers.google.cloud.links.base import BASE_LINK if TYPE_CHECKING: + from airflow.models import BaseOperator from airflow.models.taskinstance import TaskInstanceKey from airflow.utils.context import Context -DATAPROC_BASE_LINK = "https://console.cloud.google.com/dataproc" +DATAPROC_BASE_LINK = BASE_LINK + "/dataproc" DATAPROC_JOB_LOG_LINK = DATAPROC_BASE_LINK + "/jobs/{resource}?region={region}&project={project_id}" DATAPROC_CLUSTER_LINK = ( DATAPROC_BASE_LINK + "/clusters/{resource}/monitoring?region={region}&project={project_id}" @@ -47,7 +49,7 @@ class DataprocLink(BaseOperatorLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance, url: str, resource: str, @@ -65,17 +67,11 @@ def persist( def get_link( self, - operator, - dttm: Optional[datetime] = None, - ti_key: Optional["TaskInstanceKey"] = None, + operator: BaseOperator, + *, + ti_key: TaskInstanceKey, ) -> str: - if ti_key is not None: - conf = XCom.get_value(key=self.key, ti_key=ti_key) - else: - assert dttm - conf = XCom.get_one( - key=self.key, dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm - ) + conf = XCom.get_value(key=self.key, ti_key=ti_key) return ( conf["url"].format( region=conf["region"], project_id=conf["project_id"], resource=conf["resource"] @@ -93,7 +89,7 @@ class DataprocListLink(BaseOperatorLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance, url: str, ): @@ -108,20 +104,11 @@ def persist( def get_link( self, - operator, - dttm: Optional[datetime] = None, - ti_key: Optional["TaskInstanceKey"] = None, + operator: BaseOperator, + *, + ti_key: TaskInstanceKey, ) -> str: - if ti_key is not None: - list_conf = XCom.get_value(key=self.key, ti_key=ti_key) - else: - assert dttm - list_conf = XCom.get_one( - key=self.key, - dag_id=operator.dag.dag_id, - task_id=operator.task_id, - execution_date=dttm, - ) + list_conf = XCom.get_value(key=self.key, ti_key=ti_key) return ( list_conf["url"].format( project_id=list_conf["project_id"], diff --git a/airflow/providers/google/cloud/links/datastore.py b/airflow/providers/google/cloud/links/datastore.py index d17a6a8ef0ef2..08f57862312e5 100644 --- a/airflow/providers/google/cloud/links/datastore.py +++ b/airflow/providers/google/cloud/links/datastore.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from typing import TYPE_CHECKING @@ -22,12 +23,9 @@ if TYPE_CHECKING: from airflow.utils.context import Context -BASE_LINK = "https://console.cloud.google.com" -DATASTORE_BASE_LINK = BASE_LINK + "/datastore" +DATASTORE_BASE_LINK = "/datastore" DATASTORE_IMPORT_EXPORT_LINK = DATASTORE_BASE_LINK + "/import-export?project={project_id}" -DATASTORE_EXPORT_ENTITIES_LINK = ( - BASE_LINK + "/storage/browser/{bucket_name}/{export_name}?project={project_id}" -) +DATASTORE_EXPORT_ENTITIES_LINK = "/storage/browser/{bucket_name}/{export_name}?project={project_id}" DATASTORE_ENTITIES_LINK = DATASTORE_BASE_LINK + "/entities/query/kind?project={project_id}" @@ -40,7 +38,7 @@ class CloudDatastoreImportExportLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance, ): task_instance.xcom_push( @@ -61,7 +59,7 @@ class CloudDatastoreEntitiesLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance, ): task_instance.xcom_push( diff --git a/airflow/providers/google/cloud/links/kubernetes_engine.py b/airflow/providers/google/cloud/links/kubernetes_engine.py new file mode 100644 index 0000000000000..1beb4a0a84f4d --- /dev/null +++ b/airflow/providers/google/cloud/links/kubernetes_engine.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from typing import TYPE_CHECKING + +from google.cloud.container_v1.types import Cluster + +from airflow.providers.google.cloud.links.base import BaseGoogleLink + +if TYPE_CHECKING: + from airflow.utils.context import Context + +KUBERNETES_BASE_LINK = "/kubernetes" +KUBERNETES_CLUSTER_LINK = ( + KUBERNETES_BASE_LINK + "/clusters/details/{location}/{cluster_name}/details?project={project_id}" +) +KUBERNETES_POD_LINK = ( + KUBERNETES_BASE_LINK + + "/pod/{location}/{cluster_name}/{namespace}/{pod_name}/details?project={project_id}" +) + + +class KubernetesEngineClusterLink(BaseGoogleLink): + """Helper class for constructing Kubernetes Engine Cluster Link""" + + name = "Kubernetes Cluster" + key = "kubernetes_cluster_conf" + format_str = KUBERNETES_CLUSTER_LINK + + @staticmethod + def persist(context: Context, task_instance, cluster: dict | Cluster | None): + if isinstance(cluster, dict): + cluster = Cluster.from_json(json.dumps(cluster)) + + task_instance.xcom_push( + context=context, + key=KubernetesEngineClusterLink.key, + value={ + "location": task_instance.location, + "cluster_name": cluster.name, # type: ignore + "project_id": task_instance.project_id, + }, + ) + + +class KubernetesEnginePodLink(BaseGoogleLink): + """Helper class for constructing Kubernetes Engine Pod Link""" + + name = "Kubernetes Pod" + key = "kubernetes_pod_conf" + format_str = KUBERNETES_POD_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + ): + task_instance.xcom_push( + context=context, + key=KubernetesEnginePodLink.key, + value={ + "location": task_instance.location, + "cluster_name": task_instance.cluster_name, + "namespace": task_instance.pod.metadata.namespace, + "pod_name": task_instance.pod.metadata.name, + "project_id": task_instance.project_id, + }, + ) diff --git a/airflow/providers/google/cloud/links/life_sciences.py b/airflow/providers/google/cloud/links/life_sciences.py new file mode 100644 index 0000000000000..50d783da32506 --- /dev/null +++ b/airflow/providers/google/cloud/links/life_sciences.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.providers.google.cloud.links.base import BaseGoogleLink + +if TYPE_CHECKING: + from airflow.utils.context import Context + +BASE_LINK = "https://console.cloud.google.com/lifesciences" +LIFESCIENCES_LIST_LINK = BASE_LINK + "/pipelines?project={project_id}" + + +class LifeSciencesLink(BaseGoogleLink): + """Helper class for constructing Life Sciences List link""" + + name = "Life Sciences" + key = "lifesciences_key" + format_str = LIFESCIENCES_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context=context, + key=LifeSciencesLink.key, + value={ + "project_id": project_id, + }, + ) diff --git a/airflow/providers/google/cloud/links/mlengine.py b/airflow/providers/google/cloud/links/mlengine.py new file mode 100644 index 0000000000000..bbfe0cc5385ce --- /dev/null +++ b/airflow/providers/google/cloud/links/mlengine.py @@ -0,0 +1,140 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google ML Engine links.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.providers.google.cloud.links.base import BaseGoogleLink + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +MLENGINE_BASE_LINK = "https://console.cloud.google.com/ai-platform" +MLENGINE_MODEL_DETAILS_LINK = MLENGINE_BASE_LINK + "/models/{model_id}/versions?project={project_id}" +MLENGINE_MODEL_VERSION_DETAILS_LINK = ( + MLENGINE_BASE_LINK + "/models/{model_id}/versions/{version_id}/performance?project={project_id}" +) +MLENGINE_MODELS_LIST_LINK = MLENGINE_BASE_LINK + "/models/?project={project_id}" +MLENGINE_JOB_DETAILS_LINK = MLENGINE_BASE_LINK + "/jobs/{job_id}?project={project_id}" +MLENGINE_JOBS_LIST_LINK = MLENGINE_BASE_LINK + "/jobs?project={project_id}" + + +class MLEngineModelLink(BaseGoogleLink): + """Helper class for constructing ML Engine link""" + + name = "MLEngine Model" + key = "ml_engine_model" + format_str = MLENGINE_MODEL_DETAILS_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + model_id: str, + project_id: str, + ): + task_instance.xcom_push( + context, + key=MLEngineModelLink.key, + value={"model_id": model_id, "project_id": project_id}, + ) + + +class MLEngineModelsListLink(BaseGoogleLink): + """Helper class for constructing ML Engine link""" + + name = "MLEngine Models List" + key = "ml_engine_models_list" + format_str = MLENGINE_MODELS_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context, + key=MLEngineModelsListLink.key, + value={"project_id": project_id}, + ) + + +class MLEngineJobDetailsLink(BaseGoogleLink): + """Helper class for constructing ML Engine link""" + + name = "MLEngine Job Details" + key = "ml_engine_job_details" + format_str = MLENGINE_JOB_DETAILS_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + job_id: str, + project_id: str, + ): + task_instance.xcom_push( + context, + key=MLEngineJobDetailsLink.key, + value={"job_id": job_id, "project_id": project_id}, + ) + + +class MLEngineModelVersionDetailsLink(BaseGoogleLink): + """Helper class for constructing ML Engine link""" + + name = "MLEngine Version Details" + key = "ml_engine_version_details" + format_str = MLENGINE_MODEL_VERSION_DETAILS_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + model_id: str, + project_id: str, + version_id: str, + ): + task_instance.xcom_push( + context, + key=MLEngineModelVersionDetailsLink.key, + value={"model_id": model_id, "project_id": project_id, "version_id": version_id}, + ) + + +class MLEngineJobSListLink(BaseGoogleLink): + """Helper class for constructing ML Engine link""" + + name = "MLEngine Jobs List" + key = "ml_engine_jobs_list" + format_str = MLENGINE_JOBS_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context, + key=MLEngineJobSListLink.key, + value={"project_id": project_id}, + ) diff --git a/airflow/providers/google/cloud/links/pubsub.py b/airflow/providers/google/cloud/links/pubsub.py new file mode 100644 index 0000000000000..83de5ba00769f --- /dev/null +++ b/airflow/providers/google/cloud/links/pubsub.py @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Pub/Sub links.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.links.base import BaseGoogleLink + +if TYPE_CHECKING: + from airflow.utils.context import Context + +PUBSUB_BASE_LINK = "/cloudpubsub" +PUBSUB_TOPIC_LINK = PUBSUB_BASE_LINK + "/topic/detail/{topic_id}?project={project_id}" +PUBSUB_SUBSCRIPTION_LINK = PUBSUB_BASE_LINK + "/subscription/detail/{subscription_id}?project={project_id}" + + +class PubSubTopicLink(BaseGoogleLink): + """Helper class for constructing Pub/Sub Topic Link""" + + name = "Pub/Sub Topic" + key = "pubsub_topic" + format_str = PUBSUB_TOPIC_LINK + + @staticmethod + def persist( + context: Context, + task_instance: BaseOperator, + topic_id: str, + project_id: str | None, + ): + task_instance.xcom_push( + context, + key=PubSubTopicLink.key, + value={"topic_id": topic_id, "project_id": project_id}, + ) + + +class PubSubSubscriptionLink(BaseGoogleLink): + """Helper class for constructing Pub/Sub Subscription Link""" + + name = "Pub/Sub Subscription" + key = "pubsub_subscription" + format_str = PUBSUB_SUBSCRIPTION_LINK + + @staticmethod + def persist( + context: Context, + task_instance: BaseOperator, + subscription_id: str | None, + project_id: str | None, + ): + task_instance.xcom_push( + context, + key=PubSubSubscriptionLink.key, + value={"subscription_id": subscription_id, "project_id": project_id}, + ) diff --git a/airflow/providers/google/cloud/links/spanner.py b/airflow/providers/google/cloud/links/spanner.py index 0834944dc177f..0306e46233c93 100644 --- a/airflow/providers/google/cloud/links/spanner.py +++ b/airflow/providers/google/cloud/links/spanner.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Spanner links.""" -from typing import TYPE_CHECKING, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING from airflow.models import BaseOperator from airflow.providers.google.cloud.links.base import BaseGoogleLink @@ -24,7 +26,7 @@ if TYPE_CHECKING: from airflow.utils.context import Context -SPANNER_BASE_LINK = "https://console.cloud.google.com/spanner/instances" +SPANNER_BASE_LINK = "/spanner/instances" SPANNER_INSTANCE_LINK = SPANNER_BASE_LINK + "/{instance_id}/details/databases?project={project_id}" SPANNER_DATABASE_LINK = ( SPANNER_BASE_LINK + "/{instance_id}/databases/{database_id}/details/tables?project={project_id}" @@ -40,10 +42,10 @@ class SpannerInstanceLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance: BaseOperator, instance_id: str, - project_id: Optional[str], + project_id: str | None, ): task_instance.xcom_push( context, @@ -61,11 +63,11 @@ class SpannerDatabaseLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance: BaseOperator, instance_id: str, database_id: str, - project_id: Optional[str], + project_id: str | None, ): task_instance.xcom_push( context, diff --git a/airflow/providers/google/cloud/links/stackdriver.py b/airflow/providers/google/cloud/links/stackdriver.py index 5870266422bc2..1dec31ccfed2b 100644 --- a/airflow/providers/google/cloud/links/stackdriver.py +++ b/airflow/providers/google/cloud/links/stackdriver.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Stackdriver links.""" -from typing import TYPE_CHECKING, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING from airflow.models import BaseOperator from airflow.providers.google.cloud.links.base import BaseGoogleLink @@ -24,7 +26,7 @@ if TYPE_CHECKING: from airflow.utils.context import Context -STACKDRIVER_BASE_LINK = "https://pantheon.corp.google.com/monitoring/alerting" +STACKDRIVER_BASE_LINK = "/monitoring/alerting" STACKDRIVER_NOTIFICATIONS_LINK = STACKDRIVER_BASE_LINK + "/notifications?project={project_id}" STACKDRIVER_POLICIES_LINK = STACKDRIVER_BASE_LINK + "/policies?project={project_id}" @@ -39,8 +41,8 @@ class StackdriverNotificationsLink(BaseGoogleLink): @staticmethod def persist( operator_instance: BaseOperator, - context: "Context", - project_id: Optional[str], + context: Context, + project_id: str | None, ): operator_instance.xcom_push( context, @@ -59,8 +61,8 @@ class StackdriverPoliciesLink(BaseGoogleLink): @staticmethod def persist( operator_instance: BaseOperator, - context: "Context", - project_id: Optional[str], + context: Context, + project_id: str | None, ): operator_instance.xcom_push( context, diff --git a/airflow/providers/google/cloud/links/vertex_ai.py b/airflow/providers/google/cloud/links/vertex_ai.py index 910049c81a5a0..a251305c45e7d 100644 --- a/airflow/providers/google/cloud/links/vertex_ai.py +++ b/airflow/providers/google/cloud/links/vertex_ai.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from typing import TYPE_CHECKING @@ -22,15 +23,12 @@ if TYPE_CHECKING: from airflow.utils.context import Context -BASE_LINK = "https://console.cloud.google.com" -VERTEX_AI_BASE_LINK = BASE_LINK + "/vertex-ai" +VERTEX_AI_BASE_LINK = "/vertex-ai" VERTEX_AI_MODEL_LINK = ( VERTEX_AI_BASE_LINK + "/locations/{region}/models/{model_id}/deploy?project={project_id}" ) VERTEX_AI_MODEL_LIST_LINK = VERTEX_AI_BASE_LINK + "/models?project={project_id}" -VERTEX_AI_MODEL_EXPORT_LINK = ( - BASE_LINK + "/storage/browser/{bucket_name}/model-{model_id}?project={project_id}" -) +VERTEX_AI_MODEL_EXPORT_LINK = "/storage/browser/{bucket_name}/model-{model_id}?project={project_id}" VERTEX_AI_TRAINING_LINK = ( VERTEX_AI_BASE_LINK + "/locations/{region}/training/{training_id}/cpu?project={project_id}" ) @@ -62,7 +60,7 @@ class VertexAIModelLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance, model_id: str, ): @@ -86,7 +84,7 @@ class VertexAIModelListLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance, ): task_instance.xcom_push( @@ -112,7 +110,7 @@ def extract_bucket_name(config): @staticmethod def persist( - context: "Context", + context: Context, task_instance, ): task_instance.xcom_push( @@ -135,7 +133,7 @@ class VertexAITrainingLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance, training_id: str, ): @@ -159,7 +157,7 @@ class VertexAITrainingPipelinesLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance, ): task_instance.xcom_push( @@ -179,7 +177,7 @@ class VertexAIDatasetLink(BaseGoogleLink): format_str = VERTEX_AI_DATASET_LINK @staticmethod - def persist(context: "Context", task_instance, dataset_id: str): + def persist(context: Context, task_instance, dataset_id: str): task_instance.xcom_push( context=context, key=VertexAIDatasetLink.key, @@ -200,7 +198,7 @@ class VertexAIDatasetListLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance, ): task_instance.xcom_push( @@ -221,7 +219,7 @@ class VertexAIHyperparameterTuningJobListLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance, ): task_instance.xcom_push( @@ -242,7 +240,7 @@ class VertexAIBatchPredictionJobLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance, batch_prediction_job_id: str, ): @@ -266,7 +264,7 @@ class VertexAIBatchPredictionJobListLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance, ): task_instance.xcom_push( @@ -287,7 +285,7 @@ class VertexAIEndpointLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance, endpoint_id: str, ): @@ -311,7 +309,7 @@ class VertexAIEndpointListLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", + context: Context, task_instance, ): task_instance.xcom_push( diff --git a/airflow/providers/google/cloud/links/workflows.py b/airflow/providers/google/cloud/links/workflows.py new file mode 100644 index 0000000000000..8563739a99f7f --- /dev/null +++ b/airflow/providers/google/cloud/links/workflows.py @@ -0,0 +1,105 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Workflows links.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.links.base import BaseGoogleLink + +if TYPE_CHECKING: + from airflow.utils.context import Context + +WORKFLOWS_BASE_LINK = "workflows" +WORKFLOW_LINK = WORKFLOWS_BASE_LINK + "/workflow/{location_id}/{workflow_id}/executions?project={project_id}" +WORKFLOWS_LINK = WORKFLOWS_BASE_LINK + "?project={project_id}" +EXECUTION_LINK = ( + WORKFLOWS_BASE_LINK + + "/workflow/{location_id}/{workflow_id}/execution/{execution_id}?project={project_id}" +) + + +class WorkflowsWorkflowDetailsLink(BaseGoogleLink): + """Helper class for constructing Workflow details Link""" + + name = "Workflow details" + key = "workflow_details" + format_str = WORKFLOW_LINK + + @staticmethod + def persist( + context: Context, + task_instance: BaseOperator, + location_id: str, + workflow_id: str, + project_id: str | None, + ): + task_instance.xcom_push( + context, + key=WorkflowsWorkflowDetailsLink.key, + value={"location_id": location_id, "workflow_id": workflow_id, "project_id": project_id}, + ) + + +class WorkflowsListOfWorkflowsLink(BaseGoogleLink): + """Helper class for constructing list of Workflows Link""" + + name = "List of workflows" + key = "list_of_workflows" + format_str = WORKFLOWS_LINK + + @staticmethod + def persist( + context: Context, + task_instance: BaseOperator, + project_id: str | None, + ): + task_instance.xcom_push( + context, + key=WorkflowsListOfWorkflowsLink.key, + value={"project_id": project_id}, + ) + + +class WorkflowsExecutionLink(BaseGoogleLink): + """Helper class for constructing Workflows Execution Link""" + + name = "Workflow Execution" + key = "workflow_execution" + format_str = EXECUTION_LINK + + @staticmethod + def persist( + context: Context, + task_instance: BaseOperator, + location_id: str, + workflow_id: str, + execution_id: str, + project_id: str | None, + ): + task_instance.xcom_push( + context, + key=WorkflowsExecutionLink.key, + value={ + "location_id": location_id, + "workflow_id": workflow_id, + "execution_id": execution_id, + "project_id": project_id, + }, + ) diff --git a/airflow/providers/google/cloud/log/gcs_task_handler.py b/airflow/providers/google/cloud/log/gcs_task_handler.py index 92d133d109af5..5fbba80798498 100644 --- a/airflow/providers/google/cloud/log/gcs_task_handler.py +++ b/airflow/providers/google/cloud/log/gcs_task_handler.py @@ -15,21 +15,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import os -import sys -from typing import Collection, Optional - -from airflow.providers.google.common.consts import CLIENT_INFO +from __future__ import annotations -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property +import os +from typing import Collection # not sure why but mypy complains on missing `storage` but it is clearly there and is importable from google.cloud import storage # type: ignore[attr-defined] +from airflow.compat.functools import cached_property from airflow.providers.google.cloud.utils.credentials_provider import get_credentials_and_project_id +from airflow.providers.google.common.consts import CLIENT_INFO from airflow.utils.log.file_task_handler import FileTaskHandler from airflow.utils.log.logging_mixin import LoggingMixin @@ -67,15 +63,15 @@ def __init__( *, base_log_folder: str, gcs_log_folder: str, - filename_template: str, - gcp_key_path: Optional[str] = None, - gcp_keyfile_dict: Optional[dict] = None, - gcp_scopes: Optional[Collection[str]] = _DEFAULT_SCOPESS, - project_id: Optional[str] = None, + filename_template: str | None = None, + gcp_key_path: str | None = None, + gcp_keyfile_dict: dict | None = None, + gcp_scopes: Collection[str] | None = _DEFAULT_SCOPESS, + project_id: str | None = None, ): super().__init__(base_log_folder, filename_template) self.remote_base = gcs_log_folder - self.log_relative_path = '' + self.log_relative_path = "" self._hook = None self.closed = False self.upload_on_close = True @@ -151,12 +147,12 @@ def _read(self, ti, try_number, metadata=None): try: blob = storage.Blob.from_string(remote_loc, self.client) remote_log = blob.download_as_bytes().decode() - log = f'*** Reading remote log from {remote_loc}.\n{remote_log}\n' - return log, {'end_of_log': True} + log = f"*** Reading remote log from {remote_loc}.\n{remote_log}\n" + return log, {"end_of_log": True} except Exception as e: - log = f'*** Unable to read remote log from {remote_loc}\n*** {str(e)}\n\n' + log = f"*** Unable to read remote log from {remote_loc}\n*** {str(e)}\n\n" self.log.error(log) - local_log, metadata = super()._read(ti, try_number) + local_log, metadata = super()._read(ti, try_number, metadata) log += local_log return log, metadata @@ -171,14 +167,14 @@ def gcs_write(self, log, remote_log_location): try: blob = storage.Blob.from_string(remote_log_location, self.client) old_log = blob.download_as_bytes().decode() - log = '\n'.join([old_log, log]) if old_log else log + log = "\n".join([old_log, log]) if old_log else log except Exception as e: - if not hasattr(e, 'resp') or e.resp.get('status') != '404': - log = f'*** Previous log discarded: {str(e)}\n\n' + log + if not hasattr(e, "resp") or e.resp.get("status") != "404": + log = f"*** Previous log discarded: {str(e)}\n\n" + log self.log.info("Previous log discarded: %s", e) try: blob = storage.Blob.from_string(remote_log_location, self.client) blob.upload_from_string(log, content_type="text/plain") except Exception as e: - self.log.error('Could not write logs to %s: %s', remote_log_location, e) + self.log.error("Could not write logs to %s: %s", remote_log_location, e) diff --git a/airflow/providers/google/cloud/log/stackdriver_task_handler.py b/airflow/providers/google/cloud/log/stackdriver_task_handler.py index 94ce5e57b0c6f..0478693761e11 100644 --- a/airflow/providers/google/cloud/log/stackdriver_task_handler.py +++ b/airflow/providers/google/cloud/log/stackdriver_task_handler.py @@ -15,18 +15,12 @@ # specific language governing permissions and limitations # under the License. """Handler that integrates with Stackdriver""" +from __future__ import annotations + import logging -import sys -from typing import Collection, Dict, List, Optional, Tuple, Type, Union +from typing import Collection from urllib.parse import urlencode -from airflow.providers.google.common.consts import CLIENT_INFO - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - from google.auth.credentials import Credentials from google.cloud import logging as gcp_logging from google.cloud.logging import Resource @@ -34,8 +28,10 @@ from google.cloud.logging_v2.services.logging_service_v2 import LoggingServiceV2Client from google.cloud.logging_v2.types import ListLogEntriesRequest, ListLogEntriesResponse +from airflow.compat.functools import cached_property from airflow.models import TaskInstance from airflow.providers.google.cloud.utils.credentials_provider import get_credentials_and_project_id +from airflow.providers.google.common.consts import CLIENT_INFO DEFAULT_LOGGER_NAME = "airflow" _GLOBAL_RESOURCE = Resource(type="global", labels={}) @@ -80,29 +76,29 @@ class StackdriverTaskHandler(logging.Handler): LABEL_EXECUTION_DATE = "execution_date" LABEL_TRY_NUMBER = "try_number" LOG_VIEWER_BASE_URL = "https://console.cloud.google.com/logs/viewer" - LOG_NAME = 'Google Stackdriver' + LOG_NAME = "Google Stackdriver" def __init__( self, - gcp_key_path: Optional[str] = None, - scopes: Optional[Collection[str]] = _DEFAULT_SCOPESS, + gcp_key_path: str | None = None, + scopes: Collection[str] | None = _DEFAULT_SCOPESS, name: str = DEFAULT_LOGGER_NAME, - transport: Type[Transport] = BackgroundThreadTransport, + transport: type[Transport] = BackgroundThreadTransport, resource: Resource = _GLOBAL_RESOURCE, - labels: Optional[Dict[str, str]] = None, + labels: dict[str, str] | None = None, ): super().__init__() - self.gcp_key_path: Optional[str] = gcp_key_path - self.scopes: Optional[Collection[str]] = scopes + self.gcp_key_path: str | None = gcp_key_path + self.scopes: Collection[str] | None = scopes self.name: str = name - self.transport_type: Type[Transport] = transport + self.transport_type: type[Transport] = transport self.resource: Resource = resource - self.labels: Optional[Dict[str, str]] = labels - self.task_instance_labels: Optional[Dict[str, str]] = {} - self.task_instance_hostname = 'default-hostname' + self.labels: dict[str, str] | None = labels + self.task_instance_labels: dict[str, str] | None = {} + self.task_instance_hostname = "default-hostname" @cached_property - def _credentials_and_project(self) -> Tuple[Credentials, str]: + def _credentials_and_project(self) -> tuple[Credentials, str]: credentials, project = get_credentials_and_project_id( key_path=self.gcp_key_path, scopes=self.scopes, disable_logging=True ) @@ -142,7 +138,7 @@ def emit(self, record: logging.LogRecord) -> None: :param record: The record to be logged. """ message = self.format(record) - labels: Optional[Dict[str, str]] + labels: dict[str, str] | None if self.labels and self.task_instance_labels: labels = {} labels.update(self.labels) @@ -165,8 +161,8 @@ def set_context(self, task_instance: TaskInstance) -> None: self.task_instance_hostname = task_instance.hostname def read( - self, task_instance: TaskInstance, try_number: Optional[int] = None, metadata: Optional[Dict] = None - ) -> Tuple[List[Tuple[Tuple[str, str]]], List[Dict[str, Union[str, bool]]]]: + self, task_instance: TaskInstance, try_number: int | None = None, metadata: dict | None = None + ) -> tuple[list[tuple[tuple[str, str]]], list[dict[str, str | bool]]]: """ Read logs of given task instance from Stackdriver logging. @@ -177,7 +173,6 @@ def read( :return: a tuple of ( list of (one element tuple with two element tuple - hostname and logs) and list of metadata) - :rtype: Tuple[List[Tuple[Tuple[str, str]]], List[Dict[str, str]]] """ if try_number is not None and try_number < 1: logs = f"Error fetching the logs. Try number {try_number} is invalid." @@ -195,18 +190,18 @@ def read( log_filter = self._prepare_log_filter(ti_labels) next_page_token = metadata.get("next_page_token", None) - all_pages = 'download_logs' in metadata and metadata['download_logs'] + all_pages = "download_logs" in metadata and metadata["download_logs"] messages, end_of_log, next_page_token = self._read_logs(log_filter, next_page_token, all_pages) - new_metadata: Dict[str, Union[str, bool]] = {"end_of_log": end_of_log} + new_metadata: dict[str, str | bool] = {"end_of_log": end_of_log} if next_page_token: - new_metadata['next_page_token'] = next_page_token + new_metadata["next_page_token"] = next_page_token return [((self.task_instance_hostname, messages),)], [new_metadata] - def _prepare_log_filter(self, ti_labels: Dict[str, str]) -> str: + def _prepare_log_filter(self, ti_labels: dict[str, str]) -> str: """ Prepares the filter that chooses which log entries to fetch. @@ -227,20 +222,20 @@ def escale_label_value(value: str) -> str: _, project = self._credentials_and_project log_filters = [ - f'resource.type={escale_label_value(self.resource.type)}', + f"resource.type={escale_label_value(self.resource.type)}", f'logName="projects/{project}/logs/{self.name}"', ] for key, value in self.resource.labels.items(): - log_filters.append(f'resource.labels.{escape_label_key(key)}={escale_label_value(value)}') + log_filters.append(f"resource.labels.{escape_label_key(key)}={escale_label_value(value)}") for key, value in ti_labels.items(): - log_filters.append(f'labels.{escape_label_key(key)}={escale_label_value(value)}') + log_filters.append(f"labels.{escape_label_key(key)}={escale_label_value(value)}") return "\n".join(log_filters) def _read_logs( - self, log_filter: str, next_page_token: Optional[str], all_pages: bool - ) -> Tuple[str, bool, Optional[str]]: + self, log_filter: str, next_page_token: str | None, all_pages: bool + ) -> tuple[str, bool, str | None]: """ Sends requests to the Stackdriver service and downloads logs. @@ -253,7 +248,6 @@ def _read_logs( * string with logs * Boolean value describing whether there are more logs, * token of the next page - :rtype: Tuple[str, bool, str] """ messages = [] new_messages, next_page_token = self._read_single_logs_page( @@ -276,7 +270,7 @@ def _read_logs( end_of_log = not bool(next_page_token) return "\n".join(messages), end_of_log, next_page_token - def _read_single_logs_page(self, log_filter: str, page_token: Optional[str] = None) -> Tuple[str, str]: + def _read_single_logs_page(self, log_filter: str, page_token: str | None = None) -> tuple[str, str]: """ Sends requests to the Stackdriver service and downloads single pages with logs. @@ -284,14 +278,13 @@ def _read_single_logs_page(self, log_filter: str, page_token: Optional[str] = No :param page_token: The token of the page to be downloaded. If None is passed, the first page will be downloaded. :return: Downloaded logs and next page token - :rtype: Tuple[str, str] """ _, project = self._credentials_and_project request = ListLogEntriesRequest( - resource_names=[f'projects/{project}'], + resource_names=[f"projects/{project}"], filter=log_filter, page_token=page_token, - order_by='timestamp asc', + order_by="timestamp asc", page_size=1000, ) response = self._logging_service_client.list_log_entries(request=request) @@ -303,7 +296,7 @@ def _read_single_logs_page(self, log_filter: str, page_token: Optional[str] = No return "\n".join(messages), page.next_page_token @classmethod - def _task_instance_to_labels(cls, ti: TaskInstance) -> Dict[str, str]: + def _task_instance_to_labels(cls, ti: TaskInstance) -> dict[str, str]: return { cls.LABEL_TASK_ID: ti.task_id, cls.LABEL_DAG_ID: ti.dag_id, @@ -332,7 +325,6 @@ def get_external_log_url(self, task_instance: TaskInstance, try_number: int) -> :param task_instance: task instance object :param try_number: task instance try_number to read logs from. :return: URL to the external log collection service - :rtype: str """ _, project_id = self._credentials_and_project @@ -342,10 +334,10 @@ def get_external_log_url(self, task_instance: TaskInstance, try_number: int) -> log_filter = self._prepare_log_filter(ti_labels) url_query_string = { - 'project': project_id, - 'interval': 'NO_LIMIT', - 'resource': self._resource_path, - 'advancedFilter': log_filter, + "project": project_id, + "interval": "NO_LIMIT", + "resource": self._resource_path, + "advancedFilter": log_filter, } url = f"{self.LOG_VIEWER_BASE_URL}?{urlencode(url_query_string)}" diff --git a/airflow/providers/google/cloud/operators/automl.py b/airflow/providers/google/cloud/operators/automl.py index d42b4f1a248dc..ff1ee00c3e9e6 100644 --- a/airflow/providers/google/cloud/operators/automl.py +++ b/airflow/providers/google/cloud/operators/automl.py @@ -15,11 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# - """This module contains Google AutoML operators.""" +from __future__ import annotations + import ast -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Sequence, Tuple from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -81,12 +81,12 @@ def __init__( *, model: dict, location: str, - project_id: Optional[str] = None, + project_id: str | None = None, metadata: MetaData = (), - timeout: Optional[float] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -100,7 +100,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudAutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -165,13 +165,13 @@ def __init__( model_id: str, location: str, payload: dict, - operation_params: Optional[Dict[str, str]] = None, - project_id: Optional[str] = None, + operation_params: dict[str, str] | None = None, + project_id: str | None = None, metadata: MetaData = (), - timeout: Optional[float] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -187,7 +187,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudAutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -260,13 +260,13 @@ def __init__( input_config: dict, output_config: dict, location: str, - project_id: Optional[str] = None, - prediction_params: Optional[Dict[str, str]] = None, + project_id: str | None = None, + prediction_params: dict[str, str] | None = None, metadata: MetaData = (), - timeout: Optional[float] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -283,7 +283,7 @@ def __init__( self.input_config = input_config self.output_config = output_config - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudAutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -347,12 +347,12 @@ def __init__( *, dataset: dict, location: str, - project_id: Optional[str] = None, + project_id: str | None = None, metadata: MetaData = (), - timeout: Optional[float] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -366,7 +366,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudAutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -433,12 +433,12 @@ def __init__( dataset_id: str, location: str, input_config: dict, - project_id: Optional[str] = None, + project_id: str | None = None, metadata: MetaData = (), - timeout: Optional[float] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -453,7 +453,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudAutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -525,15 +525,15 @@ def __init__( dataset_id: str, table_spec_id: str, location: str, - field_mask: Optional[dict] = None, - filter_: Optional[str] = None, - page_size: Optional[int] = None, - project_id: Optional[str] = None, + field_mask: dict | None = None, + filter_: str | None = None, + page_size: int | None = None, + project_id: str | None = None, metadata: MetaData = (), - timeout: Optional[float] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -550,7 +550,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudAutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -616,12 +616,12 @@ def __init__( *, dataset: dict, location: str, - update_mask: Optional[dict] = None, + update_mask: dict | None = None, metadata: MetaData = (), - timeout: Optional[float] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -635,7 +635,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudAutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -693,12 +693,12 @@ def __init__( *, model_id: str, location: str, - project_id: Optional[str] = None, + project_id: str | None = None, metadata: MetaData = (), - timeout: Optional[float] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -712,7 +712,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudAutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -769,12 +769,12 @@ def __init__( *, model_id: str, location: str, - project_id: Optional[str] = None, + project_id: str | None = None, metadata: MetaData = (), - timeout: Optional[float] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -788,7 +788,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudAutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -808,7 +808,7 @@ class AutoMLDeployModelOperator(BaseOperator): """ Deploys a model. If a model is already deployed, deploying it with the same parameters has no effect. Deploying with different parameters (as e.g. changing node_number) will - reset the deployment state without pausing the model_id’s availability. + reset the deployment state without pausing the model_id's availability. Only applicable for Text Classification, Image Object Detection and Tables; all other domains manage deployment automatically. @@ -853,13 +853,13 @@ def __init__( *, model_id: str, location: str, - project_id: Optional[str] = None, - image_detection_metadata: Optional[dict] = None, - metadata: Sequence[Tuple[str, str]] = (), - timeout: Optional[float] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, + project_id: str | None = None, + image_detection_metadata: dict | None = None, + metadata: Sequence[tuple[str, str]] = (), + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -874,7 +874,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudAutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -941,14 +941,14 @@ def __init__( *, dataset_id: str, location: str, - page_size: Optional[int] = None, - filter_: Optional[str] = None, - project_id: Optional[str] = None, + page_size: int | None = None, + filter_: str | None = None, + project_id: str | None = None, metadata: MetaData = (), - timeout: Optional[float] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -963,7 +963,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudAutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1022,12 +1022,12 @@ def __init__( self, *, location: str, - project_id: Optional[str] = None, + project_id: str | None = None, metadata: MetaData = (), - timeout: Optional[float] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1039,7 +1039,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudAutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1102,14 +1102,14 @@ class AutoMLDeleteDatasetOperator(BaseOperator): def __init__( self, *, - dataset_id: Union[str, List[str]], + dataset_id: str | list[str], location: str, - project_id: Optional[str] = None, + project_id: str | None = None, metadata: MetaData = (), - timeout: Optional[float] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1124,7 +1124,7 @@ def __init__( self.impersonation_chain = impersonation_chain @staticmethod - def _parse_dataset_id(dataset_id: Union[str, List[str]]) -> List[str]: + def _parse_dataset_id(dataset_id: str | list[str]) -> list[str]: if not isinstance(dataset_id, str): return dataset_id try: @@ -1132,7 +1132,7 @@ def _parse_dataset_id(dataset_id: Union[str, List[str]]) -> List[str]: except (SyntaxError, ValueError): return dataset_id.split(",") - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudAutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 268bbc1cad8c2..b167038388c5b 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -15,30 +15,40 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - - """This module contains Google BigQuery operators.""" +from __future__ import annotations + import enum -import hashlib import json -import re -import uuid import warnings -from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Set, SupportsAbs, Union +from typing import TYPE_CHECKING, Any, Iterable, Sequence, SupportsAbs import attr from google.api_core.exceptions import Conflict from google.api_core.retry import Retry -from google.cloud.bigquery import DEFAULT_RETRY +from google.cloud.bigquery import DEFAULT_RETRY, CopyJob, ExtractJob, LoadJob, QueryJob from airflow.exceptions import AirflowException from airflow.models import BaseOperator, BaseOperatorLink from airflow.models.xcom import XCom -from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator +from airflow.providers.common.sql.operators.sql import ( + SQLCheckOperator, + SQLColumnCheckOperator, + SQLIntervalCheckOperator, + SQLTableCheckOperator, + SQLValueCheckOperator, + _parse_boolean, +) from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url from airflow.providers.google.cloud.links.bigquery import BigQueryDatasetLink, BigQueryTableLink +from airflow.providers.google.cloud.triggers.bigquery import ( + BigQueryCheckTrigger, + BigQueryGetDataTrigger, + BigQueryInsertJobTrigger, + BigQueryIntervalCheckTrigger, + BigQueryValueCheckTrigger, +) if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstanceKey @@ -60,25 +70,16 @@ class BigQueryUIColors(enum.Enum): class BigQueryConsoleLink(BaseOperatorLink): """Helper class for constructing BigQuery link.""" - name = 'BigQuery Console' + name = "BigQuery Console" def get_link( self, - operator, - dttm: Optional[datetime] = None, - ti_key: Optional["TaskInstanceKey"] = None, + operator: BaseOperator, + *, + ti_key: TaskInstanceKey, ): - if ti_key is not None: - job_id = XCom.get_value(key='job_id', ti_key=ti_key) - else: - assert dttm is not None - job_id = XCom.get_one( - dag_id=operator.dag.dag_id, - task_id=operator.task_id, - execution_date=dttm, - key='job_id', - ) - return BIGQUERY_JOB_DETAILS_LINK_FMT.format(job_id=job_id) if job_id else '' + job_id = XCom.get_value(key="job_id", ti_key=ti_key) + return BIGQUERY_JOB_DETAILS_LINK_FMT.format(job_id=job_id) if job_id else "" @attr.s(auto_attribs=True) @@ -89,21 +90,15 @@ class BigQueryConsoleIndexableLink(BaseOperatorLink): @property def name(self) -> str: - return f'BigQuery Console #{self.index + 1}' + return f"BigQuery Console #{self.index + 1}" def get_link( self, - operator, - dttm: Optional[datetime] = None, - ti_key: Optional["TaskInstanceKey"] = None, + operator: BaseOperator, + *, + ti_key: TaskInstanceKey, ): - if ti_key is not None: - job_ids = XCom.get_value(key='job_id', ti_key=ti_key) - else: - assert dttm is not None - job_ids = XCom.get_one( - key='job_id', dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm - ) + job_ids = XCom.get_value(key="job_id", ti_key=ti_key) if not job_ids: return None if len(job_ids) < self.index: @@ -113,7 +108,7 @@ def get_link( class _BigQueryDbHookMixin: - def get_db_hook(self: 'BigQueryCheckOperator') -> BigQueryHook: # type:ignore[misc] + def get_db_hook(self: BigQueryCheckOperator) -> BigQueryHook: # type:ignore[misc] """Get BigQuery DB Hook""" return BigQueryHook( gcp_conn_id=self.gcp_conn_id, @@ -171,26 +166,28 @@ class BigQueryCheckOperator(_BigQueryDbHookMixin, SQLCheckOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). :param labels: a dictionary containing labels for the table, passed to BigQuery + :param deferrable: Run operator in the deferrable mode """ template_fields: Sequence[str] = ( - 'sql', - 'gcp_conn_id', - 'impersonation_chain', - 'labels', + "sql", + "gcp_conn_id", + "impersonation_chain", + "labels", ) - template_ext: Sequence[str] = ('.sql',) + template_ext: Sequence[str] = (".sql",) ui_color = BigQueryUIColors.CHECK.value def __init__( self, *, sql: str, - gcp_conn_id: str = 'google_cloud_default', + gcp_conn_id: str = "google_cloud_default", use_legacy_sql: bool = True, - location: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - labels: Optional[dict] = None, + location: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + labels: dict | None = None, + deferrable: bool = False, **kwargs, ) -> None: super().__init__(sql=sql, **kwargs) @@ -200,6 +197,59 @@ def __init__( self.location = location self.impersonation_chain = impersonation_chain self.labels = labels + self.deferrable = deferrable + + def _submit_job( + self, + hook: BigQueryHook, + job_id: str, + ) -> BigQueryJob: + """Submit a new job and get the job id for polling the status using Trigger.""" + configuration = {"query": {"query": self.sql}} + + return hook.insert_job( + configuration=configuration, + project_id=hook.project_id, + location=self.location, + job_id=job_id, + nowait=True, + ) + + def execute(self, context: Context): + if not self.deferrable: + super().execute(context=context) + else: + hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + ) + job = self._submit_job(hook, job_id="") + context["ti"].xcom_push(key="job_id", value=job.job_id) + self.defer( + timeout=self.execution_timeout, + trigger=BigQueryCheckTrigger( + conn_id=self.gcp_conn_id, + job_id=job.job_id, + project_id=hook.project_id, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, Any]) -> None: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + + records = event["records"] + if not records: + raise AirflowException("The query returned empty results") + elif not all(bool(r) for r in records): + self._raise_exception(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}") + self.log.info("Record: %s", event["records"]) + self.log.info("Success.") class BigQueryValueCheckOperator(_BigQueryDbHookMixin, SQLValueCheckOperator): @@ -225,16 +275,17 @@ class BigQueryValueCheckOperator(_BigQueryDbHookMixin, SQLValueCheckOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). :param labels: a dictionary containing labels for the table, passed to BigQuery + :param deferrable: Run operator in the deferrable mode """ template_fields: Sequence[str] = ( - 'sql', - 'gcp_conn_id', - 'pass_value', - 'impersonation_chain', - 'labels', + "sql", + "gcp_conn_id", + "pass_value", + "impersonation_chain", + "labels", ) - template_ext: Sequence[str] = ('.sql',) + template_ext: Sequence[str] = (".sql",) ui_color = BigQueryUIColors.CHECK.value def __init__( @@ -243,11 +294,12 @@ def __init__( sql: str, pass_value: Any, tolerance: Any = None, - gcp_conn_id: str = 'google_cloud_default', + gcp_conn_id: str = "google_cloud_default", use_legacy_sql: bool = True, - location: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - labels: Optional[dict] = None, + location: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + labels: dict | None = None, + deferrable: bool = False, **kwargs, ) -> None: super().__init__(sql=sql, pass_value=pass_value, tolerance=tolerance, **kwargs) @@ -256,6 +308,65 @@ def __init__( self.use_legacy_sql = use_legacy_sql self.impersonation_chain = impersonation_chain self.labels = labels + self.deferrable = deferrable + + def _submit_job( + self, + hook: BigQueryHook, + job_id: str, + ) -> BigQueryJob: + """Submit a new job and get the job id for polling the status using Triggerer.""" + configuration = { + "query": { + "query": self.sql, + "useLegacySql": False, + } + } + if self.use_legacy_sql: + configuration["query"]["useLegacySql"] = self.use_legacy_sql + + return hook.insert_job( + configuration=configuration, + project_id=hook.project_id, + location=self.location, + job_id=job_id, + nowait=True, + ) + + def execute(self, context: Context) -> None: # type: ignore[override] + if not self.deferrable: + super().execute(context=context) + else: + hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id) + + job = self._submit_job(hook, job_id="") + context["ti"].xcom_push(key="job_id", value=job.job_id) + self.defer( + timeout=self.execution_timeout, + trigger=BigQueryValueCheckTrigger( + conn_id=self.gcp_conn_id, + job_id=job.job_id, + project_id=hook.project_id, + sql=self.sql, + pass_value=self.pass_value, + tolerance=self.tol, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, Any]) -> None: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + self.log.info( + "%s completed with response %s ", + self.task_id, + event["message"], + ) class BigQueryIntervalCheckOperator(_BigQueryDbHookMixin, SQLIntervalCheckOperator): @@ -292,15 +403,16 @@ class BigQueryIntervalCheckOperator(_BigQueryDbHookMixin, SQLIntervalCheckOperat Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). :param labels: a dictionary containing labels for the table, passed to BigQuery + :param deferrable: Run operator in the deferrable mode """ template_fields: Sequence[str] = ( - 'table', - 'gcp_conn_id', - 'sql1', - 'sql2', - 'impersonation_chain', - 'labels', + "table", + "gcp_conn_id", + "sql1", + "sql2", + "impersonation_chain", + "labels", ) ui_color = BigQueryUIColors.CHECK.value @@ -309,13 +421,14 @@ def __init__( *, table: str, metrics_thresholds: dict, - date_filter_column: str = 'ds', + date_filter_column: str = "ds", days_back: SupportsAbs[int] = -7, - gcp_conn_id: str = 'google_cloud_default', + gcp_conn_id: str = "google_cloud_default", use_legacy_sql: bool = True, - location: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - labels: Optional[Dict] = None, + location: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + labels: dict | None = None, + deferrable: bool = False, **kwargs, ) -> None: super().__init__( @@ -331,6 +444,292 @@ def __init__( self.location = location self.impersonation_chain = impersonation_chain self.labels = labels + self.deferrable = deferrable + + def _submit_job( + self, + hook: BigQueryHook, + sql: str, + job_id: str, + ) -> BigQueryJob: + """Submit a new job and get the job id for polling the status using Triggerer.""" + configuration = {"query": {"query": sql}} + return hook.insert_job( + configuration=configuration, + project_id=hook.project_id, + location=self.location, + job_id=job_id, + nowait=True, + ) + + def execute(self, context: Context): + if not self.deferrable: + super().execute(context) + else: + hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id) + self.log.info("Using ratio formula: %s", self.ratio_formula) + + self.log.info("Executing SQL check: %s", self.sql1) + job_1 = self._submit_job(hook, sql=self.sql1, job_id="") + context["ti"].xcom_push(key="job_id", value=job_1.job_id) + + self.log.info("Executing SQL check: %s", self.sql2) + job_2 = self._submit_job(hook, sql=self.sql2, job_id="") + self.defer( + timeout=self.execution_timeout, + trigger=BigQueryIntervalCheckTrigger( + conn_id=self.gcp_conn_id, + first_job_id=job_1.job_id, + second_job_id=job_2.job_id, + project_id=hook.project_id, + table=self.table, + metrics_thresholds=self.metrics_thresholds, + date_filter_column=self.date_filter_column, + days_back=self.days_back, + ratio_formula=self.ratio_formula, + ignore_zero=self.ignore_zero, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, Any]) -> None: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + self.log.info( + "%s completed with response %s ", + self.task_id, + event["message"], + ) + + +class BigQueryColumnCheckOperator(_BigQueryDbHookMixin, SQLColumnCheckOperator): + """ + BigQueryColumnCheckOperator subclasses the SQLColumnCheckOperator + in order to provide a job id for OpenLineage to parse. See base class + docstring for usage. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryColumnCheckOperator` + + :param table: the table name + :param column_mapping: a dictionary relating columns to their checks + :param partition_clause: a string SQL statement added to a WHERE clause + to partition data + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :param use_legacy_sql: Whether to use legacy SQL (true) + or standard SQL (false). + :param location: The geographic location of the job. See details at: + https://cloud.google.com/bigquery/docs/locations#specifying_your_location + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param labels: a dictionary containing labels for the table, passed to BigQuery + """ + + def __init__( + self, + *, + table: str, + column_mapping: dict, + partition_clause: str | None = None, + database: str | None = None, + accept_none: bool = True, + gcp_conn_id: str = "google_cloud_default", + use_legacy_sql: bool = True, + location: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + labels: dict | None = None, + **kwargs, + ) -> None: + super().__init__( + table=table, + column_mapping=column_mapping, + partition_clause=partition_clause, + database=database, + accept_none=accept_none, + **kwargs, + ) + self.table = table + self.column_mapping = column_mapping + self.partition_clause = partition_clause + self.database = database + self.accept_none = accept_none + self.gcp_conn_id = gcp_conn_id + self.use_legacy_sql = use_legacy_sql + self.location = location + self.impersonation_chain = impersonation_chain + self.labels = labels + + def _submit_job( + self, + hook: BigQueryHook, + job_id: str, + ) -> BigQueryJob: + """Submit a new job and get the job id for polling the status using Trigger.""" + configuration = {"query": {"query": self.sql}} + + return hook.insert_job( + configuration=configuration, + project_id=hook.project_id, + location=self.location, + job_id=job_id, + nowait=False, + ) + + def execute(self, context=None): + """Perform checks on the given columns.""" + hook = self.get_db_hook() + failed_tests = [] + + job = self._submit_job(hook, job_id="") + context["ti"].xcom_push(key="job_id", value=job.job_id) + records = job.result().to_dataframe() + + if records.empty: + raise AirflowException(f"The following query returned zero rows: {self.sql}") + + records.columns = records.columns.str.lower() + self.log.info("Record: %s", records) + + for row in records.iterrows(): + column = row[1].get("col_name") + check = row[1].get("check_type") + result = row[1].get("check_result") + tolerance = self.column_mapping[column][check].get("tolerance") + + self.column_mapping[column][check]["result"] = result + self.column_mapping[column][check]["success"] = self._get_match( + self.column_mapping[column][check], result, tolerance + ) + + failed_tests( + f"Column: {col}\n\tCheck: {check},\n\tCheck Values: {check_values}\n" + for col, checks in self.column_mapping.items() + for check, check_values in checks.items() + if not check_values["success"] + ) + if failed_tests: + exception_string = ( + f"Test failed.\nResults:\n{records!s}\n" + f"The following tests have failed:" + f"\n{''.join(failed_tests)}" + ) + self._raise_exception(exception_string) + + self.log.info("All tests have passed") + + +class BigQueryTableCheckOperator(_BigQueryDbHookMixin, SQLTableCheckOperator): + """ + BigQueryTableCheckOperator subclasses the SQLTableCheckOperator + in order to provide a job id for OpenLineage to parse. See base class + for usage. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryTableCheckOperator` + + :param table: the table name + :param checks: a dictionary of check names and boolean SQL statements + :param partition_clause: a string SQL statement added to a WHERE clause + to partition data + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :param use_legacy_sql: Whether to use legacy SQL (true) + or standard SQL (false). + :param location: The geographic location of the job. See details at: + https://cloud.google.com/bigquery/docs/locations#specifying_your_location + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param labels: a dictionary containing labels for the table, passed to BigQuery + """ + + def __init__( + self, + *, + table: str, + checks: dict, + partition_clause: str | None = None, + gcp_conn_id: str = "google_cloud_default", + use_legacy_sql: bool = True, + location: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + labels: dict | None = None, + **kwargs, + ) -> None: + super().__init__(table=table, checks=checks, partition_clause=partition_clause, **kwargs) + self.table = table + self.checks = checks + self.partition_clause = partition_clause + self.gcp_conn_id = gcp_conn_id + self.use_legacy_sql = use_legacy_sql + self.location = location + self.impersonation_chain = impersonation_chain + self.labels = labels + + def _submit_job( + self, + hook: BigQueryHook, + job_id: str, + ) -> BigQueryJob: + """Submit a new job and get the job id for polling the status using Trigger.""" + configuration = {"query": {"query": self.sql}} + + return hook.insert_job( + configuration=configuration, + project_id=hook.project_id, + location=self.location, + job_id=job_id, + nowait=False, + ) + + def execute(self, context=None): + """Execute the given checks on the table.""" + hook = self.get_db_hook() + job = self._submit_job(hook, job_id="") + context["ti"].xcom_push(key="job_id", value=job.job_id) + records = job.result().to_dataframe() + + if records.empty: + raise AirflowException(f"The following query returned zero rows: {self.sql}") + + records.columns = records.columns.str.lower() + self.log.info("Record:\n%s", records) + + for row in records.iterrows(): + check = row[1].get("check_name") + result = row[1].get("check_result") + self.checks[check]["success"] = _parse_boolean(str(result)) + + failed_tests = [ + f"\tCheck: {check},\n\tCheck Values: {check_values}\n" + for check, check_values in self.checks.items() + if not check_values["success"] + ] + if failed_tests: + exception_string = ( + f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n" + f"The following tests have failed:\n{', '.join(failed_tests)}" + ) + self._raise_exception(exception_string) + + self.log.info("All tests have passed") class BigQueryGetDataOperator(BaseOperator): @@ -360,6 +759,7 @@ class BigQueryGetDataOperator(BaseOperator): task_id='get_data_from_bq', dataset_id='test_dataset', table_id='Transaction_partitions', + project_id='internal-gcp-project', max_results=100, selected_fields='DATE', gcp_conn_id='airflow-conn-id' @@ -367,6 +767,8 @@ class BigQueryGetDataOperator(BaseOperator): :param dataset_id: The dataset ID of the requested table. (templated) :param table_id: The table ID of the requested table. (templated) + :param project_id: (Optional) The name of the project where the data + will be returned from. (templated) :param max_results: The maximum number of records (rows) to be fetched from the table. (templated) :param selected_fields: List of fields to return (comma-separated). If @@ -384,14 +786,16 @@ class BigQueryGetDataOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param deferrable: Run operator in the deferrable mode """ template_fields: Sequence[str] = ( - 'dataset_id', - 'table_id', - 'max_results', - 'selected_fields', - 'impersonation_chain', + "dataset_id", + "table_id", + "project_id", + "max_results", + "selected_fields", + "impersonation_chain", ) ui_color = BigQueryUIColors.QUERY.value @@ -400,12 +804,14 @@ def __init__( *, dataset_id: str, table_id: str, + project_id: str | None = None, max_results: int = 100, - selected_fields: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - location: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + selected_fields: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + location: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + deferrable: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -418,38 +824,98 @@ def __init__( self.delegate_to = delegate_to self.location = location self.impersonation_chain = impersonation_chain + self.project_id = project_id + self.deferrable = deferrable - def execute(self, context: 'Context') -> list: - self.log.info( - 'Fetching Data from %s.%s max results: %s', self.dataset_id, self.table_id, self.max_results + def _submit_job( + self, + hook: BigQueryHook, + job_id: str, + ) -> BigQueryJob: + get_query = self.generate_query() + configuration = {"query": {"query": get_query}} + """Submit a new job and get the job id for polling the status using Triggerer.""" + return hook.insert_job( + configuration=configuration, + location=self.location, + project_id=hook.project_id, + job_id=job_id, + nowait=True, ) + def generate_query(self) -> str: + """ + Generate a select query if selected fields are given or with * + for the given dataset and table id + """ + query = "select " + if self.selected_fields: + query += self.selected_fields + else: + query += "*" + query += f" from {self.dataset_id}.{self.table_id} limit {self.max_results}" + return query + + def execute(self, context: Context): hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) + self.hook = hook - if not self.selected_fields: - schema: Dict[str, list] = hook.get_schema( + if not self.deferrable: + self.log.info( + "Fetching Data from %s.%s max results: %s", self.dataset_id, self.table_id, self.max_results + ) + if not self.selected_fields: + schema: dict[str, list] = hook.get_schema( + dataset_id=self.dataset_id, + table_id=self.table_id, + ) + if "fields" in schema: + self.selected_fields = ",".join([field["name"] for field in schema["fields"]]) + + rows = hook.list_rows( dataset_id=self.dataset_id, table_id=self.table_id, + max_results=self.max_results, + selected_fields=self.selected_fields, + location=self.location, + project_id=self.project_id, ) - if "fields" in schema: - self.selected_fields = ','.join([field["name"] for field in schema["fields"]]) - rows = hook.list_rows( - dataset_id=self.dataset_id, - table_id=self.table_id, - max_results=self.max_results, - selected_fields=self.selected_fields, - location=self.location, + self.log.info("Total extracted rows: %s", len(rows)) + + table_data = [row.values() for row in rows] + return table_data + + job = self._submit_job(hook, job_id="") + self.job_id = job.job_id + context["ti"].xcom_push(key="job_id", value=self.job_id) + self.defer( + timeout=self.execution_timeout, + trigger=BigQueryGetDataTrigger( + conn_id=self.gcp_conn_id, + job_id=self.job_id, + dataset_id=self.dataset_id, + table_id=self.table_id, + project_id=hook.project_id, + ), + method_name="execute_complete", ) - self.log.info('Total extracted rows: %s', len(rows)) + def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) - table_data = [row.values() for row in rows] - return table_data + self.log.info("Total extracted rows: %s", len(event["records"])) + return event["records"] class BigQueryExecuteQueryOperator(BaseOperator): @@ -533,14 +999,14 @@ class BigQueryExecuteQueryOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'sql', - 'destination_dataset_table', - 'labels', - 'query_params', - 'impersonation_chain', + "sql", + "destination_dataset_table", + "labels", + "query_params", + "impersonation_chain", ) - template_ext: Sequence[str] = ('.sql',) - template_fields_renderers = {'sql': 'sql'} + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "sql"} ui_color = BigQueryUIColors.QUERY.value @property @@ -553,28 +1019,28 @@ def operator_extra_links(self): def __init__( self, *, - sql: Union[str, Iterable], - destination_dataset_table: Optional[str] = None, - write_disposition: str = 'WRITE_EMPTY', + sql: str | Iterable[str], + destination_dataset_table: str | None = None, + write_disposition: str = "WRITE_EMPTY", allow_large_results: bool = False, - flatten_results: Optional[bool] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - udf_config: Optional[list] = None, + flatten_results: bool | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + udf_config: list | None = None, use_legacy_sql: bool = True, - maximum_billing_tier: Optional[int] = None, - maximum_bytes_billed: Optional[float] = None, - create_disposition: str = 'CREATE_IF_NEEDED', - schema_update_options: Optional[Union[list, tuple, set]] = None, - query_params: Optional[list] = None, - labels: Optional[dict] = None, - priority: str = 'INTERACTIVE', - time_partitioning: Optional[dict] = None, - api_resource_configs: Optional[dict] = None, - cluster_fields: Optional[List[str]] = None, - location: Optional[str] = None, - encryption_configuration: Optional[dict] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + maximum_billing_tier: int | None = None, + maximum_bytes_billed: float | None = None, + create_disposition: str = "CREATE_IF_NEEDED", + schema_update_options: list | tuple | set | None = None, + query_params: list | None = None, + labels: dict | None = None, + priority: str = "INTERACTIVE", + time_partitioning: dict | None = None, + api_resource_configs: dict | None = None, + cluster_fields: list[str] | None = None, + location: str | None = None, + encryption_configuration: dict | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -605,12 +1071,12 @@ def __init__( self.cluster_fields = cluster_fields self.location = location self.encryption_configuration = encryption_configuration - self.hook = None # type: Optional[BigQueryHook] + self.hook: BigQueryHook | None = None self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): if self.hook is None: - self.log.info('Executing: %s', self.sql) + self.log.info("Executing: %s", self.sql) self.hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, use_legacy_sql=self.use_legacy_sql, @@ -619,7 +1085,7 @@ def execute(self, context: 'Context'): impersonation_chain=self.impersonation_chain, ) if isinstance(self.sql, str): - job_id: Union[str, List[str]] = self.hook.run_query( + job_id: str | list[str] = self.hook.run_query( sql=self.sql, destination_dataset_table=self.destination_dataset_table, write_disposition=self.write_disposition, @@ -663,13 +1129,13 @@ def execute(self, context: 'Context'): ] else: raise AirflowException(f"argument 'sql' of type {type(str)} is neither a string nor an iterable") - context['task_instance'].xcom_push(key='job_id', value=job_id) + context["task_instance"].xcom_push(key="job_id", value=job_id) def on_kill(self) -> None: super().on_kill() if self.hook is not None: - self.log.info('Cancelling running query') - self.hook.cancel_query() + self.log.info("Cancelling running query") + self.hook.cancel_job(self.hook.running_job_id) class BigQueryCreateEmptyTableOperator(BaseOperator): @@ -709,7 +1175,7 @@ class BigQueryCreateEmptyTableOperator(BaseOperator): .. seealso:: https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#timePartitioning - :param bigquery_conn_id: [Optional] The connection ID used to connect to Google Cloud and + :param gcp_conn_id: [Optional] The connection ID used to connect to Google Cloud and interact with the Bigquery service. :param google_cloud_storage_conn_id: [Optional] The connection ID used to connect to Google Cloud. and interact with the Google Cloud Storage service. @@ -726,7 +1192,7 @@ class BigQueryCreateEmptyTableOperator(BaseOperator): table_id='Employees', project_id='internal-gcp-project', gcs_schema_object='gs://schema-bucket/employee_schema.json', - bigquery_conn_id='airflow-conn-id', + gcp_conn_id='airflow-conn-id', google_cloud_storage_conn_id='airflow-conn-id' ) @@ -754,7 +1220,7 @@ class BigQueryCreateEmptyTableOperator(BaseOperator): project_id='internal-gcp-project', schema_fields=[{"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}], - bigquery_conn_id='airflow-conn-id-account', + gcp_conn_id='airflow-conn-id-account', google_cloud_storage_conn_id='airflow-conn-id' ) :param view: [Optional] A dictionary containing definition for the view. @@ -788,14 +1254,14 @@ class BigQueryCreateEmptyTableOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'dataset_id', - 'table_id', - 'project_id', - 'gcs_schema_object', - 'labels', - 'view', - 'materialized_view', - 'impersonation_chain', + "dataset_id", + "table_id", + "project_id", + "gcs_schema_object", + "labels", + "view", + "materialized_view", + "impersonation_chain", ) template_fields_renderers = {"table_resource": "json", "materialized_view": "json"} ui_color = BigQueryUIColors.TABLE.value @@ -806,24 +1272,33 @@ def __init__( *, dataset_id: str, table_id: str, - table_resource: Optional[Dict[str, Any]] = None, - project_id: Optional[str] = None, - schema_fields: Optional[List] = None, - gcs_schema_object: Optional[str] = None, - time_partitioning: Optional[Dict] = None, - bigquery_conn_id: str = 'google_cloud_default', - google_cloud_storage_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - labels: Optional[Dict] = None, - view: Optional[Dict] = None, - materialized_view: Optional[Dict] = None, - encryption_configuration: Optional[Dict] = None, - location: Optional[str] = None, - cluster_fields: Optional[List[str]] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + table_resource: dict[str, Any] | None = None, + project_id: str | None = None, + schema_fields: list | None = None, + gcs_schema_object: str | None = None, + time_partitioning: dict | None = None, + gcp_conn_id: str = "google_cloud_default", + bigquery_conn_id: str | None = None, + google_cloud_storage_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + labels: dict | None = None, + view: dict | None = None, + materialized_view: dict | None = None, + encryption_configuration: dict | None = None, + location: str | None = None, + cluster_fields: list[str] | None = None, + impersonation_chain: str | Sequence[str] | None = None, exists_ok: bool = False, **kwargs, ) -> None: + if bigquery_conn_id: + warnings.warn( + "The bigquery_conn_id parameter has been deprecated. Use the gcp_conn_id parameter instead.", + DeprecationWarning, + stacklevel=2, + ) + gcp_conn_id = bigquery_conn_id + super().__init__(**kwargs) self.project_id = project_id @@ -831,7 +1306,7 @@ def __init__( self.table_id = table_id self.schema_fields = schema_fields self.gcs_schema_object = gcs_schema_object - self.bigquery_conn_id = bigquery_conn_id + self.gcp_conn_id = gcp_conn_id self.google_cloud_storage_conn_id = google_cloud_storage_conn_id self.delegate_to = delegate_to self.time_partitioning = {} if time_partitioning is None else time_partitioning @@ -845,9 +1320,9 @@ def __init__( self.impersonation_chain = impersonation_chain self.exists_ok = exists_ok - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: bq_hook = BigQueryHook( - gcp_conn_id=self.bigquery_conn_id, + gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, location=self.location, impersonation_chain=self.impersonation_chain, @@ -866,7 +1341,7 @@ def execute(self, context: 'Context') -> None: schema_fields = self.schema_fields try: - self.log.info('Creating table') + self.log.info("Creating table") table = bq_hook.create_empty_table( project_id=self.project_id, dataset_id=self.dataset_id, @@ -889,10 +1364,10 @@ def execute(self, context: 'Context') -> None: table_id=table.to_api_repr()["tableReference"]["tableId"], ) self.log.info( - 'Table %s.%s.%s created successfully', table.project, table.dataset_id, table.table_id + "Table %s.%s.%s created successfully", table.project, table.dataset_id, table.table_id ) except Conflict: - self.log.info('Table %s.%s already exists.', self.dataset_id, self.table_id) + self.log.info("Table %s.%s already exists.", self.dataset_id, self.table_id) class BigQueryCreateExternalTableOperator(BaseOperator): @@ -949,7 +1424,7 @@ class BigQueryCreateExternalTableOperator(BaseOperator): columns are treated as bad records, and if there are too many bad records, an invalid error is returned in the job result. Only applicable to CSV, ignored for other formats. - :param bigquery_conn_id: (Optional) The connection ID used to connect to Google Cloud and + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud and interact with the Bigquery service. :param google_cloud_storage_conn_id: (Optional) The connection ID used to connect to Google Cloud and interact with the Google Cloud Storage service. @@ -976,13 +1451,13 @@ class BigQueryCreateExternalTableOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'bucket', - 'source_objects', - 'schema_object', - 'destination_project_dataset_table', - 'labels', - 'table_resource', - 'impersonation_chain', + "bucket", + "source_objects", + "schema_object", + "destination_project_dataset_table", + "labels", + "table_resource", + "impersonation_chain", ) template_fields_renderers = {"table_resource": "json"} ui_color = BigQueryUIColors.TABLE.value @@ -991,31 +1466,40 @@ class BigQueryCreateExternalTableOperator(BaseOperator): def __init__( self, *, - bucket: Optional[str] = None, - source_objects: Optional[List[str]] = None, - destination_project_dataset_table: Optional[str] = None, - table_resource: Optional[Dict[str, Any]] = None, - schema_fields: Optional[List] = None, - schema_object: Optional[str] = None, - source_format: Optional[str] = None, + bucket: str | None = None, + source_objects: list[str] | None = None, + destination_project_dataset_table: str | None = None, + table_resource: dict[str, Any] | None = None, + schema_fields: list | None = None, + schema_object: str | None = None, + source_format: str | None = None, autodetect: bool = False, - compression: Optional[str] = None, - skip_leading_rows: Optional[int] = None, - field_delimiter: Optional[str] = None, + compression: str | None = None, + skip_leading_rows: int | None = None, + field_delimiter: str | None = None, max_bad_records: int = 0, - quote_character: Optional[str] = None, + quote_character: str | None = None, allow_quoted_newlines: bool = False, allow_jagged_rows: bool = False, - bigquery_conn_id: str = 'google_cloud_default', - google_cloud_storage_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - src_fmt_configs: Optional[dict] = None, - labels: Optional[Dict] = None, - encryption_configuration: Optional[Dict] = None, - location: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + bigquery_conn_id: str | None = None, + google_cloud_storage_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + src_fmt_configs: dict | None = None, + labels: dict | None = None, + encryption_configuration: dict | None = None, + location: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: + if bigquery_conn_id: + warnings.warn( + "The bigquery_conn_id parameter has been deprecated. Use the gcp_conn_id parameter instead.", + DeprecationWarning, + stacklevel=2, + ) + gcp_conn_id = bigquery_conn_id + super().__init__(**kwargs) # BQ config @@ -1050,9 +1534,9 @@ def __init__( if not source_objects: raise ValueError("`source_objects` is required when not using `table_resource`.") if not source_format: - source_format = 'CSV' + source_format = "CSV" if not compression: - compression = 'NONE' + compression = "NONE" if not skip_leading_rows: skip_leading_rows = 0 if not field_delimiter: @@ -1085,7 +1569,7 @@ def __init__( self.quote_character = quote_character self.allow_quoted_newlines = allow_quoted_newlines self.allow_jagged_rows = allow_jagged_rows - self.bigquery_conn_id = bigquery_conn_id + self.gcp_conn_id = gcp_conn_id self.google_cloud_storage_conn_id = google_cloud_storage_conn_id self.delegate_to = delegate_to self.autodetect = autodetect @@ -1096,9 +1580,9 @@ def __init__( self.location = location self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: bq_hook = BigQueryHook( - gcp_conn_id=self.bigquery_conn_id, + gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, location=self.location, impersonation_chain=self.impersonation_chain, @@ -1116,7 +1600,7 @@ def execute(self, context: 'Context') -> None: ) return - if not self.schema_fields and self.schema_object and self.source_format != 'DATASTORE_BACKUP': + if not self.schema_fields and self.schema_object and self.source_format != "DATASTORE_BACKUP": gcs_hook = GCSHook( gcp_conn_id=self.google_cloud_storage_conn_id, delegate_to=self.delegate_to, @@ -1128,23 +1612,41 @@ def execute(self, context: 'Context') -> None: source_uris = [f"gs://{self.bucket}/{source_object}" for source_object in self.source_objects] - table = bq_hook.create_external_table( - external_project_dataset_table=self.destination_project_dataset_table, - schema_fields=schema_fields, - source_uris=source_uris, - source_format=self.source_format, - autodetect=self.autodetect, - compression=self.compression, - skip_leading_rows=self.skip_leading_rows, - field_delimiter=self.field_delimiter, - max_bad_records=self.max_bad_records, - quote_character=self.quote_character, - allow_quoted_newlines=self.allow_quoted_newlines, - allow_jagged_rows=self.allow_jagged_rows, - src_fmt_configs=self.src_fmt_configs, - labels=self.labels, - encryption_configuration=self.encryption_configuration, + project_id, dataset_id, table_id = bq_hook.split_tablename( + table_input=self.destination_project_dataset_table, + default_project_id=bq_hook.project_id or "", ) + + table_resource = { + "tableReference": { + "projectId": project_id, + "datasetId": dataset_id, + "tableId": table_id, + }, + "labels": self.labels, + "schema": {"fields": schema_fields}, + "externalDataConfiguration": { + "source_uris": source_uris, + "source_format": self.source_format, + "maxBadRecords": self.max_bad_records, + "autodetect": self.autodetect, + "compression": self.compression, + "csvOptions": { + "fieldDelimeter": self.field_delimiter, + "skipLeadingRows": self.skip_leading_rows, + "quote": self.quote_character, + "allowQuotedNewlines": self.allow_quoted_newlines, + "allowJaggedRows": self.allow_jagged_rows, + }, + }, + "location": self.location, + "encryptionConfiguration": self.encryption_configuration, + } + + table = bq_hook.create_empty_table( + table_resource=table_resource, + ) + BigQueryTableLink.persist( context=context, task_instance=self, @@ -1194,9 +1696,9 @@ class BigQueryDeleteDatasetOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'dataset_id', - 'project_id', - 'impersonation_chain', + "dataset_id", + "project_id", + "impersonation_chain", ) ui_color = BigQueryUIColors.DATASET.value @@ -1204,11 +1706,11 @@ def __init__( self, *, dataset_id: str, - project_id: Optional[str] = None, + project_id: str | None = None, delete_contents: bool = False, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.dataset_id = dataset_id @@ -1220,8 +1722,8 @@ def __init__( super().__init__(**kwargs) - def execute(self, context: 'Context') -> None: - self.log.info('Dataset id: %s Project id: %s', self.dataset_id, self.project_id) + def execute(self, context: Context) -> None: + self.log.info("Dataset id: %s Project id: %s", self.dataset_id, self.project_id) bq_hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, @@ -1274,10 +1776,10 @@ class BigQueryCreateEmptyDatasetOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'dataset_id', - 'project_id', - 'dataset_reference', - 'impersonation_chain', + "dataset_id", + "project_id", + "dataset_reference", + "impersonation_chain", ) template_fields_renderers = {"dataset_reference": "json"} ui_color = BigQueryUIColors.DATASET.value @@ -1286,13 +1788,13 @@ class BigQueryCreateEmptyDatasetOperator(BaseOperator): def __init__( self, *, - dataset_id: Optional[str] = None, - project_id: Optional[str] = None, - dataset_reference: Optional[Dict] = None, - location: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + dataset_id: str | None = None, + project_id: str | None = None, + dataset_reference: dict | None = None, + location: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, exists_ok: bool = False, **kwargs, ) -> None: @@ -1308,7 +1810,7 @@ def __init__( super().__init__(**kwargs) - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: bq_hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -1332,7 +1834,7 @@ def execute(self, context: 'Context') -> None: ) except Conflict: dataset_id = self.dataset_reference.get("datasetReference", {}).get("datasetId", self.dataset_id) - self.log.info('Dataset %s already exists.', dataset_id) + self.log.info("Dataset %s already exists.", dataset_id) class BigQueryGetDatasetOperator(BaseOperator): @@ -1359,15 +1861,12 @@ class BigQueryGetDatasetOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - - :rtype: dataset - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource """ template_fields: Sequence[str] = ( - 'dataset_id', - 'project_id', - 'impersonation_chain', + "dataset_id", + "project_id", + "impersonation_chain", ) ui_color = BigQueryUIColors.DATASET.value operator_extra_links = (BigQueryDatasetLink(),) @@ -1376,10 +1875,10 @@ def __init__( self, *, dataset_id: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.dataset_id = dataset_id @@ -1389,14 +1888,14 @@ def __init__( self.impersonation_chain = impersonation_chain super().__init__(**kwargs) - def execute(self, context: 'Context'): + def execute(self, context: Context): bq_hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) - self.log.info('Start getting dataset: %s:%s', self.project_id, self.dataset_id) + self.log.info("Start getting dataset: %s:%s", self.project_id, self.dataset_id) dataset = bq_hook.get_dataset(dataset_id=self.dataset_id, project_id=self.project_id) dataset = dataset.to_api_repr() @@ -1436,9 +1935,9 @@ class BigQueryGetDatasetTablesOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'dataset_id', - 'project_id', - 'impersonation_chain', + "dataset_id", + "project_id", + "impersonation_chain", ) ui_color = BigQueryUIColors.DATASET.value @@ -1446,11 +1945,11 @@ def __init__( self, *, dataset_id: str, - project_id: Optional[str] = None, - max_results: Optional[int] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + max_results: int | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.dataset_id = dataset_id @@ -1461,7 +1960,7 @@ def __init__( self.impersonation_chain = impersonation_chain super().__init__(**kwargs) - def execute(self, context: 'Context'): + def execute(self, context: Context): bq_hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -1501,15 +2000,12 @@ class BigQueryPatchDatasetOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - - :rtype: dataset - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource """ template_fields: Sequence[str] = ( - 'dataset_id', - 'project_id', - 'impersonation_chain', + "dataset_id", + "project_id", + "impersonation_chain", ) template_fields_renderers = {"dataset_resource": "json"} ui_color = BigQueryUIColors.DATASET.value @@ -1519,10 +2015,10 @@ def __init__( *, dataset_id: str, dataset_resource: dict, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: warnings.warn( @@ -1538,7 +2034,7 @@ def __init__( self.impersonation_chain = impersonation_chain super().__init__(**kwargs) - def execute(self, context: 'Context'): + def execute(self, context: Context): bq_hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -1584,16 +2080,13 @@ class BigQueryUpdateTableOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - - :rtype: table - https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#resource """ template_fields: Sequence[str] = ( - 'dataset_id', - 'table_id', - 'project_id', - 'impersonation_chain', + "dataset_id", + "table_id", + "project_id", + "impersonation_chain", ) template_fields_renderers = {"table_resource": "json"} ui_color = BigQueryUIColors.TABLE.value @@ -1602,14 +2095,14 @@ class BigQueryUpdateTableOperator(BaseOperator): def __init__( self, *, - table_resource: Dict[str, Any], - fields: Optional[List[str]] = None, - dataset_id: Optional[str] = None, - table_id: Optional[str] = None, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + table_resource: dict[str, Any], + fields: list[str] | None = None, + dataset_id: str | None = None, + table_id: str | None = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.dataset_id = dataset_id @@ -1622,7 +2115,7 @@ def __init__( self.impersonation_chain = impersonation_chain super().__init__(**kwargs) - def execute(self, context: 'Context'): + def execute(self, context: Context): bq_hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -1679,15 +2172,12 @@ class BigQueryUpdateDatasetOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - - :rtype: dataset - https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource """ template_fields: Sequence[str] = ( - 'dataset_id', - 'project_id', - 'impersonation_chain', + "dataset_id", + "project_id", + "impersonation_chain", ) template_fields_renderers = {"dataset_resource": "json"} ui_color = BigQueryUIColors.DATASET.value @@ -1696,13 +2186,13 @@ class BigQueryUpdateDatasetOperator(BaseOperator): def __init__( self, *, - dataset_resource: Dict[str, Any], - fields: Optional[List[str]] = None, - dataset_id: Optional[str] = None, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + dataset_resource: dict[str, Any], + fields: list[str] | None = None, + dataset_id: str | None = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.dataset_id = dataset_id @@ -1714,7 +2204,7 @@ def __init__( self.impersonation_chain = impersonation_chain super().__init__(**kwargs) - def execute(self, context: 'Context'): + def execute(self, context: Context): bq_hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -1768,8 +2258,8 @@ class BigQueryDeleteTableOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'deletion_dataset_table', - 'impersonation_chain', + "deletion_dataset_table", + "impersonation_chain", ) ui_color = BigQueryUIColors.TABLE.value @@ -1777,11 +2267,11 @@ def __init__( self, *, deletion_dataset_table: str, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, ignore_if_missing: bool = False, - location: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + location: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1793,8 +2283,8 @@ def __init__( self.location = location self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: - self.log.info('Deleting: %s', self.deletion_dataset_table) + def execute(self, context: Context) -> None: + self.log.info("Deleting: %s", self.deletion_dataset_table) hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -1835,9 +2325,10 @@ class BigQueryUpsertTableOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'dataset_id', - 'table_resource', - 'impersonation_chain', + "dataset_id", + "table_resource", + "impersonation_chain", + "project_id", ) template_fields_renderers = {"table_resource": "json"} ui_color = BigQueryUIColors.TABLE.value @@ -1848,11 +2339,11 @@ def __init__( *, dataset_id: str, table_resource: dict, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - location: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + location: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1865,8 +2356,8 @@ def __init__( self.location = location self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: - self.log.info('Upserting Dataset: %s with table_resource: %s', self.dataset_id, self.table_resource) + def execute(self, context: Context) -> None: + self.log.info("Upserting Dataset: %s with table_resource: %s", self.dataset_id, self.table_resource) hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -1938,11 +2429,11 @@ class BigQueryUpdateTableSchemaOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'schema_fields_updates', - 'dataset_id', - 'table_id', - 'project_id', - 'impersonation_chain', + "schema_fields_updates", + "dataset_id", + "table_id", + "project_id", + "impersonation_chain", ) template_fields_renderers = {"schema_fields_updates": "json"} ui_color = BigQueryUIColors.TABLE.value @@ -1951,14 +2442,14 @@ class BigQueryUpdateTableSchemaOperator(BaseOperator): def __init__( self, *, - schema_fields_updates: List[Dict[str, Any]], + schema_fields_updates: list[dict[str, Any]], dataset_id: str, table_id: str, include_policy_tags: bool = False, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.schema_fields_updates = schema_fields_updates @@ -1971,7 +2462,7 @@ def __init__( self.impersonation_chain = impersonation_chain super().__init__(**kwargs) - def execute(self, context: 'Context'): + def execute(self, context: Context): bq_hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -2020,7 +2511,7 @@ class BigQueryInsertJobOperator(BaseOperator): :param configuration: The configuration parameter maps directly to BigQuery's - configuration field in the job object. For more details see + configuration field in the job object. For more details see https://cloud.google.com/bigquery/docs/reference/v2/jobs :param job_id: The ID of the job. It will be suffixed with hash of job configuration unless ``force_rerun`` is True. @@ -2047,12 +2538,14 @@ class BigQueryInsertJobOperator(BaseOperator): :param cancel_on_kill: Flag which indicates whether cancel the hook's job or not, when on_kill is called :param result_retry: How to retry the `result` call that retrieves rows :param result_timeout: The number of seconds to wait for `result` method before using `result_retry` + :param deferrable: Run operator in the deferrable mode """ template_fields: Sequence[str] = ( "configuration", "job_id", "impersonation_chain", + "project_id", ) template_ext: Sequence[str] = ( ".json", @@ -2064,18 +2557,19 @@ class BigQueryInsertJobOperator(BaseOperator): def __init__( self, - configuration: Dict[str, Any], - project_id: Optional[str] = None, - location: Optional[str] = None, - job_id: Optional[str] = None, + configuration: dict[str, Any], + project_id: str | None = None, + location: str | None = None, + job_id: str | None = None, force_rerun: bool = True, - reattach_states: Optional[Set[str]] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + reattach_states: set[str] | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, cancel_on_kill: bool = True, result_retry: Retry = DEFAULT_RETRY, - result_timeout: Optional[float] = None, + result_timeout: float | None = None, + deferrable: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -2086,16 +2580,17 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.force_rerun = force_rerun - self.reattach_states: Set[str] = reattach_states or set() + self.reattach_states: set[str] = reattach_states or set() self.impersonation_chain = impersonation_chain self.cancel_on_kill = cancel_on_kill self.result_retry = result_retry self.result_timeout = result_timeout - self.hook: Optional[BigQueryHook] = None + self.hook: BigQueryHook | None = None + self.deferrable = deferrable def prepare_template(self) -> None: # If .json is passed then we have to read the file - if isinstance(self.configuration, str) and self.configuration.endswith('.json'): + if isinstance(self.configuration, str) and self.configuration.endswith(".json"): with open(self.configuration) as file: self.configuration = json.loads(file.read()) @@ -2104,7 +2599,7 @@ def _submit_job( hook: BigQueryHook, job_id: str, ) -> BigQueryJob: - # Submit a new job and wait for it to complete and get the result. + # Submit a new job without waiting for it to complete. return hook.insert_job( configuration=self.configuration, project_id=self.project_id, @@ -2112,6 +2607,7 @@ def _submit_job( job_id=job_id, timeout=self.result_timeout, retry=self.result_retry, + nowait=True, ) @staticmethod @@ -2119,21 +2615,6 @@ def _handle_job_error(job: BigQueryJob) -> None: if job.error_result: raise AirflowException(f"BigQuery job {job.job_id} failed: {job.error_result}") - def _job_id(self, context): - if self.force_rerun: - hash_base = str(uuid.uuid4()) - else: - hash_base = json.dumps(self.configuration, sort_keys=True) - - uniqueness_suffix = hashlib.md5(hash_base.encode()).hexdigest() - - if self.job_id: - return f"{self.job_id}_{uniqueness_suffix}" - - exec_date = context['execution_date'].isoformat() - job_id = f"airflow_{self.dag_id}_{self.task_id}_{exec_date}_{uniqueness_suffix}" - return re.sub(r"[:\-+.]", "_", job_id) - def execute(self, context: Any): hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, @@ -2142,12 +2623,18 @@ def execute(self, context: Any): ) self.hook = hook - job_id = self._job_id(context) + job_id = hook.generate_job_id( + job_id=self.job_id, + dag_id=self.dag_id, + task_id=self.task_id, + logical_date=context["logical_date"], + configuration=self.configuration, + force_rerun=self.force_rerun, + ) try: - self.log.info(f"Executing: {self.configuration}") + self.log.info("Executing: %s'", self.configuration) job = self._submit_job(hook, job_id) - self._handle_job_error(job) except Conflict: # If the job already exists retrieve it job = hook.get_job( @@ -2157,7 +2644,7 @@ def execute(self, context: Any): ) if job.state in self.reattach_states: # We are reattaching to a job - job.result(timeout=self.result_timeout, retry=self.result_retry) + job._begin() self._handle_job_error(job) else: # Same job configuration so we need force_rerun @@ -2167,19 +2654,69 @@ def execute(self, context: Any): f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`" ) - table = job.to_api_repr()["configuration"]["query"]["destinationTable"] - BigQueryTableLink.persist( - context=context, - task_instance=self, - dataset_id=table["datasetId"], - project_id=table["projectId"], - table_id=table["tableId"], - ) + job_types = { + LoadJob._JOB_TYPE: ["sourceTable", "destinationTable"], + CopyJob._JOB_TYPE: ["sourceTable", "destinationTable"], + ExtractJob._JOB_TYPE: ["sourceTable"], + QueryJob._JOB_TYPE: ["destinationTable"], + } + + if self.project_id: + for job_type, tables_prop in job_types.items(): + job_configuration = job.to_api_repr()["configuration"] + if job_type in job_configuration: + for table_prop in tables_prop: + if table_prop in job_configuration[job_type]: + table = job_configuration[job_type][table_prop] + persist_kwargs = { + "context": context, + "task_instance": self, + "project_id": self.project_id, + "table_id": table, + } + if not isinstance(table, str): + persist_kwargs["table_id"] = table["tableId"] + persist_kwargs["dataset_id"] = table["datasetId"] + + BigQueryTableLink.persist(**persist_kwargs) + self.job_id = job.job_id - return job.job_id + context["ti"].xcom_push(key="job_id", value=self.job_id) + # Wait for the job to complete + if not self.deferrable: + job.result(timeout=self.result_timeout, retry=self.result_retry) + self._handle_job_error(job) + + return self.job_id + self.defer( + timeout=self.execution_timeout, + trigger=BigQueryInsertJobTrigger( + conn_id=self.gcp_conn_id, + job_id=self.job_id, + project_id=self.project_id, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, Any]): + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + self.log.info( + "%s completed with response %s ", + self.task_id, + event["message"], + ) + return self.job_id def on_kill(self) -> None: if self.job_id and self.cancel_on_kill: self.hook.cancel_job( # type: ignore[union-attr] job_id=self.job_id, project_id=self.project_id, location=self.location ) + else: + self.log.info("Skipping to cancel job: %s:%s.%s", self.project_id, self.location, self.job_id) diff --git a/airflow/providers/google/cloud/operators/bigquery_dts.py b/airflow/providers/google/cloud/operators/bigquery_dts.py index 54c87505897ca..7c5f9b0ac13a0 100644 --- a/airflow/providers/google/cloud/operators/bigquery_dts.py +++ b/airflow/providers/google/cloud/operators/bigquery_dts.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains Google BigQuery Data Transfer Service operators.""" -from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -80,14 +82,14 @@ def __init__( self, *, transfer_config: dict, - project_id: Optional[str] = None, - location: Optional[str] = None, - authorization_code: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + location: str | None = None, + authorization_code: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id="google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -101,7 +103,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = BiqQueryDataTransferServiceHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, location=self.location ) @@ -170,13 +172,13 @@ def __init__( self, *, transfer_config_id: str, - project_id: Optional[str] = None, - location: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + location: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id="google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -189,7 +191,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = BiqQueryDataTransferServiceHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, location=self.location ) @@ -255,15 +257,15 @@ def __init__( self, *, transfer_config_id: str, - project_id: Optional[str] = None, - location: Optional[str] = None, - requested_time_range: Optional[dict] = None, - requested_run_time: Optional[dict] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + location: str | None = None, + requested_time_range: dict | None = None, + requested_run_time: dict | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id="google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -278,11 +280,11 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = BiqQueryDataTransferServiceHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, location=self.location ) - self.log.info('Submitting manual transfer for %s', self.transfer_config_id) + self.log.info("Submitting manual transfer for %s", self.transfer_config_id) response = hook.start_manual_transfer_runs( transfer_config_id=self.transfer_config_id, requested_time_range=self.requested_time_range, @@ -303,7 +305,7 @@ def execute(self, context: 'Context'): ) result = StartManualTransferRunsResponse.to_dict(response) - run_id = get_object_id(result['runs'][0]) + run_id = get_object_id(result["runs"][0]) self.xcom_push(context, key="run_id", value=run_id) - self.log.info('Transfer run %s submitted successfully.', run_id) + self.log.info("Transfer run %s submitted successfully.", run_id) return result diff --git a/airflow/providers/google/cloud/operators/bigtable.py b/airflow/providers/google/cloud/operators/bigtable.py index fdecd22089a27..e2e54166d5469 100644 --- a/airflow/providers/google/cloud/operators/bigtable.py +++ b/airflow/providers/google/cloud/operators/bigtable.py @@ -16,8 +16,10 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Cloud Bigtable operators.""" +from __future__ import annotations + import enum -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Iterable, Sequence import google.api_core.exceptions from google.cloud.bigtable.column_family import GarbageCollectionRule @@ -44,7 +46,7 @@ class BigtableValidationMixin: def _validate_inputs(self): for attr_name in self.REQUIRED_ATTRIBUTES: if not getattr(self, attr_name): - raise AirflowException(f'Empty parameter: {attr_name}') + raise AirflowException(f"Empty parameter: {attr_name}") class BigtableCreateInstanceOperator(BaseOperator, BigtableValidationMixin): @@ -90,13 +92,13 @@ class BigtableCreateInstanceOperator(BaseOperator, BigtableValidationMixin): account from the list granting this role to the originating account (templated). """ - REQUIRED_ATTRIBUTES: Iterable[str] = ('instance_id', 'main_cluster_id', 'main_cluster_zone') + REQUIRED_ATTRIBUTES: Iterable[str] = ("instance_id", "main_cluster_id", "main_cluster_zone") template_fields: Sequence[str] = ( - 'project_id', - 'instance_id', - 'main_cluster_id', - 'main_cluster_zone', - 'impersonation_chain', + "project_id", + "instance_id", + "main_cluster_id", + "main_cluster_zone", + "impersonation_chain", ) operator_extra_links = (BigtableInstanceLink(),) @@ -106,16 +108,16 @@ def __init__( instance_id: str, main_cluster_id: str, main_cluster_zone: str, - project_id: Optional[str] = None, - replica_clusters: Optional[List[Dict[str, str]]] = None, - instance_display_name: Optional[str] = None, - instance_type: Optional[enums.Instance.Type] = None, - instance_labels: Optional[Dict] = None, - cluster_nodes: Optional[int] = None, - cluster_storage_type: Optional[enums.StorageType] = None, - timeout: Optional[float] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + replica_clusters: list[dict[str, str]] | None = None, + instance_display_name: str | None = None, + instance_type: enums.Instance.Type | None = None, + instance_labels: dict | None = None, + cluster_nodes: int | None = None, + cluster_storage_type: enums.StorageType | None = None, + timeout: float | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.project_id = project_id @@ -134,7 +136,7 @@ def __init__( self.impersonation_chain = impersonation_chain super().__init__(**kwargs) - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = BigtableHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -165,7 +167,7 @@ def execute(self, context: 'Context') -> None: ) BigtableInstanceLink.persist(context=context, task_instance=self) except google.api_core.exceptions.GoogleAPICallError as e: - self.log.error('An error occurred. Exiting.') + self.log.error("An error occurred. Exiting.") raise e @@ -200,11 +202,11 @@ class BigtableUpdateInstanceOperator(BaseOperator, BigtableValidationMixin): account from the list granting this role to the originating account (templated). """ - REQUIRED_ATTRIBUTES: Iterable[str] = ['instance_id'] + REQUIRED_ATTRIBUTES: Iterable[str] = ["instance_id"] template_fields: Sequence[str] = ( - 'project_id', - 'instance_id', - 'impersonation_chain', + "project_id", + "instance_id", + "impersonation_chain", ) operator_extra_links = (BigtableInstanceLink(),) @@ -212,13 +214,13 @@ def __init__( self, *, instance_id: str, - project_id: Optional[str] = None, - instance_display_name: Optional[str] = None, - instance_type: Optional[Union[enums.Instance.Type, enum.IntEnum]] = None, - instance_labels: Optional[Dict] = None, - timeout: Optional[float] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + instance_display_name: str | None = None, + instance_type: enums.Instance.Type | enum.IntEnum | None = None, + instance_labels: dict | None = None, + timeout: float | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.project_id = project_id @@ -232,7 +234,7 @@ def __init__( self.impersonation_chain = impersonation_chain super().__init__(**kwargs) - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = BigtableHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -252,7 +254,7 @@ def execute(self, context: 'Context') -> None: ) BigtableInstanceLink.persist(context=context, task_instance=self) except google.api_core.exceptions.GoogleAPICallError as e: - self.log.error('An error occurred. Exiting.') + self.log.error("An error occurred. Exiting.") raise e @@ -281,20 +283,20 @@ class BigtableDeleteInstanceOperator(BaseOperator, BigtableValidationMixin): account from the list granting this role to the originating account (templated). """ - REQUIRED_ATTRIBUTES = ('instance_id',) # type: Iterable[str] + REQUIRED_ATTRIBUTES = ("instance_id",) # type: Iterable[str] template_fields: Sequence[str] = ( - 'project_id', - 'instance_id', - 'impersonation_chain', + "project_id", + "instance_id", + "impersonation_chain", ) def __init__( self, *, instance_id: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.project_id = project_id @@ -304,7 +306,7 @@ def __init__( self.impersonation_chain = impersonation_chain super().__init__(**kwargs) - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = BigtableHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -318,7 +320,7 @@ def execute(self, context: 'Context') -> None: self.project_id, ) except google.api_core.exceptions.GoogleAPICallError as e: - self.log.error('An error occurred. Exiting.') + self.log.error("An error occurred. Exiting.") raise e @@ -354,12 +356,12 @@ class BigtableCreateTableOperator(BaseOperator, BigtableValidationMixin): account from the list granting this role to the originating account (templated). """ - REQUIRED_ATTRIBUTES = ('instance_id', 'table_id') # type: Iterable[str] + REQUIRED_ATTRIBUTES = ("instance_id", "table_id") # type: Iterable[str] template_fields: Sequence[str] = ( - 'project_id', - 'instance_id', - 'table_id', - 'impersonation_chain', + "project_id", + "instance_id", + "table_id", + "impersonation_chain", ) operator_extra_links = (BigtableTablesLink(),) @@ -368,11 +370,11 @@ def __init__( *, instance_id: str, table_id: str, - project_id: Optional[str] = None, - initial_split_keys: Optional[List] = None, - column_families: Optional[Dict[str, GarbageCollectionRule]] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + initial_split_keys: list | None = None, + column_families: dict[str, GarbageCollectionRule] | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.project_id = project_id @@ -406,7 +408,7 @@ def _compare_column_families(self, hook, instance) -> bool: return False return True - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = BigtableHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -459,12 +461,12 @@ class BigtableDeleteTableOperator(BaseOperator, BigtableValidationMixin): account from the list granting this role to the originating account (templated). """ - REQUIRED_ATTRIBUTES = ('instance_id', 'table_id') # type: Iterable[str] + REQUIRED_ATTRIBUTES = ("instance_id", "table_id") # type: Iterable[str] template_fields: Sequence[str] = ( - 'project_id', - 'instance_id', - 'table_id', - 'impersonation_chain', + "project_id", + "instance_id", + "table_id", + "impersonation_chain", ) def __init__( @@ -472,10 +474,10 @@ def __init__( *, instance_id: str, table_id: str, - project_id: Optional[str] = None, - app_profile_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + app_profile_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.project_id = project_id @@ -487,7 +489,7 @@ def __init__( self.impersonation_chain = impersonation_chain super().__init__(**kwargs) - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = BigtableHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -506,7 +508,7 @@ def execute(self, context: 'Context') -> None: # It's OK if table doesn't exists. self.log.info("The table '%s' no longer exists. Consider it as deleted", self.table_id) except google.api_core.exceptions.GoogleAPICallError as e: - self.log.error('An error occurred. Exiting.') + self.log.error("An error occurred. Exiting.") raise e @@ -537,13 +539,13 @@ class BigtableUpdateClusterOperator(BaseOperator, BigtableValidationMixin): account from the list granting this role to the originating account (templated). """ - REQUIRED_ATTRIBUTES = ('instance_id', 'cluster_id', 'nodes') # type: Iterable[str] + REQUIRED_ATTRIBUTES = ("instance_id", "cluster_id", "nodes") # type: Iterable[str] template_fields: Sequence[str] = ( - 'project_id', - 'instance_id', - 'cluster_id', - 'nodes', - 'impersonation_chain', + "project_id", + "instance_id", + "cluster_id", + "nodes", + "impersonation_chain", ) operator_extra_links = (BigtableClusterLink(),) @@ -553,9 +555,9 @@ def __init__( instance_id: str, cluster_id: str, nodes: int, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.project_id = project_id @@ -567,7 +569,7 @@ def __init__( self.impersonation_chain = impersonation_chain super().__init__(**kwargs) - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = BigtableHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -584,5 +586,5 @@ def execute(self, context: 'Context') -> None: f"Dependency: cluster '{self.cluster_id}' does not exist for instance '{self.instance_id}'." ) except google.api_core.exceptions.GoogleAPICallError as e: - self.log.error('An error occurred. Exiting.') + self.log.error("An error occurred. Exiting.") raise e diff --git a/airflow/providers/google/cloud/operators/cloud_build.py b/airflow/providers/google/cloud/operators/cloud_build.py index c377af6732d7f..c33fa36c64a2b 100644 --- a/airflow/providers/google/cloud/operators/cloud_build.py +++ b/airflow/providers/google/cloud/operators/cloud_build.py @@ -15,16 +15,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Operators that integrates with Google Cloud Build service.""" +from __future__ import annotations import json import re from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union -from urllib.parse import unquote, urlparse +from typing import TYPE_CHECKING, Any, Sequence +from urllib.parse import unquote, urlsplit -import yaml from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry from google.cloud.devtools.cloudbuild_v1.types import Build, BuildTrigger, RepoSource @@ -32,6 +31,13 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.cloud_build import CloudBuildHook +from airflow.providers.google.cloud.links.cloud_build import ( + CloudBuildLink, + CloudBuildListLink, + CloudBuildTriggerDetailsLink, + CloudBuildTriggersListLink, +) +from airflow.utils import yaml if TYPE_CHECKING: from airflow.utils.context import Context @@ -66,21 +72,21 @@ class CloudBuildCancelBuildOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: dict """ template_fields: Sequence[str] = ("project_id", "id_", "gcp_conn_id") + operator_extra_links = (CloudBuildLink(),) def __init__( self, *, id_: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -92,7 +98,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudBuildHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) result = hook.cancel_build( id_=self.id_, @@ -101,6 +107,16 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + self.xcom_push(context, key="id", value=result.id) + project_id = self.project_id or hook.project_id + if project_id: + CloudBuildLink.persist( + context=context, + task_instance=self, + project_id=project_id, + build_id=result.id, + ) return Build.to_dict(result) @@ -132,22 +148,22 @@ class CloudBuildCreateBuildOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: dict """ template_fields: Sequence[str] = ("project_id", "build", "gcp_conn_id", "impersonation_chain") + operator_extra_links = (CloudBuildLink(),) def __init__( self, *, - build: Union[Dict, Build], - project_id: Optional[str] = None, + build: dict | Build, + project_id: str | None = None, wait: bool = True, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -167,12 +183,12 @@ def prepare_template(self) -> None: if not isinstance(self.build_raw, str): return with open(self.build_raw) as file: - if any(self.build_raw.endswith(ext) for ext in ['.yaml', '.yml']): + if any(self.build_raw.endswith(ext) for ext in [".yaml", ".yml"]): self.build = yaml.safe_load(file.read()) - if self.build_raw.endswith('.json'): + if self.build_raw.endswith(".json"): self.build = json.loads(file.read()) - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudBuildHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) build = BuildProcessor(build=self.build).process_body() @@ -185,6 +201,16 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + self.xcom_push(context, key="id", value=result.id) + project_id = self.project_id or hook.project_id + if project_id: + CloudBuildLink.persist( + context=context, + task_instance=self, + project_id=project_id, + build_id=result.id, + ) return Build.to_dict(result) @@ -215,21 +241,24 @@ class CloudBuildCreateBuildTriggerOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: dict """ template_fields: Sequence[str] = ("project_id", "trigger", "gcp_conn_id") + operator_extra_links = ( + CloudBuildTriggersListLink(), + CloudBuildTriggerDetailsLink(), + ) def __init__( self, *, - trigger: Union[dict, BuildTrigger], - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + trigger: dict | BuildTrigger, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -241,7 +270,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudBuildHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) result = hook.create_build_trigger( trigger=self.trigger, @@ -250,6 +279,20 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + self.xcom_push(context, key="id", value=result.id) + project_id = self.project_id or hook.project_id + if project_id: + CloudBuildTriggerDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + trigger_id=result.id, + ) + CloudBuildTriggersListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) return BuildTrigger.to_dict(result) @@ -281,17 +324,18 @@ class CloudBuildDeleteBuildTriggerOperator(BaseOperator): """ template_fields: Sequence[str] = ("project_id", "trigger_id", "gcp_conn_id") + operator_extra_links = (CloudBuildTriggersListLink(),) def __init__( self, *, trigger_id: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -303,7 +347,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudBuildHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) hook.delete_build_trigger( trigger_id=self.trigger_id, @@ -312,6 +356,13 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + project_id = self.project_id or hook.project_id + if project_id: + CloudBuildTriggersListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) class CloudBuildGetBuildOperator(BaseOperator): @@ -340,21 +391,21 @@ class CloudBuildGetBuildOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: dict """ template_fields: Sequence[str] = ("project_id", "id_", "gcp_conn_id") + operator_extra_links = (CloudBuildLink(),) def __init__( self, *, id_: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -366,7 +417,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudBuildHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) result = hook.get_build( id_=self.id_, @@ -375,6 +426,14 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + project_id = self.project_id or hook.project_id + if project_id: + CloudBuildLink.persist( + context=context, + task_instance=self, + project_id=project_id, + build_id=result.id, + ) return Build.to_dict(result) @@ -404,21 +463,21 @@ class CloudBuildGetBuildTriggerOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: dict """ template_fields: Sequence[str] = ("project_id", "trigger_id", "gcp_conn_id") + operator_extra_links = (CloudBuildTriggerDetailsLink(),) def __init__( self, *, trigger_id: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -430,7 +489,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudBuildHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) result = hook.get_build_trigger( trigger_id=self.trigger_id, @@ -439,6 +498,14 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + project_id = self.project_id or hook.project_id + if project_id: + CloudBuildTriggerDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + trigger_id=result.id, + ) return BuildTrigger.to_dict(result) @@ -470,23 +537,23 @@ class CloudBuildListBuildTriggersOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: List[dict] """ template_fields: Sequence[str] = ("location", "project_id", "gcp_conn_id") + operator_extra_links = (CloudBuildTriggersListLink(),) def __init__( self, *, location: str, - project_id: Optional[str] = None, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + page_size: int | None = None, + page_token: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -500,7 +567,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudBuildHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) results = hook.list_build_triggers( project_id=self.project_id, @@ -511,6 +578,13 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + project_id = self.project_id or hook.project_id + if project_id: + CloudBuildTriggersListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) return [BuildTrigger.to_dict(result) for result in results] @@ -542,23 +616,23 @@ class CloudBuildListBuildsOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: List[dict] """ template_fields: Sequence[str] = ("location", "project_id", "gcp_conn_id") + operator_extra_links = (CloudBuildListLink(),) def __init__( self, *, location: str, - project_id: Optional[str] = None, - page_size: Optional[int] = None, - filter_: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + page_size: int | None = None, + filter_: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -572,7 +646,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudBuildHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) results = hook.list_builds( project_id=self.project_id, @@ -583,6 +657,9 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + project_id = self.project_id or hook.project_id + if project_id: + CloudBuildListLink.persist(context=context, task_instance=self, project_id=project_id) return [Build.to_dict(result) for result in results] @@ -614,22 +691,22 @@ class CloudBuildRetryBuildOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: dict """ template_fields: Sequence[str] = ("project_id", "id_", "gcp_conn_id") + operator_extra_links = (CloudBuildLink(),) def __init__( self, *, id_: str, - project_id: Optional[str] = None, + project_id: str | None = None, wait: bool = True, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -642,7 +719,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudBuildHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) result = hook.retry_build( id_=self.id_, @@ -652,6 +729,16 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + self.xcom_push(context, key="id", value=result.id) + project_id = self.project_id or hook.project_id + if project_id: + CloudBuildLink.persist( + context=context, + task_instance=self, + project_id=project_id, + build_id=result.id, + ) return Build.to_dict(result) @@ -684,23 +771,23 @@ class CloudBuildRunBuildTriggerOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: dict """ template_fields: Sequence[str] = ("project_id", "trigger_id", "source", "gcp_conn_id") + operator_extra_links = (CloudBuildLink(),) def __init__( self, *, trigger_id: str, - source: Union[dict, RepoSource], - project_id: Optional[str] = None, + source: dict | RepoSource, + project_id: str | None = None, wait: bool = True, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -714,7 +801,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudBuildHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) result = hook.run_build_trigger( trigger_id=self.trigger_id, @@ -725,6 +812,15 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + self.xcom_push(context, key="id", value=result.id) + project_id = self.project_id or hook.project_id + if project_id: + CloudBuildLink.persist( + context=context, + task_instance=self, + project_id=project_id, + build_id=result.id, + ) return Build.to_dict(result) @@ -756,22 +852,22 @@ class CloudBuildUpdateBuildTriggerOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: dict """ template_fields: Sequence[str] = ("project_id", "trigger_id", "trigger", "gcp_conn_id") + operator_extra_links = (CloudBuildTriggerDetailsLink(),) def __init__( self, *, trigger_id: str, - trigger: Union[dict, BuildTrigger], - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + trigger: dict | BuildTrigger, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -784,7 +880,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudBuildHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) result = hook.update_build_trigger( trigger_id=self.trigger_id, @@ -794,6 +890,15 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + self.xcom_push(context, key="id", value=result.id) + project_id = self.project_id or hook.project_id + if project_id: + CloudBuildTriggerDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + trigger_id=result.id, + ) return BuildTrigger.to_dict(result) @@ -808,7 +913,7 @@ class BuildProcessor: See: https://cloud.google.com/cloud-build/docs/api/reference/rest/Shared.Types/Build """ - def __init__(self, build: Union[Dict, Build]) -> None: + def __init__(self, build: dict | Build) -> None: self.build = deepcopy(build) def _verify_source(self) -> None: @@ -849,15 +954,14 @@ def process_body(self) -> Build: Processes the body passed in the constructor :return: the body. - :rtype: `google.cloud.devtools.cloudbuild_v1.types.Build` """ - if 'source' in self.build: + if "source" in self.build: self._verify_source() self._reformat_source() return Build(self.build) @staticmethod - def _convert_repo_url_to_dict(source: str) -> Dict[str, Any]: + def _convert_repo_url_to_dict(source: str) -> dict[str, Any]: """ Convert url to repository in Google Cloud Source to a format supported by the API @@ -868,7 +972,7 @@ def _convert_repo_url_to_dict(source: str) -> Dict[str, Any]: https://source.cloud.google.com/airflow-project/airflow-repo/+/branch-name: """ - url_parts = urlparse(source) + url_parts = urlsplit(source) match = REGEX_REPO_PATH.search(url_parts.path) @@ -891,7 +995,7 @@ def _convert_repo_url_to_dict(source: str) -> Dict[str, Any]: return source_dict @staticmethod - def _convert_storage_url_to_dict(storage_url: str) -> Dict[str, Any]: + def _convert_storage_url_to_dict(storage_url: str) -> dict[str, Any]: """ Convert url to object in Google Cloud Storage to a format supported by the API @@ -902,7 +1006,7 @@ def _convert_storage_url_to_dict(storage_url: str) -> Dict[str, Any]: gs://bucket-name/object-name.tar.gz """ - url_parts = urlparse(storage_url) + url_parts = urlsplit(storage_url) if url_parts.scheme != "gs" or not url_parts.hostname or not url_parts.path or url_parts.path == "/": raise AirflowException( @@ -910,7 +1014,7 @@ def _convert_storage_url_to_dict(storage_url: str) -> Dict[str, Any]: "gs://bucket-name/object-name.tar.gz#24565443" ) - source_dict: Dict[str, Any] = { + source_dict: dict[str, Any] = { "bucket": url_parts.hostname, "object_": url_parts.path[1:], } diff --git a/airflow/providers/google/cloud/operators/cloud_composer.py b/airflow/providers/google/cloud/operators/cloud_composer.py index 8c4a8a533480a..1a245d2eab0ed 100644 --- a/airflow/providers/google/cloud/operators/cloud_composer.py +++ b/airflow/providers/google/cloud/operators/cloud_composer.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from google.api_core.exceptions import AlreadyExists from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -38,7 +40,7 @@ CLOUD_COMPOSER_DETAILS_LINK = ( CLOUD_COMPOSER_BASE_LINK + "/detail/{region}/{environment_id}/monitoring?project={project_id}" ) -CLOUD_COMPOSER_ENVIRONMENTS_LINK = CLOUD_COMPOSER_BASE_LINK + '?project={project_id}' +CLOUD_COMPOSER_ENVIRONMENTS_LINK = CLOUD_COMPOSER_BASE_LINK + "?project={project_id}" class CloudComposerEnvironmentLink(BaseGoogleLink): @@ -50,12 +52,12 @@ class CloudComposerEnvironmentLink(BaseGoogleLink): @staticmethod def persist( - operator_instance: Union[ - "CloudComposerCreateEnvironmentOperator", - "CloudComposerUpdateEnvironmentOperator", - "CloudComposerGetEnvironmentOperator", - ], - context: "Context", + operator_instance: ( + CloudComposerCreateEnvironmentOperator + | CloudComposerUpdateEnvironmentOperator + | CloudComposerGetEnvironmentOperator + ), + context: Context, ) -> None: operator_instance.xcom_push( context, @@ -76,7 +78,7 @@ class CloudComposerEnvironmentsLink(BaseGoogleLink): format_str = CLOUD_COMPOSER_ENVIRONMENTS_LINK @staticmethod - def persist(operator_instance: "CloudComposerListEnvironmentsOperator", context: "Context") -> None: + def persist(operator_instance: CloudComposerListEnvironmentsOperator, context: Context) -> None: operator_instance.xcom_push( context, key=CloudComposerEnvironmentsLink.key, @@ -115,11 +117,11 @@ class CloudComposerCreateEnvironmentOperator(BaseOperator): """ template_fields = ( - 'project_id', - 'region', - 'environment_id', - 'environment', - 'impersonation_chain', + "project_id", + "region", + "environment_id", + "environment", + "impersonation_chain", ) operator_extra_links = (CloudComposerEnvironmentLink(),) @@ -130,13 +132,13 @@ def __init__( project_id: str, region: str, environment_id: str, - environment: Union[Environment, Dict], + environment: Environment | dict, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - delegate_to: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + impersonation_chain: str | Sequence[str] | None = None, + delegate_to: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), deferrable: bool = False, pooling_period_seconds: int = 30, **kwargs, @@ -155,7 +157,7 @@ def __init__( self.deferrable = deferrable self.pooling_period_seconds = pooling_period_seconds - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudComposerHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -178,6 +180,8 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + context["ti"].xcom_push(key="operation_id", value=result.operation.name) + if not self.deferrable: environment = hook.wait_for_operation(timeout=self.timeout, operation=result) return Environment.to_dict(environment) @@ -205,7 +209,7 @@ def execute(self, context: 'Context'): ) return Environment.to_dict(environment) - def execute_complete(self, context: "Context", event: dict): + def execute_complete(self, context: Context, event: dict): if event["operation_done"]: hook = CloudComposerHook( gcp_conn_id=self.gcp_conn_id, @@ -254,10 +258,10 @@ class CloudComposerDeleteEnvironmentOperator(BaseOperator): """ template_fields = ( - 'project_id', - 'region', - 'environment_id', - 'impersonation_chain', + "project_id", + "region", + "environment_id", + "impersonation_chain", ) def __init__( @@ -266,12 +270,12 @@ def __init__( project_id: str, region: str, environment_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - delegate_to: Optional[str] = None, + impersonation_chain: str | Sequence[str] | None = None, + delegate_to: str | None = None, deferrable: bool = False, pooling_period_seconds: int = 30, **kwargs, @@ -289,7 +293,7 @@ def __init__( self.deferrable = deferrable self.pooling_period_seconds = pooling_period_seconds - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudComposerHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -319,7 +323,7 @@ def execute(self, context: 'Context'): method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME, ) - def execute_complete(self, context: "Context", event: dict): + def execute_complete(self, context: Context, event: dict): pass @@ -348,10 +352,10 @@ class CloudComposerGetEnvironmentOperator(BaseOperator): """ template_fields = ( - 'project_id', - 'region', - 'environment_id', - 'impersonation_chain', + "project_id", + "region", + "environment_id", + "impersonation_chain", ) operator_extra_links = (CloudComposerEnvironmentLink(),) @@ -362,12 +366,12 @@ def __init__( project_id: str, region: str, environment_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - delegate_to: Optional[str] = None, + impersonation_chain: str | Sequence[str] | None = None, + delegate_to: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -381,7 +385,7 @@ def __init__( self.impersonation_chain = impersonation_chain self.delegate_to = delegate_to - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudComposerHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -428,9 +432,9 @@ class CloudComposerListEnvironmentsOperator(BaseOperator): """ template_fields = ( - 'project_id', - 'region', - 'impersonation_chain', + "project_id", + "region", + "impersonation_chain", ) operator_extra_links = (CloudComposerEnvironmentsLink(),) @@ -440,14 +444,14 @@ def __init__( *, project_id: str, region: str, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + page_size: int | None = None, + page_token: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - delegate_to: Optional[str] = None, + impersonation_chain: str | Sequence[str] | None = None, + delegate_to: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -462,7 +466,7 @@ def __init__( self.impersonation_chain = impersonation_chain self.delegate_to = delegate_to - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudComposerHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -514,10 +518,10 @@ class CloudComposerUpdateEnvironmentOperator(BaseOperator): """ template_fields = ( - 'project_id', - 'region', - 'environment_id', - 'impersonation_chain', + "project_id", + "region", + "environment_id", + "impersonation_chain", ) operator_extra_links = (CloudComposerEnvironmentLink(),) @@ -528,14 +532,14 @@ def __init__( project_id: str, region: str, environment_id: str, - environment: Union[Dict, Environment], - update_mask: Union[Dict, FieldMask], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + environment: dict | Environment, + update_mask: dict | FieldMask, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - delegate_to: Optional[str] = None, + impersonation_chain: str | Sequence[str] | None = None, + delegate_to: str | None = None, deferrable: bool = False, pooling_period_seconds: int = 30, **kwargs, @@ -555,7 +559,7 @@ def __init__( self.deferrable = deferrable self.pooling_period_seconds = pooling_period_seconds - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudComposerHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -591,7 +595,7 @@ def execute(self, context: 'Context'): method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME, ) - def execute_complete(self, context: "Context", event: dict): + def execute_complete(self, context: Context, event: dict): if event["operation_done"]: hook = CloudComposerHook( gcp_conn_id=self.gcp_conn_id, @@ -635,9 +639,9 @@ class CloudComposerListImageVersionsOperator(BaseOperator): """ template_fields = ( - 'project_id', - 'region', - 'impersonation_chain', + "project_id", + "region", + "impersonation_chain", ) def __init__( @@ -645,15 +649,15 @@ def __init__( *, project_id: str, region: str, - page_size: Optional[int] = None, - page_token: Optional[str] = None, + page_size: int | None = None, + page_token: str | None = None, include_past_releases: bool = False, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - delegate_to: Optional[str] = None, + impersonation_chain: str | Sequence[str] | None = None, + delegate_to: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -669,7 +673,7 @@ def __init__( self.impersonation_chain = impersonation_chain self.delegate_to = delegate_to - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudComposerHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, diff --git a/airflow/providers/google/cloud/operators/cloud_memorystore.py b/airflow/providers/google/cloud/operators/cloud_memorystore.py index cbb4ccd025fb3..c4bc61f87ab9c 100644 --- a/airflow/providers/google/cloud/operators/cloud_memorystore.py +++ b/airflow/providers/google/cloud/operators/cloud_memorystore.py @@ -23,7 +23,9 @@ FieldMask memcache """ -from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -36,6 +38,12 @@ CloudMemorystoreHook, CloudMemorystoreMemcachedHook, ) +from airflow.providers.google.cloud.links.cloud_memorystore import ( + MemcachedInstanceDetailsLink, + MemcachedInstanceListLink, + RedisInstanceDetailsLink, + RedisInstanceListLink, +) if TYPE_CHECKING: from airflow.utils.context import Context @@ -94,19 +102,20 @@ class CloudMemorystoreCreateInstanceOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (RedisInstanceDetailsLink(),) def __init__( self, *, location: str, instance_id: str, - instance: Union[Dict, Instance], - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + instance: dict | Instance, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -120,7 +129,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudMemorystoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -133,6 +142,13 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + RedisInstanceDetailsLink.persist( + context=context, + task_instance=self, + instance_id=self.instance_id, + location_id=self.location, + project_id=self.project_id or hook.project_id, + ) return Instance.to_dict(result) @@ -180,12 +196,12 @@ def __init__( *, location: str, instance: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -198,7 +214,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudMemorystoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -257,19 +273,20 @@ class CloudMemorystoreExportInstanceOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (RedisInstanceDetailsLink(),) def __init__( self, *, location: str, instance: str, - output_config: Union[Dict, OutputConfig], - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + output_config: dict | OutputConfig, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -283,7 +300,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudMemorystoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -297,6 +314,13 @@ def execute(self, context: 'Context') -> None: timeout=self.timeout, metadata=self.metadata, ) + RedisInstanceDetailsLink.persist( + context=context, + task_instance=self, + instance_id=self.instance, + location_id=self.location, + project_id=self.project_id or hook.project_id, + ) class CloudMemorystoreFailoverInstanceOperator(BaseOperator): @@ -341,6 +365,7 @@ class CloudMemorystoreFailoverInstanceOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (RedisInstanceDetailsLink(),) def __init__( self, @@ -348,12 +373,12 @@ def __init__( location: str, instance: str, data_protection_mode: FailoverInstanceRequest.DataProtectionMode, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -367,7 +392,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudMemorystoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -380,6 +405,13 @@ def execute(self, context: 'Context') -> None: timeout=self.timeout, metadata=self.metadata, ) + RedisInstanceDetailsLink.persist( + context=context, + task_instance=self, + instance_id=self.instance, + location_id=self.location, + project_id=self.project_id or hook.project_id, + ) class CloudMemorystoreGetInstanceOperator(BaseOperator): @@ -420,18 +452,19 @@ class CloudMemorystoreGetInstanceOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (RedisInstanceDetailsLink(),) def __init__( self, *, location: str, instance: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -444,7 +477,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudMemorystoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -456,6 +489,13 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + RedisInstanceDetailsLink.persist( + context=context, + task_instance=self, + instance_id=self.instance, + location_id=self.location, + project_id=self.project_id or hook.project_id, + ) return Instance.to_dict(result) @@ -505,19 +545,20 @@ class CloudMemorystoreImportOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (RedisInstanceDetailsLink(),) def __init__( self, *, location: str, instance: str, - input_config: Union[Dict, InputConfig], - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + input_config: dict | InputConfig, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -531,7 +572,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudMemorystoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -544,6 +585,13 @@ def execute(self, context: 'Context') -> None: timeout=self.timeout, metadata=self.metadata, ) + RedisInstanceDetailsLink.persist( + context=context, + task_instance=self, + instance_id=self.instance, + location_id=self.location, + project_id=self.project_id or hook.project_id, + ) class CloudMemorystoreListInstancesOperator(BaseOperator): @@ -588,18 +636,19 @@ class CloudMemorystoreListInstancesOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (RedisInstanceListLink(),) def __init__( self, *, location: str, page_size: int, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -612,7 +661,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudMemorystoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -624,6 +673,11 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + RedisInstanceListLink.persist( + context=context, + task_instance=self, + project_id=self.project_id or hook.project_id, + ) instances = [Instance.to_dict(a) for a in result] return instances @@ -683,20 +737,21 @@ class CloudMemorystoreUpdateInstanceOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (RedisInstanceDetailsLink(),) def __init__( self, *, - update_mask: Union[Dict, FieldMask], - instance: Union[Dict, Instance], - location: Optional[str] = None, - instance_id: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + update_mask: dict | FieldMask, + instance: dict | Instance, + location: str | None = None, + instance_id: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -711,11 +766,11 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudMemorystoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) - hook.update_instance( + res = hook.update_instance( update_mask=self.update_mask, instance=self.instance, location=self.location, @@ -725,6 +780,15 @@ def execute(self, context: 'Context') -> None: timeout=self.timeout, metadata=self.metadata, ) + # projects/PROJECT_NAME/locations/LOCATION/instances/INSTANCE + location_id, instance_id = res.name.split("/")[-3::2] + RedisInstanceDetailsLink.persist( + context=context, + task_instance=self, + instance_id=self.instance_id or instance_id, + location_id=self.location or location_id, + project_id=self.project_id or hook.project_id, + ) class CloudMemorystoreScaleInstanceOperator(BaseOperator): @@ -767,19 +831,20 @@ class CloudMemorystoreScaleInstanceOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (RedisInstanceDetailsLink(),) def __init__( self, *, memory_size_gb: int, - location: Optional[str] = None, - instance_id: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + location: str | None = None, + instance_id: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -793,12 +858,12 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudMemorystoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) - hook.update_instance( + res = hook.update_instance( update_mask={"paths": ["memory_size_gb"]}, instance={"memory_size_gb": self.memory_size_gb}, location=self.location, @@ -808,6 +873,15 @@ def execute(self, context: 'Context') -> None: timeout=self.timeout, metadata=self.metadata, ) + # projects/PROJECT_NAME/locations/LOCATION/instances/INSTANCE + location_id, instance_id = res.name.split("/")[-3::2] + RedisInstanceDetailsLink.persist( + context=context, + task_instance=self, + instance_id=self.instance_id or instance_id, + location_id=self.location or location_id, + project_id=self.project_id or hook.project_id, + ) class CloudMemorystoreCreateInstanceAndImportOperator(BaseOperator): @@ -869,20 +943,21 @@ class CloudMemorystoreCreateInstanceAndImportOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (RedisInstanceDetailsLink(),) def __init__( self, *, location: str, instance_id: str, - instance: Union[Dict, Instance], - input_config: Union[Dict, InputConfig], - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + instance: dict | Instance, + input_config: dict | InputConfig, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -897,7 +972,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudMemorystoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -921,6 +996,13 @@ def execute(self, context: 'Context') -> None: timeout=self.timeout, metadata=self.metadata, ) + RedisInstanceDetailsLink.persist( + context=context, + task_instance=self, + instance_id=self.instance_id, + location_id=self.location, + project_id=self.project_id or hook.project_id, + ) class CloudMemorystoreExportAndDeleteInstanceOperator(BaseOperator): @@ -975,13 +1057,13 @@ def __init__( *, location: str, instance: str, - output_config: Union[Dict, OutputConfig], - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + output_config: dict | OutputConfig, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -995,7 +1077,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudMemorystoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -1055,6 +1137,7 @@ class CloudMemorystoreMemcachedApplyParametersOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (MemcachedInstanceDetailsLink(),) def __init__( self, @@ -1064,11 +1147,11 @@ def __init__( location: str, instance_id: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1083,7 +1166,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudMemorystoreMemcachedHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -1097,6 +1180,13 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + MemcachedInstanceDetailsLink.persist( + context=context, + task_instance=self, + instance_id=self.instance_id, + location_id=self.location, + project_id=self.project_id, + ) class CloudMemorystoreMemcachedCreateInstanceOperator(BaseOperator): @@ -1143,16 +1233,17 @@ class CloudMemorystoreMemcachedCreateInstanceOperator(BaseOperator): "metadata", "gcp_conn_id", ) + operator_extra_links = (MemcachedInstanceDetailsLink(),) def __init__( self, location: str, instance_id: str, - instance: Union[Dict, cloud_memcache.Instance], - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + instance: dict | cloud_memcache.Instance, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", *args, **kwargs, @@ -1167,7 +1258,7 @@ def __init__( self.metadata = metadata self.gcp_conn_id = gcp_conn_id - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudMemorystoreMemcachedHook(gcp_conn_id=self.gcp_conn_id) result = hook.create_instance( location=self.location, @@ -1178,6 +1269,13 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + MemcachedInstanceDetailsLink.persist( + context=context, + task_instance=self, + instance_id=self.instance_id, + location_id=self.location, + project_id=self.project_id or hook.project_id, + ) return cloud_memcache.Instance.to_dict(result) @@ -1215,10 +1313,10 @@ def __init__( self, location: str, instance: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", *args, **kwargs, @@ -1232,7 +1330,7 @@ def __init__( self.metadata = metadata self.gcp_conn_id = gcp_conn_id - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudMemorystoreMemcachedHook(gcp_conn_id=self.gcp_conn_id) hook.delete_instance( location=self.location, @@ -1282,18 +1380,19 @@ class CloudMemorystoreMemcachedGetInstanceOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (MemcachedInstanceDetailsLink(),) def __init__( self, *, location: str, instance: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1306,7 +1405,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudMemorystoreMemcachedHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -1318,6 +1417,13 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + MemcachedInstanceDetailsLink.persist( + context=context, + task_instance=self, + instance_id=self.instance, + location_id=self.location, + project_id=self.project_id or hook.project_id, + ) return cloud_memcache.Instance.to_dict(result) @@ -1360,17 +1466,18 @@ class CloudMemorystoreMemcachedListInstancesOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (MemcachedInstanceListLink(),) def __init__( self, *, location: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1382,7 +1489,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudMemorystoreMemcachedHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -1393,6 +1500,11 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + MemcachedInstanceListLink.persist( + context=context, + task_instance=self, + project_id=self.project_id or hook.project_id, + ) instances = [cloud_memcache.Instance.to_dict(a) for a in result] return instances @@ -1449,20 +1561,21 @@ class CloudMemorystoreMemcachedUpdateInstanceOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (MemcachedInstanceDetailsLink(),) def __init__( self, *, - update_mask: Union[Dict, FieldMask], - instance: Union[Dict, cloud_memcache.Instance], - location: Optional[str] = None, - instance_id: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + update_mask: dict | FieldMask, + instance: dict | cloud_memcache.Instance, + location: str | None = None, + instance_id: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1477,11 +1590,11 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudMemorystoreMemcachedHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) - hook.update_instance( + res = hook.update_instance( update_mask=self.update_mask, instance=self.instance, location=self.location, @@ -1491,6 +1604,15 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + # projects/PROJECT_NAME/locations/LOCATION/instances/INSTANCE + location_id, instance_id = res.name.split("/")[-3::2] + MemcachedInstanceDetailsLink.persist( + context=context, + task_instance=self, + instance_id=self.instance_id or instance_id, + location_id=self.location or location_id, + project_id=self.project_id or hook.project_id, + ) class CloudMemorystoreMemcachedUpdateParametersOperator(BaseOperator): @@ -1532,20 +1654,21 @@ class CloudMemorystoreMemcachedUpdateParametersOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (MemcachedInstanceDetailsLink(),) def __init__( self, *, - update_mask: Union[Dict, FieldMask], - parameters: Union[Dict, cloud_memcache.MemcacheParameters], + update_mask: dict | FieldMask, + parameters: dict | cloud_memcache.MemcacheParameters, location: str, instance_id: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1560,7 +1683,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudMemorystoreMemcachedHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -1574,3 +1697,10 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + MemcachedInstanceDetailsLink.persist( + context=context, + task_instance=self, + instance_id=self.instance_id, + location_id=self.location, + project_id=self.project_id, + ) diff --git a/airflow/providers/google/cloud/operators/cloud_sql.py b/airflow/providers/google/cloud/operators/cloud_sql.py index d6c9fc7c10c51..494c9eaa5aeb3 100644 --- a/airflow/providers/google/cloud/operators/cloud_sql.py +++ b/airflow/providers/google/cloud/operators/cloud_sql.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Cloud SQL operators.""" -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable, Mapping, Sequence from googleapiclient.errors import HttpError @@ -26,6 +28,7 @@ from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLDatabaseHook, CloudSQLHook from airflow.providers.google.cloud.links.cloud_sql import CloudSQLInstanceDatabaseLink, CloudSQLInstanceLink from airflow.providers.google.cloud.utils.field_validator import GcpBodyFieldValidator +from airflow.providers.google.common.hooks.base_google import get_field from airflow.providers.google.common.links.storage import FileDetailsLink from airflow.providers.mysql.hooks.mysql import MySqlHook from airflow.providers.postgres.hooks.postgres import PostgresHook @@ -34,10 +37,10 @@ from airflow.utils.context import Context -SETTINGS = 'settings' -SETTINGS_VERSION = 'settingsVersion' +SETTINGS = "settings" +SETTINGS_VERSION = "settingsVersion" -CLOUD_SQL_CREATE_VALIDATION = [ +CLOUD_SQL_CREATE_VALIDATION: Sequence[dict] = [ dict(name="name", allow_empty=False), dict( name="settings", @@ -233,10 +236,10 @@ def __init__( self, *, instance: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1beta4', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1beta4", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.project_id = project_id @@ -248,12 +251,12 @@ def __init__( super().__init__(**kwargs) def _validate_inputs(self) -> None: - if self.project_id == '': + if self.project_id == "": raise AirflowException("The required parameter 'project_id' is empty") if not self.instance: raise AirflowException("The required parameter 'instance' is empty or None") - def _check_if_instance_exists(self, instance, hook: CloudSQLHook) -> Union[dict, bool]: + def _check_if_instance_exists(self, instance, hook: CloudSQLHook) -> dict | bool: try: return hook.get_instance(project_id=self.project_id, instance=instance) except HttpError as e: @@ -262,7 +265,7 @@ def _check_if_instance_exists(self, instance, hook: CloudSQLHook) -> Union[dict, return False raise e - def _check_if_db_exists(self, db_name, hook: CloudSQLHook) -> Union[dict, bool]: + def _check_if_db_exists(self, db_name, hook: CloudSQLHook) -> dict | bool: try: return hook.get_database(project_id=self.project_id, instance=self.instance, database=db_name) except HttpError as e: @@ -271,7 +274,7 @@ def _check_if_db_exists(self, db_name, hook: CloudSQLHook) -> Union[dict, bool]: return False raise e - def execute(self, context: 'Context'): + def execute(self, context: Context): pass @staticmethod @@ -310,15 +313,15 @@ class CloudSQLCreateInstanceOperator(CloudSQLBaseOperator): # [START gcp_sql_create_template_fields] template_fields: Sequence[str] = ( - 'project_id', - 'instance', - 'body', - 'gcp_conn_id', - 'api_version', - 'impersonation_chain', + "project_id", + "instance", + "body", + "gcp_conn_id", + "api_version", + "impersonation_chain", ) # [END gcp_sql_create_template_fields] - ui_color = '#FADBDA' + ui_color = "#FADBDA" operator_extra_links = (CloudSQLInstanceLink(),) def __init__( @@ -326,11 +329,11 @@ def __init__( *, body: dict, instance: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1beta4', + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1beta4", validate_body: bool = True, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.body = body @@ -355,7 +358,7 @@ def _validate_body_fields(self) -> None: self.body ) - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudSQLHook( gcp_conn_id=self.gcp_conn_id, api_version=self.api_version, @@ -376,7 +379,7 @@ def execute(self, context: 'Context') -> None: instance_resource = hook.get_instance(project_id=self.project_id, instance=self.instance) service_account_email = instance_resource["serviceAccountEmailAddress"] - task_instance = context['task_instance'] + task_instance = context["task_instance"] task_instance.xcom_push(key="service_account_email", value=service_account_email) @@ -414,15 +417,15 @@ class CloudSQLInstancePatchOperator(CloudSQLBaseOperator): # [START gcp_sql_patch_template_fields] template_fields: Sequence[str] = ( - 'project_id', - 'instance', - 'body', - 'gcp_conn_id', - 'api_version', - 'impersonation_chain', + "project_id", + "instance", + "body", + "gcp_conn_id", + "api_version", + "impersonation_chain", ) # [END gcp_sql_patch_template_fields] - ui_color = '#FBDAC8' + ui_color = "#FBDAC8" operator_extra_links = (CloudSQLInstanceLink(),) def __init__( @@ -430,10 +433,10 @@ def __init__( *, body: dict, instance: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1beta4', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1beta4", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.body = body @@ -451,7 +454,7 @@ def _validate_inputs(self) -> None: if not self.body: raise AirflowException("The required parameter 'body' is empty") - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudSQLHook( gcp_conn_id=self.gcp_conn_id, api_version=self.api_version, @@ -459,8 +462,8 @@ def execute(self, context: 'Context'): ) if not self._check_if_instance_exists(self.instance, hook): raise AirflowException( - f'Cloud SQL instance with ID {self.instance} does not exist. ' - 'Please specify another instance to patch.' + f"Cloud SQL instance with ID {self.instance} does not exist. " + "Please specify another instance to patch." ) else: CloudSQLInstanceLink.persist( @@ -498,16 +501,16 @@ class CloudSQLDeleteInstanceOperator(CloudSQLBaseOperator): # [START gcp_sql_delete_template_fields] template_fields: Sequence[str] = ( - 'project_id', - 'instance', - 'gcp_conn_id', - 'api_version', - 'impersonation_chain', + "project_id", + "instance", + "gcp_conn_id", + "api_version", + "impersonation_chain", ) # [END gcp_sql_delete_template_fields] - ui_color = '#FEECD2' + ui_color = "#FEECD2" - def execute(self, context: 'Context') -> Optional[bool]: + def execute(self, context: Context) -> bool | None: hook = CloudSQLHook( gcp_conn_id=self.gcp_conn_id, api_version=self.api_version, @@ -548,15 +551,15 @@ class CloudSQLCreateInstanceDatabaseOperator(CloudSQLBaseOperator): # [START gcp_sql_db_create_template_fields] template_fields: Sequence[str] = ( - 'project_id', - 'instance', - 'body', - 'gcp_conn_id', - 'api_version', - 'impersonation_chain', + "project_id", + "instance", + "body", + "gcp_conn_id", + "api_version", + "impersonation_chain", ) # [END gcp_sql_db_create_template_fields] - ui_color = '#FFFCDB' + ui_color = "#FFFCDB" operator_extra_links = (CloudSQLInstanceDatabaseLink(),) def __init__( @@ -564,11 +567,11 @@ def __init__( *, instance: str, body: dict, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1beta4', + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1beta4", validate_body: bool = True, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.body = body @@ -593,7 +596,7 @@ def _validate_body_fields(self) -> None: CLOUD_SQL_DATABASE_CREATE_VALIDATION, api_version=self.api_version ).validate(self.body) - def execute(self, context: 'Context') -> Optional[bool]: + def execute(self, context: Context) -> bool | None: self._validate_body_fields() database = self.body.get("name") if not database: @@ -655,16 +658,16 @@ class CloudSQLPatchInstanceDatabaseOperator(CloudSQLBaseOperator): # [START gcp_sql_db_patch_template_fields] template_fields: Sequence[str] = ( - 'project_id', - 'instance', - 'body', - 'database', - 'gcp_conn_id', - 'api_version', - 'impersonation_chain', + "project_id", + "instance", + "body", + "database", + "gcp_conn_id", + "api_version", + "impersonation_chain", ) # [END gcp_sql_db_patch_template_fields] - ui_color = '#ECF4D9' + ui_color = "#ECF4D9" operator_extra_links = (CloudSQLInstanceDatabaseLink(),) def __init__( @@ -673,11 +676,11 @@ def __init__( instance: str, database: str, body: dict, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1beta4', + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1beta4", validate_body: bool = True, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.database = database @@ -705,7 +708,7 @@ def _validate_body_fields(self) -> None: self.body ) - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: self._validate_body_fields() hook = CloudSQLHook( gcp_conn_id=self.gcp_conn_id, @@ -755,25 +758,25 @@ class CloudSQLDeleteInstanceDatabaseOperator(CloudSQLBaseOperator): # [START gcp_sql_db_delete_template_fields] template_fields: Sequence[str] = ( - 'project_id', - 'instance', - 'database', - 'gcp_conn_id', - 'api_version', - 'impersonation_chain', + "project_id", + "instance", + "database", + "gcp_conn_id", + "api_version", + "impersonation_chain", ) # [END gcp_sql_db_delete_template_fields] - ui_color = '#D5EAD8' + ui_color = "#D5EAD8" def __init__( self, *, instance: str, database: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1beta4', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1beta4", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.database = database @@ -791,7 +794,7 @@ def _validate_inputs(self) -> None: if not self.database: raise AirflowException("The required parameter 'database' is empty") - def execute(self, context: 'Context') -> Optional[bool]: + def execute(self, context: Context) -> bool | None: hook = CloudSQLHook( gcp_conn_id=self.gcp_conn_id, api_version=self.api_version, @@ -841,15 +844,15 @@ class CloudSQLExportInstanceOperator(CloudSQLBaseOperator): # [START gcp_sql_export_template_fields] template_fields: Sequence[str] = ( - 'project_id', - 'instance', - 'body', - 'gcp_conn_id', - 'api_version', - 'impersonation_chain', + "project_id", + "instance", + "body", + "gcp_conn_id", + "api_version", + "impersonation_chain", ) # [END gcp_sql_export_template_fields] - ui_color = '#D4ECEA' + ui_color = "#D4ECEA" operator_extra_links = (CloudSQLInstanceLink(), FileDetailsLink()) def __init__( @@ -857,11 +860,11 @@ def __init__( *, instance: str, body: dict, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1beta4', + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1beta4", validate_body: bool = True, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.body = body @@ -886,7 +889,7 @@ def _validate_body_fields(self) -> None: self.body ) - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: self._validate_body_fields() hook = CloudSQLHook( gcp_conn_id=self.gcp_conn_id, @@ -952,15 +955,15 @@ class CloudSQLImportInstanceOperator(CloudSQLBaseOperator): # [START gcp_sql_import_template_fields] template_fields: Sequence[str] = ( - 'project_id', - 'instance', - 'body', - 'gcp_conn_id', - 'api_version', - 'impersonation_chain', + "project_id", + "instance", + "body", + "gcp_conn_id", + "api_version", + "impersonation_chain", ) # [END gcp_sql_import_template_fields] - ui_color = '#D3EDFB' + ui_color = "#D3EDFB" operator_extra_links = (CloudSQLInstanceLink(), FileDetailsLink()) def __init__( @@ -968,11 +971,11 @@ def __init__( *, instance: str, body: dict, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1beta4', + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1beta4", validate_body: bool = True, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.body = body @@ -997,7 +1000,7 @@ def _validate_body_fields(self) -> None: self.body ) - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: self._validate_body_fields() hook = CloudSQLHook( gcp_conn_id=self.gcp_conn_id, @@ -1045,20 +1048,20 @@ class CloudSQLExecuteQueryOperator(BaseOperator): """ # [START gcp_sql_query_template_fields] - template_fields: Sequence[str] = ('sql', 'gcp_cloudsql_conn_id', 'gcp_conn_id') - template_ext: Sequence[str] = ('.sql',) - template_fields_renderers = {'sql': 'sql'} + template_fields: Sequence[str] = ("sql", "gcp_cloudsql_conn_id", "gcp_conn_id") + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "sql"} # [END gcp_sql_query_template_fields] - ui_color = '#D3DEF1' + ui_color = "#D3DEF1" def __init__( self, *, - sql: Union[List[str], str], + sql: str | Iterable[str], autocommit: bool = False, - parameters: Optional[Union[Dict, Iterable]] = None, - gcp_conn_id: str = 'google_cloud_default', - gcp_cloudsql_conn_id: str = 'google_cloud_sql_default', + parameters: Iterable | Mapping | None = None, + gcp_conn_id: str = "google_cloud_default", + gcp_cloudsql_conn_id: str = "google_cloud_sql_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -1067,11 +1070,9 @@ def __init__( self.gcp_cloudsql_conn_id = gcp_cloudsql_conn_id self.autocommit = autocommit self.parameters = parameters - self.gcp_connection: Optional[Connection] = None + self.gcp_connection: Connection | None = None - def _execute_query( - self, hook: CloudSQLDatabaseHook, database_hook: Union[PostgresHook, MySqlHook] - ) -> None: + def _execute_query(self, hook: CloudSQLDatabaseHook, database_hook: PostgresHook | MySqlHook) -> None: cloud_sql_proxy_runner = None try: if hook.use_proxy: @@ -1087,14 +1088,12 @@ def _execute_query( if cloud_sql_proxy_runner: cloud_sql_proxy_runner.stop_proxy() - def execute(self, context: 'Context'): + def execute(self, context: Context): self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id) hook = CloudSQLDatabaseHook( gcp_cloudsql_conn_id=self.gcp_cloudsql_conn_id, gcp_conn_id=self.gcp_conn_id, - default_gcp_project_id=self.gcp_connection.extra_dejson.get( - 'extra__google_cloud_platform__project' - ), + default_gcp_project_id=get_field(self.gcp_connection.extra_dejson, "project"), ) hook.validate_ssl_certs() connection = hook.create_connection() diff --git a/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py b/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py index 7a9bc4d4b5e8f..d90e6da6387d8 100644 --- a/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +++ b/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py @@ -15,11 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains Google Cloud Transfer operators.""" +from __future__ import annotations + from copy import deepcopy from datetime import date, time -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -54,6 +55,11 @@ CloudDataTransferServiceHook, GcpTransferJobsStatus, ) +from airflow.providers.google.cloud.links.cloud_storage_transfer import ( + CloudStorageTransferDetailsLink, + CloudStorageTransferJobLink, + CloudStorageTransferListLink, +) from airflow.providers.google.cloud.utils.helpers import normalize_directory_path if TYPE_CHECKING: @@ -63,7 +69,7 @@ class TransferJobPreprocessor: """Helper class for preprocess of transfer job body.""" - def __init__(self, body: dict, aws_conn_id: str = 'aws_default', default_schedule: bool = False) -> None: + def __init__(self, body: dict, aws_conn_id: str = "aws_default", default_schedule: bool = False) -> None: self.body = body self.aws_conn_id = aws_conn_id self.default_schedule = default_schedule @@ -109,7 +115,6 @@ def process_body(self) -> dict: reformats schedule information. :return: Preprocessed body - :rtype: dict """ self._inject_aws_credentials() self._reformat_schedule() @@ -211,21 +216,23 @@ class CloudDataTransferServiceCreateJobOperator(BaseOperator): # [START gcp_transfer_job_create_template_fields] template_fields: Sequence[str] = ( - 'body', - 'gcp_conn_id', - 'aws_conn_id', - 'google_impersonation_chain', + "body", + "gcp_conn_id", + "aws_conn_id", + "google_impersonation_chain", ) # [END gcp_transfer_job_create_template_fields] + operator_extra_links = (CloudStorageTransferJobLink(),) def __init__( self, *, body: dict, - aws_conn_id: str = 'aws_default', - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1', - google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + aws_conn_id: str = "aws_default", + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + project_id: str | None = None, + google_impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -233,20 +240,32 @@ def __init__( self.aws_conn_id = aws_conn_id self.gcp_conn_id = gcp_conn_id self.api_version = api_version + self.project_id = project_id self.google_impersonation_chain = google_impersonation_chain self._validate_inputs() def _validate_inputs(self) -> None: TransferJobValidator(body=self.body).validate_body() - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: TransferJobPreprocessor(body=self.body, aws_conn_id=self.aws_conn_id).process_body() hook = CloudDataTransferServiceHook( api_version=self.api_version, gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.google_impersonation_chain, ) - return hook.create_transfer_job(body=self.body) + result = hook.create_transfer_job(body=self.body) + + project_id = self.project_id or hook.project_id + if project_id: + CloudStorageTransferJobLink.persist( + context=context, + task_instance=self, + project_id=project_id, + job_name=result[NAME], + ) + + return result class CloudDataTransferServiceUpdateJobOperator(BaseOperator): @@ -283,28 +302,31 @@ class CloudDataTransferServiceUpdateJobOperator(BaseOperator): # [START gcp_transfer_job_update_template_fields] template_fields: Sequence[str] = ( - 'job_name', - 'body', - 'gcp_conn_id', - 'aws_conn_id', - 'google_impersonation_chain', + "job_name", + "body", + "gcp_conn_id", + "aws_conn_id", + "google_impersonation_chain", ) # [END gcp_transfer_job_update_template_fields] + operator_extra_links = (CloudStorageTransferJobLink(),) def __init__( self, *, job_name: str, body: dict, - aws_conn_id: str = 'aws_default', - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1', - google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + aws_conn_id: str = "aws_default", + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + project_id: str | None = None, + google_impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) self.job_name = job_name self.body = body + self.project_id = project_id self.gcp_conn_id = gcp_conn_id self.api_version = api_version self.aws_conn_id = aws_conn_id @@ -316,13 +338,23 @@ def _validate_inputs(self) -> None: if not self.job_name: raise AirflowException("The required parameter 'job_name' is empty or None") - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: TransferJobPreprocessor(body=self.body, aws_conn_id=self.aws_conn_id).process_body() hook = CloudDataTransferServiceHook( api_version=self.api_version, gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.google_impersonation_chain, ) + + project_id = self.project_id or hook.project_id + if project_id: + CloudStorageTransferJobLink.persist( + context=context, + task_instance=self, + project_id=project_id, + job_name=self.job_name, + ) + return hook.update_transfer_job(job_name=self.job_name, body=self.body) @@ -355,11 +387,11 @@ class CloudDataTransferServiceDeleteJobOperator(BaseOperator): # [START gcp_transfer_job_delete_template_fields] template_fields: Sequence[str] = ( - 'job_name', - 'project_id', - 'gcp_conn_id', - 'api_version', - 'google_impersonation_chain', + "job_name", + "project_id", + "gcp_conn_id", + "api_version", + "google_impersonation_chain", ) # [END gcp_transfer_job_delete_template_fields] @@ -369,8 +401,8 @@ def __init__( job_name: str, gcp_conn_id: str = "google_cloud_default", api_version: str = "v1", - project_id: Optional[str] = None, - google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + google_impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -385,7 +417,7 @@ def _validate_inputs(self) -> None: if not self.job_name: raise AirflowException("The required parameter 'job_name' is empty or None") - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: self._validate_inputs() hook = CloudDataTransferServiceHook( api_version=self.api_version, @@ -420,23 +452,26 @@ class CloudDataTransferServiceGetOperationOperator(BaseOperator): # [START gcp_transfer_operation_get_template_fields] template_fields: Sequence[str] = ( - 'operation_name', - 'gcp_conn_id', - 'google_impersonation_chain', + "operation_name", + "gcp_conn_id", + "google_impersonation_chain", ) # [END gcp_transfer_operation_get_template_fields] + operator_extra_links = (CloudStorageTransferDetailsLink(),) def __init__( self, *, + project_id: str | None = None, operation_name: str, gcp_conn_id: str = "google_cloud_default", api_version: str = "v1", - google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + google_impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) self.operation_name = operation_name + self.project_id = project_id self.gcp_conn_id = gcp_conn_id self.api_version = api_version self.google_impersonation_chain = google_impersonation_chain @@ -446,13 +481,23 @@ def _validate_inputs(self) -> None: if not self.operation_name: raise AirflowException("The required parameter 'operation_name' is empty or None") - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: hook = CloudDataTransferServiceHook( api_version=self.api_version, gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.google_impersonation_chain, ) operation = hook.get_transfer_operation(operation_name=self.operation_name) + + project_id = self.project_id or hook.project_id + if project_id: + CloudStorageTransferDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + operation_name=self.operation_name, + ) + return operation @@ -482,31 +527,34 @@ class CloudDataTransferServiceListOperationsOperator(BaseOperator): # [START gcp_transfer_operations_list_template_fields] template_fields: Sequence[str] = ( - 'filter', - 'gcp_conn_id', - 'google_impersonation_chain', + "filter", + "gcp_conn_id", + "google_impersonation_chain", ) # [END gcp_transfer_operations_list_template_fields] + operator_extra_links = (CloudStorageTransferListLink(),) def __init__( self, - request_filter: Optional[Dict] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1', - google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + request_filter: dict | None = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + google_impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: # To preserve backward compatibility # TODO: remove one day if request_filter is None: - if 'filter' in kwargs: - request_filter = kwargs['filter'] + if "filter" in kwargs: + request_filter = kwargs["filter"] DeprecationWarning("Use 'request_filter' instead 'filter' to pass the argument.") else: TypeError("__init__() missing 1 required positional argument: 'request_filter'") super().__init__(**kwargs) self.filter = request_filter + self.project_id = project_id self.gcp_conn_id = gcp_conn_id self.api_version = api_version self.google_impersonation_chain = google_impersonation_chain @@ -516,7 +564,7 @@ def _validate_inputs(self) -> None: if not self.filter: raise AirflowException("The required parameter 'filter' is empty or None") - def execute(self, context: 'Context') -> List[dict]: + def execute(self, context: Context) -> list[dict]: hook = CloudDataTransferServiceHook( api_version=self.api_version, gcp_conn_id=self.gcp_conn_id, @@ -524,6 +572,15 @@ def execute(self, context: 'Context') -> List[dict]: ) operations_list = hook.list_transfer_operations(request_filter=self.filter) self.log.info(operations_list) + + project_id = self.project_id or hook.project_id + if project_id: + CloudStorageTransferListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) + return operations_list @@ -550,10 +607,10 @@ class CloudDataTransferServicePauseOperationOperator(BaseOperator): # [START gcp_transfer_operation_pause_template_fields] template_fields: Sequence[str] = ( - 'operation_name', - 'gcp_conn_id', - 'api_version', - 'google_impersonation_chain', + "operation_name", + "gcp_conn_id", + "api_version", + "google_impersonation_chain", ) # [END gcp_transfer_operation_pause_template_fields] @@ -563,7 +620,7 @@ def __init__( operation_name: str, gcp_conn_id: str = "google_cloud_default", api_version: str = "v1", - google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + google_impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -577,7 +634,7 @@ def _validate_inputs(self) -> None: if not self.operation_name: raise AirflowException("The required parameter 'operation_name' is empty or None") - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudDataTransferServiceHook( api_version=self.api_version, gcp_conn_id=self.gcp_conn_id, @@ -609,10 +666,10 @@ class CloudDataTransferServiceResumeOperationOperator(BaseOperator): # [START gcp_transfer_operation_resume_template_fields] template_fields: Sequence[str] = ( - 'operation_name', - 'gcp_conn_id', - 'api_version', - 'google_impersonation_chain', + "operation_name", + "gcp_conn_id", + "api_version", + "google_impersonation_chain", ) # [END gcp_transfer_operation_resume_template_fields] @@ -622,7 +679,7 @@ def __init__( operation_name: str, gcp_conn_id: str = "google_cloud_default", api_version: str = "v1", - google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + google_impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.operation_name = operation_name @@ -636,7 +693,7 @@ def _validate_inputs(self) -> None: if not self.operation_name: raise AirflowException("The required parameter 'operation_name' is empty or None") - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudDataTransferServiceHook( api_version=self.api_version, gcp_conn_id=self.gcp_conn_id, @@ -669,10 +726,10 @@ class CloudDataTransferServiceCancelOperationOperator(BaseOperator): # [START gcp_transfer_operation_cancel_template_fields] template_fields: Sequence[str] = ( - 'operation_name', - 'gcp_conn_id', - 'api_version', - 'google_impersonation_chain', + "operation_name", + "gcp_conn_id", + "api_version", + "google_impersonation_chain", ) # [END gcp_transfer_operation_cancel_template_fields] @@ -682,7 +739,7 @@ def __init__( operation_name: str, gcp_conn_id: str = "google_cloud_default", api_version: str = "v1", - google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + google_impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -696,7 +753,7 @@ def _validate_inputs(self) -> None: if not self.operation_name: raise AirflowException("The required parameter 'operation_name' is empty or None") - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudDataTransferServiceHook( api_version=self.api_version, gcp_conn_id=self.gcp_conn_id, @@ -770,35 +827,35 @@ class CloudDataTransferServiceS3ToGCSOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'gcp_conn_id', - 's3_bucket', - 'gcs_bucket', - 's3_path', - 'gcs_path', - 'description', - 'object_conditions', - 'google_impersonation_chain', + "gcp_conn_id", + "s3_bucket", + "gcs_bucket", + "s3_path", + "gcs_path", + "description", + "object_conditions", + "google_impersonation_chain", ) - ui_color = '#e09411' + ui_color = "#e09411" def __init__( self, *, s3_bucket: str, gcs_bucket: str, - s3_path: Optional[str] = None, - gcs_path: Optional[str] = None, - project_id: Optional[str] = None, - aws_conn_id: str = 'aws_default', - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - description: Optional[str] = None, - schedule: Optional[Dict] = None, - object_conditions: Optional[Dict] = None, - transfer_options: Optional[Dict] = None, + s3_path: str | None = None, + gcs_path: str | None = None, + project_id: str | None = None, + aws_conn_id: str = "aws_default", + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + description: str | None = None, + schedule: dict | None = None, + object_conditions: dict | None = None, + transfer_options: dict | None = None, wait: bool = True, - timeout: Optional[float] = None, - google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + timeout: float | None = None, + google_impersonation_chain: str | Sequence[str] | None = None, delete_job_after_completion: bool = False, **kwargs, ) -> None: @@ -826,7 +883,7 @@ def _validate_inputs(self) -> None: if self.delete_job_after_completion and not self.wait: raise AirflowException("If 'delete_job_after_completion' is True, then 'wait' must also be True.") - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudDataTransferServiceHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -942,34 +999,34 @@ class CloudDataTransferServiceGCSToGCSOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'gcp_conn_id', - 'source_bucket', - 'destination_bucket', - 'source_path', - 'destination_path', - 'description', - 'object_conditions', - 'google_impersonation_chain', + "gcp_conn_id", + "source_bucket", + "destination_bucket", + "source_path", + "destination_path", + "description", + "object_conditions", + "google_impersonation_chain", ) - ui_color = '#e09411' + ui_color = "#e09411" def __init__( self, *, source_bucket: str, destination_bucket: str, - source_path: Optional[str] = None, - destination_path: Optional[str] = None, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - description: Optional[str] = None, - schedule: Optional[Dict] = None, - object_conditions: Optional[Dict] = None, - transfer_options: Optional[Dict] = None, + source_path: str | None = None, + destination_path: str | None = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + description: str | None = None, + schedule: dict | None = None, + object_conditions: dict | None = None, + transfer_options: dict | None = None, wait: bool = True, - timeout: Optional[float] = None, - google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + timeout: float | None = None, + google_impersonation_chain: str | Sequence[str] | None = None, delete_job_after_completion: bool = False, **kwargs, ) -> None: @@ -996,7 +1053,7 @@ def _validate_inputs(self) -> None: if self.delete_job_after_completion and not self.wait: raise AirflowException("If 'delete_job_after_completion' is True, then 'wait' must also be True.") - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudDataTransferServiceHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/google/cloud/operators/compute.py b/airflow/providers/google/cloud/operators/compute.py index 7b45ede859918..37a6761b0b791 100644 --- a/airflow/providers/google/cloud/operators/compute.py +++ b/airflow/providers/google/cloud/operators/compute.py @@ -16,16 +16,24 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Compute Engine operators.""" +from __future__ import annotations from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Sequence -from googleapiclient.errors import HttpError +from google.api_core import exceptions +from google.api_core.retry import Retry +from google.cloud.compute_v1.types import Instance, InstanceGroupManager, InstanceTemplate from json_merge_patch import merge from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.compute import ComputeEngineHook +from airflow.providers.google.cloud.links.compute import ( + ComputeInstanceDetailsLink, + ComputeInstanceGroupManagerDetailsLink, + ComputeInstanceTemplateDetailsLink, +) from airflow.providers.google.cloud.utils.field_sanitizer import GcpBodyFieldSanitizer from airflow.providers.google.cloud.utils.field_validator import GcpBodyFieldValidator @@ -41,10 +49,10 @@ def __init__( *, zone: str, resource_id: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.project_id = project_id @@ -57,17 +65,479 @@ def __init__( super().__init__(**kwargs) def _validate_inputs(self) -> None: - if self.project_id == '': + if self.project_id == "": raise AirflowException("The required parameter 'project_id' is missing") if not self.zone: raise AirflowException("The required parameter 'zone' is missing") - if not self.resource_id: - raise AirflowException("The required parameter 'resource_id' is missing") - def execute(self, context: 'Context'): + def execute(self, context: Context): pass +class ComputeEngineInsertInstanceOperator(ComputeEngineBaseOperator): + """ + Creates an Instance in Google Compute Engine based on specified parameters. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:ComputeEngineInsertInstanceOperator` + + :param body: Instance representation as an object. Should at least include 'name', 'machine_type', + 'disks' and 'network_interfaces' fields but doesn't include 'zone' field, as it will be specified + in 'zone' parameter. + Full or partial URL and can be represented as examples below: + 1. "machine_type": "projects/your-project-name/zones/your-zone/machineTypes/your-machine-type" + 2. "disk_type": "projects/your-project-name/zones/your-zone/diskTypes/your-disk-type" + 3. "subnetwork": "projects/your-project-name/regions/your-region/subnetworks/your-subnetwork" + :param zone: Google Cloud zone where the Instance exists + :param project_id: Google Cloud project ID where the Compute Engine Instance exists. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param resource_id: Name of the Instance. If the name of Instance is not specified in body['name'], + the name will be taken from 'resource_id' parameter + :param request_id: Unique request_id that you might add to achieve + full idempotence (for example when client call times out repeating the request + with the same request id will not create a new instance template again) + It should be in UUID format as defined in RFC 4122 + :param gcp_conn_id: The connection ID used to connect to Google Cloud. Defaults to 'google_cloud_default'. + :param api_version: API version used (for example v1 - or beta). Defaults to v1. + :param impersonation_chain: Service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + + operator_extra_links = (ComputeInstanceDetailsLink(),) + + # [START gce_instance_insert_fields] + template_fields: Sequence[str] = ( + "body", + "project_id", + "zone", + "request_id", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gce_instance_insert_fields] + + def __init__( + self, + *, + body: dict, + zone: str, + resource_id: str | None = None, + project_id: str | None = None, + request_id: str | None = None, + retry: Retry | None = None, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + validate_body: bool = True, + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + self.body = body + self.zone = zone + self.request_id = request_id + self.resource_id = self.body["name"] if "name" in body else resource_id + self._field_validator = None # Optional[GcpBodyFieldValidator] + self.retry = retry + self.timeout = timeout + self.metadata = metadata + + if validate_body: + self._field_validator = GcpBodyFieldValidator( + GCE_INSTANCE_TEMPLATE_VALIDATION_PATCH_SPECIFICATION, api_version=api_version + ) + self._field_sanitizer = GcpBodyFieldSanitizer(GCE_INSTANCE_FIELDS_TO_SANITIZE) + super().__init__( + resource_id=self.resource_id, + zone=zone, + project_id=project_id, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def check_body_fields(self) -> None: + required_params = ["machine_type", "disks", "network_interfaces"] + for param in required_params: + if param in self.body: + continue + readable_param = param.replace("_", " ") + raise AirflowException( + f"The body '{self.body}' should contain at least {readable_param} for the new operator " + f"in the '{param}' field. Check (google.cloud.compute_v1.types.Instance) " + f"for more details about body fields description." + ) + + def _validate_inputs(self) -> None: + super()._validate_inputs() + if not self.resource_id and "name" not in self.body: + raise AirflowException( + "The required parameters 'resource_id' and body['name'] are missing. " + "Please, provide at least one of them." + ) + + def _validate_all_body_fields(self) -> None: + if self._field_validator: + self._field_validator.validate(self.body) + + def execute(self, context: Context) -> dict: + hook = ComputeEngineHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self._validate_all_body_fields() + self.check_body_fields() + try: + # Idempotence check (sort of) - we want to check if the new Instance + # is already created and if is, then we assume it was created previously - we do + # not check if content of the Instance is as expected. + # We assume success if the Instance is simply present. + existing_instance = hook.get_instance( + resource_id=self.resource_id, + project_id=self.project_id, + zone=self.zone, + ) + except exceptions.NotFound as e: + # We actually expect to get 404 / Not Found here as the should not yet exist + if not e.code == 404: + raise e + else: + self.log.info("The %s Instance already exists", self.resource_id) + ComputeInstanceDetailsLink.persist( + context=context, + task_instance=self, + location_id=self.zone, + resource_id=self.resource_id, + project_id=self.project_id or hook.project_id, + ) + return Instance.to_dict(existing_instance) + self._field_sanitizer.sanitize(self.body) + self.log.info("Creating Instance with specified body: %s", self.body) + hook.insert_instance( + body=self.body, + request_id=self.request_id, + project_id=self.project_id, + zone=self.zone, + ) + self.log.info("The specified Instance has been created SUCCESSFULLY") + new_instance = hook.get_instance( + resource_id=self.resource_id, + project_id=self.project_id, + zone=self.zone, + ) + ComputeInstanceDetailsLink.persist( + context=context, + task_instance=self, + location_id=self.zone, + resource_id=self.resource_id, + project_id=self.project_id or hook.project_id, + ) + return Instance.to_dict(new_instance) + + +class ComputeEngineInsertInstanceFromTemplateOperator(ComputeEngineBaseOperator): + """ + Creates an Instance in Google Compute Engine based on specified parameters from existing Template. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:ComputeEngineInsertInstanceFromTemplateOperator` + + :param body: Instance representation as object. For this Operator only 'name' parameter is required for + creating new Instance since all other parameters will be passed through the Template. + :param source_instance_template: Existing Instance Template that will be used as a base while creating + new Instance. When specified, only name of new Instance should be provided as input arguments in + 'body' parameter when creating new Instance. All other parameters, such as 'machine_type', 'disks' + and 'network_interfaces' will be passed to Instance as they are specified in the Instance Template. + Full or partial URL and can be represented as examples below: + 1. "https://www.googleapis.com/compute/v1/projects/your-project-name/global/instanceTemplates/temp" + 2. "projects/your-project-name/global/instanceTemplates/temp" + 3. "global/instanceTemplates/temp" + :param zone: Google Cloud zone where the instance exists. + :param project_id: Google Cloud project ID where the Compute Engine Instance exists. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param resource_id: Name of the Instance. If the name of Instance is not specified in body['name'], + the name will be taken from 'resource_id' parameter + :param request_id: Unique request_id that you might add to achieve + full idempotence (for example when client call times out repeating the request + with the same request id will not create a new instance template again) + It should be in UUID format as defined in RFC 4122 + :param gcp_conn_id: The connection ID used to connect to Google Cloud. Defaults to 'google_cloud_default'. + :param api_version: API version used (for example v1 - or beta). Defaults to v1. + :param impersonation_chain: Service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + + operator_extra_links = (ComputeInstanceDetailsLink(),) + + # [START gce_instance_insert_from_template_fields] + template_fields: Sequence[str] = ( + "body", + "source_instance_template", + "project_id", + "zone", + "request_id", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gce_instance_insert_from_template_fields] + + def __init__( + self, + *, + source_instance_template: str, + body: dict, + zone: str, + resource_id: str | None = None, + project_id: str | None = None, + request_id: str | None = None, + retry: Retry | None = None, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + validate_body: bool = True, + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + self.source_instance_template = source_instance_template + self.body = body + self.zone = zone + self.resource_id = self.body["name"] if "name" in body else resource_id + self.request_id = request_id + self._field_validator = None # Optional[GcpBodyFieldValidator] + self.retry = retry + self.timeout = timeout + self.metadata = metadata + + if validate_body: + self._field_validator = GcpBodyFieldValidator( + GCE_INSTANCE_TEMPLATE_VALIDATION_PATCH_SPECIFICATION, api_version=api_version + ) + self._field_sanitizer = GcpBodyFieldSanitizer(GCE_INSTANCE_FIELDS_TO_SANITIZE) + super().__init__( + resource_id=self.resource_id, + zone=zone, + project_id=project_id, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def _validate_all_body_fields(self) -> None: + if self._field_validator: + self._field_validator.validate(self.body) + + def _validate_inputs(self) -> None: + super()._validate_inputs() + if not self.resource_id and "name" not in self.body: + raise AirflowException( + "The required parameters 'resource_id' and body['name'] are missing. " + "Please, provide at least one of them." + ) + + def execute(self, context: Context) -> dict: + hook = ComputeEngineHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self._validate_all_body_fields() + try: + # Idempotence check (sort of) - we want to check if the new Instance + # is already created and if is, then we assume it was created - we do + # not check if content of the Instance is as expected. + # We assume success if the Instance is simply present + existing_instance = hook.get_instance( + resource_id=self.resource_id, + project_id=self.project_id, + zone=self.zone, + ) + except exceptions.NotFound as e: + # We actually expect to get 404 / Not Found here as the template should + # not yet exist + if not e.code == 404: + raise e + else: + self.log.info("The %s Instance already exists", self.resource_id) + ComputeInstanceDetailsLink.persist( + context=context, + task_instance=self, + location_id=self.zone, + resource_id=self.resource_id, + project_id=self.project_id or hook.project_id, + ) + return Instance.to_dict(existing_instance) + self._field_sanitizer.sanitize(self.body) + self.log.info("Creating Instance with specified body: %s", self.body) + hook.insert_instance( + body=self.body, + request_id=self.request_id, + project_id=self.project_id, + zone=self.zone, + source_instance_template=self.source_instance_template, + ) + self.log.info("The specified Instance has been created SUCCESSFULLY") + new_instance_from_template = hook.get_instance( + resource_id=self.resource_id, + project_id=self.project_id, + zone=self.zone, + ) + ComputeInstanceDetailsLink.persist( + context=context, + task_instance=self, + location_id=self.zone, + resource_id=self.resource_id, + project_id=self.project_id or hook.project_id, + ) + return Instance.to_dict(new_instance_from_template) + + +class ComputeEngineDeleteInstanceOperator(ComputeEngineBaseOperator): + """ + Deletes an Instance in Google Compute Engine. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:ComputeEngineDeleteInstanceOperator` + + :param project_id: Google Cloud project ID where the Compute Engine Instance exists. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param zone: Google Cloud zone where the instance exists. + :param resource_id: Name of the Instance. + :param request_id: Unique request_id that you might add to achieve + full idempotence (for example when client call times out repeating the request + with the same request id will not create a new instance template again) + It should be in UUID format as defined in RFC 4122 + :param gcp_conn_id: The connection ID used to connect to Google Cloud. Defaults to 'google_cloud_default'. + :param api_version: API version used (for example v1 - or beta). Defaults to v1. + :param impersonation_chain: Service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + + # [START gce_instance_delete_template_fields] + template_fields: Sequence[str] = ( + "zone", + "resource_id", + "request_id", + "project_id", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gce_instance_delete_template_fields] + + def __init__( + self, + *, + resource_id: str, + zone: str, + request_id: str | None = None, + project_id: str | None = None, + retry: Retry | None = None, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + validate_body: bool = True, + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + self.zone = zone + self.request_id = request_id + self.resource_id = resource_id + self._field_validator = None # Optional[GcpBodyFieldValidator] + self.retry = retry + self.timeout = timeout + self.metadata = metadata + + if validate_body: + self._field_validator = GcpBodyFieldValidator( + GCE_INSTANCE_TEMPLATE_VALIDATION_PATCH_SPECIFICATION, api_version=api_version + ) + self._field_sanitizer = GcpBodyFieldSanitizer(GCE_INSTANCE_FIELDS_TO_SANITIZE) + super().__init__( + project_id=project_id, + zone=zone, + resource_id=resource_id, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def _validate_inputs(self) -> None: + super()._validate_inputs() + if not self.resource_id: + raise AirflowException("The required parameter 'resource_id' is missing. ") + + def execute(self, context: Context) -> None: + hook = ComputeEngineHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + try: + # Checking if specified Instance exists and if it does, delete it + hook.get_instance( + resource_id=self.resource_id, + project_id=self.project_id, + zone=self.zone, + ) + self.log.info("Successfully found Instance %s", self.resource_id) + hook.delete_instance( + resource_id=self.resource_id, + project_id=self.project_id, + request_id=self.request_id, + zone=self.zone, + ) + self.log.info("Successfully deleted Instance %s", self.resource_id) + except exceptions.NotFound as e: + # Expecting 404 Error in case if Instance doesn't exist. + if e.code == 404: + self.log.error("Instance %s doesn't exist", self.resource_id) + raise e + + class ComputeEngineStartInstanceOperator(ComputeEngineBaseOperator): """ Starts an instance in Google Compute Engine. @@ -95,24 +565,38 @@ class ComputeEngineStartInstanceOperator(ComputeEngineBaseOperator): account from the list granting this role to the originating account (templated). """ + operator_extra_links = (ComputeInstanceDetailsLink(),) + # [START gce_instance_start_template_fields] template_fields: Sequence[str] = ( - 'project_id', - 'zone', - 'resource_id', - 'gcp_conn_id', - 'api_version', - 'impersonation_chain', + "project_id", + "zone", + "resource_id", + "gcp_conn_id", + "api_version", + "impersonation_chain", ) # [END gce_instance_start_template_fields] - def execute(self, context: 'Context') -> None: + def _validate_inputs(self) -> None: + super()._validate_inputs() + if not self.resource_id: + raise AirflowException("The required parameter 'resource_id' is missing. ") + + def execute(self, context: Context) -> None: hook = ComputeEngineHook( gcp_conn_id=self.gcp_conn_id, api_version=self.api_version, impersonation_chain=self.impersonation_chain, ) - return hook.start_instance(zone=self.zone, resource_id=self.resource_id, project_id=self.project_id) + ComputeInstanceDetailsLink.persist( + context=context, + task_instance=self, + location_id=self.zone, + resource_id=self.resource_id, + project_id=self.project_id or hook.project_id, + ) + hook.start_instance(zone=self.zone, resource_id=self.resource_id, project_id=self.project_id) class ComputeEngineStopInstanceOperator(ComputeEngineBaseOperator): @@ -142,23 +626,37 @@ class ComputeEngineStopInstanceOperator(ComputeEngineBaseOperator): account from the list granting this role to the originating account (templated). """ + operator_extra_links = (ComputeInstanceDetailsLink(),) + # [START gce_instance_stop_template_fields] template_fields: Sequence[str] = ( - 'project_id', - 'zone', - 'resource_id', - 'gcp_conn_id', - 'api_version', - 'impersonation_chain', + "project_id", + "zone", + "resource_id", + "gcp_conn_id", + "api_version", + "impersonation_chain", ) # [END gce_instance_stop_template_fields] - def execute(self, context: 'Context') -> None: + def _validate_inputs(self) -> None: + super()._validate_inputs() + if not self.resource_id: + raise AirflowException("The required parameter 'resource_id' is missing. ") + + def execute(self, context: Context) -> None: hook = ComputeEngineHook( gcp_conn_id=self.gcp_conn_id, api_version=self.api_version, impersonation_chain=self.impersonation_chain, ) + ComputeInstanceDetailsLink.persist( + context=context, + task_instance=self, + location_id=self.zone, + resource_id=self.resource_id, + project_id=self.project_id or hook.project_id, + ) hook.stop_instance(zone=self.zone, resource_id=self.resource_id, project_id=self.project_id) @@ -199,40 +697,372 @@ class ComputeEngineSetMachineTypeOperator(ComputeEngineBaseOperator): account from the list granting this role to the originating account (templated). """ - # [START gce_instance_set_machine_type_template_fields] + operator_extra_links = (ComputeInstanceDetailsLink(),) + + # [START gce_instance_set_machine_type_template_fields] + template_fields: Sequence[str] = ( + "project_id", + "zone", + "resource_id", + "body", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gce_instance_set_machine_type_template_fields] + + def __init__( + self, + *, + zone: str, + resource_id: str, + body: dict, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + validate_body: bool = True, + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + self.body = body + self._field_validator: GcpBodyFieldValidator | None = None + if validate_body: + self._field_validator = GcpBodyFieldValidator( + SET_MACHINE_TYPE_VALIDATION_SPECIFICATION, api_version=api_version + ) + super().__init__( + project_id=project_id, + zone=zone, + resource_id=resource_id, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def _validate_all_body_fields(self) -> None: + if self._field_validator: + self._field_validator.validate(self.body) + + def _validate_inputs(self) -> None: + super()._validate_inputs() + if not self.resource_id: + raise AirflowException("The required parameter 'resource_id' is missing. ") + + def execute(self, context: Context) -> None: + hook = ComputeEngineHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self._validate_all_body_fields() + ComputeInstanceDetailsLink.persist( + context=context, + task_instance=self, + location_id=self.zone, + resource_id=self.resource_id, + project_id=self.project_id or hook.project_id, + ) + hook.set_machine_type( + zone=self.zone, resource_id=self.resource_id, body=self.body, project_id=self.project_id + ) + + +GCE_INSTANCE_TEMPLATE_VALIDATION_PATCH_SPECIFICATION: list[dict[str, Any]] = [ + dict(name="name", regexp="^.+$"), + dict(name="description", optional=True), + dict( + name="properties", + type="dict", + optional=True, + fields=[ + dict(name="description", optional=True), + dict(name="tags", optional=True, fields=[dict(name="items", optional=True)]), + dict(name="machineType", optional=True), + dict(name="canIpForward", optional=True), + dict(name="networkInterfaces", optional=True), # not validating deeper + dict(name="disks", optional=True), # not validating the array deeper + dict( + name="metadata", + optional=True, + fields=[ + dict(name="fingerprint", optional=True), + dict(name="items", optional=True), + dict(name="kind", optional=True), + ], + ), + dict(name="serviceAccounts", optional=True), # not validating deeper + dict( + name="scheduling", + optional=True, + fields=[ + dict(name="onHostMaintenance", optional=True), + dict(name="automaticRestart", optional=True), + dict(name="preemptible", optional=True), + dict(name="nodeAffinities", optional=True), # not validating deeper + ], + ), + dict(name="labels", optional=True), + dict(name="guestAccelerators", optional=True), # not validating deeper + dict(name="minCpuPlatform", optional=True), + ], + ), +] + +GCE_INSTANCE_FIELDS_TO_SANITIZE = [ + "kind", + "id", + "creationTimestamp", + "properties.disks.sha256", + "properties.disks.kind", + "properties.disks.sourceImageEncryptionKey.sha256", + "properties.disks.index", + "properties.disks.licenses", + "properties.networkInterfaces.kind", + "properties.networkInterfaces.accessConfigs.kind", + "properties.networkInterfaces.name", + "properties.metadata.kind", + "selfLink", +] + + +class ComputeEngineInsertInstanceTemplateOperator(ComputeEngineBaseOperator): + """ + Creates an Instance Template using specified fields. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:ComputeEngineInsertInstanceTemplateOperator` + + :param body: Instance template representation as object. + :param project_id: Google Cloud project ID where the Compute Engine Instance exists. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param request_id: Unique request_id that you might add to achieve + full idempotence (for example when client call times out repeating the request + with the same request id will not create a new instance template again) + It should be in UUID format as defined in RFC 4122 + :param resource_id: Name of the Instance Template. If the name of Instance Template is not specified in + body['name'], the name will be taken from 'resource_id' parameter + :param gcp_conn_id: The connection ID used to connect to Google Cloud. Defaults to 'google_cloud_default'. + :param api_version: API version used (for example v1 - or beta). Defaults to v1. + :param impersonation_chain: Service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + + operator_extra_links = (ComputeInstanceTemplateDetailsLink(),) + + # [START gce_instance_template_insert_fields] + template_fields: Sequence[str] = ( + "body", + "project_id", + "request_id", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gce_instance_template_insert_fields] + + def __init__( + self, + *, + body: dict, + project_id: str | None = None, + resource_id: str | None = None, + request_id: str | None = None, + retry: Retry | None = None, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + validate_body: bool = True, + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + self.body = body + self.request_id = request_id + self.resource_id = self.body["name"] if "name" in body else resource_id + self._field_validator = None # Optional[GcpBodyFieldValidator] + self.retry = retry + self.timeout = timeout + self.metadata = metadata + + if validate_body: + self._field_validator = GcpBodyFieldValidator( + GCE_INSTANCE_TEMPLATE_VALIDATION_PATCH_SPECIFICATION, api_version=api_version + ) + self._field_sanitizer = GcpBodyFieldSanitizer(GCE_INSTANCE_FIELDS_TO_SANITIZE) + super().__init__( + project_id=project_id, + zone="global", + resource_id=self.resource_id, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def check_body_fields(self) -> None: + required_params = ["machine_type", "disks", "network_interfaces"] + for param in required_params: + if param in self.body["properties"]: + continue + readable_param = param.replace("_", " ") + raise AirflowException( + f"The body '{self.body}' should contain at least {readable_param} for the new operator " + f"in the '{param}' field. Check (google.cloud.compute_v1.types.Instance) " + f"for more details about body fields description." + ) + + def _validate_all_body_fields(self) -> None: + if self._field_validator: + self._field_validator.validate(self.body) + + def _validate_inputs(self) -> None: + super()._validate_inputs() + if not self.resource_id and "name" not in self.body: + raise AirflowException( + "The required parameters 'resource_id' and body['name'] are missing. " + "Please, provide at least one of them." + ) + + def execute(self, context: Context) -> dict: + hook = ComputeEngineHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self._validate_all_body_fields() + self.check_body_fields() + try: + # Idempotence check (sort of) - we want to check if the new Template + # is already created and if is, then we assume it was created by previous run + # of operator - we do not check if content of the Template + # is as expected. Templates are immutable, so we cannot update it anyway + # and deleting/recreating is not worth the hassle especially + # that we cannot delete template if it is already used in some Instance + # Group Manager. We assume success if the template is simply present + existing_template = hook.get_instance_template( + resource_id=self.resource_id, project_id=self.project_id + ) + except exceptions.NotFound as e: + # We actually expect to get 404 / Not Found here as the template should + # not yet exist + if not e.code == 404: + raise e + else: + self.log.info("The %s Template already exists.", existing_template) + ComputeInstanceTemplateDetailsLink.persist( + context=context, + task_instance=self, + resource_id=self.resource_id, + project_id=self.project_id or hook.project_id, + ) + return InstanceTemplate.to_dict(existing_template) + self._field_sanitizer.sanitize(self.body) + self.log.info("Creating Instance Template with specified body: %s", self.body) + hook.insert_instance_template( + body=self.body, + request_id=self.request_id, + project_id=self.project_id, + ) + self.log.info("The specified Instance Template has been created SUCCESSFULLY", self.body) + new_template = hook.get_instance_template( + resource_id=self.resource_id, + project_id=self.project_id, + ) + ComputeInstanceTemplateDetailsLink.persist( + context=context, + task_instance=self, + resource_id=self.resource_id, + project_id=self.project_id or hook.project_id, + ) + return InstanceTemplate.to_dict(new_template) + + +class ComputeEngineDeleteInstanceTemplateOperator(ComputeEngineBaseOperator): + """ + Deletes an Instance Template in Google Compute Engine. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:ComputeEngineDeleteInstanceTemplateOperator` + + :param resource_id: Name of the Instance Template. + :param project_id: Google Cloud project ID where the Compute Engine Instance exists. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param request_id: Unique request_id that you might add to achieve + full idempotence (for example when client call times out repeating the request + with the same request id will not create a new instance template again) + It should be in UUID format as defined in RFC 4122 + :param gcp_conn_id: The connection ID used to connect to Google Cloud. Defaults to 'google_cloud_default'. + :param api_version: API version used (for example v1 - or beta). Defaults to v1. + :param impersonation_chain: Service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + + # [START gce_instance_template_delete_fields] template_fields: Sequence[str] = ( - 'project_id', - 'zone', - 'resource_id', - 'body', - 'gcp_conn_id', - 'api_version', - 'impersonation_chain', + "resource_id", + "request_id", + "project_id", + "gcp_conn_id", + "api_version", + "impersonation_chain", ) - # [END gce_instance_set_machine_type_template_fields] + # [END gce_instance_template_delete_fields] def __init__( self, *, - zone: str, resource_id: str, - body: dict, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1', + request_id: str | None = None, + project_id: str | None = None, + retry: Retry | None = None, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", validate_body: bool = True, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: - self.body = body - self._field_validator = None # type: Optional[GcpBodyFieldValidator] + self.request_id = request_id + self.resource_id = resource_id + self._field_validator = None # Optional[GcpBodyFieldValidator] + self.retry = retry + self.timeout = timeout + self.metadata = metadata + if validate_body: self._field_validator = GcpBodyFieldValidator( - SET_MACHINE_TYPE_VALIDATION_SPECIFICATION, api_version=api_version + GCE_INSTANCE_TEMPLATE_VALIDATION_PATCH_SPECIFICATION, api_version=api_version ) + self._field_sanitizer = GcpBodyFieldSanitizer(GCE_INSTANCE_FIELDS_TO_SANITIZE) super().__init__( project_id=project_id, - zone=zone, + zone="global", resource_id=resource_id, gcp_conn_id=gcp_conn_id, api_version=api_version, @@ -240,79 +1070,35 @@ def __init__( **kwargs, ) - def _validate_all_body_fields(self) -> None: - if self._field_validator: - self._field_validator.validate(self.body) + def _validate_inputs(self) -> None: + super()._validate_inputs() + if not self.resource_id: + raise AirflowException("The required parameter 'resource_id' is missing.") - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = ComputeEngineHook( gcp_conn_id=self.gcp_conn_id, api_version=self.api_version, impersonation_chain=self.impersonation_chain, ) - self._validate_all_body_fields() - return hook.set_machine_type( - zone=self.zone, resource_id=self.resource_id, body=self.body, project_id=self.project_id - ) - - -GCE_INSTANCE_TEMPLATE_VALIDATION_PATCH_SPECIFICATION = [ - dict(name="name", regexp="^.+$"), - dict(name="description", optional=True), - dict( - name="properties", - type='dict', - optional=True, - fields=[ - dict(name="description", optional=True), - dict(name="tags", optional=True, fields=[dict(name="items", optional=True)]), - dict(name="machineType", optional=True), - dict(name="canIpForward", optional=True), - dict(name="networkInterfaces", optional=True), # not validating deeper - dict(name="disks", optional=True), # not validating the array deeper - dict( - name="metadata", - optional=True, - fields=[ - dict(name="fingerprint", optional=True), - dict(name="items", optional=True), - dict(name="kind", optional=True), - ], - ), - dict(name="serviceAccounts", optional=True), # not validating deeper - dict( - name="scheduling", - optional=True, - fields=[ - dict(name="onHostMaintenance", optional=True), - dict(name="automaticRestart", optional=True), - dict(name="preemptible", optional=True), - dict(name="nodeAffinities", optional=True), # not validating deeper - ], - ), - dict(name="labels", optional=True), - dict(name="guestAccelerators", optional=True), # not validating deeper - dict(name="minCpuPlatform", optional=True), - ], - ), -] # type: List[Dict[str, Any]] - -GCE_INSTANCE_TEMPLATE_FIELDS_TO_SANITIZE = [ - "kind", - "id", - "name", - "creationTimestamp", - "properties.disks.sha256", - "properties.disks.kind", - "properties.disks.sourceImageEncryptionKey.sha256", - "properties.disks.index", - "properties.disks.licenses", - "properties.networkInterfaces.kind", - "properties.networkInterfaces.accessConfigs.kind", - "properties.networkInterfaces.name", - "properties.metadata.kind", - "selfLink", -] + try: + # Checking if specified Instance Template exists and if it does, delete it + hook.get_instance_template( + resource_id=self.resource_id, + project_id=self.project_id, + ) + self.log.info("Successfully found Instance Template %s", self.resource_id) + hook.delete_instance_template( + resource_id=self.resource_id, + project_id=self.project_id, + request_id=self.request_id, + ) + self.log.info("Successfully deleted Instance template") + except exceptions.NotFound as e: + # Expecting 404 Error in case if Instance template doesn't exist. + if e.code == 404: + self.log.error("Instance template %s doesn't exist", self.resource_id) + raise e class ComputeEngineCopyInstanceTemplateOperator(ComputeEngineBaseOperator): @@ -354,14 +1140,16 @@ class ComputeEngineCopyInstanceTemplateOperator(ComputeEngineBaseOperator): account from the list granting this role to the originating account (templated). """ + operator_extra_links = (ComputeInstanceTemplateDetailsLink(),) + # [START gce_instance_template_copy_operator_template_fields] template_fields: Sequence[str] = ( - 'project_id', - 'resource_id', - 'request_id', - 'gcp_conn_id', - 'api_version', - 'impersonation_chain', + "project_id", + "resource_id", + "request_id", + "gcp_conn_id", + "api_version", + "impersonation_chain", ) # [END gce_instance_template_copy_operator_template_fields] @@ -370,18 +1158,18 @@ def __init__( *, resource_id: str, body_patch: dict, - project_id: Optional[str] = None, - request_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1', + project_id: str | None = None, + request_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", validate_body: bool = True, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.body_patch = body_patch self.request_id = request_id - self._field_validator = None # Optional[GcpBodyFieldValidator] - if 'name' not in self.body_patch: + self._field_validator = None # GcpBodyFieldValidator | None + if "name" not in self.body_patch: raise AirflowException( f"The body '{body_patch}' should contain at least name for the new operator " f"in the 'name' field" @@ -390,10 +1178,10 @@ def __init__( self._field_validator = GcpBodyFieldValidator( GCE_INSTANCE_TEMPLATE_VALIDATION_PATCH_SPECIFICATION, api_version=api_version ) - self._field_sanitizer = GcpBodyFieldSanitizer(GCE_INSTANCE_TEMPLATE_FIELDS_TO_SANITIZE) + self._field_sanitizer = GcpBodyFieldSanitizer(GCE_INSTANCE_FIELDS_TO_SANITIZE) super().__init__( project_id=project_id, - zone='global', + zone="global", resource_id=resource_id, gcp_conn_id=gcp_conn_id, api_version=api_version, @@ -405,7 +1193,12 @@ def _validate_all_body_fields(self) -> None: if self._field_validator: self._field_validator.validate(self.body_patch) - def execute(self, context: 'Context') -> dict: + def _validate_inputs(self) -> None: + super()._validate_inputs() + if not self.resource_id: + raise AirflowException("The required parameter 'resource_id' is missing.") + + def execute(self, context: Context) -> dict: hook = ComputeEngineHook( gcp_conn_id=self.gcp_conn_id, api_version=self.api_version, @@ -416,31 +1209,53 @@ def execute(self, context: 'Context') -> dict: # Idempotence check (sort of) - we want to check if the new template # is already created and if is, then we assume it was created by previous run # of CopyTemplate operator - we do not check if content of the template - # is as expected. Templates are immutable so we cannot update it anyway + # is as expected. Templates are immutable, so we cannot update it anyway # and deleting/recreating is not worth the hassle especially # that we cannot delete template if it is already used in some Instance # Group Manager. We assume success if the template is simply present existing_template = hook.get_instance_template( - resource_id=self.body_patch['name'], project_id=self.project_id + resource_id=self.body_patch["name"], + project_id=self.project_id, ) + except exceptions.NotFound as e: + # We actually expect to get 404 / Not Found here as the template should + # not yet exist + if not e.code == 404: + raise e + else: self.log.info( - "The %s template already existed. It was likely created by previous run of the operator. " + "The %s template already exists. It was likely created by previous run of the operator. " "Assuming success.", existing_template, ) - return existing_template - except HttpError as e: - # We actually expect to get 404 / Not Found here as the template should - # not yet exist - if not e.resp.status == 404: - raise e - old_body = hook.get_instance_template(resource_id=self.resource_id, project_id=self.project_id) + ComputeInstanceTemplateDetailsLink.persist( + context=context, + task_instance=self, + resource_id=self.body_patch["name"], + project_id=self.project_id or hook.project_id, + ) + return InstanceTemplate.to_dict(existing_template) + old_body = InstanceTemplate.to_dict( + hook.get_instance_template( + resource_id=self.resource_id, + project_id=self.project_id, + ) + ) new_body = deepcopy(old_body) self._field_sanitizer.sanitize(new_body) new_body = merge(new_body, self.body_patch) self.log.info("Calling insert instance template with updated body: %s", new_body) hook.insert_instance_template(body=new_body, request_id=self.request_id, project_id=self.project_id) - return hook.get_instance_template(resource_id=self.body_patch['name'], project_id=self.project_id) + instance_template = hook.get_instance_template( + resource_id=self.body_patch["name"], project_id=self.project_id + ) + ComputeInstanceTemplateDetailsLink.persist( + context=context, + task_instance=self, + resource_id=self.body_patch["name"], + project_id=self.project_id or hook.project_id, + ) + return InstanceTemplate.to_dict(instance_template) class ComputeEngineInstanceGroupUpdateManagerTemplateOperator(ComputeEngineBaseOperator): @@ -478,17 +1293,19 @@ class ComputeEngineInstanceGroupUpdateManagerTemplateOperator(ComputeEngineBaseO account from the list granting this role to the originating account (templated). """ + operator_extra_links = (ComputeInstanceGroupManagerDetailsLink(),) + # [START gce_igm_update_template_operator_template_fields] template_fields: Sequence[str] = ( - 'project_id', - 'resource_id', - 'zone', - 'request_id', - 'source_template', - 'destination_template', - 'gcp_conn_id', - 'api_version', - 'impersonation_chain', + "project_id", + "resource_id", + "zone", + "request_id", + "source_template", + "destination_template", + "gcp_conn_id", + "api_version", + "impersonation_chain", ) # [END gce_igm_update_template_operator_template_fields] @@ -499,12 +1316,12 @@ def __init__( zone: str, source_template: str, destination_template: str, - project_id: Optional[str] = None, - update_policy: Optional[Dict[str, Any]] = None, - request_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version='beta', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + update_policy: dict[str, Any] | None = None, + request_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + api_version="beta", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.zone = zone @@ -513,7 +1330,7 @@ def __init__( self.request_id = request_id self.update_policy = update_policy self._change_performed = False - if api_version == 'v1': + if api_version == "v1": raise AirflowException( "Api version v1 does not have update/patch " "operations for Instance Group Managers. Use beta" @@ -529,12 +1346,17 @@ def __init__( **kwargs, ) + def _validate_inputs(self) -> None: + super()._validate_inputs() + if not self.resource_id: + raise AirflowException("The required parameter 'resource_id' is missing. ") + def _possibly_replace_template(self, dictionary: dict) -> None: - if dictionary.get('instanceTemplate') == self.source_template: - dictionary['instanceTemplate'] = self.destination_template + if dictionary.get("instanceTemplate") == self.source_template: + dictionary["instanceTemplate"] = self.destination_template self._change_performed = True - def execute(self, context: 'Context') -> Optional[bool]: + def execute(self, context: Context) -> bool | None: hook = ComputeEngineHook( gcp_conn_id=self.gcp_conn_id, api_version=self.api_version, @@ -544,18 +1366,26 @@ def execute(self, context: 'Context') -> Optional[bool]: zone=self.zone, resource_id=self.resource_id, project_id=self.project_id ) patch_body = {} - if 'versions' in old_instance_group_manager: - patch_body['versions'] = old_instance_group_manager['versions'] - if 'instanceTemplate' in old_instance_group_manager: - patch_body['instanceTemplate'] = old_instance_group_manager['instanceTemplate'] + igm_dict = InstanceGroupManager.to_dict(old_instance_group_manager) + if "versions" in igm_dict: + patch_body["versions"] = igm_dict["versions"] + if "instanceTemplate" in igm_dict: + patch_body["instanceTemplate"] = igm_dict["instanceTemplate"] if self.update_policy: - patch_body['updatePolicy'] = self.update_policy + patch_body["updatePolicy"] = self.update_policy self._possibly_replace_template(patch_body) - if 'versions' in patch_body: - for version in patch_body['versions']: + if "versions" in patch_body: + for version in patch_body["versions"]: self._possibly_replace_template(version) if self._change_performed or self.update_policy: self.log.info("Calling patch instance template with updated body: %s", patch_body) + ComputeInstanceGroupManagerDetailsLink.persist( + context=context, + task_instance=self, + location_id=self.zone, + resource_id=self.resource_id, + project_id=self.project_id or hook.project_id, + ) return hook.patch_instance_group_manager( zone=self.zone, resource_id=self.resource_id, @@ -565,4 +1395,295 @@ def execute(self, context: 'Context') -> Optional[bool]: ) else: # Idempotence achieved + ComputeInstanceGroupManagerDetailsLink.persist( + context=context, + task_instance=self, + location_id=self.zone, + resource_id=self.resource_id, + project_id=self.project_id or hook.project_id, + ) return True + + +class ComputeEngineInsertInstanceGroupManagerOperator(ComputeEngineBaseOperator): + """ + Creates an Instance Group Managers using the body specified. + After the group is created, instances in the group are created using the specified Instance Template. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:ComputeEngineInsertInstanceGroupManagerOperator` + + :param body: Instance Group Managers representation as object. + :param project_id: Google Cloud project ID where the Compute Engine Instance Group Managers exists. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param request_id: Unique request_id that you might add to achieve + full idempotence (for example when client call times out repeating the request + with the same request id will not create a new Instance Group Managers again) + It should be in UUID format as defined in RFC 4122 + :param resource_id: Name of the Instance Group Managers. If the name of Instance Group Managers is + not specified in body['name'], the name will be taken from 'resource_id' parameter. + :param gcp_conn_id: The connection ID used to connect to Google Cloud. Defaults to 'google_cloud_default'. + :param api_version: API version used (for example v1 - or beta). Defaults to v1. + :param impersonation_chain: Service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + + operator_extra_links = (ComputeInstanceGroupManagerDetailsLink(),) + + # [START gce_igm_insert_fields] + template_fields: Sequence[str] = ( + "project_id", + "body", + "zone", + "request_id", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gce_igm_insert_fields] + + def __init__( + self, + *, + body: dict, + zone: str, + project_id: str | None = None, + resource_id: str | None = None, + request_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + api_version="v1", + retry: Retry | None = None, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + impersonation_chain: str | Sequence[str] | None = None, + validate_body: bool = True, + **kwargs, + ) -> None: + self.body = body + self.zone = zone + self.request_id = request_id + self.resource_id = self.body["name"] if "name" in body else resource_id + self._field_validator = None # Optional[GcpBodyFieldValidator] + self.retry = retry + self.timeout = timeout + self.metadata = metadata + if validate_body: + self._field_validator = GcpBodyFieldValidator( + GCE_INSTANCE_TEMPLATE_VALIDATION_PATCH_SPECIFICATION, api_version=api_version + ) + self._field_sanitizer = GcpBodyFieldSanitizer(GCE_INSTANCE_FIELDS_TO_SANITIZE) + super().__init__( + project_id=project_id, + zone=zone, + resource_id=self.resource_id, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def check_body_fields(self) -> None: + required_params = ["base_instance_name", "target_size", "instance_template"] + for param in required_params: + if param in self.body: + continue + readable_param = param.replace("_", " ") + raise AirflowException( + f"The body '{self.body}' should contain at least {readable_param} for the new operator " + f"in the '{param}' field. Check (google.cloud.compute_v1.types.Instance) " + f"for more details about body fields description." + ) + + def _validate_all_body_fields(self) -> None: + if self._field_validator: + self._field_validator.validate(self.body) + + def _validate_inputs(self) -> None: + super()._validate_inputs() + if not self.resource_id and "name" not in self.body: + raise AirflowException( + "The required parameters 'resource_id' and body['name'] are missing. " + "Please, provide at least one of them." + ) + + def execute(self, context: Context) -> dict: + hook = ComputeEngineHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self._validate_all_body_fields() + self.check_body_fields() + try: + # Idempotence check (sort of) - we want to check if the new Instance Group Manager + # is already created and if isn't, we create new one + existing_instance_group_manager = hook.get_instance_group_manager( + resource_id=self.resource_id, + project_id=self.project_id, + zone=self.zone, + ) + except exceptions.NotFound as e: + # We actually expect to get 404 / Not Found here as the Instance Group Manager should + # not yet exist + if not e.code == 404: + raise e + else: + self.log.info("The %s Instance Group Manager already exists", existing_instance_group_manager) + ComputeInstanceGroupManagerDetailsLink.persist( + context=context, + task_instance=self, + resource_id=self.resource_id, + project_id=self.project_id or hook.project_id, + location_id=self.zone, + ) + return InstanceGroupManager.to_dict(existing_instance_group_manager) + self._field_sanitizer.sanitize(self.body) + self.log.info("Creating Instance Group Manager with specified body: %s", self.body) + hook.insert_instance_group_manager( + body=self.body, + request_id=self.request_id, + project_id=self.project_id, + zone=self.zone, + ) + self.log.info("The specified Instance Group Manager has been created SUCCESSFULLY", self.body) + new_instance_group_manager = hook.get_instance_group_manager( + resource_id=self.resource_id, + project_id=self.project_id, + zone=self.zone, + ) + ComputeInstanceGroupManagerDetailsLink.persist( + context=context, + task_instance=self, + location_id=self.zone, + resource_id=self.resource_id, + project_id=self.project_id or hook.project_id, + ) + return InstanceGroupManager.to_dict(new_instance_group_manager) + + +class ComputeEngineDeleteInstanceGroupManagerOperator(ComputeEngineBaseOperator): + """ + Deletes an Instance Group Managers. + Deleting an Instance Group Manager is permanent and cannot be undone. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:ComputeEngineDeleteInstanceGroupManagerOperator` + + :param resource_id: Name of the Instance Group Managers. + :param project_id: Google Cloud project ID where the Compute Engine Instance Group Managers exists. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param request_id: Unique request_id that you might add to achieve + full idempotence (for example when client call times out repeating the request + with the same request id will not create a new Instance Group Managers again) + It should be in UUID format as defined in RFC 4122 + :param gcp_conn_id: The connection ID used to connect to Google Cloud. Defaults to 'google_cloud_default'. + :param api_version: API version used (for example v1 - or beta). Defaults to v1. + :param impersonation_chain: Service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + + # [START gce_igm_delete_fields] + template_fields: Sequence[str] = ( + "project_id", + "resource_id", + "zone", + "request_id", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) + # [END gce_igm_delete_fields] + + def __init__( + self, + *, + resource_id: str, + zone: str, + project_id: str | None = None, + request_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + api_version="v1", + retry: Retry | None = None, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + impersonation_chain: str | Sequence[str] | None = None, + validate_body: bool = True, + **kwargs, + ) -> None: + self.zone = zone + self.request_id = request_id + self.resource_id = resource_id + self._field_validator = None # Optional[GcpBodyFieldValidator] + self.retry = retry + self.timeout = timeout + self.metadata = metadata + if validate_body: + self._field_validator = GcpBodyFieldValidator( + GCE_INSTANCE_TEMPLATE_VALIDATION_PATCH_SPECIFICATION, api_version=api_version + ) + self._field_sanitizer = GcpBodyFieldSanitizer(GCE_INSTANCE_FIELDS_TO_SANITIZE) + super().__init__( + project_id=project_id, + zone=zone, + resource_id=resource_id, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def _validate_inputs(self) -> None: + super()._validate_inputs() + if not self.resource_id: + raise AirflowException("The required parameter 'resource_id' is missing. ") + + def execute(self, context: Context): + hook = ComputeEngineHook( + gcp_conn_id=self.gcp_conn_id, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + try: + # Checking if specified Instance Group Managers exists and if it does, delete it + hook.get_instance_group_manager( + resource_id=self.resource_id, + project_id=self.project_id, + zone=self.zone, + ) + self.log.info("Successfully found Group Manager %s", self.resource_id) + hook.delete_instance_group_manager( + resource_id=self.resource_id, + project_id=self.project_id, + request_id=self.request_id, + zone=self.zone, + ) + self.log.info("Successfully deleted Instance Group Managers") + except exceptions.NotFound as e: + # Expecting 404 Error in case if Instance Group Managers doesn't exist. + if e.code == 404: + self.log.error("Instance Group Managers %s doesn't exist", self.resource_id) + raise e diff --git a/airflow/providers/google/cloud/operators/datacatalog.py b/airflow/providers/google/cloud/operators/datacatalog.py index 145aeaa4c6233..0e658c56be9a9 100644 --- a/airflow/providers/google/cloud/operators/datacatalog.py +++ b/airflow/providers/google/cloud/operators/datacatalog.py @@ -14,17 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Sequence from google.api_core.exceptions import AlreadyExists, NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry -from google.cloud.datacatalog_v1beta1 import DataCatalogClient, SearchCatalogResult -from google.cloud.datacatalog_v1beta1.types import ( +from google.cloud.datacatalog import ( + DataCatalogClient, Entry, EntryGroup, SearchCatalogRequest, + SearchCatalogResult, Tag, TagTemplate, TagTemplateField, @@ -33,6 +35,11 @@ from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.datacatalog import CloudDataCatalogHook +from airflow.providers.google.cloud.links.datacatalog import ( + DataCatalogEntryGroupLink, + DataCatalogEntryLink, + DataCatalogTagTemplateLink, +) if TYPE_CHECKING: from airflow.utils.context import Context @@ -87,6 +94,7 @@ class CloudDataCatalogCreateEntryOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (DataCatalogEntryLink(),) def __init__( self, @@ -94,13 +102,13 @@ def __init__( location: str, entry_group: str, entry_id: str, - entry: Union[Dict, Entry], - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + entry: dict | Entry, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -115,7 +123,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDataCatalogHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -143,7 +151,15 @@ def execute(self, context: 'Context'): ) _, _, entry_id = result.name.rpartition("/") self.log.info("Current entry_id ID: %s", entry_id) - context["task_instance"].xcom_push(key="entry_id", value=entry_id) + self.xcom_push(context, key="entry_id", value=entry_id) + DataCatalogEntryLink.persist( + context=context, + task_instance=self, + entry_id=self.entry_id, + entry_group_id=self.entry_group, + location_id=self.location, + project_id=self.project_id or hook.project_id, + ) return Entry.to_dict(result) @@ -195,19 +211,20 @@ class CloudDataCatalogCreateEntryGroupOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (DataCatalogEntryGroupLink(),) def __init__( self, *, location: str, entry_group_id: str, - entry_group: Union[Dict, EntryGroup], - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + entry_group: dict | EntryGroup, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -221,7 +238,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDataCatalogHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -248,7 +265,14 @@ def execute(self, context: 'Context'): _, _, entry_group_id = result.name.rpartition("/") self.log.info("Current entry group ID: %s", entry_group_id) - context["task_instance"].xcom_push(key="entry_group_id", value=entry_group_id) + self.xcom_push(context, key="entry_group_id", value=entry_group_id) + DataCatalogEntryGroupLink.persist( + context=context, + task_instance=self, + entry_group_id=self.entry_group_id, + location_id=self.location, + project_id=self.project_id or hook.project_id, + ) return EntryGroup.to_dict(result) @@ -301,6 +325,7 @@ class CloudDataCatalogCreateTagOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (DataCatalogEntryLink(),) def __init__( self, @@ -308,14 +333,14 @@ def __init__( location: str, entry_group: str, entry: str, - tag: Union[Dict, Tag], - template_id: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + tag: dict | Tag, + template_id: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -331,7 +356,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDataCatalogHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -375,7 +400,15 @@ def execute(self, context: 'Context'): _, _, tag_id = tag.name.rpartition("/") self.log.info("Current Tag ID: %s", tag_id) - context["task_instance"].xcom_push(key="tag_id", value=tag_id) + self.xcom_push(context, key="tag_id", value=tag_id) + DataCatalogEntryLink.persist( + context=context, + task_instance=self, + entry_id=self.entry, + entry_group_id=self.entry_group, + location_id=self.location, + project_id=self.project_id or hook.project_id, + ) return Tag.to_dict(tag) @@ -425,19 +458,20 @@ class CloudDataCatalogCreateTagTemplateOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (DataCatalogTagTemplateLink(),) def __init__( self, *, location: str, tag_template_id: str, - tag_template: Union[Dict, TagTemplate], - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + tag_template: dict | TagTemplate, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -451,7 +485,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDataCatalogHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -477,7 +511,14 @@ def execute(self, context: 'Context'): ) _, _, tag_template = result.name.rpartition("/") self.log.info("Current Tag ID: %s", tag_template) - context["task_instance"].xcom_push(key="tag_template_id", value=tag_template) + self.xcom_push(context, key="tag_template_id", value=tag_template) + DataCatalogTagTemplateLink.persist( + context=context, + task_instance=self, + tag_template_id=self.tag_template_id, + location_id=self.location, + project_id=self.project_id or hook.project_id, + ) return TagTemplate.to_dict(result) @@ -532,6 +573,7 @@ class CloudDataCatalogCreateTagTemplateFieldOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (DataCatalogTagTemplateLink(),) def __init__( self, @@ -539,13 +581,13 @@ def __init__( location: str, tag_template: str, tag_template_field_id: str, - tag_template_field: Union[Dict, TagTemplateField], - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + tag_template_field: dict | TagTemplateField, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -560,7 +602,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDataCatalogHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -588,7 +630,14 @@ def execute(self, context: 'Context'): result = tag_template.fields[self.tag_template_field_id] self.log.info("Current Tag ID: %s", self.tag_template_field_id) - context["task_instance"].xcom_push(key="tag_template_field_id", value=self.tag_template_field_id) + self.xcom_push(context, key="tag_template_field_id", value=self.tag_template_field_id) + DataCatalogTagTemplateLink.persist( + context=context, + task_instance=self, + tag_template_id=self.tag_template, + location_id=self.location, + project_id=self.project_id or hook.project_id, + ) return TagTemplateField.to_dict(result) @@ -640,12 +689,12 @@ def __init__( location: str, entry_group: str, entry: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -659,7 +708,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudDataCatalogHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -724,12 +773,12 @@ def __init__( *, location: str, entry_group: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -742,7 +791,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudDataCatalogHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -810,12 +859,12 @@ def __init__( entry_group: str, entry: str, tag: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -830,7 +879,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudDataCatalogHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -899,12 +948,12 @@ def __init__( location: str, tag_template: str, force: bool, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -918,7 +967,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudDataCatalogHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -987,12 +1036,12 @@ def __init__( tag_template: str, field: str, force: bool, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1007,7 +1056,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudDataCatalogHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -1067,6 +1116,7 @@ class CloudDataCatalogGetEntryOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (DataCatalogEntryLink(),) def __init__( self, @@ -1074,12 +1124,12 @@ def __init__( location: str, entry_group: str, entry: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1093,7 +1143,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: hook = CloudDataCatalogHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -1106,6 +1156,14 @@ def execute(self, context: 'Context') -> dict: timeout=self.timeout, metadata=self.metadata, ) + DataCatalogEntryLink.persist( + context=context, + task_instance=self, + entry_id=self.entry, + entry_group_id=self.entry_group, + location_id=self.location, + project_id=self.project_id or hook.project_id, + ) return Entry.to_dict(result) @@ -1153,6 +1211,7 @@ class CloudDataCatalogGetEntryGroupOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (DataCatalogEntryGroupLink(),) def __init__( self, @@ -1160,12 +1219,12 @@ def __init__( location: str, entry_group: str, read_mask: FieldMask, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1179,7 +1238,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: hook = CloudDataCatalogHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -1192,6 +1251,13 @@ def execute(self, context: 'Context') -> dict: timeout=self.timeout, metadata=self.metadata, ) + DataCatalogEntryGroupLink.persist( + context=context, + task_instance=self, + entry_group_id=self.entry_group, + location_id=self.location, + project_id=self.project_id or hook.project_id, + ) return EntryGroup.to_dict(result) @@ -1234,18 +1300,19 @@ class CloudDataCatalogGetTagTemplateOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (DataCatalogTagTemplateLink(),) def __init__( self, *, location: str, tag_template: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1258,7 +1325,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: hook = CloudDataCatalogHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -1270,6 +1337,13 @@ def execute(self, context: 'Context') -> dict: timeout=self.timeout, metadata=self.metadata, ) + DataCatalogTagTemplateLink.persist( + context=context, + task_instance=self, + tag_template_id=self.tag_template, + location_id=self.location, + project_id=self.project_id or hook.project_id, + ) return TagTemplate.to_dict(result) @@ -1319,6 +1393,7 @@ class CloudDataCatalogListTagsOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (DataCatalogEntryLink(),) def __init__( self, @@ -1327,12 +1402,12 @@ def __init__( entry_group: str, entry: str, page_size: int = 100, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1347,7 +1422,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> list: + def execute(self, context: Context) -> list: hook = CloudDataCatalogHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -1361,6 +1436,14 @@ def execute(self, context: 'Context') -> list: timeout=self.timeout, metadata=self.metadata, ) + DataCatalogEntryLink.persist( + context=context, + task_instance=self, + entry_id=self.entry, + entry_group_id=self.entry_group, + location_id=self.location, + project_id=self.project_id or hook.project_id, + ) return [Tag.to_dict(item) for item in result] @@ -1406,18 +1489,19 @@ class CloudDataCatalogLookupEntryOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (DataCatalogEntryLink(),) def __init__( self, *, - linked_resource: Optional[str] = None, - sql_resource: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + linked_resource: str | None = None, + sql_resource: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1430,7 +1514,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: hook = CloudDataCatalogHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -1441,6 +1525,16 @@ def execute(self, context: 'Context') -> dict: timeout=self.timeout, metadata=self.metadata, ) + + project_id, location_id, entry_group_id, entry_id = result.name.split("/")[1::2] + DataCatalogEntryLink.persist( + context=context, + task_instance=self, + entry_id=entry_id, + entry_group_id=entry_group_id, + location_id=location_id, + project_id=project_id, + ) return Entry.to_dict(result) @@ -1489,6 +1583,7 @@ class CloudDataCatalogRenameTagTemplateFieldOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (DataCatalogTagTemplateLink(),) def __init__( self, @@ -1497,12 +1592,12 @@ def __init__( tag_template: str, field: str, new_tag_template_field_id: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1517,7 +1612,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudDataCatalogHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -1531,6 +1626,13 @@ def execute(self, context: 'Context') -> None: timeout=self.timeout, metadata=self.metadata, ) + DataCatalogTagTemplateLink.persist( + context=context, + task_instance=self, + tag_template_id=self.tag_template, + location_id=self.location, + project_id=self.project_id or hook.project_id, + ) class CloudDataCatalogSearchCatalogOperator(BaseOperator): @@ -1605,15 +1707,15 @@ class CloudDataCatalogSearchCatalogOperator(BaseOperator): def __init__( self, *, - scope: Union[Dict, SearchCatalogRequest.Scope], + scope: dict | SearchCatalogRequest.Scope, query: str, page_size: int = 100, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1627,7 +1729,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> list: + def execute(self, context: Context) -> list: hook = CloudDataCatalogHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -1695,21 +1797,22 @@ class CloudDataCatalogUpdateEntryOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (DataCatalogEntryLink(),) def __init__( self, *, - entry: Union[Dict, Entry], - update_mask: Union[Dict, FieldMask], - location: Optional[str] = None, - entry_group: Optional[str] = None, - entry_id: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + entry: dict | Entry, + update_mask: dict | FieldMask, + location: str | None = None, + entry_group: str | None = None, + entry_id: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1725,11 +1828,11 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudDataCatalogHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) - hook.update_entry( + result = hook.update_entry( entry=self.entry, update_mask=self.update_mask, location=self.location, @@ -1741,6 +1844,16 @@ def execute(self, context: 'Context') -> None: metadata=self.metadata, ) + location_id, entry_group_id, entry_id = result.name.split("/")[3::2] + DataCatalogEntryLink.persist( + context=context, + task_instance=self, + entry_id=self.entry_id or entry_id, + entry_group_id=self.entry_group or entry_group_id, + location_id=self.location or location_id, + project_id=self.project_id or hook.project_id, + ) + class CloudDataCatalogUpdateTagOperator(BaseOperator): """ @@ -1795,22 +1908,23 @@ class CloudDataCatalogUpdateTagOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (DataCatalogEntryLink(),) def __init__( self, *, - tag: Union[Dict, Tag], - update_mask: Union[Dict, FieldMask], - location: Optional[str] = None, - entry_group: Optional[str] = None, - entry: Optional[str] = None, - tag_id: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + tag: dict | Tag, + update_mask: dict | FieldMask, + location: str | None = None, + entry_group: str | None = None, + entry: str | None = None, + tag_id: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1827,11 +1941,11 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudDataCatalogHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) - hook.update_tag( + result = hook.update_tag( tag=self.tag, update_mask=self.update_mask, location=self.location, @@ -1844,6 +1958,16 @@ def execute(self, context: 'Context') -> None: metadata=self.metadata, ) + location_id, entry_group_id, entry_id = result.name.split("/")[3:8:2] + DataCatalogEntryLink.persist( + context=context, + task_instance=self, + entry_id=self.entry or entry_id, + entry_group_id=self.entry_group or entry_group_id, + location_id=self.location or location_id, + project_id=self.project_id or hook.project_id, + ) + class CloudDataCatalogUpdateTagTemplateOperator(BaseOperator): """ @@ -1900,20 +2024,21 @@ class CloudDataCatalogUpdateTagTemplateOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (DataCatalogTagTemplateLink(),) def __init__( self, *, - tag_template: Union[Dict, TagTemplate], - update_mask: Union[Dict, FieldMask], - location: Optional[str] = None, - tag_template_id: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + tag_template: dict | TagTemplate, + update_mask: dict | FieldMask, + location: str | None = None, + tag_template_id: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1928,11 +2053,11 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudDataCatalogHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) - hook.update_tag_template( + result = hook.update_tag_template( tag_template=self.tag_template, update_mask=self.update_mask, location=self.location, @@ -1943,6 +2068,15 @@ def execute(self, context: 'Context') -> None: metadata=self.metadata, ) + location_id, tag_template_id = result.name.split("/")[3::2] + DataCatalogTagTemplateLink.persist( + context=context, + task_instance=self, + tag_template_id=self.tag_template_id or tag_template_id, + location_id=self.location or location_id, + project_id=self.project_id or hook.project_id, + ) + class CloudDataCatalogUpdateTagTemplateFieldOperator(BaseOperator): """ @@ -2005,22 +2139,23 @@ class CloudDataCatalogUpdateTagTemplateFieldOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (DataCatalogTagTemplateLink(),) def __init__( self, *, - tag_template_field: Union[Dict, TagTemplateField], - update_mask: Union[Dict, FieldMask], - tag_template_field_name: Optional[str] = None, - location: Optional[str] = None, - tag_template: Optional[str] = None, - tag_template_field_id: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + tag_template_field: dict | TagTemplateField, + update_mask: dict | FieldMask, + tag_template_field_name: str | None = None, + location: str | None = None, + tag_template: str | None = None, + tag_template_field_id: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -2037,11 +2172,11 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudDataCatalogHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) - hook.update_tag_template_field( + result = hook.update_tag_template_field( tag_template_field=self.tag_template_field, update_mask=self.update_mask, tag_template_field_name=self.tag_template_field_name, @@ -2053,3 +2188,12 @@ def execute(self, context: 'Context') -> None: timeout=self.timeout, metadata=self.metadata, ) + + location_id, tag_template_id = result.name.split("/")[3:6:2] + DataCatalogTagTemplateLink.persist( + context=context, + task_instance=self, + tag_template_id=self.tag_template or tag_template_id, + location_id=self.location or location_id, + project_id=self.project_id or hook.project_id, + ) diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py index 98a17d800cbe5..b7dc2b149ced9 100644 --- a/airflow/providers/google/cloud/operators/dataflow.py +++ b/airflow/providers/google/cloud/operators/dataflow.py @@ -16,12 +16,14 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Dataflow operators.""" +from __future__ import annotations + import copy import re import warnings from contextlib import ExitStack from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Sequence from airflow.models import BaseOperator from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType @@ -78,6 +80,11 @@ class DataflowConfiguration: If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + + .. warning:: + + This option requires Apache Beam 2.39.0 or newer. + :param drain_pipeline: Optional, set to True if want to stop streaming job by draining it instead of canceling during killing task instance. See: https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline @@ -134,18 +141,18 @@ def __init__( *, job_name: str = "{{task.task_id}}", append_job_name: bool = True, - project_id: Optional[str] = None, - location: Optional[str] = DEFAULT_DATAFLOW_LOCATION, + project_id: str | None = None, + location: str | None = DEFAULT_DATAFLOW_LOCATION, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, + delegate_to: str | None = None, poll_sleep: int = 10, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, drain_pipeline: bool = False, - cancel_timeout: Optional[int] = 5 * 60, - wait_until_finished: Optional[bool] = None, - multiple_jobs: Optional[bool] = None, + cancel_timeout: int | None = 5 * 60, + wait_until_finished: bool | None = None, + multiple_jobs: bool | None = None, check_if_running: CheckJobRunning = CheckJobRunning.WaitForRun, - service_account: Optional[str] = None, + service_account: str | None = None, ) -> None: self.job_name = job_name self.append_job_name = append_job_name @@ -332,18 +339,18 @@ def __init__( *, jar: str, job_name: str = "{{task.task_id}}", - dataflow_default_options: Optional[dict] = None, - options: Optional[dict] = None, - project_id: Optional[str] = None, + dataflow_default_options: dict | None = None, + options: dict | None = None, + project_id: str | None = None, location: str = DEFAULT_DATAFLOW_LOCATION, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, + delegate_to: str | None = None, poll_sleep: int = 10, - job_class: Optional[str] = None, + job_class: str | None = None, check_if_running: CheckJobRunning = CheckJobRunning.WaitForRun, multiple_jobs: bool = False, - cancel_timeout: Optional[int] = 10 * 60, - wait_until_finished: Optional[bool] = None, + cancel_timeout: int | None = 10 * 60, + wait_until_finished: bool | None = None, **kwargs, ) -> None: # TODO: Remove one day @@ -375,10 +382,10 @@ def __init__( self.cancel_timeout = cancel_timeout self.wait_until_finished = wait_until_finished self.job_id = None - self.beam_hook: Optional[BeamHook] = None - self.dataflow_hook: Optional[DataflowHook] = None + self.beam_hook: BeamHook | None = None + self.dataflow_hook: DataflowHook | None = None - def execute(self, context: 'Context'): + def execute(self, context: Context): """Execute the Apache Beam Pipeline.""" self.beam_hook = BeamHook(runner=BeamRunnerType.DataflowRunner) self.dataflow_hook = DataflowHook( @@ -499,6 +506,7 @@ class DataflowTemplatedJobStartOperator(BaseOperator): `__ :param cancel_timeout: How long (in seconds) operator should wait for the pipeline to be successfully cancelled when task is being killed. + :param append_job_name: True if unique suffix has to be appended to job name. :param wait_until_finished: (Optional) If True, wait for the end of pipeline execution before exiting. If False, only submits job. @@ -600,18 +608,19 @@ def __init__( *, template: str, job_name: str = "{{task.task_id}}", - options: Optional[Dict[str, Any]] = None, - dataflow_default_options: Optional[Dict[str, Any]] = None, - parameters: Optional[Dict[str, str]] = None, - project_id: Optional[str] = None, + options: dict[str, Any] | None = None, + dataflow_default_options: dict[str, Any] | None = None, + parameters: dict[str, str] | None = None, + project_id: str | None = None, location: str = DEFAULT_DATAFLOW_LOCATION, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, + delegate_to: str | None = None, poll_sleep: int = 10, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - environment: Optional[Dict] = None, - cancel_timeout: Optional[int] = 10 * 60, - wait_until_finished: Optional[bool] = None, + impersonation_chain: str | Sequence[str] | None = None, + environment: dict | None = None, + cancel_timeout: int | None = 10 * 60, + wait_until_finished: bool | None = None, + append_job_name: bool = True, **kwargs, ) -> None: super().__init__(**kwargs) @@ -626,13 +635,14 @@ def __init__( self.delegate_to = delegate_to self.poll_sleep = poll_sleep self.job = None - self.hook: Optional[DataflowHook] = None + self.hook: DataflowHook | None = None self.impersonation_chain = impersonation_chain self.environment = environment self.cancel_timeout = cancel_timeout self.wait_until_finished = wait_until_finished + self.append_job_name = append_job_name - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: self.hook = DataflowHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -657,6 +667,7 @@ def set_current_job(current_job): project_id=self.project_id, location=self.location, environment=self.environment, + append_job_name=self.append_job_name, ) return job @@ -727,6 +738,14 @@ class DataflowStartFlexTemplateOperator(BaseOperator): If you in your pipeline do not call the wait_for_pipeline method, and pass wait_until_finish=False to the operator, the second loop will check once is job not in terminal state and exit the loop. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). """ template_fields: Sequence[str] = ("body", "location", "project_id", "gcp_conn_id") @@ -734,14 +753,15 @@ class DataflowStartFlexTemplateOperator(BaseOperator): def __init__( self, - body: Dict, + body: dict, location: str, - project_id: Optional[str] = None, + project_id: str | None = None, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, + delegate_to: str | None = None, drain_pipeline: bool = False, - cancel_timeout: Optional[int] = 10 * 60, - wait_until_finished: Optional[bool] = None, + cancel_timeout: int | None = 10 * 60, + wait_until_finished: bool | None = None, + impersonation_chain: str | Sequence[str] | None = None, *args, **kwargs, ) -> None: @@ -755,15 +775,17 @@ def __init__( self.cancel_timeout = cancel_timeout self.wait_until_finished = wait_until_finished self.job = None - self.hook: Optional[DataflowHook] = None + self.hook: DataflowHook | None = None + self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): self.hook = DataflowHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, drain_pipeline=self.drain_pipeline, cancel_timeout=self.cancel_timeout, wait_until_finished=self.wait_until_finished, + impersonation_chain=self.impersonation_chain, ) def set_current_job(current_job): @@ -821,6 +843,14 @@ class DataflowStartSqlJobOperator(BaseOperator): :param drain_pipeline: Optional, set to True if want to stop streaming job by draining it instead of canceling during killing task instance. See: https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). """ template_fields: Sequence[str] = ( @@ -837,12 +867,13 @@ def __init__( self, job_name: str, query: str, - options: Dict[str, Any], + options: dict[str, Any], location: str = DEFAULT_DATAFLOW_LOCATION, - project_id: Optional[str] = None, + project_id: str | None = None, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, + delegate_to: str | None = None, drain_pipeline: bool = False, + impersonation_chain: str | Sequence[str] | None = None, *args, **kwargs, ) -> None: @@ -855,14 +886,16 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.drain_pipeline = drain_pipeline + self.impersonation_chain = impersonation_chain self.job = None - self.hook: Optional[DataflowHook] = None + self.hook: DataflowHook | None = None - def execute(self, context: 'Context'): + def execute(self, context: Context): self.hook = DataflowHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, drain_pipeline=self.drain_pipeline, + impersonation_chain=self.impersonation_chain, ) def set_current_job(current_job): @@ -998,20 +1031,20 @@ def __init__( *, py_file: str, job_name: str = "{{task.task_id}}", - dataflow_default_options: Optional[dict] = None, - options: Optional[dict] = None, + dataflow_default_options: dict | None = None, + options: dict | None = None, py_interpreter: str = "python3", - py_options: Optional[List[str]] = None, - py_requirements: Optional[List[str]] = None, + py_options: list[str] | None = None, + py_requirements: list[str] | None = None, py_system_site_packages: bool = False, - project_id: Optional[str] = None, + project_id: str | None = None, location: str = DEFAULT_DATAFLOW_LOCATION, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, + delegate_to: str | None = None, poll_sleep: int = 10, drain_pipeline: bool = False, - cancel_timeout: Optional[int] = 10 * 60, - wait_until_finished: Optional[bool] = None, + cancel_timeout: int | None = 10 * 60, + wait_until_finished: bool | None = None, **kwargs, ) -> None: # TODO: Remove one day @@ -1043,10 +1076,10 @@ def __init__( self.cancel_timeout = cancel_timeout self.wait_until_finished = wait_until_finished self.job_id = None - self.beam_hook: Optional[BeamHook] = None - self.dataflow_hook: Optional[DataflowHook] = None + self.beam_hook: BeamHook | None = None + self.dataflow_hook: DataflowHook | None = None - def execute(self, context: 'Context'): + def execute(self, context: Context): """Execute the python dataflow job.""" self.beam_hook = BeamHook(runner=BeamRunnerType.DataflowRunner) self.dataflow_hook = DataflowHook( @@ -1109,3 +1142,96 @@ def on_kill(self) -> None: self.dataflow_hook.cancel_job( job_id=self.job_id, project_id=self.project_id or self.dataflow_hook.project_id ) + + +class DataflowStopJobOperator(BaseOperator): + """ + Stops the job with the specified name prefix or Job ID. + All jobs with provided name prefix will be stopped. + Streaming jobs are drained by default. + + Parameter ``job_name_prefix`` and ``job_id`` are mutually exclusive. + + .. seealso:: + For more details on stopping a pipeline see: + https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataflowStopJobOperator` + + :param job_name_prefix: Name prefix specifying which jobs are to be stopped. + :param job_id: Job ID specifying which jobs are to be stopped. + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param location: Optional, Job location. If set to None or missing, "us-central1" will be used. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :param poll_sleep: The time in seconds to sleep between polling Google + Cloud Platform for the dataflow job status to confirm it's stopped. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param drain_pipeline: Optional, set to False if want to stop streaming job by canceling it + instead of draining. See: https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline + :param stop_timeout: wait time in seconds for successful job canceling/draining + """ + + def __init__( + self, + job_name_prefix: str | None = None, + job_id: str | None = None, + project_id: str | None = None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + poll_sleep: int = 10, + impersonation_chain: str | Sequence[str] | None = None, + stop_timeout: int | None = 10 * 60, + drain_pipeline: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.poll_sleep = poll_sleep + self.stop_timeout = stop_timeout + self.job_name = job_name_prefix + self.job_id = job_id + self.project_id = project_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.hook: DataflowHook | None = None + self.drain_pipeline = drain_pipeline + + def execute(self, context: Context) -> None: + self.dataflow_hook = DataflowHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + poll_sleep=self.poll_sleep, + impersonation_chain=self.impersonation_chain, + cancel_timeout=self.stop_timeout, + drain_pipeline=self.drain_pipeline, + ) + if self.job_id or self.dataflow_hook.is_job_dataflow_running( + name=self.job_name, + project_id=self.project_id, + location=self.location, + ): + self.dataflow_hook.cancel_job( + job_name=self.job_name, + project_id=self.project_id, + location=self.location, + job_id=self.job_id, + ) + else: + self.log.info("No jobs to stop") + + return None diff --git a/airflow/providers/google/cloud/operators/dataform.py b/airflow/providers/google/cloud/operators/dataform.py new file mode 100644 index 0000000000000..3c92d727dd5c7 --- /dev/null +++ b/airflow/providers/google/cloud/operators/dataform.py @@ -0,0 +1,1183 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence + +from airflow.providers.google.cloud.links.dataform import ( + DataformRepositoryLink, + DataformWorkflowInvocationLink, + DataformWorkspaceLink, +) + +if TYPE_CHECKING: + from airflow.utils.context import Context + +from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault +from google.api_core.retry import Retry +from google.cloud.dataform_v1beta1.types import ( + CompilationResult, + InstallNpmPackagesResponse, + MakeDirectoryResponse, + Repository, + WorkflowInvocation, + Workspace, + WriteFileResponse, +) + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.hooks.dataform import DataformHook + + +class DataformCreateCompilationResultOperator(BaseOperator): + """ + Creates a new CompilationResult in a given project and location. + + :param project_id: Required. The ID of the Google Cloud project that the task belongs to. + :param region: Required. The ID of the Google Cloud region that the task belongs to. + :param repository_id: Required. The ID of the Dataform repository that the task belongs to. + :param compilation_result: Required. The compilation result to create. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use when fetching connection info. + :param delegate_to: The account to impersonate, if any. For this to work, the service accountmaking the + request must have domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + def __init__( + self, + project_id: str, + region: str, + repository_id: str, + compilation_result: CompilationResult | dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.project_id = project_id + self.region = region + self.repository_id = repository_id + self.compilation_result = compilation_result + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = DataformHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + result = hook.create_compilation_result( + project_id=self.project_id, + region=self.region, + repository_id=self.repository_id, + compilation_result=self.compilation_result, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return CompilationResult.to_dict(result) + + +class DataformGetCompilationResultOperator(BaseOperator): + """ + Fetches a single CompilationResult. + + :param project_id: Required. The ID of the Google Cloud project that the task belongs to. + :param region: Required. The ID of the Google Cloud region that the task belongs to. + :param repository_id: Required. The ID of the Dataform repository that the task belongs to. + :param compilation_result_id: The Id of the Dataform Compilation Result + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use when fetching connection info. + :param delegate_to: The account to impersonate, if any. For this to work, the service accountmaking the + request must have domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ("repository_id", "compilation_result_id", "delegate_to", "impersonation_chain") + + def __init__( + self, + project_id: str, + region: str, + repository_id: str, + compilation_result_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.project_id = project_id + self.region = region + self.repository_id = repository_id + self.compilation_result_id = compilation_result_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = DataformHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + result = hook.get_compilation_result( + project_id=self.project_id, + region=self.region, + repository_id=self.repository_id, + compilation_result_id=self.compilation_result_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return CompilationResult.to_dict(result) + + +class DataformCreateWorkflowInvocationOperator(BaseOperator): + """ + Creates a new WorkflowInvocation in a given Repository. + + :param project_id: Required. The ID of the Google Cloud project that the task belongs to. + :param region: Required. The ID of the Google Cloud region that the task belongs to. + :param repository_id: Required. The ID of the Dataform repository that the task belongs to. + :param workflow_invocation: Required. The workflow invocation resource to create. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use when fetching connection info. + :param delegate_to: The account to impersonate, if any. For this to work, the service accountmaking the + request must have domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param asynchronous: Flag to return workflow_invocation_id from the Dataform API. + This is useful for submitting long running workflows and + waiting on them asynchronously using the DataformWorkflowInvocationStateSensor + :param wait_time: Number of seconds between checks + """ + + template_fields = ("workflow_invocation", "delegate_to", "impersonation_chain") + operator_extra_links = (DataformWorkflowInvocationLink(),) + + def __init__( + self, + project_id: str, + region: str, + repository_id: str, + workflow_invocation: WorkflowInvocation | dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: int | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + asynchronous: bool = False, + wait_time: int = 10, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.project_id = project_id + self.region = region + self.repository_id = repository_id + self.workflow_invocation = workflow_invocation + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.asynchronous = asynchronous + self.wait_time = wait_time + + def execute(self, context: Context): + hook = DataformHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + result = hook.create_workflow_invocation( + project_id=self.project_id, + region=self.region, + repository_id=self.repository_id, + workflow_invocation=self.workflow_invocation, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + workflow_invocation_id = result.name.split("/")[-1] + DataformWorkflowInvocationLink.persist( + operator_instance=self, + context=context, + project_id=self.project_id, + region=self.region, + repository_id=self.repository_id, + workflow_invocation_id=workflow_invocation_id, + ) + if not self.asynchronous: + hook.wait_for_workflow_invocation( + workflow_invocation_id=workflow_invocation_id, + repository_id=self.repository_id, + project_id=self.project_id, + region=self.region, + timeout=self.timeout, + wait_time=self.wait_time, + ) + return WorkflowInvocation.to_dict(result) + + +class DataformGetWorkflowInvocationOperator(BaseOperator): + """ + Fetches a single WorkflowInvocation. + + :param project_id: Required. The ID of the Google Cloud project that the task belongs to. + :param region: Required. The ID of the Google Cloud region that the task belongs to. + :param repository_id: Required. The ID of the Dataform repository that the task belongs to. + :param workflow_invocation_id: the workflow invocation resource's id. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use when fetching connection info. + :param delegate_to: The account to impersonate, if any. For this to work, the service accountmaking the + request must have domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ("repository_id", "workflow_invocation_id", "delegate_to", "impersonation_chain") + operator_extra_links = (DataformWorkflowInvocationLink(),) + + def __init__( + self, + project_id: str, + region: str, + repository_id: str, + workflow_invocation_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.project_id = project_id + self.region = region + self.repository_id = repository_id + self.workflow_invocation_id = workflow_invocation_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = DataformHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + result = hook.get_workflow_invocation( + project_id=self.project_id, + region=self.region, + repository_id=self.repository_id, + workflow_invocation_id=self.workflow_invocation_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return WorkflowInvocation.to_dict(result) + + +class DataformCancelWorkflowInvocationOperator(BaseOperator): + """ + Requests cancellation of a running WorkflowInvocation. + + :param project_id: Required. The ID of the Google Cloud project that the task belongs to. + :param region: Required. The ID of the Google Cloud region that the task belongs to. + :param repository_id: Required. The ID of the Dataform repository that the task belongs to. + :param workflow_invocation_id: the workflow invocation resource's id. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use when fetching connection info. + :param delegate_to: The account to impersonate, if any. For this to work, the service accountmaking the + request must have domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ("repository_id", "workflow_invocation_id", "delegate_to", "impersonation_chain") + operator_extra_links = (DataformWorkflowInvocationLink(),) + + def __init__( + self, + project_id: str, + region: str, + repository_id: str, + workflow_invocation_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.project_id = project_id + self.region = region + self.repository_id = repository_id + self.workflow_invocation_id = workflow_invocation_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = DataformHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + hook.cancel_workflow_invocation( + project_id=self.project_id, + region=self.region, + repository_id=self.repository_id, + workflow_invocation_id=self.workflow_invocation_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class DataformCreateRepositoryOperator(BaseOperator): + """ + Creates repository. + + :param project_id: Required. The ID of the Google Cloud project that the task belongs to. + :param region: Required. The ID of the Google Cloud region that the task belongs to. + :param repository_id: Required. The ID of the Dataform repository that the task belongs to. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use when fetching connection info. + :param delegate_to: The account to impersonate, if any. For this to work, the service accountmaking the + request must have domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + operator_extra_links = (DataformRepositoryLink(),) + template_fields = ( + "project_id", + "repository_id", + "delegate_to", + "impersonation_chain", + ) + + def __init__( + self, + project_id: str, + region: str, + repository_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.project_id = project_id + self.region = region + self.repository_id = repository_id + + self.retry = retry + self.timeout = timeout + self.metadata = metadata + + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context) -> dict: + hook = DataformHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + repository = hook.create_repository( + project_id=self.project_id, + region=self.region, + repository_id=self.repository_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + DataformRepositoryLink.persist( + operator_instance=self, + context=context, + project_id=self.project_id, + region=self.region, + repository_id=self.repository_id, + ) + + return Repository.to_dict(repository) + + +class DataformDeleteRepositoryOperator(BaseOperator): + """ + Deletes repository. + + :param project_id: Required. The ID of the Google Cloud project where repository located. + :param region: Required. The ID of the Google Cloud region where repository located. + :param repository_id: Required. The ID of the Dataform repository that should be deleted. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use when fetching connection info. + :param delegate_to: The account to impersonate, if any. For this to work, the service accountmaking the + request must have domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ( + "project_id", + "repository_id", + "delegate_to", + "impersonation_chain", + ) + + def __init__( + self, + project_id: str, + region: str, + repository_id: str, + force: bool = True, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.repository_id = repository_id + self.project_id = project_id + self.region = region + self.force = force + + self.retry = retry + self.timeout = timeout + self.metadata = metadata + + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context) -> None: + hook = DataformHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + hook.delete_repository( + project_id=self.project_id, + region=self.region, + repository_id=self.repository_id, + force=self.force, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class DataformCreateWorkspaceOperator(BaseOperator): + """ + Creates workspace. + + :param project_id: Required. The ID of the Google Cloud project where workspace should be in. + :param region: Required. Name of the Google Cloud region that where workspace should be in. + :param repository_id: Required. The ID of the Dataform repository that the workspace belongs to. + :param workspace_id: Required. The ID of the new workspace that will be created. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use when fetching connection info. + :param delegate_to: The account to impersonate, if any. For this to work, the service accountmaking the + request must have domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + operator_extra_links = (DataformWorkspaceLink(),) + template_fields = ( + "project_id", + "repository_id", + "delegate_to", + "impersonation_chain", + ) + + def __init__( + self, + project_id: str, + region: str, + repository_id: str, + workspace_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.project_id = project_id + self.workspace_id = workspace_id + self.repository_id = repository_id + self.region = region + + self.retry = retry + self.timeout = timeout + self.metadata = metadata + + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context) -> dict: + hook = DataformHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + workspace = hook.create_workspace( + project_id=self.project_id, + region=self.region, + repository_id=self.repository_id, + workspace_id=self.workspace_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + DataformWorkspaceLink.persist( + operator_instance=self, + context=context, + project_id=self.project_id, + region=self.region, + repository_id=self.repository_id, + workspace_id=self.workspace_id, + ) + + return Workspace.to_dict(workspace) + + +class DataformDeleteWorkspaceOperator(BaseOperator): + """ + Deletes workspace. + + :param project_id: Required. The ID of the Google Cloud project where workspace located. + :param region: Required. The ID of the Google Cloud region where workspace located. + :param repository_id: Required. The ID of the Dataform repository where workspace located. + :param workspace_id: Required. The ID of the Dataform workspace that should be deleted. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use when fetching connection info. + :param delegate_to: The account to impersonate, if any. For this to work, the service accountmaking the + request must have domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ( + "project_id", + "repository_id", + "workspace_id", + "delegate_to", + "impersonation_chain", + ) + + def __init__( + self, + project_id: str, + region: str, + repository_id: str, + workspace_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.project_id = project_id + self.region = region + self.repository_id = repository_id + self.workspace_id = workspace_id + + self.retry = retry + self.timeout = timeout + self.metadata = metadata + + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context) -> None: + hook = DataformHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + hook.delete_workspace( + project_id=self.project_id, + region=self.region, + repository_id=self.repository_id, + workspace_id=self.workspace_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class DataformWriteFileOperator(BaseOperator): + """ + Writes new file to specified workspace. + + :param project_id: Required. The ID of the Google Cloud project where workspace located. + :param region: Required. The ID of the Google Cloud region where workspace located. + :param repository_id: Required. The ID of the Dataform repository where workspace located. + :param workspace_id: Required. The ID of the Dataform workspace where files should be created. + :param filepath: Required. Path to file including name of the file relative to workspace root. + :param contents: Required. Content of the file to be written. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use when fetching connection info. + :param delegate_to: The account to impersonate, if any. For this to work, the service accountmaking the + request must have domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ( + "project_id", + "repository_id", + "workspace_id", + "delegate_to", + "impersonation_chain", + ) + + def __init__( + self, + project_id: str, + region: str, + repository_id: str, + workspace_id: str, + filepath: str, + contents: bytes, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.project_id = project_id + self.region = region + self.repository_id = repository_id + self.workspace_id = workspace_id + self.filepath = filepath + self.contents = contents + + self.retry = retry + self.timeout = timeout + self.metadata = metadata + + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context) -> dict: + hook = DataformHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + write_file_response = hook.write_file( + project_id=self.project_id, + region=self.region, + repository_id=self.repository_id, + workspace_id=self.workspace_id, + filepath=self.filepath, + contents=self.contents, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return WriteFileResponse.to_dict(write_file_response) + + +class DataformMakeDirectoryOperator(BaseOperator): + """ + Makes new directory in specified workspace. + + :param project_id: Required. The ID of the Google Cloud project where workspace located. + :param region: Required. The ID of the Google Cloud region where workspace located. + :param repository_id: Required. The ID of the Dataform repository where workspace located. + :param workspace_id: Required. The ID of the Dataform workspace where directory should be created. + :param path: Required. The directory's full path including directory name, relative to the workspace root. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use when fetching connection info. + :param delegate_to: The account to impersonate, if any. For this to work, the service accountmaking the + request must have domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ( + "project_id", + "repository_id", + "workspace_id", + "delegate_to", + "impersonation_chain", + ) + + def __init__( + self, + project_id: str, + region: str, + repository_id: str, + workspace_id: str, + directory_path: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.project_id = project_id + self.region = region + self.repository_id = repository_id + self.workspace_id = workspace_id + self.directory_path = directory_path + + self.retry = retry + self.timeout = timeout + self.metadata = metadata + + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context) -> dict: + hook = DataformHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + make_directory_response = hook.make_directory( + project_id=self.project_id, + region=self.region, + repository_id=self.repository_id, + workspace_id=self.workspace_id, + path=self.directory_path, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + return MakeDirectoryResponse.to_dict(make_directory_response) + + +class DataformRemoveFileOperator(BaseOperator): + """ + Removes file in specified workspace. + + :param project_id: Required. The ID of the Google Cloud project where workspace located. + :param region: Required. The ID of the Google Cloud region where workspace located. + :param repository_id: Required. The ID of the Dataform repository where workspace located. + :param workspace_id: Required. The ID of the Dataform workspace where directory located. + :param filepath: Required. The full path including name of the file, relative to the workspace root. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use when fetching connection info. + :param delegate_to: The account to impersonate, if any. For this to work, the service accountmaking the + request must have domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ( + "project_id", + "repository_id", + "workspace_id", + "delegate_to", + "impersonation_chain", + ) + + def __init__( + self, + project_id: str, + region: str, + repository_id: str, + workspace_id: str, + filepath: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.project_id = project_id + self.region = region + self.repository_id = repository_id + self.workspace_id = workspace_id + self.filepath = filepath + + self.retry = retry + self.timeout = timeout + self.metadata = metadata + + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context) -> None: + hook = DataformHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + hook.remove_file( + project_id=self.project_id, + region=self.region, + repository_id=self.repository_id, + workspace_id=self.workspace_id, + filepath=self.filepath, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class DataformRemoveDirectoryOperator(BaseOperator): + """ + Removes directory in specified workspace. + + :param project_id: Required. The ID of the Google Cloud project where workspace located. + :param region: Required. The ID of the Google Cloud region where workspace located. + :param repository_id: Required. The ID of the Dataform repository where workspace located. + :param workspace_id: Required. The ID of the Dataform workspace where directory located. + :param path: Required. The directory's full path including directory name, relative to the workspace root. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use when fetching connection info. + :param delegate_to: The account to impersonate, if any. For this to work, the service accountmaking the + request must have domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ( + "project_id", + "repository_id", + "workspace_id", + "delegate_to", + "impersonation_chain", + ) + + def __init__( + self, + project_id: str, + region: str, + repository_id: str, + workspace_id: str, + directory_path: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.project_id = project_id + self.region = region + self.repository_id = repository_id + self.workspace_id = workspace_id + self.directory_path = directory_path + + self.retry = retry + self.timeout = timeout + self.metadata = metadata + + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context) -> None: + hook = DataformHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + hook.remove_directory( + project_id=self.project_id, + region=self.region, + repository_id=self.repository_id, + workspace_id=self.workspace_id, + path=self.directory_path, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + +class DataformInstallNpmPackagesOperator(BaseOperator): + """ + Installs npm dependencies in the provided workspace. Requires "package.json" to be created in workspace + + :param project_id: Required. The ID of the Google Cloud project where workspace located. + :param region: Required. The ID of the Google Cloud region where workspace located. + :param repository_id: Required. The ID of the Dataform repository where workspace located. + :param workspace_id: Required. The ID of the Dataform workspace. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use when fetching connection info. + :param delegate_to: The account to impersonate, if any. For this to work, the service accountmaking the + request must have domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ( + "project_id", + "repository_id", + "workspace_id", + "delegate_to", + "impersonation_chain", + ) + + def __init__( + self, + project_id: str, + region: str, + repository_id: str, + workspace_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.project_id = project_id + self.region = region + self.repository_id = repository_id + self.workspace_id = workspace_id + + self.retry = retry + self.timeout = timeout + self.metadata = metadata + + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context) -> dict: + hook = DataformHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + response = hook.install_npm_packages( + project_id=self.project_id, + region=self.region, + repository_id=self.repository_id, + workspace_id=self.workspace_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + + return InstallNpmPackagesResponse.to_dict(response) diff --git a/airflow/providers/google/cloud/operators/datafusion.py b/airflow/providers/google/cloud/operators/datafusion.py index b461a899f59ec..f17255a42a144 100644 --- a/airflow/providers/google/cloud/operators/datafusion.py +++ b/airflow/providers/google/cloud/operators/datafusion.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains Google DataFusion operators.""" +from __future__ import annotations + from time import sleep -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Sequence from google.api_core.retry import exponential_sleep_generator from googleapiclient.errors import HttpError @@ -55,13 +56,13 @@ class DataFusionInstanceLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", - task_instance: Union[ - "CloudDataFusionRestartInstanceOperator", - "CloudDataFusionCreateInstanceOperator", - "CloudDataFusionUpdateInstanceOperator", - "CloudDataFusionGetInstanceOperator", - ], + context: Context, + task_instance: ( + CloudDataFusionRestartInstanceOperator + | CloudDataFusionCreateInstanceOperator + | CloudDataFusionUpdateInstanceOperator + | CloudDataFusionGetInstanceOperator + ), project_id: str, ): task_instance.xcom_push( @@ -84,12 +85,12 @@ class DataFusionPipelineLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", - task_instance: Union[ - "CloudDataFusionCreatePipelineOperator", - "CloudDataFusionStartPipelineOperator", - "CloudDataFusionStopPipelineOperator", - ], + context: Context, + task_instance: ( + CloudDataFusionCreatePipelineOperator + | CloudDataFusionStartPipelineOperator + | CloudDataFusionStopPipelineOperator + ), uri: str, ): task_instance.xcom_push( @@ -111,8 +112,8 @@ class DataFusionPipelinesLink(BaseGoogleLink): @staticmethod def persist( - context: "Context", - task_instance: "CloudDataFusionListPipelinesOperator", + context: Context, + task_instance: CloudDataFusionListPipelinesOperator, uri: str, ): task_instance.xcom_push( @@ -162,11 +163,11 @@ def __init__( *, instance_name: str, location: str, - project_id: Optional[str] = None, + project_id: str | None = None, api_version: str = "v1beta1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -178,7 +179,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = DataFusionHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -234,11 +235,11 @@ def __init__( *, instance_name: str, location: str, - project_id: Optional[str] = None, + project_id: str | None = None, api_version: str = "v1beta1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -250,7 +251,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = DataFusionHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -306,13 +307,13 @@ def __init__( self, *, instance_name: str, - instance: Dict[str, Any], + instance: dict[str, Any], location: str, - project_id: Optional[str] = None, + project_id: str | None = None, api_version: str = "v1beta1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -325,7 +326,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: hook = DataFusionHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -343,7 +344,7 @@ def execute(self, context: 'Context') -> dict: instance = hook.wait_for_operation(operation) self.log.info("Instance %s created successfully", self.instance_name) except HttpError as err: - if err.resp.status not in (409, '409'): + if err.resp.status not in (409, "409"): raise self.log.info("Instance %s already exists", self.instance_name) instance = hook.get_instance( @@ -351,7 +352,7 @@ def execute(self, context: 'Context') -> dict: ) # Wait for instance to be ready for time_to_wait in exponential_sleep_generator(initial=10, maximum=120): - if instance['state'] != 'CREATING': + if instance["state"] != "CREATING": break sleep(time_to_wait) instance = hook.get_instance( @@ -408,14 +409,14 @@ def __init__( self, *, instance_name: str, - instance: Dict[str, Any], + instance: dict[str, Any], update_mask: str, location: str, - project_id: Optional[str] = None, + project_id: str | None = None, api_version: str = "v1beta1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -429,7 +430,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = DataFusionHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -488,11 +489,11 @@ def __init__( *, instance_name: str, location: str, - project_id: Optional[str] = None, + project_id: str | None = None, api_version: str = "v1beta1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -504,7 +505,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: hook = DataFusionHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -565,15 +566,15 @@ def __init__( self, *, pipeline_name: str, - pipeline: Dict[str, Any], + pipeline: dict[str, Any], instance_name: str, location: str, namespace: str = "default", - project_id: Optional[str] = None, + project_id: str | None = None, api_version: str = "v1beta1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -588,7 +589,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = DataFusionHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -656,13 +657,13 @@ def __init__( pipeline_name: str, instance_name: str, location: str, - version_id: Optional[str] = None, + version_id: str | None = None, namespace: str = "default", - project_id: Optional[str] = None, + project_id: str | None = None, api_version: str = "v1beta1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -677,7 +678,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = DataFusionHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -744,14 +745,14 @@ def __init__( *, instance_name: str, location: str, - artifact_name: Optional[str] = None, - artifact_version: Optional[str] = None, + artifact_name: str | None = None, + artifact_version: str | None = None, namespace: str = "default", - project_id: Optional[str] = None, + project_id: str | None = None, api_version: str = "v1beta1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -766,7 +767,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: hook = DataFusionHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -843,15 +844,15 @@ def __init__( pipeline_name: str, instance_name: str, location: str, - runtime_args: Optional[Dict[str, Any]] = None, - success_states: Optional[List[str]] = None, + runtime_args: dict[str, Any] | None = None, + success_states: list[str] | None = None, namespace: str = "default", pipeline_timeout: int = 5 * 60, - project_id: Optional[str] = None, + project_id: str | None = None, api_version: str = "v1beta1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, asynchronous=False, **kwargs, ) -> None: @@ -874,7 +875,7 @@ def __init__( else: self.success_states = SUCCESS_STATES + [PipelineStates.RUNNING] - def execute(self, context: 'Context') -> str: + def execute(self, context: Context) -> str: hook = DataFusionHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -955,11 +956,11 @@ def __init__( instance_name: str, location: str, namespace: str = "default", - project_id: Optional[str] = None, + project_id: str | None = None, api_version: str = "v1beta1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -973,7 +974,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = DataFusionHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/google/cloud/operators/dataplex.py b/airflow/providers/google/cloud/operators/dataplex.py index 7b225a0c010ad..12c3917690d07 100644 --- a/airflow/providers/google/cloud/operators/dataplex.py +++ b/airflow/providers/google/cloud/operators/dataplex.py @@ -14,22 +14,28 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains Google Dataplex operators.""" + +from __future__ import annotations + from time import sleep -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Sequence if TYPE_CHECKING: from airflow.utils.context import Context from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry, exponential_sleep_generator -from google.cloud.dataplex_v1.types import Task +from google.cloud.dataplex_v1.types import Lake, Task from googleapiclient.errors import HttpError from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.dataplex import DataplexHook -from airflow.providers.google.cloud.links.dataplex import DataplexTaskLink, DataplexTasksLink +from airflow.providers.google.cloud.links.dataplex import ( + DataplexLakeLink, + DataplexTaskLink, + DataplexTasksLink, +) class DataplexCreateTaskOperator(BaseOperator): @@ -73,7 +79,7 @@ class DataplexCreateTaskOperator(BaseOperator): "delegate_to", "impersonation_chain", ) - template_fields_renderers = {'body': 'json'} + template_fields_renderers = {"body": "json"} operator_extra_links = (DataplexTaskLink(),) def __init__( @@ -81,16 +87,16 @@ def __init__( project_id: str, region: str, lake_id: str, - body: Dict[str, Any], + body: dict[str, Any], dataplex_task_id: str, - validate_only: Optional[bool] = None, + validate_only: bool | None = None, api_version: str = "v1", - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, asynchronous: bool = False, *args, **kwargs, @@ -111,7 +117,7 @@ def __init__( self.impersonation_chain = impersonation_chain self.asynchronous = asynchronous - def execute(self, context: "Context") -> dict: + def execute(self, context: Context) -> dict: hook = DataplexHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -142,7 +148,7 @@ def execute(self, context: "Context") -> dict: self.log.info("Is operation done already? %s", is_done) return is_done except HttpError as err: - if err.resp.status not in (409, '409'): + if err.resp.status not in (409, "409"): raise self.log.info("Task %s already exists", self.dataplex_task_id) # Wait for task to be ready @@ -156,7 +162,7 @@ def execute(self, context: "Context") -> dict: timeout=self.timeout, metadata=self.metadata, ) - if task['state'] != 'CREATING': + if task["state"] != "CREATING": break sleep(time_to_wait) @@ -199,12 +205,12 @@ def __init__( lake_id: str, dataplex_task_id: str, api_version: str = "v1", - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, *args, **kwargs, ) -> None: @@ -221,7 +227,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: hook = DataplexHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -293,17 +299,17 @@ def __init__( project_id: str, region: str, lake_id: str, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - filter: Optional[str] = None, - order_by: Optional[str] = None, + page_size: int | None = None, + page_token: str | None = None, + filter: str | None = None, + order_by: str | None = None, api_version: str = "v1", - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, *args, **kwargs, ) -> None: @@ -323,7 +329,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context") -> List[dict]: + def execute(self, context: Context) -> list[dict]: hook = DataplexHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -385,12 +391,12 @@ def __init__( lake_id: str, dataplex_task_id: str, api_version: str = "v1", - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, *args, **kwargs, ) -> None: @@ -407,7 +413,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context") -> dict: + def execute(self, context: Context) -> dict: hook = DataplexHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -426,4 +432,214 @@ def execute(self, context: "Context") -> dict: timeout=self.timeout, metadata=self.metadata, ) + DataplexTasksLink.persist(context=context, task_instance=self) return Task.to_dict(task) + + +class DataplexCreateLakeOperator(BaseOperator): + """ + Creates a lake resource within a lake. + + :param project_id: Required. The ID of the Google Cloud project that the lake belongs to. + :param region: Required. The ID of the Google Cloud region that the lake belongs to. + :param lake_id: Required. Lake identifier. + :param body: Required. The Request body contains an instance of Lake. + :param validate_only: Optional. Only validate the request, but do not perform mutations. The default is + false. + :param api_version: The version of the api that will be requested for example 'v1'. + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :param gcp_conn_id: The connection ID to use when fetching connection info. + :param delegate_to: The account to impersonate, if any. For this to work, the service account making the + request must have domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param asynchronous: Flag informing should the Dataplex lake be created asynchronously. + This is useful for long running creating lakes and + waiting on them asynchronously using the DataplexLakeSensor + """ + + template_fields = ( + "project_id", + "lake_id", + "body", + "validate_only", + "delegate_to", + "impersonation_chain", + ) + template_fields_renderers = {"body": "json"} + operator_extra_links = (DataplexLakeLink(),) + + def __init__( + self, + project_id: str, + region: str, + lake_id: str, + body: dict[str, Any], + validate_only: bool | None = None, + api_version: str = "v1", + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + asynchronous: bool = False, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.project_id = project_id + self.region = region + self.lake_id = lake_id + self.body = body + self.validate_only = validate_only + self.api_version = api_version + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.asynchronous = asynchronous + + def execute(self, context: Context) -> dict: + hook = DataplexHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Creating Dataplex lake %s", self.lake_id) + + try: + operation = hook.create_lake( + project_id=self.project_id, + region=self.region, + lake_id=self.lake_id, + body=self.body, + validate_only=self.validate_only, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + if not self.asynchronous: + self.log.info("Waiting for Dataplex lake %s to be created", self.lake_id) + lake = hook.wait_for_operation(timeout=self.timeout, operation=operation) + self.log.info("Lake %s created successfully", self.lake_id) + else: + is_done = operation.done() + self.log.info("Is operation done already? %s", is_done) + return is_done + except HttpError as err: + if err.resp.status not in (409, "409"): + raise + self.log.info("Lake %s already exists", self.lake_id) + # Wait for lake to be ready + for time_to_wait in exponential_sleep_generator(initial=10, maximum=120): + lake = hook.get_lake( + project_id=self.project_id, + region=self.region, + lake_id=self.lake_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + if lake["state"] != "CREATING": + break + sleep(time_to_wait) + DataplexLakeLink.persist( + context=context, + task_instance=self, + ) + return Lake.to_dict(lake) + + +class DataplexDeleteLakeOperator(BaseOperator): + """ + Delete the lake resource. + + :param project_id: Required. The ID of the Google Cloud project that the lake belongs to. + :param region: Required. The ID of the Google Cloud region that the lake belongs to. + :param lake_id: Required. Lake identifier. + :param api_version: The version of the api that will be requested for example 'v1'. + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + :param gcp_conn_id: The connection ID to use when fetching connection info. + :param delegate_to: The account to impersonate, if any. For this to work, the service account making the + request must have domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields = ("project_id", "lake_id", "delegate_to", "impersonation_chain") + operator_extra_links = (DataplexLakeLink(),) + + def __init__( + self, + project_id: str, + region: str, + lake_id: str, + api_version: str = "v1", + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ) -> None: + + super().__init__(*args, **kwargs) + self.project_id = project_id + self.region = region + self.lake_id = lake_id + self.api_version = api_version + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context) -> None: + + hook = DataplexHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + api_version=self.api_version, + impersonation_chain=self.impersonation_chain, + ) + + self.log.info("Deleting Dataplex lake %s", self.lake_id) + + operation = hook.delete_lake( + project_id=self.project_id, + region=self.region, + lake_id=self.lake_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + DataplexLakeLink.persist(context=context, task_instance=self) + hook.wait_for_operation(timeout=self.timeout, operation=operation) + self.log.info("Dataplex lake %s deleted successfully!", self.lake_id) diff --git a/airflow/providers/google/cloud/operators/dataprep.py b/airflow/providers/google/cloud/operators/dataprep.py index 54f76dc381b47..61340b07473f3 100644 --- a/airflow/providers/google/cloud/operators/dataprep.py +++ b/airflow/providers/google/cloud/operators/dataprep.py @@ -16,10 +16,13 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Dataprep operator.""" +from __future__ import annotations + from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.dataprep import GoogleDataprepHook +from airflow.providers.google.cloud.links.dataprep import DataprepFlowLink, DataprepJobGroupLink if TYPE_CHECKING: from airflow.utils.context import Context @@ -34,22 +37,28 @@ class DataprepGetJobsForJobGroupOperator(BaseOperator): For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:DataprepGetJobsForJobGroupOperator` - :param job_id The ID of the job that will be requests + :param job_group_id The ID of the job group that will be requests """ - template_fields: Sequence[str] = ("job_id",) + template_fields: Sequence[str] = ("job_group_id",) - def __init__(self, *, dataprep_conn_id: str = "dataprep_default", job_id: int, **kwargs) -> None: + def __init__( + self, + *, + dataprep_conn_id: str = "dataprep_default", + job_group_id: int | str, + **kwargs, + ) -> None: super().__init__(**kwargs) self.dataprep_conn_id = (dataprep_conn_id,) - self.job_id = job_id + self.job_group_id = job_group_id - def execute(self, context: 'Context') -> dict: - self.log.info("Fetching data for job with id: %d ...", self.job_id) + def execute(self, context: Context) -> dict: + self.log.info("Fetching data for job with id: %d ...", self.job_group_id) hook = GoogleDataprepHook( dataprep_conn_id="dataprep_default", ) - response = hook.get_jobs_for_job_group(job_id=self.job_id) + response = hook.get_jobs_for_job_group(job_id=int(self.job_group_id)) return response @@ -63,33 +72,49 @@ class DataprepGetJobGroupOperator(BaseOperator): For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:DataprepGetJobGroupOperator` - :param job_group_id: The ID of the job that will be requests + :param job_group_id: The ID of the job group that will be requests :param embed: Comma-separated list of objects to pull in as part of the response :param include_deleted: if set to "true", will include deleted objects """ - template_fields: Sequence[str] = ("job_group_id", "embed") + template_fields: Sequence[str] = ( + "job_group_id", + "embed", + "project_id", + ) + operator_extra_links = (DataprepJobGroupLink(),) def __init__( self, *, dataprep_conn_id: str = "dataprep_default", - job_group_id: int, + project_id: str | None = None, + job_group_id: int | str, embed: str, include_deleted: bool, **kwargs, ) -> None: super().__init__(**kwargs) self.dataprep_conn_id: str = dataprep_conn_id + self.project_id = project_id self.job_group_id = job_group_id self.embed = embed self.include_deleted = include_deleted - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: self.log.info("Fetching data for job with id: %d ...", self.job_group_id) + + if self.project_id: + DataprepJobGroupLink.persist( + context=context, + task_instance=self, + project_id=self.project_id, + job_group_id=int(self.job_group_id), + ) + hook = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id) response = hook.get_job_group( - job_group_id=self.job_group_id, + job_group_id=int(self.job_group_id), embed=self.embed, include_deleted=self.include_deleted, ) @@ -113,14 +138,166 @@ class DataprepRunJobGroupOperator(BaseOperator): """ template_fields: Sequence[str] = ("body_request",) + operator_extra_links = (DataprepJobGroupLink(),) - def __init__(self, *, dataprep_conn_id: str = "dataprep_default", body_request: dict, **kwargs) -> None: + def __init__( + self, + *, + project_id: str | None = None, + dataprep_conn_id: str = "dataprep_default", + body_request: dict, + **kwargs, + ) -> None: super().__init__(**kwargs) - self.body_request = body_request + self.project_id = project_id self.dataprep_conn_id = dataprep_conn_id + self.body_request = body_request - def execute(self, context: "Context") -> dict: + def execute(self, context: Context) -> dict: self.log.info("Creating a job...") hook = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id) response = hook.run_job_group(body_request=self.body_request) + + job_group_id = response.get("id") + if self.project_id and job_group_id: + DataprepJobGroupLink.persist( + context=context, + task_instance=self, + project_id=self.project_id, + job_group_id=int(job_group_id), + ) + + return response + + +class DataprepCopyFlowOperator(BaseOperator): + """ + Create a copy of the provided flow id, as well as all contained recipes. + + :param dataprep_conn_id: The Dataprep connection ID + :param flow_id: ID of the flow to be copied + :param name: Name for the copy of the flow + :param description: Description of the copy of the flow + :param copy_datasources: Bool value to define should the copy of data inputs be made or not. + """ + + template_fields: Sequence[str] = ( + "flow_id", + "name", + "project_id", + "description", + ) + operator_extra_links = (DataprepFlowLink(),) + + def __init__( + self, + *, + project_id: str | None = None, + dataprep_conn_id: str = "dataprep_default", + flow_id: int | str, + name: str = "", + description: str = "", + copy_datasources: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.dataprep_conn_id = dataprep_conn_id + self.flow_id = flow_id + self.name = name + self.description = description + self.copy_datasources = copy_datasources + + def execute(self, context: Context) -> dict: + self.log.info("Copying flow with id %d...", self.flow_id) + hook = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id) + response = hook.copy_flow( + flow_id=int(self.flow_id), + name=self.name, + description=self.description, + copy_datasources=self.copy_datasources, + ) + + copied_flow_id = response.get("id") + if self.project_id and copied_flow_id: + DataprepFlowLink.persist( + context=context, + task_instance=self, + project_id=self.project_id, + flow_id=int(copied_flow_id), + ) + return response + + +class DataprepDeleteFlowOperator(BaseOperator): + """ + Delete the flow with provided id. + + :param dataprep_conn_id: The Dataprep connection ID + :param flow_id: ID of the flow to be copied + """ + + template_fields: Sequence[str] = ("flow_id",) + + def __init__( + self, + *, + dataprep_conn_id: str = "dataprep_default", + flow_id: int | str, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dataprep_conn_id = dataprep_conn_id + self.flow_id = flow_id + + def execute(self, context: Context) -> None: + self.log.info("Start delete operation of the flow with id: %d...", self.flow_id) + hook = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id) + hook.delete_flow(flow_id=int(self.flow_id)) + + +class DataprepRunFlowOperator(BaseOperator): + """ + Runs the flow with the provided id copy of the provided flow id. + + :param dataprep_conn_id: The Dataprep connection ID + :param flow_id: ID of the flow to be copied + :param body_request: Body of the POST request to be sent. + """ + + template_fields: Sequence[str] = ( + "flow_id", + "project_id", + ) + operator_extra_links = (DataprepJobGroupLink(),) + + def __init__( + self, + *, + project_id: str | None = None, + flow_id: int | str, + body_request: dict, + dataprep_conn_id: str = "dataprep_default", + **kwargs, + ): + super().__init__(**kwargs) + self.project_id = project_id + self.flow_id = flow_id + self.body_request = body_request + self.dataprep_conn_id = dataprep_conn_id + + def execute(self, context: Context) -> dict: + self.log.info("Running the flow with id: %d...", self.flow_id) + hooks = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id) + response = hooks.run_flow(flow_id=int(self.flow_id), body_request=self.body_request) + + if self.project_id: + job_group_id = response["data"][0]["id"] + DataprepJobGroupLink.persist( + context=context, + task_instance=self, + project_id=self.project_id, + job_group_id=int(job_group_id), + ) + return response diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 7a160d1ba97dd..24ca7de40100d 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -15,8 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains Google Dataproc operators.""" +from __future__ import annotations import inspect import ntpath @@ -26,13 +26,13 @@ import uuid import warnings from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import TYPE_CHECKING, Sequence from google.api_core import operation # type: ignore from google.api_core.exceptions import AlreadyExists, NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry, exponential_sleep_generator -from google.cloud.dataproc_v1 import Batch, Cluster +from google.cloud.dataproc_v1 import Batch, Cluster, JobStatus from google.protobuf.duration_pb2 import Duration from google.protobuf.field_mask_pb2 import FieldMask @@ -50,6 +50,7 @@ DataprocLink, DataprocListLink, ) +from airflow.providers.google.cloud.triggers.dataproc import DataprocBaseTrigger from airflow.utils import timezone if TYPE_CHECKING: @@ -133,38 +134,38 @@ class ClusterGenerator: def __init__( self, project_id: str, - num_workers: Optional[int] = None, - zone: Optional[str] = None, - network_uri: Optional[str] = None, - subnetwork_uri: Optional[str] = None, - internal_ip_only: Optional[bool] = None, - tags: Optional[List[str]] = None, - storage_bucket: Optional[str] = None, - init_actions_uris: Optional[List[str]] = None, + num_workers: int | None = None, + zone: str | None = None, + network_uri: str | None = None, + subnetwork_uri: str | None = None, + internal_ip_only: bool | None = None, + tags: list[str] | None = None, + storage_bucket: str | None = None, + init_actions_uris: list[str] | None = None, init_action_timeout: str = "10m", - metadata: Optional[Dict] = None, - custom_image: Optional[str] = None, - custom_image_project_id: Optional[str] = None, - custom_image_family: Optional[str] = None, - image_version: Optional[str] = None, - autoscaling_policy: Optional[str] = None, - properties: Optional[Dict] = None, - optional_components: Optional[List[str]] = None, + metadata: dict | None = None, + custom_image: str | None = None, + custom_image_project_id: str | None = None, + custom_image_family: str | None = None, + image_version: str | None = None, + autoscaling_policy: str | None = None, + properties: dict | None = None, + optional_components: list[str] | None = None, num_masters: int = 1, - master_machine_type: str = 'n1-standard-4', - master_disk_type: str = 'pd-standard', + master_machine_type: str = "n1-standard-4", + master_disk_type: str = "pd-standard", master_disk_size: int = 1024, - worker_machine_type: str = 'n1-standard-4', - worker_disk_type: str = 'pd-standard', + worker_machine_type: str = "n1-standard-4", + worker_disk_type: str = "pd-standard", worker_disk_size: int = 1024, num_preemptible_workers: int = 0, - service_account: Optional[str] = None, - service_account_scopes: Optional[List[str]] = None, - idle_delete_ttl: Optional[int] = None, - auto_delete_time: Optional[datetime] = None, - auto_delete_ttl: Optional[int] = None, - customer_managed_key: Optional[str] = None, - enable_component_gateway: Optional[bool] = False, + service_account: str | None = None, + service_account_scopes: list[str] | None = None, + idle_delete_ttl: int | None = None, + auto_delete_time: datetime | None = None, + auto_delete_ttl: int | None = None, + customer_managed_key: str | None = None, + enable_component_gateway: bool | None = False, **kwargs, ) -> None: @@ -231,45 +232,45 @@ def _get_init_action_timeout(self) -> dict: def _build_gce_cluster_config(self, cluster_data): if self.zone: - zone_uri = f'https://www.googleapis.com/compute/v1/projects/{self.project_id}/zones/{self.zone}' - cluster_data['gce_cluster_config']['zone_uri'] = zone_uri + zone_uri = f"https://www.googleapis.com/compute/v1/projects/{self.project_id}/zones/{self.zone}" + cluster_data["gce_cluster_config"]["zone_uri"] = zone_uri if self.metadata: - cluster_data['gce_cluster_config']['metadata'] = self.metadata + cluster_data["gce_cluster_config"]["metadata"] = self.metadata if self.network_uri: - cluster_data['gce_cluster_config']['network_uri'] = self.network_uri + cluster_data["gce_cluster_config"]["network_uri"] = self.network_uri if self.subnetwork_uri: - cluster_data['gce_cluster_config']['subnetwork_uri'] = self.subnetwork_uri + cluster_data["gce_cluster_config"]["subnetwork_uri"] = self.subnetwork_uri if self.internal_ip_only: if not self.subnetwork_uri: raise AirflowException("Set internal_ip_only to true only when you pass a subnetwork_uri.") - cluster_data['gce_cluster_config']['internal_ip_only'] = True + cluster_data["gce_cluster_config"]["internal_ip_only"] = True if self.tags: - cluster_data['gce_cluster_config']['tags'] = self.tags + cluster_data["gce_cluster_config"]["tags"] = self.tags if self.service_account: - cluster_data['gce_cluster_config']['service_account'] = self.service_account + cluster_data["gce_cluster_config"]["service_account"] = self.service_account if self.service_account_scopes: - cluster_data['gce_cluster_config']['service_account_scopes'] = self.service_account_scopes + cluster_data["gce_cluster_config"]["service_account_scopes"] = self.service_account_scopes return cluster_data def _build_lifecycle_config(self, cluster_data): if self.idle_delete_ttl: - cluster_data['lifecycle_config']['idle_delete_ttl'] = {"seconds": self.idle_delete_ttl} + cluster_data["lifecycle_config"]["idle_delete_ttl"] = {"seconds": self.idle_delete_ttl} if self.auto_delete_time: utc_auto_delete_time = timezone.convert_to_utc(self.auto_delete_time) - cluster_data['lifecycle_config']['auto_delete_time'] = utc_auto_delete_time.strftime( - '%Y-%m-%dT%H:%M:%S.%fZ' + cluster_data["lifecycle_config"]["auto_delete_time"] = utc_auto_delete_time.strftime( + "%Y-%m-%dT%H:%M:%S.%fZ" ) elif self.auto_delete_ttl: - cluster_data['lifecycle_config']['auto_delete_ttl'] = {"seconds": int(self.auto_delete_ttl)} + cluster_data["lifecycle_config"]["auto_delete_ttl"] = {"seconds": int(self.auto_delete_ttl)} return cluster_data @@ -286,66 +287,66 @@ def _build_cluster_data(self): worker_type_uri = self.worker_machine_type cluster_data = { - 'gce_cluster_config': {}, - 'master_config': { - 'num_instances': self.num_masters, - 'machine_type_uri': master_type_uri, - 'disk_config': { - 'boot_disk_type': self.master_disk_type, - 'boot_disk_size_gb': self.master_disk_size, + "gce_cluster_config": {}, + "master_config": { + "num_instances": self.num_masters, + "machine_type_uri": master_type_uri, + "disk_config": { + "boot_disk_type": self.master_disk_type, + "boot_disk_size_gb": self.master_disk_size, }, }, - 'worker_config': { - 'num_instances': self.num_workers, - 'machine_type_uri': worker_type_uri, - 'disk_config': { - 'boot_disk_type': self.worker_disk_type, - 'boot_disk_size_gb': self.worker_disk_size, + "worker_config": { + "num_instances": self.num_workers, + "machine_type_uri": worker_type_uri, + "disk_config": { + "boot_disk_type": self.worker_disk_type, + "boot_disk_size_gb": self.worker_disk_size, }, }, - 'secondary_worker_config': {}, - 'software_config': {}, - 'lifecycle_config': {}, - 'encryption_config': {}, - 'autoscaling_config': {}, - 'endpoint_config': {}, + "secondary_worker_config": {}, + "software_config": {}, + "lifecycle_config": {}, + "encryption_config": {}, + "autoscaling_config": {}, + "endpoint_config": {}, } if self.num_preemptible_workers > 0: - cluster_data['secondary_worker_config'] = { - 'num_instances': self.num_preemptible_workers, - 'machine_type_uri': worker_type_uri, - 'disk_config': { - 'boot_disk_type': self.worker_disk_type, - 'boot_disk_size_gb': self.worker_disk_size, + cluster_data["secondary_worker_config"] = { + "num_instances": self.num_preemptible_workers, + "machine_type_uri": worker_type_uri, + "disk_config": { + "boot_disk_type": self.worker_disk_type, + "boot_disk_size_gb": self.worker_disk_size, }, - 'is_preemptible': True, + "is_preemptible": True, } if self.storage_bucket: - cluster_data['config_bucket'] = self.storage_bucket + cluster_data["config_bucket"] = self.storage_bucket if self.image_version: - cluster_data['software_config']['image_version'] = self.image_version + cluster_data["software_config"]["image_version"] = self.image_version elif self.custom_image: project_id = self.custom_image_project_id or self.project_id custom_image_url = ( - f'https://www.googleapis.com/compute/beta/projects/{project_id}' - f'/global/images/{self.custom_image}' + f"https://www.googleapis.com/compute/beta/projects/{project_id}" + f"/global/images/{self.custom_image}" ) - cluster_data['master_config']['image_uri'] = custom_image_url + cluster_data["master_config"]["image_uri"] = custom_image_url if not self.single_node: - cluster_data['worker_config']['image_uri'] = custom_image_url + cluster_data["worker_config"]["image_uri"] = custom_image_url elif self.custom_image_family: project_id = self.custom_image_project_id or self.project_id custom_image_url = ( - 'https://www.googleapis.com/compute/beta/projects/' - f'{project_id}/global/images/family/{self.custom_image_family}' + "https://www.googleapis.com/compute/beta/projects/" + f"{project_id}/global/images/family/{self.custom_image_family}" ) - cluster_data['master_config']['image_uri'] = custom_image_url + cluster_data["master_config"]["image_uri"] = custom_image_url if not self.single_node: - cluster_data['worker_config']['image_uri'] = custom_image_url + cluster_data["worker_config"]["image_uri"] = custom_image_url cluster_data = self._build_gce_cluster_config(cluster_data) @@ -353,26 +354,26 @@ def _build_cluster_data(self): self.properties["dataproc:dataproc.allow.zero.workers"] = "true" if self.properties: - cluster_data['software_config']['properties'] = self.properties + cluster_data["software_config"]["properties"] = self.properties if self.optional_components: - cluster_data['software_config']['optional_components'] = self.optional_components + cluster_data["software_config"]["optional_components"] = self.optional_components cluster_data = self._build_lifecycle_config(cluster_data) if self.init_actions_uris: init_actions_dict = [ - {'executable_file': uri, 'execution_timeout': self._get_init_action_timeout()} + {"executable_file": uri, "execution_timeout": self._get_init_action_timeout()} for uri in self.init_actions_uris ] - cluster_data['initialization_actions'] = init_actions_dict + cluster_data["initialization_actions"] = init_actions_dict if self.customer_managed_key: - cluster_data['encryption_config'] = {'gce_pd_kms_key_name': self.customer_managed_key} + cluster_data["encryption_config"] = {"gce_pd_kms_key_name": self.customer_managed_key} if self.autoscaling_policy: - cluster_data['autoscaling_config'] = {'policy_uri': self.autoscaling_policy} + cluster_data["autoscaling_config"] = {"policy_uri": self.autoscaling_policy} if self.enable_component_gateway: - cluster_data['endpoint_config'] = {'enable_http_port_access': self.enable_component_gateway} + cluster_data["endpoint_config"] = {"enable_http_port_access": self.enable_component_gateway} return cluster_data @@ -440,15 +441,15 @@ class DataprocCreateClusterOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'region', - 'cluster_config', - 'virtual_cluster_config', - 'cluster_name', - 'labels', - 'impersonation_chain', + "project_id", + "region", + "cluster_config", + "virtual_cluster_config", + "cluster_name", + "labels", + "impersonation_chain", ) - template_fields_renderers = {'cluster_config': 'json', 'virtual_cluster_config': 'json'} + template_fields_renderers = {"cluster_config": "json", "virtual_cluster_config": "json"} operator_extra_links = (DataprocLink(),) @@ -457,18 +458,18 @@ def __init__( *, cluster_name: str, region: str, - project_id: Optional[str] = None, - cluster_config: Optional[Union[Dict, Cluster]] = None, - virtual_cluster_config: Optional[Dict] = None, - labels: Optional[Dict] = None, - request_id: Optional[str] = None, + project_id: str | None = None, + cluster_config: dict | Cluster | None = None, + virtual_cluster_config: dict | None = None, + labels: dict | None = None, + request_id: str | None = None, delete_on_error: bool = True, use_if_exists: bool = True, - retry: Union[Retry, _MethodDefault] = DEFAULT, + retry: Retry | _MethodDefault = DEFAULT, timeout: float = 1 * 60 * 60, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: @@ -483,8 +484,8 @@ def __init__( stacklevel=1, ) # Remove result of apply defaults - if 'params' in kwargs: - del kwargs['params'] + if "params" in kwargs: + del kwargs["params"] # Create cluster object from kwargs if project_id is None: @@ -555,7 +556,7 @@ def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None: gcs_uri = hook.diagnose_cluster( region=self.region, cluster_name=self.cluster_name, project_id=self.project_id ) - self.log.info('Diagnostic information for cluster %s available at: %s', self.cluster_name, gcs_uri) + self.log.info("Diagnostic information for cluster %s available at: %s", self.cluster_name, gcs_uri) if self.delete_on_error: self._delete_cluster(hook) raise AirflowException("Cluster was created but was in ERROR state.") @@ -586,8 +587,8 @@ def _wait_for_cluster_in_creating_state(self, hook: DataprocHook) -> Cluster: cluster = self._get_cluster(hook) return cluster - def execute(self, context: 'Context') -> dict: - self.log.info('Creating cluster: %s', self.cluster_name) + def execute(self, context: Context) -> dict: + self.log.info("Creating cluster: %s", self.cluster_name) hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) # Save data required to display extra link no matter what the cluster status will be DataprocLink.persist( @@ -657,7 +658,7 @@ class DataprocScaleClusterOperator(BaseOperator): account from the list granting this role to the originating account (templated). """ - template_fields: Sequence[str] = ('cluster_name', 'project_id', 'region', 'impersonation_chain') + template_fields: Sequence[str] = ("cluster_name", "project_id", "region", "impersonation_chain") operator_extra_links = (DataprocLink(),) @@ -665,13 +666,13 @@ def __init__( self, *, cluster_name: str, - project_id: Optional[str] = None, - region: str = 'global', + project_id: str | None = None, + region: str = "global", num_workers: int = 2, num_preemptible_workers: int = 0, - graceful_decommission_timeout: Optional[str] = None, + graceful_decommission_timeout: str | None = None, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -694,15 +695,15 @@ def __init__( def _build_scale_cluster_data(self) -> dict: scale_data = { - 'config': { - 'worker_config': {'num_instances': self.num_workers}, - 'secondary_worker_config': {'num_instances': self.num_preemptible_workers}, + "config": { + "worker_config": {"num_instances": self.num_workers}, + "secondary_worker_config": {"num_instances": self.num_preemptible_workers}, } } return scale_data @property - def _graceful_decommission_timeout_object(self) -> Optional[Dict[str, int]]: + def _graceful_decommission_timeout_object(self) -> dict[str, int] | None: if not self.graceful_decommission_timeout: return None @@ -728,9 +729,9 @@ def _graceful_decommission_timeout_object(self) -> Optional[Dict[str, int]]: " i.e. 1d, 4h, 10m, 30s" ) - return {'seconds': timeout} + return {"seconds": timeout} - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: """Scale, up or down, a cluster on Google Cloud Dataproc.""" self.log.info("Scaling cluster: %s", self.cluster_name) @@ -748,7 +749,7 @@ def execute(self, context: 'Context') -> None: cluster_name=self.cluster_name, cluster=scaling_cluster_data, graceful_decommission_timeout=self._graceful_decommission_timeout_object, - update_mask={'paths': update_mask}, + update_mask={"paths": update_mask}, ) operation.result() self.log.info("Cluster scaling finished") @@ -782,21 +783,21 @@ class DataprocDeleteClusterOperator(BaseOperator): account from the list granting this role to the originating account (templated). """ - template_fields: Sequence[str] = ('project_id', 'region', 'cluster_name', 'impersonation_chain') + template_fields: Sequence[str] = ("project_id", "region", "cluster_name", "impersonation_chain") def __init__( self, *, region: str, cluster_name: str, - project_id: Optional[str] = None, - cluster_uuid: Optional[str] = None, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + cluster_uuid: str | None = None, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -811,7 +812,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) self.log.info("Deleting cluster: %s", self.cluster_name) operation = hook.delete_cluster( @@ -867,6 +868,9 @@ class DataprocJobBaseOperator(BaseOperator): :param asynchronous: Flag to return after submitting the job to the Dataproc API. This is useful for submitting long running jobs and waiting on them asynchronously using the DataprocJobSensor + :param deferrable: Run operator in the deferrable mode + :param polling_interval_seconds: time in seconds between polling for job completion. + The value is considered only when running in deferrable mode. Must be greater than 0. :var dataproc_job_id: The actual "jobId" as submitted to the Dataproc API. This is useful for identifying or linking to the job in the Google Cloud Console @@ -883,20 +887,24 @@ def __init__( self, *, region: str, - job_name: str = '{{task.task_id}}_{{ds_nodash}}', + job_name: str = "{{task.task_id}}_{{ds_nodash}}", cluster_name: str = "cluster-1", - project_id: Optional[str] = None, - dataproc_properties: Optional[Dict] = None, - dataproc_jars: Optional[List[str]] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - labels: Optional[Dict] = None, - job_error_states: Optional[Set[str]] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + dataproc_properties: dict | None = None, + dataproc_jars: list[str] | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + labels: dict | None = None, + job_error_states: set[str] | None = None, + impersonation_chain: str | Sequence[str] | None = None, asynchronous: bool = False, + deferrable: bool = False, + polling_interval_seconds: int = 10, **kwargs, ) -> None: super().__init__(**kwargs) + if deferrable and polling_interval_seconds <= 0: + raise ValueError("Invalid value for polling_interval_seconds. Expected value greater than 0") self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.labels = labels @@ -906,14 +914,16 @@ def __init__( self.dataproc_jars = dataproc_jars self.region = region - self.job_error_states = job_error_states if job_error_states is not None else {'ERROR'} + self.job_error_states = job_error_states if job_error_states is not None else {"ERROR"} self.impersonation_chain = impersonation_chain self.hook = DataprocHook(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain) self.project_id = self.hook.project_id if project_id is None else project_id - self.job_template: Optional[DataProcJobBuilder] = None - self.job: Optional[dict] = None + self.job_template: DataProcJobBuilder | None = None + self.job: dict | None = None self.dataproc_job_id = None self.asynchronous = asynchronous + self.deferrable = deferrable + self.polling_interval_seconds = polling_interval_seconds def create_job_template(self) -> DataProcJobBuilder: """Initialize `self.job_template` with default values""" @@ -938,34 +948,61 @@ def create_job_template(self) -> DataProcJobBuilder: def _generate_job_template(self) -> str: if self.job_template: job = self.job_template.build() - return job['job'] + return job["job"] raise Exception("Create a job template before") - def execute(self, context: 'Context'): + def execute(self, context: Context): if self.job_template: self.job = self.job_template.build() if self.job is None: raise Exception("The job should be set here.") self.dataproc_job_id = self.job["job"]["reference"]["job_id"] - self.log.info('Submitting %s job %s', self.job_type, self.dataproc_job_id) + self.log.info("Submitting %s job %s", self.job_type, self.dataproc_job_id) job_object = self.hook.submit_job( project_id=self.project_id, job=self.job["job"], region=self.region ) job_id = job_object.reference.job_id - self.log.info('Job %s submitted successfully.', job_id) + self.log.info("Job %s submitted successfully.", job_id) # Save data required for extra links no matter what the job status will be DataprocLink.persist( context=context, task_instance=self, url=DATAPROC_JOB_LOG_LINK, resource=job_id ) + if self.deferrable: + self.defer( + trigger=DataprocBaseTrigger( + job_id=job_id, + project_id=self.project_id, + region=self.region, + delegate_to=self.delegate_to, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + polling_interval_seconds=self.polling_interval_seconds, + ), + method_name="execute_complete", + ) if not self.asynchronous: - self.log.info('Waiting for job %s to complete', job_id) + self.log.info("Waiting for job %s to complete", job_id) self.hook.wait_for_job(job_id=job_id, region=self.region, project_id=self.project_id) - self.log.info('Job %s completed successfully.', job_id) + self.log.info("Job %s completed successfully.", job_id) return job_id else: raise AirflowException("Create a job template before") + def execute_complete(self, context, event=None) -> None: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + job_state = event["job_state"] + job_id = event["job_id"] + if job_state == JobStatus.State.ERROR: + raise AirflowException(f"Job failed:\n{job_id}") + if job_state == JobStatus.State.CANCELLED: + raise AirflowException(f"Job was cancelled:\n{job_id}") + self.log.info("%s completed successfully.", self.task_id) + def on_kill(self) -> None: """ Callback called when the operator is killed. @@ -1016,27 +1053,27 @@ class DataprocSubmitPigJobOperator(DataprocJobBaseOperator): """ template_fields: Sequence[str] = ( - 'query', - 'variables', - 'job_name', - 'cluster_name', - 'region', - 'dataproc_jars', - 'dataproc_properties', - 'impersonation_chain', + "query", + "variables", + "job_name", + "cluster_name", + "region", + "dataproc_jars", + "dataproc_properties", + "impersonation_chain", ) - template_ext = ('.pg', '.pig') - ui_color = '#0273d4' - job_type = 'pig_job' + template_ext = (".pg", ".pig") + ui_color = "#0273d4" + job_type = "pig_job" operator_extra_links = (DataprocLink(),) def __init__( self, *, - query: Optional[str] = None, - query_uri: Optional[str] = None, - variables: Optional[Dict] = None, + query: str | None = None, + query_uri: str | None = None, + variables: dict | None = None, **kwargs, ) -> None: # TODO: Remove one day @@ -1062,18 +1099,18 @@ def generate_job(self): if self.query is None: if self.query_uri is None: - raise AirflowException('One of query or query_uri should be set here') + raise AirflowException("One of query or query_uri should be set here") job_template.add_query_uri(self.query_uri) else: job_template.add_query(self.query) job_template.add_variables(self.variables) return self._generate_job_template() - def execute(self, context: 'Context'): + def execute(self, context: Context): job_template = self.create_job_template() if self.query is None: if self.query_uri is None: - raise AirflowException('One of query or query_uri should be set here') + raise AirflowException("One of query or query_uri should be set here") job_template.add_query_uri(self.query_uri) else: job_template.add_query(self.query) @@ -1092,25 +1129,25 @@ class DataprocSubmitHiveJobOperator(DataprocJobBaseOperator): """ template_fields: Sequence[str] = ( - 'query', - 'variables', - 'job_name', - 'cluster_name', - 'region', - 'dataproc_jars', - 'dataproc_properties', - 'impersonation_chain', + "query", + "variables", + "job_name", + "cluster_name", + "region", + "dataproc_jars", + "dataproc_properties", + "impersonation_chain", ) - template_ext = ('.q', '.hql') - ui_color = '#0273d4' - job_type = 'hive_job' + template_ext = (".q", ".hql") + ui_color = "#0273d4" + job_type = "hive_job" def __init__( self, *, - query: Optional[str] = None, - query_uri: Optional[str] = None, - variables: Optional[Dict] = None, + query: str | None = None, + query_uri: str | None = None, + variables: dict | None = None, **kwargs, ) -> None: # TODO: Remove one day @@ -1127,7 +1164,7 @@ def __init__( self.query_uri = query_uri self.variables = variables if self.query is not None and self.query_uri is not None: - raise AirflowException('Only one of `query` and `query_uri` can be passed.') + raise AirflowException("Only one of `query` and `query_uri` can be passed.") def generate_job(self): """ @@ -1137,18 +1174,18 @@ def generate_job(self): job_template = self.create_job_template() if self.query is None: if self.query_uri is None: - raise AirflowException('One of query or query_uri should be set here') + raise AirflowException("One of query or query_uri should be set here") job_template.add_query_uri(self.query_uri) else: job_template.add_query(self.query) job_template.add_variables(self.variables) return self._generate_job_template() - def execute(self, context: 'Context'): + def execute(self, context: Context): job_template = self.create_job_template() if self.query is None: if self.query_uri is None: - raise AirflowException('One of query or query_uri should be set here') + raise AirflowException("One of query or query_uri should be set here") job_template.add_query_uri(self.query_uri) else: job_template.add_query(self.query) @@ -1166,26 +1203,26 @@ class DataprocSubmitSparkSqlJobOperator(DataprocJobBaseOperator): """ template_fields: Sequence[str] = ( - 'query', - 'variables', - 'job_name', - 'cluster_name', - 'region', - 'dataproc_jars', - 'dataproc_properties', - 'impersonation_chain', + "query", + "variables", + "job_name", + "cluster_name", + "region", + "dataproc_jars", + "dataproc_properties", + "impersonation_chain", ) - template_ext = ('.q',) - template_fields_renderers = {'sql': 'sql'} - ui_color = '#0273d4' - job_type = 'spark_sql_job' + template_ext = (".q",) + template_fields_renderers = {"sql": "sql"} + ui_color = "#0273d4" + job_type = "spark_sql_job" def __init__( self, *, - query: Optional[str] = None, - query_uri: Optional[str] = None, - variables: Optional[Dict] = None, + query: str | None = None, + query_uri: str | None = None, + variables: dict | None = None, **kwargs, ) -> None: # TODO: Remove one day @@ -1202,7 +1239,7 @@ def __init__( self.query_uri = query_uri self.variables = variables if self.query is not None and self.query_uri is not None: - raise AirflowException('Only one of `query` and `query_uri` can be passed.') + raise AirflowException("Only one of `query` and `query_uri` can be passed.") def generate_job(self): """ @@ -1217,11 +1254,11 @@ def generate_job(self): job_template.add_variables(self.variables) return self._generate_job_template() - def execute(self, context: 'Context'): + def execute(self, context: Context): job_template = self.create_job_template() if self.query is None: if self.query_uri is None: - raise AirflowException('One of query or query_uri should be set here') + raise AirflowException("One of query or query_uri should be set here") job_template.add_query_uri(self.query_uri) else: job_template.add_query(self.query) @@ -1244,25 +1281,25 @@ class DataprocSubmitSparkJobOperator(DataprocJobBaseOperator): """ template_fields: Sequence[str] = ( - 'arguments', - 'job_name', - 'cluster_name', - 'region', - 'dataproc_jars', - 'dataproc_properties', - 'impersonation_chain', + "arguments", + "job_name", + "cluster_name", + "region", + "dataproc_jars", + "dataproc_properties", + "impersonation_chain", ) - ui_color = '#0273d4' - job_type = 'spark_job' + ui_color = "#0273d4" + job_type = "spark_job" def __init__( self, *, - main_jar: Optional[str] = None, - main_class: Optional[str] = None, - arguments: Optional[List] = None, - archives: Optional[List] = None, - files: Optional[List] = None, + main_jar: str | None = None, + main_class: str | None = None, + arguments: list | None = None, + archives: list | None = None, + files: list | None = None, **kwargs, ) -> None: # TODO: Remove one day @@ -1293,7 +1330,7 @@ def generate_job(self): job_template.add_file_uris(self.files) return self._generate_job_template() - def execute(self, context: 'Context'): + def execute(self, context: Context): job_template = self.create_job_template() job_template.set_main(self.main_jar, self.main_class) job_template.add_args(self.arguments) @@ -1317,25 +1354,25 @@ class DataprocSubmitHadoopJobOperator(DataprocJobBaseOperator): """ template_fields: Sequence[str] = ( - 'arguments', - 'job_name', - 'cluster_name', - 'region', - 'dataproc_jars', - 'dataproc_properties', - 'impersonation_chain', + "arguments", + "job_name", + "cluster_name", + "region", + "dataproc_jars", + "dataproc_properties", + "impersonation_chain", ) - ui_color = '#0273d4' - job_type = 'hadoop_job' + ui_color = "#0273d4" + job_type = "hadoop_job" def __init__( self, *, - main_jar: Optional[str] = None, - main_class: Optional[str] = None, - arguments: Optional[List] = None, - archives: Optional[List] = None, - files: Optional[List] = None, + main_jar: str | None = None, + main_class: str | None = None, + arguments: list | None = None, + archives: list | None = None, + files: list | None = None, **kwargs, ) -> None: # TODO: Remove one day @@ -1366,7 +1403,7 @@ def generate_job(self): job_template.add_file_uris(self.files) return self._generate_job_template() - def execute(self, context: 'Context'): + def execute(self, context: Context): job_template = self.create_job_template() job_template.set_main(self.main_jar, self.main_class) job_template.add_args(self.arguments) @@ -1390,21 +1427,21 @@ class DataprocSubmitPySparkJobOperator(DataprocJobBaseOperator): """ template_fields: Sequence[str] = ( - 'main', - 'arguments', - 'job_name', - 'cluster_name', - 'region', - 'dataproc_jars', - 'dataproc_properties', - 'impersonation_chain', + "main", + "arguments", + "job_name", + "cluster_name", + "region", + "dataproc_jars", + "dataproc_properties", + "impersonation_chain", ) - ui_color = '#0273d4' - job_type = 'pyspark_job' + ui_color = "#0273d4" + job_type = "pyspark_job" @staticmethod def _generate_temp_filename(filename): - date = time.strftime('%Y%m%d%H%M%S') + date = time.strftime("%Y%m%d%H%M%S") return f"{date}_{str(uuid.uuid4())[:8]}_{ntpath.basename(filename)}" def _upload_file_temp(self, bucket, local_file): @@ -1421,7 +1458,7 @@ def _upload_file_temp(self, bucket, local_file): GCSHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain).upload( bucket_name=bucket, object_name=temp_filename, - mime_type='application/x-python', + mime_type="application/x-python", filename=local_file, ) return f"gs://{bucket}/{temp_filename}" @@ -1430,10 +1467,10 @@ def __init__( self, *, main: str, - arguments: Optional[List] = None, - archives: Optional[List] = None, - pyfiles: Optional[List] = None, - files: Optional[List] = None, + arguments: list | None = None, + archives: list | None = None, + pyfiles: list | None = None, + files: list | None = None, **kwargs, ) -> None: # TODO: Remove one day @@ -1463,7 +1500,7 @@ def generate_job(self): cluster_info = self.hook.get_cluster( project_id=self.project_id, region=self.region, cluster_name=self.cluster_name ) - bucket = cluster_info['config']['config_bucket'] + bucket = cluster_info["config"]["config_bucket"] self.main = f"gs://{bucket}/{self.main}" job_template.set_python_main(self.main) job_template.add_args(self.arguments) @@ -1473,14 +1510,14 @@ def generate_job(self): return self._generate_job_template() - def execute(self, context: 'Context'): + def execute(self, context: Context): job_template = self.create_job_template() # Check if the file is local, if that is the case, upload it to a bucket if os.path.isfile(self.main): cluster_info = self.hook.get_cluster( project_id=self.project_id, region=self.region, cluster_name=self.cluster_name ) - bucket = cluster_info['config']['config_bucket'] + bucket = cluster_info["config"]["config_bucket"] self.main = self._upload_file_temp(bucket, self.main) job_template.set_python_main(self.main) @@ -1513,14 +1550,14 @@ class DataprocCreateWorkflowTemplateOperator(BaseOperator): def __init__( self, *, - template: Dict, + template: dict, region: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -1533,7 +1570,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) self.log.info("Creating template") try: @@ -1594,7 +1631,7 @@ class DataprocInstantiateWorkflowTemplateOperator(BaseOperator): account from the list granting this role to the originating account (templated). """ - template_fields: Sequence[str] = ('template_id', 'impersonation_chain', 'request_id', 'parameters') + template_fields: Sequence[str] = ("template_id", "impersonation_chain", "request_id", "parameters") template_fields_renderers = {"parameters": "json"} operator_extra_links = (DataprocLink(),) @@ -1603,15 +1640,15 @@ def __init__( *, template_id: str, region: str, - project_id: Optional[str] = None, - version: Optional[int] = None, - request_id: Optional[str] = None, - parameters: Optional[Dict[str, str]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + version: int | None = None, + request_id: str | None = None, + parameters: dict[str, str] | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1628,9 +1665,9 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) - self.log.info('Instantiating template %s', self.template_id) + self.log.info("Instantiating template %s", self.template_id) operation = hook.instantiate_workflow_template( project_id=self.project_id, region=self.region, @@ -1642,12 +1679,13 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) - operation.result() - workflow_id = operation.operation.name.split('/')[-1] + self.workflow_id = operation.operation.name.split("/")[-1] DataprocLink.persist( - context=context, task_instance=self, url=DATAPROC_WORKFLOW_LINK, resource=workflow_id + context=context, task_instance=self, url=DATAPROC_WORKFLOW_LINK, resource=self.workflow_id ) - self.log.info('Template instantiated.') + self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id) + operation.result() + self.log.info("Workflow %s completed successfully", self.workflow_id) class DataprocInstantiateInlineWorkflowTemplateOperator(BaseOperator): @@ -1691,22 +1729,22 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(BaseOperator): account from the list granting this role to the originating account (templated). """ - template_fields: Sequence[str] = ('template', 'impersonation_chain') + template_fields: Sequence[str] = ("template", "impersonation_chain") template_fields_renderers = {"template": "json"} operator_extra_links = (DataprocLink(),) def __init__( self, *, - template: Dict, + template: dict, region: str, - project_id: Optional[str] = None, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1721,8 +1759,8 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): - self.log.info('Instantiating Inline Template') + def execute(self, context: Context): + self.log.info("Instantiating Inline Template") hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) operation = hook.instantiate_inline_workflow_template( template=self.template, @@ -1733,12 +1771,13 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) - operation.result() - workflow_id = operation.operation.name.split('/')[-1] + self.workflow_id = operation.operation.name.split("/")[-1] DataprocLink.persist( - context=context, task_instance=self, url=DATAPROC_WORKFLOW_LINK, resource=workflow_id + context=context, task_instance=self, url=DATAPROC_WORKFLOW_LINK, resource=self.workflow_id ) - self.log.info('Template instantiated.') + self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id) + operation.result() + self.log.info("Workflow %s completed successfully", self.workflow_id) class DataprocSubmitJobOperator(BaseOperator): @@ -1771,11 +1810,14 @@ class DataprocSubmitJobOperator(BaseOperator): :param asynchronous: Flag to return after submitting the job to the Dataproc API. This is useful for submitting long running jobs and waiting on them asynchronously using the DataprocJobSensor + :param deferrable: Run operator in the deferrable mode + :param polling_interval_seconds: time in seconds between polling for job completion. + The value is considered only when running in deferrable mode. Must be greater than 0. :param cancel_on_kill: Flag which indicates whether cancel the hook's job or not, when on_kill is called :param wait_timeout: How many seconds wait for job to be ready. Used only if ``asynchronous`` is False """ - template_fields: Sequence[str] = ('project_id', 'region', 'job', 'impersonation_chain', 'request_id') + template_fields: Sequence[str] = ("project_id", "region", "job", "impersonation_chain", "request_id") template_fields_renderers = {"job": "json"} operator_extra_links = (DataprocLink(),) @@ -1783,21 +1825,25 @@ class DataprocSubmitJobOperator(BaseOperator): def __init__( self, *, - job: Dict, + job: dict, region: str, - project_id: Optional[str] = None, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, asynchronous: bool = False, + deferrable: bool = False, + polling_interval_seconds: int = 10, cancel_on_kill: bool = True, - wait_timeout: Optional[int] = None, + wait_timeout: int | None = None, **kwargs, ) -> None: super().__init__(**kwargs) + if deferrable and polling_interval_seconds <= 0: + raise ValueError("Invalid value for polling_interval_seconds. Expected value greater than 0") self.project_id = project_id self.region = region self.job = job @@ -1808,12 +1854,14 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain self.asynchronous = asynchronous + self.deferrable = deferrable + self.polling_interval_seconds = polling_interval_seconds self.cancel_on_kill = cancel_on_kill - self.hook: Optional[DataprocHook] = None - self.job_id: Optional[str] = None + self.hook: DataprocHook | None = None + self.job_id: str | None = None self.wait_timeout = wait_timeout - def execute(self, context: 'Context'): + def execute(self, context: Context): self.log.info("Submitting job") self.hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) job_object = self.hook.submit_job( @@ -1826,22 +1874,48 @@ def execute(self, context: 'Context'): metadata=self.metadata, ) new_job_id: str = job_object.reference.job_id - self.log.info('Job %s submitted successfully.', new_job_id) + self.log.info("Job %s submitted successfully.", new_job_id) # Save data required by extra links no matter what the job status will be DataprocLink.persist( context=context, task_instance=self, url=DATAPROC_JOB_LOG_LINK, resource=new_job_id ) self.job_id = new_job_id - if not self.asynchronous: - self.log.info('Waiting for job %s to complete', new_job_id) + if self.deferrable: + self.defer( + trigger=DataprocBaseTrigger( + job_id=self.job_id, + project_id=self.project_id, + region=self.region, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + polling_interval_seconds=self.polling_interval_seconds, + ), + method_name="execute_complete", + ) + elif not self.asynchronous: + self.log.info("Waiting for job %s to complete", new_job_id) self.hook.wait_for_job( job_id=new_job_id, region=self.region, project_id=self.project_id, timeout=self.wait_timeout ) - self.log.info('Job %s completed successfully.', new_job_id) + self.log.info("Job %s completed successfully.", new_job_id) return self.job_id + def execute_complete(self, context, event=None) -> None: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + job_state = event["job_state"] + job_id = event["job_id"] + if job_state == JobStatus.State.ERROR: + raise AirflowException(f"Job failed:\n{job_id}") + if job_state == JobStatus.State.CANCELLED: + raise AirflowException(f"Job was cancelled:\n{job_id}") + self.log.info("%s completed successfully.", self.task_id) + def on_kill(self): if self.job_id and self.cancel_on_kill: self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id, region=self.region) @@ -1888,12 +1962,12 @@ class DataprocUpdateClusterOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'cluster_name', - 'cluster', - 'region', - 'request_id', - 'project_id', - 'impersonation_chain', + "cluster_name", + "cluster", + "region", + "request_id", + "project_id", + "impersonation_chain", ) operator_extra_links = (DataprocLink(),) @@ -1901,17 +1975,17 @@ def __init__( self, *, cluster_name: str, - cluster: Union[Dict, Cluster], - update_mask: Union[Dict, FieldMask], - graceful_decommission_timeout: Union[Dict, Duration], + cluster: dict | Cluster, + update_mask: dict | FieldMask, + graceful_decommission_timeout: dict | Duration, region: str, - request_id: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + request_id: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -1928,7 +2002,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) # Save data required by extra links no matter what the cluster status will be DataprocLink.persist( @@ -1966,6 +2040,8 @@ class DataprocCreateBatchOperator(BaseOperator): the first ``google.longrunning.Operation`` created and stored in the backend is returned. :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be retried. + :param result_retry: Result retry object used to retry requests. Is used to decrease delay between + executing chained tasks in a DAG by specifying exact amount of seconds for executing. :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if ``retry`` is specified, the timeout applies to each individual attempt. :param metadata: Additional metadata that is provided to the method. @@ -1981,27 +2057,28 @@ class DataprocCreateBatchOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'batch', - 'batch_id', - 'region', - 'impersonation_chain', + "project_id", + "batch", + "batch_id", + "region", + "impersonation_chain", ) operator_extra_links = (DataprocLink(),) def __init__( self, *, - region: Optional[str] = None, - project_id: Optional[str] = None, - batch: Union[Dict, Batch], - batch_id: Optional[str] = None, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + region: str | None = None, + project_id: str | None = None, + batch: dict | Batch, + batch_id: str | None = None, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, + result_retry: Retry | _MethodDefault = DEFAULT, **kwargs, ): super().__init__(**kwargs) @@ -2011,17 +2088,18 @@ def __init__( self.batch_id = batch_id self.request_id = request_id self.retry = retry + self.result_retry = result_retry self.timeout = timeout self.metadata = metadata self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - self.operation: Optional[operation.Operation] = None + self.operation: operation.Operation | None = None - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) self.log.info("Creating batch") if self.region is None: - raise AirflowException('Region should be set here') + raise AirflowException("Region should be set here") try: self.operation = hook.create_batch( region=self.region, @@ -2035,12 +2113,14 @@ def execute(self, context: 'Context'): ) if self.operation is None: raise RuntimeError("The operation should be set here!") - result = hook.wait_for_operation(timeout=self.timeout, operation=self.operation) + result = hook.wait_for_operation( + timeout=self.timeout, result_retry=self.result_retry, operation=self.operation + ) self.log.info("Batch %s created", self.batch_id) except AlreadyExists: self.log.info("Batch with given id already exists") if self.batch_id is None: - raise AirflowException('Batch Id should be set here') + raise AirflowException("Batch Id should be set here") result = hook.get_batch( batch_id=self.batch_id, region=self.region, @@ -2049,7 +2129,23 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) - batch_id = self.batch_id or result.name.split('/')[-1] + + # The existing batch may be a number of states other than 'SUCCEEDED' + if result.state != Batch.State.SUCCEEDED: + if result.state == Batch.State.FAILED or result.state == Batch.State.CANCELLED: + raise AirflowException( + f"Existing Batch {self.batch_id} failed or cancelled. " + f"Error: {result.state_message}" + ) + else: + # Batch state is either: RUNNING, PENDING, CANCELLING, or UNSPECIFIED + self.log.info( + f"Batch {self.batch_id} is in state {result.state.name}." + "Waiting for state change..." + ) + result = hook.wait_for_operation(timeout=self.timeout, operation=result) + + batch_id = self.batch_id or result.name.split("/")[-1] DataprocLink.persist(context=context, task_instance=self, url=DATAPROC_BATCH_LINK, resource=batch_id) return Batch.to_dict(result) @@ -2090,12 +2186,12 @@ def __init__( *, batch_id: str, region: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -2108,7 +2204,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) self.log.info("Deleting batch: %s", self.batch_id) hook.delete_batch( @@ -2155,12 +2251,12 @@ def __init__( *, batch_id: str, region: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -2173,7 +2269,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) self.log.info("Getting batch: %s", self.batch_id) batch = hook.get_batch( @@ -2215,7 +2311,6 @@ class DataprocListBatchesOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: List[dict] """ template_fields: Sequence[str] = ("region", "project_id", "impersonation_chain") @@ -2225,14 +2320,14 @@ def __init__( self, *, region: str, - project_id: Optional[str] = None, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + page_size: int | None = None, + page_token: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -2246,7 +2341,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) results = hook.list_batches( region=self.region, diff --git a/airflow/providers/google/cloud/operators/dataproc_metastore.py b/airflow/providers/google/cloud/operators/dataproc_metastore.py index d0ca4a5f28672..e0585805cc4bd 100644 --- a/airflow/providers/google/cloud/operators/dataproc_metastore.py +++ b/airflow/providers/google/cloud/operators/dataproc_metastore.py @@ -15,12 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains Google Dataproc Metastore operators.""" +from __future__ import annotations -from datetime import datetime from time import sleep -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry, exponential_sleep_generator @@ -58,15 +57,15 @@ class DataprocMetastoreLink(BaseOperatorLink): @staticmethod def persist( - context: "Context", - task_instance: Union[ - "DataprocMetastoreCreateServiceOperator", - "DataprocMetastoreGetServiceOperator", - "DataprocMetastoreRestoreServiceOperator", - "DataprocMetastoreUpdateServiceOperator", - "DataprocMetastoreListBackupsOperator", - "DataprocMetastoreExportMetadataOperator", - ], + context: Context, + task_instance: ( + DataprocMetastoreCreateServiceOperator + | DataprocMetastoreGetServiceOperator + | DataprocMetastoreRestoreServiceOperator + | DataprocMetastoreUpdateServiceOperator + | DataprocMetastoreListBackupsOperator + | DataprocMetastoreExportMetadataOperator + ), url: str, ): task_instance.xcom_push( @@ -82,20 +81,11 @@ def persist( def get_link( self, - operator, - dttm: Optional[datetime] = None, - ti_key: Optional["TaskInstanceKey"] = None, + operator: BaseOperator, + *, + ti_key: TaskInstanceKey, ) -> str: - if ti_key is not None: - conf = XCom.get_value(key=self.key, ti_key=ti_key) - else: - assert dttm - conf = XCom.get_one( - dag_id=operator.dag.dag_id, - task_id=operator.task_id, - execution_date=dttm, - key=self.key, - ) + conf = XCom.get_value(key=self.key, ti_key=ti_key) return ( conf["url"].format( region=conf["region"], @@ -115,11 +105,10 @@ class DataprocMetastoreDetailedLink(BaseOperatorLink): @staticmethod def persist( - context: "Context", - task_instance: Union[ - "DataprocMetastoreCreateBackupOperator", - "DataprocMetastoreCreateMetadataImportOperator", - ], + context: Context, + task_instance: ( + DataprocMetastoreCreateBackupOperator | DataprocMetastoreCreateMetadataImportOperator + ), url: str, resource: str, ): @@ -137,20 +126,11 @@ def persist( def get_link( self, - operator, - dttm: Optional[datetime] = None, - ti_key: Optional["TaskInstanceKey"] = None, + operator: BaseOperator, + *, + ti_key: TaskInstanceKey, ) -> str: - if ti_key is not None: - conf = XCom.get_value(key=self.key, ti_key=ti_key) - else: - assert dttm - conf = XCom.get_one( - dag_id=operator.dag.dag_id, - task_id=operator.task_id, - execution_date=dttm, - key=DataprocMetastoreDetailedLink.key, - ) + conf = XCom.get_value(key=self.key, ti_key=ti_key) return ( conf["url"].format( region=conf["region"], @@ -203,11 +183,11 @@ class DataprocMetastoreCreateBackupOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'backup', - 'impersonation_chain', + "project_id", + "backup", + "impersonation_chain", ) - template_fields_renderers = {'backup': 'json'} + template_fields_renderers = {"backup": "json"} operator_extra_links = (DataprocMetastoreDetailedLink(),) def __init__( @@ -216,14 +196,14 @@ def __init__( project_id: str, region: str, service_id: str, - backup: Union[Dict, Backup], + backup: dict | Backup, backup_id: str, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -239,7 +219,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: "Context") -> dict: + def execute(self, context: Context) -> dict: hook = DataprocMetastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -260,7 +240,7 @@ def execute(self, context: "Context") -> dict: backup = hook.wait_for_operation(self.timeout, operation) self.log.info("Backup %s created successfully", self.backup_id) except HttpError as err: - if err.resp.status not in (409, '409'): + if err.resp.status not in (409, "409"): raise self.log.info("Backup %s already exists", self.backup_id) backup = hook.get_backup( @@ -318,11 +298,11 @@ class DataprocMetastoreCreateMetadataImportOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'metadata_import', - 'impersonation_chain', + "project_id", + "metadata_import", + "impersonation_chain", ) - template_fields_renderers = {'metadata_import': 'json'} + template_fields_renderers = {"metadata_import": "json"} operator_extra_links = (DataprocMetastoreDetailedLink(),) def __init__( @@ -331,14 +311,14 @@ def __init__( project_id: str, region: str, service_id: str, - metadata_import: MetadataImport, + metadata_import: dict | MetadataImport, metadata_import_id: str, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -354,7 +334,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = DataprocMetastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -413,11 +393,11 @@ class DataprocMetastoreCreateServiceOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'service', - 'impersonation_chain', + "project_id", + "service", + "impersonation_chain", ) - template_fields_renderers = {'service': 'json'} + template_fields_renderers = {"service": "json"} operator_extra_links = (DataprocMetastoreLink(),) def __init__( @@ -425,14 +405,14 @@ def __init__( *, region: str, project_id: str, - service: Union[Dict, Service], + service: dict | Service, service_id: str, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -447,7 +427,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: "Context") -> dict: + def execute(self, context: Context) -> dict: hook = DataprocMetastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -466,7 +446,7 @@ def execute(self, context: "Context") -> dict: service = hook.wait_for_operation(self.timeout, operation) self.log.info("Service %s created successfully", self.service_id) except HttpError as err: - if err.resp.status not in (409, '409'): + if err.resp.status not in (409, "409"): raise self.log.info("Instance %s already exists", self.service_id) service = hook.get_service( @@ -516,8 +496,8 @@ class DataprocMetastoreDeleteBackupOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'impersonation_chain', + "project_id", + "impersonation_chain", ) def __init__( @@ -527,12 +507,12 @@ def __init__( region: str, service_id: str, backup_id: str, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -547,7 +527,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: hook = DataprocMetastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -580,8 +560,8 @@ class DataprocMetastoreDeleteServiceOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'impersonation_chain', + "project_id", + "impersonation_chain", ) def __init__( @@ -590,11 +570,11 @@ def __init__( region: str, project_id: str, service_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -607,7 +587,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = DataprocMetastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -656,8 +636,8 @@ class DataprocMetastoreExportMetadataOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'impersonation_chain', + "project_id", + "impersonation_chain", ) operator_extra_links = (DataprocMetastoreLink(), StorageLink()) @@ -668,13 +648,13 @@ def __init__( project_id: str, region: str, service_id: str, - request_id: Optional[str] = None, - database_dump_type: Optional[DatabaseDumpSpec] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + request_id: str | None = None, + database_dump_type: DatabaseDumpSpec | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -690,7 +670,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = DataprocMetastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -711,7 +691,7 @@ def execute(self, context: "Context"): DataprocMetastoreLink.persist(context=context, task_instance=self, url=METASTORE_EXPORT_LINK) uri = self._get_uri_from_destination(MetadataExport.to_dict(metadata_export)["destination_gcs_uri"]) - StorageLink.persist(context=context, task_instance=self, uri=uri) + StorageLink.persist(context=context, task_instance=self, uri=uri, project_id=self.project_id) return MetadataExport.to_dict(metadata_export) def _get_uri_from_destination(self, destination_uri: str): @@ -770,8 +750,8 @@ class DataprocMetastoreGetServiceOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'impersonation_chain', + "project_id", + "impersonation_chain", ) operator_extra_links = (DataprocMetastoreLink(),) @@ -781,11 +761,11 @@ def __init__( region: str, project_id: str, service_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -798,7 +778,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: "Context") -> dict: + def execute(self, context: Context) -> dict: hook = DataprocMetastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -843,8 +823,8 @@ class DataprocMetastoreListBackupsOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'impersonation_chain', + "project_id", + "impersonation_chain", ) operator_extra_links = (DataprocMetastoreLink(),) @@ -854,15 +834,15 @@ def __init__( project_id: str, region: str, service_id: str, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - filter: Optional[str] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + page_size: int | None = None, + page_token: str | None = None, + filter: str | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -879,7 +859,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: "Context") -> List[dict]: + def execute(self, context: Context) -> list[dict]: hook = DataprocMetastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -940,8 +920,8 @@ class DataprocMetastoreRestoreServiceOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'impersonation_chain', + "project_id", + "impersonation_chain", ) operator_extra_links = (DataprocMetastoreLink(),) @@ -955,13 +935,13 @@ def __init__( backup_region: str, backup_service_id: str, backup_id: str, - restore_type: Optional[Restore] = None, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + restore_type: Restore | None = None, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -980,7 +960,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = DataprocMetastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) @@ -1070,8 +1050,8 @@ class DataprocMetastoreUpdateServiceOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'impersonation_chain', + "project_id", + "impersonation_chain", ) operator_extra_links = (DataprocMetastoreLink(),) @@ -1081,14 +1061,14 @@ def __init__( project_id: str, region: str, service_id: str, - service: Union[Dict, Service], + service: dict | Service, update_mask: FieldMask, - request_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1104,7 +1084,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = DataprocMetastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) diff --git a/airflow/providers/google/cloud/operators/datastore.py b/airflow/providers/google/cloud/operators/datastore.py index 8a92665e3694e..21e318decfea9 100644 --- a/airflow/providers/google/cloud/operators/datastore.py +++ b/airflow/providers/google/cloud/operators/datastore.py @@ -15,9 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains Google Datastore operators.""" -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -72,11 +73,11 @@ class CloudDatastoreExportEntitiesOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'bucket', - 'namespace', - 'entity_filter', - 'labels', - 'impersonation_chain', + "bucket", + "namespace", + "entity_filter", + "labels", + "impersonation_chain", ) operator_extra_links = (StorageLink(),) @@ -84,16 +85,16 @@ def __init__( self, *, bucket: str, - namespace: Optional[str] = None, - datastore_conn_id: str = 'google_cloud_default', - cloud_storage_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - entity_filter: Optional[dict] = None, - labels: Optional[dict] = None, + namespace: str | None = None, + datastore_conn_id: str = "google_cloud_default", + cloud_storage_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + entity_filter: dict | None = None, + labels: dict | None = None, polling_interval_in_seconds: int = 10, overwrite_existing: bool = False, - project_id: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -109,8 +110,8 @@ def __init__( self.project_id = project_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> dict: - self.log.info('Exporting data to Cloud Storage bucket %s', self.bucket) + def execute(self, context: Context) -> dict: + self.log.info("Exporting data to Cloud Storage bucket %s", self.bucket) if self.overwrite_existing and self.namespace: gcs_hook = GCSHook(self.cloud_storage_conn_id, impersonation_chain=self.impersonation_chain) @@ -130,16 +131,17 @@ def execute(self, context: 'Context') -> dict: labels=self.labels, project_id=self.project_id, ) - operation_name = result['name'] + operation_name = result["name"] result = ds_hook.poll_operation_until_done(operation_name, self.polling_interval_in_seconds) - state = result['metadata']['common']['state'] - if state != 'SUCCESSFUL': - raise AirflowException(f'Operation failed: result={result}') + state = result["metadata"]["common"]["state"] + if state != "SUCCESSFUL": + raise AirflowException(f"Operation failed: result={result}") StorageLink.persist( context=context, task_instance=self, uri=f"{self.bucket}/{result['response']['outputUrl'].split('/')[3]}", + project_id=self.project_id or ds_hook.project_id, ) return result @@ -181,12 +183,12 @@ class CloudDatastoreImportEntitiesOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'bucket', - 'file', - 'namespace', - 'entity_filter', - 'labels', - 'impersonation_chain', + "bucket", + "file", + "namespace", + "entity_filter", + "labels", + "impersonation_chain", ) operator_extra_links = (CloudDatastoreImportExportLink(),) @@ -195,14 +197,14 @@ def __init__( *, bucket: str, file: str, - namespace: Optional[str] = None, - entity_filter: Optional[dict] = None, - labels: Optional[dict] = None, - datastore_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, + namespace: str | None = None, + entity_filter: dict | None = None, + labels: dict | None = None, + datastore_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, polling_interval_in_seconds: float = 10, - project_id: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -217,8 +219,8 @@ def __init__( self.project_id = project_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): - self.log.info('Importing data from Cloud Storage bucket %s', self.bucket) + def execute(self, context: Context): + self.log.info("Importing data from Cloud Storage bucket %s", self.bucket) ds_hook = DatastoreHook( self.datastore_conn_id, self.delegate_to, @@ -232,12 +234,12 @@ def execute(self, context: 'Context'): labels=self.labels, project_id=self.project_id, ) - operation_name = result['name'] + operation_name = result["name"] result = ds_hook.poll_operation_until_done(operation_name, self.polling_interval_in_seconds) - state = result['metadata']['common']['state'] - if state != 'SUCCESSFUL': - raise AirflowException(f'Operation failed: result={result}') + state = result["metadata"]["common"]["state"] + if state != "SUCCESSFUL": + raise AirflowException(f"Operation failed: result={result}") CloudDatastoreImportExportLink.persist(context=context, task_instance=self) return result @@ -279,11 +281,11 @@ class CloudDatastoreAllocateIdsOperator(BaseOperator): def __init__( self, *, - partial_keys: List, - project_id: Optional[str] = None, - delegate_to: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + partial_keys: list, + project_id: str | None = None, + delegate_to: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -294,7 +296,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> list: + def execute(self, context: Context) -> list: hook = DatastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -342,11 +344,11 @@ class CloudDatastoreBeginTransactionOperator(BaseOperator): def __init__( self, *, - transaction_options: Dict[str, Any], - project_id: Optional[str] = None, - delegate_to: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + transaction_options: dict[str, Any], + project_id: str | None = None, + delegate_to: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -357,7 +359,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> str: + def execute(self, context: Context) -> str: hook = DatastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -405,11 +407,11 @@ class CloudDatastoreCommitOperator(BaseOperator): def __init__( self, *, - body: Dict[str, Any], - project_id: Optional[str] = None, - delegate_to: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + body: dict[str, Any], + project_id: str | None = None, + delegate_to: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -420,7 +422,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: hook = DatastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -469,10 +471,10 @@ def __init__( self, *, transaction: str, - project_id: Optional[str] = None, - delegate_to: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + delegate_to: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -483,7 +485,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = DatastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -529,11 +531,11 @@ class CloudDatastoreRunQueryOperator(BaseOperator): def __init__( self, *, - body: Dict[str, Any], - project_id: Optional[str] = None, - delegate_to: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + body: dict[str, Any], + project_id: str | None = None, + delegate_to: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -544,7 +546,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: hook = DatastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -591,9 +593,9 @@ def __init__( self, *, name: str, - delegate_to: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -603,7 +605,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = DatastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -647,9 +649,9 @@ def __init__( self, *, name: str, - delegate_to: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -659,7 +661,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = DatastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, diff --git a/airflow/providers/google/cloud/operators/dlp.py b/airflow/providers/google/cloud/operators/dlp.py index bd5dedbf7b45e..09ffd21591cfd 100644 --- a/airflow/providers/google/cloud/operators/dlp.py +++ b/airflow/providers/google/cloud/operators/dlp.py @@ -15,14 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - - """ This module contains various Google Cloud DLP operators which allow you to perform basic operations using Cloud DLP. """ -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from google.api_core.exceptions import AlreadyExists, InvalidArgument, NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -45,6 +45,19 @@ from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.dlp import CloudDLPHook +from airflow.providers.google.cloud.links.data_loss_prevention import ( + CloudDLPDeidentifyTemplateDetailsLink, + CloudDLPDeidentifyTemplatesListLink, + CloudDLPInfoTypeDetailsLink, + CloudDLPInfoTypesListLink, + CloudDLPInspectTemplateDetailsLink, + CloudDLPInspectTemplatesListLink, + CloudDLPJobDetailsLink, + CloudDLPJobsListLink, + CloudDLPJobTriggerDetailsLink, + CloudDLPJobTriggersListLink, + CloudDLPPossibleInfoTypesListLink, +) if TYPE_CHECKING: from airflow.utils.context import Context @@ -85,17 +98,18 @@ class CloudDLPCancelDLPJobOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPJobDetailsLink(),) def __init__( self, *, dlp_job_id: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -107,7 +121,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -120,6 +134,15 @@ def execute(self, context: 'Context') -> None: metadata=self.metadata, ) + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPJobDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + job_name=self.dlp_job_id, + ) + class CloudDLPCreateDeidentifyTemplateOperator(BaseOperator): """ @@ -153,7 +176,6 @@ class CloudDLPCreateDeidentifyTemplateOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.dlp_v2.types.DeidentifyTemplate """ template_fields: Sequence[str] = ( @@ -164,19 +186,20 @@ class CloudDLPCreateDeidentifyTemplateOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPDeidentifyTemplateDetailsLink(),) def __init__( self, *, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - deidentify_template: Optional[Union[Dict, DeidentifyTemplate]] = None, - template_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + deidentify_template: dict | DeidentifyTemplate | None = None, + template_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -190,7 +213,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -216,8 +239,19 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + result = MessageToDict(template) + + project_id = self.project_id or hook.project_id + template_id = self.template_id or result["name"].split("/")[-1] if result["name"] else None + if project_id and template_id: + CloudDLPDeidentifyTemplateDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + template_name=template_id, + ) - return MessageToDict(template) + return result class CloudDLPCreateDLPJobOperator(BaseOperator): @@ -252,7 +286,6 @@ class CloudDLPCreateDLPJobOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.dlp_v2.types.DlpJob """ template_fields: Sequence[str] = ( @@ -263,20 +296,21 @@ class CloudDLPCreateDLPJobOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPJobDetailsLink(),) def __init__( self, *, - project_id: Optional[str] = None, - inspect_job: Optional[Union[Dict, InspectJobConfig]] = None, - risk_job: Optional[Union[Dict, RiskAnalysisJobConfig]] = None, - job_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + inspect_job: dict | InspectJobConfig | None = None, + risk_job: dict | RiskAnalysisJobConfig | None = None, + job_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), wait_until_finished: bool = True, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -291,7 +325,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -317,7 +351,19 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) - return MessageToDict(job) + + result = MessageToDict(job) + + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPJobDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + job_name=result["name"].split("/")[-1] if result["name"] else None, + ) + + return result class CloudDLPCreateInspectTemplateOperator(BaseOperator): @@ -352,7 +398,6 @@ class CloudDLPCreateInspectTemplateOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.dlp_v2.types.InspectTemplate """ template_fields: Sequence[str] = ( @@ -363,19 +408,20 @@ class CloudDLPCreateInspectTemplateOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPInspectTemplateDetailsLink(),) def __init__( self, *, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - inspect_template: Optional[InspectTemplate] = None, - template_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + inspect_template: InspectTemplate | None = None, + template_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -389,7 +435,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -415,7 +461,20 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) - return MessageToDict(template) + + result = MessageToDict(template) + + template_id = self.template_id or result["name"].split("/")[-1] if result["name"] else None + project_id = self.project_id or hook.project_id + if project_id and template_id: + CloudDLPInspectTemplateDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + template_name=template_id, + ) + + return result class CloudDLPCreateJobTriggerOperator(BaseOperator): @@ -448,7 +507,6 @@ class CloudDLPCreateJobTriggerOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.dlp_v2.types.JobTrigger """ template_fields: Sequence[str] = ( @@ -458,18 +516,19 @@ class CloudDLPCreateJobTriggerOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPJobTriggerDetailsLink(),) def __init__( self, *, - project_id: Optional[str] = None, - job_trigger: Optional[Union[Dict, JobTrigger]] = None, - trigger_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + job_trigger: dict | JobTrigger | None = None, + trigger_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -482,7 +541,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -508,7 +567,20 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) - return MessageToDict(trigger) + + result = MessageToDict(trigger) + + project_id = self.project_id or hook.project_id + trigger_name = result["name"].split("/")[-1] if result["name"] else None + if project_id: + CloudDLPJobTriggerDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + trigger_name=trigger_name, + ) + + return result class CloudDLPCreateStoredInfoTypeOperator(BaseOperator): @@ -542,7 +614,6 @@ class CloudDLPCreateStoredInfoTypeOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.dlp_v2.types.StoredInfoType """ template_fields: Sequence[str] = ( @@ -553,19 +624,20 @@ class CloudDLPCreateStoredInfoTypeOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPInfoTypeDetailsLink(),) def __init__( self, *, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - config: Optional[StoredInfoTypeConfig] = None, - stored_info_type_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + config: StoredInfoTypeConfig | None = None, + stored_info_type_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -579,7 +651,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -607,7 +679,22 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) - return MessageToDict(info) + + result = MessageToDict(info) + + project_id = self.project_id or hook.project_id + stored_info_type_id = ( + self.stored_info_type_id or result["name"].split("/")[-1] if result["name"] else None + ) + if project_id and stored_info_type_id: + CloudDLPInfoTypeDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + info_type_name=stored_info_type_id, + ) + + return result class CloudDLPDeidentifyContentOperator(BaseOperator): @@ -649,7 +736,6 @@ class CloudDLPDeidentifyContentOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.dlp_v2.types.DeidentifyContentResponse """ template_fields: Sequence[str] = ( @@ -666,17 +752,17 @@ class CloudDLPDeidentifyContentOperator(BaseOperator): def __init__( self, *, - project_id: Optional[str] = None, - deidentify_config: Optional[Union[Dict, DeidentifyConfig]] = None, - inspect_config: Optional[Union[Dict, InspectConfig]] = None, - item: Optional[Union[Dict, ContentItem]] = None, - inspect_template_name: Optional[str] = None, - deidentify_template_name: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + deidentify_config: dict | DeidentifyConfig | None = None, + inspect_config: dict | InspectConfig | None = None, + item: dict | ContentItem | None = None, + inspect_template_name: str | None = None, + deidentify_template_name: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -692,7 +778,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -749,18 +835,19 @@ class CloudDLPDeleteDeidentifyTemplateOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPDeidentifyTemplatesListLink(),) def __init__( self, *, template_id: str, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -773,7 +860,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -787,6 +874,13 @@ def execute(self, context: 'Context') -> None: timeout=self.timeout, metadata=self.metadata, ) + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPDeidentifyTemplatesListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) except NotFound: self.log.error("Template %s not found.", self.template_id) @@ -800,7 +894,7 @@ class CloudDLPDeleteDLPJobOperator(BaseOperator): For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:CloudDLPDeleteDLPJobOperator` - :param dlp_job_id: The ID of the DLP job resource to be cancelled. + :param dlp_job_id: The ID of the DLP job resource to be deleted. :param project_id: (Optional) Google Cloud project ID where the DLP Instance exists. If set to None or missing, the default project_id from the Google Cloud connection is used. @@ -827,17 +921,18 @@ class CloudDLPDeleteDLPJobOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPJobsListLink(),) def __init__( self, *, dlp_job_id: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -849,7 +944,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -862,6 +957,15 @@ def execute(self, context: 'Context') -> None: timeout=self.timeout, metadata=self.metadata, ) + + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPJobsListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) + except NotFound: self.log.error("Job %s id not found.", self.dlp_job_id) @@ -904,18 +1008,19 @@ class CloudDLPDeleteInspectTemplateOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPInspectTemplatesListLink(),) def __init__( self, *, template_id: str, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -928,7 +1033,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -942,6 +1047,15 @@ def execute(self, context: 'Context') -> None: timeout=self.timeout, metadata=self.metadata, ) + + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPInspectTemplatesListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) + except NotFound: self.log.error("Template %s not found", self.template_id) @@ -981,17 +1095,18 @@ class CloudDLPDeleteJobTriggerOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPJobTriggersListLink(),) def __init__( self, *, job_trigger_id: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1003,7 +1118,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1016,6 +1131,15 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPJobTriggersListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) + except NotFound: self.log.error("Trigger %s not found", self.job_trigger_id) @@ -1058,18 +1182,19 @@ class CloudDLPDeleteStoredInfoTypeOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPInfoTypesListLink(),) def __init__( self, *, stored_info_type_id: str, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1082,7 +1207,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1099,6 +1224,14 @@ def execute(self, context: 'Context'): except NotFound: self.log.error("Stored info %s not found", self.stored_info_type_id) + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPInfoTypesListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) + class CloudDLPGetDeidentifyTemplateOperator(BaseOperator): """ @@ -1130,7 +1263,6 @@ class CloudDLPGetDeidentifyTemplateOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.dlp_v2.types.DeidentifyTemplate """ template_fields: Sequence[str] = ( @@ -1140,18 +1272,19 @@ class CloudDLPGetDeidentifyTemplateOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPDeidentifyTemplateDetailsLink(),) def __init__( self, *, template_id: str, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1164,7 +1297,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1177,6 +1310,13 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPDeidentifyTemplateDetailsLink.persist( + context=context, task_instance=self, project_id=project_id, template_name=self.template_id + ) + return MessageToDict(template) @@ -1208,7 +1348,6 @@ class CloudDLPGetDLPJobOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.dlp_v2.types.DlpJob """ template_fields: Sequence[str] = ( @@ -1217,17 +1356,18 @@ class CloudDLPGetDLPJobOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPJobDetailsLink(),) def __init__( self, *, dlp_job_id: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1239,7 +1379,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1251,6 +1391,16 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPJobDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + job_name=self.dlp_job_id, + ) + return MessageToDict(job) @@ -1284,7 +1434,6 @@ class CloudDLPGetInspectTemplateOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.dlp_v2.types.InspectTemplate """ template_fields: Sequence[str] = ( @@ -1294,18 +1443,19 @@ class CloudDLPGetInspectTemplateOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPInspectTemplateDetailsLink(),) def __init__( self, *, template_id: str, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1318,7 +1468,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1331,6 +1481,16 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPInspectTemplateDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + template_name=self.template_id, + ) + return MessageToDict(template) @@ -1362,7 +1522,6 @@ class CloudDLPGetDLPJobTriggerOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.dlp_v2.types.JobTrigger """ template_fields: Sequence[str] = ( @@ -1371,17 +1530,18 @@ class CloudDLPGetDLPJobTriggerOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPJobTriggerDetailsLink(),) def __init__( self, *, job_trigger_id: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1393,7 +1553,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1405,6 +1565,16 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPJobTriggerDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + trigger_name=self.job_trigger_id, + ) + return MessageToDict(trigger) @@ -1438,7 +1608,6 @@ class CloudDLPGetStoredInfoTypeOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.dlp_v2.types.StoredInfoType """ template_fields: Sequence[str] = ( @@ -1448,18 +1617,19 @@ class CloudDLPGetStoredInfoTypeOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPInfoTypeDetailsLink(),) def __init__( self, *, stored_info_type_id: str, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1472,7 +1642,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1485,6 +1655,16 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPInfoTypeDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + info_type_name=self.stored_info_type_id, + ) + return MessageToDict(info) @@ -1521,7 +1701,6 @@ class CloudDLPInspectContentOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.tasks_v2.types.InspectContentResponse """ template_fields: Sequence[str] = ( @@ -1536,15 +1715,15 @@ class CloudDLPInspectContentOperator(BaseOperator): def __init__( self, *, - project_id: Optional[str] = None, - inspect_config: Optional[Union[Dict, InspectConfig]] = None, - item: Optional[Union[Dict, ContentItem]] = None, - inspect_template_name: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + inspect_config: dict | InspectConfig | None = None, + item: dict | ContentItem | None = None, + inspect_template_name: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1558,7 +1737,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1608,7 +1787,6 @@ class CloudDLPListDeidentifyTemplatesOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: list[google.cloud.dlp_v2.types.DeidentifyTemplate] """ template_fields: Sequence[str] = ( @@ -1617,19 +1795,20 @@ class CloudDLPListDeidentifyTemplatesOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPDeidentifyTemplatesListLink(),) def __init__( self, *, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - page_size: Optional[int] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + page_size: int | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1643,12 +1822,12 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - template = hook.list_deidentify_templates( + templates = hook.list_deidentify_templates( organization_id=self.organization_id, project_id=self.project_id, page_size=self.page_size, @@ -1658,7 +1837,16 @@ def execute(self, context: 'Context'): metadata=self.metadata, ) # the MessageToDict does not have the right type defined as possible to pass in constructor - return MessageToDict(template) # type: ignore[arg-type] + + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPDeidentifyTemplatesListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) + + return [MessageToDict(template) for template in templates] # type: ignore[arg-type] class CloudDLPListDLPJobsOperator(BaseOperator): @@ -1694,7 +1882,6 @@ class CloudDLPListDLPJobsOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: list[google.cloud.dlp_v2.types.DlpJob] """ template_fields: Sequence[str] = ( @@ -1702,20 +1889,21 @@ class CloudDLPListDLPJobsOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPJobsListLink(),) def __init__( self, *, - project_id: Optional[str] = None, - results_filter: Optional[str] = None, - page_size: Optional[int] = None, - job_type: Optional[str] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + results_filter: str | None = None, + page_size: int | None = None, + job_type: str | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1730,12 +1918,12 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - job = hook.list_dlp_jobs( + jobs = hook.list_dlp_jobs( project_id=self.project_id, results_filter=self.results_filter, page_size=self.page_size, @@ -1745,8 +1933,17 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPJobsListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) + # the MessageToDict does not have the right type defined as possible to pass in constructor - return MessageToDict(job) # type: ignore[arg-type] + return [MessageToDict(job) for job in jobs] # type: ignore[arg-type] class CloudDLPListInfoTypesOperator(BaseOperator): @@ -1777,7 +1974,6 @@ class CloudDLPListInfoTypesOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: ListInfoTypesResponse """ template_fields: Sequence[str] = ( @@ -1785,20 +1981,23 @@ class CloudDLPListInfoTypesOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPPossibleInfoTypesListLink(),) def __init__( self, *, - language_code: Optional[str] = None, - results_filter: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + language_code: str | None = None, + results_filter: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) + self.project_id = project_id self.language_code = language_code self.results_filter = results_filter self.retry = retry @@ -1807,7 +2006,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1819,6 +2018,15 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPPossibleInfoTypesListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) + return MessageToDict(response) @@ -1855,7 +2063,6 @@ class CloudDLPListInspectTemplatesOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: list[google.cloud.dlp_v2.types.InspectTemplate] """ template_fields: Sequence[str] = ( @@ -1864,19 +2071,20 @@ class CloudDLPListInspectTemplatesOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPInspectTemplatesListLink(),) def __init__( self, *, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - page_size: Optional[int] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + page_size: int | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1890,7 +2098,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1904,6 +2112,15 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPInspectTemplatesListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) + return [MessageToDict(t) for t in templates] @@ -1939,7 +2156,6 @@ class CloudDLPListJobTriggersOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: list[google.cloud.dlp_v2.types.JobTrigger] """ template_fields: Sequence[str] = ( @@ -1947,19 +2163,20 @@ class CloudDLPListJobTriggersOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPJobTriggersListLink(),) def __init__( self, *, - project_id: Optional[str] = None, - page_size: Optional[int] = None, - order_by: Optional[str] = None, - results_filter: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + page_size: int | None = None, + order_by: str | None = None, + results_filter: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1973,7 +2190,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1987,6 +2204,15 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPJobTriggersListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) + return [MessageToDict(j) for j in jobs] @@ -2023,7 +2249,6 @@ class CloudDLPListStoredInfoTypesOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: list[google.cloud.dlp_v2.types.StoredInfoType] """ template_fields: Sequence[str] = ( @@ -2032,19 +2257,20 @@ class CloudDLPListStoredInfoTypesOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPInfoTypesListLink(),) def __init__( self, *, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - page_size: Optional[int] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + page_size: int | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -2058,7 +2284,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -2072,6 +2298,15 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPInfoTypesListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) + return [MessageToDict(i) for i in infos] @@ -2110,7 +2345,6 @@ class CloudDLPRedactImageOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.dlp_v2.types.RedactImageResponse """ template_fields: Sequence[str] = ( @@ -2126,18 +2360,16 @@ class CloudDLPRedactImageOperator(BaseOperator): def __init__( self, *, - project_id: Optional[str] = None, - inspect_config: Optional[Union[Dict, InspectConfig]] = None, - image_redaction_configs: Optional[ - Union[List[dict], List[RedactImageRequest.ImageRedactionConfig]] - ] = None, - include_findings: Optional[bool] = None, - byte_item: Optional[Union[Dict, ByteContentItem]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + inspect_config: dict | InspectConfig | None = None, + image_redaction_configs: None | (list[dict] | list[RedactImageRequest.ImageRedactionConfig]) = None, + include_findings: bool | None = None, + byte_item: dict | ByteContentItem | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -2152,7 +2384,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -2206,7 +2438,6 @@ class CloudDLPReidentifyContentOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.dlp_v2.types.ReidentifyContentResponse """ template_fields: Sequence[str] = ( @@ -2223,17 +2454,17 @@ class CloudDLPReidentifyContentOperator(BaseOperator): def __init__( self, *, - project_id: Optional[str] = None, - reidentify_config: Optional[Union[Dict, DeidentifyConfig]] = None, - inspect_config: Optional[Union[Dict, InspectConfig]] = None, - item: Optional[Union[Dict, ContentItem]] = None, - inspect_template_name: Optional[str] = None, - reidentify_template_name: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + reidentify_config: dict | DeidentifyConfig | None = None, + inspect_config: dict | InspectConfig | None = None, + item: dict | ContentItem | None = None, + inspect_template_name: str | None = None, + reidentify_template_name: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -2249,7 +2480,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -2300,7 +2531,6 @@ class CloudDLPUpdateDeidentifyTemplateOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.dlp_v2.types.DeidentifyTemplate """ template_fields: Sequence[str] = ( @@ -2312,20 +2542,21 @@ class CloudDLPUpdateDeidentifyTemplateOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPDeidentifyTemplateDetailsLink(),) def __init__( self, *, template_id: str, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - deidentify_template: Optional[Union[Dict, DeidentifyTemplate]] = None, - update_mask: Optional[Union[Dict, FieldMask]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + deidentify_template: dict | DeidentifyTemplate | None = None, + update_mask: dict | FieldMask | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -2340,7 +2571,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -2355,6 +2586,16 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPDeidentifyTemplateDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + template_name=self.template_id, + ) + return MessageToDict(template) @@ -2390,7 +2631,6 @@ class CloudDLPUpdateInspectTemplateOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.dlp_v2.types.InspectTemplate """ template_fields: Sequence[str] = ( @@ -2402,20 +2642,21 @@ class CloudDLPUpdateInspectTemplateOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPInspectTemplateDetailsLink(),) def __init__( self, *, template_id: str, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - inspect_template: Optional[Union[Dict, InspectTemplate]] = None, - update_mask: Optional[Union[Dict, FieldMask]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + inspect_template: dict | InspectTemplate | None = None, + update_mask: dict | FieldMask | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -2430,7 +2671,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -2445,6 +2686,16 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPInspectTemplateDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + template_name=self.template_id, + ) + return MessageToDict(template) @@ -2478,7 +2729,6 @@ class CloudDLPUpdateJobTriggerOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.dlp_v2.types.InspectTemplate """ template_fields: Sequence[str] = ( @@ -2489,19 +2739,20 @@ class CloudDLPUpdateJobTriggerOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPJobTriggerDetailsLink(),) def __init__( self, *, job_trigger_id, - project_id: Optional[str] = None, - job_trigger: Optional[JobTrigger] = None, - update_mask: Optional[Union[Dict, FieldMask]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + job_trigger: JobTrigger | None = None, + update_mask: dict | FieldMask | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -2515,7 +2766,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -2529,6 +2780,16 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPJobTriggerDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + trigger_name=self.job_trigger_id, + ) + return MessageToDict(trigger) @@ -2565,7 +2826,6 @@ class CloudDLPUpdateStoredInfoTypeOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.dlp_v2.types.StoredInfoType """ template_fields: Sequence[str] = ( @@ -2577,20 +2837,21 @@ class CloudDLPUpdateStoredInfoTypeOperator(BaseOperator): "gcp_conn_id", "impersonation_chain", ) + operator_extra_links = (CloudDLPInfoTypeDetailsLink(),) def __init__( self, *, stored_info_type_id, - organization_id: Optional[str] = None, - project_id: Optional[str] = None, - config: Optional[Union[Dict, StoredInfoTypeConfig]] = None, - update_mask: Optional[Union[Dict, FieldMask]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + organization_id: str | None = None, + project_id: str | None = None, + config: dict | StoredInfoTypeConfig | None = None, + update_mask: dict | FieldMask | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -2605,7 +2866,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudDLPHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -2620,4 +2881,14 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + project_id = self.project_id or hook.project_id + if project_id: + CloudDLPInfoTypeDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + info_type_name=self.stored_info_type_id, + ) + return MessageToDict(info) diff --git a/airflow/providers/google/cloud/operators/functions.py b/airflow/providers/google/cloud/operators/functions.py index 6c84a1fdb60fa..e895c85e55da4 100644 --- a/airflow/providers/google/cloud/operators/functions.py +++ b/airflow/providers/google/cloud/operators/functions.py @@ -16,15 +16,20 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Cloud Functions operators.""" +from __future__ import annotations import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Sequence from googleapiclient.errors import HttpError from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.functions import CloudFunctionsHook +from airflow.providers.google.cloud.links.cloud_functions import ( + CloudFunctionsDetailsLink, + CloudFunctionsListLink, +) from airflow.providers.google.cloud.utils.field_validator import ( GcpBodyFieldValidator, GcpFieldValidationException, @@ -45,24 +50,24 @@ def _validate_max_instances(value): raise GcpFieldValidationException("The max instances parameter has to be greater than 0") -CLOUD_FUNCTION_VALIDATION = [ +CLOUD_FUNCTION_VALIDATION: list[dict[str, Any]] = [ dict(name="name", regexp="^.+$"), dict(name="description", regexp="^.+$", optional=True), - dict(name="entryPoint", regexp=r'^.+$', optional=True), - dict(name="runtime", regexp=r'^.+$', optional=True), - dict(name="timeout", regexp=r'^.+$', optional=True), + dict(name="entryPoint", regexp=r"^.+$", optional=True), + dict(name="runtime", regexp=r"^.+$", optional=True), + dict(name="timeout", regexp=r"^.+$", optional=True), dict(name="availableMemoryMb", custom_validation=_validate_available_memory_in_mb, optional=True), dict(name="labels", optional=True), dict(name="environmentVariables", optional=True), - dict(name="network", regexp=r'^.+$', optional=True), + dict(name="network", regexp=r"^.+$", optional=True), dict(name="maxInstances", optional=True, custom_validation=_validate_max_instances), dict( name="source_code", type="union", fields=[ - dict(name="sourceArchiveUrl", regexp=r'^.+$'), - dict(name="sourceRepositoryUrl", regexp=r'^.+$', api_version='v1beta2'), - dict(name="sourceRepository", type="dict", fields=[dict(name="url", regexp=r'^.+$')]), + dict(name="sourceArchiveUrl", regexp=r"^.+$"), + dict(name="sourceRepositoryUrl", regexp=r"^.+$", api_version="v1beta2"), + dict(name="sourceRepository", type="dict", fields=[dict(name="url", regexp=r"^.+$")]), dict(name="sourceUploadUrl"), ], ), @@ -81,9 +86,9 @@ def _validate_max_instances(value): name="eventTrigger", type="dict", fields=[ - dict(name="eventType", regexp=r'^.+$'), - dict(name="resource", regexp=r'^.+$'), - dict(name="service", regexp=r'^.+$', optional=True), + dict(name="eventType", regexp=r"^.+$"), + dict(name="resource", regexp=r"^.+$"), + dict(name="service", regexp=r"^.+$", optional=True), dict( name="failurePolicy", type="dict", @@ -94,7 +99,7 @@ def _validate_max_instances(value): ), ], ), -] # type: List[Dict[str, Any]] +] class CloudFunctionDeployFunctionOperator(BaseOperator): @@ -135,26 +140,27 @@ class CloudFunctionDeployFunctionOperator(BaseOperator): # [START gcf_function_deploy_template_fields] template_fields: Sequence[str] = ( - 'body', - 'project_id', - 'location', - 'gcp_conn_id', - 'api_version', - 'impersonation_chain', + "body", + "project_id", + "location", + "gcp_conn_id", + "api_version", + "impersonation_chain", ) # [END gcf_function_deploy_template_fields] + operator_extra_links = (CloudFunctionsDetailsLink(),) def __init__( self, *, location: str, - body: Dict, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1', - zip_path: Optional[str] = None, + body: dict, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + zip_path: str | None = None, validate_body: bool = True, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.project_id = project_id @@ -164,7 +170,7 @@ def __init__( self.api_version = api_version self.zip_path = zip_path self.zip_path_preprocessor = ZipPathPreprocessor(body, zip_path) - self._field_validator = None # type: Optional[GcpBodyFieldValidator] + self._field_validator: GcpBodyFieldValidator | None = None self.impersonation_chain = impersonation_chain if validate_body: self._field_validator = GcpBodyFieldValidator(CLOUD_FUNCTION_VALIDATION, api_version=api_version) @@ -186,10 +192,10 @@ def _create_new_function(self, hook) -> None: hook.create_new_function(project_id=self.project_id, location=self.location, body=self.body) def _update_function(self, hook) -> None: - hook.update_function(self.body['name'], self.body, self.body.keys()) + hook.update_function(self.body["name"], self.body, self.body.keys()) def _check_if_function_exists(self, hook) -> bool: - name = self.body.get('name') + name = self.body.get("name") if not name: raise GcpFieldValidationException(f"The 'name' field should be present in body: '{self.body}'.") try: @@ -207,11 +213,11 @@ def _upload_source_code(self, hook): ) def _set_airflow_version_label(self) -> None: - if 'labels' not in self.body.keys(): - self.body['labels'] = {} - self.body['labels'].update({'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')}) + if "labels" not in self.body.keys(): + self.body["labels"] = {} + self.body["labels"].update({"airflow-version": "v" + version.replace(".", "-").replace("+", "-")}) - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudFunctionsHook( gcp_conn_id=self.gcp_conn_id, api_version=self.api_version, @@ -225,12 +231,21 @@ def execute(self, context: 'Context'): self._create_new_function(hook) else: self._update_function(hook) + project_id = self.project_id or hook.project_id + if project_id: + CloudFunctionsDetailsLink.persist( + context=context, + task_instance=self, + location=self.location, + project_id=project_id, + function_name=self.body["name"].split("/")[-1], + ) -GCF_SOURCE_ARCHIVE_URL = 'sourceArchiveUrl' -GCF_SOURCE_UPLOAD_URL = 'sourceUploadUrl' -SOURCE_REPOSITORY = 'sourceRepository' -GCF_ZIP_PATH = 'zip_path' +GCF_SOURCE_ARCHIVE_URL = "sourceArchiveUrl" +GCF_SOURCE_UPLOAD_URL = "sourceUploadUrl" +SOURCE_REPOSITORY = "sourceRepository" +GCF_ZIP_PATH = "zip_path" class ZipPathPreprocessor: @@ -254,9 +269,9 @@ class ZipPathPreprocessor: """ - upload_function = None # type: Optional[bool] + upload_function: bool | None = None - def __init__(self, body: dict, zip_path: Optional[str] = None) -> None: + def __init__(self, body: dict, zip_path: str | None = None) -> None: self.body = body self.zip_path = zip_path @@ -291,13 +306,9 @@ def _verify_archive_url_and_zip_path(self) -> None: ) def should_upload_function(self) -> bool: - """ - Checks if function source should be uploaded. - - :rtype: bool - """ + """Checks if function source should be uploaded.""" if self.upload_function is None: - raise AirflowException('validate() method has to be invoked before should_upload_function') + raise AirflowException("validate() method has to be invoked before should_upload_function") return self.upload_function def preprocess_body(self) -> None: @@ -312,7 +323,7 @@ def preprocess_body(self) -> None: self.upload_function = False -FUNCTION_NAME_PATTERN = '^projects/[^/]+/locations/[^/]+/functions/[^/]+$' +FUNCTION_NAME_PATTERN = "^projects/[^/]+/locations/[^/]+/functions/[^/]+$" FUNCTION_NAME_COMPILED_PATTERN = re.compile(FUNCTION_NAME_PATTERN) @@ -340,23 +351,26 @@ class CloudFunctionDeleteFunctionOperator(BaseOperator): # [START gcf_function_delete_template_fields] template_fields: Sequence[str] = ( - 'name', - 'gcp_conn_id', - 'api_version', - 'impersonation_chain', + "name", + "gcp_conn_id", + "api_version", + "impersonation_chain", ) # [END gcf_function_delete_template_fields] + operator_extra_links = (CloudFunctionsListLink(),) def __init__( self, *, name: str, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + impersonation_chain: str | Sequence[str] | None = None, + project_id: str | None = None, **kwargs, ) -> None: self.name = name + self.project_id = project_id self.gcp_conn_id = gcp_conn_id self.api_version = api_version self.impersonation_chain = impersonation_chain @@ -365,27 +379,34 @@ def __init__( def _validate_inputs(self) -> None: if not self.name: - raise AttributeError('Empty parameter: name') + raise AttributeError("Empty parameter: name") else: pattern = FUNCTION_NAME_COMPILED_PATTERN if not pattern.match(self.name): - raise AttributeError(f'Parameter name must match pattern: {FUNCTION_NAME_PATTERN}') + raise AttributeError(f"Parameter name must match pattern: {FUNCTION_NAME_PATTERN}") - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudFunctionsHook( gcp_conn_id=self.gcp_conn_id, api_version=self.api_version, impersonation_chain=self.impersonation_chain, ) try: + project_id = self.project_id or hook.project_id + if project_id: + CloudFunctionsListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) return hook.delete_function(self.name) except HttpError as e: status = e.resp.status if status == 404: - self.log.info('The function does not exist in this project') + self.log.info("The function does not exist in this project") return None else: - self.log.error('An error occurred. Exiting.') + self.log.error("An error occurred. Exiting.") raise e @@ -416,23 +437,24 @@ class CloudFunctionInvokeFunctionOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'function_id', - 'input_data', - 'location', - 'project_id', - 'impersonation_chain', + "function_id", + "input_data", + "location", + "project_id", + "impersonation_chain", ) + operator_extra_links = (CloudFunctionsDetailsLink(),) def __init__( self, *, function_id: str, - input_data: Dict, + input_data: dict, location: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -444,19 +466,30 @@ def __init__( self.api_version = api_version self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudFunctionsHook( api_version=self.api_version, gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - self.log.info('Calling function %s.', self.function_id) + self.log.info("Calling function %s.", self.function_id) result = hook.call_function( function_id=self.function_id, input_data=self.input_data, location=self.location, project_id=self.project_id, ) - self.log.info('Function called successfully. Execution id %s', result.get('executionId')) - self.xcom_push(context=context, key='execution_id', value=result.get('executionId')) + self.log.info("Function called successfully. Execution id %s", result.get("executionId")) + self.xcom_push(context=context, key="execution_id", value=result.get("executionId")) + + project_id = self.project_id or hook.project_id + if project_id: + CloudFunctionsDetailsLink.persist( + context=context, + task_instance=self, + location=self.location, + project_id=project_id, + function_name=self.function_id, + ) + return result diff --git a/airflow/providers/google/cloud/operators/gcs.py b/airflow/providers/google/cloud/operators/gcs.py index 27cc6f79bd108..e5d249fa82e8c 100644 --- a/airflow/providers/google/cloud/operators/gcs.py +++ b/airflow/providers/google/cloud/operators/gcs.py @@ -16,12 +16,14 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Cloud Storage Bucket operator.""" +from __future__ import annotations + import datetime import subprocess import sys from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence import pendulum @@ -30,11 +32,11 @@ from google.api_core.exceptions import Conflict from google.cloud.exceptions import GoogleCloudError -from pendulum.datetime import DateTime from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.common.links.storage import FileDetailsLink, StorageLink from airflow.utils import timezone @@ -100,26 +102,27 @@ class GCSCreateBucketOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'bucket_name', - 'storage_class', - 'location', - 'project_id', - 'impersonation_chain', + "bucket_name", + "storage_class", + "location", + "project_id", + "impersonation_chain", ) - ui_color = '#f0eee4' + ui_color = "#f0eee4" + operator_extra_links = (StorageLink(),) def __init__( self, *, bucket_name: str, - resource: Optional[Dict] = None, - storage_class: str = 'MULTI_REGIONAL', - location: str = 'US', - project_id: Optional[str] = None, - labels: Optional[Dict] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + resource: dict | None = None, + storage_class: str = "MULTI_REGIONAL", + location: str = "US", + project_id: str | None = None, + labels: dict | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -133,12 +136,18 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: hook = GCSHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) + StorageLink.persist( + context=context, + task_instance=self, + uri=self.bucket_name, + project_id=self.project_id or hook.project_id, + ) try: hook.create_bucket( bucket_name=self.bucket_name, @@ -157,7 +166,7 @@ class GCSListObjectsOperator(BaseOperator): List all objects from the bucket with the given string prefix and delimiter in name. This operator returns a python list with the name of objects which can be used by - `xcom` in the downstream task. + XCom in the downstream task. :param bucket: The Google Cloud Storage bucket to find the objects. (templated) :param prefix: Prefix string which filters objects whose name begin with @@ -192,23 +201,25 @@ class GCSListObjectsOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'bucket', - 'prefix', - 'delimiter', - 'impersonation_chain', + "bucket", + "prefix", + "delimiter", + "impersonation_chain", ) - ui_color = '#f0eee4' + ui_color = "#f0eee4" + + operator_extra_links = (StorageLink(),) def __init__( self, *, bucket: str, - prefix: Optional[str] = None, - delimiter: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + prefix: str | None = None, + delimiter: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -219,7 +230,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context") -> list: + def execute(self, context: Context) -> list: hook = GCSHook( gcp_conn_id=self.gcp_conn_id, @@ -228,12 +239,19 @@ def execute(self, context: "Context") -> list: ) self.log.info( - 'Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s', + "Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s", self.bucket, self.delimiter, self.prefix, ) + StorageLink.persist( + context=context, + task_instance=self, + uri=self.bucket, + project_id=hook.project_id, + ) + return hook.list(bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter) @@ -263,21 +281,21 @@ class GCSDeleteObjectsOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'bucket_name', - 'prefix', - 'objects', - 'impersonation_chain', + "bucket_name", + "prefix", + "objects", + "impersonation_chain", ) def __init__( self, *, bucket_name: str, - objects: Optional[List[str]] = None, - prefix: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + objects: list[str] | None = None, + prefix: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: @@ -288,12 +306,15 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - if not objects and not prefix: - raise ValueError("Either object or prefix should be set. Both are None") + if objects is None and prefix is None: + err_message = "(Task {task_id}) Either object or prefix should be set. Both are None.".format( + **kwargs + ) + raise ValueError(err_message) super().__init__(**kwargs) - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: hook = GCSHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -339,13 +360,14 @@ class GCSBucketCreateAclEntryOperator(BaseOperator): # [START gcs_bucket_create_acl_template_fields] template_fields: Sequence[str] = ( - 'bucket', - 'entity', - 'role', - 'user_project', - 'impersonation_chain', + "bucket", + "entity", + "role", + "user_project", + "impersonation_chain", ) # [END gcs_bucket_create_acl_template_fields] + operator_extra_links = (StorageLink(),) def __init__( self, @@ -353,9 +375,9 @@ def __init__( bucket: str, entity: str, role: str, - user_project: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + user_project: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -366,11 +388,17 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: hook = GCSHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) + StorageLink.persist( + context=context, + task_instance=self, + uri=self.bucket, + project_id=hook.project_id, + ) hook.insert_bucket_acl( bucket_name=self.bucket, entity=self.entity, role=self.role, user_project=self.user_project ) @@ -409,15 +437,16 @@ class GCSObjectCreateAclEntryOperator(BaseOperator): # [START gcs_object_create_acl_template_fields] template_fields: Sequence[str] = ( - 'bucket', - 'object_name', - 'entity', - 'generation', - 'role', - 'user_project', - 'impersonation_chain', + "bucket", + "object_name", + "entity", + "generation", + "role", + "user_project", + "impersonation_chain", ) # [END gcs_object_create_acl_template_fields] + operator_extra_links = (FileDetailsLink(),) def __init__( self, @@ -426,10 +455,10 @@ def __init__( object_name: str, entity: str, role: str, - generation: Optional[int] = None, - user_project: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + generation: int | None = None, + user_project: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -442,11 +471,17 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: hook = GCSHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) + FileDetailsLink.persist( + context=context, + task_instance=self, + uri=f"{self.bucket}/{self.object_name}", + project_id=hook.project_id, + ) hook.insert_object_acl( bucket_name=self.bucket, object_name=self.object_name, @@ -491,24 +526,25 @@ class GCSFileTransformOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'source_bucket', - 'source_object', - 'destination_bucket', - 'destination_object', - 'transform_script', - 'impersonation_chain', + "source_bucket", + "source_object", + "destination_bucket", + "destination_object", + "transform_script", + "impersonation_chain", ) + operator_extra_links = (FileDetailsLink(),) def __init__( self, *, source_bucket: str, source_object: str, - transform_script: Union[str, List[str]], - destination_bucket: Optional[str] = None, - destination_object: Optional[str] = None, + transform_script: str | list[str], + destination_bucket: str | None = None, + destination_object: str | None = None, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -522,7 +558,7 @@ def __init__( self.output_encoding = sys.getdefaultencoding() self.impersonation_chain = impersonation_chain - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: hook = GCSHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) with NamedTemporaryFile() as source_file, NamedTemporaryFile() as destination_file: @@ -539,7 +575,7 @@ def execute(self, context: "Context") -> None: ) as process: self.log.info("Process output:") if process.stdout: - for line in iter(process.stdout.readline, b''): + for line in iter(process.stdout.readline, b""): self.log.info(line.decode(self.output_encoding).rstrip()) process.wait() @@ -549,6 +585,12 @@ def execute(self, context: "Context") -> None: self.log.info("Transformation succeeded. Output temporarily located at %s", destination_file.name) self.log.info("Uploading file to %s as %s", self.destination_bucket, self.destination_object) + FileDetailsLink.persist( + context=context, + task_instance=self, + uri=f"{self.destination_bucket}/{self.destination_object}", + project_id=hook.project_id, + ) hook.upload( bucket_name=self.destination_bucket, object_name=self.destination_object, @@ -620,17 +662,18 @@ class GCSTimeSpanFileTransformOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'source_bucket', - 'source_prefix', - 'destination_bucket', - 'destination_prefix', - 'transform_script', - 'source_impersonation_chain', - 'destination_impersonation_chain', + "source_bucket", + "source_prefix", + "destination_bucket", + "destination_prefix", + "transform_script", + "source_impersonation_chain", + "destination_impersonation_chain", ) + operator_extra_links = (StorageLink(),) @staticmethod - def interpolate_prefix(prefix: str, dt: datetime.datetime) -> Optional[str]: + def interpolate_prefix(prefix: str, dt: datetime.datetime) -> str | None: """Interpolate prefix with datetime. :param prefix: The prefix to interpolate @@ -648,13 +691,13 @@ def __init__( destination_bucket: str, destination_prefix: str, destination_gcp_conn_id: str, - transform_script: Union[str, List[str]], - source_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - destination_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - chunk_size: Optional[int] = None, - download_continue_on_fail: Optional[bool] = False, + transform_script: str | list[str], + source_impersonation_chain: str | Sequence[str] | None = None, + destination_impersonation_chain: str | Sequence[str] | None = None, + chunk_size: int | None = None, + download_continue_on_fail: bool | None = False, download_num_attempts: int = 1, - upload_continue_on_fail: Optional[bool] = False, + upload_continue_on_fail: bool | None = False, upload_num_attempts: int = 1, **kwargs, ) -> None: @@ -678,25 +721,28 @@ def __init__( self.upload_continue_on_fail = upload_continue_on_fail self.upload_num_attempts = upload_num_attempts - def execute(self, context: "Context") -> List[str]: + def execute(self, context: Context) -> list[str]: # Define intervals and prefixes. try: - timespan_start = context["data_interval_start"] - timespan_end = context["data_interval_end"] + orig_start = context["data_interval_start"] + orig_end = context["data_interval_end"] except KeyError: - timespan_start = pendulum.instance(context["execution_date"]) + orig_start = pendulum.instance(context["execution_date"]) following_execution_date = context["dag"].following_schedule(context["execution_date"]) if following_execution_date is None: - timespan_end = None + orig_end = None else: - timespan_end = pendulum.instance(following_execution_date) - - if timespan_end is None: # Only possible in Airflow before 2.2. - self.log.warning("No following schedule found, setting timespan end to max %s", timespan_end) - timespan_end = DateTime.max - elif timespan_start >= timespan_end: # Airflow 2.2 sets start == end for non-perodic schedules. - self.log.warning("DAG schedule not periodic, setting timespan end to max %s", timespan_end) - timespan_end = DateTime.max + orig_end = pendulum.instance(following_execution_date) + + timespan_start = orig_start + if orig_end is None: # Only possible in Airflow before 2.2. + self.log.warning("No following schedule found, setting timespan end to max %s", orig_end) + timespan_end = pendulum.instance(datetime.datetime.max) + elif orig_start >= orig_end: # Airflow 2.2 sets start == end for non-perodic schedules. + self.log.warning("DAG schedule not periodic, setting timespan end to max %s", orig_end) + timespan_end = pendulum.instance(datetime.datetime.max) + else: + timespan_end = orig_end timespan_start = timespan_start.in_timezone(timezone.utc) timespan_end = timespan_end.in_timezone(timezone.utc) @@ -718,6 +764,12 @@ def execute(self, context: "Context") -> List[str]: gcp_conn_id=self.destination_gcp_conn_id, impersonation_chain=self.destination_impersonation_chain, ) + StorageLink.persist( + context=context, + task_instance=self, + uri=self.destination_bucket, + project_id=destination_hook.project_id, + ) # Fetch list of files. blobs_to_transform = source_hook.list_by_timespan( @@ -761,7 +813,7 @@ def execute(self, context: "Context") -> List[str]: ) as process: self.log.info("Process output:") if process.stdout: - for line in iter(process.stdout.readline, b''): + for line in iter(process.stdout.readline, b""): self.log.info(line.decode(self.output_encoding).rstrip()) process.wait() @@ -824,7 +876,7 @@ class GCSDeleteBucketOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'bucket_name', + "bucket_name", "gcp_conn_id", "impersonation_chain", ) @@ -834,8 +886,8 @@ def __init__( *, bucket_name: str, force: bool = True, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -845,7 +897,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: hook = GCSHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) hook.delete_bucket(bucket_name=self.bucket_name, force=self.force) @@ -893,31 +945,32 @@ class GCSSynchronizeBucketsOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'source_bucket', - 'destination_bucket', - 'source_object', - 'destination_object', - 'recursive', - 'delete_extra_files', - 'allow_overwrite', - 'gcp_conn_id', - 'delegate_to', - 'impersonation_chain', + "source_bucket", + "destination_bucket", + "source_object", + "destination_object", + "recursive", + "delete_extra_files", + "allow_overwrite", + "gcp_conn_id", + "delegate_to", + "impersonation_chain", ) + operator_extra_links = (StorageLink(),) def __init__( self, *, source_bucket: str, destination_bucket: str, - source_object: Optional[str] = None, - destination_object: Optional[str] = None, + source_object: str | None = None, + destination_object: str | None = None, recursive: bool = True, delete_extra_files: bool = False, allow_overwrite: bool = False, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -932,12 +985,18 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: hook = GCSHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) + StorageLink.persist( + context=context, + task_instance=self, + uri=self._get_uri(self.destination_bucket, self.destination_object), + project_id=hook.project_id, + ) hook.sync( source_bucket=self.source_bucket, destination_bucket=self.destination_bucket, @@ -947,3 +1006,8 @@ def execute(self, context: "Context") -> None: delete_extra_files=self.delete_extra_files, allow_overwrite=self.allow_overwrite, ) + + def _get_uri(self, gcs_bucket: str, gcs_object: str | None) -> str: + if gcs_object and gcs_object[-1] == "/": + gcs_object = gcs_object[:-1] + return f"{gcs_bucket}/{gcs_object}" if gcs_object else gcs_bucket diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py b/airflow/providers/google/cloud/operators/kubernetes_engine.py index 83c013ba441fc..045935771c715 100644 --- a/airflow/providers/google/cloud/operators/kubernetes_engine.py +++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py @@ -15,14 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains Google Kubernetes Engine operators.""" +from __future__ import annotations import os import tempfile import warnings from contextlib import contextmanager -from typing import TYPE_CHECKING, Dict, Generator, Optional, Sequence, Union +from typing import TYPE_CHECKING, Generator, Sequence from google.cloud.container_v1.types import Cluster @@ -30,6 +30,10 @@ from airflow.models import BaseOperator from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator from airflow.providers.google.cloud.hooks.kubernetes_engine import GKEHook +from airflow.providers.google.cloud.links.kubernetes_engine import ( + KubernetesEngineClusterLink, + KubernetesEnginePodLink, +) from airflow.providers.google.common.hooks.base_google import GoogleBaseHook from airflow.utils.process_utils import execute_in_subprocess, patch_environ @@ -77,12 +81,12 @@ class GKEDeleteClusterOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'gcp_conn_id', - 'name', - 'location', - 'api_version', - 'impersonation_chain', + "project_id", + "gcp_conn_id", + "name", + "location", + "api_version", + "impersonation_chain", ) def __init__( @@ -90,10 +94,10 @@ def __init__( *, name: str, location: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v2', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v2", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -108,10 +112,10 @@ def __init__( def _check_input(self) -> None: if not all([self.project_id, self.name, self.location]): - self.log.error('One of (project_id, name, location) is missing or incorrect') - raise AirflowException('Operator has incorrect or missing input.') + self.log.error("One of (project_id, name, location) is missing or incorrect") + raise AirflowException("Operator has incorrect or missing input.") - def execute(self, context: 'Context') -> Optional[str]: + def execute(self, context: Context) -> str | None: hook = GKEHook( gcp_conn_id=self.gcp_conn_id, location=self.location, @@ -173,23 +177,24 @@ class GKECreateClusterOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'gcp_conn_id', - 'location', - 'api_version', - 'body', - 'impersonation_chain', + "project_id", + "gcp_conn_id", + "location", + "api_version", + "body", + "impersonation_chain", ) + operator_extra_links = (KubernetesEngineClusterLink(),) def __init__( self, *, location: str, - body: Optional[Union[Dict, Cluster]], - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v2', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + body: dict | Cluster | None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v2", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -233,13 +238,14 @@ def _check_input(self) -> None: self.log.error("Only one of body['initial_node_count']) and body['node_pools'] may be specified") raise AirflowException("Operator has incorrect or missing input.") - def execute(self, context: 'Context') -> str: + def execute(self, context: Context) -> str: hook = GKEHook( gcp_conn_id=self.gcp_conn_id, location=self.location, impersonation_chain=self.impersonation_chain, ) create_op = hook.create_cluster(cluster=self.body, project_id=self.project_id) + KubernetesEngineClusterLink.persist(context=context, task_instance=self, cluster=self.body) return create_op @@ -290,8 +296,9 @@ class GKEStartPodOperator(KubernetesPodOperator): """ template_fields: Sequence[str] = tuple( - {'project_id', 'location', 'cluster_name'} | set(KubernetesPodOperator.template_fields) + {"project_id", "location", "cluster_name"} | set(KubernetesPodOperator.template_fields) ) + operator_extra_links = (KubernetesEnginePodLink(),) def __init__( self, @@ -299,11 +306,11 @@ def __init__( location: str, cluster_name: str, use_internal_ip: bool = False, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, regional: bool = False, - is_delete_operator_pod: Optional[bool] = None, + is_delete_operator_pod: bool | None = None, **kwargs, ) -> None: if is_delete_operator_pod is None: @@ -341,9 +348,9 @@ def __init__( @contextmanager def get_gke_config_file( gcp_conn_id, - project_id: Optional[str], + project_id: str | None, cluster_name: str, - impersonation_chain: Optional[Union[str, Sequence[str]]], + impersonation_chain: str | Sequence[str] | None, regional: bool, location: str, use_internal_ip: bool, @@ -391,23 +398,23 @@ def get_gke_config_file( cmd.extend( [ - '--impersonate-service-account', + "--impersonate-service-account", impersonation_account, ] ) if regional: - cmd.append('--region') + cmd.append("--region") else: - cmd.append('--zone') + cmd.append("--zone") cmd.append(location) if use_internal_ip: - cmd.append('--internal-ip') + cmd.append("--internal-ip") execute_in_subprocess(cmd) # Tell `KubernetesPodOperator` where the config file is located yield os.environ[KUBE_CONFIG_ENV_VAR] - def execute(self, context: 'Context') -> Optional[str]: + def execute(self, context: Context) -> str | None: with GKEStartPodOperator.get_gke_config_file( gcp_conn_id=self.gcp_conn_id, @@ -419,4 +426,7 @@ def execute(self, context: 'Context') -> Optional[str]: use_internal_ip=self.use_internal_ip, ) as config_file: self.config_file = config_file - return super().execute(context) + result = super().execute(context) + if not self.is_delete_operator_pod: + KubernetesEnginePodLink.persist(context=context, task_instance=self) + return result diff --git a/airflow/providers/google/cloud/operators/life_sciences.py b/airflow/providers/google/cloud/operators/life_sciences.py index a3fc1b3ff5ecd..b549d2f786e66 100644 --- a/airflow/providers/google/cloud/operators/life_sciences.py +++ b/airflow/providers/google/cloud/operators/life_sciences.py @@ -16,12 +16,14 @@ # specific language governing permissions and limitations # under the License. """Operators that interact with Google Cloud Life Sciences service.""" +from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.life_sciences import LifeSciencesHook +from airflow.providers.google.cloud.links.life_sciences import LifeSciencesLink if TYPE_CHECKING: from airflow.utils.context import Context @@ -57,16 +59,17 @@ class LifeSciencesRunPipelineOperator(BaseOperator): "api_version", "impersonation_chain", ) + operator_extra_links = (LifeSciencesLink(),) def __init__( self, *, body: dict, location: str, - project_id: Optional[str] = None, + project_id: str | None = None, gcp_conn_id: str = "google_cloud_default", api_version: str = "v2beta", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -84,11 +87,17 @@ def _validate_inputs(self) -> None: if not self.location: raise AirflowException("The required parameter 'location' is missing") - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: hook = LifeSciencesHook( gcp_conn_id=self.gcp_conn_id, api_version=self.api_version, impersonation_chain=self.impersonation_chain, ) - + project_id = self.project_id or hook.project_id + if project_id: + LifeSciencesLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) return hook.run_pipeline(body=self.body, location=self.location, project_id=self.project_id) diff --git a/airflow/providers/google/cloud/operators/looker.py b/airflow/providers/google/cloud/operators/looker.py index 917884ab666a2..971fc3fddbbf0 100644 --- a/airflow/providers/google/cloud/operators/looker.py +++ b/airflow/providers/google/cloud/operators/looker.py @@ -15,10 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains Google Cloud Looker operators.""" +from __future__ import annotations -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -53,11 +53,11 @@ def __init__( looker_conn_id: str, model: str, view: str, - query_params: Optional[Dict] = None, + query_params: dict | None = None, asynchronous: bool = False, cancel_on_kill: bool = True, wait_time: int = 10, - wait_timeout: Optional[int] = None, + wait_timeout: int | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -69,10 +69,10 @@ def __init__( self.cancel_on_kill = cancel_on_kill self.wait_time = wait_time self.wait_timeout = wait_timeout - self.hook: Optional[LookerHook] = None - self.materialization_id: Optional[str] = None + self.hook: LookerHook | None = None + self.materialization_id: str | None = None - def execute(self, context: "Context") -> str: + def execute(self, context: Context) -> str: self.hook = LookerHook(looker_conn_id=self.looker_conn_id) @@ -86,7 +86,7 @@ def execute(self, context: "Context") -> str: if not self.materialization_id: raise AirflowException( - f'No `materialization_id` was returned for model: {self.model}, view: {self.view}.' + f"No `materialization_id` was returned for model: {self.model}, view: {self.view}." ) self.log.info("PDT materialization job submitted successfully. Job id: %s.", self.materialization_id) diff --git a/airflow/providers/google/cloud/operators/mlengine.py b/airflow/providers/google/cloud/operators/mlengine.py index ce7d6ca9d51fd..bdc47a54d8e51 100644 --- a/airflow/providers/google/cloud/operators/mlengine.py +++ b/airflow/providers/google/cloud/operators/mlengine.py @@ -16,18 +16,26 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Cloud MLEngine operators.""" -import datetime + +from __future__ import annotations + import logging import re import warnings -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Sequence from airflow.exceptions import AirflowException -from airflow.models import BaseOperator, BaseOperatorLink, XCom +from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.mlengine import MLEngineHook +from airflow.providers.google.cloud.links.mlengine import ( + MLEngineJobDetailsLink, + MLEngineJobSListLink, + MLEngineModelLink, + MLEngineModelsListLink, + MLEngineModelVersionDetailsLink, +) if TYPE_CHECKING: - from airflow.models.taskinstance import TaskInstanceKey from airflow.utils.context import Context @@ -43,25 +51,24 @@ def _normalize_mlengine_job_id(job_id: str) -> str: :param job_id: A job_id str that may have invalid characters. :return: A valid job_id representation. - :rtype: str """ # Add a prefix when a job_id starts with a digit or a template - match = re.search(r'\d|\{{2}', job_id) + match = re.search(r"\d|\{{2}", job_id) if match and match.start() == 0: - job = f'z_{job_id}' + job = f"z_{job_id}" else: job = job_id # Clean up 'bad' characters except templates tracker = 0 - cleansed_job_id = '' - for match in re.finditer(r'\{{2}.+?\}{2}', job): - cleansed_job_id += re.sub(r'[^0-9a-zA-Z]+', '_', job[tracker : match.start()]) + cleansed_job_id = "" + for match in re.finditer(r"\{{2}.+?\}{2}", job): + cleansed_job_id += re.sub(r"[^0-9a-zA-Z]+", "_", job[tracker : match.start()]) cleansed_job_id += job[match.start() : match.end()] tracker = match.end() # Clean up last substring or the full string if no templates - cleansed_job_id += re.sub(r'[^0-9a-zA-Z]+', '_', job[tracker:]) + cleansed_job_id += re.sub(r"[^0-9a-zA-Z]+", "_", job[tracker:]) return cleansed_job_id @@ -152,15 +159,15 @@ class MLEngineStartBatchPredictionJobOperator(BaseOperator): """ template_fields: Sequence[str] = ( - '_project_id', - '_job_id', - '_region', - '_input_paths', - '_output_path', - '_model_name', - '_version_name', - '_uri', - '_impersonation_chain', + "_project_id", + "_job_id", + "_region", + "_input_paths", + "_output_path", + "_model_name", + "_version_name", + "_uri", + "_impersonation_chain", ) def __init__( @@ -169,19 +176,19 @@ def __init__( job_id: str, region: str, data_format: str, - input_paths: List[str], + input_paths: list[str], output_path: str, - model_name: Optional[str] = None, - version_name: Optional[str] = None, - uri: Optional[str] = None, - max_worker_count: Optional[int] = None, - runtime_version: Optional[str] = None, - signature_name: Optional[str] = None, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + model_name: str | None = None, + version_name: str | None = None, + uri: str | None = None, + max_worker_count: int | None = None, + runtime_version: str | None = None, + signature_name: str | None = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + labels: dict[str, str] | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -204,60 +211,60 @@ def __init__( self._impersonation_chain = impersonation_chain if not self._project_id: - raise AirflowException('Google Cloud project id is required.') + raise AirflowException("Google Cloud project id is required.") if not self._job_id: - raise AirflowException('An unique job id is required for Google MLEngine prediction job.') + raise AirflowException("An unique job id is required for Google MLEngine prediction job.") if self._uri: if self._model_name or self._version_name: raise AirflowException( - 'Ambiguous model origin: Both uri and model/version name are provided.' + "Ambiguous model origin: Both uri and model/version name are provided." ) if self._version_name and not self._model_name: raise AirflowException( - 'Missing model: Batch prediction expects a model name when a version name is provided.' + "Missing model: Batch prediction expects a model name when a version name is provided." ) if not (self._uri or self._model_name): raise AirflowException( - 'Missing model origin: Batch prediction expects a model, ' - 'a model & version combination, or a URI to a savedModel.' + "Missing model origin: Batch prediction expects a model, " + "a model & version combination, or a URI to a savedModel." ) - def execute(self, context: 'Context'): + def execute(self, context: Context): job_id = _normalize_mlengine_job_id(self._job_id) - prediction_request: Dict[str, Any] = { - 'jobId': job_id, - 'predictionInput': { - 'dataFormat': self._data_format, - 'inputPaths': self._input_paths, - 'outputPath': self._output_path, - 'region': self._region, + prediction_request: dict[str, Any] = { + "jobId": job_id, + "predictionInput": { + "dataFormat": self._data_format, + "inputPaths": self._input_paths, + "outputPath": self._output_path, + "region": self._region, }, } if self._labels: - prediction_request['labels'] = self._labels + prediction_request["labels"] = self._labels if self._uri: - prediction_request['predictionInput']['uri'] = self._uri + prediction_request["predictionInput"]["uri"] = self._uri elif self._model_name: - origin_name = f'projects/{self._project_id}/models/{self._model_name}' + origin_name = f"projects/{self._project_id}/models/{self._model_name}" if not self._version_name: - prediction_request['predictionInput']['modelName'] = origin_name + prediction_request["predictionInput"]["modelName"] = origin_name else: - prediction_request['predictionInput']['versionName'] = ( - origin_name + f'/versions/{self._version_name}' + prediction_request["predictionInput"]["versionName"] = ( + origin_name + f"/versions/{self._version_name}" ) if self._max_worker_count: - prediction_request['predictionInput']['maxWorkerCount'] = self._max_worker_count + prediction_request["predictionInput"]["maxWorkerCount"] = self._max_worker_count if self._runtime_version: - prediction_request['predictionInput']['runtimeVersion'] = self._runtime_version + prediction_request["predictionInput"]["runtimeVersion"] = self._runtime_version if self._signature_name: - prediction_request['predictionInput']['signatureName'] = self._signature_name + prediction_request["predictionInput"]["signatureName"] = self._signature_name hook = MLEngineHook( self._gcp_conn_id, self._delegate_to, impersonation_chain=self._impersonation_chain @@ -266,17 +273,17 @@ def execute(self, context: 'Context'): # Helper method to check if the existing job's prediction input is the # same as the request we get here. def check_existing_job(existing_job): - return existing_job.get('predictionInput') == prediction_request['predictionInput'] + return existing_job.get("predictionInput") == prediction_request["predictionInput"] finished_prediction_job = hook.create_job( project_id=self._project_id, job=prediction_request, use_existing_job_fn=check_existing_job ) - if finished_prediction_job['state'] != 'SUCCEEDED': - self.log.error('MLEngine batch prediction job failed: %s', str(finished_prediction_job)) - raise RuntimeError(finished_prediction_job['errorMessage']) + if finished_prediction_job["state"] != "SUCCEEDED": + self.log.error("MLEngine batch prediction job failed: %s", str(finished_prediction_job)) + raise RuntimeError(finished_prediction_job["errorMessage"]) - return finished_prediction_job['predictionOutput'] + return finished_prediction_job["predictionOutput"] class MLEngineManageModelOperator(BaseOperator): @@ -315,20 +322,20 @@ class MLEngineManageModelOperator(BaseOperator): """ template_fields: Sequence[str] = ( - '_project_id', - '_model', - '_impersonation_chain', + "_project_id", + "_model", + "_impersonation_chain", ) def __init__( self, *, model: dict, - operation: str = 'create', - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + operation: str = "create", + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -347,18 +354,18 @@ def __init__( self._delegate_to = delegate_to self._impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to, impersonation_chain=self._impersonation_chain, ) - if self._operation == 'create': + if self._operation == "create": return hook.create_model(project_id=self._project_id, model=self._model) - elif self._operation == 'get': - return hook.get_model(project_id=self._project_id, model_name=self._model['name']) + elif self._operation == "get": + return hook.get_model(project_id=self._project_id, model_name=self._model["name"]) else: - raise ValueError(f'Unknown operation: {self._operation}') + raise ValueError(f"Unknown operation: {self._operation}") class MLEngineCreateModelOperator(BaseOperator): @@ -390,19 +397,20 @@ class MLEngineCreateModelOperator(BaseOperator): """ template_fields: Sequence[str] = ( - '_project_id', - '_model', - '_impersonation_chain', + "_project_id", + "_model", + "_impersonation_chain", ) + operator_extra_links = (MLEngineModelLink(),) def __init__( self, *, model: dict, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -412,12 +420,22 @@ def __init__( self._delegate_to = delegate_to self._impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to, impersonation_chain=self._impersonation_chain, ) + + project_id = self._project_id or hook.project_id + if project_id: + MLEngineModelLink.persist( + context=context, + task_instance=self, + project_id=project_id, + model_id=self._model["name"], + ) + return hook.create_model(project_id=self._project_id, model=self._model) @@ -450,19 +468,20 @@ class MLEngineGetModelOperator(BaseOperator): """ template_fields: Sequence[str] = ( - '_project_id', - '_model_name', - '_impersonation_chain', + "_project_id", + "_model_name", + "_impersonation_chain", ) + operator_extra_links = (MLEngineModelLink(),) def __init__( self, *, model_name: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -472,12 +491,21 @@ def __init__( self._delegate_to = delegate_to self._impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to, impersonation_chain=self._impersonation_chain, ) + project_id = self._project_id or hook.project_id + if project_id: + MLEngineModelLink.persist( + context=context, + task_instance=self, + project_id=project_id, + model_id=self._model_name, + ) + return hook.get_model(project_id=self._project_id, model_name=self._model_name) @@ -513,20 +541,21 @@ class MLEngineDeleteModelOperator(BaseOperator): """ template_fields: Sequence[str] = ( - '_project_id', - '_model_name', - '_impersonation_chain', + "_project_id", + "_model_name", + "_impersonation_chain", ) + operator_extra_links = (MLEngineModelsListLink(),) def __init__( self, *, model_name: str, delete_contents: bool = False, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -537,13 +566,21 @@ def __init__( self._delegate_to = delegate_to self._impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to, impersonation_chain=self._impersonation_chain, ) + project_id = self._project_id or hook.project_id + if project_id: + MLEngineModelsListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) + return hook.delete_model( project_id=self._project_id, model_name=self._model_name, delete_contents=self._delete_contents ) @@ -606,24 +643,24 @@ class MLEngineManageVersionOperator(BaseOperator): """ template_fields: Sequence[str] = ( - '_project_id', - '_model_name', - '_version_name', - '_version', - '_impersonation_chain', + "_project_id", + "_model_name", + "_version_name", + "_version", + "_impersonation_chain", ) def __init__( self, *, model_name: str, - version_name: Optional[str] = None, - version: Optional[dict] = None, - operation: str = 'create', - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + version_name: str | None = None, + version: dict | None = None, + operation: str = "create", + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -643,9 +680,9 @@ def __init__( stacklevel=3, ) - def execute(self, context: 'Context'): - if 'name' not in self._version: - self._version['name'] = self._version_name + def execute(self, context: Context): + if "name" not in self._version: + self._version["name"] = self._version_name hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, @@ -653,24 +690,24 @@ def execute(self, context: 'Context'): impersonation_chain=self._impersonation_chain, ) - if self._operation == 'create': + if self._operation == "create": if not self._version: raise ValueError(f"version attribute of {self.__class__.__name__} could not be empty") return hook.create_version( project_id=self._project_id, model_name=self._model_name, version_spec=self._version ) - elif self._operation == 'set_default': + elif self._operation == "set_default": return hook.set_default_version( - project_id=self._project_id, model_name=self._model_name, version_name=self._version['name'] + project_id=self._project_id, model_name=self._model_name, version_name=self._version["name"] ) - elif self._operation == 'list': + elif self._operation == "list": return hook.list_versions(project_id=self._project_id, model_name=self._model_name) - elif self._operation == 'delete': + elif self._operation == "delete": return hook.delete_version( - project_id=self._project_id, model_name=self._model_name, version_name=self._version['name'] + project_id=self._project_id, model_name=self._model_name, version_name=self._version["name"] ) else: - raise ValueError(f'Unknown operation: {self._operation}') + raise ValueError(f"Unknown operation: {self._operation}") class MLEngineCreateVersionOperator(BaseOperator): @@ -704,21 +741,22 @@ class MLEngineCreateVersionOperator(BaseOperator): """ template_fields: Sequence[str] = ( - '_project_id', - '_model_name', - '_version', - '_impersonation_chain', + "_project_id", + "_model_name", + "_version", + "_impersonation_chain", ) + operator_extra_links = (MLEngineModelVersionDetailsLink(),) def __init__( self, *, model_name: str, version: dict, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: @@ -738,13 +776,23 @@ def _validate_inputs(self): if not self._version: raise AirflowException("The version parameter could not be empty.") - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to, impersonation_chain=self._impersonation_chain, ) + project_id = self._project_id or hook.project_id + if project_id: + MLEngineModelVersionDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + model_id=self._model_name, + version_id=self._version["name"], + ) + return hook.create_version( project_id=self._project_id, model_name=self._model_name, version_spec=self._version ) @@ -781,21 +829,22 @@ class MLEngineSetDefaultVersionOperator(BaseOperator): """ template_fields: Sequence[str] = ( - '_project_id', - '_model_name', - '_version_name', - '_impersonation_chain', + "_project_id", + "_model_name", + "_version_name", + "_impersonation_chain", ) + operator_extra_links = (MLEngineModelVersionDetailsLink(),) def __init__( self, *, model_name: str, version_name: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: @@ -815,13 +864,23 @@ def _validate_inputs(self): if not self._version_name: raise AirflowException("The version_name parameter could not be empty.") - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to, impersonation_chain=self._impersonation_chain, ) + project_id = self._project_id or hook.project_id + if project_id: + MLEngineModelVersionDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + model_id=self._model_name, + version_id=self._version_name, + ) + return hook.set_default_version( project_id=self._project_id, model_name=self._model_name, version_name=self._version_name ) @@ -857,19 +916,20 @@ class MLEngineListVersionsOperator(BaseOperator): """ template_fields: Sequence[str] = ( - '_project_id', - '_model_name', - '_impersonation_chain', + "_project_id", + "_model_name", + "_impersonation_chain", ) + operator_extra_links = (MLEngineModelLink(),) def __init__( self, *, model_name: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: @@ -885,13 +945,22 @@ def _validate_inputs(self): if not self._model_name: raise AirflowException("The model_name parameter could not be empty.") - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to, impersonation_chain=self._impersonation_chain, ) + project_id = self._project_id or hook.project_id + if project_id: + MLEngineModelLink.persist( + context=context, + task_instance=self, + project_id=project_id, + model_id=self._model_name, + ) + return hook.list_versions( project_id=self._project_id, model_name=self._model_name, @@ -929,21 +998,22 @@ class MLEngineDeleteVersionOperator(BaseOperator): """ template_fields: Sequence[str] = ( - '_project_id', - '_model_name', - '_version_name', - '_impersonation_chain', + "_project_id", + "_model_name", + "_version_name", + "_impersonation_chain", ) + operator_extra_links = (MLEngineModelLink(),) def __init__( self, *, model_name: str, version_name: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: @@ -963,47 +1033,27 @@ def _validate_inputs(self): if not self._version_name: raise AirflowException("The version_name parameter could not be empty.") - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to, impersonation_chain=self._impersonation_chain, ) + project_id = self._project_id or hook.project_id + if project_id: + MLEngineModelLink.persist( + context=context, + task_instance=self, + project_id=project_id, + model_id=self._model_name, + ) + return hook.delete_version( project_id=self._project_id, model_name=self._model_name, version_name=self._version_name ) -class AIPlatformConsoleLink(BaseOperatorLink): - """Helper class for constructing AI Platform Console link.""" - - name = "AI Platform Console" - - def get_link( - self, - operator, - dttm: Optional[datetime.datetime] = None, - ti_key: Optional["TaskInstanceKey"] = None, - ) -> str: - if ti_key is not None: - gcp_metadata_dict = XCom.get_value(key="gcp_metadata", ti_key=ti_key) - else: - assert dttm is not None - gcp_metadata_dict = XCom.get_one( - key="gcp_metadata", - dag_id=operator.dag.dag_id, - task_id=operator.task_id, - execution_date=dttm, - ) - if not gcp_metadata_dict: - return '' - job_id = gcp_metadata_dict['job_id'] - project_id = gcp_metadata_dict['project_id'] - console_link = f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}" - return console_link - - class MLEngineStartTrainingJobOperator(BaseOperator): """ Operator for launching a MLEngine training job. @@ -1069,47 +1119,46 @@ class MLEngineStartTrainingJobOperator(BaseOperator): """ template_fields: Sequence[str] = ( - '_project_id', - '_job_id', - '_region', - '_package_uris', - '_training_python_module', - '_training_args', - '_scale_tier', - '_master_type', - '_master_config', - '_runtime_version', - '_python_version', - '_job_dir', - '_service_account', - '_hyperparameters', - '_impersonation_chain', + "_project_id", + "_job_id", + "_region", + "_package_uris", + "_training_python_module", + "_training_args", + "_scale_tier", + "_master_type", + "_master_config", + "_runtime_version", + "_python_version", + "_job_dir", + "_service_account", + "_hyperparameters", + "_impersonation_chain", ) - - operator_extra_links = (AIPlatformConsoleLink(),) + operator_extra_links = (MLEngineJobDetailsLink(),) def __init__( self, *, job_id: str, region: str, - package_uris: Optional[List[str]] = None, - training_python_module: Optional[str] = None, - training_args: Optional[List[str]] = None, - scale_tier: Optional[str] = None, - master_type: Optional[str] = None, - master_config: Optional[Dict] = None, - runtime_version: Optional[str] = None, - python_version: Optional[str] = None, - job_dir: Optional[str] = None, - service_account: Optional[str] = None, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - mode: str = 'PRODUCTION', - labels: Optional[Dict[str, str]] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - hyperparameters: Optional[Dict] = None, + package_uris: list[str] | None = None, + training_python_module: str | None = None, + training_args: list[str] | None = None, + scale_tier: str | None = None, + master_type: str | None = None, + master_config: dict | None = None, + runtime_version: str | None = None, + python_version: str | None = None, + job_dir: str | None = None, + service_account: str | None = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + mode: str = "PRODUCTION", + labels: dict[str, str] | None = None, + impersonation_chain: str | Sequence[str] | None = None, + hyperparameters: dict | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1133,78 +1182,78 @@ def __init__( self._hyperparameters = hyperparameters self._impersonation_chain = impersonation_chain - custom = self._scale_tier is not None and self._scale_tier.upper() == 'CUSTOM' + custom = self._scale_tier is not None and self._scale_tier.upper() == "CUSTOM" custom_image = ( custom and self._master_config is not None - and self._master_config.get('imageUri', None) is not None + and self._master_config.get("imageUri", None) is not None ) if not self._project_id: - raise AirflowException('Google Cloud project id is required.') + raise AirflowException("Google Cloud project id is required.") if not self._job_id: - raise AirflowException('An unique job id is required for Google MLEngine training job.') + raise AirflowException("An unique job id is required for Google MLEngine training job.") if not self._region: - raise AirflowException('Google Compute Engine region is required.') + raise AirflowException("Google Compute Engine region is required.") if custom and not self._master_type: - raise AirflowException('master_type must be set when scale_tier is CUSTOM') + raise AirflowException("master_type must be set when scale_tier is CUSTOM") if self._master_config and not self._master_type: - raise AirflowException('master_type must be set when master_config is provided') + raise AirflowException("master_type must be set when master_config is provided") if not (package_uris and training_python_module) and not custom_image: raise AirflowException( - 'Either a Python package with a Python module or a custom Docker image should be provided.' + "Either a Python package with a Python module or a custom Docker image should be provided." ) if (package_uris or training_python_module) and custom_image: raise AirflowException( - 'Either a Python package with a Python module or ' - 'a custom Docker image should be provided but not both.' + "Either a Python package with a Python module or " + "a custom Docker image should be provided but not both." ) - def execute(self, context: 'Context'): + def execute(self, context: Context): job_id = _normalize_mlengine_job_id(self._job_id) - training_request: Dict[str, Any] = { - 'jobId': job_id, - 'trainingInput': { - 'scaleTier': self._scale_tier, - 'region': self._region, + training_request: dict[str, Any] = { + "jobId": job_id, + "trainingInput": { + "scaleTier": self._scale_tier, + "region": self._region, }, } if self._package_uris: - training_request['trainingInput']['packageUris'] = self._package_uris + training_request["trainingInput"]["packageUris"] = self._package_uris if self._training_python_module: - training_request['trainingInput']['pythonModule'] = self._training_python_module + training_request["trainingInput"]["pythonModule"] = self._training_python_module if self._training_args: - training_request['trainingInput']['args'] = self._training_args + training_request["trainingInput"]["args"] = self._training_args if self._master_type: - training_request['trainingInput']['masterType'] = self._master_type + training_request["trainingInput"]["masterType"] = self._master_type if self._master_config: - training_request['trainingInput']['masterConfig'] = self._master_config + training_request["trainingInput"]["masterConfig"] = self._master_config if self._runtime_version: - training_request['trainingInput']['runtimeVersion'] = self._runtime_version + training_request["trainingInput"]["runtimeVersion"] = self._runtime_version if self._python_version: - training_request['trainingInput']['pythonVersion'] = self._python_version + training_request["trainingInput"]["pythonVersion"] = self._python_version if self._job_dir: - training_request['trainingInput']['jobDir'] = self._job_dir + training_request["trainingInput"]["jobDir"] = self._job_dir if self._service_account: - training_request['trainingInput']['serviceAccount'] = self._service_account + training_request["trainingInput"]["serviceAccount"] = self._service_account if self._hyperparameters: - training_request['trainingInput']['hyperparameters'] = self._hyperparameters + training_request["trainingInput"]["hyperparameters"] = self._hyperparameters if self._labels: - training_request['labels'] = self._labels + training_request["labels"] = self._labels - if self._mode == 'DRY_RUN': - self.log.info('In dry_run mode.') - self.log.info('MLEngine Training job request is: %s', training_request) + if self._mode == "DRY_RUN": + self.log.info("In dry_run mode.") + self.log.info("MLEngine Training job request is: %s", training_request) return hook = MLEngineHook( @@ -1216,14 +1265,14 @@ def execute(self, context: 'Context'): # Helper method to check if the existing job's training input is the # same as the request we get here. def check_existing_job(existing_job): - existing_training_input = existing_job.get('trainingInput') - requested_training_input = training_request['trainingInput'] - if 'scaleTier' not in existing_training_input: - existing_training_input['scaleTier'] = None + existing_training_input = existing_job.get("trainingInput") + requested_training_input = training_request["trainingInput"] + if "scaleTier" not in existing_training_input: + existing_training_input["scaleTier"] = None - existing_training_input['args'] = existing_training_input.get('args') + existing_training_input["args"] = existing_training_input.get("args") requested_training_input["args"] = ( - requested_training_input['args'] if requested_training_input["args"] else None + requested_training_input["args"] if requested_training_input["args"] else None ) return existing_training_input == requested_training_input @@ -1232,15 +1281,18 @@ def check_existing_job(existing_job): project_id=self._project_id, job=training_request, use_existing_job_fn=check_existing_job ) - if finished_training_job['state'] != 'SUCCEEDED': - self.log.error('MLEngine training job failed: %s', str(finished_training_job)) - raise RuntimeError(finished_training_job['errorMessage']) - - gcp_metadata = { - "job_id": job_id, - "project_id": self._project_id, - } - context['task_instance'].xcom_push("gcp_metadata", gcp_metadata) + if finished_training_job["state"] != "SUCCEEDED": + self.log.error("MLEngine training job failed: %s", str(finished_training_job)) + raise RuntimeError(finished_training_job["errorMessage"]) + + project_id = self._project_id or hook.project_id + if project_id: + MLEngineJobDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + job_id=job_id, + ) class MLEngineTrainingCancelJobOperator(BaseOperator): @@ -1267,19 +1319,20 @@ class MLEngineTrainingCancelJobOperator(BaseOperator): """ template_fields: Sequence[str] = ( - '_project_id', - '_job_id', - '_impersonation_chain', + "_project_id", + "_job_id", + "_impersonation_chain", ) + operator_extra_links = (MLEngineJobSListLink(),) def __init__( self, *, job_id: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1290,9 +1343,9 @@ def __init__( self._impersonation_chain = impersonation_chain if not self._project_id: - raise AirflowException('Google Cloud project id is required.') + raise AirflowException("Google Cloud project id is required.") - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, @@ -1300,4 +1353,12 @@ def execute(self, context: 'Context'): impersonation_chain=self._impersonation_chain, ) + project_id = self._project_id or hook.project_id + if project_id: + MLEngineJobSListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) + hook.cancel_job(project_id=self._project_id, job_id=_normalize_mlengine_job_id(self._job_id)) diff --git a/airflow/providers/google/cloud/operators/natural_language.py b/airflow/providers/google/cloud/operators/natural_language.py index a920bd137963b..20f3085222983 100644 --- a/airflow/providers/google/cloud/operators/natural_language.py +++ b/airflow/providers/google/cloud/operators/natural_language.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Cloud Language operators.""" -from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence, Tuple from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -73,13 +75,13 @@ class CloudNaturalLanguageAnalyzeEntitiesOperator(BaseOperator): def __init__( self, *, - document: Union[dict, Document], - encoding_type: Optional[enums.EncodingType] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + document: dict | Document, + encoding_type: enums.EncodingType | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -91,7 +93,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudNaturalLanguageHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -133,7 +135,6 @@ class CloudNaturalLanguageAnalyzeEntitySentimentOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.language_v1.types.AnalyzeEntitiesResponse """ # [START natural_language_analyze_entity_sentiment_template_fields] @@ -147,13 +148,13 @@ class CloudNaturalLanguageAnalyzeEntitySentimentOperator(BaseOperator): def __init__( self, *, - document: Union[dict, Document], - encoding_type: Optional[enums.EncodingType] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + document: dict | Document, + encoding_type: enums.EncodingType | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -165,7 +166,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudNaturalLanguageHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -210,7 +211,6 @@ class CloudNaturalLanguageAnalyzeSentimentOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.language_v1.types.AnalyzeEntitiesResponse """ # [START natural_language_analyze_sentiment_template_fields] @@ -224,13 +224,13 @@ class CloudNaturalLanguageAnalyzeSentimentOperator(BaseOperator): def __init__( self, *, - document: Union[dict, Document], - encoding_type: Optional[enums.EncodingType] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + document: dict | Document, + encoding_type: enums.EncodingType | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -242,7 +242,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudNaturalLanguageHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -294,12 +294,12 @@ class CloudNaturalLanguageClassifyTextOperator(BaseOperator): def __init__( self, *, - document: Union[dict, Document], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + document: dict | Document, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -310,7 +310,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudNaturalLanguageHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, diff --git a/airflow/providers/google/cloud/operators/pubsub.py b/airflow/providers/google/cloud/operators/pubsub.py index 7b74427b68c1d..b23e7f20a84d9 100644 --- a/airflow/providers/google/cloud/operators/pubsub.py +++ b/airflow/providers/google/cloud/operators/pubsub.py @@ -22,7 +22,9 @@ MessageStoragePolicy """ -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -38,6 +40,7 @@ from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.pubsub import PubSubHook +from airflow.providers.google.cloud.links.pubsub import PubSubSubscriptionLink, PubSubTopicLink if TYPE_CHECKING: from airflow.utils.context import Context @@ -90,7 +93,7 @@ class PubSubCreateTopicOperator(BaseOperator): of Google Cloud regions where messages published to the topic may be stored. If not present, then no constraints are in effect. - Union[Dict, google.cloud.pubsub_v1.types.MessageStoragePolicy] + Union[dict, google.cloud.pubsub_v1.types.MessageStoragePolicy] :param kms_key_name: The resource name of the Cloud KMS CryptoKey to be used to protect access to messages published on this topic. The expected format is @@ -112,27 +115,28 @@ class PubSubCreateTopicOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'topic', - 'impersonation_chain', + "project_id", + "topic", + "impersonation_chain", ) - ui_color = '#0273d4' + ui_color = "#0273d4" + operator_extra_links = (PubSubTopicLink(),) def __init__( self, *, topic: str, - project_id: Optional[str] = None, + project_id: str | None = None, fail_if_exists: bool = False, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - message_storage_policy: Union[Dict, MessageStoragePolicy] = None, - kms_key_name: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + labels: dict[str, str] | None = None, + message_storage_policy: dict | MessageStoragePolicy = None, + kms_key_name: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: @@ -150,7 +154,7 @@ def __init__( self.metadata = metadata self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = PubSubHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -170,6 +174,12 @@ def execute(self, context: 'Context') -> None: metadata=self.metadata, ) self.log.info("Created topic %s", self.topic) + PubSubTopicLink.persist( + context=context, + task_instance=self, + topic_id=self.topic, + project_id=self.project_id or hook.project_id, + ) class PubSubCreateSubscriptionOperator(BaseOperator): @@ -265,7 +275,7 @@ class PubSubCreateSubscriptionOperator(BaseOperator): in which they are received by the Pub/Sub system. Otherwise, they may be delivered in any order. :param expiration_policy: A policy that specifies the conditions for this - subscription’s expiration. A subscription is considered active as long as any + subscription's expiration. A subscription is considered active as long as any connected subscriber is successfully consuming messages from the subscription or is issuing operations on the subscription. If expiration_policy is not set, a default policy with ttl of 31 days will be used. The minimum allowed value for @@ -298,38 +308,39 @@ class PubSubCreateSubscriptionOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'topic', - 'subscription', - 'subscription_project_id', - 'impersonation_chain', + "project_id", + "topic", + "subscription", + "subscription_project_id", + "impersonation_chain", ) - ui_color = '#0273d4' + ui_color = "#0273d4" + operator_extra_links = (PubSubSubscriptionLink(),) def __init__( self, *, topic: str, - project_id: Optional[str] = None, - subscription: Optional[str] = None, - subscription_project_id: Optional[str] = None, + project_id: str | None = None, + subscription: str | None = None, + subscription_project_id: str | None = None, ack_deadline_secs: int = 10, fail_if_exists: bool = False, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - push_config: Optional[Union[Dict, PushConfig]] = None, - retain_acked_messages: Optional[bool] = None, - message_retention_duration: Optional[Union[Dict, Duration]] = None, - labels: Optional[Dict[str, str]] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + push_config: dict | PushConfig | None = None, + retain_acked_messages: bool | None = None, + message_retention_duration: dict | Duration | None = None, + labels: dict[str, str] | None = None, enable_message_ordering: bool = False, - expiration_policy: Optional[Union[Dict, ExpirationPolicy]] = None, - filter_: Optional[str] = None, - dead_letter_policy: Optional[Union[Dict, DeadLetterPolicy]] = None, - retry_policy: Optional[Union[Dict, RetryPolicy]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + expiration_policy: dict | ExpirationPolicy | None = None, + filter_: str | None = None, + dead_letter_policy: dict | DeadLetterPolicy | None = None, + retry_policy: dict | RetryPolicy | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -355,7 +366,7 @@ def __init__( self.metadata = metadata self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> str: + def execute(self, context: Context) -> str: hook = PubSubHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -385,6 +396,12 @@ def execute(self, context: 'Context') -> str: ) self.log.info("Created subscription for topic %s", self.topic) + PubSubSubscriptionLink.persist( + context=context, + task_instance=self, + subscription_id=self.subscription or result, # result returns subscription name + project_id=self.project_id or hook.project_id, + ) return result @@ -440,24 +457,24 @@ class PubSubDeleteTopicOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'topic', - 'impersonation_chain', + "project_id", + "topic", + "impersonation_chain", ) - ui_color = '#cb4335' + ui_color = "#cb4335" def __init__( self, *, topic: str, - project_id: Optional[str] = None, + project_id: str | None = None, fail_if_not_exists: bool = False, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -471,7 +488,7 @@ def __init__( self.metadata = metadata self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = PubSubHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -544,24 +561,24 @@ class PubSubDeleteSubscriptionOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'subscription', - 'impersonation_chain', + "project_id", + "subscription", + "impersonation_chain", ) - ui_color = '#cb4335' + ui_color = "#cb4335" def __init__( self, *, subscription: str, - project_id: Optional[str] = None, + project_id: str | None = None, fail_if_not_exists: bool = False, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -575,7 +592,7 @@ def __init__( self.metadata = metadata self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = PubSubHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -652,22 +669,22 @@ class PubSubPublishMessageOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'topic', - 'messages', - 'impersonation_chain', + "project_id", + "topic", + "messages", + "impersonation_chain", ) - ui_color = '#0273d4' + ui_color = "#0273d4" def __init__( self, *, topic: str, - messages: List, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + messages: list, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -678,7 +695,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = PubSubHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -738,9 +755,9 @@ class PubSubPullOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'subscription', - 'impersonation_chain', + "project_id", + "subscription", + "impersonation_chain", ) def __init__( @@ -750,10 +767,10 @@ def __init__( subscription: str, max_messages: int = 5, ack_messages: bool = False, - messages_callback: Optional[Callable[[List[ReceivedMessage], "Context"], Any]] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + messages_callback: Callable[[list[ReceivedMessage], Context], Any] | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -766,7 +783,7 @@ def __init__( self.messages_callback = messages_callback self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> list: + def execute(self, context: Context) -> list: hook = PubSubHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -795,8 +812,8 @@ def execute(self, context: 'Context') -> list: def _default_message_callback( self, - pulled_messages: List[ReceivedMessage], - context: "Context", + pulled_messages: list[ReceivedMessage], + context: Context, ) -> list: """ This method can be overridden by subclasses or by `messages_callback` constructor argument. diff --git a/airflow/providers/google/cloud/operators/spanner.py b/airflow/providers/google/cloud/operators/spanner.py index 0780621e2901c..a23a9fc4682fa 100644 --- a/airflow/providers/google/cloud/operators/spanner.py +++ b/airflow/providers/google/cloud/operators/spanner.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Spanner operators.""" -from typing import TYPE_CHECKING, List, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -60,12 +62,12 @@ class SpannerDeployInstanceOperator(BaseOperator): # [START gcp_spanner_deploy_template_fields] template_fields: Sequence[str] = ( - 'project_id', - 'instance_id', - 'configuration_name', - 'display_name', - 'gcp_conn_id', - 'impersonation_chain', + "project_id", + "instance_id", + "configuration_name", + "display_name", + "gcp_conn_id", + "impersonation_chain", ) # [END gcp_spanner_deploy_template_fields] operator_extra_links = (SpannerInstanceLink(),) @@ -77,9 +79,9 @@ def __init__( configuration_name: str, node_count: int, display_name: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.instance_id = instance_id @@ -93,12 +95,12 @@ def __init__( super().__init__(**kwargs) def _validate_inputs(self) -> None: - if self.project_id == '': + if self.project_id == "": raise AirflowException("The required parameter 'project_id' is empty") if not self.instance_id: raise AirflowException("The required parameter 'instance_id' is empty or None") - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = SpannerHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -149,10 +151,10 @@ class SpannerDeleteInstanceOperator(BaseOperator): # [START gcp_spanner_delete_template_fields] template_fields: Sequence[str] = ( - 'project_id', - 'instance_id', - 'gcp_conn_id', - 'impersonation_chain', + "project_id", + "instance_id", + "gcp_conn_id", + "impersonation_chain", ) # [END gcp_spanner_delete_template_fields] @@ -160,9 +162,9 @@ def __init__( self, *, instance_id: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.instance_id = instance_id @@ -173,12 +175,12 @@ def __init__( super().__init__(**kwargs) def _validate_inputs(self) -> None: - if self.project_id == '': + if self.project_id == "": raise AirflowException("The required parameter 'project_id' is empty") if not self.instance_id: raise AirflowException("The required parameter 'instance_id' is empty or None") - def execute(self, context: 'Context') -> Optional[bool]: + def execute(self, context: Context) -> bool | None: hook = SpannerHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -221,15 +223,15 @@ class SpannerQueryDatabaseInstanceOperator(BaseOperator): # [START gcp_spanner_query_template_fields] template_fields: Sequence[str] = ( - 'project_id', - 'instance_id', - 'database_id', - 'query', - 'gcp_conn_id', - 'impersonation_chain', + "project_id", + "instance_id", + "database_id", + "query", + "gcp_conn_id", + "impersonation_chain", ) - template_ext: Sequence[str] = ('.sql',) - template_fields_renderers = {'query': 'sql'} + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"query": "sql"} # [END gcp_spanner_query_template_fields] operator_extra_links = (SpannerDatabaseLink(),) @@ -238,10 +240,10 @@ def __init__( *, instance_id: str, database_id: str, - query: Union[str, List[str]], - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + query: str | list[str], + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.instance_id = instance_id @@ -254,7 +256,7 @@ def __init__( super().__init__(**kwargs) def _validate_inputs(self) -> None: - if self.project_id == '': + if self.project_id == "": raise AirflowException("The required parameter 'project_id' is empty") if not self.instance_id: raise AirflowException("The required parameter 'instance_id' is empty or None") @@ -263,13 +265,13 @@ def _validate_inputs(self) -> None: if not self.query: raise AirflowException("The required parameter 'query' is empty") - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = SpannerHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) if isinstance(self.query, str): - queries = [x.strip() for x in self.query.split(';')] + queries = [x.strip() for x in self.query.split(";")] self.sanitize_queries(queries) else: queries = self.query @@ -295,14 +297,13 @@ def execute(self, context: 'Context'): ) @staticmethod - def sanitize_queries(queries: List[str]) -> None: + def sanitize_queries(queries: list[str]) -> None: """ Drops empty query in queries. :param queries: queries - :rtype: None """ - if queries and queries[-1] == '': + if queries and queries[-1] == "": del queries[-1] @@ -333,15 +334,15 @@ class SpannerDeployDatabaseInstanceOperator(BaseOperator): # [START gcp_spanner_database_deploy_template_fields] template_fields: Sequence[str] = ( - 'project_id', - 'instance_id', - 'database_id', - 'ddl_statements', - 'gcp_conn_id', - 'impersonation_chain', + "project_id", + "instance_id", + "database_id", + "ddl_statements", + "gcp_conn_id", + "impersonation_chain", ) - template_ext: Sequence[str] = ('.sql',) - template_fields_renderers = {'ddl_statements': 'sql'} + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"ddl_statements": "sql"} # [END gcp_spanner_database_deploy_template_fields] operator_extra_links = (SpannerDatabaseLink(),) @@ -350,10 +351,10 @@ def __init__( *, instance_id: str, database_id: str, - ddl_statements: List[str], - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ddl_statements: list[str], + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.instance_id = instance_id @@ -366,14 +367,14 @@ def __init__( super().__init__(**kwargs) def _validate_inputs(self) -> None: - if self.project_id == '': + if self.project_id == "": raise AirflowException("The required parameter 'project_id' is empty") if not self.instance_id: raise AirflowException("The required parameter 'instance_id' is empty or None") if not self.database_id: raise AirflowException("The required parameter 'database_id' is empty or None") - def execute(self, context: 'Context') -> Optional[bool]: + def execute(self, context: Context) -> bool | None: hook = SpannerHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -439,15 +440,15 @@ class SpannerUpdateDatabaseInstanceOperator(BaseOperator): # [START gcp_spanner_database_update_template_fields] template_fields: Sequence[str] = ( - 'project_id', - 'instance_id', - 'database_id', - 'ddl_statements', - 'gcp_conn_id', - 'impersonation_chain', + "project_id", + "instance_id", + "database_id", + "ddl_statements", + "gcp_conn_id", + "impersonation_chain", ) - template_ext: Sequence[str] = ('.sql',) - template_fields_renderers = {'ddl_statements': 'sql'} + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"ddl_statements": "sql"} # [END gcp_spanner_database_update_template_fields] operator_extra_links = (SpannerDatabaseLink(),) @@ -456,11 +457,11 @@ def __init__( *, instance_id: str, database_id: str, - ddl_statements: List[str], - project_id: Optional[str] = None, - operation_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ddl_statements: list[str], + project_id: str | None = None, + operation_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.instance_id = instance_id @@ -474,7 +475,7 @@ def __init__( super().__init__(**kwargs) def _validate_inputs(self) -> None: - if self.project_id == '': + if self.project_id == "": raise AirflowException("The required parameter 'project_id' is empty") if not self.instance_id: raise AirflowException("The required parameter 'instance_id' is empty or None") @@ -483,7 +484,7 @@ def _validate_inputs(self) -> None: if not self.ddl_statements: raise AirflowException("The required parameter 'ddl_statements' is empty or None") - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = SpannerHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -538,11 +539,11 @@ class SpannerDeleteDatabaseInstanceOperator(BaseOperator): # [START gcp_spanner_database_delete_template_fields] template_fields: Sequence[str] = ( - 'project_id', - 'instance_id', - 'database_id', - 'gcp_conn_id', - 'impersonation_chain', + "project_id", + "instance_id", + "database_id", + "gcp_conn_id", + "impersonation_chain", ) # [END gcp_spanner_database_delete_template_fields] @@ -551,9 +552,9 @@ def __init__( *, instance_id: str, database_id: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.instance_id = instance_id @@ -565,14 +566,14 @@ def __init__( super().__init__(**kwargs) def _validate_inputs(self) -> None: - if self.project_id == '': + if self.project_id == "": raise AirflowException("The required parameter 'project_id' is empty") if not self.instance_id: raise AirflowException("The required parameter 'instance_id' is empty or None") if not self.database_id: raise AirflowException("The required parameter 'database_id' is empty or None") - def execute(self, context: 'Context') -> bool: + def execute(self, context: Context) -> bool: hook = SpannerHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, diff --git a/airflow/providers/google/cloud/operators/speech_to_text.py b/airflow/providers/google/cloud/operators/speech_to_text.py index 6d046c01b9df2..1f044536eec68 100644 --- a/airflow/providers/google/cloud/operators/speech_to_text.py +++ b/airflow/providers/google/cloud/operators/speech_to_text.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Speech to Text operator.""" -from typing import TYPE_CHECKING, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -80,11 +82,11 @@ def __init__( *, audio: RecognitionAudio, config: RecognitionConfig, - project_id: Optional[str] = None, + project_id: str | None = None, gcp_conn_id: str = "google_cloud_default", - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.audio = audio @@ -103,7 +105,7 @@ def _validate_inputs(self) -> None: if self.config == "": raise AirflowException("The required parameter 'config' is empty") - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudSpeechToTextHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, diff --git a/airflow/providers/google/cloud/operators/stackdriver.py b/airflow/providers/google/cloud/operators/stackdriver.py index 919afa566dba8..6006867b0fc94 100644 --- a/airflow/providers/google/cloud/operators/stackdriver.py +++ b/airflow/providers/google/cloud/operators/stackdriver.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -83,8 +84,8 @@ class StackdriverListAlertPoliciesOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'filter_', - 'impersonation_chain', + "filter_", + "impersonation_chain", ) operator_extra_links = (StackdriverPoliciesLink(),) ui_color = "#e5ffcc" @@ -92,17 +93,17 @@ class StackdriverListAlertPoliciesOperator(BaseOperator): def __init__( self, *, - format_: Optional[str] = None, - filter_: Optional[str] = None, - order_by: Optional[str] = None, - page_size: Optional[int] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - gcp_conn_id: str = 'google_cloud_default', - project_id: Optional[str] = None, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + format_: str | None = None, + filter_: str | None = None, + order_by: str | None = None, + page_size: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + project_id: str | None = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -117,11 +118,11 @@ def __init__( self.project_id = project_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - self.hook: Optional[StackdriverHook] = None + self.hook: StackdriverHook | None = None - def execute(self, context: 'Context'): + def execute(self, context: Context): self.log.info( - 'List Alert Policies: Project id: %s Format: %s Filter: %s Order By: %s Page Size: %s', + "List Alert Policies: Project id: %s Format: %s Filter: %s Order By: %s Page Size: %s", self.project_id, self.format_, self.filter_, @@ -189,22 +190,22 @@ class StackdriverEnableAlertPoliciesOperator(BaseOperator): ui_color = "#e5ffcc" template_fields: Sequence[str] = ( - 'filter_', - 'impersonation_chain', + "filter_", + "impersonation_chain", ) operator_extra_links = (StackdriverPoliciesLink(),) def __init__( self, *, - filter_: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - gcp_conn_id: str = 'google_cloud_default', - project_id: Optional[str] = None, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + filter_: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + project_id: str | None = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -216,10 +217,10 @@ def __init__( self.timeout = timeout self.metadata = metadata self.impersonation_chain = impersonation_chain - self.hook: Optional[StackdriverHook] = None + self.hook: StackdriverHook | None = None - def execute(self, context: 'Context'): - self.log.info('Enable Alert Policies: Project id: %s Filter: %s', self.project_id, self.filter_) + def execute(self, context: Context): + self.log.info("Enable Alert Policies: Project id: %s Filter: %s", self.project_id, self.filter_) if self.hook is None: self.hook = StackdriverHook( gcp_conn_id=self.gcp_conn_id, @@ -277,22 +278,22 @@ class StackdriverDisableAlertPoliciesOperator(BaseOperator): ui_color = "#e5ffcc" template_fields: Sequence[str] = ( - 'filter_', - 'impersonation_chain', + "filter_", + "impersonation_chain", ) operator_extra_links = (StackdriverPoliciesLink(),) def __init__( self, *, - filter_: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - gcp_conn_id: str = 'google_cloud_default', - project_id: Optional[str] = None, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + filter_: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + project_id: str | None = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -304,10 +305,10 @@ def __init__( self.timeout = timeout self.metadata = metadata self.impersonation_chain = impersonation_chain - self.hook: Optional[StackdriverHook] = None + self.hook: StackdriverHook | None = None - def execute(self, context: 'Context'): - self.log.info('Disable Alert Policies: Project id: %s Filter: %s', self.project_id, self.filter_) + def execute(self, context: Context): + self.log.info("Disable Alert Policies: Project id: %s Filter: %s", self.project_id, self.filter_) if self.hook is None: self.hook = StackdriverHook( gcp_conn_id=self.gcp_conn_id, @@ -364,10 +365,10 @@ class StackdriverUpsertAlertOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'alerts', - 'impersonation_chain', + "alerts", + "impersonation_chain", ) - template_ext: Sequence[str] = ('.json',) + template_ext: Sequence[str] = (".json",) operator_extra_links = (StackdriverPoliciesLink(),) ui_color = "#e5ffcc" @@ -376,13 +377,13 @@ def __init__( self, *, alerts: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - gcp_conn_id: str = 'google_cloud_default', - project_id: Optional[str] = None, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + project_id: str | None = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -394,10 +395,10 @@ def __init__( self.project_id = project_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - self.hook: Optional[StackdriverHook] = None + self.hook: StackdriverHook | None = None - def execute(self, context: 'Context'): - self.log.info('Upsert Alert Policies: Alerts: %s Project id: %s', self.alerts, self.project_id) + def execute(self, context: Context): + self.log.info("Upsert Alert Policies: Alerts: %s Project id: %s", self.alerts, self.project_id) if self.hook is None: self.hook = StackdriverHook( gcp_conn_id=self.gcp_conn_id, @@ -451,8 +452,8 @@ class StackdriverDeleteAlertOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'name', - 'impersonation_chain', + "name", + "impersonation_chain", ) ui_color = "#e5ffcc" @@ -461,13 +462,13 @@ def __init__( self, *, name: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - gcp_conn_id: str = 'google_cloud_default', - project_id: Optional[str] = None, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + project_id: str | None = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -479,10 +480,10 @@ def __init__( self.project_id = project_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - self.hook: Optional[StackdriverHook] = None + self.hook: StackdriverHook | None = None - def execute(self, context: 'Context'): - self.log.info('Delete Alert Policy: Project id: %s Name: %s', self.project_id, self.name) + def execute(self, context: Context): + self.log.info("Delete Alert Policy: Project id: %s Name: %s", self.project_id, self.name) if self.hook is None: self.hook = StackdriverHook( gcp_conn_id=self.gcp_conn_id, @@ -547,8 +548,8 @@ class StackdriverListNotificationChannelsOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'filter_', - 'impersonation_chain', + "filter_", + "impersonation_chain", ) operator_extra_links = (StackdriverNotificationsLink(),) @@ -557,17 +558,17 @@ class StackdriverListNotificationChannelsOperator(BaseOperator): def __init__( self, *, - format_: Optional[str] = None, - filter_: Optional[str] = None, - order_by: Optional[str] = None, - page_size: Optional[int] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - gcp_conn_id: str = 'google_cloud_default', - project_id: Optional[str] = None, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + format_: str | None = None, + filter_: str | None = None, + order_by: str | None = None, + page_size: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + project_id: str | None = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -582,11 +583,11 @@ def __init__( self.project_id = project_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - self.hook: Optional[StackdriverHook] = None + self.hook: StackdriverHook | None = None - def execute(self, context: 'Context'): + def execute(self, context: Context): self.log.info( - 'List Notification Channels: Project id: %s Format: %s Filter: %s Order By: %s Page Size: %s', + "List Notification Channels: Project id: %s Format: %s Filter: %s Order By: %s Page Size: %s", self.project_id, self.format_, self.filter_, @@ -652,8 +653,8 @@ class StackdriverEnableNotificationChannelsOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'filter_', - 'impersonation_chain', + "filter_", + "impersonation_chain", ) operator_extra_links = (StackdriverNotificationsLink(),) @@ -662,14 +663,14 @@ class StackdriverEnableNotificationChannelsOperator(BaseOperator): def __init__( self, *, - filter_: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - gcp_conn_id: str = 'google_cloud_default', - project_id: Optional[str] = None, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + filter_: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + project_id: str | None = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -681,11 +682,11 @@ def __init__( self.project_id = project_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - self.hook: Optional[StackdriverHook] = None + self.hook: StackdriverHook | None = None - def execute(self, context: 'Context'): + def execute(self, context: Context): self.log.info( - 'Enable Notification Channels: Project id: %s Filter: %s', self.project_id, self.filter_ + "Enable Notification Channels: Project id: %s Filter: %s", self.project_id, self.filter_ ) if self.hook is None: self.hook = StackdriverHook( @@ -742,8 +743,8 @@ class StackdriverDisableNotificationChannelsOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'filter_', - 'impersonation_chain', + "filter_", + "impersonation_chain", ) operator_extra_links = (StackdriverNotificationsLink(),) @@ -752,14 +753,14 @@ class StackdriverDisableNotificationChannelsOperator(BaseOperator): def __init__( self, *, - filter_: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - gcp_conn_id: str = 'google_cloud_default', - project_id: Optional[str] = None, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + filter_: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + project_id: str | None = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -771,11 +772,11 @@ def __init__( self.project_id = project_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - self.hook: Optional[StackdriverHook] = None + self.hook: StackdriverHook | None = None - def execute(self, context: 'Context'): + def execute(self, context: Context): self.log.info( - 'Disable Notification Channels: Project id: %s Filter: %s', self.project_id, self.filter_ + "Disable Notification Channels: Project id: %s Filter: %s", self.project_id, self.filter_ ) if self.hook is None: self.hook = StackdriverHook( @@ -833,10 +834,10 @@ class StackdriverUpsertNotificationChannelOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'channels', - 'impersonation_chain', + "channels", + "impersonation_chain", ) - template_ext: Sequence[str] = ('.json',) + template_ext: Sequence[str] = (".json",) operator_extra_links = (StackdriverNotificationsLink(),) ui_color = "#e5ffcc" @@ -845,13 +846,13 @@ def __init__( self, *, channels: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - gcp_conn_id: str = 'google_cloud_default', - project_id: Optional[str] = None, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + project_id: str | None = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -863,11 +864,11 @@ def __init__( self.project_id = project_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - self.hook: Optional[StackdriverHook] = None + self.hook: StackdriverHook | None = None - def execute(self, context: 'Context'): + def execute(self, context: Context): self.log.info( - 'Upsert Notification Channels: Channels: %s Project id: %s', self.channels, self.project_id + "Upsert Notification Channels: Channels: %s Project id: %s", self.channels, self.project_id ) if self.hook is None: self.hook = StackdriverHook( @@ -922,8 +923,8 @@ class StackdriverDeleteNotificationChannelOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'name', - 'impersonation_chain', + "name", + "impersonation_chain", ) ui_color = "#e5ffcc" @@ -932,13 +933,13 @@ def __init__( self, *, name: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - gcp_conn_id: str = 'google_cloud_default', - project_id: Optional[str] = None, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + project_id: str | None = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -950,10 +951,10 @@ def __init__( self.project_id = project_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - self.hook: Optional[StackdriverHook] = None + self.hook: StackdriverHook | None = None - def execute(self, context: 'Context'): - self.log.info('Delete Notification Channel: Project id: %s Name: %s', self.project_id, self.name) + def execute(self, context: Context): + self.log.info("Delete Notification Channel: Project id: %s Name: %s", self.project_id, self.name) if self.hook is None: self.hook = StackdriverHook( gcp_conn_id=self.gcp_conn_id, diff --git a/airflow/providers/google/cloud/operators/tasks.py b/airflow/providers/google/cloud/operators/tasks.py index b0c60309bb4c8..73991bc9b6986 100644 --- a/airflow/providers/google/cloud/operators/tasks.py +++ b/airflow/providers/google/cloud/operators/tasks.py @@ -15,13 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ This module contains various Google Cloud Tasks operators which allow you to perform basic operations using Cloud Tasks queues/tasks. """ -from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence, Tuple from google.api_core.exceptions import AlreadyExists from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -72,7 +73,6 @@ class CloudTasksQueueCreateOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.tasks_v2.types.Queue """ template_fields: Sequence[str] = ( @@ -90,13 +90,13 @@ def __init__( *, location: str, task_queue: Queue, - project_id: Optional[str] = None, - queue_name: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + project_id: str | None = None, + queue_name: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -110,7 +110,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudTasksHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -180,7 +180,6 @@ class CloudTasksQueueUpdateOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.tasks_v2.types.Queue """ template_fields: Sequence[str] = ( @@ -198,15 +197,15 @@ def __init__( self, *, task_queue: Queue, - project_id: Optional[str] = None, - location: Optional[str] = None, - queue_name: Optional[str] = None, - update_mask: Optional[FieldMask] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + project_id: str | None = None, + location: str | None = None, + queue_name: str | None = None, + update_mask: FieldMask | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -221,7 +220,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudTasksHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -272,7 +271,6 @@ class CloudTasksQueueGetOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.tasks_v2.types.Queue """ template_fields: Sequence[str] = ( @@ -289,12 +287,12 @@ def __init__( *, location: str, queue_name: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -307,7 +305,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudTasksHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -358,7 +356,6 @@ class CloudTasksQueuesListOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: list[google.cloud.tasks_v2.types.Queue] """ template_fields: Sequence[str] = ( @@ -373,14 +370,14 @@ def __init__( self, *, location: str, - project_id: Optional[str] = None, - results_filter: Optional[str] = None, - page_size: Optional[int] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + project_id: str | None = None, + results_filter: str | None = None, + page_size: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -394,7 +391,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudTasksHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -458,12 +455,12 @@ def __init__( *, location: str, queue_name: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -476,7 +473,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudTasksHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -519,7 +516,6 @@ class CloudTasksQueuePurgeOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: list[google.cloud.tasks_v2.types.Queue] """ template_fields: Sequence[str] = ( @@ -536,12 +532,12 @@ def __init__( *, location: str, queue_name: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -554,7 +550,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudTasksHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -603,7 +599,6 @@ class CloudTasksQueuePauseOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: list[google.cloud.tasks_v2.types.Queue] """ template_fields: Sequence[str] = ( @@ -620,12 +615,12 @@ def __init__( *, location: str, queue_name: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -638,7 +633,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudTasksHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -687,7 +682,6 @@ class CloudTasksQueueResumeOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: list[google.cloud.tasks_v2.types.Queue] """ template_fields: Sequence[str] = ( @@ -704,12 +698,12 @@ def __init__( *, location: str, queue_name: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -722,7 +716,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudTasksHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -777,7 +771,6 @@ class CloudTasksTaskCreateOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.tasks_v2.types.Task """ template_fields: Sequence[str] = ( @@ -796,15 +789,15 @@ def __init__( *, location: str, queue_name: str, - task: Union[Dict, Task], - project_id: Optional[str] = None, - task_name: Optional[str] = None, - response_view: Optional[Task.View] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + task: dict | Task, + project_id: str | None = None, + task_name: str | None = None, + response_view: Task.View | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -820,7 +813,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudTasksHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -875,7 +868,6 @@ class CloudTasksTaskGetOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.tasks_v2.types.Task """ template_fields: Sequence[str] = ( @@ -894,13 +886,13 @@ def __init__( location: str, queue_name: str, task_name: str, - project_id: Optional[str] = None, - response_view: Optional[Task.View] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + project_id: str | None = None, + response_view: Task.View | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -915,7 +907,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudTasksHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -970,7 +962,6 @@ class CloudTasksTasksListOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: list[google.cloud.tasks_v2.types.Task] """ template_fields: Sequence[str] = ( @@ -987,14 +978,14 @@ def __init__( *, location: str, queue_name: str, - project_id: Optional[str] = None, - response_view: Optional[Task.View] = None, - page_size: Optional[int] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + project_id: str | None = None, + response_view: Task.View | None = None, + page_size: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1009,7 +1000,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudTasksHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1078,12 +1069,12 @@ def __init__( location: str, queue_name: str, task_name: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1097,7 +1088,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudTasksHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1144,7 +1135,6 @@ class CloudTasksTaskRunOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :rtype: google.cloud.tasks_v2.types.Task """ template_fields: Sequence[str] = ( @@ -1163,13 +1153,13 @@ def __init__( location: str, queue_name: str, task_name: str, - project_id: Optional[str] = None, - response_view: Optional[Task.View] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + project_id: str | None = None, + response_view: Task.View | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1184,7 +1174,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudTasksHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, diff --git a/airflow/providers/google/cloud/operators/text_to_speech.py b/airflow/providers/google/cloud/operators/text_to_speech.py index 1d14134c32781..4181180e01647 100644 --- a/airflow/providers/google/cloud/operators/text_to_speech.py +++ b/airflow/providers/google/cloud/operators/text_to_speech.py @@ -16,9 +16,10 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Text to Speech operator.""" +from __future__ import annotations from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Dict, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -86,16 +87,16 @@ class CloudTextToSpeechSynthesizeOperator(BaseOperator): def __init__( self, *, - input_data: Union[Dict, SynthesisInput], - voice: Union[Dict, VoiceSelectionParams], - audio_config: Union[Dict, AudioConfig], + input_data: dict | SynthesisInput, + voice: dict | VoiceSelectionParams, + audio_config: dict | AudioConfig, target_bucket_name: str, target_filename: str, - project_id: Optional[str] = None, + project_id: str | None = None, gcp_conn_id: str = "google_cloud_default", - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.input_data = input_data @@ -122,7 +123,7 @@ def _validate_inputs(self) -> None: if getattr(self, parameter) == "": raise AirflowException(f"The required parameter '{parameter}' is empty") - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = CloudTextToSpeechHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, diff --git a/airflow/providers/google/cloud/operators/translate.py b/airflow/providers/google/cloud/operators/translate.py index 40bd4dda6e981..9af76795261a1 100644 --- a/airflow/providers/google/cloud/operators/translate.py +++ b/airflow/providers/google/cloud/operators/translate.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Translate operators.""" -from typing import TYPE_CHECKING, List, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -80,26 +82,26 @@ class CloudTranslateTextOperator(BaseOperator): # [START translate_template_fields] template_fields: Sequence[str] = ( - 'values', - 'target_language', - 'format_', - 'source_language', - 'model', - 'gcp_conn_id', - 'impersonation_chain', + "values", + "target_language", + "format_", + "source_language", + "model", + "gcp_conn_id", + "impersonation_chain", ) # [END translate_template_fields] def __init__( self, *, - values: Union[List[str], str], + values: list[str] | str, target_language: str, format_: str, - source_language: Optional[str], + source_language: str | None, model: str, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -111,7 +113,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: hook = CloudTranslateHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -127,6 +129,6 @@ def execute(self, context: 'Context') -> dict: self.log.debug("Translation %s", translation) return translation except ValueError as e: - self.log.error('An error has been thrown from translate method:') + self.log.error("An error has been thrown from translate method:") self.log.error(e) raise AirflowException(e) diff --git a/airflow/providers/google/cloud/operators/translate_speech.py b/airflow/providers/google/cloud/operators/translate_speech.py index 7e1f7caa724c5..8a88ffff73f07 100644 --- a/airflow/providers/google/cloud/operators/translate_speech.py +++ b/airflow/providers/google/cloud/operators/translate_speech.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Cloud Translate Speech operator.""" -from typing import TYPE_CHECKING, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from google.cloud.speech_v1.types import RecognitionAudio, RecognitionConfig from google.protobuf.json_format import MessageToDict @@ -25,6 +27,7 @@ from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.speech_to_text import CloudSpeechToTextHook from airflow.providers.google.cloud.hooks.translate import CloudTranslateHook +from airflow.providers.google.common.links.storage import FileDetailsLink if TYPE_CHECKING: from airflow.utils.context import Context @@ -101,14 +104,15 @@ class CloudTranslateSpeechOperator(BaseOperator): # [START translate_speech_template_fields] template_fields: Sequence[str] = ( - 'target_language', - 'format_', - 'source_language', - 'model', - 'project_id', - 'gcp_conn_id', - 'impersonation_chain', + "target_language", + "format_", + "source_language", + "model", + "project_id", + "gcp_conn_id", + "impersonation_chain", ) + operator_extra_links = (FileDetailsLink(),) # [END translate_speech_template_fields] def __init__( @@ -118,11 +122,11 @@ def __init__( config: RecognitionConfig, target_language: str, format_: str, - source_language: Optional[str], + source_language: str | None, model: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -136,7 +140,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: speech_to_text_hook = CloudSpeechToTextHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -151,13 +155,13 @@ def execute(self, context: 'Context') -> dict: self.log.info("Recognition operation finished") - if not recognize_dict['results']: + if not recognize_dict["results"]: self.log.info("No recognition results") return {} self.log.debug("Recognition result: %s", recognize_dict) try: - transcript = recognize_dict['results'][0]['alternatives'][0]['transcript'] + transcript = recognize_dict["results"][0]["alternatives"][0]["transcript"] except KeyError as key: raise AirflowException( f"Wrong response '{recognize_dict}' returned - it should contain {key} field" @@ -171,9 +175,15 @@ def execute(self, context: 'Context') -> dict: source_language=self.source_language, model=self.model, ) - self.log.info('Translated output: %s', translation) + self.log.info("Translated output: %s", translation) + FileDetailsLink.persist( + context=context, + task_instance=self, + uri=self.audio["uri"][5:], + project_id=self.project_id or translate_hook.project_id, + ) return translation except ValueError as e: - self.log.error('An error has been thrown from translate speech method:') + self.log.error("An error has been thrown from translate speech method:") self.log.error(e) raise AirflowException(e) diff --git a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py index f2882631df7f2..6851d5263351c 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py @@ -15,10 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains Google Vertex AI operators.""" +from __future__ import annotations -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Sequence from google.api_core.exceptions import NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -29,7 +29,11 @@ from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.vertex_ai.auto_ml import AutoMLHook -from airflow.providers.google.cloud.links.vertex_ai import VertexAIModelLink, VertexAITrainingPipelinesLink +from airflow.providers.google.cloud.links.vertex_ai import ( + VertexAIModelLink, + VertexAITrainingLink, + VertexAITrainingPipelinesLink, +) if TYPE_CHECKING: from airflow.utils.context import Context @@ -44,18 +48,18 @@ def __init__( project_id: str, region: str, display_name: str, - labels: Optional[Dict[str, str]] = None, - training_encryption_spec_key_name: Optional[str] = None, - model_encryption_spec_key_name: Optional[str] = None, + labels: dict[str, str] | None = None, + training_encryption_spec_key_name: str | None = None, + model_encryption_spec_key_name: str | None = None, # RUN - training_fraction_split: Optional[float] = None, - test_fraction_split: Optional[float] = None, - model_display_name: Optional[str] = None, - model_labels: Optional[Dict[str, str]] = None, + training_fraction_split: float | None = None, + test_fraction_split: float | None = None, + model_display_name: str | None = None, + model_labels: dict[str, str] | None = None, sync: bool = True, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -75,7 +79,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - self.hook = None # type: Optional[AutoMLHook] + self.hook: AutoMLHook | None = None def on_kill(self) -> None: """ @@ -89,11 +93,12 @@ def on_kill(self) -> None: class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator): """Create AutoML Forecasting Training job""" - template_fields = [ - 'region', - 'impersonation_chain', - ] - operator_extra_links = (VertexAIModelLink(),) + template_fields = ( + "dataset_id", + "region", + "impersonation_chain", + ) + operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink()) def __init__( self, @@ -102,24 +107,24 @@ def __init__( target_column: str, time_column: str, time_series_identifier_column: str, - unavailable_at_forecast_columns: List[str], - available_at_forecast_columns: List[str], + unavailable_at_forecast_columns: list[str], + available_at_forecast_columns: list[str], forecast_horizon: int, data_granularity_unit: str, data_granularity_count: int, - optimization_objective: Optional[str] = None, - column_specs: Optional[Dict[str, str]] = None, - column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None, - validation_fraction_split: Optional[float] = None, - predefined_split_column_name: Optional[str] = None, - weight_column: Optional[str] = None, - time_series_attribute_columns: Optional[List[str]] = None, - context_window: Optional[int] = None, + optimization_objective: str | None = None, + column_specs: dict[str, str] | None = None, + column_transformations: list[dict[str, dict[str, str]]] | None = None, + validation_fraction_split: float | None = None, + predefined_split_column_name: str | None = None, + weight_column: str | None = None, + time_series_attribute_columns: list[str] | None = None, + context_window: int | None = None, export_evaluated_data_items: bool = False, - export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None, + export_evaluated_data_items_bigquery_destination_uri: str | None = None, export_evaluated_data_items_override_destination: bool = False, - quantiles: Optional[List[float]] = None, - validation_options: Optional[str] = None, + quantiles: list[float] | None = None, + validation_options: str | None = None, budget_milli_node_hours: int = 1000, **kwargs, ) -> None: @@ -152,13 +157,13 @@ def __init__( self.validation_options = validation_options self.budget_milli_node_hours = budget_milli_node_hours - def execute(self, context: "Context"): + def execute(self, context: Context): self.hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) - model = self.hook.create_auto_ml_forecasting_training_job( + model, training_id = self.hook.create_auto_ml_forecasting_training_job( project_id=self.project_id, region=self.region, display_name=self.display_name, @@ -199,20 +204,26 @@ def execute(self, context: "Context"): sync=self.sync, ) - result = Model.to_dict(model) - model_id = self.hook.extract_model_id(result) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + if model: + result = Model.to_dict(model) + model_id = self.hook.extract_model_id(result) + VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + else: + result = model # type: ignore + self.xcom_push(context, key="training_id", value=training_id) + VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) return result class CreateAutoMLImageTrainingJobOperator(AutoMLTrainingJobBaseOperator): """Create Auto ML Image Training job""" - template_fields = [ - 'region', - 'impersonation_chain', - ] - operator_extra_links = (VertexAIModelLink(),) + template_fields = ( + "dataset_id", + "region", + "impersonation_chain", + ) + operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink()) def __init__( self, @@ -221,12 +232,12 @@ def __init__( prediction_type: str = "classification", multi_label: bool = False, model_type: str = "CLOUD", - base_model: Optional[Model] = None, - validation_fraction_split: Optional[float] = None, - training_filter_split: Optional[str] = None, - validation_filter_split: Optional[str] = None, - test_filter_split: Optional[str] = None, - budget_milli_node_hours: Optional[int] = None, + base_model: Model | None = None, + validation_fraction_split: float | None = None, + training_filter_split: str | None = None, + validation_filter_split: str | None = None, + test_filter_split: str | None = None, + budget_milli_node_hours: int | None = None, disable_early_stopping: bool = False, **kwargs, ) -> None: @@ -243,13 +254,13 @@ def __init__( self.budget_milli_node_hours = budget_milli_node_hours self.disable_early_stopping = disable_early_stopping - def execute(self, context: "Context"): + def execute(self, context: Context): self.hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) - model = self.hook.create_auto_ml_image_training_job( + model, training_id = self.hook.create_auto_ml_image_training_job( project_id=self.project_id, region=self.region, display_name=self.display_name, @@ -274,20 +285,26 @@ def execute(self, context: "Context"): sync=self.sync, ) - result = Model.to_dict(model) - model_id = self.hook.extract_model_id(result) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + if model: + result = Model.to_dict(model) + model_id = self.hook.extract_model_id(result) + VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + else: + result = model # type: ignore + self.xcom_push(context, key="training_id", value=training_id) + VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) return result class CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator): """Create Auto ML Tabular Training job""" - template_fields = [ - 'region', - 'impersonation_chain', - ] - operator_extra_links = (VertexAIModelLink(),) + template_fields = ( + "dataset_id", + "region", + "impersonation_chain", + ) + operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink()) def __init__( self, @@ -295,19 +312,19 @@ def __init__( dataset_id: str, target_column: str, optimization_prediction_type: str, - optimization_objective: Optional[str] = None, - column_specs: Optional[Dict[str, str]] = None, - column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None, - optimization_objective_recall_value: Optional[float] = None, - optimization_objective_precision_value: Optional[float] = None, - validation_fraction_split: Optional[float] = None, - predefined_split_column_name: Optional[str] = None, - timestamp_split_column_name: Optional[str] = None, - weight_column: Optional[str] = None, + optimization_objective: str | None = None, + column_specs: dict[str, str] | None = None, + column_transformations: list[dict[str, dict[str, str]]] | None = None, + optimization_objective_recall_value: float | None = None, + optimization_objective_precision_value: float | None = None, + validation_fraction_split: float | None = None, + predefined_split_column_name: str | None = None, + timestamp_split_column_name: str | None = None, + weight_column: str | None = None, budget_milli_node_hours: int = 1000, disable_early_stopping: bool = False, export_evaluated_data_items: bool = False, - export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None, + export_evaluated_data_items_bigquery_destination_uri: str | None = None, export_evaluated_data_items_override_destination: bool = False, **kwargs, ) -> None: @@ -334,13 +351,13 @@ def __init__( export_evaluated_data_items_override_destination ) - def execute(self, context: "Context"): + def execute(self, context: Context): self.hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) - model = self.hook.create_auto_ml_tabular_training_job( + model, training_id = self.hook.create_auto_ml_tabular_training_job( project_id=self.project_id, region=self.region, display_name=self.display_name, @@ -375,9 +392,14 @@ def execute(self, context: "Context"): sync=self.sync, ) - result = Model.to_dict(model) - model_id = self.hook.extract_model_id(result) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + if model: + result = Model.to_dict(model) + model_id = self.hook.extract_model_id(result) + VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + else: + result = model # type: ignore + self.xcom_push(context, key="training_id", value=training_id) + VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) return result @@ -385,10 +407,11 @@ class CreateAutoMLTextTrainingJobOperator(AutoMLTrainingJobBaseOperator): """Create Auto ML Text Training job""" template_fields = [ - 'region', - 'impersonation_chain', + "dataset_id", + "region", + "impersonation_chain", ] - operator_extra_links = (VertexAIModelLink(),) + operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink()) def __init__( self, @@ -397,10 +420,10 @@ def __init__( prediction_type: str, multi_label: bool = False, sentiment_max: int = 10, - validation_fraction_split: Optional[float] = None, - training_filter_split: Optional[str] = None, - validation_filter_split: Optional[str] = None, - test_filter_split: Optional[str] = None, + validation_fraction_split: float | None = None, + training_filter_split: str | None = None, + validation_filter_split: str | None = None, + test_filter_split: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -413,13 +436,13 @@ def __init__( self.validation_filter_split = validation_filter_split self.test_filter_split = test_filter_split - def execute(self, context: "Context"): + def execute(self, context: Context): self.hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) - model = self.hook.create_auto_ml_text_training_job( + model, training_id = self.hook.create_auto_ml_text_training_job( project_id=self.project_id, region=self.region, display_name=self.display_name, @@ -441,20 +464,26 @@ def execute(self, context: "Context"): sync=self.sync, ) - result = Model.to_dict(model) - model_id = self.hook.extract_model_id(result) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + if model: + result = Model.to_dict(model) + model_id = self.hook.extract_model_id(result) + VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + else: + result = model # type: ignore + self.xcom_push(context, key="training_id", value=training_id) + VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) return result class CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator): """Create Auto ML Video Training job""" - template_fields = [ - 'region', - 'impersonation_chain', - ] - operator_extra_links = (VertexAIModelLink(),) + template_fields = ( + "dataset_id", + "region", + "impersonation_chain", + ) + operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink()) def __init__( self, @@ -462,8 +491,8 @@ def __init__( dataset_id: str, prediction_type: str = "classification", model_type: str = "CLOUD", - training_filter_split: Optional[str] = None, - test_filter_split: Optional[str] = None, + training_filter_split: str | None = None, + test_filter_split: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -473,13 +502,13 @@ def __init__( self.training_filter_split = training_filter_split self.test_filter_split = test_filter_split - def execute(self, context: "Context"): + def execute(self, context: Context): self.hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) - model = self.hook.create_auto_ml_video_training_job( + model, training_id = self.hook.create_auto_ml_video_training_job( project_id=self.project_id, region=self.region, display_name=self.display_name, @@ -498,9 +527,14 @@ def execute(self, context: "Context"): sync=self.sync, ) - result = Model.to_dict(model) - model_id = self.hook.extract_model_id(result) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + if model: + result = Model.to_dict(model) + model_id = self.hook.extract_model_id(result) + VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + else: + result = model # type: ignore + self.xcom_push(context, key="training_id", value=training_id) + VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) return result @@ -509,7 +543,7 @@ class DeleteAutoMLTrainingJobOperator(BaseOperator): AutoMLTextTrainingJob, or AutoMLVideoTrainingJob. """ - template_fields = ("region", "project_id", "impersonation_chain") + template_fields = ("training_pipeline", "region", "project_id", "impersonation_chain") def __init__( self, @@ -517,12 +551,12 @@ def __init__( training_pipeline_id: str, region: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -536,7 +570,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -563,11 +597,11 @@ class ListAutoMLTrainingJobOperator(BaseOperator): AutoMLTextTrainingJob, or AutoMLVideoTrainingJob in a Location. """ - template_fields = [ + template_fields = ( "region", "project_id", "impersonation_chain", - ] + ) operator_extra_links = [ VertexAITrainingPipelinesLink(), ] @@ -577,16 +611,16 @@ def __init__( *, region: str, project_id: str, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - filter: Optional[str] = None, - read_mask: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + page_size: int | None = None, + page_token: str | None = None, + filter: str | None = None, + read_mask: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -603,7 +637,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py b/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py index 4a30e093c0a6f..448e46611eaea 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains Google Vertex AI operators. .. spelling:: @@ -25,8 +24,9 @@ aiplatform gapic """ +from __future__ import annotations -from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Sequence from google.api_core.exceptions import NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -156,7 +156,7 @@ class CreateBatchPredictionJobOperator(BaseOperator): account from the list granting this role to the originating account (templated). """ - template_fields = ("region", "project_id", "impersonation_chain") + template_fields = ("region", "project_id", "model_name", "impersonation_chain") operator_extra_links = (VertexAIBatchPredictionJobLink(),) def __init__( @@ -165,28 +165,28 @@ def __init__( region: str, project_id: str, job_display_name: str, - model_name: Union[str, "Model"], + model_name: str | Model, instances_format: str = "jsonl", predictions_format: str = "jsonl", - gcs_source: Optional[Union[str, Sequence[str]]] = None, - bigquery_source: Optional[str] = None, - gcs_destination_prefix: Optional[str] = None, - bigquery_destination_prefix: Optional[str] = None, - model_parameters: Optional[Dict] = None, - machine_type: Optional[str] = None, - accelerator_type: Optional[str] = None, - accelerator_count: Optional[int] = None, - starting_replica_count: Optional[int] = None, - max_replica_count: Optional[int] = None, - generate_explanation: Optional[bool] = False, - explanation_metadata: Optional["explain.ExplanationMetadata"] = None, - explanation_parameters: Optional["explain.ExplanationParameters"] = None, - labels: Optional[Dict[str, str]] = None, - encryption_spec_key_name: Optional[str] = None, + gcs_source: str | Sequence[str] | None = None, + bigquery_source: str | None = None, + gcs_destination_prefix: str | None = None, + bigquery_destination_prefix: str | None = None, + model_parameters: dict | None = None, + machine_type: str | None = None, + accelerator_type: str | None = None, + accelerator_count: int | None = None, + starting_replica_count: int | None = None, + max_replica_count: int | None = None, + generate_explanation: bool | None = False, + explanation_metadata: explain.ExplanationMetadata | None = None, + explanation_parameters: explain.ExplanationParameters | None = None, + labels: dict[str, str] | None = None, + encryption_spec_key_name: str | None = None, sync: bool = True, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -215,9 +215,9 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - self.hook = None # type: Optional[BatchPredictionJobHook] + self.hook: BatchPredictionJobHook | None = None - def execute(self, context: 'Context'): + def execute(self, context: Context): self.log.info("Creating Batch prediction job") self.hook = BatchPredictionJobHook( gcp_conn_id=self.gcp_conn_id, @@ -300,12 +300,12 @@ def __init__( region: str, project_id: str, batch_prediction_job_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -319,7 +319,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = BatchPredictionJobHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -375,12 +375,12 @@ def __init__( region: str, project_id: str, batch_prediction_job: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -394,7 +394,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = BatchPredictionJobHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -464,16 +464,16 @@ def __init__( *, region: str, project_id: str, - filter: Optional[str] = None, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - read_mask: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + filter: str | None = None, + page_size: int | None = None, + page_token: str | None = None, + read_mask: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -490,7 +490,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = BatchPredictionJobHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py index 6196bf38805dc..b10a0715fc826 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py @@ -15,10 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains Google Vertex AI operators.""" +from __future__ import annotations -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Sequence from google.api_core.exceptions import NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -29,7 +29,11 @@ from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook -from airflow.providers.google.cloud.links.vertex_ai import VertexAIModelLink, VertexAITrainingPipelinesLink +from airflow.providers.google.cloud.links.vertex_ai import ( + VertexAIModelLink, + VertexAITrainingLink, + VertexAITrainingPipelinesLink, +) if TYPE_CHECKING: from airflow.utils.context import Context @@ -45,51 +49,51 @@ def __init__( region: str, display_name: str, container_uri: str, - model_serving_container_image_uri: Optional[str] = None, - model_serving_container_predict_route: Optional[str] = None, - model_serving_container_health_route: Optional[str] = None, - model_serving_container_command: Optional[Sequence[str]] = None, - model_serving_container_args: Optional[Sequence[str]] = None, - model_serving_container_environment_variables: Optional[Dict[str, str]] = None, - model_serving_container_ports: Optional[Sequence[int]] = None, - model_description: Optional[str] = None, - model_instance_schema_uri: Optional[str] = None, - model_parameters_schema_uri: Optional[str] = None, - model_prediction_schema_uri: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - training_encryption_spec_key_name: Optional[str] = None, - model_encryption_spec_key_name: Optional[str] = None, - staging_bucket: Optional[str] = None, + model_serving_container_image_uri: str | None = None, + model_serving_container_predict_route: str | None = None, + model_serving_container_health_route: str | None = None, + model_serving_container_command: Sequence[str] | None = None, + model_serving_container_args: Sequence[str] | None = None, + model_serving_container_environment_variables: dict[str, str] | None = None, + model_serving_container_ports: Sequence[int] | None = None, + model_description: str | None = None, + model_instance_schema_uri: str | None = None, + model_parameters_schema_uri: str | None = None, + model_prediction_schema_uri: str | None = None, + labels: dict[str, str] | None = None, + training_encryption_spec_key_name: str | None = None, + model_encryption_spec_key_name: str | None = None, + staging_bucket: str | None = None, # RUN - dataset_id: Optional[str] = None, - annotation_schema_uri: Optional[str] = None, - model_display_name: Optional[str] = None, - model_labels: Optional[Dict[str, str]] = None, - base_output_dir: Optional[str] = None, - service_account: Optional[str] = None, - network: Optional[str] = None, - bigquery_destination: Optional[str] = None, - args: Optional[List[Union[str, float, int]]] = None, - environment_variables: Optional[Dict[str, str]] = None, + dataset_id: str | None = None, + annotation_schema_uri: str | None = None, + model_display_name: str | None = None, + model_labels: dict[str, str] | None = None, + base_output_dir: str | None = None, + service_account: str | None = None, + network: str | None = None, + bigquery_destination: str | None = None, + args: list[str | float | int] | None = None, + environment_variables: dict[str, str] | None = None, replica_count: int = 1, machine_type: str = "n1-standard-4", accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", accelerator_count: int = 0, boot_disk_type: str = "pd-ssd", boot_disk_size_gb: int = 100, - training_fraction_split: Optional[float] = None, - validation_fraction_split: Optional[float] = None, - test_fraction_split: Optional[float] = None, - training_filter_split: Optional[str] = None, - validation_filter_split: Optional[str] = None, - test_filter_split: Optional[str] = None, - predefined_split_column_name: Optional[str] = None, - timestamp_split_column_name: Optional[str] = None, - tensorboard: Optional[str] = None, + training_fraction_split: float | None = None, + validation_fraction_split: float | None = None, + test_fraction_split: float | None = None, + training_filter_split: str | None = None, + validation_filter_split: str | None = None, + test_filter_split: str | None = None, + predefined_split_column_name: str | None = None, + timestamp_split_column_name: str | None = None, + tensorboard: str | None = None, sync=True, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -115,7 +119,7 @@ def __init__( self.staging_bucket = staging_bucket # END Custom # START Run param - self.dataset = Dataset(name=dataset_id) if dataset_id else None + self.dataset_id = dataset_id self.annotation_schema_uri = annotation_schema_uri self.model_display_name = model_display_name self.model_labels = model_labels @@ -406,12 +410,13 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): account from the list granting this role to the originating account (templated). """ - template_fields = [ - 'region', - 'command', - 'impersonation_chain', - ] - operator_extra_links = (VertexAIModelLink(),) + template_fields = ( + "region", + "command", + "dataset_id", + "impersonation_chain", + ) + operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink()) def __init__( self, @@ -422,13 +427,13 @@ def __init__( super().__init__(**kwargs) self.command = command - def execute(self, context: "Context"): + def execute(self, context: Context): self.hook = CustomJobHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) - model = self.hook.create_custom_container_training_job( + model, training_id, custom_job_id = self.hook.create_custom_container_training_job( project_id=self.project_id, region=self.region, display_name=self.display_name, @@ -450,7 +455,7 @@ def execute(self, context: "Context"): model_encryption_spec_key_name=self.model_encryption_spec_key_name, staging_bucket=self.staging_bucket, # RUN - dataset=self.dataset, + dataset=Dataset(name=self.dataset_id) if self.dataset_id else None, annotation_schema_uri=self.annotation_schema_uri, model_display_name=self.model_display_name, model_labels=self.model_labels, @@ -478,9 +483,15 @@ def execute(self, context: "Context"): sync=True, ) - result = Model.to_dict(model) - model_id = self.hook.extract_model_id(result) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + if model: + result = Model.to_dict(model) + model_id = self.hook.extract_model_id(result) + VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + else: + result = model # type: ignore + self.xcom_push(context, key="training_id", value=training_id) + self.xcom_push(context, key="custom_job_id", value=custom_job_id) + VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) return result def on_kill(self) -> None: @@ -751,11 +762,12 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator account from the list granting this role to the originating account (templated). """ - template_fields = [ - 'region', - 'impersonation_chain', - ] - operator_extra_links = (VertexAIModelLink(),) + template_fields = ( + "region", + "dataset_id", + "impersonation_chain", + ) + operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink()) def __init__( self, @@ -768,13 +780,13 @@ def __init__( self.python_package_gcs_uri = python_package_gcs_uri self.python_module_name = python_module_name - def execute(self, context: "Context"): + def execute(self, context: Context): self.hook = CustomJobHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) - model = self.hook.create_custom_python_package_training_job( + model, training_id, custom_job_id = self.hook.create_custom_python_package_training_job( project_id=self.project_id, region=self.region, display_name=self.display_name, @@ -797,7 +809,7 @@ def execute(self, context: "Context"): model_encryption_spec_key_name=self.model_encryption_spec_key_name, staging_bucket=self.staging_bucket, # RUN - dataset=self.dataset, + dataset=Dataset(name=self.dataset_id) if self.dataset_id else None, annotation_schema_uri=self.annotation_schema_uri, model_display_name=self.model_display_name, model_labels=self.model_labels, @@ -825,9 +837,15 @@ def execute(self, context: "Context"): sync=True, ) - result = Model.to_dict(model) - model_id = self.hook.extract_model_id(result) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + if model: + result = Model.to_dict(model) + model_id = self.hook.extract_model_id(result) + VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + else: + result = model # type: ignore + self.xcom_push(context, key="training_id", value=training_id) + self.xcom_push(context, key="custom_job_id", value=custom_job_id) + VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) return result def on_kill(self) -> None: @@ -1098,32 +1116,33 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): account from the list granting this role to the originating account (templated). """ - template_fields = [ - 'region', - 'script_path', - 'requirements', - 'impersonation_chain', - ] - operator_extra_links = (VertexAIModelLink(),) + template_fields = ( + "region", + "script_path", + "requirements", + "dataset_id", + "impersonation_chain", + ) + operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink()) def __init__( self, *, script_path: str, - requirements: Optional[Sequence[str]] = None, + requirements: Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) self.requirements = requirements self.script_path = script_path - def execute(self, context: "Context"): + def execute(self, context: Context): self.hook = CustomJobHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) - model = self.hook.create_custom_training_job( + model, training_id, custom_job_id = self.hook.create_custom_training_job( project_id=self.project_id, region=self.region, display_name=self.display_name, @@ -1146,7 +1165,7 @@ def execute(self, context: "Context"): model_encryption_spec_key_name=self.model_encryption_spec_key_name, staging_bucket=self.staging_bucket, # RUN - dataset=self.dataset, + dataset=Dataset(name=self.dataset_id) if self.dataset_id else None, annotation_schema_uri=self.annotation_schema_uri, model_display_name=self.model_display_name, model_labels=self.model_labels, @@ -1174,9 +1193,15 @@ def execute(self, context: "Context"): sync=True, ) - result = Model.to_dict(model) - model_id = self.hook.extract_model_id(result) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + if model: + result = Model.to_dict(model) + model_id = self.hook.extract_model_id(result) + VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + else: + result = model # type: ignore + self.xcom_push(context, key="training_id", value=training_id) + self.xcom_push(context, key="custom_job_id", value=custom_job_id) + VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) return result def on_kill(self) -> None: @@ -1212,7 +1237,7 @@ class DeleteCustomTrainingJobOperator(BaseOperator): account from the list granting this role to the originating account (templated). """ - template_fields = ("region", "project_id", "impersonation_chain") + template_fields = ("training_pipeline", "custom_job", "region", "project_id", "impersonation_chain") def __init__( self, @@ -1221,12 +1246,12 @@ def __init__( custom_job_id: str, region: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1241,7 +1266,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = CustomJobHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -1335,16 +1360,16 @@ def __init__( *, region: str, project_id: str, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - filter: Optional[str] = None, - read_mask: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + page_size: int | None = None, + page_token: str | None = None, + filter: str | None = None, + read_mask: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1361,7 +1386,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = CustomJobHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py index 95ec5efd9238b..1359548d0b8cd 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py @@ -15,10 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains Google Vertex AI operators.""" +from __future__ import annotations -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Sequence from google.api_core.exceptions import NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -67,13 +67,13 @@ def __init__( *, region: str, project_id: str, - dataset: Union[Dataset, Dict], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + dataset: Dataset | dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -87,7 +87,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -147,13 +147,13 @@ def __init__( region: str, project_id: str, dataset_id: str, - read_mask: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + read_mask: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -168,7 +168,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -225,12 +225,12 @@ def __init__( region: str, project_id: str, dataset_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -244,7 +244,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -300,13 +300,13 @@ def __init__( region: str, project_id: str, dataset_id: str, - export_config: Union[ExportDataConfig, Dict], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + export_config: ExportDataConfig | dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -321,7 +321,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -376,13 +376,13 @@ def __init__( region: str, project_id: str, dataset_id: str, - import_configs: Union[Sequence[ImportDataConfig], List], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + import_configs: Sequence[ImportDataConfig] | list, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -397,7 +397,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -455,17 +455,17 @@ def __init__( *, region: str, project_id: str, - filter: Optional[str] = None, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - read_mask: Optional[str] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + filter: str | None = None, + page_size: int | None = None, + page_token: str | None = None, + read_mask: str | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -483,7 +483,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -539,14 +539,14 @@ def __init__( project_id: str, region: str, dataset_id: str, - dataset: Union[Dataset, Dict], - update_mask: Union[FieldMask, Dict], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + dataset: Dataset | dict, + update_mask: FieldMask | dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -562,7 +562,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py b/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py index 64e2b4816af07..cfe67cf7aa8ca 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains Google Vertex AI operators. .. spelling:: @@ -27,8 +26,9 @@ FieldMask unassigns """ +from __future__ import annotations -from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Sequence from google.api_core.exceptions import NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -82,14 +82,14 @@ def __init__( *, region: str, project_id: str, - endpoint: Union[Endpoint, Dict], - endpoint_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + endpoint: Endpoint | dict, + endpoint_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -104,7 +104,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = EndpointServiceHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -164,12 +164,12 @@ def __init__( region: str, project_id: str, endpoint_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -183,7 +183,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = EndpointServiceHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -246,7 +246,7 @@ class DeployModelOperator(BaseOperator): account from the list granting this role to the originating account (templated). """ - template_fields = ("region", "endpoint_id", "project_id", "impersonation_chain") + template_fields = ("region", "endpoint_id", "project_id", "deployed_model", "impersonation_chain") operator_extra_links = (VertexAIModelLink(),) def __init__( @@ -255,14 +255,14 @@ def __init__( region: str, project_id: str, endpoint_id: str, - deployed_model: Union[DeployedModel, Dict], - traffic_split: Optional[Union[Sequence, Dict]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + deployed_model: DeployedModel | dict, + traffic_split: Sequence | dict | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -278,7 +278,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = EndpointServiceHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -340,12 +340,12 @@ def __init__( region: str, project_id: str, endpoint_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -359,7 +359,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = EndpointServiceHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -436,17 +436,17 @@ def __init__( *, region: str, project_id: str, - filter: Optional[str] = None, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - read_mask: Optional[str] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + filter: str | None = None, + page_size: int | None = None, + page_token: str | None = None, + read_mask: str | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -464,7 +464,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = EndpointServiceHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -528,13 +528,13 @@ def __init__( project_id: str, endpoint_id: str, deployed_model_id: str, - traffic_split: Optional[Union[Sequence, Dict]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + traffic_split: Sequence | dict | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -550,7 +550,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = EndpointServiceHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -608,14 +608,14 @@ def __init__( project_id: str, region: str, endpoint_id: str, - endpoint: Union[Endpoint, Dict], - update_mask: Union[FieldMask, Dict], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + endpoint: Endpoint | dict, + update_mask: FieldMask | dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -631,7 +631,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = EndpointServiceHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py b/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py index 66ebe0d063678..da70afc1b2ef5 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains Google Vertex AI operators. .. spelling:: @@ -26,8 +25,9 @@ aiplatform myVPC """ +from __future__ import annotations -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Sequence from google.api_core.exceptions import NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -147,9 +147,9 @@ class CreateHyperparameterTuningJobOperator(BaseOperator): """ template_fields = [ - 'region', - 'project_id', - 'impersonation_chain', + "region", + "project_id", + "impersonation_chain", ] operator_extra_links = (VertexAITrainingLink(),) @@ -159,34 +159,34 @@ def __init__( project_id: str, region: str, display_name: str, - metric_spec: Dict[str, str], - parameter_spec: Dict[str, hyperparameter_tuning._ParameterSpec], + metric_spec: dict[str, str], + parameter_spec: dict[str, hyperparameter_tuning._ParameterSpec], max_trial_count: int, parallel_trial_count: int, # START: CustomJob param - worker_pool_specs: Union[List[Dict], List[gapic.WorkerPoolSpec]], - base_output_dir: Optional[str] = None, - custom_job_labels: Optional[Dict[str, str]] = None, - custom_job_encryption_spec_key_name: Optional[str] = None, - staging_bucket: Optional[str] = None, + worker_pool_specs: list[dict] | list[gapic.WorkerPoolSpec], + base_output_dir: str | None = None, + custom_job_labels: dict[str, str] | None = None, + custom_job_encryption_spec_key_name: str | None = None, + staging_bucket: str | None = None, # END: CustomJob param max_failed_trial_count: int = 0, - search_algorithm: Optional[str] = None, - measurement_selection: Optional[str] = "best", - hyperparameter_tuning_job_labels: Optional[Dict[str, str]] = None, - hyperparameter_tuning_job_encryption_spec_key_name: Optional[str] = None, + search_algorithm: str | None = None, + measurement_selection: str | None = "best", + hyperparameter_tuning_job_labels: dict[str, str] | None = None, + hyperparameter_tuning_job_encryption_spec_key_name: str | None = None, # START: run param - service_account: Optional[str] = None, - network: Optional[str] = None, - timeout: Optional[int] = None, # seconds + service_account: str | None = None, + network: str | None = None, + timeout: int | None = None, # seconds restart_job_on_worker_restart: bool = False, enable_web_access: bool = False, - tensorboard: Optional[str] = None, + tensorboard: str | None = None, sync: bool = True, # END: run param gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -219,9 +219,9 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - self.hook = None # type: Optional[HyperparameterTuningJobHook] + self.hook: HyperparameterTuningJobHook | None = None - def execute(self, context: "Context"): + def execute(self, context: Context): self.log.info("Creating Hyperparameter Tuning job") self.hook = HyperparameterTuningJobHook( gcp_conn_id=self.gcp_conn_id, @@ -311,12 +311,12 @@ def __init__( region: str, project_id: str, hyperparameter_tuning_job_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -330,7 +330,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = HyperparameterTuningJobHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -379,12 +379,12 @@ def __init__( hyperparameter_tuning_job_id: str, region: str, project_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -398,7 +398,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = HyperparameterTuningJobHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -458,16 +458,16 @@ def __init__( *, region: str, project_id: str, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - filter: Optional[str] = None, - read_mask: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + page_size: int | None = None, + page_token: str | None = None, + filter: str | None = None, + read_mask: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -484,7 +484,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = HyperparameterTuningJobHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/google/cloud/operators/vertex_ai/model_service.py b/airflow/providers/google/cloud/operators/vertex_ai/model_service.py index b6228d7c386b1..3b8c1e972f857 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/model_service.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/model_service.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains Google Vertex AI operators. .. spelling:: @@ -23,8 +22,9 @@ aiplatform camelCase """ +from __future__ import annotations -from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Sequence from google.api_core.exceptions import NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -75,12 +75,12 @@ def __init__( region: str, project_id: str, model_id: str, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -94,7 +94,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = ModelServiceHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -151,13 +151,13 @@ def __init__( project_id: str, region: str, model_id: str, - output_config: Union[model_service.ExportModelRequest.OutputConfig, Dict], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + output_config: model_service.ExportModelRequest.OutputConfig | dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -172,7 +172,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = ModelServiceHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -246,17 +246,17 @@ def __init__( *, region: str, project_id: str, - filter: Optional[str] = None, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - read_mask: Optional[str] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + filter: str | None = None, + page_size: int | None = None, + page_token: str | None = None, + read_mask: str | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -274,7 +274,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = ModelServiceHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -320,7 +320,7 @@ class UploadModelOperator(BaseOperator): account from the list granting this role to the originating account (templated). """ - template_fields = ("region", "project_id", "impersonation_chain") + template_fields = ("region", "project_id", "model", "impersonation_chain") operator_extra_links = (VertexAIModelLink(),) def __init__( @@ -328,13 +328,13 @@ def __init__( *, project_id: str, region: str, - model: Union[Model, Dict], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + model: Model | dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -348,7 +348,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context"): + def execute(self, context: Context): hook = ModelServiceHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/google/cloud/operators/video_intelligence.py b/airflow/providers/google/cloud/operators/video_intelligence.py index b6105ffa82c50..13e019f3f71ad 100644 --- a/airflow/providers/google/cloud/operators/video_intelligence.py +++ b/airflow/providers/google/cloud/operators/video_intelligence.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Cloud Vision operators.""" -from typing import TYPE_CHECKING, Dict, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -80,14 +82,14 @@ def __init__( self, *, input_uri: str, - input_content: Optional[bytes] = None, - output_uri: Optional[str] = None, - video_context: Union[Dict, VideoContext] = None, - location: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + input_content: bytes | None = None, + output_uri: str | None = None, + video_context: dict | VideoContext = None, + location: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -101,7 +103,7 @@ def __init__( self.timeout = timeout self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudVideoIntelligenceHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -170,14 +172,14 @@ def __init__( self, *, input_uri: str, - output_uri: Optional[str] = None, - input_content: Optional[bytes] = None, - video_context: Union[Dict, VideoContext] = None, - location: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + output_uri: str | None = None, + input_content: bytes | None = None, + video_context: dict | VideoContext = None, + location: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -191,7 +193,7 @@ def __init__( self.timeout = timeout self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudVideoIntelligenceHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -260,14 +262,14 @@ def __init__( self, *, input_uri: str, - output_uri: Optional[str] = None, - input_content: Optional[bytes] = None, - video_context: Union[Dict, VideoContext] = None, - location: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + output_uri: str | None = None, + input_content: bytes | None = None, + video_context: dict | VideoContext = None, + location: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -281,7 +283,7 @@ def __init__( self.timeout = timeout self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudVideoIntelligenceHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, diff --git a/airflow/providers/google/cloud/operators/vision.py b/airflow/providers/google/cloud/operators/vision.py index 7df9946c1ffe5..3f4f9f96cd301 100644 --- a/airflow/providers/google/cloud/operators/vision.py +++ b/airflow/providers/google/cloud/operators/vision.py @@ -16,9 +16,10 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Cloud Vision operator.""" +from __future__ import annotations from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Sequence, Tuple from google.api_core.exceptions import AlreadyExists from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -90,15 +91,15 @@ class CloudVisionCreateProductSetOperator(BaseOperator): def __init__( self, *, - product_set: Union[dict, ProductSet], + product_set: dict | ProductSet, location: str, - project_id: Optional[str] = None, - product_set_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + project_id: str | None = None, + product_set_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -112,7 +113,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudVisionHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -167,11 +168,11 @@ class CloudVisionGetProductSetOperator(BaseOperator): # [START vision_productset_get_template_fields] template_fields: Sequence[str] = ( - 'location', - 'project_id', - 'product_set_id', - 'gcp_conn_id', - 'impersonation_chain', + "location", + "project_id", + "product_set_id", + "gcp_conn_id", + "impersonation_chain", ) # [END vision_productset_get_template_fields] @@ -180,12 +181,12 @@ def __init__( *, location: str, product_set_id: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -198,7 +199,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudVisionHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -241,7 +242,7 @@ class CloudVisionUpdateProductSetOperator(BaseOperator): :param project_id: (Optional) The project in which the ProductSet should be created. If set to None or missing, the default project_id from the Google Cloud connection is used. :param update_mask: (Optional) The `FieldMask` that specifies which fields to update. If update_mask - isn’t specified, all mutable fields are to be updated. Valid mask path is display_name. If a dict is + isn't specified, all mutable fields are to be updated. Valid mask path is display_name. If a dict is provided, it must be of the same form as the protobuf message `FieldMask`. :param retry: (Optional) A retry object used to retry requests. If `None` is specified, requests will not be retried. @@ -262,27 +263,27 @@ class CloudVisionUpdateProductSetOperator(BaseOperator): # [START vision_productset_update_template_fields] template_fields: Sequence[str] = ( - 'location', - 'project_id', - 'product_set_id', - 'gcp_conn_id', - 'impersonation_chain', + "location", + "project_id", + "product_set_id", + "gcp_conn_id", + "impersonation_chain", ) # [END vision_productset_update_template_fields] def __init__( self, *, - product_set: Union[Dict, ProductSet], - location: Optional[str] = None, - product_set_id: Optional[str] = None, - project_id: Optional[str] = None, - update_mask: Union[Dict, FieldMask] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + product_set: dict | ProductSet, + location: str | None = None, + product_set_id: str | None = None, + project_id: str | None = None, + update_mask: dict | FieldMask = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -297,7 +298,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudVisionHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -348,11 +349,11 @@ class CloudVisionDeleteProductSetOperator(BaseOperator): # [START vision_productset_delete_template_fields] template_fields: Sequence[str] = ( - 'location', - 'project_id', - 'product_set_id', - 'gcp_conn_id', - 'impersonation_chain', + "location", + "project_id", + "product_set_id", + "gcp_conn_id", + "impersonation_chain", ) # [END vision_productset_delete_template_fields] @@ -361,12 +362,12 @@ def __init__( *, location: str, product_set_id: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -379,7 +380,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudVisionHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -437,11 +438,11 @@ class CloudVisionCreateProductOperator(BaseOperator): # [START vision_product_create_template_fields] template_fields: Sequence[str] = ( - 'location', - 'project_id', - 'product_id', - 'gcp_conn_id', - 'impersonation_chain', + "location", + "project_id", + "product_id", + "gcp_conn_id", + "impersonation_chain", ) # [END vision_product_create_template_fields] @@ -450,13 +451,13 @@ def __init__( *, location: str, product: str, - project_id: Optional[str] = None, - product_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + project_id: str | None = None, + product_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -470,7 +471,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudVisionHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -487,7 +488,7 @@ def execute(self, context: 'Context'): ) except AlreadyExists: self.log.info( - 'Product with id %s already exists. Exiting from the create operation.', self.product_id + "Product with id %s already exists. Exiting from the create operation.", self.product_id ) return self.product_id @@ -528,11 +529,11 @@ class CloudVisionGetProductOperator(BaseOperator): # [START vision_product_get_template_fields] template_fields: Sequence[str] = ( - 'location', - 'project_id', - 'product_id', - 'gcp_conn_id', - 'impersonation_chain', + "location", + "project_id", + "product_id", + "gcp_conn_id", + "impersonation_chain", ) # [END vision_product_get_template_fields] @@ -541,12 +542,12 @@ def __init__( *, location: str, product_id: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -559,7 +560,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudVisionHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -612,7 +613,7 @@ class CloudVisionUpdateProductOperator(BaseOperator): :param project_id: (Optional) The project in which the Product is located. If set to None or missing, the default project_id from the Google Cloud connection is used. :param update_mask: (Optional) The `FieldMask` that specifies which fields to update. If update_mask - isn’t specified, all mutable fields are to be updated. Valid mask paths include product_labels, + isn't specified, all mutable fields are to be updated. Valid mask paths include product_labels, display_name, and description. If a dict is provided, it must be of the same form as the protobuf message `FieldMask`. :param retry: (Optional) A retry object used to retry requests. If `None` is @@ -634,27 +635,27 @@ class CloudVisionUpdateProductOperator(BaseOperator): # [START vision_product_update_template_fields] template_fields: Sequence[str] = ( - 'location', - 'project_id', - 'product_id', - 'gcp_conn_id', - 'impersonation_chain', + "location", + "project_id", + "product_id", + "gcp_conn_id", + "impersonation_chain", ) # [END vision_product_update_template_fields] def __init__( self, *, - product: Union[Dict, Product], - location: Optional[str] = None, - product_id: Optional[str] = None, - project_id: Optional[str] = None, - update_mask: Union[Dict, FieldMask] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + product: dict | Product, + location: str | None = None, + product_id: str | None = None, + project_id: str | None = None, + update_mask: dict | FieldMask = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -669,7 +670,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudVisionHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -725,11 +726,11 @@ class CloudVisionDeleteProductOperator(BaseOperator): # [START vision_product_delete_template_fields] template_fields: Sequence[str] = ( - 'location', - 'project_id', - 'product_id', - 'gcp_conn_id', - 'impersonation_chain', + "location", + "project_id", + "product_id", + "gcp_conn_id", + "impersonation_chain", ) # [END vision_product_delete_template_fields] @@ -738,12 +739,12 @@ def __init__( *, location: str, product_id: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -756,7 +757,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudVisionHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -800,20 +801,20 @@ class CloudVisionImageAnnotateOperator(BaseOperator): # [START vision_annotate_image_template_fields] template_fields: Sequence[str] = ( - 'request', - 'gcp_conn_id', - 'impersonation_chain', + "request", + "gcp_conn_id", + "impersonation_chain", ) # [END vision_annotate_image_template_fields] def __init__( self, *, - request: Union[Dict, AnnotateImageRequest], - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + request: dict | AnnotateImageRequest, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -823,7 +824,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudVisionHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -893,15 +894,15 @@ def __init__( self, *, location: str, - reference_image: Union[Dict, ReferenceImage], + reference_image: dict | ReferenceImage, product_id: str, - reference_image_id: Optional[str] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + reference_image_id: str | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -916,7 +917,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): try: hook = CloudVisionHook( gcp_conn_id=self.gcp_conn_id, @@ -991,12 +992,12 @@ def __init__( location: str, product_id: str, reference_image_id: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1010,7 +1011,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudVisionHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1034,7 +1035,7 @@ class CloudVisionAddProductToProductSetOperator(BaseOperator): Possible errors: - - Returns `NOT_FOUND` if the Product or the ProductSet doesn’t exist. + - Returns `NOT_FOUND` if the Product or the ProductSet doesn't exist. .. seealso:: For more information on how to use this operator, take a look at the guide: @@ -1080,12 +1081,12 @@ def __init__( product_set_id: str, product_id: str, location: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1099,7 +1100,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudVisionHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1163,12 +1164,12 @@ def __init__( product_set_id: str, product_id: str, location: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, metadata: MetaData = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1182,7 +1183,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudVisionHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1241,15 +1242,15 @@ class CloudVisionDetectTextOperator(BaseOperator): def __init__( self, - image: Union[Dict, Image], - max_results: Optional[int] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - language_hints: Optional[Union[str, List[str]]] = None, - web_detection_params: Optional[Dict] = None, - additional_properties: Optional[Dict] = None, + image: dict | Image, + max_results: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + language_hints: str | list[str] | None = None, + web_detection_params: dict | None = None, + additional_properties: dict | None = None, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1266,7 +1267,7 @@ def __init__( ) self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudVisionHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1323,15 +1324,15 @@ class CloudVisionTextDetectOperator(BaseOperator): def __init__( self, - image: Union[Dict, Image], - max_results: Optional[int] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - language_hints: Optional[Union[str, List[str]]] = None, - web_detection_params: Optional[Dict] = None, - additional_properties: Optional[Dict] = None, + image: dict | Image, + max_results: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + language_hints: str | list[str] | None = None, + web_detection_params: dict | None = None, + additional_properties: dict | None = None, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1347,7 +1348,7 @@ def __init__( ) self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudVisionHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1400,13 +1401,13 @@ class CloudVisionDetectImageLabelsOperator(BaseOperator): def __init__( self, - image: Union[Dict, Image], - max_results: Optional[int] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - additional_properties: Optional[Dict] = None, + image: dict | Image, + max_results: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + additional_properties: dict | None = None, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1418,7 +1419,7 @@ def __init__( self.additional_properties = additional_properties self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudVisionHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1471,13 +1472,13 @@ class CloudVisionDetectImageSafeSearchOperator(BaseOperator): def __init__( self, - image: Union[Dict, Image], - max_results: Optional[int] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - additional_properties: Optional[Dict] = None, + image: dict | Image, + max_results: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + additional_properties: dict | None = None, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1489,7 +1490,7 @@ def __init__( self.additional_properties = additional_properties self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudVisionHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -1504,8 +1505,8 @@ def execute(self, context: 'Context'): def prepare_additional_parameters( - additional_properties: Optional[Dict], language_hints: Any, web_detection_params: Any -) -> Optional[Dict]: + additional_properties: dict | None, language_hints: Any, web_detection_params: Any +) -> dict | None: """ Creates additional_properties parameter based on language_hints, web_detection_params and additional_properties parameters specified by the user @@ -1518,14 +1519,14 @@ def prepare_additional_parameters( merged_additional_parameters = deepcopy(additional_properties) - if 'image_context' not in merged_additional_parameters: - merged_additional_parameters['image_context'] = {} + if "image_context" not in merged_additional_parameters: + merged_additional_parameters["image_context"] = {} - merged_additional_parameters['image_context']['language_hints'] = merged_additional_parameters[ - 'image_context' - ].get('language_hints', language_hints) - merged_additional_parameters['image_context']['web_detection_params'] = merged_additional_parameters[ - 'image_context' - ].get('web_detection_params', web_detection_params) + merged_additional_parameters["image_context"]["language_hints"] = merged_additional_parameters[ + "image_context" + ].get("language_hints", language_hints) + merged_additional_parameters["image_context"]["web_detection_params"] = merged_additional_parameters[ + "image_context" + ].get("web_detection_params", web_detection_params) return merged_additional_parameters diff --git a/airflow/providers/google/cloud/operators/workflows.py b/airflow/providers/google/cloud/operators/workflows.py index 157d34a057e05..e39cfc3ffbc31 100644 --- a/airflow/providers/google/cloud/operators/workflows.py +++ b/airflow/providers/google/cloud/operators/workflows.py @@ -14,12 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import hashlib import json import re import uuid from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Sequence import pytz from google.api_core.exceptions import AlreadyExists @@ -31,6 +33,11 @@ from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.workflows import WorkflowsHook +from airflow.providers.google.cloud.links.workflows import ( + WorkflowsExecutionLink, + WorkflowsListOfWorkflowsLink, + WorkflowsWorkflowDetailsLink, +) if TYPE_CHECKING: from airflow.utils.context import Context @@ -60,20 +67,21 @@ class WorkflowsCreateWorkflowOperator(BaseOperator): template_fields: Sequence[str] = ("location", "workflow", "workflow_id") template_fields_renderers = {"workflow": "json"} + operator_extra_links = (WorkflowsWorkflowDetailsLink(),) def __init__( self, *, - workflow: Dict, + workflow: dict, workflow_id: str, location: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", force_rerun: bool = False, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -102,12 +110,12 @@ def _workflow_id(self, context): # We are limited by allowed length of workflow_id so # we use hash of whole information - exec_date = context['execution_date'].isoformat() + exec_date = context["execution_date"].isoformat() base = f"airflow_{self.dag_id}_{self.task_id}_{exec_date}_{hash_base}" workflow_id = hashlib.md5(base.encode()).hexdigest() return re.sub(r"[:\-+.]", "_", workflow_id) - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = WorkflowsHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) workflow_id = self._workflow_id(context) @@ -132,6 +140,15 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + WorkflowsWorkflowDetailsLink.persist( + context=context, + task_instance=self, + location_id=self.location, + workflow_id=self.workflow_id, + project_id=self.project_id or hook.project_id, + ) + return Workflow.to_dict(workflow) @@ -162,19 +179,20 @@ class WorkflowsUpdateWorkflowOperator(BaseOperator): template_fields: Sequence[str] = ("workflow_id", "update_mask") template_fields_renderers = {"update_mask": "json"} + operator_extra_links = (WorkflowsWorkflowDetailsLink(),) def __init__( self, *, workflow_id: str, location: str, - project_id: Optional[str] = None, - update_mask: Optional[FieldMask] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + update_mask: FieldMask | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -189,7 +207,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = WorkflowsHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) workflow = hook.get_workflow( @@ -209,6 +227,15 @@ def execute(self, context: 'Context'): metadata=self.metadata, ) workflow = operation.result() + + WorkflowsWorkflowDetailsLink.persist( + context=context, + task_instance=self, + location_id=self.location, + workflow_id=self.workflow_id, + project_id=self.project_id or hook.project_id, + ) + return Workflow.to_dict(workflow) @@ -239,12 +266,12 @@ def __init__( *, workflow_id: str, location: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -258,7 +285,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = WorkflowsHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) self.log.info("Deleting workflow %s", self.workflow_id) operation = hook.delete_workflow( @@ -296,19 +323,20 @@ class WorkflowsListWorkflowsOperator(BaseOperator): """ template_fields: Sequence[str] = ("location", "order_by", "filter_") + operator_extra_links = (WorkflowsListOfWorkflowsLink(),) def __init__( self, *, location: str, - project_id: Optional[str] = None, - filter_: Optional[str] = None, - order_by: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + filter_: str | None = None, + order_by: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -323,7 +351,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = WorkflowsHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) self.log.info("Retrieving workflows") workflows_iter = hook.list_workflows( @@ -335,6 +363,13 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + WorkflowsListOfWorkflowsLink.persist( + context=context, + task_instance=self, + project_id=self.project_id or hook.project_id, + ) + return [Workflow.to_dict(w) for w in workflows_iter] @@ -357,18 +392,19 @@ class WorkflowsGetWorkflowOperator(BaseOperator): """ template_fields: Sequence[str] = ("location", "workflow_id") + operator_extra_links = (WorkflowsWorkflowDetailsLink(),) def __init__( self, *, workflow_id: str, location: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -382,7 +418,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = WorkflowsHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) self.log.info("Retrieving workflow") workflow = hook.get_workflow( @@ -393,6 +429,15 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + WorkflowsWorkflowDetailsLink.persist( + context=context, + task_instance=self, + location_id=self.location, + workflow_id=self.workflow_id, + project_id=self.project_id or hook.project_id, + ) + return Workflow.to_dict(workflow) @@ -418,19 +463,20 @@ class WorkflowsCreateExecutionOperator(BaseOperator): template_fields: Sequence[str] = ("location", "workflow_id", "execution") template_fields_renderers = {"execution": "json"} + operator_extra_links = (WorkflowsExecutionLink(),) def __init__( self, *, workflow_id: str, - execution: Dict, + execution: dict, location: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -445,7 +491,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = WorkflowsHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) self.log.info("Creating execution") execution = hook.create_execution( @@ -459,6 +505,16 @@ def execute(self, context: 'Context'): ) execution_id = execution.name.split("/")[-1] self.xcom_push(context, key="execution_id", value=execution_id) + + WorkflowsExecutionLink.persist( + context=context, + task_instance=self, + location_id=self.location, + workflow_id=self.workflow_id, + execution_id=execution_id, + project_id=self.project_id or hook.project_id, + ) + return Execution.to_dict(execution) @@ -482,6 +538,7 @@ class WorkflowsCancelExecutionOperator(BaseOperator): """ template_fields: Sequence[str] = ("location", "workflow_id", "execution_id") + operator_extra_links = (WorkflowsExecutionLink(),) def __init__( self, @@ -489,12 +546,12 @@ def __init__( workflow_id: str, execution_id: str, location: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -509,7 +566,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = WorkflowsHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) self.log.info("Canceling execution %s", self.execution_id) execution = hook.cancel_execution( @@ -521,6 +578,16 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + WorkflowsExecutionLink.persist( + context=context, + task_instance=self, + location_id=self.location, + workflow_id=self.workflow_id, + execution_id=self.execution_id, + project_id=self.project_id or hook.project_id, + ) + return Execution.to_dict(execution) @@ -549,19 +616,20 @@ class WorkflowsListExecutionsOperator(BaseOperator): """ template_fields: Sequence[str] = ("location", "workflow_id") + operator_extra_links = (WorkflowsWorkflowDetailsLink(),) def __init__( self, *, workflow_id: str, location: str, - start_date_filter: Optional[datetime] = None, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + start_date_filter: datetime | None = None, + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -576,7 +644,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = WorkflowsHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) self.log.info("Retrieving executions for workflow %s", self.workflow_id) execution_iter = hook.list_executions( @@ -588,6 +656,14 @@ def execute(self, context: 'Context'): metadata=self.metadata, ) + WorkflowsWorkflowDetailsLink.persist( + context=context, + task_instance=self, + location_id=self.location, + workflow_id=self.workflow_id, + project_id=self.project_id or hook.project_id, + ) + return [Execution.to_dict(e) for e in execution_iter if e.start_time > self.start_date_filter] @@ -611,6 +687,7 @@ class WorkflowsGetExecutionOperator(BaseOperator): """ template_fields: Sequence[str] = ("location", "workflow_id", "execution_id") + operator_extra_links = (WorkflowsExecutionLink(),) def __init__( self, @@ -618,12 +695,12 @@ def __init__( workflow_id: str, execution_id: str, location: str, - project_id: Optional[str] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -638,7 +715,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = WorkflowsHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) self.log.info("Retrieving execution %s for workflow %s", self.execution_id, self.workflow_id) execution = hook.get_execution( @@ -650,4 +727,14 @@ def execute(self, context: 'Context'): timeout=self.timeout, metadata=self.metadata, ) + + WorkflowsExecutionLink.persist( + context=context, + task_instance=self, + location_id=self.location, + workflow_id=self.workflow_id, + execution_id=self.execution_id, + project_id=self.project_id or hook.project_id, + ) + return Execution.to_dict(execution) diff --git a/airflow/providers/google/cloud/secrets/secret_manager.py b/airflow/providers/google/cloud/secrets/secret_manager.py index 845b07449b158..7e1d97e60d6fd 100644 --- a/airflow/providers/google/cloud/secrets/secret_manager.py +++ b/airflow/providers/google/cloud/secrets/secret_manager.py @@ -14,12 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Objects relating to sourcing connections from Google Cloud Secrets Manager""" +from __future__ import annotations + import logging import re import warnings -from typing import Optional from google.auth.exceptions import DefaultCredentialsError @@ -36,8 +36,8 @@ def _parse_version(val): - val = re.sub(r'(\d+\.\d+\.\d+).*', lambda x: x.group(1), val) - return tuple(int(x) for x in val.split('.')) + val = re.sub(r"(\d+\.\d+\.\d+).*", lambda x: x.group(1), val) + return tuple(int(x) for x in val.split(".")) class CloudSecretManagerBackend(BaseSecretsBackend, LoggingMixin): @@ -83,10 +83,10 @@ def __init__( connections_prefix: str = "airflow-connections", variables_prefix: str = "airflow-variables", config_prefix: str = "airflow-config", - gcp_keyfile_dict: Optional[dict] = None, - gcp_key_path: Optional[str] = None, - gcp_scopes: Optional[str] = None, - project_id: Optional[str] = None, + gcp_keyfile_dict: dict | None = None, + gcp_key_path: str | None = None, + gcp_scopes: str | None = None, + project_id: str | None = None, sep: str = "-", **kwargs, ) -> None: @@ -107,9 +107,9 @@ def __init__( ) except (DefaultCredentialsError, FileNotFoundError): log.exception( - 'Unable to load credentials for GCP Secret Manager. ' - 'Make sure that the keyfile path, dictionary, or GOOGLE_APPLICATION_CREDENTIALS ' - 'environment variable is correct and properly configured.' + "Unable to load credentials for GCP Secret Manager. " + "Make sure that the keyfile path, dictionary, or GOOGLE_APPLICATION_CREDENTIALS " + "environment variable is correct and properly configured." ) # In case project id provided @@ -129,7 +129,7 @@ def _is_valid_prefix_and_sep(self) -> bool: prefix = self.connections_prefix + self.sep return _SecretManagerClient.is_valid_secret_name(prefix) - def get_conn_value(self, conn_id: str) -> Optional[str]: + def get_conn_value(self, conn_id: str) -> str | None: """ Get serialized representation of Connection @@ -140,7 +140,7 @@ def get_conn_value(self, conn_id: str) -> Optional[str]: return self._get_secret(self.connections_prefix, conn_id) - def get_conn_uri(self, conn_id: str) -> Optional[str]: + def get_conn_uri(self, conn_id: str) -> str | None: """ Return URI representation of Connection conn_id. @@ -158,7 +158,7 @@ def get_conn_uri(self, conn_id: str) -> Optional[str]: ) return self.get_conn_value(conn_id) - def get_variable(self, key: str) -> Optional[str]: + def get_variable(self, key: str) -> str | None: """ Get Airflow Variable from Environment Variable @@ -170,7 +170,7 @@ def get_variable(self, key: str) -> Optional[str]: return self._get_secret(self.variables_prefix, key) - def get_config(self, key: str) -> Optional[str]: + def get_config(self, key: str) -> str | None: """ Get Airflow Configuration @@ -182,7 +182,7 @@ def get_config(self, key: str) -> Optional[str]: return self._get_secret(self.config_prefix, key) - def _get_secret(self, path_prefix: str, secret_id: str) -> Optional[str]: + def _get_secret(self, path_prefix: str, secret_id: str) -> str | None: """ Get secret value from the SecretManager based on prefix. diff --git a/airflow/providers/google/cloud/sensors/bigquery.py b/airflow/providers/google/cloud/sensors/bigquery.py index f0a9d67f58e5e..5ffe3bac56332 100644 --- a/airflow/providers/google/cloud/sensors/bigquery.py +++ b/airflow/providers/google/cloud/sensors/bigquery.py @@ -15,10 +15,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""This module contains a Google Bigquery sensor.""" -from typing import TYPE_CHECKING, Optional, Sequence, Union +"""This module contains Google BigQuery sensors.""" +from __future__ import annotations +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Sequence + +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook +from airflow.providers.google.cloud.triggers.bigquery import BigQueryTableExistenceTrigger from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: @@ -50,12 +55,12 @@ class BigQueryTableExistenceSensor(BaseSensorOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'dataset_id', - 'table_id', - 'impersonation_chain', + "project_id", + "dataset_id", + "table_id", + "impersonation_chain", ) - ui_color = '#f0eee4' + ui_color = "#f0eee4" def __init__( self, @@ -63,12 +68,11 @@ def __init__( project_id: str, dataset_id: str, table_id: str, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: - super().__init__(**kwargs) self.project_id = project_id @@ -78,9 +82,9 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def poke(self, context: 'Context') -> bool: - table_uri = f'{self.project_id}:{self.dataset_id}.{self.table_id}' - self.log.info('Sensor checks existence of table: %s', table_uri) + def poke(self, context: Context) -> bool: + table_uri = f"{self.project_id}:{self.dataset_id}.{self.table_id}" + self.log.info("Sensor checks existence of table: %s", table_uri) hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -117,13 +121,13 @@ class BigQueryTablePartitionExistenceSensor(BaseSensorOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'dataset_id', - 'table_id', - 'partition_id', - 'impersonation_chain', + "project_id", + "dataset_id", + "table_id", + "partition_id", + "impersonation_chain", ) - ui_color = '#f0eee4' + ui_color = "#f0eee4" def __init__( self, @@ -132,12 +136,11 @@ def __init__( dataset_id: str, table_id: str, partition_id: str, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: - super().__init__(**kwargs) self.project_id = project_id @@ -148,8 +151,8 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def poke(self, context: 'Context') -> bool: - table_uri = f'{self.project_id}:{self.dataset_id}.{self.table_id}' + def poke(self, context: Context) -> bool: + table_uri = f"{self.project_id}:{self.dataset_id}.{self.table_id}" self.log.info('Sensor checks existence of partition: "%s" in table: %s', self.partition_id, table_uri) hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, @@ -162,3 +165,73 @@ def poke(self, context: 'Context') -> bool: table_id=self.table_id, partition_id=self.partition_id, ) + + +class BigQueryTableExistenceAsyncSensor(BigQueryTableExistenceSensor): + """ + Checks for the existence of a table in Google Big Query. + + :param project_id: The Google cloud project in which to look for the table. + The connection supplied to the hook must provide + access to the specified project. + :param dataset_id: The name of the dataset in which to look for the table. + storage bucket. + :param table_id: The name of the table to check the existence of. + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param polling_interval: The interval in seconds to wait between checks table existence. + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + polling_interval: float = 5.0, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.polling_interval = polling_interval + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Context) -> None: + """Airflow runs this method on the worker and defers using the trigger.""" + self.defer( + timeout=timedelta(seconds=self.timeout), + trigger=BigQueryTableExistenceTrigger( + dataset_id=self.dataset_id, + table_id=self.table_id, + project_id=self.project_id, + poll_interval=self.polling_interval, + gcp_conn_id=self.gcp_conn_id, + hook_params={ + "delegate_to": self.delegate_to, + "impersonation_chain": self.impersonation_chain, + }, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None = None) -> str: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + table_uri = f"{self.project_id}:{self.dataset_id}.{self.table_id}" + self.log.info("Sensor checks existence of table: %s", table_uri) + if event: + if event["status"] == "success": + return event["message"] + raise AirflowException(event["message"]) + raise AirflowException("No event received in trigger callback") diff --git a/airflow/providers/google/cloud/sensors/bigquery_dts.py b/airflow/providers/google/cloud/sensors/bigquery_dts.py index ef92601c0acf2..9dc2623783c18 100644 --- a/airflow/providers/google/cloud/sensors/bigquery_dts.py +++ b/airflow/providers/google/cloud/sensors/bigquery_dts.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google BigQuery Data Transfer Service sensor.""" -from typing import TYPE_CHECKING, Optional, Sequence, Set, Tuple, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -76,16 +78,16 @@ def __init__( *, run_id: str, transfer_config_id: str, - expected_statuses: Union[ - Set[Union[str, TransferState, int]], str, TransferState, int - ] = TransferState.SUCCEEDED, - project_id: Optional[str] = None, + expected_statuses: ( + set[str | TransferState | int] | str | TransferState | int + ) = TransferState.SUCCEEDED, + project_id: str | None = None, gcp_conn_id: str = "google_cloud_default", - retry: Union[Retry, _MethodDefault] = DEFAULT, - request_timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - location: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + retry: Retry | _MethodDefault = DEFAULT, + request_timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + location: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -100,7 +102,7 @@ def __init__( self.impersonation_chain = impersonation_chain self.location = location - def _normalize_state_list(self, states) -> Set[TransferState]: + def _normalize_state_list(self, states) -> set[TransferState]: states = {states} if isinstance(states, (str, TransferState, int)) else states result = set() for state in states: @@ -120,7 +122,7 @@ def _normalize_state_list(self, states) -> Set[TransferState]: ) return result - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: hook = BiqQueryDataTransferServiceHook( gcp_conn_id=self.gcp_cloud_conn_id, impersonation_chain=self.impersonation_chain, diff --git a/airflow/providers/google/cloud/sensors/bigtable.py b/airflow/providers/google/cloud/sensors/bigtable.py index 63a1ae35e6825..3dd689f8c6087 100644 --- a/airflow/providers/google/cloud/sensors/bigtable.py +++ b/airflow/providers/google/cloud/sensors/bigtable.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Cloud Bigtable sensor.""" -from typing import TYPE_CHECKING, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence import google.api_core.exceptions from google.cloud.bigtable.table import ClusterState @@ -56,12 +58,12 @@ class BigtableTableReplicationCompletedSensor(BaseSensorOperator, BigtableValida account from the list granting this role to the originating account (templated). """ - REQUIRED_ATTRIBUTES = ('instance_id', 'table_id') + REQUIRED_ATTRIBUTES = ("instance_id", "table_id") template_fields: Sequence[str] = ( - 'project_id', - 'instance_id', - 'table_id', - 'impersonation_chain', + "project_id", + "instance_id", + "table_id", + "impersonation_chain", ) operator_extra_links = (BigtableTablesLink(),) @@ -70,9 +72,9 @@ def __init__( *, instance_id: str, table_id: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: self.project_id = project_id @@ -83,7 +85,7 @@ def __init__( self.impersonation_chain = impersonation_chain super().__init__(**kwargs) - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: hook = BigtableHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, diff --git a/airflow/providers/google/cloud/sensors/cloud_composer.py b/airflow/providers/google/cloud/sensors/cloud_composer.py new file mode 100644 index 0000000000000..ad4ae923314f4 --- /dev/null +++ b/airflow/providers/google/cloud/sensors/cloud_composer.py @@ -0,0 +1,99 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Cloud Composer sensor.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.triggers.cloud_composer import CloudComposerExecutionTrigger +from airflow.sensors.base import BaseSensorOperator + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class CloudComposerEnvironmentSensor(BaseSensorOperator): + """ + Check the status of the Cloud Composer Environment task + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param operation_name: The name of the operation resource + :param gcp_conn_id: The connection ID to use when fetching connection info. + :param delegate_to: The account to impersonate, if any. For this to work, the service account making the + request must have domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param pooling_period_seconds: Optional: Control the rate of the poll for the result of deferrable run. + """ + + def __init__( + self, + *, + project_id: str, + region: str, + operation_name: str, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + pooling_period_seconds: int = 30, + **kwargs, + ): + super().__init__(**kwargs) + self.project_id = project_id + self.region = region + self.operation_name = operation_name + self.pooling_period_seconds = pooling_period_seconds + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context) -> None: + """Airflow runs this method on the worker and defers using the trigger.""" + self.defer( + trigger=CloudComposerExecutionTrigger( + project_id=self.project_id, + region=self.region, + operation_name=self.operation_name, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + delegate_to=self.delegate_to, + pooling_period_seconds=self.pooling_period_seconds, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None = None) -> str: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event: + if event.get("operation_done"): + return event["operation_done"] + raise AirflowException(event["message"]) + raise AirflowException("No event received in trigger callback") diff --git a/airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py b/airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py index 5cb63a37f6a9b..67f64e2497646 100644 --- a/airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +++ b/airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Cloud Transfer sensor.""" -from typing import TYPE_CHECKING, Optional, Sequence, Set, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( COUNTERS, @@ -24,6 +26,7 @@ NAME, CloudDataTransferServiceHook, ) +from airflow.providers.google.cloud.links.cloud_storage_transfer import CloudStorageTransferJobLink from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: @@ -59,19 +62,20 @@ class CloudDataTransferServiceJobStatusSensor(BaseSensorOperator): # [START gcp_transfer_job_sensor_template_fields] template_fields: Sequence[str] = ( - 'job_name', - 'impersonation_chain', + "job_name", + "impersonation_chain", ) # [END gcp_transfer_job_sensor_template_fields] + operator_extra_links = (CloudStorageTransferJobLink(),) def __init__( self, *, job_name: str, - expected_statuses: Union[Set[str], str], - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + expected_statuses: set[str] | str, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -83,13 +87,13 @@ def __init__( self.gcp_cloud_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: hook = CloudDataTransferServiceHook( gcp_conn_id=self.gcp_cloud_conn_id, impersonation_chain=self.impersonation_chain, ) operations = hook.list_transfer_operations( - request_filter={'project_id': self.project_id, 'job_names': [self.job_name]} + request_filter={"project_id": self.project_id, "job_names": [self.job_name]} ) for operation in operations: @@ -101,4 +105,13 @@ def poke(self, context: 'Context') -> bool: if check: self.xcom_push(key="sensed_operations", value=operations, context=context) + project_id = self.project_id or hook.project_id + if project_id: + CloudStorageTransferJobLink.persist( + context=context, + task_instance=self, + project_id=project_id, + job_name=self.job_name, + ) + return check diff --git a/airflow/providers/google/cloud/sensors/dataflow.py b/airflow/providers/google/cloud/sensors/dataflow.py index b423f0fa3ecb3..491dead285570 100644 --- a/airflow/providers/google/cloud/sensors/dataflow.py +++ b/airflow/providers/google/cloud/sensors/dataflow.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Cloud Dataflow sensor.""" -from typing import TYPE_CHECKING, Callable, Optional, Sequence, Set, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Sequence from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.dataflow import ( @@ -61,18 +63,18 @@ class DataflowJobStatusSensor(BaseSensorOperator): account from the list granting this role to the originating account (templated). """ - template_fields: Sequence[str] = ('job_id',) + template_fields: Sequence[str] = ("job_id",) def __init__( self, *, job_id: str, - expected_statuses: Union[Set[str], str], - project_id: Optional[str] = None, + expected_statuses: set[str] | str, + project_id: str | None = None, location: str = DEFAULT_DATAFLOW_LOCATION, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -85,9 +87,9 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - self.hook: Optional[DataflowHook] = None + self.hook: DataflowHook | None = None - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: self.log.info( "Waiting for job %s to be in one of the states: %s.", self.job_id, @@ -148,7 +150,7 @@ class DataflowJobMetricsSensor(BaseSensorOperator): account from the list granting this role to the originating account (templated). """ - template_fields: Sequence[str] = ('job_id',) + template_fields: Sequence[str] = ("job_id",) def __init__( self, @@ -156,11 +158,11 @@ def __init__( job_id: str, callback: Callable[[dict], bool], fail_on_terminal_state: bool = True, - project_id: Optional[str] = None, + project_id: str | None = None, location: str = DEFAULT_DATAFLOW_LOCATION, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -172,9 +174,9 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - self.hook: Optional[DataflowHook] = None + self.hook: DataflowHook | None = None - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: self.hook = DataflowHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -233,7 +235,7 @@ class DataflowJobMessagesSensor(BaseSensorOperator): account from the list granting this role to the originating account (templated). """ - template_fields: Sequence[str] = ('job_id',) + template_fields: Sequence[str] = ("job_id",) def __init__( self, @@ -241,11 +243,11 @@ def __init__( job_id: str, callback: Callable, fail_on_terminal_state: bool = True, - project_id: Optional[str] = None, + project_id: str | None = None, location: str = DEFAULT_DATAFLOW_LOCATION, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -257,9 +259,9 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - self.hook: Optional[DataflowHook] = None + self.hook: DataflowHook | None = None - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: self.hook = DataflowHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -318,7 +320,7 @@ class DataflowJobAutoScalingEventsSensor(BaseSensorOperator): account from the list granting this role to the originating account (templated). """ - template_fields: Sequence[str] = ('job_id',) + template_fields: Sequence[str] = ("job_id",) def __init__( self, @@ -326,11 +328,11 @@ def __init__( job_id: str, callback: Callable, fail_on_terminal_state: bool = True, - project_id: Optional[str] = None, + project_id: str | None = None, location: str = DEFAULT_DATAFLOW_LOCATION, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -342,9 +344,9 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - self.hook: Optional[DataflowHook] = None + self.hook: DataflowHook | None = None - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: self.hook = DataflowHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/google/cloud/sensors/dataform.py b/airflow/providers/google/cloud/sensors/dataform.py new file mode 100644 index 0000000000000..412a7e5c80800 --- /dev/null +++ b/airflow/providers/google/cloud/sensors/dataform.py @@ -0,0 +1,110 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Dataform sensor.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable, Sequence + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.dataform import DataformHook +from airflow.sensors.base import BaseSensorOperator + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class DataformWorkflowInvocationStateSensor(BaseSensorOperator): + """ + Checks for the status of a Workflow Invocation in Google Cloud Dataform. + + :param project_id: Required, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param region: Required, The location of the Dataform workflow invocation (for example europe-west1). + :param repository_id: Required. The ID of the Dataform repository that the task belongs to. + :param workflow_invocation_id: Required, ID of the workflow invocation to be checked. + :param expected_statuses: The expected state of the operation. + See: + https://cloud.google.com/python/docs/reference/dataform/latest/google.cloud.dataform_v1beta1.types.WorkflowInvocation.State + :param failure_statuses: State that will terminate the sensor with an exception + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. See: + https://developers.google.com/identity/protocols/oauth2/service-account#delegatingauthority + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ("workflow_invocation_id",) + + def __init__( + self, + *, + project_id: str, + region: str, + repository_id: str, + workflow_invocation_id: str, + expected_statuses: set[int] | int, + failure_statuses: Iterable[int] | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.repository_id = repository_id + self.workflow_invocation_id = workflow_invocation_id + self.expected_statuses = ( + {expected_statuses} if isinstance(expected_statuses, int) else expected_statuses + ) + self.failure_statuses = failure_statuses + self.project_id = project_id + self.region = region + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.hook: DataformHook | None = None + + def poke(self, context: Context) -> bool: + self.hook = DataformHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + workflow_invocation = self.hook.get_workflow_invocation( + project_id=self.project_id, + region=self.region, + repository_id=self.repository_id, + workflow_invocation_id=self.workflow_invocation_id, + ) + workflow_status = workflow_invocation.state + if workflow_status is not None: + if self.failure_statuses and workflow_status in self.failure_statuses: + raise AirflowException( + f"Workflow Invocation with id '{self.workflow_invocation_id}' " + f"state is: {workflow_status}. Terminating sensor..." + ) + + return workflow_status in self.expected_statuses diff --git a/airflow/providers/google/cloud/sensors/datafusion.py b/airflow/providers/google/cloud/sensors/datafusion.py index 63776e28e8c23..f15f423a4a551 100644 --- a/airflow/providers/google/cloud/sensors/datafusion.py +++ b/airflow/providers/google/cloud/sensors/datafusion.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Cloud Data Fusion sensors.""" -from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable, Sequence from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.datafusion import DataFusionHook @@ -55,7 +57,7 @@ class CloudDataFusionPipelineStateSensor(BaseSensorOperator): """ - template_fields: Sequence[str] = ('pipeline_id',) + template_fields: Sequence[str] = ("pipeline_id",) def __init__( self, @@ -64,12 +66,12 @@ def __init__( expected_statuses: Iterable[str], instance_name: str, location: str, - failure_statuses: Optional[Iterable[str]] = None, - project_id: Optional[str] = None, + failure_statuses: Iterable[str] | None = None, + project_id: str | None = None, namespace: str = "default", - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -85,7 +87,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: self.log.info( "Waiting for pipeline %s to be in one of the states: %s.", self.pipeline_id, diff --git a/airflow/providers/google/cloud/sensors/dataplex.py b/airflow/providers/google/cloud/sensors/dataplex.py index c8e46c6fb993b..19c56610a78c5 100644 --- a/airflow/providers/google/cloud/sensors/dataplex.py +++ b/airflow/providers/google/cloud/sensors/dataplex.py @@ -14,9 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains Google Dataplex sensors.""" -from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence if TYPE_CHECKING: from airflow.utils.context import Context @@ -64,7 +65,7 @@ class DataplexTaskStateSensor(BaseSensorOperator): account from the list granting this role to the originating account (templated). """ - template_fields = ['dataplex_task_id'] + template_fields = ["dataplex_task_id"] def __init__( self, @@ -73,11 +74,11 @@ def __init__( lake_id: str, dataplex_task_id: str, api_version: str = "v1", - retry: Union[Retry, _MethodDefault] = DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + retry: Retry | _MethodDefault = DEFAULT, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, *args, **kwargs, ) -> None: @@ -93,7 +94,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def poke(self, context: "Context") -> bool: + def poke(self, context: Context) -> bool: self.log.info("Waiting for task %s to be %s", self.dataplex_task_id, TaskState.ACTIVE) hook = DataplexHook( gcp_conn_id=self.gcp_conn_id, diff --git a/airflow/providers/google/cloud/sensors/dataprep.py b/airflow/providers/google/cloud/sensors/dataprep.py new file mode 100644 index 0000000000000..d30f6e18e872c --- /dev/null +++ b/airflow/providers/google/cloud/sensors/dataprep.py @@ -0,0 +1,53 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Dataprep Job sensor.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence + +from airflow.providers.google.cloud.hooks.dataprep import GoogleDataprepHook, JobGroupStatuses +from airflow.sensors.base import BaseSensorOperator + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class DataprepJobGroupIsFinishedSensor(BaseSensorOperator): + """ + Check the status of the Dataprep task to be finished. + + :param job_group_id: ID of the job group to check + """ + + template_fields: Sequence[str] = ("job_group_id",) + + def __init__( + self, + *, + job_group_id: int | str, + dataprep_conn_id: str = "dataprep_default", + **kwargs, + ): + super().__init__(**kwargs) + self.job_group_id = job_group_id + self.dataprep_conn_id = dataprep_conn_id + + def poke(self, context: Context) -> bool: + hooks = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id) + status = hooks.get_job_group_status(job_group_id=int(self.job_group_id)) + return status != JobGroupStatuses.IN_PROGRESS diff --git a/airflow/providers/google/cloud/sensors/dataproc.py b/airflow/providers/google/cloud/sensors/dataproc.py index 02b2d5e14d7ab..9df525c7151a9 100644 --- a/airflow/providers/google/cloud/sensors/dataproc.py +++ b/airflow/providers/google/cloud/sensors/dataproc.py @@ -16,9 +16,11 @@ # specific language governing permissions and limitations # under the License. """This module contains a Dataproc Job sensor.""" +from __future__ import annotations + # pylint: disable=C0302 import time -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Sequence from google.api_core.exceptions import ServerError from google.cloud.dataproc_v1.types import JobStatus @@ -43,17 +45,17 @@ class DataprocJobSensor(BaseSensorOperator): :param wait_timeout: How many seconds wait for job to be ready. """ - template_fields: Sequence[str] = ('project_id', 'region', 'dataproc_job_id') - ui_color = '#f0eee4' + template_fields: Sequence[str] = ("project_id", "region", "dataproc_job_id") + ui_color = "#f0eee4" def __init__( self, *, dataproc_job_id: str, region: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - wait_timeout: Optional[int] = None, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + wait_timeout: int | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -62,16 +64,16 @@ def __init__( self.dataproc_job_id = dataproc_job_id self.region = region self.wait_timeout = wait_timeout - self.start_sensor_time: Optional[float] = None + self.start_sensor_time: float | None = None - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: self.start_sensor_time = time.monotonic() super().execute(context) def _duration(self): return time.monotonic() - self.start_sensor_time - def poke(self, context: "Context") -> bool: + def poke(self, context: Context) -> bool: hook = DataprocHook(gcp_conn_id=self.gcp_conn_id) if self.wait_timeout: try: @@ -93,13 +95,13 @@ def poke(self, context: "Context") -> bool: state = job.status.state if state == JobStatus.State.ERROR: - raise AirflowException(f'Job failed:\n{job}') + raise AirflowException(f"Job failed:\n{job}") elif state in { JobStatus.State.CANCELLED, JobStatus.State.CANCEL_PENDING, JobStatus.State.CANCEL_STARTED, }: - raise AirflowException(f'Job was cancelled:\n{job}') + raise AirflowException(f"Job was cancelled:\n{job}") elif JobStatus.State.DONE == state: self.log.debug("Job %s completed successfully.", self.dataproc_job_id) return True diff --git a/airflow/providers/google/cloud/sensors/gcs.py b/airflow/providers/google/cloud/sensors/gcs.py index 225ce5937a907..27c26abac05fe 100644 --- a/airflow/providers/google/cloud/sensors/gcs.py +++ b/airflow/providers/google/cloud/sensors/gcs.py @@ -16,11 +16,12 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Cloud Storage sensors.""" +from __future__ import annotations import os import textwrap from datetime import datetime -from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Set, Union +from typing import TYPE_CHECKING, Callable, Sequence from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.gcs import GCSHook @@ -53,20 +54,20 @@ class GCSObjectExistenceSensor(BaseSensorOperator): """ template_fields: Sequence[str] = ( - 'bucket', - 'object', - 'impersonation_chain', + "bucket", + "object", + "impersonation_chain", ) - ui_color = '#f0eee4' + ui_color = "#f0eee4" def __init__( self, *, bucket: str, object: str, - google_cloud_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + google_cloud_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: @@ -77,8 +78,8 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def poke(self, context: "Context") -> bool: - self.log.info('Sensor checks existence of : %s, %s', self.bucket, self.object) + def poke(self, context: Context) -> bool: + self.log.info("Sensor checks existence of : %s, %s", self.bucket, self.object) hook = GCSHook( gcp_conn_id=self.google_cloud_conn_id, delegate_to=self.delegate_to, @@ -126,20 +127,20 @@ class GCSObjectUpdateSensor(BaseSensorOperator): """ template_fields: Sequence[str] = ( - 'bucket', - 'object', - 'impersonation_chain', + "bucket", + "object", + "impersonation_chain", ) - ui_color = '#f0eee4' + ui_color = "#f0eee4" def __init__( self, bucket: str, object: str, ts_func: Callable = ts_function, - google_cloud_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + google_cloud_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: @@ -151,8 +152,8 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def poke(self, context: "Context") -> bool: - self.log.info('Sensor checks existence of : %s, %s', self.bucket, self.object) + def poke(self, context: Context) -> bool: + self.log.info("Sensor checks existence of : %s, %s", self.bucket, self.object) hook = GCSHook( gcp_conn_id=self.google_cloud_conn_id, delegate_to=self.delegate_to, @@ -188,19 +189,19 @@ class GCSObjectsWithPrefixExistenceSensor(BaseSensorOperator): """ template_fields: Sequence[str] = ( - 'bucket', - 'prefix', - 'impersonation_chain', + "bucket", + "prefix", + "impersonation_chain", ) - ui_color = '#f0eee4' + ui_color = "#f0eee4" def __init__( self, bucket: str, prefix: str, - google_cloud_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + google_cloud_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -208,11 +209,11 @@ def __init__( self.prefix = prefix self.google_cloud_conn_id = google_cloud_conn_id self.delegate_to = delegate_to - self._matches: List[str] = [] + self._matches: list[str] = [] self.impersonation_chain = impersonation_chain - def poke(self, context: "Context") -> bool: - self.log.info('Sensor checks existence of objects: %s, %s', self.bucket, self.prefix) + def poke(self, context: Context) -> bool: + self.log.info("Sensor checks existence of objects: %s, %s", self.bucket, self.prefix) hook = GCSHook( gcp_conn_id=self.google_cloud_conn_id, delegate_to=self.delegate_to, @@ -221,7 +222,7 @@ def poke(self, context: "Context") -> bool: self._matches = hook.list(self.bucket, prefix=self.prefix) return bool(self._matches) - def execute(self, context: "Context") -> List[str]: + def execute(self, context: Context) -> list[str]: """Overridden to allow matches to be passed""" super().execute(context) return self._matches @@ -240,7 +241,7 @@ class GCSUploadSessionCompleteSensor(BaseSensorOperator): """ Checks for changes in the number of objects at prefix in Google Cloud Storage bucket and returns True if the inactivity period has passed with no - increase in the number of objects. Note, this sensor will no behave correctly + increase in the number of objects. Note, this sensor will not behave correctly in reschedule mode, as the state of the listed objects in the GCS bucket will be lost between rescheduled invocations. @@ -274,11 +275,11 @@ class GCSUploadSessionCompleteSensor(BaseSensorOperator): """ template_fields: Sequence[str] = ( - 'bucket', - 'prefix', - 'impersonation_chain', + "bucket", + "prefix", + "impersonation_chain", ) - ui_color = '#f0eee4' + ui_color = "#f0eee4" def __init__( self, @@ -286,11 +287,11 @@ def __init__( prefix: str, inactivity_period: float = 60 * 60, min_objects: int = 1, - previous_objects: Optional[Set[str]] = None, + previous_objects: set[str] | None = None, allow_delete: bool = True, - google_cloud_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + google_cloud_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: @@ -309,9 +310,9 @@ def __init__( self.delegate_to = delegate_to self.last_activity_time = None self.impersonation_chain = impersonation_chain - self.hook: Optional[GCSHook] = None + self.hook: GCSHook | None = None - def _get_gcs_hook(self) -> Optional[GCSHook]: + def _get_gcs_hook(self) -> GCSHook | None: if not self.hook: self.hook = GCSHook( gcp_conn_id=self.google_cloud_conn_id, @@ -320,7 +321,7 @@ def _get_gcs_hook(self) -> Optional[GCSHook]: ) return self.hook - def is_bucket_updated(self, current_objects: Set[str]) -> bool: + def is_bucket_updated(self, current_objects: set[str]) -> bool: """ Checks whether new objects have been uploaded and the inactivity_period has passed and updates the state of the sensor accordingly. @@ -394,7 +395,7 @@ def is_bucket_updated(self, current_objects: Set[str]) -> bool: return False return False - def poke(self, context: "Context") -> bool: + def poke(self, context: Context) -> bool: return self.is_bucket_updated( set(self._get_gcs_hook().list(self.bucket, prefix=self.prefix)) # type: ignore[union-attr] ) diff --git a/airflow/providers/google/cloud/sensors/looker.py b/airflow/providers/google/cloud/sensors/looker.py index 15844baa01942..e75d0fb665ac8 100644 --- a/airflow/providers/google/cloud/sensors/looker.py +++ b/airflow/providers/google/cloud/sensors/looker.py @@ -15,10 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains Google Cloud Looker sensors.""" +from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.looker import JobStatus, LookerHook @@ -47,18 +47,18 @@ def __init__( self.materialization_id = materialization_id self.looker_conn_id = looker_conn_id self.cancel_on_kill = cancel_on_kill - self.hook: Optional[LookerHook] = None + self.hook: LookerHook | None = None - def poke(self, context: "Context") -> bool: + def poke(self, context: Context) -> bool: self.hook = LookerHook(looker_conn_id=self.looker_conn_id) if not self.materialization_id: - raise AirflowException('Invalid `materialization_id`.') + raise AirflowException("Invalid `materialization_id`.") # materialization_id is templated var pulling output from start task status_dict = self.hook.pdt_build_status(materialization_id=self.materialization_id) - status = status_dict['status'] + status = status_dict["status"] if status == JobStatus.ERROR.value: msg = status_dict["message"] @@ -67,11 +67,11 @@ def poke(self, context: "Context") -> bool: ) elif status == JobStatus.CANCELLED.value: raise AirflowException( - f'PDT materialization job was cancelled. Job id: {self.materialization_id}.' + f"PDT materialization job was cancelled. Job id: {self.materialization_id}." ) elif status == JobStatus.UNKNOWN.value: raise AirflowException( - f'PDT materialization job has unknown status. Job id: {self.materialization_id}.' + f"PDT materialization job has unknown status. Job id: {self.materialization_id}." ) elif status == JobStatus.DONE.value: self.log.debug( diff --git a/airflow/providers/google/cloud/sensors/pubsub.py b/airflow/providers/google/cloud/sensors/pubsub.py index 86081de72f4a0..95c792518f118 100644 --- a/airflow/providers/google/cloud/sensors/pubsub.py +++ b/airflow/providers/google/cloud/sensors/pubsub.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google PubSub sensor.""" -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Sequence from google.cloud.pubsub_v1.types import ReceivedMessage @@ -83,11 +85,11 @@ class PubSubPullSensor(BaseSensorOperator): """ template_fields: Sequence[str] = ( - 'project_id', - 'subscription', - 'impersonation_chain', + "project_id", + "subscription", + "impersonation_chain", ) - ui_color = '#ff7f50' + ui_color = "#ff7f50" def __init__( self, @@ -96,10 +98,10 @@ def __init__( subscription: str, max_messages: int = 5, ack_messages: bool = False, - gcp_conn_id: str = 'google_cloud_default', - messages_callback: Optional[Callable[[List[ReceivedMessage], "Context"], Any]] = None, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + messages_callback: Callable[[list[ReceivedMessage], Context], Any] | None = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: @@ -115,12 +117,12 @@ def __init__( self._return_value = None - def execute(self, context: "Context") -> Any: + def execute(self, context: Context) -> Any: """Overridden to allow messages to be passed""" super().execute(context) return self._return_value - def poke(self, context: "Context") -> bool: + def poke(self, context: Context) -> bool: hook = PubSubHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -149,8 +151,8 @@ def poke(self, context: "Context") -> bool: def _default_message_callback( self, - pulled_messages: List[ReceivedMessage], - context: "Context", + pulled_messages: list[ReceivedMessage], + context: Context, ): """ This method can be overridden by subclasses or by `messages_callback` constructor argument. diff --git a/airflow/providers/google/cloud/sensors/tasks.py b/airflow/providers/google/cloud/sensors/tasks.py new file mode 100644 index 0000000000000..714fcd18d630c --- /dev/null +++ b/airflow/providers/google/cloud/sensors/tasks.py @@ -0,0 +1,89 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Task sensor.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence + +from airflow.providers.google.cloud.hooks.tasks import CloudTasksHook +from airflow.sensors.base import BaseSensorOperator + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class TaskQueueEmptySensor(BaseSensorOperator): + """ + Pulls tasks count from a cloud task queue. + Always waits for queue returning tasks count as 0. + + :param project_id: the Google Cloud project ID for the subscription (templated) + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param queue_name: The queue name to for which task empty sensing is required. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "project_id", + "location", + "queue_name", + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + location: str, + project_id: str | None = None, + queue_name: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.location = location + self.project_id = project_id + self.queue_name = queue_name + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def poke(self, context: Context) -> bool: + + hook = CloudTasksHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + # TODO uncomment page_size once https://issuetracker.google.com/issues/155978649?pli=1 gets fixed + tasks = hook.list_tasks( + location=self.location, + queue_name=self.queue_name, + # page_size=1 + ) + + self.log.info("tasks exhausted in cloud task queue?: %s" % (len(tasks) == 0)) + + return len(tasks) == 0 diff --git a/airflow/providers/google/cloud/sensors/workflows.py b/airflow/providers/google/cloud/sensors/workflows.py index 8560fd3ee2559..e352fcc4a05ec 100644 --- a/airflow/providers/google/cloud/sensors/workflows.py +++ b/airflow/providers/google/cloud/sensors/workflows.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Sequence, Set, Tuple, Union +from typing import TYPE_CHECKING, Sequence from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.api_core.retry import Retry @@ -56,14 +57,14 @@ def __init__( workflow_id: str, execution_id: str, location: str, - project_id: str, - success_states: Optional[Set[Execution.State]] = None, - failure_states: Optional[Set[Execution.State]] = None, - retry: Union[Retry, _MethodDefault] = DEFAULT, - request_timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + project_id: str | None = None, + success_states: set[Execution.State] | None = None, + failure_states: set[Execution.State] | None = None, + retry: Retry | _MethodDefault = DEFAULT, + request_timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -83,7 +84,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def poke(self, context: 'Context'): + def poke(self, context: Context): hook = WorkflowsHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) self.log.info("Checking state of execution %s for workflow %s", self.execution_id, self.workflow_id) execution: Execution = hook.get_execution( diff --git a/airflow/providers/google/cloud/transfers/adls_to_gcs.py b/airflow/providers/google/cloud/transfers/adls_to_gcs.py index 8d142605a5168..fd827ef92b46d 100644 --- a/airflow/providers/google/cloud/transfers/adls_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/adls_to_gcs.py @@ -19,9 +19,11 @@ This module contains Azure Data Lake Storage to Google Cloud Storage operator. """ +from __future__ import annotations + import os from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook @@ -95,11 +97,11 @@ class ADLSToGCSOperator(ADLSListOperator): """ template_fields: Sequence[str] = ( - 'src_adls', - 'dest_gcs', - 'google_impersonation_chain', + "src_adls", + "dest_gcs", + "google_impersonation_chain", ) - ui_color = '#f0eee4' + ui_color = "#f0eee4" def __init__( self, @@ -107,11 +109,11 @@ def __init__( src_adls: str, dest_gcs: str, azure_data_lake_conn_id: str, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, replace: bool = False, gzip: bool = False, - google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + google_impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: @@ -125,7 +127,7 @@ def __init__( self.gzip = gzip self.google_impersonation_chain = google_impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): # use the super to list all files in an Azure Data Lake path files = super().execute(context) g_hook = GCSHook( @@ -146,7 +148,7 @@ def execute(self, context: 'Context'): hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id) for obj in files: - with NamedTemporaryFile(mode='wb', delete=True) as f: + with NamedTemporaryFile(mode="wb", delete=True) as f: hook.download_file(local_path=f.name, remote_path=obj) f.flush() dest_gcs_bucket, dest_gcs_prefix = _parse_gcs_url(self.dest_gcs) diff --git a/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py b/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py index f107eb686e656..5c86e59cec9ba 100644 --- a/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py @@ -15,9 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow import AirflowException from airflow.models import BaseOperator @@ -62,10 +63,10 @@ class AzureFileShareToGCSOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'share_name', - 'directory_name', - 'prefix', - 'dest_gcs', + "share_name", + "directory_name", + "prefix", + "dest_gcs", ) def __init__( @@ -73,14 +74,14 @@ def __init__( *, share_name: str, dest_gcs: str, - directory_name: Optional[str] = None, - prefix: str = '', - azure_fileshare_conn_id: str = 'azure_fileshare_default', - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, + directory_name: str | None = None, + prefix: str = "", + azure_fileshare_conn_id: str = "azure_fileshare_default", + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, replace: bool = False, gzip: bool = False, - google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + google_impersonation_chain: str | Sequence[str] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -99,15 +100,15 @@ def __init__( def _check_inputs(self) -> None: if self.dest_gcs and not gcs_object_is_directory(self.dest_gcs): self.log.info( - 'Destination Google Cloud Storage path is not a valid ' + "Destination Google Cloud Storage path is not a valid " '"directory", define a path that ends with a slash "/" or ' - 'leave it empty for the root of the bucket.' + "leave it empty for the root of the bucket." ) raise AirflowException( 'The destination Google Cloud Storage path must end with a slash "/" or be empty.' ) - def execute(self, context: 'Context'): + def execute(self, context: Context): self._check_inputs() azure_fileshare_hook = AzureFileShareHook(self.azure_fileshare_conn_id) files = azure_fileshare_hook.list_files( @@ -144,7 +145,7 @@ def execute(self, context: 'Context'): files = list(set(files) - set(existing_files)) if files: - self.log.info('%s files are going to be synced.', len(files)) + self.log.info("%s files are going to be synced.", len(files)) if self.directory_name is None: raise RuntimeError("The directory_name must be set!.") for file in files: @@ -163,7 +164,7 @@ def execute(self, context: 'Context'): gcs_hook.upload(dest_gcs_bucket, dest_gcs_object, temp_file.name, gzip=self.gzip) self.log.info("All done, uploaded %d files to Google Cloud Storage.", len(files)) else: - self.log.info('There are no new files to sync. Have a nice day!') - self.log.info('In sync, no files needed to be uploaded to Google Cloud Storage') + self.log.info("There are no new files to sync. Have a nice day!") + self.log.info("In sync, no files needed to be uploaded to Google Cloud Storage") return files diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py b/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py index 87e6f3f586f4a..b3bda020614b7 100644 --- a/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +++ b/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py @@ -16,8 +16,10 @@ # specific language governing permissions and limitations # under the License. """This module contains Google BigQuery to BigQuery operator.""" +from __future__ import annotations + import warnings -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook @@ -56,7 +58,11 @@ class BigQueryToBigQueryOperator(BaseOperator): encryption_configuration = { "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key" } - :param location: The location used for the operation. + :param location: The geographic location of the job. You must specify the location to run the job if + the location to run a job is not in the US or the EU multi-regional location or + the location is in a single region (for example, us-central1). + For more details check: + https://cloud.google.com/bigquery/docs/locations#specifying_your_location :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -68,28 +74,28 @@ class BigQueryToBigQueryOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'source_project_dataset_tables', - 'destination_project_dataset_table', - 'labels', - 'impersonation_chain', + "source_project_dataset_tables", + "destination_project_dataset_table", + "labels", + "impersonation_chain", ) - template_ext: Sequence[str] = ('.sql',) - ui_color = '#e6f0e4' + template_ext: Sequence[str] = (".sql",) + ui_color = "#e6f0e4" operator_extra_links = (BigQueryTableLink(),) def __init__( self, *, - source_project_dataset_tables: Union[List[str], str], + source_project_dataset_tables: list[str] | str, destination_project_dataset_table: str, - write_disposition: str = 'WRITE_EMPTY', - create_disposition: str = 'CREATE_IF_NEEDED', - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - labels: Optional[Dict] = None, - encryption_configuration: Optional[Dict] = None, - location: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + write_disposition: str = "WRITE_EMPTY", + create_disposition: str = "CREATE_IF_NEEDED", + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + labels: dict | None = None, + encryption_configuration: dict | None = None, + location: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -105,9 +111,9 @@ def __init__( self.location = location self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: self.log.info( - 'Executing copy of %s into: %s', + "Executing copy of %s into: %s", self.source_project_dataset_tables, self.destination_project_dataset_table, ) @@ -129,7 +135,7 @@ def execute(self, context: 'Context') -> None: encryption_configuration=self.encryption_configuration, ) - job = hook.get_job(job_id=job_id).to_api_repr() + job = hook.get_job(job_id=job_id, location=self.location).to_api_repr() conf = job["configuration"]["copy"]["destinationTable"] BigQueryTableLink.persist( context=context, diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py index 09ac190e0f269..02b34890eb40d 100644 --- a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py @@ -16,11 +16,19 @@ # specific language governing permissions and limitations # under the License. """This module contains Google BigQuery to Google Cloud Storage operator.""" -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union +from __future__ import annotations +from typing import TYPE_CHECKING, Any, Sequence + +from google.api_core.exceptions import Conflict +from google.api_core.retry import Retry +from google.cloud.bigquery import DEFAULT_RETRY, ExtractJob + +from airflow import AirflowException from airflow.models import BaseOperator -from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook +from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink +from airflow.providers.google.cloud.triggers.bigquery import BigQueryInsertJobTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -42,6 +50,7 @@ class BigQueryToGCSOperator(BaseOperator): Storage URI (e.g. gs://some-bucket/some-file.txt). (templated) Follows convention defined here: https://cloud.google.com/bigquery/exporting-data-from-bigquery#exportingmultiple + :param project_id: Google Cloud Project where the job is running :param compression: Type of compression to use. :param export_format: File format to export. :param field_delimiter: The delimiter to use when extracting to a CSV. @@ -61,36 +70,54 @@ class BigQueryToGCSOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param result_retry: How to retry the `result` call that retrieves rows + :param result_timeout: The number of seconds to wait for `result` method before using `result_retry` + :param job_id: The ID of the job. It will be suffixed with hash of job configuration + unless ``force_rerun`` is True. + The ID must contain only letters (a-z, A-Z), numbers (0-9), underscores (_), or + dashes (-). The maximum length is 1,024 characters. If not provided then uuid will + be generated. + :param force_rerun: If True then operator will use hash of uuid as job id suffix + :param reattach_states: Set of BigQuery job's states in case of which we should reattach + to the job. Should be other than final states. + :param deferrable: Run operator in the deferrable mode """ template_fields: Sequence[str] = ( - 'source_project_dataset_table', - 'destination_cloud_storage_uris', - 'labels', - 'impersonation_chain', + "source_project_dataset_table", + "destination_cloud_storage_uris", + "labels", + "impersonation_chain", ) template_ext: Sequence[str] = () - ui_color = '#e4e6f0' + ui_color = "#e4e6f0" operator_extra_links = (BigQueryTableLink(),) def __init__( self, *, source_project_dataset_table: str, - destination_cloud_storage_uris: List[str], - compression: str = 'NONE', - export_format: str = 'CSV', - field_delimiter: str = ',', + destination_cloud_storage_uris: list[str], + project_id: str | None = None, + compression: str = "NONE", + export_format: str = "CSV", + field_delimiter: str = ",", print_header: bool = True, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - labels: Optional[Dict] = None, - location: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + labels: dict | None = None, + location: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + result_retry: Retry = DEFAULT_RETRY, + result_timeout: float | None = None, + job_id: str | None = None, + force_rerun: bool = False, + reattach_states: set[str] | None = None, + deferrable: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) - + self.project_id = project_id self.source_project_dataset_table = source_project_dataset_table self.destination_cloud_storage_uris = destination_cloud_storage_uris self.compression = compression @@ -102,10 +129,71 @@ def __init__( self.labels = labels self.location = location self.impersonation_chain = impersonation_chain + self.result_retry = result_retry + self.result_timeout = result_timeout + self.job_id = job_id + self.force_rerun = force_rerun + self.reattach_states: set[str] = reattach_states or set() + self.hook: BigQueryHook | None = None + self.deferrable = deferrable + + @staticmethod + def _handle_job_error(job: ExtractJob) -> None: + if job.error_result: + raise AirflowException(f"BigQuery job {job.job_id} failed: {job.error_result}") + + def _prepare_configuration(self): + source_project, source_dataset, source_table = self.hook.split_tablename( + table_input=self.source_project_dataset_table, + default_project_id=self.project_id or self.hook.project_id, + var_name="source_project_dataset_table", + ) + + configuration: dict[str, Any] = { + "extract": { + "sourceTable": { + "projectId": source_project, + "datasetId": source_dataset, + "tableId": source_table, + }, + "compression": self.compression, + "destinationUris": self.destination_cloud_storage_uris, + "destinationFormat": self.export_format, + } + } + + if self.labels: + configuration["labels"] = self.labels + + if self.export_format == "CSV": + # Only set fieldDelimiter and printHeader fields if using CSV. + # Google does not like it if you set these fields for other export + # formats. + configuration["extract"]["fieldDelimiter"] = self.field_delimiter + configuration["extract"]["printHeader"] = self.print_header + return configuration + + def _submit_job( + self, + hook: BigQueryHook, + job_id: str, + configuration: dict, + ) -> BigQueryJob: + # Submit a new job without waiting for it to complete. - def execute(self, context: 'Context'): + return hook.insert_job( + configuration=configuration, + project_id=hook.project_id, + location=self.location, + job_id=job_id, + timeout=self.result_timeout, + retry=self.result_retry, + nowait=True, + ) + + def execute(self, context: Context): self.log.info( - 'Executing extract of %s into: %s', + "Executing extract of %s into: %s", self.source_project_dataset_table, self.destination_cloud_storage_uris, ) @@ -115,18 +203,41 @@ def execute(self, context: 'Context'): location=self.location, impersonation_chain=self.impersonation_chain, ) - job_id = hook.run_extract( - source_project_dataset_table=self.source_project_dataset_table, - destination_cloud_storage_uris=self.destination_cloud_storage_uris, - compression=self.compression, - export_format=self.export_format, - field_delimiter=self.field_delimiter, - print_header=self.print_header, - labels=self.labels, + self.hook = hook + + configuration = self._prepare_configuration() + job_id = hook.generate_job_id( + job_id=self.job_id, + dag_id=self.dag_id, + task_id=self.task_id, + logical_date=context["logical_date"], + configuration=configuration, + force_rerun=self.force_rerun, ) - job = hook.get_job(job_id=job_id).to_api_repr() - conf = job["configuration"]["extract"]["sourceTable"] + try: + self.log.info("Executing: %s", configuration) + job: ExtractJob = self._submit_job(hook=hook, job_id=job_id, configuration=configuration) + except Conflict: + # If the job already exists retrieve it + job = hook.get_job( + project_id=self.project_id, + location=self.location, + job_id=job_id, + ) + if job.state in self.reattach_states: + # We are reattaching to a job + job.result(timeout=self.result_timeout, retry=self.result_retry) + self._handle_job_error(job) + else: + # Same job configuration so we need force_rerun + raise AirflowException( + f"Job with id: {job_id} already exists and is in {job.state} state. If you " + f"want to force rerun it consider setting `force_rerun=True`." + f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`" + ) + + conf = job.to_api_repr()["configuration"]["extract"]["sourceTable"] dataset_id, project_id, table_id = conf["datasetId"], conf["projectId"], conf["tableId"] BigQueryTableLink.persist( context=context, @@ -135,3 +246,30 @@ def execute(self, context: 'Context'): project_id=project_id, table_id=table_id, ) + + if self.deferrable: + self.defer( + timeout=self.execution_timeout, + trigger=BigQueryInsertJobTrigger( + conn_id=self.gcp_conn_id, + job_id=job_id, + project_id=self.hook.project_id, + ), + method_name="execute_complete", + ) + else: + job.result(timeout=self.result_timeout, retry=self.result_retry) + + def execute_complete(self, context: Context, event: dict[str, Any]): + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + self.log.info( + "%s completed with response %s ", + self.task_id, + event["message"], + ) diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py b/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py index d8a600eabeb5f..ac62541b4cb2e 100644 --- a/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +++ b/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains Google BigQuery to MSSQL operator.""" -from typing import TYPE_CHECKING, List, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook @@ -75,7 +77,7 @@ class BigQueryToMsSqlOperator(BaseOperator): account from the list granting this role to the originating account (templated). """ - template_fields: Sequence[str] = ('source_project_dataset_table', 'mssql_table', 'impersonation_chain') + template_fields: Sequence[str] = ("source_project_dataset_table", "mssql_table", "impersonation_chain") operator_extra_links = (BigQueryTableLink(),) def __init__( @@ -83,15 +85,15 @@ def __init__( *, source_project_dataset_table: str, mssql_table: str, - selected_fields: Optional[Union[List[str], str]] = None, - gcp_conn_id: str = 'google_cloud_default', - mssql_conn_id: str = 'mssql_default', - database: Optional[str] = None, - delegate_to: Optional[str] = None, + selected_fields: list[str] | str | None = None, + gcp_conn_id: str = "google_cloud_default", + mssql_conn_id: str = "mssql_default", + database: str | None = None, + delegate_to: str | None = None, replace: bool = False, batch_size: int = 1000, - location: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + location: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -106,21 +108,21 @@ def __init__( self.location = location self.impersonation_chain = impersonation_chain try: - _, self.dataset_id, self.table_id = source_project_dataset_table.split('.') + _, self.dataset_id, self.table_id = source_project_dataset_table.split(".") except ValueError: raise ValueError( - f'Could not parse {source_project_dataset_table} as ..
' + f"Could not parse {source_project_dataset_table} as ..
" ) from None self.source_project_dataset_table = source_project_dataset_table - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: big_query_hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, location=self.location, impersonation_chain=self.impersonation_chain, ) - project_id, dataset_id, table_id = self.source_project_dataset_table.split('.') + project_id, dataset_id, table_id = self.source_project_dataset_table.split(".") BigQueryTableLink.persist( context=context, task_instance=self, diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_mysql.py b/airflow/providers/google/cloud/transfers/bigquery_to_mysql.py index 6fbfceaf38d59..6979e983eae09 100644 --- a/airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +++ b/airflow/providers/google/cloud/transfers/bigquery_to_mysql.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains Google BigQuery to MySQL operator.""" -from typing import TYPE_CHECKING, List, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook @@ -74,10 +76,10 @@ class BigQueryToMySqlOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'dataset_id', - 'table_id', - 'mysql_table', - 'impersonation_chain', + "dataset_id", + "table_id", + "mysql_table", + "impersonation_chain", ) def __init__( @@ -85,15 +87,15 @@ def __init__( *, dataset_table: str, mysql_table: str, - selected_fields: Optional[Union[List[str], str]] = None, - gcp_conn_id: str = 'google_cloud_default', - mysql_conn_id: str = 'mysql_default', - database: Optional[str] = None, - delegate_to: Optional[str] = None, + selected_fields: list[str] | str | None = None, + gcp_conn_id: str = "google_cloud_default", + mysql_conn_id: str = "mysql_default", + database: str | None = None, + delegate_to: str | None = None, replace: bool = False, batch_size: int = 1000, - location: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + location: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -108,11 +110,11 @@ def __init__( self.location = location self.impersonation_chain = impersonation_chain try: - self.dataset_id, self.table_id = dataset_table.split('.') + self.dataset_id, self.table_id = dataset_table.split(".") except ValueError: - raise ValueError(f'Could not parse {dataset_table} as .
') from None + raise ValueError(f"Could not parse {dataset_table} as .
") from None - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: big_query_hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/google/cloud/transfers/calendar_to_gcs.py b/airflow/providers/google/cloud/transfers/calendar_to_gcs.py index 765fbf8da5ef5..08b22420e4d0a 100644 --- a/airflow/providers/google/cloud/transfers/calendar_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/calendar_to_gcs.py @@ -14,11 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import json from datetime import datetime from tempfile import NamedTemporaryFile -from typing import Any, List, Optional, Sequence, Union +from typing import Any, Sequence from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook @@ -89,25 +90,25 @@ def __init__( destination_bucket: str, api_version: str, calendar_id: str = "primary", - i_cal_uid: Optional[str] = None, - max_attendees: Optional[int] = None, - max_results: Optional[int] = None, - order_by: Optional[str] = None, - private_extended_property: Optional[str] = None, - text_search_query: Optional[str] = None, - shared_extended_property: Optional[str] = None, - show_deleted: Optional[bool] = None, - show_hidden_invitation: Optional[bool] = None, - single_events: Optional[bool] = None, - sync_token: Optional[str] = None, - time_max: Optional[datetime] = None, - time_min: Optional[datetime] = None, - time_zone: Optional[str] = None, - updated_min: Optional[datetime] = None, - destination_path: Optional[str] = None, + i_cal_uid: str | None = None, + max_attendees: int | None = None, + max_results: int | None = None, + order_by: str | None = None, + private_extended_property: str | None = None, + text_search_query: str | None = None, + shared_extended_property: str | None = None, + show_deleted: bool | None = None, + show_hidden_invitation: bool | None = None, + single_events: bool | None = None, + sync_token: str | None = None, + time_max: datetime | None = None, + time_min: datetime | None = None, + time_zone: str | None = None, + updated_min: datetime | None = None, + destination_path: str | None = None, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -136,7 +137,7 @@ def __init__( def _upload_data( self, - events: List[Any], + events: list[Any], ) -> str: gcs_hook = GCSHook( gcp_conn_id=self.gcp_conn_id, diff --git a/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py b/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py index 248a03d8e1770..e757df46f2b29 100644 --- a/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py @@ -19,13 +19,14 @@ This module contains operator for copying data from Cassandra to Google Cloud Storage in JSON format. """ +from __future__ import annotations import json from base64 import b64encode from datetime import datetime from decimal import Decimal from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NewType, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Iterable, NewType, Sequence from uuid import UUID from cassandra.util import Date, OrderedMapSerializedKey, SortedSet, Time @@ -38,7 +39,7 @@ if TYPE_CHECKING: from airflow.utils.context import Context -NotSetType = NewType('NotSetType', object) +NotSetType = NewType("NotSetType", object) NOT_SET = NotSetType(object()) @@ -84,14 +85,14 @@ class CassandraToGCSOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'cql', - 'bucket', - 'filename', - 'schema_filename', - 'impersonation_chain', + "cql", + "bucket", + "filename", + "schema_filename", + "impersonation_chain", ) - template_ext: Sequence[str] = ('.cql',) - ui_color = '#a0e08c' + template_ext: Sequence[str] = (".cql",) + ui_color = "#a0e08c" def __init__( self, @@ -99,14 +100,14 @@ def __init__( cql: str, bucket: str, filename: str, - schema_filename: Optional[str] = None, + schema_filename: str | None = None, approx_max_file_size_bytes: int = 1900000000, gzip: bool = False, - cassandra_conn_id: str = 'cassandra_default', - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - query_timeout: Union[float, None, NotSetType] = NOT_SET, + cassandra_conn_id: str = "cassandra_default", + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + query_timeout: float | None | NotSetType = NOT_SET, encode_uuid: bool = True, **kwargs, ) -> None: @@ -127,62 +128,62 @@ def __init__( # Default Cassandra to BigQuery type mapping CQL_TYPE_MAP = { - 'BytesType': 'STRING', - 'DecimalType': 'FLOAT', - 'UUIDType': 'STRING', - 'BooleanType': 'BOOL', - 'ByteType': 'INTEGER', - 'AsciiType': 'STRING', - 'FloatType': 'FLOAT', - 'DoubleType': 'FLOAT', - 'LongType': 'INTEGER', - 'Int32Type': 'INTEGER', - 'IntegerType': 'INTEGER', - 'InetAddressType': 'STRING', - 'CounterColumnType': 'INTEGER', - 'DateType': 'TIMESTAMP', - 'SimpleDateType': 'DATE', - 'TimestampType': 'TIMESTAMP', - 'TimeUUIDType': 'STRING', - 'ShortType': 'INTEGER', - 'TimeType': 'TIME', - 'DurationType': 'INTEGER', - 'UTF8Type': 'STRING', - 'VarcharType': 'STRING', + "BytesType": "STRING", + "DecimalType": "FLOAT", + "UUIDType": "STRING", + "BooleanType": "BOOL", + "ByteType": "INTEGER", + "AsciiType": "STRING", + "FloatType": "FLOAT", + "DoubleType": "FLOAT", + "LongType": "INTEGER", + "Int32Type": "INTEGER", + "IntegerType": "INTEGER", + "InetAddressType": "STRING", + "CounterColumnType": "INTEGER", + "DateType": "TIMESTAMP", + "SimpleDateType": "DATE", + "TimestampType": "TIMESTAMP", + "TimeUUIDType": "STRING", + "ShortType": "INTEGER", + "TimeType": "TIME", + "DurationType": "INTEGER", + "UTF8Type": "STRING", + "VarcharType": "STRING", } - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CassandraHook(cassandra_conn_id=self.cassandra_conn_id) query_extra = {} if self.query_timeout is not NOT_SET: - query_extra['timeout'] = self.query_timeout + query_extra["timeout"] = self.query_timeout cursor = hook.get_conn().execute(self.cql, **query_extra) # If a schema is set, create a BQ schema JSON file. if self.schema_filename: - self.log.info('Writing local schema file') + self.log.info("Writing local schema file") schema_file = self._write_local_schema_file(cursor) # Flush file before uploading - schema_file['file_handle'].flush() + schema_file["file_handle"].flush() - self.log.info('Uploading schema file to GCS.') + self.log.info("Uploading schema file to GCS.") self._upload_to_gcs(schema_file) - schema_file['file_handle'].close() + schema_file["file_handle"].close() counter = 0 - self.log.info('Writing local data files') + self.log.info("Writing local data files") for file_to_upload in self._write_local_data_files(cursor): # Flush file before uploading - file_to_upload['file_handle'].flush() + file_to_upload["file_handle"].flush() - self.log.info('Uploading chunk file #%d to GCS.', counter) + self.log.info("Uploading chunk file #%d to GCS.", counter) self._upload_to_gcs(file_to_upload) - self.log.info('Removing local file') - file_to_upload['file_handle'].close() + self.log.info("Removing local file") + file_to_upload["file_handle"].close() counter += 1 # Close all sessions and connection associated with this Cassandra cluster @@ -200,16 +201,16 @@ def _write_local_data_files(self, cursor): tmp_file_handle = NamedTemporaryFile(delete=True) file_to_upload = { - 'file_name': self.filename.format(file_no), - 'file_handle': tmp_file_handle, + "file_name": self.filename.format(file_no), + "file_handle": tmp_file_handle, } for row in cursor: row_dict = self.generate_data_dict(row._fields, row) - content = json.dumps(row_dict).encode('utf-8') + content = json.dumps(row_dict).encode("utf-8") tmp_file_handle.write(content) # Append newline to make dumps BigQuery compatible. - tmp_file_handle.write(b'\n') + tmp_file_handle.write(b"\n") if tmp_file_handle.tell() >= self.approx_max_file_size_bytes: file_no += 1 @@ -217,8 +218,8 @@ def _write_local_data_files(self, cursor): yield file_to_upload tmp_file_handle = NamedTemporaryFile(delete=True) file_to_upload = { - 'file_name': self.filename.format(file_no), - 'file_handle': tmp_file_handle, + "file_name": self.filename.format(file_no), + "file_handle": tmp_file_handle, } yield file_to_upload @@ -236,12 +237,12 @@ def _write_local_schema_file(self, cursor): for name, type_ in zip(cursor.column_names, cursor.column_types): schema.append(self.generate_schema_dict(name, type_)) - json_serialized_schema = json.dumps(schema).encode('utf-8') + json_serialized_schema = json.dumps(schema).encode("utf-8") tmp_schema_file_handle.write(json_serialized_schema) schema_file_to_upload = { - 'file_name': self.schema_filename, - 'file_handle': tmp_schema_file_handle, + "file_name": self.schema_filename, + "file_handle": tmp_schema_file_handle, } return schema_file_to_upload @@ -254,27 +255,27 @@ def _upload_to_gcs(self, file_to_upload): ) hook.upload( bucket_name=self.bucket, - object_name=file_to_upload.get('file_name'), - filename=file_to_upload.get('file_handle').name, - mime_type='application/json', + object_name=file_to_upload.get("file_name"), + filename=file_to_upload.get("file_handle").name, + mime_type="application/json", gzip=self.gzip, ) - def generate_data_dict(self, names: Iterable[str], values: Any) -> Dict[str, Any]: + def generate_data_dict(self, names: Iterable[str], values: Any) -> dict[str, Any]: """Generates data structure that will be stored as file in GCS.""" return {n: self.convert_value(v) for n, v in zip(names, values)} - def convert_value(self, value: Optional[Any]) -> Optional[Any]: + def convert_value(self, value: Any | None) -> Any | None: """Convert value to BQ type.""" if not value: return value elif isinstance(value, (str, int, float, bool, dict)): return value elif isinstance(value, bytes): - return b64encode(value).decode('ascii') + return b64encode(value).decode("ascii") elif isinstance(value, UUID): if self.encode_uuid: - return b64encode(value.bytes).decode('ascii') + return b64encode(value.bytes).decode("ascii") else: return str(value) elif isinstance(value, (datetime, Date)): @@ -282,23 +283,23 @@ def convert_value(self, value: Optional[Any]) -> Optional[Any]: elif isinstance(value, Decimal): return float(value) elif isinstance(value, Time): - return str(value).split('.')[0] + return str(value).split(".")[0] elif isinstance(value, (list, SortedSet)): return self.convert_array_types(value) - elif hasattr(value, '_fields'): + elif hasattr(value, "_fields"): return self.convert_user_type(value) elif isinstance(value, tuple): return self.convert_tuple_type(value) elif isinstance(value, OrderedMapSerializedKey): return self.convert_map_type(value) else: - raise AirflowException('Unexpected value: ' + str(value)) + raise AirflowException("Unexpected value: " + str(value)) - def convert_array_types(self, value: Union[List[Any], SortedSet]) -> List[Any]: + def convert_array_types(self, value: list[Any] | SortedSet) -> list[Any]: """Maps convert_value over array.""" return [self.convert_value(nested_value) for nested_value in value] - def convert_user_type(self, value: Any) -> Dict[str, Any]: + def convert_user_type(self, value: Any) -> dict[str, Any]: """ Converts a user type to RECORD that contains n fields, where n is the number of attributes. Each element in the user type class will be converted to its @@ -308,46 +309,46 @@ def convert_user_type(self, value: Any) -> Dict[str, Any]: values = [self.convert_value(getattr(value, name)) for name in names] return self.generate_data_dict(names, values) - def convert_tuple_type(self, values: Tuple[Any]) -> Dict[str, Any]: + def convert_tuple_type(self, values: tuple[Any]) -> dict[str, Any]: """ Converts a tuple to RECORD that contains n fields, each will be converted to its corresponding data type in bq and will be named 'field_', where index is determined by the order of the tuple elements defined in cassandra. """ - names = ['field_' + str(i) for i in range(len(values))] + names = ["field_" + str(i) for i in range(len(values))] return self.generate_data_dict(names, values) - def convert_map_type(self, value: OrderedMapSerializedKey) -> List[Dict[str, Any]]: + def convert_map_type(self, value: OrderedMapSerializedKey) -> list[dict[str, Any]]: """ Converts a map to a repeated RECORD that contains two fields: 'key' and 'value', each will be converted to its corresponding data type in BQ. """ converted_map = [] for k, v in zip(value.keys(), value.values()): - converted_map.append({'key': self.convert_value(k), 'value': self.convert_value(v)}) + converted_map.append({"key": self.convert_value(k), "value": self.convert_value(v)}) return converted_map @classmethod - def generate_schema_dict(cls, name: str, type_: Any) -> Dict[str, Any]: + def generate_schema_dict(cls, name: str, type_: Any) -> dict[str, Any]: """Generates BQ schema.""" - field_schema: Dict[str, Any] = {} - field_schema.update({'name': name}) - field_schema.update({'type_': cls.get_bq_type(type_)}) - field_schema.update({'mode': cls.get_bq_mode(type_)}) + field_schema: dict[str, Any] = {} + field_schema.update({"name": name}) + field_schema.update({"type_": cls.get_bq_type(type_)}) + field_schema.update({"mode": cls.get_bq_mode(type_)}) fields = cls.get_bq_fields(type_) if fields: - field_schema.update({'fields': fields}) + field_schema.update({"fields": fields}) return field_schema @classmethod - def get_bq_fields(cls, type_: Any) -> List[Dict[str, Any]]: + def get_bq_fields(cls, type_: Any) -> list[dict[str, Any]]: """Converts non simple type value to BQ representation.""" if cls.is_simple_type(type_): return [] # In case of not simple type - names: List[str] = [] - types: List[Any] = [] + names: list[str] = [] + types: list[Any] = [] if cls.is_array_type(type_) and cls.is_record_type(type_.subtypes[0]): names = type_.subtypes[0].fieldnames types = type_.subtypes[0].subtypes @@ -355,10 +356,10 @@ def get_bq_fields(cls, type_: Any) -> List[Dict[str, Any]]: names = type_.fieldnames types = type_.subtypes - if types and not names and type_.cassname == 'TupleType': - names = ['field_' + str(i) for i in range(len(types))] - elif types and not names and type_.cassname == 'MapType': - names = ['key', 'value'] + if types and not names and type_.cassname == "TupleType": + names = ["field_" + str(i) for i in range(len(types))] + elif types and not names and type_.cassname == "MapType": + names = ["key", "value"] return [cls.generate_schema_dict(n, t) for n, t in zip(names, types)] @@ -370,12 +371,12 @@ def is_simple_type(type_: Any) -> bool: @staticmethod def is_array_type(type_: Any) -> bool: """Check if type is an array type.""" - return type_.cassname in ['ListType', 'SetType'] + return type_.cassname in ["ListType", "SetType"] @staticmethod def is_record_type(type_: Any) -> bool: """Checks the record type.""" - return type_.cassname in ['UserType', 'TupleType', 'MapType'] + return type_.cassname in ["UserType", "TupleType", "MapType"] @classmethod def get_bq_type(cls, type_: Any) -> str: @@ -383,18 +384,18 @@ def get_bq_type(cls, type_: Any) -> str: if cls.is_simple_type(type_): return CassandraToGCSOperator.CQL_TYPE_MAP[type_.cassname] elif cls.is_record_type(type_): - return 'RECORD' + return "RECORD" elif cls.is_array_type(type_): return cls.get_bq_type(type_.subtypes[0]) else: - raise AirflowException('Not a supported type_: ' + type_.cassname) + raise AirflowException("Not a supported type_: " + type_.cassname) @classmethod def get_bq_mode(cls, type_: Any) -> str: """Converts type to equivalent BQ mode.""" - if cls.is_array_type(type_) or type_.cassname == 'MapType': - return 'REPEATED' + if cls.is_array_type(type_) or type_.cassname == "MapType": + return "REPEATED" elif cls.is_record_type(type_) or cls.is_simple_type(type_): - return 'NULLABLE' + return "NULLABLE" else: - raise AirflowException('Not a supported type_: ' + type_.cassname) + raise AirflowException("Not a supported type_: " + type_.cassname) diff --git a/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py b/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py index 4222b8ad2befa..bda341bd2e73d 100644 --- a/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py @@ -16,10 +16,12 @@ # specific language governing permissions and limitations # under the License. """This module contains Facebook Ad Reporting to GCS operators.""" +from __future__ import annotations + import csv import tempfile from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Sequence from facebook_business.adobjects.adsinsights import AdsInsights @@ -95,14 +97,14 @@ def __init__( *, bucket_name: str, object_name: str, - fields: List[str], - parameters: Optional[Dict[str, Any]] = None, + fields: list[str], + parameters: dict[str, Any] | None = None, gzip: bool = False, upload_as_account: bool = False, - api_version: Optional[str] = None, + api_version: str | None = None, gcp_conn_id: str = "google_cloud_default", facebook_conn_id: str = "facebook_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -117,7 +119,7 @@ def __init__( self.upload_as_account = upload_as_account self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): service = FacebookAdsReportingHook( facebook_conn_id=self.facebook_conn_id, api_version=self.api_version ) @@ -157,9 +159,9 @@ def _generate_rows_with_action(self, type_check: bool): def _prepare_rows_for_upload( self, - rows: List[AdsInsights], - converted_rows_with_action: Dict[FlushAction, list], - account_id: Optional[str], + rows: list[AdsInsights], + converted_rows_with_action: dict[FlushAction, list], + account_id: str | None, ): converted_rows = [dict(row) for row in rows] if account_id is not None and self.upload_as_account: @@ -174,7 +176,7 @@ def _prepare_rows_for_upload( self.log.info("Facebook Returned %s data points ", len(converted_rows)) return converted_rows_with_action - def _decide_and_flush(self, converted_rows_with_action: Dict[FlushAction, list]): + def _decide_and_flush(self, converted_rows_with_action: dict[FlushAction, list]): total_data_count = 0 once_action = converted_rows_with_action.get(FlushAction.EXPORT_ONCE) if once_action is not None: @@ -202,7 +204,7 @@ def _decide_and_flush(self, converted_rows_with_action: Dict[FlushAction, list]) raise AirflowException(message) return total_data_count - def _flush_rows(self, converted_rows: Optional[List[Any]], object_name: str): + def _flush_rows(self, converted_rows: list[Any] | None, object_name: str): if converted_rows: headers = converted_rows[0].keys() with tempfile.NamedTemporaryFile("w", suffix=".csv") as csvfile: diff --git a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py index 6821e78433396..ce7da490626dc 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py @@ -16,14 +16,21 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Cloud Storage to BigQuery operator.""" +from __future__ import annotations import json -import warnings -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Sequence +from google.api_core.exceptions import Conflict +from google.api_core.retry import Retry +from google.cloud.bigquery import DEFAULT_RETRY, CopyJob, ExtractJob, LoadJob, QueryJob + +from airflow import AirflowException from airflow.models import BaseOperator -from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook +from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink +from airflow.providers.google.cloud.triggers.bigquery import BigQueryInsertJobTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -56,6 +63,8 @@ class GCSToBigQueryOperator(BaseOperator): :param schema_object: If set, a GCS object path pointing to a .json file that contains the schema for the table. (templated) Parameter must be defined if 'schema_fields' is null and autodetect is False. + :param schema_object_bucket: [Optional] If set, the GCS bucket where the schema object + template is stored. (templated) (Default: the value of ``bucket``) :param source_format: File format to export. :param compression: [Optional] The compression type of the data source. Possible values include GZIP and NONE. @@ -63,7 +72,18 @@ class GCSToBigQueryOperator(BaseOperator): This setting is ignored for Google Cloud Bigtable, Google Cloud Datastore backups and Avro formats. :param create_disposition: The create disposition if the table doesn't exist. - :param skip_leading_rows: Number of rows to skip when loading from a CSV. + :param skip_leading_rows: The number of rows at the top of a CSV file that BigQuery + will skip when loading the data. + When autodetect is on, the behavior is the following: + skip_leading_rows unspecified - Autodetect tries to detect headers in the first row. + If they are not detected, the row is read as data. Otherwise, data is read starting + from the second row. + skip_leading_rows is 0 - Instructs autodetect that there are no headers and data + should be read starting from the first row. + skip_leading_rows = N > 0 - Autodetect skips N-1 rows and tries to detect headers + in row N. If headers are not detected, row N is just skipped. Otherwise, row N is + used to extract column names for the detected schema. + Default value set to None so that autodetect option can detect schema fields. :param write_disposition: The write disposition if the table already exists. :param field_delimiter: The delimiter to use when loading from a CSV. :param max_bad_records: The maximum number of bad records that BigQuery can @@ -126,18 +146,23 @@ class GCSToBigQueryOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). :param labels: [Optional] Labels for the BiqQuery table. - :param description: [Optional] Description for the BigQuery table. + :param description: [Optional] Description for the BigQuery table. This will only be used if the + destination table is newly created. If the table already exists and a value different than the + current description is provided, the job will fail. + :param deferrable: Run operator in the deferrable mode """ template_fields: Sequence[str] = ( - 'bucket', - 'source_objects', - 'schema_object', - 'destination_project_dataset_table', - 'impersonation_chain', + "bucket", + "source_objects", + "schema_object", + "schema_object_bucket", + "destination_project_dataset_table", + "impersonation_chain", ) - template_ext: Sequence[str] = ('.sql',) - ui_color = '#f0eee4' + template_ext: Sequence[str] = (".sql",) + ui_color = "#f0eee4" + operator_extra_links = (BigQueryTableLink(),) def __init__( self, @@ -147,12 +172,13 @@ def __init__( destination_project_dataset_table, schema_fields=None, schema_object=None, - source_format='CSV', - compression='NONE', - create_disposition='CREATE_IF_NEEDED', - skip_leading_rows=0, - write_disposition='WRITE_EMPTY', - field_delimiter=',', + schema_object_bucket=None, + source_format="CSV", + compression="NONE", + create_disposition="CREATE_IF_NEEDED", + skip_leading_rows=None, + write_disposition="WRITE_EMPTY", + field_delimiter=",", max_bad_records=0, quote_character=None, ignore_unknown_values=False, @@ -160,7 +186,7 @@ def __init__( allow_jagged_rows=False, encoding="UTF-8", max_id_key=None, - gcp_conn_id='google_cloud_default', + gcp_conn_id="google_cloud_default", delegate_to=None, schema_update_options=(), src_fmt_configs=None, @@ -170,13 +196,22 @@ def __init__( autodetect=True, encryption_configuration=None, location=None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, labels=None, description=None, + deferrable: bool = False, + result_retry: Retry = DEFAULT_RETRY, + result_timeout: float | None = None, + cancel_on_kill: bool = True, + job_id: str | None = None, + force_rerun: bool = True, + reattach_states: set[str] | None = None, **kwargs, - ): + ) -> None: super().__init__(**kwargs) + self.hook: BigQueryHook | None = None + self.configuration: dict[str, Any] = {} # GCS config if src_fmt_configs is None: @@ -187,6 +222,10 @@ def __init__( self.source_objects = source_objects self.schema_object = schema_object + if schema_object_bucket is None: + schema_object_bucket = bucket + self.schema_object_bucket = schema_object_bucket + # BQ config self.destination_project_dataset_table = destination_project_dataset_table self.schema_fields = schema_fields @@ -220,103 +259,306 @@ def __init__( self.labels = labels self.description = description - def execute(self, context: 'Context'): - bq_hook = BigQueryHook( + self.job_id = job_id + self.deferrable = deferrable + self.result_retry = result_retry + self.result_timeout = result_timeout + self.force_rerun = force_rerun + self.reattach_states: set[str] = reattach_states or set() + self.cancel_on_kill = cancel_on_kill + + def _submit_job( + self, + hook: BigQueryHook, + job_id: str, + ) -> BigQueryJob: + # Submit a new job without waiting for it to complete. + return hook.insert_job( + configuration=self.configuration, + project_id=hook.project_id, + location=self.location, + job_id=job_id, + timeout=self.result_timeout, + retry=self.result_retry, + nowait=True, + ) + + @staticmethod + def _handle_job_error(job: BigQueryJob) -> None: + if job.error_result: + raise AirflowException(f"BigQuery job {job.job_id} failed: {job.error_result}") + + def execute(self, context: Context): + hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, location=self.location, impersonation_chain=self.impersonation_chain, ) - - if not self.schema_fields: - if self.schema_object and self.source_format != 'DATASTORE_BACKUP': - gcs_hook = GCSHook( - gcp_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to, - impersonation_chain=self.impersonation_chain, - ) - blob = gcs_hook.download( - bucket_name=self.bucket, - object_name=self.schema_object, - ) - schema_fields = json.loads(blob.decode("utf-8")) - else: - schema_fields = None - else: - schema_fields = self.schema_fields + self.hook = hook + job_id = self.hook.generate_job_id( + job_id=self.job_id, + dag_id=self.dag_id, + task_id=self.task_id, + logical_date=context["logical_date"], + configuration=self.configuration, + force_rerun=self.force_rerun, + ) self.source_objects = ( self.source_objects if isinstance(self.source_objects, list) else [self.source_objects] ) - source_uris = [f'gs://{self.bucket}/{source_object}' for source_object in self.source_objects] + source_uris = [f"gs://{self.bucket}/{source_object}" for source_object in self.source_objects] + if not self.schema_fields: + gcs_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + if self.schema_object and self.source_format != "DATASTORE_BACKUP": + schema_fields = json.loads(gcs_hook.download(self.bucket, self.schema_object).decode("utf-8")) + self.log.info("Autodetected fields from schema object: %s", schema_fields) if self.external_table: - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - bq_hook.create_external_table( - external_project_dataset_table=self.destination_project_dataset_table, - schema_fields=schema_fields, - source_uris=source_uris, - source_format=self.source_format, - autodetect=self.autodetect, - compression=self.compression, - skip_leading_rows=self.skip_leading_rows, - field_delimiter=self.field_delimiter, - max_bad_records=self.max_bad_records, - quote_character=self.quote_character, - ignore_unknown_values=self.ignore_unknown_values, - allow_quoted_newlines=self.allow_quoted_newlines, - allow_jagged_rows=self.allow_jagged_rows, - encoding=self.encoding, - src_fmt_configs=self.src_fmt_configs, - encryption_configuration=self.encryption_configuration, - labels=self.labels, - description=self.description, - ) + self.log.info("Creating a new BigQuery table for storing data...") + project_id, dataset_id, table_id = self.hook.split_tablename( + table_input=self.destination_project_dataset_table, + default_project_id=self.hook.project_id or "", + ) + table_resource = { + "tableReference": { + "projectId": project_id, + "datasetId": dataset_id, + "tableId": table_id, + }, + "labels": self.labels, + "description": self.description, + "externalDataConfiguration": { + "source_uris": source_uris, + "source_format": self.source_format, + "maxBadRecords": self.max_bad_records, + "autodetect": self.autodetect, + "compression": self.compression, + "csvOptions": { + "fieldDelimeter": self.field_delimiter, + "skipLeadingRows": self.skip_leading_rows, + "quote": self.quote_character, + "allowQuotedNewlines": self.allow_quoted_newlines, + "allowJaggedRows": self.allow_jagged_rows, + }, + }, + "location": self.location, + "encryptionConfiguration": self.encryption_configuration, + } + table_resource_checked_schema = self._check_schema_fields(table_resource) + table = self.hook.create_empty_table( + table_resource=table_resource_checked_schema, + ) + max_id = self._find_max_value_in_column() + BigQueryTableLink.persist( + context=context, + task_instance=self, + dataset_id=table.to_api_repr()["tableReference"]["datasetId"], + project_id=table.to_api_repr()["tableReference"]["projectId"], + table_id=table.to_api_repr()["tableReference"]["tableId"], + ) + return max_id else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - bq_hook.run_load( - destination_project_dataset_table=self.destination_project_dataset_table, - schema_fields=schema_fields, - source_uris=source_uris, - source_format=self.source_format, - autodetect=self.autodetect, - create_disposition=self.create_disposition, - skip_leading_rows=self.skip_leading_rows, - write_disposition=self.write_disposition, - field_delimiter=self.field_delimiter, - max_bad_records=self.max_bad_records, - quote_character=self.quote_character, - ignore_unknown_values=self.ignore_unknown_values, - allow_quoted_newlines=self.allow_quoted_newlines, - allow_jagged_rows=self.allow_jagged_rows, - encoding=self.encoding, - schema_update_options=self.schema_update_options, - src_fmt_configs=self.src_fmt_configs, - time_partitioning=self.time_partitioning, - cluster_fields=self.cluster_fields, - encryption_configuration=self.encryption_configuration, - labels=self.labels, - description=self.description, + self.log.info("Using existing BigQuery table for storing data...") + destination_project, destination_dataset, destination_table = self.hook.split_tablename( + table_input=self.destination_project_dataset_table, + default_project_id=self.hook.project_id or "", + var_name="destination_project_dataset_table", + ) + self.configuration = { + "load": { + "autodetect": self.autodetect, + "createDisposition": self.create_disposition, + "destinationTable": { + "projectId": destination_project, + "datasetId": destination_dataset, + "tableId": destination_table, + }, + "destinationTableProperties": { + "description": self.description, + "labels": self.labels, + }, + "sourceFormat": self.source_format, + "skipLeadingRows": self.skip_leading_rows, + "sourceUris": source_uris, + "writeDisposition": self.write_disposition, + "ignoreUnknownValues": self.ignore_unknown_values, + "allowQuotedNewlines": self.allow_quoted_newlines, + "encoding": self.encoding, + }, + } + self.configuration = self._check_schema_fields(self.configuration) + try: + self.log.info("Executing: %s", self.configuration) + job = self._submit_job(self.hook, job_id) + except Conflict: + # If the job already exists retrieve it + job = self.hook.get_job( + project_id=self.hook.project_id, + location=self.location, + job_id=job_id, ) + if job.state in self.reattach_states: + # We are reattaching to a job + job._begin() + self._handle_job_error(job) + else: + # Same job configuration so we need force_rerun + raise AirflowException( + f"Job with id: {job_id} already exists and is in {job.state} state. If you " + f"want to force rerun it consider setting `force_rerun=True`." + f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`" + ) - if self.max_id_key: - select_command = f'SELECT MAX({self.max_id_key}) FROM `{self.destination_project_dataset_table}`' - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - job_id = bq_hook.run_query( - sql=select_command, - use_legacy_sql=False, - ) - row = list(bq_hook.get_job(job_id).result()) - if row: - max_id = row[0] if row[0] else 0 - self.log.info( - 'Loaded BQ data with max %s.%s=%s', - self.destination_project_dataset_table, - self.max_id_key, - max_id, + job_types = { + LoadJob._JOB_TYPE: ["sourceTable", "destinationTable"], + CopyJob._JOB_TYPE: ["sourceTable", "destinationTable"], + ExtractJob._JOB_TYPE: ["sourceTable"], + QueryJob._JOB_TYPE: ["destinationTable"], + } + + if self.hook.project_id: + for job_type, tables_prop in job_types.items(): + job_configuration = job.to_api_repr()["configuration"] + if job_type in job_configuration: + for table_prop in tables_prop: + if table_prop in job_configuration[job_type]: + table = job_configuration[job_type][table_prop] + persist_kwargs = { + "context": context, + "task_instance": self, + "project_id": self.hook.project_id, + "table_id": table, + } + if not isinstance(table, str): + persist_kwargs["table_id"] = table["tableId"] + persist_kwargs["dataset_id"] = table["datasetId"] + BigQueryTableLink.persist(**persist_kwargs) + + self.job_id = job.job_id + context["ti"].xcom_push(key="job_id", value=self.job_id) + if self.deferrable: + self.defer( + timeout=self.execution_timeout, + trigger=BigQueryInsertJobTrigger( + conn_id=self.gcp_conn_id, + job_id=self.job_id, + project_id=self.hook.project_id, + ), + method_name="execute_complete", ) + else: + job.result(timeout=self.result_timeout, retry=self.result_retry) + max_id = self._find_max_value_in_column() + self._handle_job_error(job) + return max_id + + def execute_complete(self, context: Context, event: dict[str, Any]): + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + self.log.info( + "%s completed with response %s ", + self.task_id, + event["message"], + ) + return self._find_max_value_in_column() + + def _find_max_value_in_column(self): + hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + if self.max_id_key: + self.log.info(f"Selecting the MAX value from BigQuery column '{self.max_id_key}'...") + select_command = ( + f"SELECT MAX({self.max_id_key}) AS max_value " + f"FROM {self.destination_project_dataset_table}" + ) + + self.configuration = { + "query": { + "query": select_command, + "useLegacySql": False, + "schemaUpdateOptions": [], + } + } + job_id = hook.insert_job(configuration=self.configuration, project_id=hook.project_id) + rows = list(hook.get_job(job_id=job_id, location=self.location).result()) + if rows: + for row in rows: + max_id = row[0] if row[0] else 0 + self.log.info( + "Loaded BQ data with MAX value of column %s.%s: %s", + self.destination_project_dataset_table, + self.max_id_key, + max_id, + ) + return str(max_id) else: raise RuntimeError(f"The {select_command} returned no rows!") + + def _check_schema_fields(self, table_resource): + """ + Helper method to detect schema fields if they were not specified by user and autodetect=True. + If source_objects were passed, method reads the second row in CSV file. If there is at least one digit + table_resurce is returned without changes so that BigQuery can determine schema_fields in the + next step. + If there are only characters, the first row with fields is used to construct schema_fields argument + with type 'STRING'. Table_resource is updated with new schema_fileds key and returned back to operator + :param table_resource: Configuration or table_resource dictionary + :return: table_resource: Updated table_resource dict with schema_fields + """ + if not self.autodetect and not self.schema_fields: + raise RuntimeError( + "Table schema was not found. Set autodetect=True to " + "automatically set schema fields from source objects or pass " + "schema_fields explicitly" + ) + elif not self.schema_fields: + for source_object in self.source_objects: + gcs_hook = GCSHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + blob = gcs_hook.download( + bucket_name=self.schema_object_bucket, + object_name=source_object, + ) + fields, values = [item.split(",") for item in blob.decode("utf-8").splitlines()][:2] + import re + + if any(re.match(r"[\d\-\\.]+$", value) for value in values): + return table_resource + else: + schema_fields = [] + for field in fields: + schema_fields.append({"name": field, "type": "STRING", "mode": "NULLABLE"}) + self.schema_fields = schema_fields + if self.external_table: + table_resource["externalDataConfiguration"]["csvOptions"]["skipLeadingRows"] = 1 + elif not self.external_table: + table_resource["load"]["skipLeadingRows"] = 1 + if self.external_table: + table_resource["schema"] = {"fields": self.schema_fields} + elif not self.external_table: + table_resource["load"]["schema"] = {"fields": self.schema_fields} + return table_resource + + def on_kill(self) -> None: + if self.job_id and self.cancel_on_kill: + self.hook.cancel_job(job_id=self.job_id, location=self.location) # type: ignore[union-attr] + else: + self.log.info("Skipping to cancel job: %s.%s", self.location, self.job_id) diff --git a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py index 5a10aa7a32506..b0a15c98991d9 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py @@ -16,13 +16,15 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Cloud Storage operator.""" -from typing import TYPE_CHECKING, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook -WILDCARD = '*' +WILDCARD = "*" if TYPE_CHECKING: from airflow.utils.context import Context @@ -89,6 +91,8 @@ class GCSToGCSOperator(BaseOperator): account from the list granting this role to the originating account (templated). :param source_object_required: Whether you want to raise an exception when the source object doesn't exist. It doesn't have any effect when the source objects are folders or patterns. + :param exact_match: When specified, only exact match of the source object (filename) will be + copied. :Example: @@ -161,15 +165,15 @@ class GCSToGCSOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'source_bucket', - 'source_object', - 'source_objects', - 'destination_bucket', - 'destination_object', - 'delimiter', - 'impersonation_chain', + "source_bucket", + "source_object", + "source_objects", + "destination_bucket", + "destination_object", + "delimiter", + "impersonation_chain", ) - ui_color = '#f0eee4' + ui_color = "#f0eee4" def __init__( self, @@ -182,13 +186,14 @@ def __init__( delimiter=None, move_object=False, replace=True, - gcp_conn_id='google_cloud_default', + gcp_conn_id="google_cloud_default", delegate_to=None, last_modified_time=None, maximum_modified_time=None, is_older_than=None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, source_object_required=False, + exact_match=False, **kwargs, ): super().__init__(**kwargs) @@ -208,8 +213,9 @@ def __init__( self.is_older_than = is_older_than self.impersonation_chain = impersonation_chain self.source_object_required = source_object_required + self.exact_match = exact_match - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = GCSHook( gcp_conn_id=self.gcp_conn_id, @@ -228,7 +234,7 @@ def execute(self, context: 'Context'): raise AirflowException(error_msg) if self.source_objects and not all(isinstance(item, str) for item in self.source_objects): - raise AirflowException('At least, one of the `objects` in the `source_objects` is not a string') + raise AirflowException("At least, one of the `objects` in the `source_objects` is not a string") # If source_object is set, default it to source_objects if self.source_object: @@ -236,15 +242,15 @@ def execute(self, context: 'Context'): if self.destination_bucket is None: self.log.warning( - 'destination_bucket is None. Defaulting it to source_bucket (%s)', self.source_bucket + "destination_bucket is None. Defaulting it to source_bucket (%s)", self.source_bucket ) self.destination_bucket = self.source_bucket # An empty source_object means to copy all files if len(self.source_objects) == 0: - self.source_objects = [''] + self.source_objects = [""] # Raise exception if empty string `''` is used twice in source_object, this is to avoid double copy - if self.source_objects.count('') > 1: + if self.source_objects.count("") > 1: raise AirflowException("You can't have two empty strings inside source_object") # Iterate over the source_objects and do the copy @@ -260,8 +266,8 @@ def _ignore_existing_files(self, hook, prefix, **kwargs): # list all files in the Destination GCS bucket # and only keep those files which are present in # Source GCS bucket and not in Destination GCS bucket - delimiter = kwargs.get('delimiter') - objects = kwargs.get('objects') + delimiter = kwargs.get("delimiter") + objects = kwargs.get("objects") if self.destination_object is None: existing_objects = hook.list(self.destination_bucket, prefix=prefix, delimiter=delimiter) else: @@ -277,9 +283,9 @@ def _ignore_existing_files(self, hook, prefix, **kwargs): objects = set(objects) - set(existing_objects) if len(objects) > 0: - self.log.info('%s files are going to be synced: %s.', len(objects), objects) + self.log.info("%s files are going to be synced: %s.", len(objects), objects) else: - self.log.info('There are no new files to sync. Have a nice day!') + self.log.info("There are no new files to sync. Have a nice day!") return objects def _copy_source_without_wildcard(self, hook, prefix): @@ -341,6 +347,8 @@ def _copy_source_without_wildcard(self, hook, prefix): raise AirflowException(msg) for source_obj in objects: + if self.exact_match and (source_obj != prefix or not source_obj.endswith(prefix)): + continue if self.destination_object is None: destination_object = source_obj else: @@ -358,7 +366,7 @@ def _copy_source_with_wildcard(self, hook, prefix): ) raise AirflowException(error_msg) - self.log.info('Delimiter ignored because wildcard is in prefix') + self.log.info("Delimiter ignored because wildcard is in prefix") prefix_, delimiter = prefix.split(WILDCARD, 1) objects = hook.list(self.source_bucket, prefix=prefix_, delimiter=delimiter) if not self.replace: @@ -421,7 +429,7 @@ def _copy_single_object(self, hook, source_object, destination_object): return self.log.info( - 'Executing copy of gs://%s/%s to gs://%s/%s', + "Executing copy of gs://%s/%s to gs://%s/%s", self.source_bucket, source_object, self.destination_bucket, diff --git a/airflow/providers/google/cloud/transfers/gcs_to_local.py b/airflow/providers/google/cloud/transfers/gcs_to_local.py index 875fe098c93d1..185e0c164cdac 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_local.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_local.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -66,32 +67,32 @@ class GCSToLocalFilesystemOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'bucket', - 'object_name', - 'filename', - 'store_to_xcom_key', - 'impersonation_chain', - 'file_encoding', + "bucket", + "object_name", + "filename", + "store_to_xcom_key", + "impersonation_chain", + "file_encoding", ) - ui_color = '#f0eee4' + ui_color = "#f0eee4" def __init__( self, *, bucket: str, - object_name: Optional[str] = None, - filename: Optional[str] = None, - store_to_xcom_key: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - file_encoding: str = 'utf-8', + object_name: str | None = None, + filename: str | None = None, + store_to_xcom_key: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + file_encoding: str = "utf-8", **kwargs, ) -> None: # To preserve backward compatibility # TODO: Remove one day if object_name is None: - object_name = kwargs.get('object') + object_name = kwargs.get("object") if object_name is not None: self.object_name = object_name DeprecationWarning("Use 'object_name' instead of 'object'.") @@ -111,8 +112,8 @@ def __init__( self.impersonation_chain = impersonation_chain self.file_encoding = file_encoding - def execute(self, context: 'Context'): - self.log.info('Executing download: %s, %s, %s', self.bucket, self.object_name, self.filename) + def execute(self, context: Context): + self.log.info("Executing download: %s, %s, %s", self.bucket, self.object_name, self.filename) hook = GCSHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -123,8 +124,8 @@ def execute(self, context: 'Context'): file_size = hook.get_size(bucket_name=self.bucket, object_name=self.object_name) if file_size < MAX_XCOM_SIZE: file_bytes = hook.download(bucket_name=self.bucket, object_name=self.object_name) - context['ti'].xcom_push(key=self.store_to_xcom_key, value=str(file_bytes, self.file_encoding)) + context["ti"].xcom_push(key=self.store_to_xcom_key, value=str(file_bytes, self.file_encoding)) else: - raise AirflowException('The size of the downloaded file is too large to push to XCom!') + raise AirflowException("The size of the downloaded file is too large to push to XCom!") else: hook.download(bucket_name=self.bucket, object_name=self.object_name, filename=self.filename) diff --git a/airflow/providers/google/cloud/transfers/gcs_to_sftp.py b/airflow/providers/google/cloud/transfers/gcs_to_sftp.py index 1ad74cc3c1d01..834eb6e877e21 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_sftp.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_sftp.py @@ -16,9 +16,11 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Cloud Storage to SFTP operator.""" +from __future__ import annotations + import os from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -40,7 +42,7 @@ class GCSToSFTPOperator(BaseOperator): with models.DAG( "example_gcs_to_sftp", start_date=datetime(2020, 6, 19), - schedule_interval=None, + schedule=None, ) as dag: # downloads file to /tmp/sftp/folder/subfolder/file.txt copy_file_from_gcs_to_sftp = GCSToSFTPOperator( @@ -113,8 +115,8 @@ def __init__( move_object: bool = False, gcp_conn_id: str = "google_cloud_default", sftp_conn_id: str = "ssh_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -130,7 +132,7 @@ def __init__( self.impersonation_chain = impersonation_chain self.sftp_dirs = None - def execute(self, context: 'Context'): + def execute(self, context: Context): gcs_hook = GCSHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -162,7 +164,7 @@ def execute(self, context: 'Context'): self._copy_single_object(gcs_hook, sftp_hook, self.source_object, destination_path) self.log.info("Done. Uploaded '%s' file to %s", self.source_object, destination_path) - def _resolve_destination_path(self, source_object: str, prefix: Optional[str] = None) -> str: + def _resolve_destination_path(self, source_object: str, prefix: str | None = None) -> str: if not self.keep_directory_structure: if prefix: source_object = os.path.relpath(source_object, start=prefix) diff --git a/airflow/providers/google/cloud/transfers/gdrive_to_gcs.py b/airflow/providers/google/cloud/transfers/gdrive_to_gcs.py index 531f90dcbfcfb..746558537561f 100644 --- a/airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/gdrive_to_gcs.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook @@ -67,13 +68,13 @@ def __init__( self, *, bucket_name: str, - object_name: Optional[str] = None, + object_name: str | None = None, file_name: str, folder_id: str, - drive_id: Optional[str] = None, + drive_id: str | None = None, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -86,7 +87,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): gdrive_hook = GoogleDriveHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/google/cloud/transfers/gdrive_to_local.py b/airflow/providers/google/cloud/transfers/gdrive_to_local.py index c61f2db99e116..6267c3aacd172 100644 --- a/airflow/providers/google/cloud/transfers/gdrive_to_local.py +++ b/airflow/providers/google/cloud/transfers/gdrive_to_local.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.google.suite.hooks.drive import GoogleDriveHook @@ -35,6 +36,7 @@ class GoogleDriveToLocalOperator(BaseOperator): :param output_file: Path to downloaded file :param folder_id: The folder id of the folder in which the Google Drive file resides :param file_name: The name of the file residing in Google Drive + :param gcp_conn_id: The GCP connection ID to use when fetching connection info. :param drive_id: Optional. The id of the shared Google Drive in which the file resides. :param delegate_to: The account to impersonate using domain-wide delegation of authority, if any. For this to work, the service account making the request must have @@ -63,9 +65,10 @@ def __init__( output_file: str, file_name: str, folder_id: str, - drive_id: Optional[str] = None, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + drive_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -73,12 +76,14 @@ def __init__( self.folder_id = folder_id self.drive_id = drive_id self.file_name = file_name + self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): - self.log.info('Executing download: %s into %s', self.file_name, self.output_file) + def execute(self, context: Context): + self.log.info("Executing download: %s into %s", self.file_name, self.output_file) gdrive_hook = GoogleDriveHook( + gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) diff --git a/airflow/providers/google/cloud/transfers/local_to_gcs.py b/airflow/providers/google/cloud/transfers/local_to_gcs.py index b8fd0b2d3793d..037c1da9833e4 100644 --- a/airflow/providers/google/cloud/transfers/local_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/local_to_gcs.py @@ -16,9 +16,11 @@ # specific language governing permissions and limitations # under the License. """This module contains operator for uploading local file(s) to GCS.""" +from __future__ import annotations + import os from glob import glob -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook @@ -57,10 +59,10 @@ class LocalFilesystemToGCSOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'src', - 'dst', - 'bucket', - 'impersonation_chain', + "src", + "dst", + "bucket", + "impersonation_chain", ) def __init__( @@ -69,11 +71,11 @@ def __init__( src, dst, bucket, - gcp_conn_id='google_cloud_default', - mime_type='application/octet-stream', + gcp_conn_id="google_cloud_default", + mime_type="application/octet-stream", delegate_to=None, gzip=False, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -87,7 +89,7 @@ def __init__( self.gzip = gzip self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): """Uploads a file or list of files to Google Cloud Storage""" hook = GCSHook( gcp_conn_id=self.gcp_conn_id, diff --git a/airflow/providers/google/cloud/transfers/mssql_to_gcs.py b/airflow/providers/google/cloud/transfers/mssql_to_gcs.py index 113b0713d146d..0c12c01f4a7d2 100644 --- a/airflow/providers/google/cloud/transfers/mssql_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/mssql_to_gcs.py @@ -16,9 +16,10 @@ # specific language governing permissions and limitations # under the License. """MsSQL to GCS operator.""" +from __future__ import annotations + import datetime import decimal -from typing import Dict from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook @@ -26,7 +27,7 @@ class MSSQLToGCSOperator(BaseSQLToGCSOperator): """Copy data from Microsoft SQL Server to Google Cloud Storage - in JSON or CSV format. + in JSON, CSV or Parquet format. :param mssql_conn_id: Reference to a specific MSSQL hook. @@ -52,11 +53,11 @@ class MSSQLToGCSOperator(BaseSQLToGCSOperator): """ - ui_color = '#e0a98c' + ui_color = "#e0a98c" - type_map = {3: 'INTEGER', 4: 'TIMESTAMP', 5: 'NUMERIC'} + type_map = {3: "INTEGER", 4: "TIMESTAMP", 5: "NUMERIC"} - def __init__(self, *, mssql_conn_id='mssql_default', **kwargs): + def __init__(self, *, mssql_conn_id="mssql_default", **kwargs): super().__init__(**kwargs) self.mssql_conn_id = mssql_conn_id @@ -72,11 +73,11 @@ def query(self): cursor.execute(self.sql) return cursor - def field_to_bigquery(self, field) -> Dict[str, str]: + def field_to_bigquery(self, field) -> dict[str, str]: return { - 'name': field[0].replace(" ", "_"), - 'type': self.type_map.get(field[1], "STRING"), - 'mode': "NULLABLE", + "name": field[0].replace(" ", "_"), + "type": self.type_map.get(field[1], "STRING"), + "mode": "NULLABLE", } @classmethod diff --git a/airflow/providers/google/cloud/transfers/mysql_to_gcs.py b/airflow/providers/google/cloud/transfers/mysql_to_gcs.py index 2e72eaa774c38..8f9e4dad4f842 100644 --- a/airflow/providers/google/cloud/transfers/mysql_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/mysql_to_gcs.py @@ -16,11 +16,11 @@ # specific language governing permissions and limitations # under the License. """MySQL to GCS operator.""" +from __future__ import annotations import base64 from datetime import date, datetime, time, timedelta from decimal import Decimal -from typing import Dict from MySQLdb.constants import FIELD_TYPE @@ -29,7 +29,7 @@ class MySQLToGCSOperator(BaseSQLToGCSOperator): - """Copy data from MySQL to Google Cloud Storage in JSON or CSV format. + """Copy data from MySQL to Google Cloud Storage in JSON, CSV or Parquet format. .. seealso:: For more information on how to use this operator, take a look at the guide: @@ -41,27 +41,27 @@ class MySQLToGCSOperator(BaseSQLToGCSOperator): default timezone. """ - ui_color = '#a0e08c' + ui_color = "#a0e08c" type_map = { - FIELD_TYPE.BIT: 'INTEGER', - FIELD_TYPE.DATETIME: 'TIMESTAMP', - FIELD_TYPE.DATE: 'TIMESTAMP', - FIELD_TYPE.DECIMAL: 'FLOAT', - FIELD_TYPE.NEWDECIMAL: 'FLOAT', - FIELD_TYPE.DOUBLE: 'FLOAT', - FIELD_TYPE.FLOAT: 'FLOAT', - FIELD_TYPE.INT24: 'INTEGER', - FIELD_TYPE.LONG: 'INTEGER', - FIELD_TYPE.LONGLONG: 'INTEGER', - FIELD_TYPE.SHORT: 'INTEGER', - FIELD_TYPE.TIME: 'TIME', - FIELD_TYPE.TIMESTAMP: 'TIMESTAMP', - FIELD_TYPE.TINY: 'INTEGER', - FIELD_TYPE.YEAR: 'INTEGER', + FIELD_TYPE.BIT: "INTEGER", + FIELD_TYPE.DATETIME: "TIMESTAMP", + FIELD_TYPE.DATE: "TIMESTAMP", + FIELD_TYPE.DECIMAL: "FLOAT", + FIELD_TYPE.NEWDECIMAL: "FLOAT", + FIELD_TYPE.DOUBLE: "FLOAT", + FIELD_TYPE.FLOAT: "FLOAT", + FIELD_TYPE.INT24: "INTEGER", + FIELD_TYPE.LONG: "INTEGER", + FIELD_TYPE.LONGLONG: "INTEGER", + FIELD_TYPE.SHORT: "INTEGER", + FIELD_TYPE.TIME: "TIME", + FIELD_TYPE.TIMESTAMP: "TIMESTAMP", + FIELD_TYPE.TINY: "INTEGER", + FIELD_TYPE.YEAR: "INTEGER", } - def __init__(self, *, mysql_conn_id='mysql_default', ensure_utc=False, **kwargs): + def __init__(self, *, mysql_conn_id="mysql_default", ensure_utc=False, **kwargs): super().__init__(**kwargs) self.mysql_conn_id = mysql_conn_id self.ensure_utc = ensure_utc @@ -74,22 +74,22 @@ def query(self): if self.ensure_utc: # Ensure TIMESTAMP results are in UTC tz_query = "SET time_zone = '+00:00'" - self.log.info('Executing: %s', tz_query) + self.log.info("Executing: %s", tz_query) cursor.execute(tz_query) - self.log.info('Executing: %s', self.sql) + self.log.info("Executing: %s", self.sql) cursor.execute(self.sql) return cursor - def field_to_bigquery(self, field) -> Dict[str, str]: + def field_to_bigquery(self, field) -> dict[str, str]: field_type = self.type_map.get(field[1], "STRING") # Always allow TIMESTAMP to be nullable. MySQLdb returns None types # for required fields because some MySQL timestamps can't be # represented by Python's datetime (e.g. 0000-00-00 00:00:00). field_mode = "NULLABLE" if field[6] or field_type == "TIMESTAMP" else "REQUIRED" return { - 'name': field[0], - 'type': field_type, - 'mode': field_mode, + "name": field[0], + "type": field_type, + "mode": field_mode, } def convert_type(self, value, schema_type: str, **kwargs): @@ -128,5 +128,5 @@ def convert_type(self, value, schema_type: str, **kwargs): if schema_type == "INTEGER": value = int.from_bytes(value, "big") else: - value = base64.standard_b64encode(value).decode('ascii') + value = base64.standard_b64encode(value).decode("ascii") return value diff --git a/airflow/providers/google/cloud/transfers/oracle_to_gcs.py b/airflow/providers/google/cloud/transfers/oracle_to_gcs.py index 3306c9801063a..fcf8458ef9ec8 100644 --- a/airflow/providers/google/cloud/transfers/oracle_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/oracle_to_gcs.py @@ -15,21 +15,21 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import base64 import calendar from datetime import date, datetime, timedelta from decimal import Decimal -from typing import Dict -import cx_Oracle +import oracledb from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator from airflow.providers.oracle.hooks.oracle import OracleHook class OracleToGCSOperator(BaseSQLToGCSOperator): - """Copy data from Oracle to Google Cloud Storage in JSON or CSV format. + """Copy data from Oracle to Google Cloud Storage in JSON, CSV or Parquet format. .. seealso:: For more information on how to use this operator, take a look at the guide: @@ -42,21 +42,21 @@ class OracleToGCSOperator(BaseSQLToGCSOperator): default timezone. """ - ui_color = '#a0e08c' + ui_color = "#a0e08c" type_map = { - cx_Oracle.DB_TYPE_BINARY_DOUBLE: 'DECIMAL', - cx_Oracle.DB_TYPE_BINARY_FLOAT: 'DECIMAL', - cx_Oracle.DB_TYPE_BINARY_INTEGER: 'INTEGER', - cx_Oracle.DB_TYPE_BOOLEAN: 'BOOLEAN', - cx_Oracle.DB_TYPE_DATE: 'TIMESTAMP', - cx_Oracle.DB_TYPE_NUMBER: 'NUMERIC', - cx_Oracle.DB_TYPE_TIMESTAMP: 'TIMESTAMP', - cx_Oracle.DB_TYPE_TIMESTAMP_LTZ: 'TIMESTAMP', - cx_Oracle.DB_TYPE_TIMESTAMP_TZ: 'TIMESTAMP', + oracledb.DB_TYPE_BINARY_DOUBLE: "DECIMAL", # type: ignore + oracledb.DB_TYPE_BINARY_FLOAT: "DECIMAL", # type: ignore + oracledb.DB_TYPE_BINARY_INTEGER: "INTEGER", # type: ignore + oracledb.DB_TYPE_BOOLEAN: "BOOLEAN", # type: ignore + oracledb.DB_TYPE_DATE: "TIMESTAMP", # type: ignore + oracledb.DB_TYPE_NUMBER: "NUMERIC", # type: ignore + oracledb.DB_TYPE_TIMESTAMP: "TIMESTAMP", # type: ignore + oracledb.DB_TYPE_TIMESTAMP_LTZ: "TIMESTAMP", # type: ignore + oracledb.DB_TYPE_TIMESTAMP_TZ: "TIMESTAMP", # type: ignore } - def __init__(self, *, oracle_conn_id='oracle_default', ensure_utc=False, **kwargs): + def __init__(self, *, oracle_conn_id="oracle_default", ensure_utc=False, **kwargs): super().__init__(**kwargs) self.ensure_utc = ensure_utc self.oracle_conn_id = oracle_conn_id @@ -69,20 +69,20 @@ def query(self): if self.ensure_utc: # Ensure TIMESTAMP results are in UTC tz_query = "SET time_zone = '+00:00'" - self.log.info('Executing: %s', tz_query) + self.log.info("Executing: %s", tz_query) cursor.execute(tz_query) - self.log.info('Executing: %s', self.sql) + self.log.info("Executing: %s", self.sql) cursor.execute(self.sql) return cursor - def field_to_bigquery(self, field) -> Dict[str, str]: + def field_to_bigquery(self, field) -> dict[str, str]: field_type = self.type_map.get(field[1], "STRING") field_mode = "NULLABLE" if not field[6] or field_type == "TIMESTAMP" else "REQUIRED" return { - 'name': field[0], - 'type': field_type, - 'mode': field_mode, + "name": field[0], + "type": field_type, + "mode": field_mode, } def convert_type(self, value, schema_type, **kwargs): @@ -119,5 +119,5 @@ def convert_type(self, value, schema_type, **kwargs): if schema_type == "INTEGER": value = int.from_bytes(value, "big") else: - value = base64.standard_b64encode(value).decode('ascii') + value = base64.standard_b64encode(value).decode("ascii") return value diff --git a/airflow/providers/google/cloud/transfers/postgres_to_gcs.py b/airflow/providers/google/cloud/transfers/postgres_to_gcs.py index 3f3012514df06..b7637db0a9188 100644 --- a/airflow/providers/google/cloud/transfers/postgres_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/postgres_to_gcs.py @@ -16,13 +16,13 @@ # specific language governing permissions and limitations # under the License. """PostgreSQL to GCS operator.""" +from __future__ import annotations import datetime import json import time import uuid from decimal import Decimal -from typing import Dict import pendulum @@ -67,7 +67,7 @@ def description(self): class PostgresToGCSOperator(BaseSQLToGCSOperator): """ - Copy data from Postgres to Google Cloud Storage in JSON or CSV format. + Copy data from Postgres to Google Cloud Storage in JSON, CSV or Parquet format. :param postgres_conn_id: Reference to a specific Postgres hook. :param use_server_side_cursor: If server-side cursor should be used for querying postgres. @@ -75,29 +75,29 @@ class PostgresToGCSOperator(BaseSQLToGCSOperator): :param cursor_itersize: How many records are fetched at a time in case of server-side cursor. """ - ui_color = '#a0e08c' + ui_color = "#a0e08c" type_map = { - 1114: 'DATETIME', - 1184: 'TIMESTAMP', - 1082: 'DATE', - 1083: 'TIME', - 1005: 'INTEGER', - 1007: 'INTEGER', - 1016: 'INTEGER', - 20: 'INTEGER', - 21: 'INTEGER', - 23: 'INTEGER', - 16: 'BOOLEAN', - 700: 'FLOAT', - 701: 'FLOAT', - 1700: 'FLOAT', + 1114: "DATETIME", + 1184: "TIMESTAMP", + 1082: "DATE", + 1083: "TIME", + 1005: "INTEGER", + 1007: "INTEGER", + 1016: "INTEGER", + 20: "INTEGER", + 21: "INTEGER", + 23: "INTEGER", + 16: "BOOL", + 700: "FLOAT", + 701: "FLOAT", + 1700: "FLOAT", } def __init__( self, *, - postgres_conn_id='postgres_default', + postgres_conn_id="postgres_default", use_server_side_cursor=False, cursor_itersize=2000, **kwargs, @@ -121,11 +121,11 @@ def query(self): return _PostgresServerSideCursorDecorator(cursor) return cursor - def field_to_bigquery(self, field) -> Dict[str, str]: + def field_to_bigquery(self, field) -> dict[str, str]: return { - 'name': field[0], - 'type': self.type_map.get(field[1], "STRING"), - 'mode': 'REPEATED' if field[1] in (1009, 1005, 1007, 1016) else 'NULLABLE', + "name": field[0], + "type": self.type_map.get(field[1], "STRING"), + "mode": "REPEATED" if field[1] in (1009, 1005, 1007, 1016) else "NULLABLE", } def convert_type(self, value, schema_type, stringify_dict=True): diff --git a/airflow/providers/google/cloud/transfers/presto_to_gcs.py b/airflow/providers/google/cloud/transfers/presto_to_gcs.py index 1b2be5e091a19..e66531825e332 100644 --- a/airflow/providers/google/cloud/transfers/presto_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/presto_to_gcs.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Tuple +from __future__ import annotations + +from typing import Any from prestodb.client import PrestoResult from prestodb.dbapi import Cursor as PrestoCursor @@ -41,11 +43,11 @@ class _PrestoToGCSPrestoCursorAdapter: def __init__(self, cursor: PrestoCursor): self.cursor: PrestoCursor = cursor - self.rows: List[Any] = [] + self.rows: list[Any] = [] self.initialized: bool = False @property - def description(self) -> List[Tuple]: + def description(self) -> list[tuple]: """ This read-only attribute is a sequence of 7-item sequences. @@ -135,13 +137,13 @@ def __next__(self) -> Any: raise StopIteration() return result - def __iter__(self) -> "_PrestoToGCSPrestoCursorAdapter": + def __iter__(self) -> _PrestoToGCSPrestoCursorAdapter: """Return self to make cursors compatible to the iteration protocol""" return self class PrestoToGCSOperator(BaseSQLToGCSOperator): - """Copy data from PrestoDB to Google Cloud Storage in JSON or CSV format. + """Copy data from PrestoDB to Google Cloud Storage in JSON, CSV or Parquet format. :param presto_conn_id: Reference to a specific Presto hook. """ @@ -186,7 +188,7 @@ def query(self): cursor.execute(self.sql) return _PrestoToGCSPrestoCursorAdapter(cursor) - def field_to_bigquery(self, field) -> Dict[str, str]: + def field_to_bigquery(self, field) -> dict[str, str]: """Convert presto field type to BigQuery field type.""" clear_field_type = field[1].upper() # remove type argument e.g. DECIMAL(2, 10) => DECIMAL diff --git a/airflow/providers/google/cloud/transfers/s3_to_gcs.py b/airflow/providers/google/cloud/transfers/s3_to_gcs.py index d86e63018abe5..318407fd7084b 100644 --- a/airflow/providers/google/cloud/transfers/s3_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/s3_to_gcs.py @@ -15,8 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -25,7 +27,7 @@ try: from airflow.providers.amazon.aws.operators.s3 import S3ListOperator except ImportError: - from airflow.providers.amazon.aws.operators.s3_list import S3ListOperator + from airflow.providers.amazon.aws.operators.s3_list import S3ListOperator # type: ignore[no-redef] if TYPE_CHECKING: from airflow.utils.context import Context @@ -94,28 +96,28 @@ class S3ToGCSOperator(S3ListOperator): """ template_fields: Sequence[str] = ( - 'bucket', - 'prefix', - 'delimiter', - 'dest_gcs', - 'google_impersonation_chain', + "bucket", + "prefix", + "delimiter", + "dest_gcs", + "google_impersonation_chain", ) - ui_color = '#e09411' + ui_color = "#e09411" def __init__( self, *, bucket, - prefix='', - delimiter='', - aws_conn_id='aws_default', + prefix="", + delimiter="", + aws_conn_id="aws_default", verify=None, - gcp_conn_id='google_cloud_default', + gcp_conn_id="google_cloud_default", dest_gcs=None, delegate_to=None, replace=False, gzip=False, - google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + google_impersonation_chain: str | Sequence[str] | None = None, **kwargs, ): @@ -131,15 +133,15 @@ def __init__( def _check_inputs(self) -> None: if self.dest_gcs and not gcs_object_is_directory(self.dest_gcs): self.log.info( - 'Destination Google Cloud Storage path is not a valid ' + "Destination Google Cloud Storage path is not a valid " '"directory", define a path that ends with a slash "/" or ' - 'leave it empty for the root of the bucket.' + "leave it empty for the root of the bucket." ) raise AirflowException( 'The destination Google Cloud Storage path must end with a slash "/" or be empty.' ) - def execute(self, context: 'Context'): + def execute(self, context: Context): self._check_inputs() # use the super method to list all the files in an S3 bucket/key files = super().execute(context) @@ -173,9 +175,9 @@ def execute(self, context: 'Context'): files = list(set(files) - set(existing_files)) if len(files) > 0: - self.log.info('%s files are going to be synced: %s.', len(files), files) + self.log.info("%s files are going to be synced: %s.", len(files), files) else: - self.log.info('There are no new files to sync. Have a nice day!') + self.log.info("There are no new files to sync. Have a nice day!") if files: hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) @@ -184,7 +186,7 @@ def execute(self, context: 'Context'): # GCS hook builds its own in-memory file so we have to create # and pass the path file_object = hook.get_key(file, self.bucket) - with NamedTemporaryFile(mode='wb', delete=True) as f: + with NamedTemporaryFile(mode="wb", delete=True) as f: file_object.download_fileobj(f) f.flush() @@ -205,6 +207,6 @@ def execute(self, context: 'Context'): self.log.info("All done, uploaded %d files to Google Cloud Storage", len(files)) else: - self.log.info('In sync, no files needed to be uploaded to Google Cloud Storage') + self.log.info("In sync, no files needed to be uploaded to Google Cloud Storage") return files diff --git a/airflow/providers/google/cloud/transfers/salesforce_to_gcs.py b/airflow/providers/google/cloud/transfers/salesforce_to_gcs.py index 2803e3f2c2f25..4764ed33d1cb4 100644 --- a/airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/salesforce_to_gcs.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import os import tempfile -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook @@ -53,12 +54,12 @@ class SalesforceToGcsOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'query', - 'bucket_name', - 'object_name', + "query", + "bucket_name", + "object_name", ) - template_ext: Sequence[str] = ('.sql',) - template_fields_renderers = {'sql': 'sql'} + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "sql"} def __init__( self, @@ -68,7 +69,7 @@ def __init__( object_name: str, salesforce_conn_id: str, include_deleted: bool = False, - query_params: Optional[dict] = None, + query_params: dict | None = None, export_format: str = "csv", coerce_to_timestamp: bool = False, record_time_added: bool = False, @@ -89,7 +90,7 @@ def __init__( self.include_deleted = include_deleted self.query_params = query_params - def execute(self, context: 'Context'): + def execute(self, context: Context): salesforce = SalesforceHook(salesforce_conn_id=self.salesforce_conn_id) response = salesforce.make_query( query=self.query, include_deleted=self.include_deleted, query_params=self.query_params diff --git a/airflow/providers/google/cloud/transfers/sftp_to_gcs.py b/airflow/providers/google/cloud/transfers/sftp_to_gcs.py index 8f750af2b5810..75fa7bf0793b8 100644 --- a/airflow/providers/google/cloud/transfers/sftp_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/sftp_to_gcs.py @@ -16,9 +16,11 @@ # specific language governing permissions and limitations # under the License. """This module contains SFTP to Google Cloud Storage operator.""" +from __future__ import annotations + import os from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -84,14 +86,14 @@ def __init__( *, source_path: str, destination_bucket: str, - destination_path: Optional[str] = None, + destination_path: str | None = None, gcp_conn_id: str = "google_cloud_default", sftp_conn_id: str = "ssh_default", - delegate_to: Optional[str] = None, + delegate_to: str | None = None, mime_type: str = "application/octet-stream", gzip: bool = False, move_object: bool = False, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -107,7 +109,7 @@ def __init__( self.move_object = move_object self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): gcs_hook = GCSHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -170,7 +172,7 @@ def _copy_single_object( sftp_hook.delete_file(source_path) @staticmethod - def _set_destination_path(path: Union[str, None]) -> str: + def _set_destination_path(path: str | None) -> str: if path is not None: return path.lstrip("/") if path.startswith("/") else path return "" diff --git a/airflow/providers/google/cloud/transfers/sheets_to_gcs.py b/airflow/providers/google/cloud/transfers/sheets_to_gcs.py index 45b8081f2dad2..38314d1fe1041 100644 --- a/airflow/providers/google/cloud/transfers/sheets_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/sheets_to_gcs.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import csv from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Sequence from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook @@ -69,11 +70,11 @@ def __init__( *, spreadsheet_id: str, destination_bucket: str, - sheet_filter: Optional[List[str]] = None, - destination_path: Optional[str] = None, + sheet_filter: list[str] | None = None, + destination_path: str | None = None, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -90,7 +91,7 @@ def _upload_data( gcs_hook: GCSHook, hook: GSheetsHook, sheet_range: str, - sheet_values: List[Any], + sheet_values: list[Any], ) -> str: # Construct destination file path sheet = hook.get_spreadsheet(self.spreadsheet_id) @@ -113,7 +114,7 @@ def _upload_data( ) return dest_file_name - def execute(self, context: 'Context'): + def execute(self, context: Context): sheet_hook = GSheetsHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -126,7 +127,7 @@ def execute(self, context: 'Context'): ) # Pull data and upload - destination_array: List[str] = [] + destination_array: list[str] = [] sheet_titles = sheet_hook.get_sheet_titles( spreadsheet_id=self.spreadsheet_id, sheet_filter=self.sheet_filter ) diff --git a/airflow/providers/google/cloud/transfers/sql_to_gcs.py b/airflow/providers/google/cloud/transfers/sql_to_gcs.py index 46e1ad505d784..e4a3f3e9420c9 100644 --- a/airflow/providers/google/cloud/transfers/sql_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/sql_to_gcs.py @@ -16,10 +16,12 @@ # specific language governing permissions and limitations # under the License. """Base operator for SQL to GCS operators.""" +from __future__ import annotations + import abc import json from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Dict, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence import pyarrow as pa import pyarrow.parquet as pq @@ -34,7 +36,7 @@ class BaseSQLToGCSOperator(BaseOperator): """ - Copy data from SQL to Google Cloud Storage in JSON or CSV format. + Copy data from SQL to Google Cloud Storage in JSON, CSV, or Parquet format. :param sql: The SQL to execute. :param bucket: The bucket to upload to. @@ -50,7 +52,9 @@ class BaseSQLToGCSOperator(BaseOperator): filename param docs above). This param allows developers to specify the file size of the splits. Check https://cloud.google.com/storage/quotas to see the maximum allowed file size for a single object. - :param export_format: Desired format of files to be exported. + :param export_format: Desired format of files to be exported. (json, csv or parquet) + :param stringify_dict: Whether to dump Dictionary type objects + (such as JSON columns) as a string. Applies only to CSV/JSON export format. :param field_delimiter: The delimiter to be used for CSV files. :param null_marker: The null marker to be used for CSV files. :param gzip: Option to compress file for upload (does not apply to schemas). @@ -71,21 +75,22 @@ class BaseSQLToGCSOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param upload_metadata: whether to upload the row count metadata as blob metadata :param exclude_columns: set of columns to exclude from transmission """ template_fields: Sequence[str] = ( - 'sql', - 'bucket', - 'filename', - 'schema_filename', - 'schema', - 'parameters', - 'impersonation_chain', + "sql", + "bucket", + "filename", + "schema_filename", + "schema", + "parameters", + "impersonation_chain", ) - template_ext: Sequence[str] = ('.sql',) - template_fields_renderers = {'sql': 'sql'} - ui_color = '#a0e08c' + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "sql"} + ui_color = "#a0e08c" def __init__( self, @@ -93,17 +98,19 @@ def __init__( sql: str, bucket: str, filename: str, - schema_filename: Optional[str] = None, + schema_filename: str | None = None, approx_max_file_size_bytes: int = 1900000000, - export_format: str = 'json', - field_delimiter: str = ',', - null_marker: Optional[str] = None, + export_format: str = "json", + stringify_dict: bool = False, + field_delimiter: str = ",", + null_marker: str | None = None, gzip: bool = False, - schema: Optional[Union[str, list]] = None, - parameters: Optional[dict] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + schema: str | list | None = None, + parameters: dict | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + upload_metadata: bool = False, exclude_columns=None, **kwargs, ) -> None: @@ -117,6 +124,7 @@ def __init__( self.schema_filename = schema_filename self.approx_max_file_size_bytes = approx_max_file_size_bytes self.export_format = export_format.lower() + self.stringify_dict = stringify_dict self.field_delimiter = field_delimiter self.null_marker = null_marker self.gzip = gzip @@ -125,41 +133,66 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain + self.upload_metadata = upload_metadata self.exclude_columns = exclude_columns - def execute(self, context: 'Context'): + def execute(self, context: Context): self.log.info("Executing query") cursor = self.query() # If a schema is set, create a BQ schema JSON file. if self.schema_filename: - self.log.info('Writing local schema file') + self.log.info("Writing local schema file") schema_file = self._write_local_schema_file(cursor) # Flush file before uploading - schema_file['file_handle'].flush() + schema_file["file_handle"].flush() - self.log.info('Uploading schema file to GCS.') + self.log.info("Uploading schema file to GCS.") self._upload_to_gcs(schema_file) - schema_file['file_handle'].close() + schema_file["file_handle"].close() counter = 0 - self.log.info('Writing local data files') + files = [] + total_row_count = 0 + total_files = 0 + self.log.info("Writing local data files") for file_to_upload in self._write_local_data_files(cursor): # Flush file before uploading - file_to_upload['file_handle'].flush() + file_to_upload["file_handle"].flush() - self.log.info('Uploading chunk file #%d to GCS.', counter) + self.log.info("Uploading chunk file #%d to GCS.", counter) self._upload_to_gcs(file_to_upload) - self.log.info('Removing local file') - file_to_upload['file_handle'].close() + self.log.info("Removing local file") + file_to_upload["file_handle"].close() + + # Metadata to be outputted to Xcom + total_row_count += file_to_upload["file_row_count"] + total_files += 1 + files.append( + { + "file_name": file_to_upload["file_name"], + "file_mime_type": file_to_upload["file_mime_type"], + "file_row_count": file_to_upload["file_row_count"], + } + ) + counter += 1 - def convert_types(self, schema, col_type_dict, row, stringify_dict=False) -> list: + file_meta = { + "bucket": self.bucket, + "total_row_count": total_row_count, + "total_files": total_files, + "files": files, + } + + return file_meta + + def convert_types(self, schema, col_type_dict, row) -> list: """Convert values from DBAPI to output-friendly formats.""" return [ - self.convert_type(value, col_type_dict.get(name), stringify_dict=stringify_dict) + self.convert_type(value, col_type_dict.get(name), stringify_dict=self.stringify_dict) for name, value in zip(schema, row) ] @@ -171,6 +204,8 @@ def _write_local_data_files(self, cursor): names in GCS, and values are file handles to local files that contain the data for the GCS objects. """ + import os + org_schema = list(map(lambda schema_tuple: schema_tuple[0], cursor.description)) schema = [column for column in org_schema if column not in self.exclude_columns] @@ -178,31 +213,33 @@ def _write_local_data_files(self, cursor): file_no = 0 tmp_file_handle = NamedTemporaryFile(delete=True) - if self.export_format == 'csv': - file_mime_type = 'text/csv' - elif self.export_format == 'parquet': - file_mime_type = 'application/octet-stream' + if self.export_format == "csv": + file_mime_type = "text/csv" + elif self.export_format == "parquet": + file_mime_type = "application/octet-stream" else: - file_mime_type = 'application/json' + file_mime_type = "application/json" file_to_upload = { - 'file_name': self.filename.format(file_no), - 'file_handle': tmp_file_handle, - 'file_mime_type': file_mime_type, + "file_name": self.filename.format(file_no), + "file_handle": tmp_file_handle, + "file_mime_type": file_mime_type, + "file_row_count": 0, } - if self.export_format == 'csv': + if self.export_format == "csv": csv_writer = self._configure_csv_file(tmp_file_handle, schema) - if self.export_format == 'parquet': + if self.export_format == "parquet": parquet_schema = self._convert_parquet_schema(cursor) parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema) for row in cursor: - if self.export_format == 'csv': + file_to_upload["file_row_count"] += 1 + if self.export_format == "csv": row = self.convert_types(schema, col_type_dict, row) if self.null_marker is not None: row = [value if value is not None else self.null_marker for value in row] csv_writer.writerow(row) - elif self.export_format == 'parquet': + elif self.export_format == "parquet": row = self.convert_types(schema, col_type_dict, row) if self.null_marker is not None: row = [value if value is not None else self.null_marker for value in row] @@ -210,7 +247,7 @@ def _write_local_data_files(self, cursor): tbl = pa.Table.from_pydict(row_pydic, parquet_schema) parquet_writer.write_table(tbl) else: - row = self.convert_types(schema, col_type_dict, row, stringify_dict=False) + row = self.convert_types(schema, col_type_dict, row) row_dict = dict(zip(schema, row)) tmp_file_handle.write( @@ -218,34 +255,42 @@ def _write_local_data_files(self, cursor): ) # Append newline to make dumps BigQuery compatible. - tmp_file_handle.write(b'\n') + tmp_file_handle.write(b"\n") # Stop if the file exceeds the file size limit. - if tmp_file_handle.tell() >= self.approx_max_file_size_bytes: + fppos = tmp_file_handle.tell() + tmp_file_handle.seek(0, os.SEEK_END) + file_size = tmp_file_handle.tell() + tmp_file_handle.seek(fppos, os.SEEK_SET) + + if file_size >= self.approx_max_file_size_bytes: file_no += 1 - if self.export_format == 'parquet': + if self.export_format == "parquet": parquet_writer.close() yield file_to_upload tmp_file_handle = NamedTemporaryFile(delete=True) file_to_upload = { - 'file_name': self.filename.format(file_no), - 'file_handle': tmp_file_handle, - 'file_mime_type': file_mime_type, + "file_name": self.filename.format(file_no), + "file_handle": tmp_file_handle, + "file_mime_type": file_mime_type, + "file_row_count": 0, } - if self.export_format == 'csv': + if self.export_format == "csv": csv_writer = self._configure_csv_file(tmp_file_handle, schema) - if self.export_format == 'parquet': + if self.export_format == "parquet": parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema) - if self.export_format == 'parquet': + if self.export_format == "parquet": parquet_writer.close() - yield file_to_upload + # Last file may have 0 rows, don't yield if empty + if file_to_upload["file_row_count"] > 0: + yield file_to_upload def _configure_csv_file(self, file_handle, schema): """Configure a csv writer with the file_handle and write schema as headers for the new file. """ - csv_writer = csv.writer(file_handle, encoding='utf-8', delimiter=self.field_delimiter) + csv_writer = csv.writer(file_handle, encoding="utf-8", delimiter=self.field_delimiter) csv_writer.writerow(schema) return csv_writer @@ -255,21 +300,21 @@ def _configure_parquet_file(self, file_handle, parquet_schema): def _convert_parquet_schema(self, cursor): type_map = { - 'INTEGER': pa.int64(), - 'FLOAT': pa.float64(), - 'NUMERIC': pa.float64(), - 'BIGNUMERIC': pa.float64(), - 'BOOL': pa.bool_(), - 'STRING': pa.string(), - 'BYTES': pa.binary(), - 'DATE': pa.date32(), - 'DATETIME': pa.date64(), - 'TIMESTAMP': pa.timestamp('s'), + "INTEGER": pa.int64(), + "FLOAT": pa.float64(), + "NUMERIC": pa.float64(), + "BIGNUMERIC": pa.float64(), + "BOOL": pa.bool_(), + "STRING": pa.string(), + "BYTES": pa.binary(), + "DATE": pa.date32(), + "DATETIME": pa.date64(), + "TIMESTAMP": pa.timestamp("s"), } columns = [field[0] for field in cursor.description] bq_fields = [self.field_to_bigquery(field) for field in cursor.description] - bq_types = [bq_field.get('type') if bq_field is not None else None for bq_field in bq_fields] + bq_types = [bq_field.get("type") if bq_field is not None else None for bq_field in bq_fields] pq_types = [type_map.get(bq_type, pa.string()) for bq_type in bq_types] parquet_schema = pa.schema(zip(columns, pq_types)) return parquet_schema @@ -279,7 +324,7 @@ def query(self): """Execute DBAPI query.""" @abc.abstractmethod - def field_to_bigquery(self, field) -> Dict[str, str]: + def field_to_bigquery(self, field) -> dict[str, str]: """Convert a DBAPI field to BigQuery schema format.""" @abc.abstractmethod @@ -294,16 +339,16 @@ def _get_col_type_dict(self): elif isinstance(self.schema, list): schema = self.schema elif self.schema is not None: - self.log.warning('Using default schema due to unexpected type. Should be a string or list.') + self.log.warning("Using default schema due to unexpected type. Should be a string or list.") col_type_dict = {} try: - col_type_dict = {col['name']: col['type'] for col in schema} + col_type_dict = {col["name"]: col["type"] for col in schema} except KeyError: self.log.warning( - 'Using default schema due to missing name or type. Please ' - 'refer to: https://cloud.google.com/bigquery/docs/schemas' - '#specifying_a_json_schema_file' + "Using default schema due to missing name or type. Please " + "refer to: https://cloud.google.com/bigquery/docs/schemas" + "#specifying_a_json_schema_file" ) return col_type_dict @@ -331,15 +376,15 @@ def _write_local_schema_file(self, cursor): if isinstance(schema, list): schema = json.dumps(schema, sort_keys=True) - self.log.info('Using schema for %s', self.schema_filename) + self.log.info("Using schema for %s", self.schema_filename) self.log.debug("Current schema: %s", schema) tmp_schema_file_handle = NamedTemporaryFile(delete=True) - tmp_schema_file_handle.write(schema.encode('utf-8')) + tmp_schema_file_handle.write(schema.encode("utf-8")) schema_file_to_upload = { - 'file_name': self.schema_filename, - 'file_handle': tmp_schema_file_handle, - 'file_mime_type': 'application/json', + "file_name": self.schema_filename, + "file_handle": tmp_schema_file_handle, + "file_mime_type": "application/json", } return schema_file_to_upload @@ -350,10 +395,16 @@ def _upload_to_gcs(self, file_to_upload): delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) + is_data_file = file_to_upload.get("file_name") != self.schema_filename + metadata = None + if is_data_file and self.upload_metadata: + metadata = {"row_count": file_to_upload["file_row_count"]} + hook.upload( self.bucket, - file_to_upload.get('file_name'), - file_to_upload.get('file_handle').name, - mime_type=file_to_upload.get('file_mime_type'), - gzip=self.gzip if file_to_upload.get('file_name') != self.schema_filename else False, + file_to_upload.get("file_name"), + file_to_upload.get("file_handle").name, + mime_type=file_to_upload.get("file_mime_type"), + gzip=self.gzip if is_data_file else False, + metadata=metadata, ) diff --git a/airflow/providers/google/cloud/transfers/trino_to_gcs.py b/airflow/providers/google/cloud/transfers/trino_to_gcs.py index 6d2e2e223b3dc..ce104dc565649 100644 --- a/airflow/providers/google/cloud/transfers/trino_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/trino_to_gcs.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Tuple +from __future__ import annotations + +from typing import Any from trino.client import TrinoResult from trino.dbapi import Cursor as TrinoCursor @@ -41,11 +43,11 @@ class _TrinoToGCSTrinoCursorAdapter: def __init__(self, cursor: TrinoCursor): self.cursor: TrinoCursor = cursor - self.rows: List[Any] = [] + self.rows: list[Any] = [] self.initialized: bool = False @property - def description(self) -> List[Tuple]: + def description(self) -> list[tuple]: """ This read-only attribute is a sequence of 7-item sequences. @@ -135,13 +137,13 @@ def __next__(self) -> Any: raise StopIteration() return result - def __iter__(self) -> "_TrinoToGCSTrinoCursorAdapter": + def __iter__(self) -> _TrinoToGCSTrinoCursorAdapter: """Return self to make cursors compatible to the iteration protocol""" return self class TrinoToGCSOperator(BaseSQLToGCSOperator): - """Copy data from TrinoDB to Google Cloud Storage in JSON or CSV format. + """Copy data from TrinoDB to Google Cloud Storage in JSON, CSV or Parquet format. :param trino_conn_id: Reference to a specific Trino hook. """ @@ -186,7 +188,7 @@ def query(self): cursor.execute(self.sql) return _TrinoToGCSTrinoCursorAdapter(cursor) - def field_to_bigquery(self, field) -> Dict[str, str]: + def field_to_bigquery(self, field) -> dict[str, str]: """Convert trino field type to BigQuery field type.""" clear_field_type = field[1].upper() # remove type argument e.g. DECIMAL(2, 10) => DECIMAL diff --git a/airflow/providers/google/cloud/triggers/bigquery.py b/airflow/providers/google/cloud/triggers/bigquery.py new file mode 100644 index 0000000000000..271b02e3fcb02 --- /dev/null +++ b/airflow/providers/google/cloud/triggers/bigquery.py @@ -0,0 +1,529 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from typing import Any, AsyncIterator, SupportsAbs + +from aiohttp import ClientSession +from aiohttp.client_exceptions import ClientResponseError + +from airflow.providers.google.cloud.hooks.bigquery import BigQueryAsyncHook, BigQueryTableAsyncHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class BigQueryInsertJobTrigger(BaseTrigger): + """ + BigQueryInsertJobTrigger run on the trigger worker to perform insert operation + + :param conn_id: Reference to google cloud connection id + :param job_id: The ID of the job. It will be suffixed with hash of job configuration + :param project_id: Google Cloud Project where the job is running + :param dataset_id: The dataset ID of the requested table. (templated) + :param table_id: The table ID of the requested table. (templated) + :param poll_interval: polling period in seconds to check for the status + """ + + def __init__( + self, + conn_id: str, + job_id: str | None, + project_id: str | None, + dataset_id: str | None = None, + table_id: str | None = None, + poll_interval: float = 4.0, + ): + super().__init__() + self.log.info("Using the connection %s .", conn_id) + self.conn_id = conn_id + self.job_id = job_id + self._job_conn = None + self.dataset_id = dataset_id + self.project_id = project_id + self.table_id = table_id + self.poll_interval = poll_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes BigQueryInsertJobTrigger arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger", + { + "conn_id": self.conn_id, + "job_id": self.job_id, + "dataset_id": self.dataset_id, + "project_id": self.project_id, + "table_id": self.table_id, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] + """Gets current job execution status and yields a TriggerEvent""" + hook = self._get_async_hook() + while True: + try: + # Poll for job execution status + response_from_hook = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id) + self.log.debug("Response from hook: %s", response_from_hook) + + if response_from_hook == "success": + yield TriggerEvent( + { + "job_id": self.job_id, + "status": "success", + "message": "Job completed", + } + ) + elif response_from_hook == "pending": + self.log.info("Query is still running...") + self.log.info("Sleeping for %s seconds.", self.poll_interval) + await asyncio.sleep(self.poll_interval) + else: + yield TriggerEvent({"status": "error", "message": response_from_hook}) + + except Exception as e: + self.log.exception("Exception occurred while checking for query completion") + yield TriggerEvent({"status": "error", "message": str(e)}) + + def _get_async_hook(self) -> BigQueryAsyncHook: + return BigQueryAsyncHook(gcp_conn_id=self.conn_id) + + +class BigQueryCheckTrigger(BigQueryInsertJobTrigger): + """BigQueryCheckTrigger run on the trigger worker""" + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes BigQueryCheckTrigger arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.bigquery.BigQueryCheckTrigger", + { + "conn_id": self.conn_id, + "job_id": self.job_id, + "dataset_id": self.dataset_id, + "project_id": self.project_id, + "table_id": self.table_id, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] + """Gets current job execution status and yields a TriggerEvent""" + hook = self._get_async_hook() + while True: + try: + # Poll for job execution status + response_from_hook = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id) + if response_from_hook == "success": + query_results = await hook.get_job_output(job_id=self.job_id, project_id=self.project_id) + + records = hook.get_records(query_results) + + # If empty list, then no records are available + if not records: + yield TriggerEvent( + { + "status": "success", + "records": None, + } + ) + else: + # Extract only first record from the query results + first_record = records.pop(0) + yield TriggerEvent( + { + "status": "success", + "records": first_record, + } + ) + return + + elif response_from_hook == "pending": + self.log.info("Query is still running...") + self.log.info("Sleeping for %s seconds.", self.poll_interval) + await asyncio.sleep(self.poll_interval) + else: + yield TriggerEvent({"status": "error", "message": response_from_hook}) + except Exception as e: + self.log.exception("Exception occurred while checking for query completion") + yield TriggerEvent({"status": "error", "message": str(e)}) + + +class BigQueryGetDataTrigger(BigQueryInsertJobTrigger): + """BigQueryGetDataTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class""" + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes BigQueryInsertJobTrigger arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.bigquery.BigQueryGetDataTrigger", + { + "conn_id": self.conn_id, + "job_id": self.job_id, + "dataset_id": self.dataset_id, + "project_id": self.project_id, + "table_id": self.table_id, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] + """Gets current job execution status and yields a TriggerEvent with response data""" + hook = self._get_async_hook() + while True: + try: + # Poll for job execution status + response_from_hook = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id) + if response_from_hook == "success": + query_results = await hook.get_job_output(job_id=self.job_id, project_id=self.project_id) + records = hook.get_records(query_results) + self.log.debug("Response from hook: %s", response_from_hook) + yield TriggerEvent( + { + "status": "success", + "message": response_from_hook, + "records": records, + } + ) + return + elif response_from_hook == "pending": + self.log.info("Query is still running...") + self.log.info("Sleeping for %s seconds.", self.poll_interval) + await asyncio.sleep(self.poll_interval) + else: + yield TriggerEvent({"status": "error", "message": response_from_hook}) + return + except Exception as e: + self.log.exception("Exception occurred while checking for query completion") + yield TriggerEvent({"status": "error", "message": str(e)}) + return + + +class BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger): + """ + BigQueryIntervalCheckTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class + + :param conn_id: Reference to google cloud connection id + :param first_job_id: The ID of the job 1 performed + :param second_job_id: The ID of the job 2 performed + :param project_id: Google Cloud Project where the job is running + :param dataset_id: The dataset ID of the requested table. (templated) + :param table: table name + :param metrics_thresholds: dictionary of ratios indexed by metrics + :param date_filter_column: column name + :param days_back: number of days between ds and the ds we want to check + against + :param ratio_formula: ration formula + :param ignore_zero: boolean value to consider zero or not + :param table_id: The table ID of the requested table. (templated) + :param poll_interval: polling period in seconds to check for the status + """ + + def __init__( + self, + conn_id: str, + first_job_id: str, + second_job_id: str, + project_id: str | None, + table: str, + metrics_thresholds: dict[str, int], + date_filter_column: str | None = "ds", + days_back: SupportsAbs[int] = -7, + ratio_formula: str = "max_over_min", + ignore_zero: bool = True, + dataset_id: str | None = None, + table_id: str | None = None, + poll_interval: float = 4.0, + ): + super().__init__( + conn_id=conn_id, + job_id=first_job_id, + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, + poll_interval=poll_interval, + ) + self.conn_id = conn_id + self.first_job_id = first_job_id + self.second_job_id = second_job_id + self.project_id = project_id + self.table = table + self.metrics_thresholds = metrics_thresholds + self.date_filter_column = date_filter_column + self.days_back = days_back + self.ratio_formula = ratio_formula + self.ignore_zero = ignore_zero + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes BigQueryCheckTrigger arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.bigquery.BigQueryIntervalCheckTrigger", + { + "conn_id": self.conn_id, + "first_job_id": self.first_job_id, + "second_job_id": self.second_job_id, + "project_id": self.project_id, + "table": self.table, + "metrics_thresholds": self.metrics_thresholds, + "date_filter_column": self.date_filter_column, + "days_back": self.days_back, + "ratio_formula": self.ratio_formula, + "ignore_zero": self.ignore_zero, + }, + ) + + async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] + """Gets current job execution status and yields a TriggerEvent""" + hook = self._get_async_hook() + while True: + try: + first_job_response_from_hook = await hook.get_job_status( + job_id=self.first_job_id, project_id=self.project_id + ) + second_job_response_from_hook = await hook.get_job_status( + job_id=self.second_job_id, project_id=self.project_id + ) + + if first_job_response_from_hook == "success" and second_job_response_from_hook == "success": + first_query_results = await hook.get_job_output( + job_id=self.first_job_id, project_id=self.project_id + ) + + second_query_results = await hook.get_job_output( + job_id=self.second_job_id, project_id=self.project_id + ) + + first_records = hook.get_records(first_query_results) + + second_records = hook.get_records(second_query_results) + + # If empty list, then no records are available + if not first_records: + first_job_row: str | None = None + else: + # Extract only first record from the query results + first_job_row = first_records.pop(0) + + # If empty list, then no records are available + if not second_records: + second_job_row: str | None = None + else: + # Extract only first record from the query results + second_job_row = second_records.pop(0) + + hook.interval_check( + first_job_row, + second_job_row, + self.metrics_thresholds, + self.ignore_zero, + self.ratio_formula, + ) + + yield TriggerEvent( + { + "status": "success", + "message": "Job completed", + "first_row_data": first_job_row, + "second_row_data": second_job_row, + } + ) + return + elif first_job_response_from_hook == "pending" or second_job_response_from_hook == "pending": + self.log.info("Query is still running...") + self.log.info("Sleeping for %s seconds.", self.poll_interval) + await asyncio.sleep(self.poll_interval) + else: + yield TriggerEvent( + {"status": "error", "message": second_job_response_from_hook, "data": None} + ) + return + + except Exception as e: + self.log.exception("Exception occurred while checking for query completion") + yield TriggerEvent({"status": "error", "message": str(e)}) + return + + +class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger): + """ + BigQueryValueCheckTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class + + :param conn_id: Reference to google cloud connection id + :param sql: the sql to be executed + :param pass_value: pass value + :param job_id: The ID of the job + :param project_id: Google Cloud Project where the job is running + :param tolerance: certain metrics for tolerance + :param dataset_id: The dataset ID of the requested table. (templated) + :param table_id: The table ID of the requested table. (templated) + :param poll_interval: polling period in seconds to check for the status + """ + + def __init__( + self, + conn_id: str, + sql: str, + pass_value: int | float | str, + job_id: str | None, + project_id: str | None, + tolerance: Any = None, + dataset_id: str | None = None, + table_id: str | None = None, + poll_interval: float = 4.0, + ): + super().__init__( + conn_id=conn_id, + job_id=job_id, + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, + poll_interval=poll_interval, + ) + self.sql = sql + self.pass_value = pass_value + self.tolerance = tolerance + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes BigQueryValueCheckTrigger arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.bigquery.BigQueryValueCheckTrigger", + { + "conn_id": self.conn_id, + "pass_value": self.pass_value, + "job_id": self.job_id, + "dataset_id": self.dataset_id, + "project_id": self.project_id, + "sql": self.sql, + "table_id": self.table_id, + "tolerance": self.tolerance, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] + """Gets current job execution status and yields a TriggerEvent""" + hook = self._get_async_hook() + while True: + try: + # Poll for job execution status + response_from_hook = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id) + if response_from_hook == "success": + query_results = await hook.get_job_output(job_id=self.job_id, project_id=self.project_id) + records = hook.get_records(query_results) + records = records.pop(0) if records else None + hook.value_check(self.sql, self.pass_value, records, self.tolerance) + yield TriggerEvent({"status": "success", "message": "Job completed", "records": records}) + return + elif response_from_hook == "pending": + self.log.info("Query is still running...") + self.log.info("Sleeping for %s seconds.", self.poll_interval) + await asyncio.sleep(self.poll_interval) + else: + yield TriggerEvent({"status": "error", "message": response_from_hook, "records": None}) + return + + except Exception as e: + self.log.exception("Exception occurred while checking for query completion") + yield TriggerEvent({"status": "error", "message": str(e)}) + return + + +class BigQueryTableExistenceTrigger(BaseTrigger): + """ + Initialize the BigQuery Table Existence Trigger with needed parameters + + :param project_id: Google Cloud Project where the job is running + :param dataset_id: The dataset ID of the requested table. + :param table_id: The table ID of the requested table. + :param gcp_conn_id: Reference to google cloud connection id + :param hook_params: params for hook + :param poll_interval: polling period in seconds to check for the status + """ + + def __init__( + self, + project_id: str, + dataset_id: str, + table_id: str, + gcp_conn_id: str, + hook_params: dict[str, Any], + poll_interval: float = 4.0, + ): + self.dataset_id = dataset_id + self.project_id = project_id + self.table_id = table_id + self.gcp_conn_id: str = gcp_conn_id + self.poll_interval = poll_interval + self.hook_params = hook_params + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes BigQueryTableExistenceTrigger arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger", + { + "dataset_id": self.dataset_id, + "project_id": self.project_id, + "table_id": self.table_id, + "gcp_conn_id": self.gcp_conn_id, + "poll_interval": self.poll_interval, + "hook_params": self.hook_params, + }, + ) + + def _get_async_hook(self) -> BigQueryTableAsyncHook: + return BigQueryTableAsyncHook(gcp_conn_id=self.gcp_conn_id) + + async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] + """Will run until the table exists in the Google Big Query.""" + while True: + try: + hook = self._get_async_hook() + response = await self._table_exists( + hook=hook, dataset=self.dataset_id, table_id=self.table_id, project_id=self.project_id + ) + if response: + yield TriggerEvent({"status": "success", "message": "success"}) + return + await asyncio.sleep(self.poll_interval) + except Exception as e: + self.log.exception("Exception occurred while checking for Table existence") + yield TriggerEvent({"status": "error", "message": str(e)}) + return + + async def _table_exists( + self, hook: BigQueryTableAsyncHook, dataset: str, table_id: str, project_id: str + ) -> bool: + """ + Create client session and make call to BigQueryTableAsyncHook and check for the table in + Google Big Query. + + :param hook: BigQueryTableAsyncHook Hook class + :param dataset: The name of the dataset in which to look for the table storage bucket. + :param table_id: The name of the table to check the existence of. + :param project_id: The Google cloud project in which to look for the table. + The connection supplied to the hook must provide + access to the specified project. + """ + async with ClientSession() as session: + try: + client = await hook.get_table_client( + dataset=dataset, table_id=table_id, project_id=project_id, session=session + ) + response = await client.get() + return True if response else False + except ClientResponseError as err: + if err.status == 404: + return False + raise err diff --git a/airflow/providers/google/cloud/triggers/cloud_composer.py b/airflow/providers/google/cloud/triggers/cloud_composer.py index e1e5e009a54b6..20613b70abe8d 100644 --- a/airflow/providers/google/cloud/triggers/cloud_composer.py +++ b/airflow/providers/google/cloud/triggers/cloud_composer.py @@ -16,22 +16,14 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import asyncio -import logging -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Sequence from airflow import AirflowException -from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerHook - -try: - from airflow.triggers.base import BaseTrigger, TriggerEvent -except ImportError: - logging.getLogger(__name__).warning( - 'Deferrable Operators only work starting Airflow 2.2', - exc_info=True, - ) - BaseTrigger = object # type: ignore - TriggerEvent = None # type: ignore +from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerAsyncHook +from airflow.triggers.base import BaseTrigger, TriggerEvent class CloudComposerExecutionTrigger(BaseTrigger): @@ -43,30 +35,28 @@ def __init__( region: str, operation_name: str, gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - delegate_to: Optional[str] = None, + impersonation_chain: str | Sequence[str] | None = None, + delegate_to: str | None = None, pooling_period_seconds: int = 30, ): super().__init__() self.project_id = project_id self.region = region self.operation_name = operation_name - self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain self.delegate_to = delegate_to - self.pooling_period_seconds = pooling_period_seconds - self.gcp_hook = CloudComposerHook( + self.gcp_hook = CloudComposerAsyncHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, delegate_to=self.delegate_to, ) - def serialize(self) -> Tuple[str, Dict[str, Any]]: + def serialize(self) -> tuple[str, dict[str, Any]]: return ( - 'airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerExecutionTrigger', + "airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerExecutionTrigger", { "project_id": self.project_id, "region": self.region, @@ -80,7 +70,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: async def run(self): while True: - operation = self.gcp_hook.get_operation(operation_name=self.operation_name) + operation = await self.gcp_hook.get_operation(operation_name=self.operation_name) if operation.done: break elif operation.error.message: diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py new file mode 100644 index 0000000000000..d4d12a3f8df89 --- /dev/null +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Dataproc triggers.""" +from __future__ import annotations + +import asyncio +from typing import Sequence + +from google.cloud.dataproc_v1 import JobStatus + +from airflow import AirflowException +from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class DataprocBaseTrigger(BaseTrigger): + """ + Trigger that periodically polls information from Dataproc API to verify job status. + Implementation leverages asynchronous transport. + """ + + def __init__( + self, + job_id: str, + region: str, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + delegate_to: str | None = None, + polling_interval_seconds: int = 30, + ): + super().__init__() + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.job_id = job_id + self.project_id = project_id + self.region = region + self.polling_interval_seconds = polling_interval_seconds + self.delegate_to = delegate_to + self.hook = DataprocAsyncHook( + delegate_to=self.delegate_to, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + def serialize(self): + return ( + "airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger", + { + "job_id": self.job_id, + "project_id": self.project_id, + "region": self.region, + "gcp_conn_id": self.gcp_conn_id, + "delegate_to": self.delegate_to, + "impersonation_chain": self.impersonation_chain, + "polling_interval_seconds": self.polling_interval_seconds, + }, + ) + + async def run(self): + while True: + job = await self.hook.get_job(project_id=self.project_id, region=self.region, job_id=self.job_id) + state = job.status.state + self.log.info("Dataproc job: %s is in state: %s", self.job_id, state) + if state in (JobStatus.State.ERROR, JobStatus.State.DONE, JobStatus.State.CANCELLED): + if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED): + break + elif state == JobStatus.State.ERROR: + raise AirflowException(f"Dataproc job execution failed {self.job_id}") + await asyncio.sleep(self.polling_interval_seconds) + yield TriggerEvent({"job_id": self.job_id, "job_state": state}) diff --git a/airflow/providers/google/cloud/utils/bigquery.py b/airflow/providers/google/cloud/utils/bigquery.py new file mode 100644 index 0000000000000..03753c2423691 --- /dev/null +++ b/airflow/providers/google/cloud/utils/bigquery.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + + +def bq_cast(string_field: str, bq_type: str) -> None | int | float | bool | str: + """ + Helper method that casts a BigQuery row to the appropriate data types. + This is useful because BigQuery returns all fields as strings. + """ + if string_field is None: + return None + elif bq_type == "INTEGER": + return int(string_field) + elif bq_type in ("FLOAT", "TIMESTAMP"): + return float(string_field) + elif bq_type == "BOOLEAN": + if string_field not in ["true", "false"]: + raise ValueError(f"{string_field} must have value 'true' or 'false'") + return string_field == "true" + else: + return string_field diff --git a/airflow/providers/google/cloud/utils/bigquery_get_data.py b/airflow/providers/google/cloud/utils/bigquery_get_data.py index d5a4668297b81..39ab1ef35b989 100644 --- a/airflow/providers/google/cloud/utils/bigquery_get_data.py +++ b/airflow/providers/google/cloud/utils/bigquery_get_data.py @@ -14,10 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from collections.abc import Iterator from logging import Logger -from typing import List, Optional, Union from google.cloud.bigquery.table import Row @@ -30,14 +30,14 @@ def bigquery_get_data( table_id: str, big_query_hook: BigQueryHook, batch_size: int, - selected_fields: Optional[Union[List[str], str]], + selected_fields: list[str] | str | None, ) -> Iterator: - logger.info('Fetching Data from:') - logger.info('Dataset: %s ; Table: %s', dataset_id, table_id) + logger.info("Fetching Data from:") + logger.info("Dataset: %s ; Table: %s", dataset_id, table_id) i = 0 while True: - rows: List[Row] = big_query_hook.list_rows( + rows: list[Row] = big_query_hook.list_rows( dataset_id=dataset_id, table_id=table_id, max_results=batch_size, @@ -46,10 +46,10 @@ def bigquery_get_data( ) if len(rows) == 0: - logger.info('Job Finished') + logger.info("Job Finished") return - logger.info('Total Extracted rows: %s', len(rows) + i * batch_size) + logger.info("Total Extracted rows: %s", len(rows) + i * batch_size) yield [row.values() for row in rows] diff --git a/airflow/providers/google/cloud/utils/credentials_provider.py b/airflow/providers/google/cloud/utils/credentials_provider.py index 1cf33ea70b056..f10dd3d4946c2 100644 --- a/airflow/providers/google/cloud/utils/credentials_provider.py +++ b/airflow/providers/google/cloud/utils/credentials_provider.py @@ -19,11 +19,13 @@ This module contains a mechanism for providing temporary Google Cloud authentication. """ +from __future__ import annotations + import json import logging import tempfile from contextlib import ExitStack, contextmanager -from typing import Collection, Dict, Generator, Optional, Sequence, Tuple, Union +from typing import Collection, Generator, Sequence from urllib.parse import urlencode import google.auth @@ -40,13 +42,13 @@ log = logging.getLogger(__name__) AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT = "AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT" -_DEFAULT_SCOPES: Sequence[str] = ('https://www.googleapis.com/auth/cloud-platform',) +_DEFAULT_SCOPES: Sequence[str] = ("https://www.googleapis.com/auth/cloud-platform",) def build_gcp_conn( - key_file_path: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - project_id: Optional[str] = None, + key_file_path: str | None = None, + scopes: Sequence[str] | None = None, + project_id: str | None = None, ) -> str: """ Builds a uri that can be used as :envvar:`AIRFLOW_CONN_{CONN_ID}` with provided service key, @@ -58,16 +60,14 @@ def build_gcp_conn( :return: String representing Airflow connection. """ conn = "google-cloud-platform://?{}" - extras = "extra__google_cloud_platform" - query_params = {} if key_file_path: - query_params[f"{extras}__key_path"] = key_file_path + query_params["key_path"] = key_file_path if scopes: scopes_string = ",".join(scopes) - query_params[f"{extras}__scope"] = scopes_string + query_params["scope"] = scopes_string if project_id: - query_params[f"{extras}__projects"] = project_id + query_params["projects"] = project_id query = urlencode(query_params) return conn.format(query) @@ -75,8 +75,8 @@ def build_gcp_conn( @contextmanager def provide_gcp_credentials( - key_file_path: Optional[str] = None, - key_file_dict: Optional[Dict] = None, + key_file_path: str | None = None, + key_file_dict: dict | None = None, ) -> Generator[None, None, None]: """ Context manager that provides a Google Cloud credentials for application supporting @@ -111,9 +111,9 @@ def provide_gcp_credentials( @contextmanager def provide_gcp_connection( - key_file_path: Optional[str] = None, - scopes: Optional[Sequence] = None, - project_id: Optional[str] = None, + key_file_path: str | None = None, + scopes: Sequence | None = None, + project_id: str | None = None, ) -> Generator[None, None, None]: """ Context manager that provides a temporary value of :envvar:`AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT` @@ -135,9 +135,9 @@ def provide_gcp_connection( @contextmanager def provide_gcp_conn_and_credentials( - key_file_path: Optional[str] = None, - scopes: Optional[Sequence] = None, - project_id: Optional[str] = None, + key_file_path: str | None = None, + scopes: Sequence | None = None, + project_id: str | None = None, ) -> Generator[None, None, None]: """ Context manager that provides both: @@ -173,6 +173,9 @@ class _CredentialProvider(LoggingMixin): :param key_path: Path to Google Cloud Service Account key file (JSON). :param keyfile_dict: A dict representing Cloud Service Account as in the Credential JSON file + :param key_secret_name: Keyfile Secret Name in GCP Secret Manager. + :param key_secret_project_id: Project ID to read the secrets from. If not passed, the project ID from + default credentials will be used. :param scopes: OAuth scopes for the connection :param delegate_to: The account to impersonate using domain-wide delegation of authority, if any. For this to work, the service account making the request must have @@ -191,14 +194,15 @@ class to configure Logger. def __init__( self, - key_path: Optional[str] = None, - keyfile_dict: Optional[Dict[str, str]] = None, - key_secret_name: Optional[str] = None, - scopes: Optional[Collection[str]] = None, - delegate_to: Optional[str] = None, + key_path: str | None = None, + keyfile_dict: dict[str, str] | None = None, + key_secret_name: str | None = None, + key_secret_project_id: str | None = None, + scopes: Collection[str] | None = None, + delegate_to: str | None = None, disable_logging: bool = False, - target_principal: Optional[str] = None, - delegates: Optional[Sequence[str]] = None, + target_principal: str | None = None, + delegates: Sequence[str] | None = None, ) -> None: super().__init__() key_options = [key_path, key_secret_name, keyfile_dict] @@ -210,13 +214,14 @@ def __init__( self.key_path = key_path self.keyfile_dict = keyfile_dict self.key_secret_name = key_secret_name + self.key_secret_project_id = key_secret_project_id self.scopes = scopes self.delegate_to = delegate_to self.disable_logging = disable_logging self.target_principal = target_principal self.delegates = delegates - def get_credentials_and_project(self) -> Tuple[google.auth.credentials.Credentials, str]: + def get_credentials_and_project(self) -> tuple[google.auth.credentials.Credentials, str]: """ Get current credentials and project ID. @@ -232,7 +237,7 @@ def get_credentials_and_project(self) -> Tuple[google.auth.credentials.Credentia credentials, project_id = self._get_credentials_using_adc() if self.delegate_to: - if hasattr(credentials, 'with_subject'): + if hasattr(credentials, "with_subject"): credentials = credentials.with_subject(self.delegate_to) else: raise AirflowException( @@ -254,10 +259,10 @@ def get_credentials_and_project(self) -> Tuple[google.auth.credentials.Credentia return credentials, project_id def _get_credentials_using_keyfile_dict(self): - self._log_debug('Getting connection using JSON Dict') + self._log_debug("Getting connection using JSON Dict") # Depending on how the JSON was formatted, it may contain # escaped newlines. Convert those to actual newlines. - self.keyfile_dict['private_key'] = self.keyfile_dict['private_key'].replace('\\n', '\n') + self.keyfile_dict["private_key"] = self.keyfile_dict["private_key"].replace("\\n", "\n") credentials = google.oauth2.service_account.Credentials.from_service_account_info( self.keyfile_dict, scopes=self.scopes ) @@ -265,13 +270,13 @@ def _get_credentials_using_keyfile_dict(self): return credentials, project_id def _get_credentials_using_key_path(self): - if self.key_path.endswith('.p12'): - raise AirflowException('Legacy P12 key file are not supported, use a JSON key file.') + if self.key_path.endswith(".p12"): + raise AirflowException("Legacy P12 key file are not supported, use a JSON key file.") - if not self.key_path.endswith('.json'): - raise AirflowException('Unrecognised extension for key file.') + if not self.key_path.endswith(".json"): + raise AirflowException("Unrecognised extension for key file.") - self._log_debug('Getting connection using JSON key file %s', self.key_path) + self._log_debug("Getting connection using JSON key file %s", self.key_path) credentials = google.oauth2.service_account.Credentials.from_service_account_file( self.key_path, scopes=self.scopes ) @@ -279,23 +284,26 @@ def _get_credentials_using_key_path(self): return credentials, project_id def _get_credentials_using_key_secret_name(self): - self._log_debug('Getting connection using JSON key data from GCP secret: %s', self.key_secret_name) + self._log_debug("Getting connection using JSON key data from GCP secret: %s", self.key_secret_name) # Use ADC to access GCP Secret Manager. adc_credentials, adc_project_id = google.auth.default(scopes=self.scopes) secret_manager_client = _SecretManagerClient(credentials=adc_credentials) if not secret_manager_client.is_valid_secret_name(self.key_secret_name): - raise AirflowException('Invalid secret name specified for fetching JSON key data.') + raise AirflowException("Invalid secret name specified for fetching JSON key data.") - secret_value = secret_manager_client.get_secret(self.key_secret_name, adc_project_id) + secret_value = secret_manager_client.get_secret( + secret_id=self.key_secret_name, + project_id=self.key_secret_project_id if self.key_secret_project_id else adc_project_id, + ) if secret_value is None: raise AirflowException(f"Failed getting value of secret {self.key_secret_name}.") try: keyfile_dict = json.loads(secret_value) except json.decoder.JSONDecodeError: - raise AirflowException('Key data read from GCP Secret Manager is not valid JSON.') + raise AirflowException("Key data read from GCP Secret Manager is not valid JSON.") credentials = google.oauth2.service_account.Credentials.from_service_account_info( keyfile_dict, scopes=self.scopes @@ -305,7 +313,7 @@ def _get_credentials_using_key_secret_name(self): def _get_credentials_using_adc(self): self._log_info( - 'Getting connection using `google.auth.default()` since no key file is defined for hook.' + "Getting connection using `google.auth.default()` since no key file is defined for hook." ) credentials, project_id = google.auth.default(scopes=self.scopes) return credentials, project_id @@ -319,26 +327,25 @@ def _log_debug(self, *args, **kwargs) -> None: self.log.debug(*args, **kwargs) -def get_credentials_and_project_id(*args, **kwargs) -> Tuple[google.auth.credentials.Credentials, str]: +def get_credentials_and_project_id(*args, **kwargs) -> tuple[google.auth.credentials.Credentials, str]: """Returns the Credentials object for Google API and the associated project_id.""" return _CredentialProvider(*args, **kwargs).get_credentials_and_project() -def _get_scopes(scopes: Optional[str] = None) -> Sequence[str]: +def _get_scopes(scopes: str | None = None) -> Sequence[str]: """ Parse a comma-separated string containing OAuth2 scopes if `scopes` is provided. Otherwise, default scope will be returned. :param scopes: A comma-separated string containing OAuth2 scopes :return: Returns the scope defined in the connection configuration, or the default scope - :rtype: Sequence[str] """ - return [s.strip() for s in scopes.split(',')] if scopes else _DEFAULT_SCOPES + return [s.strip() for s in scopes.split(",")] if scopes else _DEFAULT_SCOPES def _get_target_principal_and_delegates( - impersonation_chain: Optional[Union[str, Sequence[str]]] = None -) -> Tuple[Optional[str], Optional[Sequence[str]]]: + impersonation_chain: str | Sequence[str] | None = None, +) -> tuple[str | None, Sequence[str] | None]: """ Analyze contents of impersonation_chain and return target_principal (the service account to directly impersonate using short-term credentials, if any) and optional list of delegates @@ -348,7 +355,6 @@ def _get_target_principal_and_delegates( account :return: Returns the tuple of target_principal and delegates - :rtype: Tuple[Optional[str], Optional[Sequence[str]]] """ if not impersonation_chain: return None, None @@ -366,10 +372,9 @@ def _get_project_id_from_service_account_email(service_account_email: str) -> st :param service_account_email: email of the service account. :return: Returns the project_id of the provided service account. - :rtype: str """ try: - return service_account_email.split('@')[1].split('.')[0] + return service_account_email.split("@")[1].split(".")[0] except IndexError: raise AirflowException( f"Could not extract project_id from service account's email: {service_account_email}." diff --git a/airflow/providers/google/cloud/utils/dataform.py b/airflow/providers/google/cloud/utils/dataform.py new file mode 100644 index 0000000000000..d47b8228a1a08 --- /dev/null +++ b/airflow/providers/google/cloud/utils/dataform.py @@ -0,0 +1,207 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +from enum import Enum +from typing import Mapping + +from airflow.providers.google.cloud.operators.dataform import ( + DataformInstallNpmPackagesOperator, + DataformMakeDirectoryOperator, + DataformWriteFileOperator, +) + + +class DataformLocations(str, Enum): + """Enum for storing available locations for resources in Dataform.""" + + US = "US" + EUROPE = "EU" + + +def make_initialization_workspace_flow( + project_id: str, + region: str, + repository_id: str, + workspace_id: str, + package_name: str | None = None, + without_installation: bool = False, +) -> tuple: + """ + Creates flow which simulates the initialization of the default project. + :param project_id: Required. The ID of the Google Cloud project where workspace located. + :param region: Required. The ID of the Google Cloud region where workspace located. + :param repository_id: Required. The ID of the Dataform repository where workspace located. + :param workspace_id: Required. The ID of the Dataform workspace which requires initialization. + :param package_name: Name of the package. If value is not provided then workspace_id will be used. + :param without_installation: Defines should installation of npm packages be added to flow. + """ + make_definitions_directory = DataformMakeDirectoryOperator( + task_id="make-definitions-directory", + project_id=project_id, + region=region, + repository_id=repository_id, + workspace_id=workspace_id, + directory_path="definitions", + ) + + first_view_content = b""" + -- This is an example SQLX file to help you learn the basics of Dataform. + -- Visit https://cloud.google.com/dataform/docs/how-to for more information on how to configure + -- your SQL workflow. + + -- You can delete this file, then commit and push your changes to your repository when you are ready. + + -- Config blocks allow you to configure, document, and test your data assets. + config { + type: "view", // Creates a view in BigQuery. Try changing to "table" instead. + columns: { + test: "A description for the test column", // Column descriptions are pushed to BigQuery. + } + } + + -- The rest of a SQLX file contains your SELECT statement used to create the table. + + SELECT 1 as test + """ + make_first_view_file = DataformWriteFileOperator( + task_id="write-first-view", + project_id=project_id, + region=region, + repository_id=repository_id, + workspace_id=workspace_id, + filepath="definitions/first_view.sqlx", + contents=first_view_content, + ) + + second_view_content = b""" + config { type: "view" } + + -- Use the ref() function to manage dependencies. + -- Learn more about ref() and other built in functions + -- here: https://cloud.google.com/dataform/docs/dataform-core + + SELECT test from ${ref("first_view")} + """ + make_second_view_file = DataformWriteFileOperator( + task_id="write-second-view", + project_id=project_id, + region=region, + repository_id=repository_id, + workspace_id=workspace_id, + filepath="definitions/second_view.sqlx", + contents=second_view_content, + ) + + make_includes_directory = DataformMakeDirectoryOperator( + task_id="make-includes-directory", + project_id=project_id, + region=region, + repository_id=repository_id, + workspace_id=workspace_id, + directory_path="includes", + ) + + gitignore_contents = b""" + node_modules/ + """ + make_gitignore_file = DataformWriteFileOperator( + task_id="write-gitignore-file", + project_id=project_id, + region=region, + repository_id=repository_id, + workspace_id=workspace_id, + filepath=".gitignore", + contents=gitignore_contents, + ) + + default_location: str = define_default_location(region).value + dataform_config_content = json.dumps( + { + "defaultSchema": "dataform", + "assertionSchema": "dataform_assertions", + "warehouse": "bigquery", + "defaultDatabase": project_id, + "defaultLocation": default_location, + }, + indent=4, + ).encode() + make_dataform_config_file = DataformWriteFileOperator( + task_id="write-dataform-config-file", + project_id=project_id, + region=region, + repository_id=repository_id, + workspace_id=workspace_id, + filepath="dataform.json", + contents=dataform_config_content, + ) + + package_name = package_name if package_name else workspace_id + package_json_content = json.dumps( + { + "name": package_name, + "dependencies": { + "@dataform/core": "2.0.1", + }, + }, + indent=4, + ).encode() + make_package_json_file = DataformWriteFileOperator( + task_id="write-package-json", + project_id=project_id, + region=region, + repository_id=repository_id, + workspace_id=workspace_id, + filepath="package.json", + contents=package_json_content, + ) + + ( + make_definitions_directory + >> make_first_view_file + >> make_second_view_file + >> make_gitignore_file + >> make_dataform_config_file + >> make_package_json_file + ) + + if without_installation: + make_package_json_file >> make_includes_directory + else: + install_npm_packages = DataformInstallNpmPackagesOperator( + task_id="install-npm-packages", + project_id=project_id, + region=region, + repository_id=repository_id, + workspace_id=workspace_id, + ) + make_package_json_file >> install_npm_packages >> make_includes_directory + + return make_definitions_directory, make_includes_directory + + +def define_default_location(region: str) -> DataformLocations: + if "us" in region: + return DataformLocations.US + elif "europe" in region: + return DataformLocations.EUROPE + + regions_mapping: Mapping[str, DataformLocations] = {} + + return regions_mapping[region] diff --git a/airflow/providers/google/cloud/utils/field_sanitizer.py b/airflow/providers/google/cloud/utils/field_sanitizer.py index 6d2814b7e67c2..de43e8c894989 100644 --- a/airflow/providers/google/cloud/utils/field_sanitizer.py +++ b/airflow/providers/google/cloud/utils/field_sanitizer.py @@ -96,8 +96,7 @@ arrays - the sanitizer iterates through all dictionaries in the array and searches components in all elements of the array. """ - -from typing import List +from __future__ import annotations from airflow.exceptions import AirflowException from airflow.utils.log.logging_mixin import LoggingMixin @@ -116,7 +115,7 @@ class GcpBodyFieldSanitizer(LoggingMixin): """ - def __init__(self, sanitize_specs: List[str]) -> None: + def __init__(self, sanitize_specs: list[str]) -> None: super().__init__() self._sanitize_specs = sanitize_specs diff --git a/airflow/providers/google/cloud/utils/field_validator.py b/airflow/providers/google/cloud/utils/field_validator.py index 974c3b4559bbc..70ad1827af371 100644 --- a/airflow/providers/google/cloud/utils/field_validator.py +++ b/airflow/providers/google/cloud/utils/field_validator.py @@ -102,7 +102,7 @@ Forward-compatibility notes --------------------------- Certain decisions are crucial to allow the client APIs to work also with future API -versions. Since body attached is passed to the API’s call, this is entirely +versions. Since body attached is passed to the API's call, this is entirely possible to pass-through any new fields in the body (for future API versions) - albeit without validation on the client side - they can and will still be validated on the server side usually. @@ -120,7 +120,7 @@ remains successful). This is very nice feature to protect against typos in names. * For unions, newly added union variants can be added by future calls and they will pass validation, however the content or presence of those fields will not be validated. - This means that it’s possible to send a new non-validated union field together with an + This means that it's possible to send a new non-validated union field together with an old validated field and this problem will not be detected by the client. In such case warning will be printed. * When you add validator to an operator, you should also add ``validate_body`` parameter @@ -129,14 +129,15 @@ backwards-incompatible changes that might sometimes occur in the APIs. """ +from __future__ import annotations import re -from typing import Callable, Dict, Sequence +from typing import Callable, Sequence from airflow.exceptions import AirflowException from airflow.utils.log.logging_mixin import LoggingMixin -COMPOSITE_FIELD_TYPES = ['union', 'dict', 'list'] +COMPOSITE_FIELD_TYPES = ["union", "dict", "list"] class GcpFieldValidationException(AirflowException): @@ -165,9 +166,9 @@ def _int_greater_than_zero(value): name="an_union", type="union", fields=[ - dict(name="variant_1", regexp=r'^.+$'), - dict(name="variant_2", regexp=r'^.+$', api_version='v1beta2'), - dict(name="variant_3", type="dict", fields=[dict(name="url", regexp=r'^.+$')]), + dict(name="variant_1", regexp=r"^.+$"), + dict(name="variant_2", regexp=r"^.+$", api_version="v1beta2"), + dict(name="variant_3", type="dict", fields=[dict(name="url", regexp=r"^.+$")]), dict(name="variant_4"), ], ), @@ -187,7 +188,7 @@ class GcpBodyFieldValidator(LoggingMixin): """ - def __init__(self, validation_specs: Sequence[Dict], api_version: str) -> None: + def __init__(self, validation_specs: Sequence[dict], api_version: str) -> None: super().__init__() self._validation_specs = validation_specs self._api_version = api_version @@ -195,12 +196,12 @@ def __init__(self, validation_specs: Sequence[Dict], api_version: str) -> None: @staticmethod def _get_field_name_with_parent(field_name, parent): if parent: - return parent + '.' + field_name + return parent + "." + field_name return field_name @staticmethod def _sanity_checks( - children_validation_specs: Dict, + children_validation_specs: dict, field_type: str, full_field_path: str, regexp: str, @@ -208,7 +209,7 @@ def _sanity_checks( custom_validation: Callable, value, ) -> None: - if value is None and field_type != 'union': + if value is None and field_type != "union": raise GcpFieldValidationException( f"The required body field '{full_field_path}' is missing. Please add it." ) @@ -251,12 +252,12 @@ def _validate_is_empty(full_field_path: str, value: str) -> None: f"The body field '{full_field_path}' can't be empty. Please provide a value." ) - def _validate_dict(self, children_validation_specs: Dict, full_field_path: str, value: Dict) -> None: + def _validate_dict(self, children_validation_specs: dict, full_field_path: str, value: dict) -> None: for child_validation_spec in children_validation_specs: self._validate_field( validation_spec=child_validation_spec, dictionary_to_validate=value, parent=full_field_path ) - all_dict_keys = [spec['name'] for spec in children_validation_specs] + all_dict_keys = [spec["name"] for spec in children_validation_specs] for field_name in value.keys(): if field_name not in all_dict_keys: self.log.warning( @@ -271,7 +272,7 @@ def _validate_dict(self, children_validation_specs: Dict, full_field_path: str, ) def _validate_union( - self, children_validation_specs: Dict, full_field_path: str, dictionary_to_validate: Dict + self, children_validation_specs: dict, full_field_path: str, dictionary_to_validate: dict ) -> None: field_found = False found_field_name = None @@ -284,7 +285,7 @@ def _validate_union( parent=full_field_path, force_optional=True, ) - field_name = child_validation_spec['name'] + field_name = child_validation_spec["name"] if new_field_found and field_found: raise GcpFieldValidationException( f"The mutually exclusive fields '{field_name}' and '{found_field_name}' belonging to " @@ -303,7 +304,7 @@ def _validate_union( "supports the new API version.", full_field_path, dictionary_to_validate, - [field['name'] for field in children_validation_specs], + [field["name"] for field in children_validation_specs], ) def _validate_field(self, validation_spec, dictionary_to_validate, parent=None, force_optional=False): @@ -317,14 +318,14 @@ def _validate_field(self, validation_spec, dictionary_to_validate, parent=None, (all union fields have force_optional set to True) :return: True if the field is present """ - field_name = validation_spec['name'] - field_type = validation_spec.get('type') - optional = validation_spec.get('optional') - regexp = validation_spec.get('regexp') - allow_empty = validation_spec.get('allow_empty') - children_validation_specs = validation_spec.get('fields') - required_api_version = validation_spec.get('api_version') - custom_validation = validation_spec.get('custom_validation') + field_name = validation_spec["name"] + field_type = validation_spec.get("type") + optional = validation_spec.get("optional") + regexp = validation_spec.get("regexp") + allow_empty = validation_spec.get("allow_empty") + children_validation_specs = validation_spec.get("fields") + required_api_version = validation_spec.get("api_version") + custom_validation = validation_spec.get("custom_validation") full_field_path = self._get_field_name_with_parent(field_name=field_name, parent=parent) if required_api_version and required_api_version != self._api_version: @@ -359,7 +360,7 @@ def _validate_field(self, validation_spec, dictionary_to_validate, parent=None, self._validate_is_empty(full_field_path, value) if regexp: self._validate_regexp(full_field_path, regexp, value) - elif field_type == 'dict': + elif field_type == "dict": if not isinstance(value, dict): raise GcpFieldValidationException( f"The field '{full_field_path}' should be of dictionary type according to " @@ -375,7 +376,7 @@ def _validate_field(self, validation_spec, dictionary_to_validate, parent=None, ) else: self._validate_dict(children_validation_specs, full_field_path, value) - elif field_type == 'union': + elif field_type == "union": if not children_validation_specs: raise GcpValidationSpecificationException( f"The union field '{full_field_path}' has no nested fields defined in " @@ -383,7 +384,7 @@ def _validate_field(self, validation_spec, dictionary_to_validate, parent=None, "Unions should have at least one nested field defined." ) self._validate_union(children_validation_specs, full_field_path, dictionary_to_validate) - elif field_type == 'list': + elif field_type == "list": if not isinstance(value, list): raise GcpFieldValidationException( f"The field '{full_field_path}' should be of list type according to " @@ -428,18 +429,18 @@ def validate(self, body_to_validate: dict) -> None: f"There was an error when validating: body '{body_to_validate}': '{e}'" ) all_field_names = [ - spec['name'] + spec["name"] for spec in self._validation_specs - if spec.get('type') != 'union' and spec.get('api_version') != self._api_version + if spec.get("type") != "union" and spec.get("api_version") != self._api_version ] - all_union_fields = [spec for spec in self._validation_specs if spec.get('type') == 'union'] + all_union_fields = [spec for spec in self._validation_specs if spec.get("type") == "union"] for union_field in all_union_fields: all_field_names.extend( [ - nested_union_spec['name'] - for nested_union_spec in union_field['fields'] - if nested_union_spec.get('type') != 'union' - and nested_union_spec.get('api_version') != self._api_version + nested_union_spec["name"] + for nested_union_spec in union_field["fields"] + if nested_union_spec.get("type") != "union" + and nested_union_spec.get("api_version") != self._api_version ] ) for field_name in body_to_validate.keys(): diff --git a/airflow/providers/google/cloud/utils/helpers.py b/airflow/providers/google/cloud/utils/helpers.py index ae402c9be1d91..28c7aa07c3f24 100644 --- a/airflow/providers/google/cloud/utils/helpers.py +++ b/airflow/providers/google/cloud/utils/helpers.py @@ -14,11 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains helper functions for Google Cloud operators.""" -from typing import Optional +from __future__ import annotations -def normalize_directory_path(source_object: Optional[str]) -> Optional[str]: +def normalize_directory_path(source_object: str | None) -> str | None: """Makes sure dir path ends with a slash""" return source_object + "/" if source_object and not source_object.endswith("/") else source_object diff --git a/airflow/providers/google/cloud/utils/mlengine_operator_utils.py b/airflow/providers/google/cloud/utils/mlengine_operator_utils.py index 3de91247f2fbd..1d6dc5d437ca9 100644 --- a/airflow/providers/google/cloud/utils/mlengine_operator_utils.py +++ b/airflow/providers/google/cloud/utils/mlengine_operator_utils.py @@ -14,15 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# """This module contains helper functions for MLEngine operators.""" +from __future__ import annotations import base64 import json import os import re -from typing import Callable, Dict, Iterable, List, Optional, Tuple, TypeVar +from typing import Callable, Iterable, TypeVar from urllib.parse import urlsplit import dill @@ -40,20 +39,20 @@ def create_evaluate_ops( task_prefix: str, data_format: str, - input_paths: List[str], + input_paths: list[str], prediction_path: str, - metric_fn_and_keys: Tuple[T, Iterable[str]], + metric_fn_and_keys: tuple[T, Iterable[str]], validate_fn: T, - batch_prediction_job_id: Optional[str] = None, - region: Optional[str] = None, - project_id: Optional[str] = None, - dataflow_options: Optional[Dict] = None, - model_uri: Optional[str] = None, - model_name: Optional[str] = None, - version_name: Optional[str] = None, - dag: Optional[DAG] = None, + batch_prediction_job_id: str | None = None, + region: str | None = None, + project_id: str | None = None, + dataflow_options: dict | None = None, + model_uri: str | None = None, + model_name: str | None = None, + version_name: str | None = None, + dag: DAG | None = None, py_interpreter="python3", -): +) -> tuple[MLEngineStartBatchPredictionJobOperator, BeamRunPythonPipelineOperator, PythonOperator]: """ Creates Operators needed for model evaluation and returns. @@ -183,7 +182,6 @@ def validate_err_and_count(summary): issues check: https://issues.apache.org/jira/browse/BEAM-1251 :returns: a tuple of three operators, (prediction, summary, validation) - :rtype: tuple(DataFlowPythonOperator, DataFlowPythonOperator, PythonOperator) """ batch_prediction_job_id = batch_prediction_job_id or "" @@ -206,11 +204,11 @@ def validate_err_and_count(summary): if dag is not None and dag.default_args is not None: default_args = dag.default_args - project_id = project_id or default_args.get('project_id') - region = region or default_args['region'] - model_name = model_name or default_args.get('model_name') - version_name = version_name or default_args.get('version_name') - dataflow_options = dataflow_options or default_args.get('dataflow_default_options') + project_id = project_id or default_args.get("project_id") + region = region or default_args["region"] + model_name = model_name or default_args.get("model_name") + version_name = version_name or default_args.get("version_name") + dataflow_options = dataflow_options or default_args.get("dataflow_default_options") evaluate_prediction = MLEngineStartBatchPredictionJobOperator( task_id=(task_prefix + "-prediction"), @@ -229,15 +227,15 @@ def validate_err_and_count(summary): metric_fn_encoded = base64.b64encode(dill.dumps(metric_fn, recurse=True)).decode() evaluate_summary = BeamRunPythonPipelineOperator( task_id=(task_prefix + "-summary"), - py_file=os.path.join(os.path.dirname(__file__), 'mlengine_prediction_summary.py'), + py_file=os.path.join(os.path.dirname(__file__), "mlengine_prediction_summary.py"), default_pipeline_options=dataflow_options, pipeline_options={ "prediction_path": prediction_path, "metric_fn_encoded": metric_fn_encoded, - "metric_keys": ','.join(metric_keys), + "metric_keys": ",".join(metric_keys), }, py_interpreter=py_interpreter, - py_requirements=['apache-beam[gcp]>=2.14.0'], + py_requirements=["apache-beam[gcp]>=2.14.0"], dag=dag, ) evaluate_summary.set_upstream(evaluate_prediction) diff --git a/airflow/providers/google/cloud/utils/mlengine_prediction_summary.py b/airflow/providers/google/cloud/utils/mlengine_prediction_summary.py index 482d809c6e631..0b3b7c5633039 100644 --- a/airflow/providers/google/cloud/utils/mlengine_prediction_summary.py +++ b/airflow/providers/google/cloud/utils/mlengine_prediction_summary.py @@ -106,6 +106,7 @@ def metric_fn(inst): pcoll """ +from __future__ import annotations import argparse import base64 @@ -115,9 +116,10 @@ def metric_fn(inst): import apache_beam as beam import dill +from apache_beam.coders.coders import Coder -class JsonCoder: +class JsonCoder(Coder): """JSON encoder/decoder.""" @staticmethod @@ -200,7 +202,7 @@ def run(argv=None): | "Write" >> beam.io.WriteToText( prediction_summary_path, - shard_name_template='', # without trailing -NNNNN-of-NNNNN. + shard_name_template="", # without trailing -NNNNN-of-NNNNN. coder=JsonCoder(), ) ) diff --git a/airflow/providers/google/common/auth_backend/google_openid.py b/airflow/providers/google/common/auth_backend/google_openid.py index 496ac29616686..eef4d5e31205b 100644 --- a/airflow/providers/google/common/auth_backend/google_openid.py +++ b/airflow/providers/google/common/auth_backend/google_openid.py @@ -16,14 +16,16 @@ # specific language governing permissions and limitations # under the License. """Authentication backend that use Google credentials for authorization.""" +from __future__ import annotations + import logging from functools import wraps -from typing import Callable, Optional, TypeVar, cast +from typing import Callable, TypeVar, cast import google import google.auth.transport.requests import google.oauth2.id_token -from flask import Response, _request_ctx_stack, current_app, request as flask_request # type: ignore +from flask import Response, current_app, request as flask_request # type: ignore from google.auth import exceptions from google.auth.transport.requests import AuthorizedSession from google.oauth2 import service_account @@ -53,7 +55,7 @@ def init_app(_): """Initializes authentication.""" -def _get_id_token_from_request(request) -> Optional[str]: +def _get_id_token_from_request(request) -> str | None: authorization_header = request.headers.get("Authorization") if not authorization_header: @@ -68,7 +70,7 @@ def _get_id_token_from_request(request) -> Optional[str]: return id_token -def _verify_id_token(id_token: str) -> Optional[str]: +def _verify_id_token(id_token: str) -> str | None: try: request_adapter = google.auth.transport.requests.Request() id_info = google.oauth2.id_token.verify_token(id_token, request_adapter, AUDIENCE) @@ -88,7 +90,7 @@ def _verify_id_token(id_token: str) -> Optional[str]: def _lookup_user(user_email: str): - security_manager = current_app.appbuilder.sm + security_manager = current_app.appbuilder.sm # type: ignore[attr-defined] user = security_manager.find_user(email=user_email) if not user: @@ -101,8 +103,7 @@ def _lookup_user(user_email: str): def _set_current_user(user): - ctx = _request_ctx_stack.top - ctx.user = user + current_app.appbuilder.sm.lm._update_request_context_with_user(user=user) # type: ignore[attr-defined] T = TypeVar("T", bound=Callable) diff --git a/airflow/providers/google/common/consts.py b/airflow/providers/google/common/consts.py index 049a2989098e5..f8d7209901d73 100644 --- a/airflow/providers/google/common/consts.py +++ b/airflow/providers/google/common/consts.py @@ -14,10 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from google.api_core.gapic_v1.client_info import ClientInfo from airflow import version -GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME = 'execute_complete' +GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME = "execute_complete" -CLIENT_INFO = ClientInfo(client_library_version='airflow_v' + version.version) +CLIENT_INFO = ClientInfo(client_library_version="airflow_v" + version.version) diff --git a/airflow/providers/google/common/hooks/base_google.py b/airflow/providers/google/common/hooks/base_google.py index d9fe5daba5443..cfad2db0a5e73 100644 --- a/airflow/providers/google/common/hooks/base_google.py +++ b/airflow/providers/google/common/hooks/base_google.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains a Google Cloud API base hook.""" +from __future__ import annotations + import functools import json import logging @@ -25,17 +26,21 @@ import warnings from contextlib import ExitStack, contextmanager from subprocess import check_output -from typing import Any, Callable, Dict, Generator, Optional, Sequence, Tuple, TypeVar, Union, cast +from typing import Any, Callable, Generator, Sequence, TypeVar, cast import google.auth import google.auth.credentials import google.oauth2.service_account import google_auth_httplib2 +import requests import tenacity +from asgiref.sync import sync_to_async from google.api_core.exceptions import Forbidden, ResourceExhausted, TooManyRequests from google.api_core.gapic_v1.client_info import ClientInfo -from google.auth import _cloud_sdk +from google.auth import _cloud_sdk, compute_engine from google.auth.environment_vars import CLOUD_SDK_CONFIG_DIR, CREDENTIALS +from google.auth.exceptions import RefreshError +from google.auth.transport import _http_client from googleapiclient import discovery from googleapiclient.errors import HttpError from googleapiclient.http import MediaIoBaseDownload, build_http, set_user_agent @@ -53,16 +58,15 @@ log = logging.getLogger(__name__) - # Constants used by the mechanism of repeating requests in reaction to exceeding the temporary quota. INVALID_KEYS = [ - 'DefaultRequestsPerMinutePerProject', - 'DefaultRequestsPerMinutePerUser', - 'RequestsPerMinutePerProject', + "DefaultRequestsPerMinutePerProject", + "DefaultRequestsPerMinutePerUser", + "RequestsPerMinutePerProject", "Resource has been exhausted (e.g. check quota).", ] INVALID_REASONS = [ - 'userRateLimitExceeded', + "userRateLimitExceeded", ] @@ -114,13 +118,26 @@ def __init__(self): # A fake project_id to use in functions decorated by fallback_to_default_project_id -# This allows the 'project_id' argument to be of type str instead of Optional[str], +# This allows the 'project_id' argument to be of type str instead of str | None, # making it easier to type hint the function body without dealing with the None # case that can never happen at runtime. PROVIDE_PROJECT_ID: str = cast(str, None) T = TypeVar("T", bound=Callable) -RT = TypeVar('RT') +RT = TypeVar("RT") + + +def get_field(extras: dict, field_name: str): + """Get field from extra, first checking short name, then for backcompat we check for prefixed name.""" + if field_name.startswith("extra__"): + raise ValueError( + f"Got prefixed name {field_name}; please remove the 'extra__google_cloud_platform__' prefix " + "when using this method." + ) + if field_name in extras: + return extras[field_name] or None + prefixed_name = f"extra__google_cloud_platform__{field_name}" + return extras.get(prefixed_name) or None class GoogleBaseHook(BaseHook): @@ -161,13 +178,13 @@ class GoogleBaseHook(BaseHook): account from the list granting this role to the originating account. """ - conn_name_attr = 'gcp_conn_id' - default_conn_name = 'google_cloud_default' - conn_type = 'google_cloud_platform' - hook_name = 'Google Cloud' + conn_name_attr = "gcp_conn_id" + default_conn_name = "google_cloud_default" + conn_type = "google_cloud_platform" + hook_name = "Google Cloud" @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: + def get_connection_form_widgets() -> dict[str, Any]: """Returns connection widgets to add to connection form""" from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget from flask_babel import lazy_gettext @@ -175,23 +192,18 @@ def get_connection_form_widgets() -> Dict[str, Any]: from wtforms.validators import NumberRange return { - "extra__google_cloud_platform__project": StringField( - lazy_gettext('Project Id'), widget=BS3TextFieldWidget() + "project": StringField(lazy_gettext("Project Id"), widget=BS3TextFieldWidget()), + "key_path": StringField(lazy_gettext("Keyfile Path"), widget=BS3TextFieldWidget()), + "keyfile_dict": PasswordField(lazy_gettext("Keyfile JSON"), widget=BS3PasswordFieldWidget()), + "scope": StringField(lazy_gettext("Scopes (comma separated)"), widget=BS3TextFieldWidget()), + "key_secret_name": StringField( + lazy_gettext("Keyfile Secret Name (in GCP Secret Manager)"), widget=BS3TextFieldWidget() ), - "extra__google_cloud_platform__key_path": StringField( - lazy_gettext('Keyfile Path'), widget=BS3TextFieldWidget() + "key_secret_project_id": StringField( + lazy_gettext("Keyfile Secret Project Id (in GCP Secret Manager)"), widget=BS3TextFieldWidget() ), - "extra__google_cloud_platform__keyfile_dict": PasswordField( - lazy_gettext('Keyfile JSON'), widget=BS3PasswordFieldWidget() - ), - "extra__google_cloud_platform__scope": StringField( - lazy_gettext('Scopes (comma separated)'), widget=BS3TextFieldWidget() - ), - "extra__google_cloud_platform__key_secret_name": StringField( - lazy_gettext('Keyfile Secret Name (in GCP Secret Manager)'), widget=BS3TextFieldWidget() - ), - "extra__google_cloud_platform__num_retries": IntegerField( - lazy_gettext('Number of Retries'), + "num_retries": IntegerField( + lazy_gettext("Number of Retries"), validators=[NumberRange(min=0)], widget=BS3TextFieldWidget(), default=5, @@ -199,41 +211,42 @@ def get_connection_form_widgets() -> Dict[str, Any]: } @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { - "hidden_fields": ['host', 'schema', 'login', 'password', 'port', 'extra'], + "hidden_fields": ["host", "schema", "login", "password", "port", "extra"], "relabeling": {}, } def __init__( self, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__() self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - self.extras = self.get_connection(self.gcp_conn_id).extra_dejson # type: Dict - self._cached_credentials: Optional[google.auth.credentials.Credentials] = None - self._cached_project_id: Optional[str] = None + self.extras: dict = self.get_connection(self.gcp_conn_id).extra_dejson + self._cached_credentials: google.auth.credentials.Credentials | None = None + self._cached_project_id: str | None = None - def _get_credentials_and_project_id(self) -> Tuple[google.auth.credentials.Credentials, Optional[str]]: + def get_credentials_and_project_id(self) -> tuple[google.auth.credentials.Credentials, str | None]: """Returns the Credentials object for Google API and the associated project_id""" if self._cached_credentials is not None: return self._cached_credentials, self._cached_project_id - key_path: Optional[str] = self._get_field('key_path', None) + key_path: str | None = self._get_field("key_path", None) try: - keyfile_dict: Optional[str] = self._get_field('keyfile_dict', None) - keyfile_dict_json: Optional[Dict[str, str]] = None + keyfile_dict: str | None = self._get_field("keyfile_dict", None) + keyfile_dict_json: dict[str, str] | None = None if keyfile_dict: keyfile_dict_json = json.loads(keyfile_dict) except json.decoder.JSONDecodeError: - raise AirflowException('Invalid key JSON.') - key_secret_name: Optional[str] = self._get_field('key_secret_name', None) + raise AirflowException("Invalid key JSON.") + key_secret_name: str | None = self._get_field("key_secret_name", None) + key_secret_project_id: str | None = self._get_field("key_secret_project_id", None) target_principal, delegates = _get_target_principal_and_delegates(self.impersonation_chain) @@ -241,13 +254,14 @@ def _get_credentials_and_project_id(self) -> Tuple[google.auth.credentials.Crede key_path=key_path, keyfile_dict=keyfile_dict_json, key_secret_name=key_secret_name, + key_secret_project_id=key_secret_project_id, scopes=self.scopes, delegate_to=self.delegate_to, target_principal=target_principal, delegates=delegates, ) - overridden_project_id = self._get_field('project') + overridden_project_id = self._get_field("project") if overridden_project_id: project_id = overridden_project_id @@ -256,14 +270,19 @@ def _get_credentials_and_project_id(self) -> Tuple[google.auth.credentials.Crede return credentials, project_id - def _get_credentials(self) -> google.auth.credentials.Credentials: + def get_credentials(self) -> google.auth.credentials.Credentials: """Returns the Credentials object for Google API""" - credentials, _ = self._get_credentials_and_project_id() + credentials, _ = self.get_credentials_and_project_id() return credentials def _get_access_token(self) -> str: """Returns a valid access token from Google API Credentials""" - return self._get_credentials().token + credentials = self.get_credentials() + auth_req = google.auth.transport.requests.Request() + # credentials.token is None + # Need to refresh credentials to populate the token + credentials.refresh(auth_req) + return credentials.token @functools.lru_cache(maxsize=None) def _get_credentials_email(self) -> str: @@ -273,21 +292,32 @@ def _get_credentials_email(self) -> str: If a service account is used, it returns the service account. If user authentication (e.g. gcloud auth) is used, it returns the e-mail account of that user. """ - credentials = self._get_credentials() - service_account_email = getattr(credentials, 'service_account_email', None) + credentials = self.get_credentials() + + if isinstance(credentials, compute_engine.Credentials): + try: + credentials.refresh(_http_client.Request()) + except RefreshError as msg: + """ + If the Compute Engine metadata service can't be reached in this case the instance has not + credentials. + """ + self.log.debug(msg) + + service_account_email = getattr(credentials, "service_account_email", None) if service_account_email: return service_account_email http_authorized = self._authorize() - oauth2_client = discovery.build('oauth2', "v1", http=http_authorized, cache_discovery=False) - return oauth2_client.tokeninfo().execute()['email'] + oauth2_client = discovery.build("oauth2", "v1", http=http_authorized, cache_discovery=False) + return oauth2_client.tokeninfo().execute()["email"] def _authorize(self) -> google_auth_httplib2.AuthorizedHttp: """ Returns an authorized HTTP object to be used to build a Google cloud service hook connection. """ - credentials = self._get_credentials() + credentials = self.get_credentials() http = build_http() http = set_user_agent(http, "airflow/" + version.version) authed_http = google_auth_httplib2.AuthorizedHttp(credentials, http=http) @@ -300,21 +330,16 @@ def _get_field(self, f: str, default: Any = None) -> Any: to the hook page, which allow admins to specify service_account, key_path, etc. They get formatted as shown below. """ - long_f = f'extra__google_cloud_platform__{f}' - if hasattr(self, 'extras') and long_f in self.extras: - return self.extras[long_f] - else: - return default + return hasattr(self, "extras") and get_field(self.extras, f) or default @property - def project_id(self) -> Optional[str]: + def project_id(self) -> str | None: """ Returns project id. :return: id of the project - :rtype: str """ - _, project_id = self._get_credentials_and_project_id() + _, project_id = self.get_credentials_and_project_id() return project_id @property @@ -323,19 +348,18 @@ def num_retries(self) -> int: Returns num_retries from Connection. :return: the number of times each API request should be retried - :rtype: int """ - field_value = self._get_field('num_retries', default=5) + field_value = self._get_field("num_retries", default=5) if field_value is None: return 5 - if isinstance(field_value, str) and field_value.strip() == '': + if isinstance(field_value, str) and field_value.strip() == "": return 5 try: return int(field_value) except ValueError: raise AirflowException( f"The num_retries field should be a integer. " - f"Current value: \"{field_value}\" (type: {type(field_value)}). " + f'Current value: "{field_value}" (type: {type(field_value)}). ' f"Please check the connection configuration." ) @@ -363,9 +387,8 @@ def scopes(self) -> Sequence[str]: Return OAuth 2.0 scopes. :return: Returns the scope defined in the connection configuration, or the default scope - :rtype: Sequence[str] """ - scope_value = self._get_field('scope', None) # type: Optional[str] + scope_value: str | None = self._get_field("scope", None) return _get_scopes(scope_value) @@ -378,10 +401,10 @@ def quota_retry(*args, **kwargs) -> Callable: def decorator(fun: Callable): default_kwargs = { - 'wait': tenacity.wait_exponential(multiplier=1, max=100), - 'retry': retry_if_temporary_quota(), - 'before': tenacity.before_log(log, logging.DEBUG), - 'after': tenacity.after_log(log, logging.DEBUG), + "wait": tenacity.wait_exponential(multiplier=1, max=100), + "retry": retry_if_temporary_quota(), + "before": tenacity.before_log(log, logging.DEBUG), + "after": tenacity.after_log(log, logging.DEBUG), } default_kwargs.update(**kwargs) return tenacity.retry(*args, **default_kwargs)(fun) @@ -398,10 +421,10 @@ def operation_in_progress_retry(*args, **kwargs) -> Callable[[T], T]: def decorator(fun: T): default_kwargs = { - 'wait': tenacity.wait_exponential(multiplier=1, max=300), - 'retry': retry_if_operation_in_progress(), - 'before': tenacity.before_log(log, logging.DEBUG), - 'after': tenacity.after_log(log, logging.DEBUG), + "wait": tenacity.wait_exponential(multiplier=1, max=300), + "retry": retry_if_operation_in_progress(), + "before": tenacity.before_log(log, logging.DEBUG), + "after": tenacity.after_log(log, logging.DEBUG), } default_kwargs.update(**kwargs) return cast(T, tenacity.retry(*args, **default_kwargs)(fun)) @@ -426,11 +449,11 @@ def inner_wrapper(self: GoogleBaseHook, *args, **kwargs) -> RT: raise AirflowException( "You must use keyword arguments in this methods rather than positional" ) - if 'project_id' in kwargs: - kwargs['project_id'] = kwargs['project_id'] or self.project_id + if "project_id" in kwargs: + kwargs["project_id"] = kwargs["project_id"] or self.project_id else: - kwargs['project_id'] = self.project_id - if not kwargs['project_id']: + kwargs["project_id"] = self.project_id + if not kwargs["project_id"]: raise AirflowException( "The project id must be passed either as " "keyword project_id parameter or as project_id extra " @@ -459,7 +482,7 @@ def wrapper(self: GoogleBaseHook, *args, **kwargs): return cast(T, wrapper) @contextmanager - def provide_gcp_credential_file_as_context(self) -> Generator[Optional[str], None, None]: + def provide_gcp_credential_file_as_context(self) -> Generator[str | None, None, None]: """ Context manager that provides a Google Cloud credentials for application supporting `Application Default Credentials (ADC) strategy `__. @@ -467,20 +490,20 @@ def provide_gcp_credential_file_as_context(self) -> Generator[Optional[str], Non It can be used to provide credentials for external programs (e.g. gcloud) that expect authorization file in ``GOOGLE_APPLICATION_CREDENTIALS`` environment variable. """ - key_path: Optional[str] = self._get_field('key_path', None) - keyfile_dict: Optional[str] = self._get_field('keyfile_dict', None) + key_path: str | None = self._get_field("key_path", None) + keyfile_dict: str | None = self._get_field("keyfile_dict", None) if key_path and keyfile_dict: raise AirflowException( "The `keyfile_dict` and `key_path` fields are mutually exclusive. " "Please provide only one value." ) elif key_path: - if key_path.endswith('.p12'): - raise AirflowException('Legacy P12 key file are not supported, use a JSON key file.') + if key_path.endswith(".p12"): + raise AirflowException("Legacy P12 key file are not supported, use a JSON key file.") with patch_environ({CREDENTIALS: key_path}): yield key_path elif keyfile_dict: - with tempfile.NamedTemporaryFile(mode='w+t') as conf_file: + with tempfile.NamedTemporaryFile(mode="w+t") as conf_file: conf_file.write(keyfile_dict) conf_file.flush() with patch_environ({CREDENTIALS: conf_file.name}): @@ -562,3 +585,42 @@ def download_content_from_request(file_handle, request: dict, chunk_size: int) - while done is False: _, done = downloader.next_chunk() file_handle.flush() + + def test_connection(self): + """Test the Google cloud connectivity from UI""" + status, message = False, "" + try: + token = self._get_access_token() + url = f"https://www.googleapis.com/oauth2/v3/tokeninfo?access_token={token}" + response = requests.post(url) + if response.status_code == 200: + status = True + message = "Connection successfully tested" + except Exception as e: + status = False + message = str(e) + + return status, message + + +class GoogleBaseAsyncHook(BaseHook): + """GoogleBaseAsyncHook inherits from BaseHook class, run on the trigger worker""" + + sync_hook_class: Any = None + + def __init__(self, **kwargs: Any): + self._hook_kwargs = kwargs + self._sync_hook = None + + async def get_sync_hook(self) -> Any: + """ + Sync version of the Google Cloud Hooks makes blocking calls in ``__init__`` so we don't inherit + from it. + """ + if not self._sync_hook: + self._sync_hook = await sync_to_async(self.sync_hook_class)(**self._hook_kwargs) + return self._sync_hook + + async def service_file_as_context(self) -> Any: + sync_hook = await self.get_sync_hook() + return await sync_to_async(sync_hook.provide_gcp_credential_file_as_context)() diff --git a/airflow/providers/google/common/hooks/discovery_api.py b/airflow/providers/google/common/hooks/discovery_api.py index bad4c7945f27e..9e5efc22bf526 100644 --- a/airflow/providers/google/common/hooks/discovery_api.py +++ b/airflow/providers/google/common/hooks/discovery_api.py @@ -15,9 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module allows you to connect to the Google Discovery API Service and query it.""" -from typing import Optional, Sequence, Union +from __future__ import annotations + +from typing import Sequence from googleapiclient.discovery import Resource, build @@ -45,15 +46,15 @@ class GoogleDiscoveryApiHook(GoogleBaseHook): account from the list granting this role to the originating account. """ - _conn = None # type: Optional[Resource] + _conn: Resource | None = None def __init__( self, api_service_name: str, api_version: str, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -68,7 +69,6 @@ def get_conn(self) -> Resource: Creates an authenticated api client for the given api service name and credentials. :return: the authenticated api service. - :rtype: Resource """ self.log.info("Authenticating Google API Client") @@ -96,7 +96,6 @@ def query(self, endpoint: str, data: dict, paginate: bool = False, num_retries: :param paginate: If set to True, it will collect all pages of data. :param num_retries: Define the number of retries for the requests being made if it fails. :return: the API response from the passed endpoint. - :rtype: dict """ google_api_conn_client = self.get_conn() @@ -104,7 +103,7 @@ def query(self, endpoint: str, data: dict, paginate: bool = False, num_retries: return api_response def _call_api_request(self, google_api_conn_client, endpoint, data, paginate, num_retries): - api_endpoint_parts = endpoint.split('.') + api_endpoint_parts = endpoint.split(".") google_api_endpoint_instance = self._build_api_request( google_api_conn_client, api_sub_functions=api_endpoint_parts[1:], api_endpoint_params=data @@ -150,7 +149,7 @@ def _build_next_api_request( google_api_conn_client = getattr(google_api_conn_client, sub_function) google_api_conn_client = google_api_conn_client() else: - google_api_conn_client = getattr(google_api_conn_client, sub_function + '_next') + google_api_conn_client = getattr(google_api_conn_client, sub_function + "_next") google_api_conn_client = google_api_conn_client(api_endpoint_instance, api_response) return google_api_conn_client diff --git a/airflow/providers/google/common/links/storage.py b/airflow/providers/google/common/links/storage.py index 7934d95d33419..42bb710937154 100644 --- a/airflow/providers/google/common/links/storage.py +++ b/airflow/providers/google/common/links/storage.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. """This module contains a link for GCS Storage assets.""" -from typing import TYPE_CHECKING, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING from airflow.models import BaseOperator from airflow.providers.google.cloud.links.base import BaseGoogleLink @@ -36,11 +38,11 @@ class StorageLink(BaseGoogleLink): format_str = GCS_STORAGE_LINK @staticmethod - def persist(context: "Context", task_instance, uri: str): + def persist(context: Context, task_instance, uri: str, project_id: str | None): task_instance.xcom_push( context=context, key=StorageLink.key, - value={"uri": uri, "project_id": task_instance.project_id}, + value={"uri": uri, "project_id": project_id}, ) @@ -52,7 +54,7 @@ class FileDetailsLink(BaseGoogleLink): format_str = GCS_FILE_DETAILS_LINK @staticmethod - def persist(context: "Context", task_instance: BaseOperator, uri: str, project_id: Optional[str]): + def persist(context: Context, task_instance: BaseOperator, uri: str, project_id: str | None): task_instance.xcom_push( context=context, key=FileDetailsLink.key, diff --git a/airflow/providers/google/common/utils/id_token_credentials.py b/airflow/providers/google/common/utils/id_token_credentials.py index 099bc96d852dd..6cac058992ff2 100644 --- a/airflow/providers/google/common/utils/id_token_credentials.py +++ b/airflow/providers/google/common/utils/id_token_credentials.py @@ -28,10 +28,10 @@ RefreshError """ +from __future__ import annotations import json import os -from typing import Optional import google.auth.transport import google.oauth2 @@ -58,8 +58,8 @@ def refresh(self, request): def _load_credentials_from_file( - filename: str, target_audience: Optional[str] -) -> Optional[google_auth_credentials.Credentials]: + filename: str, target_audience: str | None +) -> google_auth_credentials.Credentials | None: """ Loads credentials from a file. @@ -67,7 +67,6 @@ def _load_credentials_from_file( :param filename: The full path to the credentials file. :return: Loaded credentials - :rtype: google.auth.credentials.Credentials :raise google.auth.exceptions.DefaultCredentialsError: if the file is in the wrong format or is missing. """ if not os.path.exists(filename): @@ -108,8 +107,8 @@ def _load_credentials_from_file( def _get_explicit_environ_credentials( - target_audience: Optional[str], -) -> Optional[google_auth_credentials.Credentials]: + target_audience: str | None, +) -> google_auth_credentials.Credentials | None: """Gets credentials from the GOOGLE_APPLICATION_CREDENTIALS environment variable.""" explicit_file = os.environ.get(environment_vars.CREDENTIALS) @@ -124,8 +123,8 @@ def _get_explicit_environ_credentials( def _get_gcloud_sdk_credentials( - target_audience: Optional[str], -) -> Optional[google_auth_credentials.Credentials]: + target_audience: str | None, +) -> google_auth_credentials.Credentials | None: """Gets the credentials and project ID from the Cloud SDK.""" from google.auth import _cloud_sdk @@ -141,8 +140,8 @@ def _get_gcloud_sdk_credentials( def _get_gce_credentials( - target_audience: Optional[str], request: Optional[google.auth.transport.Request] = None -) -> Optional[google_auth_credentials.Credentials]: + target_audience: str | None, request: google.auth.transport.Request | None = None +) -> google_auth_credentials.Credentials | None: """Gets credentials and project ID from the GCE Metadata Service.""" # Ping requires a transport, but we want application default credentials # to require no arguments. So, we'll use the _http_client transport which @@ -170,7 +169,7 @@ def _get_gce_credentials( def get_default_id_token_credentials( - target_audience: Optional[str], request: google.auth.transport.Request = None + target_audience: str | None, request: google.auth.transport.Request = None ) -> google_auth_credentials.Credentials: """Gets the default ID Token credentials for the current environment. @@ -185,7 +184,6 @@ def get_default_id_token_credentials( is running on Compute Engine. If not specified, then it will use the standard library http client to make requests. :return: the current environment's credentials. - :rtype: google.auth.credentials.Credentials :raises ~google.auth.exceptions.DefaultCredentialsError: If no credentials were found, or if the credentials found were invalid. """ diff --git a/airflow/providers/google/config_templates/default_config.cfg b/airflow/providers/google/config_templates/default_config.cfg index cdc264bc17228..8b9742256465d 100644 --- a/airflow/providers/google/config_templates/default_config.cfg +++ b/airflow/providers/google/config_templates/default_config.cfg @@ -16,7 +16,6 @@ # specific language governing permissions and limitations # under the License. - # This is the template for Airflow's default configuration. When Airflow is # imported, it looks for a configuration file at $AIRFLOW_HOME/airflow.cfg. If # it doesn't exist, Airflow uses this template to generate it by replacing diff --git a/airflow/providers/google/firebase/example_dags/example_firestore.py b/airflow/providers/google/firebase/example_dags/example_firestore.py deleted file mode 100644 index 19e0071fef71a..0000000000000 --- a/airflow/providers/google/firebase/example_dags/example_firestore.py +++ /dev/null @@ -1,143 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example Airflow DAG that shows interactions with Google Cloud Firestore. - -Prerequisites -============= - -This example uses two Google Cloud projects: - -* ``GCP_PROJECT_ID`` - It contains a bucket and a firestore database. -* ``G_FIRESTORE_PROJECT_ID`` - it contains the Data Warehouse based on the BigQuery service. - -Saving in a bucket should be possible from the ``G_FIRESTORE_PROJECT_ID`` project. -Reading from a bucket should be possible from the ``GCP_PROJECT_ID`` project. - -The bucket and dataset should be located in the same region. - -If you want to run this example, you must do the following: - -1. Create Google Cloud project and enable the BigQuery API -2. Create the Firebase project -3. Create a bucket in the same location as the Firebase project -4. Grant Firebase admin account permissions to manage BigQuery. This is required to create a dataset. -5. Create a bucket in Firebase project and -6. Give read/write access for Firebase admin to bucket to step no. 5. -7. Create collection in the Firestore database. -""" - -import os -from datetime import datetime -from urllib.parse import urlparse - -from airflow import models -from airflow.models.baseoperator import chain -from airflow.providers.google.cloud.operators.bigquery import ( - BigQueryCreateEmptyDatasetOperator, - BigQueryCreateExternalTableOperator, - BigQueryDeleteDatasetOperator, - BigQueryInsertJobOperator, -) -from airflow.providers.google.firebase.operators.firestore import CloudFirestoreExportDatabaseOperator - -GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-gcp-project") -FIRESTORE_PROJECT_ID = os.environ.get("G_FIRESTORE_PROJECT_ID", "example-firebase-project") - -EXPORT_DESTINATION_URL = os.environ.get("GCP_FIRESTORE_ARCHIVE_URL", "gs://INVALID BUCKET NAME/namespace/") -BUCKET_NAME = urlparse(EXPORT_DESTINATION_URL).hostname -EXPORT_PREFIX = urlparse(EXPORT_DESTINATION_URL).path - -EXPORT_COLLECTION_ID = os.environ.get("GCP_FIRESTORE_COLLECTION_ID", "firestore_collection_id") -DATASET_NAME = os.environ.get("GCP_FIRESTORE_DATASET_NAME", "test_firestore_export") -DATASET_LOCATION = os.environ.get("GCP_FIRESTORE_DATASET_LOCATION", "EU") - -if BUCKET_NAME is None: - raise ValueError("Bucket name is required. Please set GCP_FIRESTORE_ARCHIVE_URL env variable.") - -with models.DAG( - "example_google_firestore", - start_date=datetime(2021, 1, 1), - schedule_interval='@once', - catchup=False, - tags=["example"], -) as dag: - # [START howto_operator_export_database_to_gcs] - export_database_to_gcs = CloudFirestoreExportDatabaseOperator( - task_id="export_database_to_gcs", - project_id=FIRESTORE_PROJECT_ID, - body={"outputUriPrefix": EXPORT_DESTINATION_URL, "collectionIds": [EXPORT_COLLECTION_ID]}, - ) - # [END howto_operator_export_database_to_gcs] - - create_dataset = BigQueryCreateEmptyDatasetOperator( - task_id="create_dataset", - dataset_id=DATASET_NAME, - location=DATASET_LOCATION, - project_id=GCP_PROJECT_ID, - ) - - delete_dataset = BigQueryDeleteDatasetOperator( - task_id="delete_dataset", dataset_id=DATASET_NAME, project_id=GCP_PROJECT_ID, delete_contents=True - ) - - # [START howto_operator_create_external_table_multiple_types] - create_external_table_multiple_types = BigQueryCreateExternalTableOperator( - task_id="create_external_table", - bucket=BUCKET_NAME, - table_resource={ - "tableReference": { - "projectId": GCP_PROJECT_ID, - "datasetId": DATASET_NAME, - "tableId": "firestore_data", - }, - "schema": { - "fields": [ - {"name": "name", "type": "STRING"}, - {"name": "post_abbr", "type": "STRING"}, - ] - }, - "externalDataConfiguration": { - "sourceFormat": "DATASTORE_BACKUP", - "compression": "NONE", - "csvOptions": {"skipLeadingRows": 1}, - }, - }, - ) - # [END howto_operator_create_external_table_multiple_types] - - read_data_from_gcs_multiple_types = BigQueryInsertJobOperator( - task_id="execute_query", - configuration={ - "query": { - "query": f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}.firestore_data`", - "useLegacySql": False, - } - }, - ) - - chain( - # Firestore - export_database_to_gcs, - # BigQuery - create_dataset, - create_external_table_multiple_types, - read_data_from_gcs_multiple_types, - delete_dataset, - ) diff --git a/airflow/providers/google/firebase/hooks/firestore.py b/airflow/providers/google/firebase/hooks/firestore.py index 8ba0e6101dc0f..ac873c56bf777 100644 --- a/airflow/providers/google/firebase/hooks/firestore.py +++ b/airflow/providers/google/firebase/hooks/firestore.py @@ -16,9 +16,10 @@ # specific language governing permissions and limitations # under the License. """Hook for Google Cloud Firestore service""" +from __future__ import annotations import time -from typing import Any, Dict, Optional, Sequence, Union +from typing import Sequence from googleapiclient.discovery import build, build_from_document @@ -51,14 +52,14 @@ class CloudFirestoreHook(GoogleBaseHook): account from the list granting this role to the originating account. """ - _conn = None # type: Optional[Any] + _conn = None def __init__( self, api_version: str = "v1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -87,7 +88,7 @@ def get_conn(self): @GoogleBaseHook.fallback_to_default_project_id def export_documents( - self, body: Dict, database_id: str = "(default)", project_id: Optional[str] = None + self, body: dict, database_id: str = "(default)", project_id: str | None = None ) -> None: """ Starts a export with the specified configuration. @@ -119,7 +120,6 @@ def _wait_for_operation_to_complete(self, operation_name: str) -> None: :param operation_name: The name of the operation. :return: The response returned by the operation. - :rtype: dict :exception: AirflowException in case error is returned. """ service = self.get_conn() diff --git a/airflow/providers/google/firebase/operators/firestore.py b/airflow/providers/google/firebase/operators/firestore.py index 77227ca275a9b..af7ec700bc319 100644 --- a/airflow/providers/google/firebase/operators/firestore.py +++ b/airflow/providers/google/firebase/operators/firestore.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Dict, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -62,12 +63,12 @@ class CloudFirestoreExportDatabaseOperator(BaseOperator): def __init__( self, *, - body: Dict, + body: dict, database_id: str = "(default)", - project_id: Optional[str] = None, + project_id: str | None = None, gcp_conn_id: str = "google_cloud_default", api_version: str = "v1", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -83,7 +84,7 @@ def _validate_inputs(self) -> None: if not self.body: raise AirflowException("The required parameter 'body' is missing") - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = CloudFirestoreHook( gcp_conn_id=self.gcp_conn_id, api_version=self.api_version, diff --git a/airflow/providers/google/go_module_utils.py b/airflow/providers/google/go_module_utils.py index 1c554e32209e2..a05590dd0a7f0 100644 --- a/airflow/providers/google/go_module_utils.py +++ b/airflow/providers/google/go_module_utils.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """Utilities initializing and managing Go modules.""" +from __future__ import annotations + import os from airflow.utils.process_utils import execute_in_subprocess diff --git a/airflow/providers/google/leveldb/example_dags/example_leveldb.py b/airflow/providers/google/leveldb/example_dags/example_leveldb.py deleted file mode 100644 index c703c86c96a0c..0000000000000 --- a/airflow/providers/google/leveldb/example_dags/example_leveldb.py +++ /dev/null @@ -1,45 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example use of LevelDB operators. -""" - -from datetime import datetime - -from airflow import models -from airflow.providers.google.leveldb.operators.leveldb import LevelDBOperator - -with models.DAG( - 'example_leveldb', - start_date=datetime(2021, 1, 1), - schedule_interval='@once', - catchup=False, - tags=['example'], -) as dag: - # [START howto_operator_leveldb_get_key] - get_key_leveldb_task = LevelDBOperator(task_id='get_key_leveldb', command='get', key=b'key') - # [END howto_operator_leveldb_get_key] - # [START howto_operator_leveldb_put_key] - put_key_leveldb_task = LevelDBOperator( - task_id='put_key_leveldb', - command='put', - key=b'another_key', - value=b'another_value', - ) - # [END howto_operator_leveldb_put_key] - get_key_leveldb_task >> put_key_leveldb_task diff --git a/airflow/providers/google/leveldb/hooks/leveldb.py b/airflow/providers/google/leveldb/hooks/leveldb.py index bff3f6af9072e..730af48db9493 100644 --- a/airflow/providers/google/leveldb/hooks/leveldb.py +++ b/airflow/providers/google/leveldb/hooks/leveldb.py @@ -15,25 +15,15 @@ # specific language governing permissions and limitations # under the License. """Hook for Level DB""" -from typing import List, Optional +from __future__ import annotations + +from airflow.exceptions import AirflowException, AirflowOptionalProviderFeatureException +from airflow.hooks.base import BaseHook try: import plyvel from plyvel import DB - - from airflow.exceptions import AirflowException - from airflow.hooks.base import BaseHook - except ImportError as e: - # Plyvel is an optional feature and if imports are missing, it should be silently ignored - # As of Airflow 2.3 and above the operator can throw OptionalProviderFeatureException - try: - from airflow.exceptions import AirflowOptionalProviderFeatureException - except ImportError: - # However, in order to keep backwards-compatibility with Airflow 2.1 and 2.2, if the - # 2.3 exception cannot be imported, the original ImportError should be raised. - # This try/except can be removed when the provider depends on Airflow >= 2.3.0 - raise e from None raise AirflowOptionalProviderFeatureException(e) DB_NOT_INITIALIZED_BEFORE = "The `get_conn` method should be called before!" @@ -49,18 +39,18 @@ class LevelDBHook(BaseHook): `LevelDB Connection Documentation `__ """ - conn_name_attr = 'leveldb_conn_id' - default_conn_name = 'leveldb_default' - conn_type = 'leveldb' - hook_name = 'LevelDB' + conn_name_attr = "leveldb_conn_id" + default_conn_name = "leveldb_default" + conn_type = "leveldb" + hook_name = "LevelDB" def __init__(self, leveldb_conn_id: str = default_conn_name): super().__init__() self.leveldb_conn_id = leveldb_conn_id self.connection = self.get_connection(leveldb_conn_id) - self.db: Optional[plyvel.DB] = None + self.db: plyvel.DB | None = None - def get_conn(self, name: str = '/tmp/testdb/', create_if_missing: bool = False, **kwargs) -> DB: + def get_conn(self, name: str = "/tmp/testdb/", create_if_missing: bool = False, **kwargs) -> DB: """ Creates `Plyvel DB `__ @@ -68,7 +58,6 @@ def get_conn(self, name: str = '/tmp/testdb/', create_if_missing: bool = False, :param create_if_missing: whether a new database should be created if needed :param kwargs: other options of creation plyvel.DB. See more in the link above. :returns: DB - :rtype: plyvel.DB """ if self.db is not None: return self.db @@ -86,10 +75,10 @@ def run( self, command: str, key: bytes, - value: Optional[bytes] = None, - keys: Optional[List[bytes]] = None, - values: Optional[List[bytes]] = None, - ) -> Optional[bytes]: + value: bytes | None = None, + keys: list[bytes] | None = None, + values: list[bytes] | None = None, + ) -> bytes | None: """ Execute operation with leveldb @@ -97,20 +86,19 @@ def run( ``"put"``, ``"get"``, ``"delete"``, ``"write_batch"``. :param key: key for command(put,get,delete) execution(, e.g. ``b'key'``, ``b'another-key'``) :param value: value for command(put) execution(bytes, e.g. ``b'value'``, ``b'another-value'``) - :param keys: keys for command(write_batch) execution(List[bytes], e.g. ``[b'key', b'another-key'])`` + :param keys: keys for command(write_batch) execution(list[bytes], e.g. ``[b'key', b'another-key'])`` :param values: values for command(write_batch) execution e.g. ``[b'value'``, ``b'another-value']`` :returns: value from get or None - :rtype: Optional[bytes] """ - if command == 'put': + if command == "put": if not value: raise Exception("Please provide `value`!") return self.put(key, value) - elif command == 'get': + elif command == "get": return self.get(key) - elif command == 'delete': + elif command == "delete": return self.delete(key) - elif command == 'write_batch': + elif command == "write_batch": if not keys: raise Exception("Please provide `keys`!") if not values: @@ -136,7 +124,6 @@ def get(self, key: bytes) -> bytes: :param key: key for get execution, e.g. ``b'key'``, ``b'another-key'`` :returns: value of key from db.get - :rtype: bytes """ if not self.db: raise Exception(DB_NOT_INITIALIZED_BEFORE) @@ -152,7 +139,7 @@ def delete(self, key: bytes): raise Exception(DB_NOT_INITIALIZED_BEFORE) self.db.delete(key) - def write_batch(self, keys: List[bytes], values: List[bytes]): + def write_batch(self, keys: list[bytes], values: list[bytes]): """ Write batch of values in a leveldb db by keys diff --git a/airflow/providers/google/leveldb/operators/leveldb.py b/airflow/providers/google/leveldb/operators/leveldb.py index 8f9ae6c2727af..7ff2da63ff066 100644 --- a/airflow/providers/google/leveldb/operators/leveldb.py +++ b/airflow/providers/google/leveldb/operators/leveldb.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING, Any from airflow.models import BaseOperator from airflow.providers.google.leveldb.hooks.leveldb import LevelDBHook @@ -35,7 +37,7 @@ class LevelDBOperator(BaseOperator): ``"put"``, ``"get"``, ``"delete"``, ``"write_batch"``. :param key: key for command(put,get,delete) execution(, e.g. ``b'key'``, ``b'another-key'``) :param value: value for command(put) execution(bytes, e.g. ``b'value'``, ``b'another-value'``) - :param keys: keys for command(write_batch) execution(List[bytes], e.g. ``[b'key', b'another-key'])`` + :param keys: keys for command(write_batch) execution(list[bytes], e.g. ``[b'key', b'another-key'])`` :param values: values for command(write_batch) execution e.g. ``[b'value'``, ``b'another-value']`` :param leveldb_conn_id: :param create_if_missing: whether a new database should be created if needed @@ -48,13 +50,13 @@ def __init__( *, command: str, key: bytes, - value: Optional[bytes] = None, - keys: Optional[List[bytes]] = None, - values: Optional[List[bytes]] = None, - leveldb_conn_id: str = 'leveldb_default', - name: str = '/tmp/testdb/', + value: bytes | None = None, + keys: list[bytes] | None = None, + values: list[bytes] | None = None, + leveldb_conn_id: str = "leveldb_default", + name: str = "/tmp/testdb/", create_if_missing: bool = True, - create_db_extra_options: Optional[Dict[str, Any]] = None, + create_db_extra_options: dict[str, Any] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -68,13 +70,12 @@ def __init__( self.create_if_missing = create_if_missing self.create_db_extra_options = create_db_extra_options or {} - def execute(self, context: 'Context') -> Optional[str]: + def execute(self, context: Context) -> str | None: """ Execute command in LevelDB :returns: value from get(str, not bytes, to prevent error in json.dumps in serialize_value in xcom.py) - or None(Optional[str]) - :rtype: Optional[str] + or str | None """ leveldb_hook = LevelDBHook(leveldb_conn_id=self.leveldb_conn_id) leveldb_hook.get_conn( diff --git a/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py b/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py deleted file mode 100644 index 4c32391fb0a3f..0000000000000 --- a/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py +++ /dev/null @@ -1,169 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example Airflow DAG that shows how to use CampaignManager. -""" -import os -import time -from datetime import datetime - -from airflow import models -from airflow.providers.google.marketing_platform.operators.campaign_manager import ( - GoogleCampaignManagerBatchInsertConversionsOperator, - GoogleCampaignManagerBatchUpdateConversionsOperator, - GoogleCampaignManagerDeleteReportOperator, - GoogleCampaignManagerDownloadReportOperator, - GoogleCampaignManagerInsertReportOperator, - GoogleCampaignManagerRunReportOperator, -) -from airflow.providers.google.marketing_platform.sensors.campaign_manager import ( - GoogleCampaignManagerReportSensor, -) - -PROFILE_ID = os.environ.get("MARKETING_PROFILE_ID", "123456789") -FLOODLIGHT_ACTIVITY_ID = int(os.environ.get("FLOODLIGHT_ACTIVITY_ID", 12345)) -FLOODLIGHT_CONFIGURATION_ID = int(os.environ.get("FLOODLIGHT_CONFIGURATION_ID", 12345)) -ENCRYPTION_ENTITY_ID = int(os.environ.get("ENCRYPTION_ENTITY_ID", 12345)) -DEVICE_ID = os.environ.get("DEVICE_ID", "12345") -BUCKET = os.environ.get("MARKETING_BUCKET", "test-cm-bucket") -REPORT_NAME = "test-report" -REPORT = { - "type": "STANDARD", - "name": REPORT_NAME, - "criteria": { - "dateRange": { - "kind": "dfareporting#dateRange", - "relativeDateRange": "LAST_365_DAYS", - }, - "dimensions": [{"kind": "dfareporting#sortedDimension", "name": "dfa:advertiser"}], - "metricNames": ["dfa:activeViewImpressionDistributionViewable"], - }, -} - -CONVERSION = { - "kind": "dfareporting#conversion", - "floodlightActivityId": FLOODLIGHT_ACTIVITY_ID, - "floodlightConfigurationId": FLOODLIGHT_CONFIGURATION_ID, - "mobileDeviceId": DEVICE_ID, - "ordinal": "0", - "quantity": 42, - "value": 123.4, - "timestampMicros": int(time.time()) * 1000000, - "customVariables": [ - { - "kind": "dfareporting#customFloodlightVariable", - "type": "U4", - "value": "value", - } - ], -} - -CONVERSION_UPDATE = { - "kind": "dfareporting#conversion", - "floodlightActivityId": FLOODLIGHT_ACTIVITY_ID, - "floodlightConfigurationId": FLOODLIGHT_CONFIGURATION_ID, - "mobileDeviceId": DEVICE_ID, - "ordinal": "0", - "quantity": 42, - "value": 123.4, -} - -with models.DAG( - "example_campaign_manager", - schedule_interval='@once', # Override to match your needs, - start_date=datetime(2021, 1, 1), - catchup=False, -) as dag: - # [START howto_campaign_manager_insert_report_operator] - create_report = GoogleCampaignManagerInsertReportOperator( - profile_id=PROFILE_ID, report=REPORT, task_id="create_report" - ) - report_id = create_report.output["report_id"] - # [END howto_campaign_manager_insert_report_operator] - - # [START howto_campaign_manager_run_report_operator] - run_report = GoogleCampaignManagerRunReportOperator( - profile_id=PROFILE_ID, report_id=report_id, task_id="run_report" - ) - file_id = run_report.output["file_id"] - # [END howto_campaign_manager_run_report_operator] - - # [START howto_campaign_manager_wait_for_operation] - wait_for_report = GoogleCampaignManagerReportSensor( - task_id="wait_for_report", - profile_id=PROFILE_ID, - report_id=report_id, - file_id=file_id, - ) - # [END howto_campaign_manager_wait_for_operation] - - # [START howto_campaign_manager_get_report_operator] - get_report = GoogleCampaignManagerDownloadReportOperator( - task_id="get_report", - profile_id=PROFILE_ID, - report_id=report_id, - file_id=file_id, - report_name="test_report.csv", - bucket_name=BUCKET, - ) - # [END howto_campaign_manager_get_report_operator] - - # [START howto_campaign_manager_delete_report_operator] - delete_report = GoogleCampaignManagerDeleteReportOperator( - profile_id=PROFILE_ID, report_name=REPORT_NAME, task_id="delete_report" - ) - # [END howto_campaign_manager_delete_report_operator] - - wait_for_report >> get_report >> delete_report - - # Task dependencies created via `XComArgs`: - # create_report >> run_report - # create_report >> wait_for_report - # create_report >> get_report - # run_report >> get_report - # run_report >> wait_for_report - - # [START howto_campaign_manager_insert_conversions] - insert_conversion = GoogleCampaignManagerBatchInsertConversionsOperator( - task_id="insert_conversion", - profile_id=PROFILE_ID, - conversions=[CONVERSION], - encryption_source="AD_SERVING", - encryption_entity_type="DCM_ADVERTISER", - encryption_entity_id=ENCRYPTION_ENTITY_ID, - ) - # [END howto_campaign_manager_insert_conversions] - - # [START howto_campaign_manager_update_conversions] - update_conversion = GoogleCampaignManagerBatchUpdateConversionsOperator( - task_id="update_conversion", - profile_id=PROFILE_ID, - conversions=[CONVERSION_UPDATE], - encryption_source="AD_SERVING", - encryption_entity_type="DCM_ADVERTISER", - encryption_entity_id=ENCRYPTION_ENTITY_ID, - max_failed_updates=1, - ) - # [END howto_campaign_manager_update_conversions] - - insert_conversion >> update_conversion - - -if __name__ == "__main__": - dag.clear() - dag.run() diff --git a/airflow/providers/google/marketing_platform/example_dags/example_display_video.py b/airflow/providers/google/marketing_platform/example_dags/example_display_video.py index d2e2d07b002af..dc9325e95ee34 100644 --- a/airflow/providers/google/marketing_platform/example_dags/example_display_video.py +++ b/airflow/providers/google/marketing_platform/example_dags/example_display_video.py @@ -18,11 +18,14 @@ """ Example Airflow DAG that shows how to use DisplayVideo. """ +from __future__ import annotations + import os from datetime import datetime -from typing import Dict +from typing import cast from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.transfers.gcs_to_bigquery import GCSToBigQueryOperator from airflow.providers.google.marketing_platform.hooks.display_video import GoogleDisplayVideo360Hook from airflow.providers.google.marketing_platform.operators.display_video import ( @@ -73,26 +76,25 @@ PARAMETERS = {"dataRange": "LAST_14_DAYS", "timezoneCode": "America/New_York"} -CREATE_SDF_DOWNLOAD_TASK_BODY_REQUEST: Dict = { +CREATE_SDF_DOWNLOAD_TASK_BODY_REQUEST: dict = { "version": SDF_VERSION, "advertiserId": ADVERTISER_ID, "inventorySourceFilter": {"inventorySourceIds": []}, } -DOWNLOAD_LINE_ITEMS_REQUEST: Dict = {"filterType": ADVERTISER_ID, "format": "CSV", "fileSpec": "EWF"} +DOWNLOAD_LINE_ITEMS_REQUEST: dict = {"filterType": ADVERTISER_ID, "format": "CSV", "fileSpec": "EWF"} # [END howto_display_video_env_variables] START_DATE = datetime(2021, 1, 1) with models.DAG( "example_display_video", - schedule_interval='@once', # Override to match your needs, start_date=START_DATE, catchup=False, ) as dag1: # [START howto_google_display_video_createquery_report_operator] create_report = GoogleDisplayVideo360CreateReportOperator(body=REPORT, task_id="create_report") - report_id = create_report.output["report_id"] + report_id = cast(str, XComArg(create_report, key="report_id")) # [END howto_google_display_video_createquery_report_operator] # [START howto_google_display_video_runquery_report_operator] @@ -129,17 +131,16 @@ with models.DAG( "example_display_video_misc", - schedule_interval='@once', # Override to match your needs, start_date=START_DATE, catchup=False, ) as dag2: # [START howto_google_display_video_upload_multiple_entity_read_files_to_big_query] upload_erf_to_bq = GCSToBigQueryOperator( - task_id='upload_erf_to_bq', + task_id="upload_erf_to_bq", bucket=BUCKET, source_objects=ERF_SOURCE_OBJECT, destination_project_dataset_table=f"{BQ_DATA_SET}.gcs_to_bq_table", - write_disposition='WRITE_TRUNCATE', + write_disposition="WRITE_TRUNCATE", ) # [END howto_google_display_video_upload_multiple_entity_read_files_to_big_query] @@ -163,7 +164,6 @@ with models.DAG( "example_display_video_sdf", - schedule_interval='@once', # Override to match your needs, start_date=START_DATE, catchup=False, ) as dag3: diff --git a/airflow/providers/google/marketing_platform/hooks/analytics.py b/airflow/providers/google/marketing_platform/hooks/analytics.py index 975c4b0d80053..524eaaa8717bc 100644 --- a/airflow/providers/google/marketing_platform/hooks/analytics.py +++ b/airflow/providers/google/marketing_platform/hooks/analytics.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional +from __future__ import annotations + +from typing import Any from googleapiclient.discovery import Resource, build from googleapiclient.http import MediaFileUpload @@ -31,9 +33,9 @@ def __init__(self, api_version: str = "v3", *args, **kwargs): self.api_version = api_version self._conn = None - def _paginate(self, resource: Resource, list_args: Optional[Dict[str, Any]] = None) -> List[dict]: + def _paginate(self, resource: Resource, list_args: dict[str, Any] | None = None) -> list[dict]: list_args = list_args or {} - result: List[dict] = [] + result: list[dict] = [] while True: # start index has value 1 request = resource.list(start_index=len(result) + 1, **list_args) @@ -58,7 +60,7 @@ def get_conn(self) -> Resource: ) return self._conn - def list_accounts(self) -> List[Dict[str, Any]]: + def list_accounts(self) -> list[dict[str, Any]]: """Lists accounts list from Google Analytics 360.""" self.log.info("Retrieving accounts list...") conn = self.get_conn() @@ -68,7 +70,7 @@ def list_accounts(self) -> List[Dict[str, Any]]: def get_ad_words_link( self, account_id: str, web_property_id: str, web_property_ad_words_link_id: str - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Returns a web property-Google Ads link to which the user has access. @@ -77,7 +79,6 @@ def get_ad_words_link( :param web_property_ad_words_link_id: to retrieve the Google Ads link for. :returns: web property-Google Ads - :rtype: Dict """ self.log.info("Retrieving ad words links...") ad_words_link = ( @@ -93,7 +94,7 @@ def get_ad_words_link( ) return ad_words_link - def list_ad_words_links(self, account_id: str, web_property_id: str) -> List[Dict[str, Any]]: + def list_ad_words_links(self, account_id: str, web_property_id: str) -> list[dict[str, Any]]: """ Lists webProperty-Google Ads links for a given web property. @@ -101,7 +102,6 @@ def list_ad_words_links(self, account_id: str, web_property_id: str) -> List[Dic :param web_property_id: Web property UA-string to retrieve the Google Ads links for. :returns: list of entity Google Ads links. - :rtype: list """ self.log.info("Retrieving ad words list...") conn = self.get_conn() @@ -153,7 +153,7 @@ def delete_upload_data( account_id: str, web_property_id: str, custom_data_source_id: str, - delete_request_body: Dict[str, Any], + delete_request_body: dict[str, Any], ) -> None: """ Deletes the uploaded data for a given account/property/dataset @@ -178,7 +178,7 @@ def delete_upload_data( body=delete_request_body, ).execute() - def list_uploads(self, account_id, web_property_id, custom_data_source_id) -> List[Dict[str, Any]]: + def list_uploads(self, account_id, web_property_id, custom_data_source_id) -> list[dict[str, Any]]: """ Get list of data upload from GA diff --git a/airflow/providers/google/marketing_platform/hooks/campaign_manager.py b/airflow/providers/google/marketing_platform/hooks/campaign_manager.py index e529546318988..f64d3f3b3cf6a 100644 --- a/airflow/providers/google/marketing_platform/hooks/campaign_manager.py +++ b/airflow/providers/google/marketing_platform/hooks/campaign_manager.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Campaign Manager hook.""" -from typing import Any, Dict, List, Optional, Sequence, Union +from __future__ import annotations + +from typing import Any, Sequence from googleapiclient import http from googleapiclient.discovery import Resource, build @@ -28,14 +30,14 @@ class GoogleCampaignManagerHook(GoogleBaseHook): """Hook for Google Campaign Manager.""" - _conn = None # type: Optional[Resource] + _conn: Resource | None = None def __init__( self, api_version: str = "v3.3", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -71,7 +73,7 @@ def delete_report(self, profile_id: str, report_id: str) -> Any: ) return response - def insert_report(self, profile_id: str, report: Dict[str, Any]) -> Any: + def insert_report(self, profile_id: str, report: dict[str, Any]) -> Any: """ Creates a report. @@ -89,11 +91,11 @@ def insert_report(self, profile_id: str, report: Dict[str, Any]) -> Any: def list_reports( self, profile_id: str, - max_results: Optional[int] = None, - scope: Optional[str] = None, - sort_field: Optional[str] = None, - sort_order: Optional[str] = None, - ) -> List[dict]: + max_results: int | None = None, + scope: str | None = None, + sort_field: str | None = None, + sort_order: str | None = None, + ) -> list[dict]: """ Retrieves list of reports. @@ -103,7 +105,7 @@ def list_reports( :param sort_field: The field by which to sort the list. :param sort_order: Order of sorted results. """ - reports: List[dict] = [] + reports: list[dict] = [] conn = self.get_conn() request = conn.reports().list( profileId=profile_id, @@ -136,7 +138,7 @@ def patch_report(self, profile_id: str, report_id: str, update_mask: dict) -> An ) return response - def run_report(self, profile_id: str, report_id: str, synchronous: Optional[bool] = None) -> Any: + def run_report(self, profile_id: str, report_id: str, synchronous: bool | None = None) -> Any: """ Runs a report. @@ -203,12 +205,12 @@ def get_report_file(self, file_id: str, profile_id: str, report_id: str) -> http @staticmethod def _conversions_batch_request( - conversions: List[Dict[str, Any]], + conversions: list[dict[str, Any]], encryption_entity_type: str, encryption_entity_id: int, encryption_source: str, kind: str, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: return { "kind": kind, "conversions": conversions, @@ -223,7 +225,7 @@ def _conversions_batch_request( def conversions_batch_insert( self, profile_id: str, - conversions: List[Dict[str, Any]], + conversions: list[dict[str, Any]], encryption_entity_type: str, encryption_entity_id: int, encryption_source: str, @@ -258,8 +260,8 @@ def conversions_batch_insert( ) .execute(num_retries=self.num_retries) ) - if response.get('hasFailures', False): - errored_conversions = [stat['errors'] for stat in response['status'] if 'errors' in stat] + if response.get("hasFailures", False): + errored_conversions = [stat["errors"] for stat in response["status"] if "errors" in stat] if len(errored_conversions) > max_failed_inserts: raise AirflowException(errored_conversions) return response @@ -267,7 +269,7 @@ def conversions_batch_insert( def conversions_batch_update( self, profile_id: str, - conversions: List[Dict[str, Any]], + conversions: list[dict[str, Any]], encryption_entity_type: str, encryption_entity_id: int, encryption_source: str, @@ -302,8 +304,8 @@ def conversions_batch_update( ) .execute(num_retries=self.num_retries) ) - if response.get('hasFailures', False): - errored_conversions = [stat['errors'] for stat in response['status'] if 'errors' in stat] + if response.get("hasFailures", False): + errored_conversions = [stat["errors"] for stat in response["status"] if "errors" in stat] if len(errored_conversions) > max_failed_updates: raise AirflowException(errored_conversions) return response diff --git a/airflow/providers/google/marketing_platform/hooks/display_video.py b/airflow/providers/google/marketing_platform/hooks/display_video.py index 73de3a7634dad..05bba7a8ce185 100644 --- a/airflow/providers/google/marketing_platform/hooks/display_video.py +++ b/airflow/providers/google/marketing_platform/hooks/display_video.py @@ -16,8 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains Google DisplayVideo hook.""" +from __future__ import annotations -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Sequence from googleapiclient.discovery import Resource, build @@ -27,14 +28,14 @@ class GoogleDisplayVideo360Hook(GoogleBaseHook): """Hook for Google Display & Video 360.""" - _conn: Optional[Resource] = None + _conn: Resource | None = None def __init__( self, api_version: str = "v1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -68,7 +69,7 @@ def get_conn_to_display_video(self) -> Resource: return self._conn @staticmethod - def erf_uri(partner_id, entity_type) -> List[str]: + def erf_uri(partner_id, entity_type) -> list[str]: """ Return URI for all Entity Read Files in bucket. @@ -86,7 +87,7 @@ def erf_uri(partner_id, entity_type) -> List[str]: """ return [f"gdbm-{partner_id}/entity/{{{{ ds_nodash }}}}.*.{entity_type}.json"] - def create_query(self, query: Dict[str, Any]) -> dict: + def create_query(self, query: dict[str, Any]) -> dict: """ Creates a query. @@ -112,14 +113,12 @@ def get_query(self, query_id: str) -> dict: response = self.get_conn().queries().getquery(queryId=query_id).execute(num_retries=self.num_retries) return response - def list_queries( - self, - ) -> List[Dict]: + def list_queries(self) -> list[dict]: """Retrieves stored queries.""" response = self.get_conn().queries().listqueries().execute(num_retries=self.num_retries) - return response.get('queries', []) + return response.get("queries", []) - def run_query(self, query_id: str, params: Optional[Dict[str, Any]]) -> None: + def run_query(self, query_id: str, params: dict[str, Any] | None) -> None: """ Runs a stored query to generate a report. @@ -133,13 +132,12 @@ def run_query(self, query_id: str, params: Optional[Dict[str, Any]]) -> None: .execute(num_retries=self.num_retries) ) - def upload_line_items(self, line_items: Any) -> List[Dict[str, Any]]: + def upload_line_items(self, line_items: Any) -> list[dict[str, Any]]: """ Uploads line items in CSV format. :param line_items: downloaded data from GCS and passed to the body request :return: response body. - :rtype: List[Dict[str, Any]] """ request_body = { "lineItems": line_items, @@ -155,7 +153,7 @@ def upload_line_items(self, line_items: Any) -> List[Dict[str, Any]]: ) return response - def download_line_items(self, request_body: Dict[str, Any]) -> List[Any]: + def download_line_items(self, request_body: dict[str, Any]) -> list[Any]: """ Retrieves line items in CSV format. @@ -171,7 +169,7 @@ def download_line_items(self, request_body: Dict[str, Any]) -> List[Any]: ) return response["lineItems"] - def create_sdf_download_operation(self, body_request: Dict[str, Any]) -> Dict[str, Any]: + def create_sdf_download_operation(self, body_request: dict[str, Any]) -> dict[str, Any]: """ Creates an SDF Download Task and Returns an Operation. diff --git a/airflow/providers/google/marketing_platform/hooks/search_ads.py b/airflow/providers/google/marketing_platform/hooks/search_ads.py index 8fb382b20aed9..8115833b1afa6 100644 --- a/airflow/providers/google/marketing_platform/hooks/search_ads.py +++ b/airflow/providers/google/marketing_platform/hooks/search_ads.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Search Ads 360 hook.""" -from typing import Any, Dict, Optional, Sequence, Union +from __future__ import annotations + +from typing import Any, Sequence from googleapiclient.discovery import build @@ -26,14 +28,14 @@ class GoogleSearchAdsHook(GoogleBaseHook): """Hook for Google Search Ads 360.""" - _conn = None # type: Optional[Any] + _conn = None def __init__( self, api_version: str = "v2", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -54,7 +56,7 @@ def get_conn(self): ) return self._conn - def insert_report(self, report: Dict[str, Any]) -> Any: + def insert_report(self, report: dict[str, Any]) -> Any: """ Inserts a report request into the reporting system. diff --git a/airflow/providers/google/marketing_platform/operators/analytics.py b/airflow/providers/google/marketing_platform/operators/analytics.py index 8d5d90c986577..8767b6de1a916 100644 --- a/airflow/providers/google/marketing_platform/operators/analytics.py +++ b/airflow/providers/google/marketing_platform/operators/analytics.py @@ -16,9 +16,11 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Analytics 360 operators.""" +from __future__ import annotations + import csv from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Sequence from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook @@ -65,7 +67,7 @@ def __init__( *, api_version: str = "v3", gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -74,7 +76,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> List[Dict[str, Any]]: + def execute(self, context: Context) -> list[dict[str, Any]]: hook = GoogleAnalyticsHook( api_version=self.api_version, gcp_conn_id=self.gcp_conn_id, @@ -126,7 +128,7 @@ def __init__( web_property_id: str, api_version: str = "v3", gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -138,7 +140,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> Dict[str, Any]: + def execute(self, context: Context) -> dict[str, Any]: hook = GoogleAnalyticsHook( api_version=self.api_version, gcp_conn_id=self.gcp_conn_id, @@ -191,7 +193,7 @@ def __init__( web_property_id: str, api_version: str = "v3", gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -202,7 +204,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> List[Dict[str, Any]]: + def execute(self, context: Context) -> list[dict[str, Any]]: hook = GoogleAnalyticsHook( api_version=self.api_version, gcp_conn_id=self.gcp_conn_id, @@ -259,9 +261,9 @@ def __init__( custom_data_source_id: str, resumable_upload: bool = False, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, + delegate_to: str | None = None, api_version: str = "v3", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -276,7 +278,7 @@ def __init__( self.api_version = api_version self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: gcs_hook = GCSHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -341,9 +343,9 @@ def __init__( web_property_id: str, custom_data_source_id: str, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, + delegate_to: str | None = None, api_version: str = "v3", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -356,7 +358,7 @@ def __init__( self.api_version = api_version self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: ga_hook = GoogleAnalyticsHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -420,9 +422,9 @@ def __init__( storage_bucket: str, storage_name_object: str, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - custom_dimension_header_mapping: Optional[Dict[str, str]] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + custom_dimension_header_mapping: dict[str, str] | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -434,7 +436,7 @@ def __init__( self.impersonation_chain = impersonation_chain def _modify_column_headers( - self, tmp_file_location: str, custom_dimension_header_mapping: Dict[str, str] + self, tmp_file_location: str, custom_dimension_header_mapping: dict[str, str] ) -> None: # Check headers self.log.info("Checking if file contains headers") @@ -466,7 +468,7 @@ def _modify_column_headers( with open(tmp_file_location, "w") as write_file: write_file.writelines(all_data) - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: gcs_hook = GCSHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/google/marketing_platform/operators/campaign_manager.py b/airflow/providers/google/marketing_platform/operators/campaign_manager.py index 45730a30d26c4..e9dcfc95dd302 100644 --- a/airflow/providers/google/marketing_platform/operators/campaign_manager.py +++ b/airflow/providers/google/marketing_platform/operators/campaign_manager.py @@ -16,10 +16,12 @@ # specific language governing permissions and limitations # under the License. """This module contains Google CampaignManager operators.""" +from __future__ import annotations + import json import tempfile import uuid -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Sequence from googleapiclient import http @@ -76,12 +78,12 @@ def __init__( self, *, profile_id: str, - report_name: Optional[str] = None, - report_id: Optional[str] = None, + report_name: str | None = None, + report_id: str | None = None, api_version: str = "v3.3", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -98,7 +100,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = GoogleCampaignManagerHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -173,13 +175,13 @@ def __init__( report_id: str, file_id: str, bucket_name: str, - report_name: Optional[str] = None, + report_name: str | None = None, gzip: bool = True, chunk_size: int = 10 * 1024 * 1024, api_version: str = "v3.3", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -209,7 +211,7 @@ def _set_bucket_name(name: str) -> str: bucket = name if not name.startswith("gs://") else name[5:] return bucket.strip("/") - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = GoogleCampaignManagerHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -295,11 +297,11 @@ def __init__( self, *, profile_id: str, - report: Dict[str, Any], + report: dict[str, Any], api_version: str = "v3.3", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -312,11 +314,11 @@ def __init__( def prepare_template(self) -> None: # If .json is passed then we have to read the file - if isinstance(self.report, str) and self.report.endswith('.json'): + if isinstance(self.report, str) and self.report.endswith(".json"): with open(self.report) as file: self.report = json.load(file) - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = GoogleCampaignManagerHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -379,8 +381,8 @@ def __init__( synchronous: bool = False, api_version: str = "v3.3", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -392,7 +394,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = GoogleCampaignManagerHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -461,15 +463,15 @@ def __init__( self, *, profile_id: str, - conversions: List[Dict[str, Any]], + conversions: list[dict[str, Any]], encryption_entity_type: str, encryption_entity_id: int, encryption_source: str, max_failed_inserts: int = 0, api_version: str = "v3.3", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -484,7 +486,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = GoogleCampaignManagerHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -552,15 +554,15 @@ def __init__( self, *, profile_id: str, - conversions: List[Dict[str, Any]], + conversions: list[dict[str, Any]], encryption_entity_type: str, encryption_entity_id: int, encryption_source: str, max_failed_updates: int = 0, api_version: str = "v3.3", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -575,7 +577,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = GoogleCampaignManagerHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/google/marketing_platform/operators/display_video.py b/airflow/providers/google/marketing_platform/operators/display_video.py index de102a14d8f31..d33ffb46d7603 100644 --- a/airflow/providers/google/marketing_platform/operators/display_video.py +++ b/airflow/providers/google/marketing_platform/operators/display_video.py @@ -16,13 +16,15 @@ # specific language governing permissions and limitations # under the License. """This module contains Google DisplayVideo operators.""" +from __future__ import annotations + import csv import json import shutil import tempfile import urllib.request -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union -from urllib.parse import urlparse +from typing import TYPE_CHECKING, Any, Sequence +from urllib.parse import urlsplit from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -71,11 +73,11 @@ class GoogleDisplayVideo360CreateReportOperator(BaseOperator): def __init__( self, *, - body: Dict[str, Any], + body: dict[str, Any], api_version: str = "v1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -87,11 +89,11 @@ def __init__( def prepare_template(self) -> None: # If .json is passed then we have to read the file - if isinstance(self.body, str) and self.body.endswith('.json'): + if isinstance(self.body, str) and self.body.endswith(".json"): with open(self.body) as file: self.body = json.load(file) - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: hook = GoogleDisplayVideo360Hook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -143,12 +145,12 @@ class GoogleDisplayVideo360DeleteReportOperator(BaseOperator): def __init__( self, *, - report_id: Optional[str] = None, - report_name: Optional[str] = None, + report_id: str | None = None, + report_name: str | None = None, api_version: str = "v1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -165,7 +167,7 @@ def __init__( if not (report_name or report_id): raise AirflowException("Provide one of the values: `report_name` or `report_id`.") - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = GoogleDisplayVideo360Hook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -230,13 +232,13 @@ def __init__( *, report_id: str, bucket_name: str, - report_name: Optional[str] = None, + report_name: str | None = None, gzip: bool = True, chunk_size: int = 10 * 1024 * 1024, api_version: str = "v1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -260,7 +262,7 @@ def _set_bucket_name(name: str) -> str: bucket = name if not name.startswith("gs://") else name[5:] return bucket.strip("/") - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = GoogleDisplayVideo360Hook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -280,7 +282,7 @@ def execute(self, context: 'Context'): # If no custom report_name provided, use DV360 name file_url = resource["metadata"]["googleCloudStoragePathForLatestReport"] - report_name = self.report_name or urlparse(file_url).path.split("/")[-1] + report_name = self.report_name or urlsplit(file_url).path.split("/")[-1] report_name = self._resolve_file_name(report_name) # Download the report @@ -348,11 +350,11 @@ def __init__( self, *, report_id: str, - parameters: Optional[Dict[str, Any]] = None, + parameters: dict[str, Any] | None = None, api_version: str = "v1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -363,7 +365,7 @@ def __init__( self.parameters = parameters self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = GoogleDisplayVideo360Hook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -405,14 +407,14 @@ class GoogleDisplayVideo360DownloadLineItemsOperator(BaseOperator): def __init__( self, *, - request_body: Dict[str, Any], + request_body: dict[str, Any], bucket_name: str, object_name: str, gzip: bool = False, api_version: str = "v1.1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -425,7 +427,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> str: + def execute(self, context: Context) -> str: gcs_hook = GCSHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -439,7 +441,7 @@ def execute(self, context: 'Context') -> str: ) self.log.info("Retrieving report...") - content: List[str] = hook.download_line_items(request_body=self.request_body) + content: list[str] = hook.download_line_items(request_body=self.request_body) with tempfile.NamedTemporaryFile("w+") as temp_file: writer = csv.writer(temp_file) @@ -487,8 +489,8 @@ def __init__( object_name: str, api_version: str = "v1.1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -499,7 +501,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: gcs_hook = GCSHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -565,11 +567,11 @@ class GoogleDisplayVideo360CreateSDFDownloadTaskOperator(BaseOperator): def __init__( self, *, - body_request: Dict[str, Any], + body_request: dict[str, Any], api_version: str = "v1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -579,7 +581,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> Dict[str, Any]: + def execute(self, context: Context) -> dict[str, Any]: hook = GoogleDisplayVideo360Hook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -645,8 +647,8 @@ def __init__( gzip: bool = False, api_version: str = "v1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -659,7 +661,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> str: + def execute(self, context: Context) -> str: hook = GoogleDisplayVideo360Hook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/google/marketing_platform/operators/search_ads.py b/airflow/providers/google/marketing_platform/operators/search_ads.py index 674dc448f6d7a..87f3e14be7105 100644 --- a/airflow/providers/google/marketing_platform/operators/search_ads.py +++ b/airflow/providers/google/marketing_platform/operators/search_ads.py @@ -16,9 +16,11 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Search Ads operators.""" +from __future__ import annotations + import json from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -66,11 +68,11 @@ class GoogleSearchAdsInsertReportOperator(BaseOperator): def __init__( self, *, - report: Dict[str, Any], + report: dict[str, Any], api_version: str = "v2", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -82,11 +84,11 @@ def __init__( def prepare_template(self) -> None: # If .json is passed then we have to read the file - if isinstance(self.report, str) and self.report.endswith('.json'): + if isinstance(self.report, str) and self.report.endswith(".json"): with open(self.report) as file: self.report = json.load(file) - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = GoogleSearchAdsHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -145,13 +147,13 @@ def __init__( *, report_id: str, bucket_name: str, - report_name: Optional[str] = None, + report_name: str | None = None, gzip: bool = True, chunk_size: int = 10 * 1024 * 1024, api_version: str = "v2", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -187,7 +189,7 @@ def _handle_report_fragment(fragment: bytes) -> bytes: return fragment_records[1] return b"" - def execute(self, context: 'Context'): + def execute(self, context: Context): hook = GoogleSearchAdsHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -206,8 +208,8 @@ def execute(self, context: 'Context'): report_name = self._resolve_file_name(report_name) response = hook.get(report_id=self.report_id) - if not response['isReportReady']: - raise AirflowException(f'Report {self.report_id} is not ready yet') + if not response["isReportReady"]: + raise AirflowException(f"Report {self.report_id} is not ready yet") # Resolve report fragments fragments_count = len(response["files"]) diff --git a/airflow/providers/google/marketing_platform/sensors/campaign_manager.py b/airflow/providers/google/marketing_platform/sensors/campaign_manager.py index 705fe22cf4e64..88e727d50ce73 100644 --- a/airflow/providers/google/marketing_platform/sensors/campaign_manager.py +++ b/airflow/providers/google/marketing_platform/sensors/campaign_manager.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Campaign Manager sensor.""" -from typing import TYPE_CHECKING, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.providers.google.marketing_platform.hooks.campaign_manager import GoogleCampaignManagerHook from airflow.sensors.base import BaseSensorOperator @@ -62,7 +64,7 @@ class GoogleCampaignManagerReportSensor(BaseSensorOperator): "impersonation_chain", ) - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: hook = GoogleCampaignManagerHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -81,10 +83,10 @@ def __init__( file_id: str, api_version: str = "v3.3", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, + delegate_to: str | None = None, mode: str = "reschedule", poke_interval: int = 60 * 5, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) diff --git a/airflow/providers/google/marketing_platform/sensors/display_video.py b/airflow/providers/google/marketing_platform/sensors/display_video.py index 5af14fc1da5ea..5605e3cbd4bc0 100644 --- a/airflow/providers/google/marketing_platform/sensors/display_video.py +++ b/airflow/providers/google/marketing_platform/sensors/display_video.py @@ -14,9 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Sensor for detecting the completion of DV360 reports.""" -from typing import TYPE_CHECKING, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow import AirflowException from airflow.providers.google.marketing_platform.hooks.display_video import GoogleDisplayVideo360Hook @@ -61,8 +62,8 @@ def __init__( report_id: str, api_version: str = "v1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -73,7 +74,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: hook = GoogleDisplayVideo360Hook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -122,10 +123,10 @@ def __init__( operation_name: str, api_version: str = "v1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, + delegate_to: str | None = None, mode: str = "reschedule", poke_interval: int = 60 * 5, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, *args, **kwargs, ) -> None: @@ -138,7 +139,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: hook = GoogleDisplayVideo360Hook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/google/marketing_platform/sensors/search_ads.py b/airflow/providers/google/marketing_platform/sensors/search_ads.py index 9c5e7c2b3a1cc..b978c6706ecd9 100644 --- a/airflow/providers/google/marketing_platform/sensors/search_ads.py +++ b/airflow/providers/google/marketing_platform/sensors/search_ads.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Search Ads sensor.""" -from typing import TYPE_CHECKING, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.providers.google.marketing_platform.hooks.search_ads import GoogleSearchAdsHook from airflow.sensors.base import BaseSensorOperator @@ -64,10 +66,10 @@ def __init__( report_id: str, api_version: str = "v2", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, + delegate_to: str | None = None, mode: str = "reschedule", poke_interval: int = 5 * 60, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(mode=mode, poke_interval=poke_interval, **kwargs) @@ -77,13 +79,13 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def poke(self, context: 'Context'): + def poke(self, context: Context): hook = GoogleSearchAdsHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, api_version=self.api_version, impersonation_chain=self.impersonation_chain, ) - self.log.info('Checking status of %s report.', self.report_id) + self.log.info("Checking status of %s report.", self.report_id) response = hook.get(report_id=self.report_id) - return response['isReportReady'] + return response["isReportReady"] diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index e61f2dfc3c482..319c061432d93 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -29,6 +29,12 @@ description: | - `Google Workspace `__ (formerly Google Suite) versions: + - 8.5.0 + - 8.4.0 + - 8.3.0 + - 8.2.0 + - 8.1.0 + - 8.0.0 - 7.0.0 - 6.8.0 - 6.7.0 @@ -48,8 +54,72 @@ versions: - 2.0.0 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-common-sql>=1.3.1 + # Google has very clear rules on what dependencies should be used. All the limits below + # follow strict guidelines of Google Libraries as quoted here: + # While this issue is open, dependents of google-api-core, google-cloud-core. and google-auth + # should preserve >1, <3 pins on these packages. + # https://github.com/googleapis/google-cloud-python/issues/10566 + # Some of Google Packages are limited to <2.0.0 because 2.0.0 releases of the libraries + # Introduced breaking changes across the board. Those libraries should be upgraded soon + # TODO: Upgrade all Google libraries that are limited to <2.0.0 + - PyOpenSSL + - asgiref>=3.5.2 + - gcloud_aio_auth>=4.0.0 + - gcloud-aio-bigquery>=6.1.2 + - gcloud-aio-storage + - google-ads>=15.1.1 + - google-api-core>=2.7.0,<3.0.0 + - google-api-python-client>=1.6.0,<2.0.0 + - google-auth>=1.0.0 + - google-auth-httplib2>=0.0.1 + - google-cloud-aiplatform>=1.7.1,<2.0.0 + - google-cloud-automl>=2.1.0 + - google-cloud-bigquery-datatransfer>=3.0.0 + - google-cloud-bigtable>=1.0.0,<2.0.0 + - google-cloud-build>=3.0.0 + - google-cloud-compute>=0.1.0,<2.0.0 + - google-cloud-container>=2.2.0,<3.0.0 + - google-cloud-dataform>=0.2.0 + - google-cloud-datacatalog>=3.0.0 + - google-cloud-dataplex>=0.1.0 + - google-cloud-dataproc>=3.1.0 + - google-cloud-dataproc-metastore>=1.2.0,<2.0.0 + - google-cloud-dlp>=0.11.0,<2.0.0 + - google-cloud-kms>=2.0.0 + - google-cloud-language>=1.1.1,<2.0.0 + - google-cloud-logging>=2.1.1 + - google-cloud-memcache>=0.2.0 + - google-cloud-monitoring>=2.0.0 + - google-cloud-os-login>=2.0.0 + - google-cloud-orchestration-airflow>=1.0.0,<2.0.0 + - google-cloud-pubsub>=2.0.0 + - google-cloud-redis>=2.0.0 + - google-cloud-secret-manager>=0.2.0,<2.0.0 + - google-cloud-spanner>=1.10.0,<2.0.0 + - google-cloud-speech>=0.36.3,<2.0.0 + - google-cloud-storage>=1.30,<3.0.0 + - google-cloud-tasks>=2.0.0 + - google-cloud-texttospeech>=0.4.0,<2.0.0 + - google-cloud-translate>=1.5.0,<2.0.0 + - google-cloud-videointelligence>=1.7.0,<2.0.0 + - google-cloud-vision>=0.35.2,<2.0.0 + - google-cloud-workflows>=0.1.0,<2.0.0 + - grpcio-gcp>=0.2.2 + - httpx + - json-merge-patch>=0.2 + - looker-sdk>=22.2.0 + - pandas-gbq + - pandas>=0.17.1 + - sqlalchemy-bigquery>=1.2.1 + # A transient dependency of google-cloud-bigquery-datatransfer, but we + # further constrain it since older versions are buggy. + - proto-plus>=1.19.6 + # Google bigtable client require protobuf <= 3.20.0. We can remove the limitation + # when this limitation is removed + - protobuf<=3.20.0 integrations: - integration-name: Google Analytics360 @@ -99,6 +169,11 @@ integrations: how-to-guide: - /docs/apache-airflow-providers-google/operators/cloud/cloud_composer.rst tags: [google] + - integration-name: Google Cloud Dataform + external-doc-url: https://cloud.google.com/dataform/ + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/dataform.rst + tags: [google] - integration-name: Google Cloud Data Loss Prevention (DLP) external-doc-url: https://cloud.google.com/dlp/ how-to-guide: @@ -509,6 +584,9 @@ operators: - integration-name: Google Looker python-modules: - airflow.providers.google.cloud.operators.looker + - integration-name: Google Cloud Dataform + python-modules: + - airflow.providers.google.cloud.operators.dataform sensors: - integration-name: Google BigQuery @@ -520,6 +598,9 @@ sensors: - integration-name: Google Bigtable python-modules: - airflow.providers.google.cloud.sensors.bigtable + - integration-name: Google Cloud Composer + python-modules: + - airflow.providers.google.cloud.sensors.cloud_composer - integration-name: Google Cloud Storage Transfer Service python-modules: - airflow.providers.google.cloud.sensors.cloud_storage_transfer_service @@ -529,6 +610,9 @@ sensors: - integration-name: Google Data Fusion python-modules: - airflow.providers.google.cloud.sensors.datafusion + - integration-name: Google Dataprep + python-modules: + - airflow.providers.google.cloud.sensors.dataprep - integration-name: Google Dataplex python-modules: - airflow.providers.google.cloud.sensors.dataplex @@ -559,6 +643,12 @@ sensors: - integration-name: Google Looker python-modules: - airflow.providers.google.cloud.sensors.looker + - integration-name: Google Cloud Dataform + python-modules: + - airflow.providers.google.cloud.sensors.dataform + - integration-name: Google Cloud Tasks + python-modules: + - airflow.providers.google.cloud.sensors.tasks hooks: - integration-name: Google Ads @@ -727,6 +817,9 @@ hooks: - integration-name: Google Looker python-modules: - airflow.providers.google.cloud.hooks.looker + - integration-name: Google Cloud Dataform + python-modules: + - airflow.providers.google.cloud.hooks.dataform transfers: - source-integration-name: Presto @@ -737,7 +830,7 @@ transfers: target-integration-name: Google Cloud Storage (GCS) how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/trino_to_gcs.rst python-module: airflow.providers.google.cloud.transfers.trino_to_gcs - - source-integration-name: SQL + - source-integration-name: Common SQL target-integration-name: Google Cloud Storage (GCS) python-module: airflow.providers.google.cloud.transfers.sql_to_gcs - source-integration-name: Google Cloud Storage (GCS) @@ -816,7 +909,7 @@ transfers: target-integration-name: Google Spreadsheet how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/gcs_to_sheets.rst python-module: airflow.providers.google.suite.transfers.gcs_to_sheets - - source-integration-name: SQL + - source-integration-name: Common SQL target-integration-name: Google Spreadsheet how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/sql_to_sheets.rst python-module: airflow.providers.google.suite.transfers.sql_to_sheets @@ -851,14 +944,6 @@ transfers: python-module: airflow.providers.google.cloud.transfers.mssql_to_gcs how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/mssql_to_gcs.rst -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.google.common.hooks.base_google.GoogleBaseHook - - airflow.providers.google.cloud.hooks.dataprep.GoogleDataprepHook - - airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook - - airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook - - airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineSSHHook - - airflow.providers.google.cloud.hooks.bigquery.BigQueryHook - - airflow.providers.google.leveldb.hooks.leveldb.LevelDBHook connection-types: - hook-class-name: airflow.providers.google.common.hooks.base_google.GoogleBaseHook @@ -879,7 +964,9 @@ connection-types: extra-links: - airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink - airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink - - airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink + - airflow.providers.google.cloud.links.dataform.DataformRepositoryLink + - airflow.providers.google.cloud.links.dataform.DataformWorkspaceLink + - airflow.providers.google.cloud.links.dataform.DataformWorkflowInvocationLink - airflow.providers.google.cloud.operators.datafusion.DataFusionInstanceLink - airflow.providers.google.cloud.operators.datafusion.DataFusionPipelineLink - airflow.providers.google.cloud.operators.datafusion.DataFusionPipelinesLink @@ -887,15 +974,24 @@ extra-links: - airflow.providers.google.cloud.links.cloud_sql.CloudSQLInstanceDatabaseLink - airflow.providers.google.cloud.links.dataplex.DataplexTaskLink - airflow.providers.google.cloud.links.dataplex.DataplexTasksLink + - airflow.providers.google.cloud.links.dataplex.DataplexLakeLink - airflow.providers.google.cloud.links.bigquery.BigQueryDatasetLink - airflow.providers.google.cloud.links.bigquery.BigQueryTableLink - airflow.providers.google.cloud.links.bigquery_dts.BigQueryDataTransferConfigLink + - airflow.providers.google.cloud.links.compute.ComputeInstanceDetailsLink + - airflow.providers.google.cloud.links.compute.ComputeInstanceTemplateDetailsLink + - airflow.providers.google.cloud.links.compute.ComputeInstanceGroupManagerDetailsLink - airflow.providers.google.cloud.links.cloud_tasks.CloudTasksQueueLink - airflow.providers.google.cloud.links.cloud_tasks.CloudTasksLink + - airflow.providers.google.cloud.links.datacatalog.DataCatalogEntryGroupLink + - airflow.providers.google.cloud.links.datacatalog.DataCatalogEntryLink + - airflow.providers.google.cloud.links.datacatalog.DataCatalogTagTemplateLink - airflow.providers.google.cloud.links.dataproc.DataprocLink - airflow.providers.google.cloud.links.dataproc.DataprocListLink - airflow.providers.google.cloud.operators.dataproc_metastore.DataprocMetastoreDetailedLink - airflow.providers.google.cloud.operators.dataproc_metastore.DataprocMetastoreLink + - airflow.providers.google.cloud.links.dataprep.DataprepFlowLink + - airflow.providers.google.cloud.links.dataprep.DataprepJobGroupLink - airflow.providers.google.cloud.links.vertex_ai.VertexAIModelLink - airflow.providers.google.cloud.links.vertex_ai.VertexAIModelListLink - airflow.providers.google.cloud.links.vertex_ai.VertexAIModelExportLink @@ -908,6 +1004,9 @@ extra-links: - airflow.providers.google.cloud.links.vertex_ai.VertexAIBatchPredictionJobListLink - airflow.providers.google.cloud.links.vertex_ai.VertexAIEndpointLink - airflow.providers.google.cloud.links.vertex_ai.VertexAIEndpointListLink + - airflow.providers.google.cloud.links.workflows.WorkflowsWorkflowDetailsLink + - airflow.providers.google.cloud.links.workflows.WorkflowsListOfWorkflowsLink + - airflow.providers.google.cloud.links.workflows.WorkflowsExecutionLink - airflow.providers.google.cloud.operators.cloud_composer.CloudComposerEnvironmentLink - airflow.providers.google.cloud.operators.cloud_composer.CloudComposerEnvironmentsLink - airflow.providers.google.cloud.links.dataflow.DataflowJobLink @@ -920,14 +1019,59 @@ extra-links: - airflow.providers.google.cloud.links.spanner.SpannerInstanceLink - airflow.providers.google.cloud.links.stackdriver.StackdriverNotificationsLink - airflow.providers.google.cloud.links.stackdriver.StackdriverPoliciesLink + - airflow.providers.google.cloud.links.kubernetes_engine.KubernetesEngineClusterLink + - airflow.providers.google.cloud.links.kubernetes_engine.KubernetesEnginePodLink + - airflow.providers.google.cloud.links.pubsub.PubSubSubscriptionLink + - airflow.providers.google.cloud.links.pubsub.PubSubTopicLink + - airflow.providers.google.cloud.links.cloud_memorystore.MemcachedInstanceDetailsLink + - airflow.providers.google.cloud.links.cloud_memorystore.MemcachedInstanceListLink + - airflow.providers.google.cloud.links.cloud_memorystore.RedisInstanceDetailsLink + - airflow.providers.google.cloud.links.cloud_memorystore.RedisInstanceListLink + - airflow.providers.google.cloud.links.cloud_build.CloudBuildLink + - airflow.providers.google.cloud.links.cloud_build.CloudBuildListLink + - airflow.providers.google.cloud.links.cloud_build.CloudBuildTriggersListLink + - airflow.providers.google.cloud.links.cloud_build.CloudBuildTriggerDetailsLink + - airflow.providers.google.cloud.links.life_sciences.LifeSciencesLink + - airflow.providers.google.cloud.links.cloud_functions.CloudFunctionsDetailsLink + - airflow.providers.google.cloud.links.cloud_functions.CloudFunctionsListLink + - airflow.providers.google.cloud.links.cloud_storage_transfer.CloudStorageTransferListLink + - airflow.providers.google.cloud.links.cloud_storage_transfer.CloudStorageTransferJobLink + - airflow.providers.google.cloud.links.cloud_storage_transfer.CloudStorageTransferDetailsLink + - airflow.providers.google.cloud.links.data_loss_prevention.CloudDLPDeidentifyTemplatesListLink + - airflow.providers.google.cloud.links.data_loss_prevention.CloudDLPDeidentifyTemplateDetailsLink + - airflow.providers.google.cloud.links.data_loss_prevention.CloudDLPJobTriggersListLink + - airflow.providers.google.cloud.links.data_loss_prevention.CloudDLPJobTriggerDetailsLink + - airflow.providers.google.cloud.links.data_loss_prevention.CloudDLPJobsListLink + - airflow.providers.google.cloud.links.data_loss_prevention.CloudDLPJobDetailsLink + - airflow.providers.google.cloud.links.data_loss_prevention.CloudDLPInspectTemplatesListLink + - airflow.providers.google.cloud.links.data_loss_prevention.CloudDLPInspectTemplateDetailsLink + - airflow.providers.google.cloud.links.data_loss_prevention.CloudDLPInfoTypesListLink + - airflow.providers.google.cloud.links.data_loss_prevention.CloudDLPInfoTypeDetailsLink + - airflow.providers.google.cloud.links.data_loss_prevention.CloudDLPPossibleInfoTypesListLink + - airflow.providers.google.cloud.links.mlengine.MLEngineModelLink + - airflow.providers.google.cloud.links.mlengine.MLEngineModelsListLink + - airflow.providers.google.cloud.links.mlengine.MLEngineJobDetailsLink + - airflow.providers.google.cloud.links.mlengine.MLEngineJobSListLink + - airflow.providers.google.cloud.links.mlengine.MLEngineModelVersionDetailsLink - airflow.providers.google.common.links.storage.StorageLink - airflow.providers.google.common.links.storage.FileDetailsLink additional-extras: - apache.beam: apache-beam[gcp] - leveldb: plyvel - facebook: apache-airlfow-providers-facebook>=2.2.0 - amazon: apache-airlfow-providers-facebook>=2.6.0 + - name: apache.beam + dependencies: + - apache-beam[gcp] + - name: leveldb + dependencies: + - plyvel + - name: oracle + dependencies: + - apache-airflow-providers-oracle>=3.1.0 + - name: facebook + dependencies: + - apache-airflow-providers-facebook>=2.2.0 + - name: amazon + dependencies: + - apache-airflow-providers-amazon>=2.6.0 secrets-backends: - airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend diff --git a/airflow/providers/google/suite/example_dags/example_gcs_to_gdrive.py b/airflow/providers/google/suite/example_dags/example_gcs_to_gdrive.py deleted file mode 100644 index e4602c464c736..0000000000000 --- a/airflow/providers/google/suite/example_dags/example_gcs_to_gdrive.py +++ /dev/null @@ -1,59 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example DAG using GoogleCloudStorageToGoogleDriveOperator. -""" -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.suite.transfers.gcs_to_gdrive import GCSToGoogleDriveOperator - -GCS_TO_GDRIVE_BUCKET = os.environ.get("GCS_TO_DRIVE_BUCKET", "example-object") - -with models.DAG( - "example_gcs_to_gdrive", - schedule_interval=None, # Override to match your needs, - start_date=datetime(2021, 1, 1), - catchup=False, - tags=['example'], -) as dag: - # [START howto_operator_gcs_to_gdrive_copy_single_file] - copy_single_file = GCSToGoogleDriveOperator( - task_id="copy_single_file", - source_bucket=GCS_TO_GDRIVE_BUCKET, - source_object="sales/january.avro", - destination_object="copied_sales/january-backup.avro", - ) - # [END howto_operator_gcs_to_gdrive_copy_single_file] - # [START howto_operator_gcs_to_gdrive_copy_files] - copy_files = GCSToGoogleDriveOperator( - task_id="copy_files", - source_bucket=GCS_TO_GDRIVE_BUCKET, - source_object="sales/*", - destination_object="copied_sales/", - ) - # [END howto_operator_gcs_to_gdrive_copy_files] - # [START howto_operator_gcs_to_gdrive_move_files] - move_files = GCSToGoogleDriveOperator( - task_id="move_files", - source_bucket=GCS_TO_GDRIVE_BUCKET, - source_object="sales/*.avro", - move_object=True, - ) - # [END howto_operator_gcs_to_gdrive_move_files] diff --git a/airflow/providers/google/suite/example_dags/example_gcs_to_sheets.py b/airflow/providers/google/suite/example_dags/example_gcs_to_sheets.py deleted file mode 100644 index 7385d49f29fce..0000000000000 --- a/airflow/providers/google/suite/example_dags/example_gcs_to_sheets.py +++ /dev/null @@ -1,53 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.transfers.sheets_to_gcs import GoogleSheetsToGCSOperator -from airflow.providers.google.suite.transfers.gcs_to_sheets import GCSToGoogleSheetsOperator - -BUCKET = os.environ.get("GCP_GCS_BUCKET", "example-test-bucket3") -SPREADSHEET_ID = os.environ.get("SPREADSHEET_ID", "example-spreadsheetID") -NEW_SPREADSHEET_ID = os.environ.get("NEW_SPREADSHEET_ID", "1234567890qwerty") - -with models.DAG( - "example_gcs_to_sheets", - start_date=datetime(2021, 1, 1), - schedule_interval='@once', # Override to match your needs - catchup=False, - tags=["example"], -) as dag: - - upload_sheet_to_gcs = GoogleSheetsToGCSOperator( - task_id="upload_sheet_to_gcs", - destination_bucket=BUCKET, - spreadsheet_id=SPREADSHEET_ID, - ) - - # [START upload_gcs_to_sheets] - upload_gcs_to_sheet = GCSToGoogleSheetsOperator( - task_id="upload_gcs_to_sheet", - bucket_name=BUCKET, - object_name="{{ task_instance.xcom_pull('upload_sheet_to_gcs')[0] }}", - spreadsheet_id=NEW_SPREADSHEET_ID, - ) - # [END upload_gcs_to_sheets] - - upload_sheet_to_gcs >> upload_gcs_to_sheet diff --git a/airflow/providers/google/suite/example_dags/example_local_to_drive.py b/airflow/providers/google/suite/example_dags/example_local_to_drive.py deleted file mode 100644 index 6e985b796a60b..0000000000000 --- a/airflow/providers/google/suite/example_dags/example_local_to_drive.py +++ /dev/null @@ -1,56 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example DAG using LocalFilesystemToGoogleDriveOperator. -""" - -from datetime import datetime -from pathlib import Path - -from airflow import models -from airflow.providers.google.suite.transfers.local_to_drive import LocalFilesystemToGoogleDriveOperator - -SINGLE_FILE_LOCAL_PATHS = [Path("test1")] -MULTIPLE_FILES_LOCAL_PATHS = [Path("test1"), Path("test2")] -DRIVE_FOLDER = Path("test-folder") - -with models.DAG( - "example_local_to_drive", - schedule_interval='@once', # Override to match your needs - start_date=datetime(2021, 1, 1), - catchup=False, - tags=["example"], -) as dag: - # [START howto_operator_local_to_drive_upload_single_file] - upload_single_file = LocalFilesystemToGoogleDriveOperator( - task_id="upload_single_file", - local_paths=SINGLE_FILE_LOCAL_PATHS, - drive_folder=DRIVE_FOLDER, - ) - # [END howto_operator_local_to_drive_upload_single_file] - - # [START howto_operator_local_to_drive_upload_multiple_files] - upload_multiple_files = LocalFilesystemToGoogleDriveOperator( - task_id="upload_multiple_files", - local_paths=MULTIPLE_FILES_LOCAL_PATHS, - drive_folder=DRIVE_FOLDER, - ignore_if_missing=True, - ) - # [END howto_operator_local_to_drive_upload_multiple_files] - - upload_single_file >> upload_multiple_files diff --git a/airflow/providers/google/suite/example_dags/example_sheets.py b/airflow/providers/google/suite/example_dags/example_sheets.py deleted file mode 100644 index e5b3f820b24de..0000000000000 --- a/airflow/providers/google/suite/example_dags/example_sheets.py +++ /dev/null @@ -1,75 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -from datetime import datetime - -from airflow import models -from airflow.operators.bash import BashOperator -from airflow.providers.google.cloud.transfers.sheets_to_gcs import GoogleSheetsToGCSOperator -from airflow.providers.google.suite.operators.sheets import GoogleSheetsCreateSpreadsheetOperator -from airflow.providers.google.suite.transfers.gcs_to_sheets import GCSToGoogleSheetsOperator - -GCS_BUCKET = os.environ.get("SHEETS_GCS_BUCKET", "test28397ye") -SPREADSHEET_ID = os.environ.get("SPREADSHEET_ID", "1234567890qwerty") -NEW_SPREADSHEET_ID = os.environ.get("NEW_SPREADSHEET_ID", "1234567890qwerty") - -SPREADSHEET = { - "properties": {"title": "Test1"}, - "sheets": [{"properties": {"title": "Sheet1"}}], -} - -with models.DAG( - "example_sheets_gcs", - schedule_interval='@once', # Override to match your needs, - start_date=datetime(2021, 1, 1), - catchup=False, - tags=["example"], -) as dag: - # [START upload_sheet_to_gcs] - upload_sheet_to_gcs = GoogleSheetsToGCSOperator( - task_id="upload_sheet_to_gcs", - destination_bucket=GCS_BUCKET, - spreadsheet_id=SPREADSHEET_ID, - ) - # [END upload_sheet_to_gcs] - - # [START create_spreadsheet] - create_spreadsheet = GoogleSheetsCreateSpreadsheetOperator( - task_id="create_spreadsheet", spreadsheet=SPREADSHEET - ) - # [END create_spreadsheet] - - # [START print_spreadsheet_url] - print_spreadsheet_url = BashOperator( - task_id="print_spreadsheet_url", - bash_command=f"echo {create_spreadsheet.output['spreadsheet_url']}", - ) - # [END print_spreadsheet_url] - - # [START upload_gcs_to_sheet] - upload_gcs_to_sheet = GCSToGoogleSheetsOperator( - task_id="upload_gcs_to_sheet", - bucket_name=GCS_BUCKET, - object_name="{{ task_instance.xcom_pull('upload_sheet_to_gcs')[0] }}", - spreadsheet_id=NEW_SPREADSHEET_ID, - ) - # [END upload_gcs_to_sheet] - - create_spreadsheet >> print_spreadsheet_url - upload_sheet_to_gcs >> upload_gcs_to_sheet diff --git a/airflow/providers/google/suite/example_dags/example_sql_to_sheets.py b/airflow/providers/google/suite/example_dags/example_sql_to_sheets.py deleted file mode 100644 index 88aae41d67c96..0000000000000 --- a/airflow/providers/google/suite/example_dags/example_sql_to_sheets.py +++ /dev/null @@ -1,42 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from datetime import datetime - -from airflow import models -from airflow.providers.google.suite.transfers.sql_to_sheets import SQLToGoogleSheetsOperator - -SQL = "select 1 as my_col" -NEW_SPREADSHEET_ID = "123" - -with models.DAG( - "example_sql_to_sheets", - start_date=datetime(2021, 1, 1), - schedule_interval=None, # Override to match your needs - catchup=False, - tags=["example"], -) as dag: - - # [START upload_sql_to_sheets] - upload_gcs_to_sheet = SQLToGoogleSheetsOperator( - task_id="upload_sql_to_sheet", - sql=SQL, - sql_conn_id="database_conn_id", - spreadsheet_id=NEW_SPREADSHEET_ID, - ) - # [END upload_sql_to_sheets] diff --git a/airflow/providers/google/suite/hooks/calendar.py b/airflow/providers/google/suite/hooks/calendar.py index 567cc8ade0581..485656716076c 100644 --- a/airflow/providers/google/suite/hooks/calendar.py +++ b/airflow/providers/google/suite/hooks/calendar.py @@ -15,11 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains a Google Calendar API hook""" +from __future__ import annotations from datetime import datetime -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Sequence from googleapiclient.discovery import build @@ -51,9 +51,9 @@ class GoogleCalendarHook(GoogleBaseHook): def __init__( self, api_version: str, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -70,32 +70,31 @@ def get_conn(self) -> Any: Retrieves connection to Google Calendar. :return: Google Calendar services object. - :rtype: Any """ if not self._conn: http_authorized = self._authorize() - self._conn = build('calendar', self.api_version, http=http_authorized, cache_discovery=False) + self._conn = build("calendar", self.api_version, http=http_authorized, cache_discovery=False) return self._conn def get_events( self, - calendar_id: str = 'primary', - i_cal_uid: Optional[str] = None, - max_attendees: Optional[int] = None, - max_results: Optional[int] = None, - order_by: Optional[str] = None, - private_extended_property: Optional[str] = None, - q: Optional[str] = None, - shared_extended_property: Optional[str] = None, - show_deleted: Optional[bool] = False, - show_hidden_invitation: Optional[bool] = False, - single_events: Optional[bool] = False, - sync_token: Optional[str] = None, - time_max: Optional[datetime] = None, - time_min: Optional[datetime] = None, - time_zone: Optional[str] = None, - updated_min: Optional[datetime] = None, + calendar_id: str = "primary", + i_cal_uid: str | None = None, + max_attendees: int | None = None, + max_results: int | None = None, + order_by: str | None = None, + private_extended_property: str | None = None, + q: str | None = None, + shared_extended_property: str | None = None, + show_deleted: bool | None = False, + show_hidden_invitation: bool | None = False, + single_events: bool | None = False, + sync_token: str | None = None, + time_max: datetime | None = None, + time_min: datetime | None = None, + time_zone: str | None = None, + updated_min: datetime | None = None, ) -> list: """ Gets events from Google Calendar from a single calendar_id @@ -126,7 +125,6 @@ def get_events( Default is no filter :param time_zone: Optional. Time zone used in response. Default is calendars time zone. :param updated_min: Optional. Lower bound for an event's last modification time - :rtype: List """ service = self.get_conn() page_token = None @@ -163,13 +161,13 @@ def get_events( def create_event( self, - event: Dict[str, Any], - calendar_id: str = 'primary', - conference_data_version: Optional[int] = 0, - max_attendees: Optional[int] = None, - send_notifications: Optional[bool] = False, - send_updates: Optional[str] = 'false', - supports_attachments: Optional[bool] = False, + event: dict[str, Any], + calendar_id: str = "primary", + conference_data_version: int | None = 0, + max_attendees: int | None = None, + send_notifications: bool | None = False, + send_updates: str | None = "false", + supports_attachments: bool | None = False, ) -> dict: """ Create event on the specified calendar @@ -184,7 +182,6 @@ def create_event( :param send_updates: Optional. Default is "false". Acceptable values as "all", "none", ``"externalOnly"`` https://developers.google.com/calendar/api/v3/reference/events#resource - :rtype: Dict """ if "start" not in event or "end" not in event: raise AirflowException( diff --git a/airflow/providers/google/suite/hooks/drive.py b/airflow/providers/google/suite/hooks/drive.py index 94390503aae67..9e3b4a5437d49 100644 --- a/airflow/providers/google/suite/hooks/drive.py +++ b/airflow/providers/google/suite/hooks/drive.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """Hook for Google Drive service""" -from typing import IO, Any, Optional, Sequence, Union +from __future__ import annotations + +from typing import IO, Any, Sequence from googleapiclient.discovery import Resource, build from googleapiclient.http import HttpRequest, MediaFileUpload @@ -43,14 +45,14 @@ class GoogleDriveHook(GoogleBaseHook): account from the list granting this role to the originating account. """ - _conn = None # type: Optional[Resource] + _conn: Resource | None = None def __init__( self, api_version: str = "v3", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -122,37 +124,50 @@ def get_media_request(self, file_id: str) -> HttpRequest: :param file_id: The Google Drive file id :return: request - :rtype: HttpRequest """ service = self.get_conn() request = service.files().get_media(fileId=file_id) return request - def exists(self, folder_id: str, file_name: str, drive_id: Optional[str] = None): + def exists( + self, folder_id: str, file_name: str, drive_id: str | None = None, *, include_trashed: bool = True + ) -> bool: """ Checks to see if a file exists within a Google Drive folder :param folder_id: The id of the Google Drive folder in which the file resides :param file_name: The name of a file in Google Drive :param drive_id: Optional. The id of the shared Google Drive in which the file resides. + :param include_trashed: Whether to include objects in trash or not, default True as in Google API. + :return: True if the file exists, False otherwise - :rtype: bool """ - return bool(self.get_file_id(folder_id=folder_id, file_name=file_name, drive_id=drive_id)) + return bool( + self.get_file_id( + folder_id=folder_id, file_name=file_name, include_trashed=include_trashed, drive_id=drive_id + ) + ) - def get_file_id(self, folder_id: str, file_name: str, drive_id: Optional[str] = None): + def get_file_id( + self, folder_id: str, file_name: str, drive_id: str | None = None, *, include_trashed: bool = True + ) -> dict: """ Returns the file id of a Google Drive file :param folder_id: The id of the Google Drive folder in which the file resides :param file_name: The name of a file in Google Drive :param drive_id: Optional. The id of the shared Google Drive in which the file resides. + :param include_trashed: Whether to include objects in trash or not, default True as in Google API. + :return: Google Drive file id if the file exists, otherwise None - :rtype: str if file exists else None """ query = f"name = '{file_name}'" if folder_id: query += f" and parents in '{folder_id}'" + + if not include_trashed: + query += " and trashed=false" + service = self.get_conn() if drive_id: files = ( @@ -176,8 +191,8 @@ def get_file_id(self, folder_id: str, file_name: str, drive_id: Optional[str] = .execute(num_retries=self.num_retries) ) file_metadata = {} - if files['files']: - file_metadata = {"id": files['files'][0]['id'], "mime_type": files['files'][0]['mimeType']} + if files["files"]: + file_metadata = {"id": files["files"][0]["id"], "mime_type": files["files"][0]["mimeType"]} return file_metadata def upload_file( @@ -200,7 +215,6 @@ def upload_file( :param resumable: True if this is a resumable upload. False means upload in a single request. :return: File ID - :rtype: str """ service = self.get_conn() directory_path, _, file_name = remote_location.rpartition("/") diff --git a/airflow/providers/google/suite/hooks/sheets.py b/airflow/providers/google/suite/hooks/sheets.py index 3e95407b25f4b..a4aed79e0457f 100644 --- a/airflow/providers/google/suite/hooks/sheets.py +++ b/airflow/providers/google/suite/hooks/sheets.py @@ -15,10 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """This module contains a Google Sheets API hook""" +from __future__ import annotations -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Sequence from googleapiclient.discovery import build @@ -49,10 +49,10 @@ class GSheetsHook(GoogleBaseHook): def __init__( self, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v4', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v4", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, ) -> None: super().__init__( gcp_conn_id=gcp_conn_id, @@ -69,11 +69,10 @@ def get_conn(self) -> Any: Retrieves connection to Google Sheets. :return: Google Sheets services object. - :rtype: Any """ if not self._conn: http_authorized = self._authorize() - self._conn = build('sheets', self.api_version, http=http_authorized, cache_discovery=False) + self._conn = build("sheets", self.api_version, http=http_authorized, cache_discovery=False) return self._conn @@ -81,9 +80,9 @@ def get_values( self, spreadsheet_id: str, range_: str, - major_dimension: str = 'DIMENSION_UNSPECIFIED', - value_render_option: str = 'FORMATTED_VALUE', - date_time_render_option: str = 'SERIAL_NUMBER', + major_dimension: str = "DIMENSION_UNSPECIFIED", + value_render_option: str = "FORMATTED_VALUE", + date_time_render_option: str = "SERIAL_NUMBER", ) -> list: """ Gets values from Google Sheet from a single range @@ -98,7 +97,6 @@ def get_values( :param date_time_render_option: Determines how dates should be rendered in the output. SERIAL_NUMBER or FORMATTED_STRING :return: An array of sheet values from the specified sheet. - :rtype: List """ service = self.get_conn() @@ -115,15 +113,15 @@ def get_values( .execute(num_retries=self.num_retries) ) - return response['values'] + return response.get("values", []) def batch_get_values( self, spreadsheet_id: str, - ranges: List, - major_dimension: str = 'DIMENSION_UNSPECIFIED', - value_render_option: str = 'FORMATTED_VALUE', - date_time_render_option: str = 'SERIAL_NUMBER', + ranges: list, + major_dimension: str = "DIMENSION_UNSPECIFIED", + value_render_option: str = "FORMATTED_VALUE", + date_time_render_option: str = "SERIAL_NUMBER", ) -> dict: """ Gets values from Google Sheet from a list of ranges @@ -138,7 +136,6 @@ def batch_get_values( :param date_time_render_option: Determines how dates should be rendered in the output. SERIAL_NUMBER or FORMATTED_STRING :return: Google Sheets API response. - :rtype: Dict """ service = self.get_conn() @@ -161,12 +158,12 @@ def update_values( self, spreadsheet_id: str, range_: str, - values: List, - major_dimension: str = 'ROWS', - value_input_option: str = 'RAW', + values: list, + major_dimension: str = "ROWS", + value_input_option: str = "RAW", include_values_in_response: bool = False, - value_render_option: str = 'FORMATTED_VALUE', - date_time_render_option: str = 'SERIAL_NUMBER', + value_render_option: str = "FORMATTED_VALUE", + date_time_render_option: str = "SERIAL_NUMBER", ) -> dict: """ Updates values from Google Sheet from a single range @@ -186,7 +183,6 @@ def update_values( :param date_time_render_option: Determines how dates should be rendered in the output. SERIAL_NUMBER or FORMATTED_STRING :return: Google Sheets API response. - :rtype: Dict """ service = self.get_conn() body = {"range": range_, "majorDimension": major_dimension, "values": values} @@ -211,13 +207,13 @@ def update_values( def batch_update_values( self, spreadsheet_id: str, - ranges: List, - values: List, - major_dimension: str = 'ROWS', - value_input_option: str = 'RAW', + ranges: list, + values: list, + major_dimension: str = "ROWS", + value_input_option: str = "RAW", include_values_in_response: bool = False, - value_render_option: str = 'FORMATTED_VALUE', - date_time_render_option: str = 'SERIAL_NUMBER', + value_render_option: str = "FORMATTED_VALUE", + date_time_render_option: str = "SERIAL_NUMBER", ) -> dict: """ Updates values from Google Sheet for multiple ranges @@ -237,7 +233,6 @@ def batch_update_values( :param date_time_render_option: Determines how dates should be rendered in the output. SERIAL_NUMBER or FORMATTED_STRING :return: Google Sheets API response. - :rtype: Dict """ if len(ranges) != len(values): raise AirflowException( @@ -270,13 +265,13 @@ def append_values( self, spreadsheet_id: str, range_: str, - values: List, - major_dimension: str = 'ROWS', - value_input_option: str = 'RAW', - insert_data_option: str = 'OVERWRITE', + values: list, + major_dimension: str = "ROWS", + value_input_option: str = "RAW", + insert_data_option: str = "OVERWRITE", include_values_in_response: bool = False, - value_render_option: str = 'FORMATTED_VALUE', - date_time_render_option: str = 'SERIAL_NUMBER', + value_render_option: str = "FORMATTED_VALUE", + date_time_render_option: str = "SERIAL_NUMBER", ) -> dict: """ Append values from Google Sheet from a single range @@ -298,7 +293,6 @@ def append_values( :param date_time_render_option: Determines how dates should be rendered in the output. SERIAL_NUMBER or FORMATTED_STRING :return: Google Sheets API response. - :rtype: Dict """ service = self.get_conn() body = {"range": range_, "majorDimension": major_dimension, "values": values} @@ -329,7 +323,6 @@ def clear(self, spreadsheet_id: str, range_: str) -> dict: :param spreadsheet_id: The Google Sheet ID to interact with :param range_: The A1 notation of the values to retrieve. :return: Google Sheets API response. - :rtype: Dict """ service = self.get_conn() @@ -350,7 +343,6 @@ def batch_clear(self, spreadsheet_id: str, ranges: list) -> dict: :param spreadsheet_id: The Google Sheet ID to interact with :param ranges: The A1 notation of the values to retrieve. :return: Google Sheets API response. - :rtype: Dict """ service = self.get_conn() body = {"ranges": ranges} @@ -379,7 +371,7 @@ def get_spreadsheet(self, spreadsheet_id: str): ) return response - def get_sheet_titles(self, spreadsheet_id: str, sheet_filter: Optional[List[str]] = None): + def get_sheet_titles(self, spreadsheet_id: str, sheet_filter: list[str] | None = None): """ Retrieves the sheet titles from a spreadsheet matching the given id and sheet filter. @@ -392,15 +384,15 @@ def get_sheet_titles(self, spreadsheet_id: str, sheet_filter: Optional[List[str] if sheet_filter: titles = [ - sh['properties']['title'] - for sh in response['sheets'] - if sh['properties']['title'] in sheet_filter + sh["properties"]["title"] + for sh in response["sheets"] + if sh["properties"]["title"] in sheet_filter ] else: - titles = [sh['properties']['title'] for sh in response['sheets']] + titles = [sh["properties"]["title"] for sh in response["sheets"]] return titles - def create_spreadsheet(self, spreadsheet: Dict[str, Any]) -> Dict[str, Any]: + def create_spreadsheet(self, spreadsheet: dict[str, Any]) -> dict[str, Any]: """ Creates a spreadsheet, returning the newly created spreadsheet. @@ -408,10 +400,10 @@ def create_spreadsheet(self, spreadsheet: Dict[str, Any]) -> Dict[str, Any]: https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets#Spreadsheet :return: An spreadsheet object. """ - self.log.info("Creating spreadsheet: %s", spreadsheet['properties']['title']) + self.log.info("Creating spreadsheet: %s", spreadsheet["properties"]["title"]) response = ( self.get_conn().spreadsheets().create(body=spreadsheet).execute(num_retries=self.num_retries) ) - self.log.info("Spreadsheet: %s created", spreadsheet['properties']['title']) + self.log.info("Spreadsheet: %s created", spreadsheet["properties"]["title"]) return response diff --git a/airflow/providers/google/suite/operators/sheets.py b/airflow/providers/google/suite/operators/sheets.py index 48f93af2839e9..71bdae50cc6fb 100644 --- a/airflow/providers/google/suite/operators/sheets.py +++ b/airflow/providers/google/suite/operators/sheets.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Sequence from airflow.models import BaseOperator from airflow.providers.google.suite.hooks.sheets import GSheetsHook @@ -53,10 +54,10 @@ class GoogleSheetsCreateSpreadsheetOperator(BaseOperator): def __init__( self, *, - spreadsheet: Dict[str, Any], + spreadsheet: dict[str, Any], gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -65,7 +66,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: Any) -> Dict[str, Any]: + def execute(self, context: Any) -> dict[str, Any]: hook = GSheetsHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/google/suite/sensors/drive.py b/airflow/providers/google/suite/sensors/drive.py index 5729deb316153..c00650ddbc8a9 100644 --- a/airflow/providers/google/suite/sensors/drive.py +++ b/airflow/providers/google/suite/sensors/drive.py @@ -16,8 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Drive sensors.""" +from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.providers.google.suite.hooks.drive import GoogleDriveHook from airflow.sensors.base import BaseSensorOperator @@ -49,22 +50,22 @@ class GoogleDriveFileExistenceSensor(BaseSensorOperator): """ template_fields: Sequence[str] = ( - 'folder_id', - 'file_name', - 'drive_id', - 'impersonation_chain', + "folder_id", + "file_name", + "drive_id", + "impersonation_chain", ) - ui_color = '#f0eee4' + ui_color = "#f0eee4" def __init__( self, *, folder_id: str, file_name: str, - drive_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + drive_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: @@ -76,8 +77,8 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def poke(self, context: 'Context') -> bool: - self.log.info('Sensor is checking for the file %s in the folder %s', self.file_name, self.folder_id) + def poke(self, context: Context) -> bool: + self.log.info("Sensor is checking for the file %s in the folder %s", self.file_name, self.folder_id) hook = GoogleDriveHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/google/suite/transfers/gcs_to_gdrive.py b/airflow/providers/google/suite/transfers/gcs_to_gdrive.py index f58119e1612d5..2130afb77d68b 100644 --- a/airflow/providers/google/suite/transfers/gcs_to_gdrive.py +++ b/airflow/providers/google/suite/transfers/gcs_to_gdrive.py @@ -16,8 +16,10 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Cloud Storage to Google Drive transfer operator.""" +from __future__ import annotations + import tempfile -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -90,11 +92,11 @@ def __init__( *, source_bucket: str, source_object: str, - destination_object: Optional[str] = None, + destination_object: str | None = None, move_object: bool = False, gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -106,10 +108,10 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - self.gcs_hook = None # type: Optional[GCSHook] - self.gdrive_hook = None # type: Optional[GoogleDriveHook] + self.gcs_hook: GCSHook | None = None + self.gdrive_hook: GoogleDriveHook | None = None - def execute(self, context: 'Context'): + def execute(self, context: Context): self.gcs_hook = GCSHook( gcp_conn_id=self.gcp_conn_id, diff --git a/airflow/providers/google/suite/transfers/gcs_to_sheets.py b/airflow/providers/google/suite/transfers/gcs_to_sheets.py index 591ae77ca9545..08326e91f391d 100644 --- a/airflow/providers/google/suite/transfers/gcs_to_sheets.py +++ b/airflow/providers/google/suite/transfers/gcs_to_sheets.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import csv from tempfile import NamedTemporaryFile -from typing import Any, Optional, Sequence, Union +from typing import Any, Sequence from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook @@ -66,8 +67,8 @@ def __init__( object_name: str, spreadsheet_range: str = "Sheet1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) diff --git a/airflow/providers/google/suite/transfers/local_to_drive.py b/airflow/providers/google/suite/transfers/local_to_drive.py index ce94793eecdf1..228def2c6080e 100644 --- a/airflow/providers/google/suite/transfers/local_to_drive.py +++ b/airflow/providers/google/suite/transfers/local_to_drive.py @@ -15,10 +15,11 @@ # specific language governing permissions and limitations # under the License. """This file contains Google Drive operators""" +from __future__ import annotations import os from pathlib import Path -from typing import TYPE_CHECKING, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowFailException from airflow.models import BaseOperator @@ -63,25 +64,24 @@ class LocalFilesystemToGoogleDriveOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account :return: Remote file ids after upload - :rtype: Sequence[str] """ template_fields = ( - 'local_paths', - 'drive_folder', + "local_paths", + "drive_folder", ) def __init__( self, - local_paths: Union[Sequence[Path], Sequence[str]], - drive_folder: Union[Path, str], + local_paths: Sequence[Path] | Sequence[str], + drive_folder: Path | str, gcp_conn_id: str = "google_cloud_default", delete: bool = False, ignore_if_missing: bool = False, chunk_size: int = 100 * 1024 * 1024, resumable: bool = False, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -95,7 +95,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: "Context") -> List[str]: + def execute(self, context: Context) -> list[str]: hook = GoogleDriveHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/google/suite/transfers/sql_to_sheets.py b/airflow/providers/google/suite/transfers/sql_to_sheets.py index 8384868199b6f..f8ee694408286 100644 --- a/airflow/providers/google/suite/transfers/sql_to_sheets.py +++ b/airflow/providers/google/suite/transfers/sql_to_sheets.py @@ -14,15 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +from __future__ import annotations import datetime import logging import numbers from contextlib import closing -from typing import Any, Iterable, Mapping, Optional, Sequence, Union +from typing import Any, Iterable, Mapping, Sequence -from airflow.operators.sql import BaseSQLOperator +from airflow.providers.common.sql.operators.sql import BaseSQLOperator from airflow.providers.google.suite.hooks.sheets import GSheetsHook @@ -68,12 +68,12 @@ def __init__( sql: str, spreadsheet_id: str, sql_conn_id: str, - parameters: Optional[Union[Mapping, Iterable]] = None, - database: Optional[str] = None, + parameters: Iterable | Mapping | None = None, + database: str | None = None, spreadsheet_range: str = "Sheet1", gcp_conn_id: str = "google_cloud_default", - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) diff --git a/airflow/providers/grpc/.latest-doc-only-change.txt b/airflow/providers/grpc/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/grpc/.latest-doc-only-change.txt +++ b/airflow/providers/grpc/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/grpc/CHANGELOG.rst b/airflow/providers/grpc/CHANGELOG.rst index 215b24d936a83..5f894e33365b3 100644 --- a/airflow/providers/grpc/CHANGELOG.rst +++ b/airflow/providers/grpc/CHANGELOG.rst @@ -16,9 +16,62 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +In GrpcHook, non-prefixed extra fields are supported and are preferred. E.g. ``auth_type`` will +be preferred if ``extra__grpc__auth_type`` is also present. + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Features +~~~~~~~~ + +* ``Look for 'extra__' instead of 'extra_' in 'get_field' (#27489)`` +* ``Allow and prefer non-prefixed extra fields for GrpcHook (#27045)`` + +Bug Fixes +~~~~~~~~~ + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.0.4 ..... diff --git a/airflow/providers/grpc/hooks/grpc.py b/airflow/providers/grpc/hooks/grpc.py index 1b575d73f4bdf..eb12fd92d8310 100644 --- a/airflow/providers/grpc/hooks/grpc.py +++ b/airflow/providers/grpc/hooks/grpc.py @@ -14,9 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """GRPC Hook""" -from typing import Any, Callable, Dict, Generator, List, Optional +from __future__ import annotations + +from typing import Any, Callable, Generator import grpc from google import auth as google_auth @@ -45,35 +46,31 @@ class GrpcHook(BaseHook): A callable that accepts the connection as its only arg. """ - conn_name_attr = 'grpc_conn_id' - default_conn_name = 'grpc_default' - conn_type = 'grpc' - hook_name = 'GRPC Connection' + conn_name_attr = "grpc_conn_id" + default_conn_name = "grpc_default" + conn_type = "grpc" + hook_name = "GRPC Connection" @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: + def get_connection_form_widgets() -> dict[str, Any]: """Returns connection widgets to add to connection form""" from flask_appbuilder.fieldwidgets import BS3TextFieldWidget from flask_babel import lazy_gettext from wtforms import StringField return { - "extra__grpc__auth_type": StringField( - lazy_gettext('Grpc Auth Type'), widget=BS3TextFieldWidget() - ), - "extra__grpc__credential_pem_file": StringField( - lazy_gettext('Credential Keyfile Path'), widget=BS3TextFieldWidget() - ), - "extra__grpc__scopes": StringField( - lazy_gettext('Scopes (comma separated)'), widget=BS3TextFieldWidget() + "auth_type": StringField(lazy_gettext("Grpc Auth Type"), widget=BS3TextFieldWidget()), + "credential_pem_file": StringField( + lazy_gettext("Credential Keyfile Path"), widget=BS3TextFieldWidget() ), + "scopes": StringField(lazy_gettext("Scopes (comma separated)"), widget=BS3TextFieldWidget()), } def __init__( self, grpc_conn_id: str = default_conn_name, - interceptors: Optional[List[Callable]] = None, - custom_connection_func: Optional[Callable] = None, + interceptors: list[Callable] | None = None, + custom_connection_func: Callable | None = None, ) -> None: super().__init__() self.grpc_conn_id = grpc_conn_id @@ -125,7 +122,7 @@ def get_conn(self) -> grpc.Channel: return channel def run( - self, stub_class: Callable, call_func: str, streaming: bool = False, data: Optional[dict] = None + self, stub_class: Callable, call_func: str, streaming: bool = False, data: dict | None = None ) -> Generator: """Call gRPC function and yield response to caller""" if data is None: @@ -150,12 +147,17 @@ def run( ) raise ex - def _get_field(self, field_name: str) -> str: - """ - Fetches a field from extras, and returns it. This is some Airflow - magic. The grpc hook type adds custom UI elements - to the hook page, which allow admins to specify scopes, credential pem files, etc. - They get formatted as shown below. - """ - full_field_name = f'extra__grpc__{field_name}' - return self.extras[full_field_name] + def _get_field(self, field_name: str): + """Get field from extra, first checking short name, then for backcompat we check for prefixed name.""" + backcompat_prefix = "extra__grpc__" + if field_name.startswith("extra__"): + raise ValueError( + f"Got prefixed name {field_name}; please remove the '{backcompat_prefix}' prefix " + "when using this method." + ) + if field_name in self.extras: + return self.extras[field_name] + prefixed_name = f"{backcompat_prefix}{field_name}" + if prefixed_name in self.extras: + return self.extras[prefixed_name] + raise KeyError(f"Param {field_name} not found in extra dict") diff --git a/airflow/providers/grpc/operators/grpc.py b/airflow/providers/grpc/operators/grpc.py index 5cca48b4fff31..3a263a5195f24 100644 --- a/airflow/providers/grpc/operators/grpc.py +++ b/airflow/providers/grpc/operators/grpc.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Callable, Sequence from airflow.models import BaseOperator from airflow.providers.grpc.hooks.grpc import GrpcHook @@ -43,7 +44,7 @@ class GrpcOperator(BaseOperator): :param log_response: A flag to indicate if we need to log the response """ - template_fields: Sequence[str] = ('stub_class', 'call_func', 'data') + template_fields: Sequence[str] = ("stub_class", "call_func", "data") template_fields_renderers = {"data": "py"} def __init__( @@ -52,11 +53,11 @@ def __init__( stub_class: Callable, call_func: str, grpc_conn_id: str = "grpc_default", - data: Optional[dict] = None, - interceptors: Optional[List[Callable]] = None, - custom_connection_func: Optional[Callable] = None, + data: dict | None = None, + interceptors: list[Callable] | None = None, + custom_connection_func: Callable | None = None, streaming: bool = False, - response_callback: Optional[Callable] = None, + response_callback: Callable | None = None, log_response: bool = False, **kwargs, ) -> None: @@ -78,7 +79,7 @@ def _get_grpc_hook(self) -> GrpcHook: custom_connection_func=self.custom_connection_func, ) - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = self._get_grpc_hook() self.log.info("Calling gRPC service") @@ -88,7 +89,7 @@ def execute(self, context: 'Context') -> None: for response in responses: self._handle_response(response, context) - def _handle_response(self, response: Any, context: 'Context') -> None: + def _handle_response(self, response: Any, context: Context) -> None: if self.log_response: self.log.info(repr(response)) if self.response_callback: diff --git a/airflow/providers/grpc/provider.yaml b/airflow/providers/grpc/provider.yaml index 507124ff9bb2c..86cd9be1ecacd 100644 --- a/airflow/providers/grpc/provider.yaml +++ b/airflow/providers/grpc/provider.yaml @@ -22,6 +22,8 @@ description: | `gRPC `__ versions: + - 3.1.0 + - 3.0.0 - 2.0.4 - 2.0.3 - 2.0.2 @@ -30,9 +32,16 @@ versions: - 1.1.0 - 1.0.1 - 1.0.0 - -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + # Google has very clear rules on what dependencies should be used. All the limits below + # follow strict guidelines of Google Libraries as quoted here: + # While this issue is open, dependents of google-api-core, google-cloud-core. and google-auth + # should preserve >1, <3 pins on these packages. + # https://github.com/googleapis/google-cloud-python/issues/10566 + - google-auth>=1.0.0, <3.0.0 + - google-auth-httplib2>=0.0.1 + - grpcio>=1.15.0 integrations: - integration-name: gRPC @@ -49,9 +58,6 @@ hooks: python-modules: - airflow.providers.grpc.hooks.grpc -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.grpc.hooks.grpc.GrpcHook - connection-types: - hook-class-name: airflow.providers.grpc.hooks.grpc.GrpcHook connection-type: grpc diff --git a/airflow/providers/hashicorp/.latest-doc-only-change.txt b/airflow/providers/hashicorp/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/hashicorp/.latest-doc-only-change.txt +++ b/airflow/providers/hashicorp/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/hashicorp/CHANGELOG.rst b/airflow/providers/hashicorp/CHANGELOG.rst index 8ed5448024e76..bef98f6564b67 100644 --- a/airflow/providers/hashicorp/CHANGELOG.rst +++ b/airflow/providers/hashicorp/CHANGELOG.rst @@ -16,9 +16,74 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.2.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` +* ``Add Airflow specific warning classes (#25799)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Pass kwargs from vault hook to hvac client (#26680)`` + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Use newer kubernetes authentication method in internal vault client (#25351)`` + + +3.0.1 +..... + +Bug Fixes +~~~~~~~~~ + +* ``Update providers to use functools compat for ''cached_property'' (#24582)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare provider documentation 2022.05.11 (#23631)`` + * ``pydocstyle D202 added (#24221)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.2.0 ..... @@ -27,11 +92,11 @@ Features * ``Update secrets backends to use get_conn_value instead of get_conn_uri (#22348)`` -.. Review and move the new changes to one of the sections above: * ``Prepare mid-April provider documentation. (#22819)`` * ``Clean up in-line f-string concatenation (#23591)`` * ``Use new Breese for building, pulling and verifying the images. (#23104)`` + 2.1.4 ..... diff --git a/airflow/providers/hashicorp/_internal_client/vault_client.py b/airflow/providers/hashicorp/_internal_client/vault_client.py index ee36c21f7e8d5..076a8696667f4 100644 --- a/airflow/providers/hashicorp/_internal_client/vault_client.py +++ b/airflow/providers/hashicorp/_internal_client/vault_client.py @@ -14,37 +14,32 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys -from typing import List, Optional +from __future__ import annotations import hvac - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - +from hvac.api.auth_methods import Kubernetes from hvac.exceptions import InvalidPath, VaultError from requests import Response +from airflow.compat.functools import cached_property from airflow.utils.log.logging_mixin import LoggingMixin -DEFAULT_KUBERNETES_JWT_PATH = '/var/run/secrets/kubernetes.io/serviceaccount/token' +DEFAULT_KUBERNETES_JWT_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/token" DEFAULT_KV_ENGINE_VERSION = 2 -VALID_KV_VERSIONS: List[int] = [1, 2] -VALID_AUTH_TYPES: List[str] = [ - 'approle', - 'aws_iam', - 'azure', - 'github', - 'gcp', - 'kubernetes', - 'ldap', - 'radius', - 'token', - 'userpass', +VALID_KV_VERSIONS: list[int] = [1, 2] +VALID_AUTH_TYPES: list[str] = [ + "approle", + "aws_iam", + "azure", + "github", + "gcp", + "kubernetes", + "ldap", + "radius", + "token", + "userpass", ] @@ -91,28 +86,28 @@ class _VaultClient(LoggingMixin): def __init__( self, - url: Optional[str] = None, - auth_type: str = 'token', - auth_mount_point: Optional[str] = None, + url: str | None = None, + auth_type: str = "token", + auth_mount_point: str | None = None, mount_point: str = "secret", - kv_engine_version: Optional[int] = None, - token: Optional[str] = None, - token_path: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, - key_id: Optional[str] = None, - secret_id: Optional[str] = None, - role_id: Optional[str] = None, - kubernetes_role: Optional[str] = None, - kubernetes_jwt_path: Optional[str] = '/var/run/secrets/kubernetes.io/serviceaccount/token', - gcp_key_path: Optional[str] = None, - gcp_keyfile_dict: Optional[dict] = None, - gcp_scopes: Optional[str] = None, - azure_tenant_id: Optional[str] = None, - azure_resource: Optional[str] = None, - radius_host: Optional[str] = None, - radius_secret: Optional[str] = None, - radius_port: Optional[int] = None, + kv_engine_version: int | None = None, + token: str | None = None, + token_path: str | None = None, + username: str | None = None, + password: str | None = None, + key_id: str | None = None, + secret_id: str | None = None, + role_id: str | None = None, + kubernetes_role: str | None = None, + kubernetes_jwt_path: str | None = "/var/run/secrets/kubernetes.io/serviceaccount/token", + gcp_key_path: str | None = None, + gcp_keyfile_dict: dict | None = None, + gcp_scopes: str | None = None, + azure_tenant_id: str | None = None, + azure_resource: str | None = None, + radius_host: str | None = None, + radius_secret: str | None = None, + radius_port: int | None = None, **kwargs, ): super().__init__() @@ -178,14 +173,13 @@ def client(self): it is still authenticated to Vault, and invalidates the cache if this is not the case. - :rtype: hvac.Client :return: Vault Client """ if not self._client.is_authenticated(): # Invalidate the cache: # https://github.com/pydanny/cached-property#invalidating-the-cache - self.__dict__.pop('_client', None) + self.__dict__.pop("_client", None) return self._client @cached_property @@ -193,16 +187,15 @@ def _client(self) -> hvac.Client: """ Return an authenticated Hashicorp Vault client. - :rtype: hvac.Client :return: Vault Client """ _client = hvac.Client(url=self.url, **self.kwargs) if self.auth_type == "approle": self._auth_approle(_client) - elif self.auth_type == 'aws_iam': + elif self.auth_type == "aws_iam": self._auth_aws_iam(_client) - elif self.auth_type == 'azure': + elif self.auth_type == "azure": self._auth_azure(_client) elif self.auth_type == "gcp": self._auth_gcp(_client) @@ -261,9 +254,11 @@ def _auth_kubernetes(self, _client: hvac.Client) -> None: with open(self.kubernetes_jwt_path) as f: jwt = f.read().strip() if self.auth_mount_point: - _client.auth_kubernetes(role=self.kubernetes_role, jwt=jwt, mount_point=self.auth_mount_point) + Kubernetes(_client.adapter).login( + role=self.kubernetes_role, jwt=jwt, mount_point=self.auth_mount_point + ) else: - _client.auth_kubernetes(role=self.kubernetes_role, jwt=jwt) + Kubernetes(_client.adapter).login(role=self.kubernetes_role, jwt=jwt) def _auth_github(self, _client: hvac.Client) -> None: if self.auth_mount_point: @@ -329,7 +324,7 @@ def _set_token(self, _client: hvac.Client) -> None: else: _client.token = self.token - def get_secret(self, secret_path: str, secret_version: Optional[int] = None) -> Optional[dict]: + def get_secret(self, secret_path: str, secret_version: int | None = None) -> dict | None: """ Get secret value from the KV engine. @@ -360,12 +355,11 @@ def get_secret(self, secret_path: str, secret_version: Optional[int] = None) -> return_data = response["data"] if self.kv_engine_version == 1 else response["data"]["data"] return return_data - def get_secret_metadata(self, secret_path: str) -> Optional[dict]: + def get_secret_metadata(self, secret_path: str) -> dict | None: """ Reads secret metadata (including versions) from the engine. It is only valid for KV version 2. :param secret_path: The path of the secret. - :rtype: dict :return: secret metadata. This is a Dict containing metadata for the secret. See https://hvac.readthedocs.io/en/stable/usage/secrets_engines/kv_v2.html for details. @@ -382,8 +376,8 @@ def get_secret_metadata(self, secret_path: str) -> Optional[dict]: return None def get_secret_including_metadata( - self, secret_path: str, secret_version: Optional[int] = None - ) -> Optional[dict]: + self, secret_path: str, secret_version: int | None = None + ) -> dict | None: """ Reads secret including metadata. It is only valid for KV version 2. @@ -392,7 +386,6 @@ def get_secret_including_metadata( :param secret_path: The path of the secret. :param secret_version: Specifies the version of Secret to return. If not set, the latest version is returned. (Can only be used in case of version 2 of KV). - :rtype: dict :return: The key info. This is a Dict with "data" mapping keeping secret and "metadata" mapping keeping metadata of the secret. """ @@ -412,7 +405,7 @@ def get_secret_including_metadata( return None def create_or_update_secret( - self, secret_path: str, secret: dict, method: Optional[str] = None, cas: Optional[int] = None + self, secret_path: str, secret: dict, method: str | None = None, cas: int | None = None ) -> Response: """ Creates or updates secret. @@ -426,7 +419,6 @@ def create_or_update_secret( allowed. If set to 0 a write will only be allowed if the key doesn't exist. If the index is non-zero the write will only be allowed if the key's current version matches the version specified in the cas parameter. Only valid for KV engine version 2. - :rtype: requests.Response :return: The response of the create_or_update_secret request. See https://hvac.readthedocs.io/en/stable/usage/secrets_engines/kv_v1.html diff --git a/airflow/providers/hashicorp/hooks/vault.py b/airflow/providers/hashicorp/hooks/vault.py index ce351ff2c52cb..36022d41417d0 100644 --- a/airflow/providers/hashicorp/hooks/vault.py +++ b/airflow/providers/hashicorp/hooks/vault.py @@ -14,11 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Hook for HashiCorp Vault""" +from __future__ import annotations + import json import warnings -from typing import Optional, Tuple import hvac from hvac.exceptions import VaultError @@ -30,6 +30,7 @@ DEFAULT_KV_ENGINE_VERSION, _VaultClient, ) +from airflow.utils.helpers import merge_dicts class VaultHook(BaseHook): @@ -97,36 +98,37 @@ class VaultHook(BaseHook): """ - conn_name_attr = 'vault_conn_id' - default_conn_name = 'vault_default' - conn_type = 'vault' - hook_name = 'Hashicorp Vault' + conn_name_attr = "vault_conn_id" + default_conn_name = "vault_default" + conn_type = "vault" + hook_name = "Hashicorp Vault" def __init__( self, vault_conn_id: str = default_conn_name, - auth_type: Optional[str] = None, - auth_mount_point: Optional[str] = None, - kv_engine_version: Optional[int] = None, - role_id: Optional[str] = None, - kubernetes_role: Optional[str] = None, - kubernetes_jwt_path: Optional[str] = None, - token_path: Optional[str] = None, - gcp_key_path: Optional[str] = None, - gcp_scopes: Optional[str] = None, - azure_tenant_id: Optional[str] = None, - azure_resource: Optional[str] = None, - radius_host: Optional[str] = None, - radius_port: Optional[int] = None, + auth_type: str | None = None, + auth_mount_point: str | None = None, + kv_engine_version: int | None = None, + role_id: str | None = None, + kubernetes_role: str | None = None, + kubernetes_jwt_path: str | None = None, + token_path: str | None = None, + gcp_key_path: str | None = None, + gcp_scopes: str | None = None, + azure_tenant_id: str | None = None, + azure_resource: str | None = None, + radius_host: str | None = None, + radius_port: int | None = None, + **kwargs, ): super().__init__() self.connection = self.get_connection(vault_conn_id) if not auth_type: - auth_type = self.connection.extra_dejson.get('auth_type') or "token" + auth_type = self.connection.extra_dejson.get("auth_type") or "token" if not auth_mount_point: - auth_mount_point = self.connection.extra_dejson.get('auth_mount_point') + auth_mount_point = self.connection.extra_dejson.get("auth_mount_point") if not kv_engine_version: conn_version = self.connection.extra_dejson.get("kv_engine_version") @@ -135,6 +137,11 @@ def __init__( except ValueError: raise VaultError(f"The version is not an int: {conn_version}. ") + client_kwargs = self.connection.extra_dejson.get("client_kwargs", {}) + + if kwargs: + client_kwargs = merge_dicts(client_kwargs, kwargs) + if auth_type == "approle": if role_id: warnings.warn( @@ -143,8 +150,8 @@ def __init__( DeprecationWarning, stacklevel=2, ) - elif self.connection.extra_dejson.get('role_id'): - role_id = self.connection.extra_dejson.get('role_id') + elif self.connection.extra_dejson.get("role_id"): + role_id = self.connection.extra_dejson.get("role_id") warnings.warn( """The usage of role_id in connection extra for AppRole authentication has been deprecated. Please use connection login.""", @@ -156,37 +163,41 @@ def __init__( if auth_type == "aws_iam": if not role_id: - role_id = self.connection.extra_dejson.get('role_id') + role_id = self.connection.extra_dejson.get("role_id") azure_resource, azure_tenant_id = ( self._get_azure_parameters_from_connection(azure_resource, azure_tenant_id) - if auth_type == 'azure' + if auth_type == "azure" else (None, None) ) gcp_key_path, gcp_keyfile_dict, gcp_scopes = ( self._get_gcp_parameters_from_connection(gcp_key_path, gcp_scopes) - if auth_type == 'gcp' + if auth_type == "gcp" else (None, None, None) ) kubernetes_jwt_path, kubernetes_role = ( self._get_kubernetes_parameters_from_connection(kubernetes_jwt_path, kubernetes_role) - if auth_type == 'kubernetes' + if auth_type == "kubernetes" else (None, None) ) radius_host, radius_port = ( self._get_radius_parameters_from_connection(radius_host, radius_port) - if auth_type == 'radius' + if auth_type == "radius" else (None, None) ) - if self.connection.conn_type == 'vault': - conn_protocol = 'http' - elif self.connection.conn_type == 'vaults': - conn_protocol = 'https' - elif self.connection.conn_type == 'http': - conn_protocol = 'http' - elif self.connection.conn_type == 'https': - conn_protocol = 'https' + key_id = self.connection.extra_dejson.get("key_id") + if not key_id: + key_id = self.connection.login + + if self.connection.conn_type == "vault": + conn_protocol = "http" + elif self.connection.conn_type == "vaults": + conn_protocol = "https" + elif self.connection.conn_type == "http": + conn_protocol = "http" + elif self.connection.conn_type == "https": + conn_protocol = "https" else: raise VaultError("The url schema must be one of ['http', 'https', 'vault', 'vaults' ]") @@ -195,36 +206,40 @@ def __init__( url += f":{self.connection.port}" # Schema is really path in the Connection definition. This is pretty confusing because of URL schema - mount_point = self.connection.schema if self.connection.schema else 'secret' - - self.vault_client = _VaultClient( - url=url, - auth_type=auth_type, - auth_mount_point=auth_mount_point, - mount_point=mount_point, - kv_engine_version=kv_engine_version, - token=self.connection.password, - token_path=token_path, - username=self.connection.login, - password=self.connection.password, - key_id=self.connection.login, - secret_id=self.connection.password, - role_id=role_id, - kubernetes_role=kubernetes_role, - kubernetes_jwt_path=kubernetes_jwt_path, - gcp_key_path=gcp_key_path, - gcp_keyfile_dict=gcp_keyfile_dict, - gcp_scopes=gcp_scopes, - azure_tenant_id=azure_tenant_id, - azure_resource=azure_resource, - radius_host=radius_host, - radius_secret=self.connection.password, - radius_port=radius_port, + mount_point = self.connection.schema if self.connection.schema else "secret" + + client_kwargs.update( + **dict( + url=url, + auth_type=auth_type, + auth_mount_point=auth_mount_point, + mount_point=mount_point, + kv_engine_version=kv_engine_version, + token=self.connection.password, + token_path=token_path, + username=self.connection.login, + password=self.connection.password, + key_id=self.connection.login, + secret_id=self.connection.password, + role_id=role_id, + kubernetes_role=kubernetes_role, + kubernetes_jwt_path=kubernetes_jwt_path, + gcp_key_path=gcp_key_path, + gcp_keyfile_dict=gcp_keyfile_dict, + gcp_scopes=gcp_scopes, + azure_tenant_id=azure_tenant_id, + azure_resource=azure_resource, + radius_host=radius_host, + radius_secret=self.connection.password, + radius_port=radius_port, + ) ) + self.vault_client = _VaultClient(**client_kwargs) + def _get_kubernetes_parameters_from_connection( - self, kubernetes_jwt_path: Optional[str], kubernetes_role: Optional[str] - ) -> Tuple[str, Optional[str]]: + self, kubernetes_jwt_path: str | None, kubernetes_role: str | None + ) -> tuple[str, str | None]: if not kubernetes_jwt_path: kubernetes_jwt_path = self.connection.extra_dejson.get("kubernetes_jwt_path") if not kubernetes_jwt_path: @@ -235,9 +250,9 @@ def _get_kubernetes_parameters_from_connection( def _get_gcp_parameters_from_connection( self, - gcp_key_path: Optional[str], - gcp_scopes: Optional[str], - ) -> Tuple[Optional[str], Optional[dict], Optional[str]]: + gcp_key_path: str | None, + gcp_scopes: str | None, + ) -> tuple[str | None, dict | None, str | None]: if not gcp_scopes: gcp_scopes = self.connection.extra_dejson.get("gcp_scopes") if not gcp_key_path: @@ -247,8 +262,8 @@ def _get_gcp_parameters_from_connection( return gcp_key_path, gcp_keyfile_dict, gcp_scopes def _get_azure_parameters_from_connection( - self, azure_resource: Optional[str], azure_tenant_id: Optional[str] - ) -> Tuple[Optional[str], Optional[str]]: + self, azure_resource: str | None, azure_tenant_id: str | None + ) -> tuple[str | None, str | None]: if not azure_tenant_id: azure_tenant_id = self.connection.extra_dejson.get("azure_tenant_id") if not azure_resource: @@ -256,8 +271,8 @@ def _get_azure_parameters_from_connection( return azure_resource, azure_tenant_id def _get_radius_parameters_from_connection( - self, radius_host: Optional[str], radius_port: Optional[int] - ) -> Tuple[Optional[str], Optional[int]]: + self, radius_host: str | None, radius_port: int | None + ) -> tuple[str | None, int | None]: if not radius_port: radius_port_str = self.connection.extra_dejson.get("radius_port") if radius_port_str: @@ -273,12 +288,11 @@ def get_conn(self) -> hvac.Client: """ Retrieves connection to Vault. - :rtype: hvac.Client :return: connection used. """ return self.vault_client.client - def get_secret(self, secret_path: str, secret_version: Optional[int] = None) -> Optional[dict]: + def get_secret(self, secret_path: str, secret_version: int | None = None) -> dict | None: """ Get secret value from the engine. @@ -289,17 +303,15 @@ def get_secret(self, secret_path: str, secret_version: Optional[int] = None) -> and https://hvac.readthedocs.io/en/stable/usage/secrets_engines/kv_v2.html for details. :param secret_path: Path of the secret - :rtype: dict :return: secret stored in the vault as a dictionary """ return self.vault_client.get_secret(secret_path=secret_path, secret_version=secret_version) - def get_secret_metadata(self, secret_path: str) -> Optional[dict]: + def get_secret_metadata(self, secret_path: str) -> dict | None: """ Reads secret metadata (including versions) from the engine. It is only valid for KV version 2. :param secret_path: Path to read from - :rtype: dict :return: secret metadata. This is a Dict containing metadata for the secret. See https://hvac.readthedocs.io/en/stable/usage/secrets_engines/kv_v2.html for details. @@ -308,8 +320,8 @@ def get_secret_metadata(self, secret_path: str) -> Optional[dict]: return self.vault_client.get_secret_metadata(secret_path=secret_path) def get_secret_including_metadata( - self, secret_path: str, secret_version: Optional[int] = None - ) -> Optional[dict]: + self, secret_path: str, secret_version: int | None = None + ) -> dict | None: """ Reads secret including metadata. It is only valid for KV version 2. @@ -317,7 +329,6 @@ def get_secret_including_metadata( :param secret_path: Path of the secret :param secret_version: Optional version of key to read - can only be used in case of version 2 of KV - :rtype: dict :return: key info. This is a Dict with "data" mapping keeping secret and "metadata" mapping keeping metadata of the secret. @@ -327,7 +338,7 @@ def get_secret_including_metadata( ) def create_or_update_secret( - self, secret_path: str, secret: dict, method: Optional[str] = None, cas: Optional[int] = None + self, secret_path: str, secret: dict, method: str | None = None, cas: int | None = None ) -> Response: """ Creates or updates secret. @@ -341,7 +352,6 @@ def create_or_update_secret( allowed. If set to 0 a write will only be allowed if the key doesn't exist. If the index is non-zero the write will only be allowed if the key's current version matches the version specified in the cas parameter. Only valid for KV engine version 2. - :rtype: requests.Response :return: The response of the create_or_update_secret request. See https://hvac.readthedocs.io/en/stable/usage/secrets_engines/kv_v1.html diff --git a/airflow/providers/hashicorp/provider.yaml b/airflow/providers/hashicorp/provider.yaml index dcb891e97aaa2..0706ece7ccde2 100644 --- a/airflow/providers/hashicorp/provider.yaml +++ b/airflow/providers/hashicorp/provider.yaml @@ -22,6 +22,10 @@ description: | Hashicorp including `Hashicorp Vault `__ versions: + - 3.2.0 + - 3.1.0 + - 3.0.1 + - 3.0.0 - 2.2.0 - 2.1.4 - 2.1.3 @@ -33,8 +37,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - hvac>=0.10 integrations: - integration-name: Hashicorp Vault @@ -47,9 +52,6 @@ hooks: python-modules: - airflow.providers.hashicorp.hooks.vault -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.hashicorp.hooks.vault.VaultHook - connection-types: - hook-class-name: airflow.providers.hashicorp.hooks.vault.VaultHook connection-type: vault diff --git a/airflow/providers/hashicorp/secrets/vault.py b/airflow/providers/hashicorp/secrets/vault.py index 4ff25d71b48e9..9c22ff71d66b8 100644 --- a/airflow/providers/hashicorp/secrets/vault.py +++ b/airflow/providers/hashicorp/secrets/vault.py @@ -16,8 +16,10 @@ # specific language governing permissions and limitations # under the License. """Objects relating to sourcing connections & variables from Hashicorp Vault""" +from __future__ import annotations + import warnings -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from airflow.providers.hashicorp._internal_client.vault_client import _VaultClient from airflow.secrets import BaseSecretsBackend @@ -86,44 +88,44 @@ class VaultBackend(BaseSecretsBackend, LoggingMixin): def __init__( self, - connections_path: str = 'connections', - variables_path: str = 'variables', - config_path: str = 'config', - url: Optional[str] = None, - auth_type: str = 'token', - auth_mount_point: Optional[str] = None, - mount_point: str = 'secret', + connections_path: str = "connections", + variables_path: str = "variables", + config_path: str = "config", + url: str | None = None, + auth_type: str = "token", + auth_mount_point: str | None = None, + mount_point: str = "secret", kv_engine_version: int = 2, - token: Optional[str] = None, - token_path: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, - key_id: Optional[str] = None, - secret_id: Optional[str] = None, - role_id: Optional[str] = None, - kubernetes_role: Optional[str] = None, - kubernetes_jwt_path: str = '/var/run/secrets/kubernetes.io/serviceaccount/token', - gcp_key_path: Optional[str] = None, - gcp_keyfile_dict: Optional[dict] = None, - gcp_scopes: Optional[str] = None, - azure_tenant_id: Optional[str] = None, - azure_resource: Optional[str] = None, - radius_host: Optional[str] = None, - radius_secret: Optional[str] = None, - radius_port: Optional[int] = None, + token: str | None = None, + token_path: str | None = None, + username: str | None = None, + password: str | None = None, + key_id: str | None = None, + secret_id: str | None = None, + role_id: str | None = None, + kubernetes_role: str | None = None, + kubernetes_jwt_path: str = "/var/run/secrets/kubernetes.io/serviceaccount/token", + gcp_key_path: str | None = None, + gcp_keyfile_dict: dict | None = None, + gcp_scopes: str | None = None, + azure_tenant_id: str | None = None, + azure_resource: str | None = None, + radius_host: str | None = None, + radius_secret: str | None = None, + radius_port: int | None = None, **kwargs, ): super().__init__() if connections_path is not None: - self.connections_path = connections_path.rstrip('/') + self.connections_path = connections_path.rstrip("/") else: self.connections_path = connections_path if variables_path is not None: - self.variables_path = variables_path.rstrip('/') + self.variables_path = variables_path.rstrip("/") else: self.variables_path = variables_path if config_path is not None: - self.config_path = config_path.rstrip('/') + self.config_path = config_path.rstrip("/") else: self.config_path = config_path self.mount_point = mount_point @@ -154,11 +156,10 @@ def __init__( **kwargs, ) - def get_response(self, conn_id: str) -> Optional[dict]: + def get_response(self, conn_id: str) -> dict | None: """ Get data from Vault - :rtype: dict :return: The data from the Vault path if exists """ if self.connections_path is None: @@ -167,21 +168,19 @@ def get_response(self, conn_id: str) -> Optional[dict]: secret_path = self.build_path(self.connections_path, conn_id) return self.vault_client.get_secret(secret_path=secret_path) - def get_conn_uri(self, conn_id: str) -> Optional[str]: + def get_conn_uri(self, conn_id: str) -> str | None: """ Get serialized representation of connection :param conn_id: The connection id - :rtype: str :return: The connection uri retrieved from the secret """ - # Since VaultBackend implements `get_connection`, `get_conn_uri` is not used. So we # don't need to implement (or direct users to use) method `get_conn_value` instead warnings.warn( f"Method `{self.__class__.__name__}.get_conn_uri` is deprecated and will be removed " "in a future release.", - PendingDeprecationWarning, + DeprecationWarning, stacklevel=2, ) response = self.get_response(conn_id) @@ -192,12 +191,11 @@ def get_conn_uri(self, conn_id: str) -> Optional[str]: if TYPE_CHECKING: from airflow.models.connection import Connection - def get_connection(self, conn_id: str) -> 'Optional[Connection]': + def get_connection(self, conn_id: str) -> Connection | None: """ Get connection from Vault as secret. Prioritize conn_uri if exists, if not fall back to normal Connection creation. - :rtype: Connection :return: A Connection object constructed from Vault data """ # The Connection needs to be locally imported because otherwise we get into cyclic import @@ -214,12 +212,11 @@ def get_connection(self, conn_id: str) -> 'Optional[Connection]': return Connection(conn_id, **response) - def get_variable(self, key: str) -> Optional[str]: + def get_variable(self, key: str) -> str | None: """ Get Airflow Variable :param key: Variable Key - :rtype: str :return: Variable Value retrieved from the vault """ if self.variables_path is None: @@ -229,12 +226,11 @@ def get_variable(self, key: str) -> Optional[str]: response = self.vault_client.get_secret(secret_path=secret_path) return response.get("value") if response else None - def get_config(self, key: str) -> Optional[str]: + def get_config(self, key: str) -> str | None: """ Get Airflow Configuration :param key: Configuration Option Key - :rtype: str :return: Configuration Option Value retrieved from the vault """ if self.config_path is None: diff --git a/airflow/providers/http/.latest-doc-only-change.txt b/airflow/providers/http/.latest-doc-only-change.txt index 95db1e9c1ea1d..ff7136e07d744 100644 --- a/airflow/providers/http/.latest-doc-only-change.txt +++ b/airflow/providers/http/.latest-doc-only-change.txt @@ -1 +1 @@ -d9567eb106929b21329c01171fd398fbef2dc6c6 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/http/CHANGELOG.rst b/airflow/providers/http/CHANGELOG.rst index 5d36f25eb3330..a3a693ba54039 100644 --- a/airflow/providers/http/CHANGELOG.rst +++ b/airflow/providers/http/CHANGELOG.rst @@ -16,9 +16,70 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +4.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +4.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +The SimpleHTTPOperator, HttpSensor and HttpHook use now TCP_KEEPALIVE by default. +You can disable it by setting ``tcp_keep_alive`` to False and you can control keepalive parameters +by new ``tcp_keep_alive_*`` parameters added to constructor of the Hook, Operator and Sensor. Setting the +TCP_KEEPALIVE prevents some firewalls from closing a long-running connection that has long periods of +inactivity by sending empty TCP packets periodically. This has a very small impact on network traffic, +and potentially prevents the idle/hanging connections from being closed automatically by the firewalls. + +* ``Add TCP_KEEPALIVE option to http provider (#24967)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``fix document about response_check in HttpSensor (#24708)`` + * ``Fix HttpHook.run_with_advanced_retry document error (#24380)`` + * ``Remove 'xcom_push' flag from providers (#24823)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Migrate HTTP example DAGs to new design AIP-47 (#23991)`` + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.1.2 ..... diff --git a/airflow/providers/http/hooks/http.py b/airflow/providers/http/hooks/http.py index 54c4b3125d77d..1ec77e4b46014 100644 --- a/airflow/providers/http/hooks/http.py +++ b/airflow/providers/http/hooks/http.py @@ -15,11 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict, Optional, Union +from __future__ import annotations + +from typing import Any, Callable import requests import tenacity from requests.auth import HTTPBasicAuth +from requests_toolbelt.adapters.socket_options import TCPKeepAliveAdapter from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook @@ -34,18 +37,27 @@ class HttpHook(BaseHook): API url i.e https://www.google.com/ and optional authentication credentials. Default headers can also be specified in the Extra field in json format. :param auth_type: The auth type for the service + :param tcp_keep_alive: Enable TCP Keep Alive for the connection. + :param tcp_keep_alive_idle: The TCP Keep Alive Idle parameter (corresponds to ``socket.TCP_KEEPIDLE``). + :param tcp_keep_alive_count: The TCP Keep Alive count parameter (corresponds to ``socket.TCP_KEEPCNT``) + :param tcp_keep_alive_interval: The TCP Keep Alive interval parameter (corresponds to + ``socket.TCP_KEEPINTVL``) """ - conn_name_attr = 'http_conn_id' - default_conn_name = 'http_default' - conn_type = 'http' - hook_name = 'HTTP' + conn_name_attr = "http_conn_id" + default_conn_name = "http_default" + conn_type = "http" + hook_name = "HTTP" def __init__( self, - method: str = 'POST', + method: str = "POST", http_conn_id: str = default_conn_name, auth_type: Any = HTTPBasicAuth, + tcp_keep_alive: bool = True, + tcp_keep_alive_idle: int = 120, + tcp_keep_alive_count: int = 20, + tcp_keep_alive_interval: int = 30, ) -> None: super().__init__() self.http_conn_id = http_conn_id @@ -53,10 +65,14 @@ def __init__( self.base_url: str = "" self._retry_obj: Callable[..., Any] self.auth_type: Any = auth_type + self.tcp_keep_alive = tcp_keep_alive + self.keep_alive_idle = tcp_keep_alive_idle + self.keep_alive_count = tcp_keep_alive_count + self.keep_alive_interval = tcp_keep_alive_interval # headers may be passed through directly or in the "extra" field in the connection # definition - def get_conn(self, headers: Optional[Dict[Any, Any]] = None) -> requests.Session: + def get_conn(self, headers: dict[Any, Any] | None = None) -> requests.Session: """ Returns http session for use with requests @@ -83,7 +99,7 @@ def get_conn(self, headers: Optional[Dict[Any, Any]] = None) -> requests.Session try: session.headers.update(conn.extra_dejson) except TypeError: - self.log.warning('Connection to %s has invalid extra field.', conn.host) + self.log.warning("Connection to %s has invalid extra field.", conn.host) if headers: session.headers.update(headers) @@ -91,10 +107,10 @@ def get_conn(self, headers: Optional[Dict[Any, Any]] = None) -> requests.Session def run( self, - endpoint: Optional[str] = None, - data: Optional[Union[Dict[str, Any], str]] = None, - headers: Optional[Dict[str, Any]] = None, - extra_options: Optional[Dict[str, Any]] = None, + endpoint: str | None = None, + data: dict[str, Any] | str | None = None, + headers: dict[str, Any] | None = None, + extra_options: dict[str, Any] | None = None, **request_kwargs: Any, ) -> Any: r""" @@ -115,10 +131,15 @@ def run( url = self.url_from_endpoint(endpoint) - if self.method == 'GET': + if self.tcp_keep_alive: + keep_alive_adapter = TCPKeepAliveAdapter( + idle=self.keep_alive_idle, count=self.keep_alive_count, interval=self.keep_alive_interval + ) + session.mount(url, keep_alive_adapter) + if self.method == "GET": # GET uses params req = requests.Request(self.method, url, params=data, headers=headers, **request_kwargs) - elif self.method == 'HEAD': + elif self.method == "HEAD": # HEAD doesn't use params req = requests.Request(self.method, url, headers=headers, **request_kwargs) else: @@ -147,7 +168,7 @@ def run_and_check( self, session: requests.Session, prepped_request: requests.PreparedRequest, - extra_options: Dict[Any, Any], + extra_options: dict[Any, Any], ) -> Any: """ Grabs extra options like timeout and actually runs the request, @@ -170,7 +191,7 @@ def run_and_check( ) # Send the request. - send_kwargs: Dict[str, Any] = { + send_kwargs: dict[str, Any] = { "timeout": extra_options.get("timeout"), "allow_redirects": extra_options.get("allow_redirects", True), } @@ -179,15 +200,15 @@ def run_and_check( try: response = session.send(prepped_request, **send_kwargs) - if extra_options.get('check_response', True): + if extra_options.get("check_response", True): self.check_response(response) return response except requests.exceptions.ConnectionError as ex: - self.log.warning('%s Tenacity will retry to execute the operation', ex) + self.log.warning("%s Tenacity will retry to execute the operation", ex) raise ex - def run_with_advanced_retry(self, _retry_args: Dict[Any, Any], *args: Any, **kwargs: Any) -> Any: + def run_with_advanced_retry(self, _retry_args: dict[Any, Any], *args: Any, **kwargs: Any) -> Any: """ Runs Hook.run() with a Tenacity decorator attached to it. This is useful for connectors which might be disturbed by intermittent issues and should not @@ -203,7 +224,7 @@ def run_with_advanced_retry(self, _retry_args: Dict[Any, Any], *args: Any, **kwa retry_args = dict( wait=tenacity.wait_exponential(), stop=tenacity.stop_after_attempt(10), - retry=requests.exceptions.ConnectionError, + retry=tenacity.retry_if_exception_type(Exception), ) hook.run_with_advanced_retry(endpoint="v1/test", _retry_args=retry_args) @@ -212,16 +233,16 @@ def run_with_advanced_retry(self, _retry_args: Dict[Any, Any], *args: Any, **kwa return self._retry_obj(self.run, *args, **kwargs) - def url_from_endpoint(self, endpoint: Optional[str]) -> str: + def url_from_endpoint(self, endpoint: str | None) -> str: """Combine base url with endpoint""" - if self.base_url and not self.base_url.endswith('/') and endpoint and not endpoint.startswith('/'): - return self.base_url + '/' + endpoint - return (self.base_url or '') + (endpoint or '') + if self.base_url and not self.base_url.endswith("/") and endpoint and not endpoint.startswith("/"): + return self.base_url + "/" + endpoint + return (self.base_url or "") + (endpoint or "") def test_connection(self): """Test HTTP Connection""" try: self.run() - return True, 'Connection successfully tested' + return True, "Connection successfully tested" except Exception as e: return False, str(e) diff --git a/airflow/providers/http/operators/http.py b/airflow/providers/http/operators/http.py index 0622b8e7d19e9..df108fd627698 100644 --- a/airflow/providers/http/operators/http.py +++ b/airflow/providers/http/operators/http.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Type +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Sequence from requests.auth import AuthBase, HTTPBasicAuth @@ -54,30 +56,39 @@ class SimpleHttpOperator(BaseOperator): 'requests' documentation (options to modify timeout, ssl, etc.) :param log_response: Log the response (default: False) :param auth_type: The auth type for the service + :param tcp_keep_alive: Enable TCP Keep Alive for the connection. + :param tcp_keep_alive_idle: The TCP Keep Alive Idle parameter (corresponds to ``socket.TCP_KEEPIDLE``). + :param tcp_keep_alive_count: The TCP Keep Alive count parameter (corresponds to ``socket.TCP_KEEPCNT``) + :param tcp_keep_alive_interval: The TCP Keep Alive interval parameter (corresponds to + ``socket.TCP_KEEPINTVL``) """ template_fields: Sequence[str] = ( - 'endpoint', - 'data', - 'headers', + "endpoint", + "data", + "headers", ) - template_fields_renderers = {'headers': 'json', 'data': 'py'} + template_fields_renderers = {"headers": "json", "data": "py"} template_ext: Sequence[str] = () - ui_color = '#f4a460' + ui_color = "#f4a460" def __init__( self, *, - endpoint: Optional[str] = None, - method: str = 'POST', + endpoint: str | None = None, + method: str = "POST", data: Any = None, - headers: Optional[Dict[str, str]] = None, - response_check: Optional[Callable[..., bool]] = None, - response_filter: Optional[Callable[..., Any]] = None, - extra_options: Optional[Dict[str, Any]] = None, - http_conn_id: str = 'http_default', + headers: dict[str, str] | None = None, + response_check: Callable[..., bool] | None = None, + response_filter: Callable[..., Any] | None = None, + extra_options: dict[str, Any] | None = None, + http_conn_id: str = "http_default", log_response: bool = False, - auth_type: Type[AuthBase] = HTTPBasicAuth, + auth_type: type[AuthBase] = HTTPBasicAuth, + tcp_keep_alive: bool = True, + tcp_keep_alive_idle: int = 120, + tcp_keep_alive_count: int = 20, + tcp_keep_alive_interval: int = 30, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -91,13 +102,23 @@ def __init__( self.extra_options = extra_options or {} self.log_response = log_response self.auth_type = auth_type - if kwargs.get('xcom_push') is not None: - raise AirflowException("'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead") + self.tcp_keep_alive = tcp_keep_alive + self.tcp_keep_alive_idle = tcp_keep_alive_idle + self.tcp_keep_alive_count = tcp_keep_alive_count + self.tcp_keep_alive_interval = tcp_keep_alive_interval - def execute(self, context: 'Context') -> Any: + def execute(self, context: Context) -> Any: from airflow.utils.operator_helpers import determine_kwargs - http = HttpHook(self.method, http_conn_id=self.http_conn_id, auth_type=self.auth_type) + http = HttpHook( + self.method, + http_conn_id=self.http_conn_id, + auth_type=self.auth_type, + tcp_keep_alive=self.tcp_keep_alive, + tcp_keep_alive_idle=self.tcp_keep_alive_idle, + tcp_keep_alive_count=self.tcp_keep_alive_count, + tcp_keep_alive_interval=self.tcp_keep_alive_interval, + ) self.log.info("Calling HTTP method") diff --git a/airflow/providers/http/provider.yaml b/airflow/providers/http/provider.yaml index f3cfdffb83d25..28d08832d7aa0 100644 --- a/airflow/providers/http/provider.yaml +++ b/airflow/providers/http/provider.yaml @@ -22,6 +22,9 @@ description: | `Hypertext Transfer Protocol (HTTP) `__ versions: + - 4.1.0 + - 4.0.0 + - 3.0.0 - 2.1.2 - 2.1.1 - 2.1.0 @@ -33,6 +36,12 @@ versions: - 1.1.0 - 1.0.0 +dependencies: + # The 2.26.0 release of requests got rid of the chardet LGPL mandatory dependency, allowing us to + # release it as a requirement for airflow + - requests>=2.26.0 + - requests_toolbelt + integrations: - integration-name: Hypertext Transfer Protocol (HTTP) external-doc-url: https://www.w3.org/Protocols/ @@ -55,9 +64,6 @@ hooks: python-modules: - airflow.providers.http.hooks.http -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.http.hooks.http.HttpHook - connection-types: - hook-class-name: airflow.providers.http.hooks.http.HttpHook connection-type: http diff --git a/airflow/providers/http/sensors/http.py b/airflow/providers/http/sensors/http.py index 0ca93b106e653..f691fd0a69225 100644 --- a/airflow/providers/http/sensors/http.py +++ b/airflow/providers/http/sensors/http.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Sequence from airflow.exceptions import AirflowException from airflow.providers.http.hooks.http import HttpHook @@ -38,15 +40,18 @@ class HttpSensor(BaseSensorOperator): The response check can access the template context to the operator: + .. code-block:: python + def response_check(response, task_instance): # The task_instance is injected, so you can pull data form xcom # Other context variables such as dag, ds, execution_date are also available. - xcom_data = task_instance.xcom_pull(task_ids='pushing_task') + xcom_data = task_instance.xcom_pull(task_ids="pushing_task") # In practice you would do something more sensible with this data.. print(xcom_data) return True - HttpSensor(task_id='my_http_sensor', ..., response_check=response_check) + + HttpSensor(task_id="my_http_sensor", ..., response_check=response_check) .. seealso:: For more information on how to use this operator, take a look at the guide: @@ -64,20 +69,29 @@ def response_check(response, task_instance): It should return True for 'pass' and False otherwise. :param extra_options: Extra options for the 'requests' library, see the 'requests' documentation (options to modify timeout, ssl, etc.) + :param tcp_keep_alive: Enable TCP Keep Alive for the connection. + :param tcp_keep_alive_idle: The TCP Keep Alive Idle parameter (corresponds to ``socket.TCP_KEEPIDLE``). + :param tcp_keep_alive_count: The TCP Keep Alive count parameter (corresponds to ``socket.TCP_KEEPCNT``) + :param tcp_keep_alive_interval: The TCP Keep Alive interval parameter (corresponds to + ``socket.TCP_KEEPINTVL``) """ - template_fields: Sequence[str] = ('endpoint', 'request_params', 'headers') + template_fields: Sequence[str] = ("endpoint", "request_params", "headers") def __init__( self, *, endpoint: str, - http_conn_id: str = 'http_default', - method: str = 'GET', - request_params: Optional[Dict[str, Any]] = None, - headers: Optional[Dict[str, Any]] = None, - response_check: Optional[Callable[..., bool]] = None, - extra_options: Optional[Dict[str, Any]] = None, + http_conn_id: str = "http_default", + method: str = "GET", + request_params: dict[str, Any] | None = None, + headers: dict[str, Any] | None = None, + response_check: Callable[..., bool] | None = None, + extra_options: dict[str, Any] | None = None, + tcp_keep_alive: bool = True, + tcp_keep_alive_idle: int = 120, + tcp_keep_alive_count: int = 20, + tcp_keep_alive_interval: int = 30, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -88,15 +102,26 @@ def __init__( self.headers = headers or {} self.extra_options = extra_options or {} self.response_check = response_check + self.tcp_keep_alive = tcp_keep_alive + self.tcp_keep_alive_idle = tcp_keep_alive_idle + self.tcp_keep_alive_count = tcp_keep_alive_count + self.tcp_keep_alive_interval = tcp_keep_alive_interval - self.hook = HttpHook(method=method, http_conn_id=http_conn_id) - - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: from airflow.utils.operator_helpers import determine_kwargs - self.log.info('Poking: %s', self.endpoint) + hook = HttpHook( + method=self.method, + http_conn_id=self.http_conn_id, + tcp_keep_alive=self.tcp_keep_alive, + tcp_keep_alive_idle=self.tcp_keep_alive_idle, + tcp_keep_alive_count=self.tcp_keep_alive_count, + tcp_keep_alive_interval=self.tcp_keep_alive_interval, + ) + + self.log.info("Poking: %s", self.endpoint) try: - response = self.hook.run( + response = hook.run( self.endpoint, data=self.request_params, headers=self.headers, diff --git a/airflow/providers/imap/.latest-doc-only-change.txt b/airflow/providers/imap/.latest-doc-only-change.txt index 7c8d093986bf5..ff7136e07d744 100644 --- a/airflow/providers/imap/.latest-doc-only-change.txt +++ b/airflow/providers/imap/.latest-doc-only-change.txt @@ -1 +1 @@ -5b2fe0e74013cd08d1f76f5c115f2c8f990ff9bc +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/imap/CHANGELOG.rst b/airflow/providers/imap/CHANGELOG.rst index d2cf6060f5ef2..32c36a482f373 100644 --- a/airflow/providers/imap/CHANGELOG.rst +++ b/airflow/providers/imap/CHANGELOG.rst @@ -16,9 +16,51 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.2.3 ..... diff --git a/airflow/providers/imap/hooks/imap.py b/airflow/providers/imap/hooks/imap.py index 56499ddac3fa4..59c056ba8d856 100644 --- a/airflow/providers/imap/hooks/imap.py +++ b/airflow/providers/imap/hooks/imap.py @@ -20,11 +20,13 @@ and also to download it. It uses the imaplib library that is already integrated in python 3. """ +from __future__ import annotations + import email import imaplib import os import re -from typing import Any, Iterable, List, Optional, Tuple, Type, Union +from typing import Any, Iterable from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook @@ -43,23 +45,23 @@ class ImapHook(BaseHook): that contains the information used to authenticate the client. """ - conn_name_attr = 'imap_conn_id' - default_conn_name = 'imap_default' - conn_type = 'imap' - hook_name = 'IMAP' + conn_name_attr = "imap_conn_id" + default_conn_name = "imap_default" + conn_type = "imap" + hook_name = "IMAP" def __init__(self, imap_conn_id: str = default_conn_name) -> None: super().__init__() self.imap_conn_id = imap_conn_id - self.mail_client: Optional[Union[imaplib.IMAP4_SSL, imaplib.IMAP4]] = None + self.mail_client: imaplib.IMAP4_SSL | imaplib.IMAP4 | None = None - def __enter__(self) -> 'ImapHook': + def __enter__(self) -> ImapHook: return self.get_conn() def __exit__(self, exc_type, exc_val, exc_tb): self.mail_client.logout() - def get_conn(self) -> 'ImapHook': + def get_conn(self) -> ImapHook: """ Login to the mail server. @@ -67,7 +69,6 @@ def get_conn(self) -> 'ImapHook': to automatically open and close the connection to the mail server. :return: an authorized ImapHook object. - :rtype: ImapHook """ if not self.mail_client: conn = self.get_connection(self.imap_conn_id) @@ -76,9 +77,9 @@ def get_conn(self) -> 'ImapHook': return self - def _build_client(self, conn: Connection) -> Union[imaplib.IMAP4_SSL, imaplib.IMAP4]: - IMAP: Union[Type[imaplib.IMAP4_SSL], Type[imaplib.IMAP4]] - if conn.extra_dejson.get('use_ssl', True): + def _build_client(self, conn: Connection) -> imaplib.IMAP4_SSL | imaplib.IMAP4: + IMAP: type[imaplib.IMAP4_SSL] | type[imaplib.IMAP4] + if conn.extra_dejson.get("use_ssl", True): IMAP = imaplib.IMAP4_SSL else: IMAP = imaplib.IMAP4 @@ -91,7 +92,7 @@ def _build_client(self, conn: Connection) -> Union[imaplib.IMAP4_SSL, imaplib.IM return mail_client def has_mail_attachment( - self, name: str, *, check_regex: bool = False, mail_folder: str = 'INBOX', mail_filter: str = 'All' + self, name: str, *, check_regex: bool = False, mail_folder: str = "INBOX", mail_filter: str = "All" ) -> bool: """ Checks the mail folder for mails containing attachments with the given name. @@ -102,7 +103,6 @@ def has_mail_attachment( :param mail_filter: If set other than 'All' only specific mails will be checked. See :py:meth:`imaplib.IMAP4.search` for details. :returns: True if there is an attachment with the given name and False if not. - :rtype: bool """ mail_attachments = self._retrieve_mails_attachments_by_name( name, check_regex, True, mail_folder, mail_filter @@ -115,10 +115,10 @@ def retrieve_mail_attachments( *, check_regex: bool = False, latest_only: bool = False, - mail_folder: str = 'INBOX', - mail_filter: str = 'All', - not_found_mode: str = 'raise', - ) -> List[Tuple]: + mail_folder: str = "INBOX", + mail_filter: str = "All", + not_found_mode: str = "raise", + ) -> list[tuple]: """ Retrieves mail's attachments in the mail folder by its name. @@ -134,7 +134,6 @@ def retrieve_mail_attachments( if set to 'warn' it will only print a warning and if set to 'ignore' it won't notify you at all. :returns: a list of tuple each containing the attachment filename and its payload. - :rtype: a list of tuple """ mail_attachments = self._retrieve_mails_attachments_by_name( name, check_regex, latest_only, mail_folder, mail_filter @@ -152,9 +151,9 @@ def download_mail_attachments( *, check_regex: bool = False, latest_only: bool = False, - mail_folder: str = 'INBOX', - mail_filter: str = 'All', - not_found_mode: str = 'raise', + mail_folder: str = "INBOX", + mail_filter: str = "All", + not_found_mode: str = "raise", ) -> None: """ Downloads mail's attachments in the mail folder by its name to the local directory. @@ -183,18 +182,18 @@ def download_mail_attachments( self._create_files(mail_attachments, local_output_directory) def _handle_not_found_mode(self, not_found_mode: str) -> None: - if not_found_mode == 'raise': - raise AirflowException('No mail attachments found!') - if not_found_mode == 'warn': - self.log.warning('No mail attachments found!') - elif not_found_mode == 'ignore': + if not_found_mode == "raise": + raise AirflowException("No mail attachments found!") + if not_found_mode == "warn": + self.log.warning("No mail attachments found!") + elif not_found_mode == "ignore": pass # Do not notify if the attachment has not been found. else: self.log.error('Invalid "not_found_mode" %s', not_found_mode) def _retrieve_mails_attachments_by_name( self, name: str, check_regex: bool, latest_only: bool, mail_folder: str, mail_filter: str - ) -> List: + ) -> list: if not self.mail_client: raise Exception("The 'mail_client' should be initialized before!") @@ -225,25 +224,25 @@ def _list_mail_ids_desc(self, mail_filter: str) -> Iterable[str]: def _fetch_mail_body(self, mail_id: str) -> str: if not self.mail_client: raise Exception("The 'mail_client' should be initialized before!") - _, data = self.mail_client.fetch(mail_id, '(RFC822)') + _, data = self.mail_client.fetch(mail_id, "(RFC822)") mail_body = data[0][1] # type: ignore # The mail body is always in this specific location - mail_body_str = mail_body.decode('utf-8') # type: ignore + mail_body_str = mail_body.decode("utf-8") # type: ignore return mail_body_str def _check_mail_body( self, response_mail_body: str, name: str, check_regex: bool, latest_only: bool - ) -> List[Tuple[Any, Any]]: + ) -> list[tuple[Any, Any]]: mail = Mail(response_mail_body) if mail.has_attachments(): return mail.get_attachments_by_name(name, check_regex, find_first=latest_only) return [] - def _create_files(self, mail_attachments: List, local_output_directory: str) -> None: + def _create_files(self, mail_attachments: list, local_output_directory: str) -> None: for name, payload in mail_attachments: if self._is_symlink(name): - self.log.error('Can not create file because it is a symlink!') + self.log.error("Can not create file because it is a symlink!") elif self._is_escaping_current_directory(name): - self.log.error('Can not create file because it is escaping the current directory!') + self.log.error("Can not create file because it is escaping the current directory!") else: self._create_file(name, payload, local_output_directory) @@ -253,19 +252,19 @@ def _is_symlink(self, name: str) -> bool: return os.path.islink(name) def _is_escaping_current_directory(self, name: str) -> bool: - return '../' in name + return "../" in name def _correct_path(self, name: str, local_output_directory: str) -> str: return ( local_output_directory + name - if local_output_directory.endswith('/') - else local_output_directory + '/' + name + if local_output_directory.endswith("/") + else local_output_directory + "/" + name ) def _create_file(self, name: str, payload: Any, local_output_directory: str) -> None: file_path = self._correct_path(name, local_output_directory) - with open(file_path, 'wb') as file: + with open(file_path, "wb") as file: file.write(payload) @@ -285,13 +284,12 @@ def has_attachments(self) -> bool: Checks the mail for a attachments. :returns: True if it has attachments and False if not. - :rtype: bool """ - return self.mail.get_content_maintype() == 'multipart' + return self.mail.get_content_maintype() == "multipart" def get_attachments_by_name( self, name: str, check_regex: bool, find_first: bool = False - ) -> List[Tuple[Any, Any]]: + ) -> list[tuple[Any, Any]]: """ Gets all attachments by name for the mail. @@ -300,7 +298,6 @@ def get_attachments_by_name( :param find_first: If set to True it will only find the first match and then quit. :returns: a list of tuples each containing name and payload where the attachments name matches the given name. - :rtype: list(tuple) """ attachments = [] @@ -310,14 +307,14 @@ def get_attachments_by_name( ) if found_attachment: file_name, file_payload = attachment.get_file() - self.log.info('Found attachment: %s', file_name) + self.log.info("Found attachment: %s", file_name) attachments.append((file_name, file_payload)) if find_first: break return attachments - def _iterate_attachments(self) -> Iterable['MailPart']: + def _iterate_attachments(self) -> Iterable[MailPart]: for part in self.mail.walk(): mail_part = MailPart(part) if mail_part.is_attachment(): @@ -339,17 +336,15 @@ def is_attachment(self) -> bool: Checks if the part is a valid mail attachment. :returns: True if it is an attachment and False if not. - :rtype: bool """ - return self.part.get_content_maintype() != 'multipart' and self.part.get('Content-Disposition') + return self.part.get_content_maintype() != "multipart" and self.part.get("Content-Disposition") - def has_matching_name(self, name: str) -> Optional[Tuple[Any, Any]]: + def has_matching_name(self, name: str) -> tuple[Any, Any] | None: """ Checks if the given name matches the part's name. :param name: The name to look for. :returns: True if it matches the name (including regular expression). - :rtype: tuple """ return re.match(name, self.part.get_filename()) # type: ignore @@ -359,15 +354,13 @@ def has_equal_name(self, name: str) -> bool: :param name: The name to look for. :returns: True if it is equal to the given name. - :rtype: bool """ return self.part.get_filename() == name - def get_file(self) -> Tuple: + def get_file(self) -> tuple: """ Gets the file including name and payload. :returns: the part's name and payload. - :rtype: tuple """ return self.part.get_filename(), self.part.get_payload(decode=True) diff --git a/airflow/providers/imap/provider.yaml b/airflow/providers/imap/provider.yaml index 53c00f08e920e..09b1ed36232b9 100644 --- a/airflow/providers/imap/provider.yaml +++ b/airflow/providers/imap/provider.yaml @@ -18,10 +18,13 @@ --- package-name: apache-airflow-providers-imap name: Internet Message Access Protocol (IMAP) + description: | `Internet Message Access Protocol (IMAP) `__ versions: + - 3.1.0 + - 3.0.0 - 2.2.3 - 2.2.2 - 2.2.1 @@ -32,6 +35,8 @@ versions: - 1.0.1 - 1.0.0 +dependencies: [] + integrations: - integration-name: Internet Message Access Protocol (IMAP) external-doc-url: https://tools.ietf.org/html/rfc3501 @@ -48,9 +53,6 @@ hooks: python-modules: - airflow.providers.imap.hooks.imap -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.imap.hooks.imap.ImapHook - connection-types: - hook-class-name: airflow.providers.imap.hooks.imap.ImapHook connection-type: imap diff --git a/airflow/providers/imap/sensors/imap_attachment.py b/airflow/providers/imap/sensors/imap_attachment.py index a18b09f36ca54..6451930e17e26 100644 --- a/airflow/providers/imap/sensors/imap_attachment.py +++ b/airflow/providers/imap/sensors/imap_attachment.py @@ -16,6 +16,8 @@ # specific language governing permissions and limitations # under the License. """This module allows you to poke for attachments on a mail server.""" +from __future__ import annotations + from typing import TYPE_CHECKING, Sequence from airflow.providers.imap.hooks.imap import ImapHook @@ -39,16 +41,16 @@ class ImapAttachmentSensor(BaseSensorOperator): :param imap_conn_id: The :ref:`imap connection id ` to run the sensor against. """ - template_fields: Sequence[str] = ('attachment_name', 'mail_filter') + template_fields: Sequence[str] = ("attachment_name", "mail_filter") def __init__( self, *, attachment_name, check_regex=False, - mail_folder='INBOX', - mail_filter='All', - conn_id='imap_default', + mail_folder="INBOX", + mail_filter="All", + conn_id="imap_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -59,15 +61,14 @@ def __init__( self.mail_filter = mail_filter self.conn_id = conn_id - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: """ Pokes for a mail attachment on the mail server. :param context: The context that is being provided when poking. :return: True if attachment with the given name is present and False if not. - :rtype: bool """ - self.log.info('Poking for %s', self.attachment_name) + self.log.info("Poking for %s", self.attachment_name) with ImapHook(imap_conn_id=self.conn_id) as imap_hook: return imap_hook.has_mail_attachment( diff --git a/airflow/providers/influxdb/.latest-doc-only-change.txt b/airflow/providers/influxdb/.latest-doc-only-change.txt index 029fd1fd22aec..ff7136e07d744 100644 --- a/airflow/providers/influxdb/.latest-doc-only-change.txt +++ b/airflow/providers/influxdb/.latest-doc-only-change.txt @@ -1 +1 @@ -2d109401b3566aef613501691d18cf7e4c776cd2 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/influxdb/CHANGELOG.rst b/airflow/providers/influxdb/CHANGELOG.rst index f2af84c54d06c..c160e8fd431e9 100644 --- a/airflow/providers/influxdb/CHANGELOG.rst +++ b/airflow/providers/influxdb/CHANGELOG.rst @@ -17,9 +17,55 @@ specific language governing permissions and limitations under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +2.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Features +~~~~~~~~ + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + +2.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Migrate Influx example DAGs to new design #22449 (#24136)`` + * ``Prepare provider documentation 2022.05.11 (#23631)`` + * ``Bump pre-commit hook versions (#22887)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 1.1.3 ..... diff --git a/airflow/providers/influxdb/example_dags/example_influxdb_query.py b/airflow/providers/influxdb/example_dags/example_influxdb_query.py deleted file mode 100644 index 21b6e8fbf584a..0000000000000 --- a/airflow/providers/influxdb/example_dags/example_influxdb_query.py +++ /dev/null @@ -1,39 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from datetime import datetime - -from airflow.models.dag import DAG -from airflow.providers.influxdb.operators.influxdb import InfluxDBOperator - -dag = DAG( - 'example_influxdb_operator', - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) - -# [START howto_operator_influxdb] - -query_influxdb_task = InfluxDBOperator( - influxdb_conn_id='influxdb_conn_id', - task_id='query_influxdb', - sql='from(bucket:"test-influx") |> range(start: -10m, stop: {{ds}})', - dag=dag, -) - -# [END howto_operator_influxdb] diff --git a/airflow/providers/influxdb/hooks/influxdb.py b/airflow/providers/influxdb/hooks/influxdb.py index a72f862ba5dd9..40f276cde4cad 100644 --- a/airflow/providers/influxdb/hooks/influxdb.py +++ b/airflow/providers/influxdb/hooks/influxdb.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ This module allows to connect to a InfluxDB database. @@ -23,8 +22,7 @@ FluxTable """ - -from typing import Dict, List +from __future__ import annotations import pandas as pd from influxdb_client import InfluxDBClient @@ -45,17 +43,17 @@ class InfluxDBHook(BaseHook): :param influxdb_conn_id: Reference to :ref:`Influxdb connection id `. """ - conn_name_attr = 'influxdb_conn_id' - default_conn_name = 'influxdb_default' - conn_type = 'influxdb' - hook_name = 'Influxdb' + conn_name_attr = "influxdb_conn_id" + default_conn_name = "influxdb_default" + conn_type = "influxdb" + hook_name = "Influxdb" def __init__(self, conn_id: str = default_conn_name, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.influxdb_conn_id = conn_id self.connection = kwargs.pop("connection", None) self.client = None - self.extras: Dict = {} + self.extras: dict = {} self.uri = None self.org_name = None @@ -68,7 +66,7 @@ def get_uri(self, conn: Connection): based on SSL or other InfluxDB host requirements """ - conn_scheme = 'https' if conn.schema is None else conn.schema + conn_scheme = "https" if conn.schema is None else conn.schema conn_port = 7687 if conn.port is None else conn.port return f"{conn_scheme}://{conn.host}:{conn_port}" @@ -81,22 +79,22 @@ def get_conn(self) -> InfluxDBClient: self.extras = self.connection.extra_dejson.copy() self.uri = self.get_uri(self.connection) - self.log.info('URI: %s', self.uri) + self.log.info("URI: %s", self.uri) if self.client is not None: return self.client - token = self.connection.extra_dejson.get('token') - self.org_name = self.connection.extra_dejson.get('org_name') + token = self.connection.extra_dejson.get("token") + self.org_name = self.connection.extra_dejson.get("org_name") - self.log.info('URI: %s', self.uri) - self.log.info('Organization: %s', self.org_name) + self.log.info("URI: %s", self.uri) + self.log.info("Organization: %s", self.org_name) self.client = self.get_client(self.uri, token, self.org_name) return self.client - def query(self, query) -> List[FluxTable]: + def query(self, query) -> list[FluxTable]: """ Function to to run the query. Note: The bucket name diff --git a/airflow/providers/influxdb/operators/influxdb.py b/airflow/providers/influxdb/operators/influxdb.py index 6222b22202061..73abc1cf2a1c1 100644 --- a/airflow/providers/influxdb/operators/influxdb.py +++ b/airflow/providers/influxdb/operators/influxdb.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator @@ -37,20 +39,20 @@ class InfluxDBOperator(BaseOperator): :param influxdb_conn_id: Reference to :ref:`Influxdb connection id `. """ - template_fields: Sequence[str] = ('sql',) + template_fields: Sequence[str] = ("sql",) def __init__( self, *, sql: str, - influxdb_conn_id: str = 'influxdb_default', + influxdb_conn_id: str = "influxdb_default", **kwargs, ) -> None: super().__init__(**kwargs) self.influxdb_conn_id = influxdb_conn_id self.sql = sql - def execute(self, context: 'Context') -> None: - self.log.info('Executing: %s', self.sql) + def execute(self, context: Context) -> None: + self.log.info("Executing: %s", self.sql) self.hook = InfluxDBHook(conn_id=self.influxdb_conn_id) self.hook.query(self.sql) diff --git a/airflow/providers/influxdb/provider.yaml b/airflow/providers/influxdb/provider.yaml index c9e91f21b7db2..0518b1e5a507b 100644 --- a/airflow/providers/influxdb/provider.yaml +++ b/airflow/providers/influxdb/provider.yaml @@ -17,15 +17,26 @@ --- package-name: apache-airflow-providers-influxdb + name: Influxdb + description: | `InfluxDB `__ + +dependencies: + - apache-airflow>=2.3.0 + - influxdb-client>=1.19.0 + - requests>=2.26.0 + versions: + - 2.1.0 + - 2.0.0 - 1.1.3 - 1.1.2 - 1.1.1 - 1.1.0 - 1.0.0 + integrations: - integration-name: Influxdb external-doc-url: https://www.influxdata.com/ diff --git a/airflow/providers/jdbc/CHANGELOG.rst b/airflow/providers/jdbc/CHANGELOG.rst index e0f252bc5b7d6..1d4b276fdc88b 100644 --- a/airflow/providers/jdbc/CHANGELOG.rst +++ b/airflow/providers/jdbc/CHANGELOG.rst @@ -16,9 +16,110 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.3.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` +* ``Allow and prefer non-prefixed extra fields for JdbcHook (#27044)`` + +Features +~~~~~~~~ + +* ``Add SQLExecuteQueryOperator (#25717)`` +* ``Look for 'extra__' instead of 'extra_' in 'get_field' (#27489)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + +4.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.3+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Misc +~~~~ + +* In JdbcHook, non-prefixed extra fields are supported and are preferred. E.g. ``drv_path`` will + be preferred if ``extra__jdbc__drv_path`` is also present. + +3.2.1 +..... + +Misc +~~~~ + +* ``Add common-sql lower bound for common-sql (#25789)`` + +.. Review and move the new changes to one of the sections above: + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +3.2.0 +..... + +Features +~~~~~~~~ + +* ``Adding configurable fetch_all_handler for JdbcOperator (#25412)`` +* ``Unify DbApiHook.run() method with the methods which override it (#23971)`` + + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Move all SQL classes to common-sql provider (#24836)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Features +~~~~~~~~ + +* ``Handler parameter from 'JdbcOperator' to 'JdbcHook.run' (#23817)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Migrate JDBC example DAGs to new design #22450 (#24137)`` + * ``Prepare provider documentation 2022.05.11 (#23631)`` + * ``Use new Breese for building, pulling and verifying the images. (#23104)`` + * ``Replace usage of 'DummyOperator' with 'EmptyOperator' (#22974)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.1.3 ..... diff --git a/airflow/providers/jdbc/example_dags/example_jdbc_queries.py b/airflow/providers/jdbc/example_dags/example_jdbc_queries.py deleted file mode 100644 index 5be0598c4c503..0000000000000 --- a/airflow/providers/jdbc/example_dags/example_jdbc_queries.py +++ /dev/null @@ -1,60 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Example DAG demonstrating the usage of the JdbcOperator.""" - -from datetime import datetime, timedelta - -from airflow import DAG - -try: - from airflow.operators.empty import EmptyOperator -except ModuleNotFoundError: - from airflow.operators.dummy import DummyOperator as EmptyOperator # type: ignore -from airflow.providers.jdbc.operators.jdbc import JdbcOperator - -with DAG( - dag_id='example_jdbc_operator', - schedule_interval='0 0 * * *', - start_date=datetime(2021, 1, 1), - dagrun_timeout=timedelta(minutes=60), - tags=['example'], - catchup=False, -) as dag: - - run_this_last = EmptyOperator(task_id='run_this_last') - - # [START howto_operator_jdbc_template] - delete_data = JdbcOperator( - task_id='delete_data', - sql='delete from my_schema.my_table where dt = {{ ds }}', - jdbc_conn_id='my_jdbc_connection', - autocommit=True, - ) - # [END howto_operator_jdbc_template] - - # [START howto_operator_jdbc] - insert_data = JdbcOperator( - task_id='insert_data', - sql='insert into my_schema.my_table select dt, value from my_schema.source_data', - jdbc_conn_id='my_jdbc_connection', - autocommit=True, - ) - # [END howto_operator_jdbc] - - delete_data >> insert_data >> run_this_last diff --git a/airflow/providers/jdbc/hooks/jdbc.py b/airflow/providers/jdbc/hooks/jdbc.py index 734afecb5ba18..56947af0ebb2e 100644 --- a/airflow/providers/jdbc/hooks/jdbc.py +++ b/airflow/providers/jdbc/hooks/jdbc.py @@ -15,13 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import Any, Dict, Optional +from typing import Any import jaydebeapi -from airflow.hooks.dbapi import DbApiHook from airflow.models.connection import Connection +from airflow.providers.common.sql.hooks.sql import DbApiHook class JdbcHook(DbApiHook): @@ -33,41 +34,53 @@ class JdbcHook(DbApiHook): Raises an airflow error if the given connection id doesn't exist. """ - conn_name_attr = 'jdbc_conn_id' - default_conn_name = 'jdbc_default' - conn_type = 'jdbc' - hook_name = 'JDBC Connection' + conn_name_attr = "jdbc_conn_id" + default_conn_name = "jdbc_default" + conn_type = "jdbc" + hook_name = "JDBC Connection" supports_autocommit = True @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: + def get_connection_form_widgets() -> dict[str, Any]: """Returns connection widgets to add to connection form""" from flask_appbuilder.fieldwidgets import BS3TextFieldWidget from flask_babel import lazy_gettext from wtforms import StringField return { - "extra__jdbc__drv_path": StringField(lazy_gettext('Driver Path'), widget=BS3TextFieldWidget()), - "extra__jdbc__drv_clsname": StringField( - lazy_gettext('Driver Class'), widget=BS3TextFieldWidget() - ), + "drv_path": StringField(lazy_gettext("Driver Path"), widget=BS3TextFieldWidget()), + "drv_clsname": StringField(lazy_gettext("Driver Class"), widget=BS3TextFieldWidget()), } @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { - "hidden_fields": ['port', 'schema', 'extra'], - "relabeling": {'host': 'Connection URL'}, + "hidden_fields": ["port", "schema", "extra"], + "relabeling": {"host": "Connection URL"}, } + def _get_field(self, extras: dict, field_name: str): + """Get field from extra, first checking short name, then for backcompat we check for prefixed name.""" + backcompat_prefix = "extra__jdbc__" + if field_name.startswith("extra__"): + raise ValueError( + f"Got prefixed name {field_name}; please remove the '{backcompat_prefix}' prefix " + "when using this method." + ) + if field_name in extras: + return extras[field_name] or None + prefixed_name = f"{backcompat_prefix}{field_name}" + return extras.get(prefixed_name) or None + def get_conn(self) -> jaydebeapi.Connection: conn: Connection = self.get_connection(getattr(self, self.conn_name_attr)) + extras = conn.extra_dejson host: str = conn.host login: str = conn.login psw: str = conn.password - jdbc_driver_loc: Optional[str] = conn.extra_dejson.get('extra__jdbc__drv_path') - jdbc_driver_name: Optional[str] = conn.extra_dejson.get('extra__jdbc__drv_clsname') + jdbc_driver_loc: str | None = self._get_field(extras, "drv_path") + jdbc_driver_name: str | None = self._get_field(extras, "drv_clsname") conn = jaydebeapi.connect( jclassname=jdbc_driver_name, @@ -94,6 +107,5 @@ def get_autocommit(self, conn: jaydebeapi.Connection) -> bool: :param conn: The connection. :return: connection autocommit setting. - :rtype: bool """ return conn.jconn.getAutoCommit() diff --git a/airflow/providers/jdbc/operators/jdbc.py b/airflow/providers/jdbc/operators/jdbc.py index 2c023d9afe9ba..228d30ad62eb2 100644 --- a/airflow/providers/jdbc/operators/jdbc.py +++ b/airflow/providers/jdbc/operators/jdbc.py @@ -15,22 +15,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Sequence, Union +import warnings +from typing import Sequence -from airflow.models import BaseOperator -from airflow.providers.jdbc.hooks.jdbc import JdbcHook +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator -if TYPE_CHECKING: - from airflow.utils.context import Context - -def fetch_all_handler(cursor): - """Handler for DbApiHook.run() to return results""" - return cursor.fetchall() - - -class JdbcOperator(BaseOperator): +class JdbcOperator(SQLExecuteQueryOperator): """ Executes sql code in a database using jdbc driver. @@ -49,28 +42,16 @@ class JdbcOperator(BaseOperator): :param parameters: (optional) the parameters to render the SQL query with. """ - template_fields: Sequence[str] = ('sql',) - template_ext: Sequence[str] = ('.sql',) - template_fields_renderers = {'sql': 'sql'} - ui_color = '#ededed' - - def __init__( - self, - *, - sql: Union[str, List[str]], - jdbc_conn_id: str = 'jdbc_default', - autocommit: bool = False, - parameters: Optional[Union[Mapping, Iterable]] = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.parameters = parameters - self.sql = sql - self.jdbc_conn_id = jdbc_conn_id - self.autocommit = autocommit - self.hook = None - - def execute(self, context: 'Context') -> None: - self.log.info('Executing: %s', self.sql) - hook = JdbcHook(jdbc_conn_id=self.jdbc_conn_id) - return hook.run(self.sql, self.autocommit, parameters=self.parameters, handler=fetch_all_handler) + template_fields: Sequence[str] = ("sql",) + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "sql"} + ui_color = "#ededed" + + def __init__(self, *, jdbc_conn_id: str = "jdbc_default", **kwargs) -> None: + super().__init__(conn_id=jdbc_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""", + DeprecationWarning, + stacklevel=2, + ) diff --git a/airflow/providers/jdbc/provider.yaml b/airflow/providers/jdbc/provider.yaml index 475ffaf34ec64..e229585814ae7 100644 --- a/airflow/providers/jdbc/provider.yaml +++ b/airflow/providers/jdbc/provider.yaml @@ -22,6 +22,11 @@ description: | `Java Database Connectivity (JDBC) `__ versions: + - 3.3.0 + - 3.2.1 + - 3.2.0 + - 3.1.0 + - 3.0.0 - 2.1.3 - 2.1.2 - 2.1.1 @@ -31,8 +36,10 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-common-sql>=1.3.1 + - jaydebeapi>=1.1.1 integrations: - integration-name: Java Database Connectivity (JDBC) @@ -51,8 +58,6 @@ hooks: python-modules: - airflow.providers.jdbc.hooks.jdbc -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.jdbc.hooks.jdbc.JdbcHook connection-types: - hook-class-name: airflow.providers.jdbc.hooks.jdbc.JdbcHook diff --git a/airflow/providers/jenkins/.latest-doc-only-change.txt b/airflow/providers/jenkins/.latest-doc-only-change.txt index ab24993f57139..ff7136e07d744 100644 --- a/airflow/providers/jenkins/.latest-doc-only-change.txt +++ b/airflow/providers/jenkins/.latest-doc-only-change.txt @@ -1 +1 @@ -8b6b0848a3cacf9999477d6af4d2a87463f03026 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/jenkins/CHANGELOG.rst b/airflow/providers/jenkins/CHANGELOG.rst index 65597d7438d61..fa6a3660f8a25 100644 --- a/airflow/providers/jenkins/CHANGELOG.rst +++ b/airflow/providers/jenkins/CHANGELOG.rst @@ -16,9 +16,60 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + + +Bug Fixes +~~~~~~~~~ + +* ``Bug Fix for 'apache-airflow-providers-jenkins' 'JenkinsJobTriggerOperator' (#22802)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Migrate Jenkins example DAGs to new design #22451 (#24138)`` + * ``Prepare provider documentation 2022.05.11 (#23631)`` + * ``Use new Breese for building, pulling and verifying the images. (#23104)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.1.0 ..... diff --git a/airflow/providers/jenkins/hooks/jenkins.py b/airflow/providers/jenkins/hooks/jenkins.py index f78a80500dac1..db1284a59a53b 100644 --- a/airflow/providers/jenkins/hooks/jenkins.py +++ b/airflow/providers/jenkins/hooks/jenkins.py @@ -15,9 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# - -from typing import Optional +from __future__ import annotations import jenkins @@ -29,40 +27,40 @@ class JenkinsHook(BaseHook): """Hook to manage connection to jenkins server""" - conn_name_attr = 'conn_id' - default_conn_name = 'jenkins_default' - conn_type = 'jenkins' - hook_name = 'Jenkins' + conn_name_attr = "conn_id" + default_conn_name = "jenkins_default" + conn_type = "jenkins" + hook_name = "Jenkins" def __init__(self, conn_id: str = default_conn_name) -> None: super().__init__() connection = self.get_connection(conn_id) self.connection = connection - connection_prefix = 'http' + connection_prefix = "http" # connection.extra contains info about using https (true) or http (false) if to_boolean(connection.extra): - connection_prefix = 'https' - url = f'{connection_prefix}://{connection.host}:{connection.port}' - self.log.info('Trying to connect to %s', url) + connection_prefix = "https" + url = f"{connection_prefix}://{connection.host}:{connection.port}" + self.log.info("Trying to connect to %s", url) self.jenkins_server = jenkins.Jenkins(url, connection.login, connection.password) def get_jenkins_server(self) -> jenkins.Jenkins: """Get jenkins server""" return self.jenkins_server - def get_build_building_state(self, job_name: str, build_number: Optional[int]) -> bool: + def get_build_building_state(self, job_name: str, build_number: int | None) -> bool: """Get build building state""" try: if not build_number: self.log.info("Build number not specified, getting latest build info from Jenkins") job_info = self.jenkins_server.get_job_info(job_name) - build_number_to_check = job_info['lastBuild']['number'] + build_number_to_check = job_info["lastBuild"]["number"] else: build_number_to_check = build_number self.log.info("Getting build info for %s build number: #%s", job_name, build_number_to_check) build_info = self.jenkins_server.get_build_info(job_name, build_number_to_check) - building = build_info['building'] + building = build_info["building"] return building except jenkins.JenkinsException as err: - raise AirflowException(f'Jenkins call failed with error : {err}') + raise AirflowException(f"Jenkins call failed with error : {err}") diff --git a/airflow/providers/jenkins/operators/jenkins_job_trigger.py b/airflow/providers/jenkins/operators/jenkins_job_trigger.py index b7dcb25913b27..9f86e183be0f6 100644 --- a/airflow/providers/jenkins/operators/jenkins_job_trigger.py +++ b/airflow/providers/jenkins/operators/jenkins_job_trigger.py @@ -15,12 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import ast import json import socket import time -from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union +from typing import Any, Iterable, List, Mapping, Sequence, Union from urllib.error import HTTPError, URLError import jenkins @@ -32,10 +33,10 @@ from airflow.providers.jenkins.hooks.jenkins import JenkinsHook JenkinsRequest = Mapping[str, Any] -ParamType = Optional[Union[str, Dict, List]] +ParamType = Union[str, dict, List, None] -def jenkins_request_with_headers(jenkins_server: Jenkins, req: Request) -> Optional[JenkinsRequest]: +def jenkins_request_with_headers(jenkins_server: Jenkins, req: Request) -> JenkinsRequest | None: """ We need to get the headers in addition to the body answer to get the location from them @@ -55,19 +56,19 @@ def jenkins_request_with_headers(jenkins_server: Jenkins, req: Request) -> Optio raise jenkins.EmptyResponseException( f"Error communicating with server[{jenkins_server.server}]: empty response" ) - return {'body': response_body.decode('utf-8'), 'headers': response_headers} + return {"body": response_body.decode("utf-8"), "headers": response_headers} except HTTPError as e: # Jenkins's funky authentication means its nigh impossible to distinguish errors. if e.code in [401, 403, 500]: - raise JenkinsException(f'Error in request. Possibly authentication failed [{e.code}]: {e.reason}') + raise JenkinsException(f"Error in request. Possibly authentication failed [{e.code}]: {e.reason}") elif e.code == 404: - raise jenkins.NotFoundException('Requested item could not be found') + raise jenkins.NotFoundException("Requested item could not be found") else: raise except socket.timeout as e: - raise jenkins.TimeoutException(f'Error in request: {e}') + raise jenkins.TimeoutException(f"Error in request: {e}") except URLError as e: - raise JenkinsException(f'Error in request: {e.reason}') + raise JenkinsException(f"Error in request: {e.reason}") return None @@ -89,9 +90,9 @@ class JenkinsJobTriggerOperator(BaseOperator): :param allowed_jenkins_states: Iterable of allowed result jenkins states, default is ``['SUCCESS']`` """ - template_fields: Sequence[str] = ('parameters',) - template_ext: Sequence[str] = ('.json',) - ui_color = '#f9ec86' + template_fields: Sequence[str] = ("parameters",) + template_ext: Sequence[str] = (".json",) + ui_color = "#f9ec86" def __init__( self, @@ -101,7 +102,7 @@ def __init__( parameters: ParamType = None, sleep_time: int = 10, max_try_before_job_appears: int = 10, - allowed_jenkins_states: Optional[Iterable[str]] = None, + allowed_jenkins_states: Iterable[str] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -110,9 +111,9 @@ def __init__( self.sleep_time = max(sleep_time, 1) self.jenkins_connection_id = jenkins_connection_id self.max_try_before_job_appears = max_try_before_job_appears - self.allowed_jenkins_states = list(allowed_jenkins_states) if allowed_jenkins_states else ['SUCCESS'] + self.allowed_jenkins_states = list(allowed_jenkins_states) if allowed_jenkins_states else ["SUCCESS"] - def build_job(self, jenkins_server: Jenkins, params: ParamType = None) -> Optional[JenkinsRequest]: + def build_job(self, jenkins_server: Jenkins, params: ParamType = None) -> JenkinsRequest | None: """ This function makes an API call to Jenkins to trigger a build for 'job_name' It returned a dict with 2 keys : body and headers. @@ -129,7 +130,7 @@ def build_job(self, jenkins_server: Jenkins, params: ParamType = None) -> Option if params and isinstance(params, str): params = ast.literal_eval(params) - request = Request(method='POST', url=jenkins_server.build_job_url(self.job_name, params, None)) + request = Request(method="POST", url=jenkins_server.build_job_url(self.job_name, params, None)) return jenkins_request_with_headers(jenkins_server, request) def poll_job_in_queue(self, location: str, jenkins_server: Jenkins) -> int: @@ -148,32 +149,32 @@ def poll_job_in_queue(self, location: str, jenkins_server: Jenkins) -> int: :return: The build_number corresponding to the triggered job """ try_count = 0 - location += '/api/json' + location += "/api/json" # TODO Use get_queue_info instead # once it will be available in python-jenkins (v > 0.4.15) - self.log.info('Polling jenkins queue at the url %s', location) + self.log.info("Polling jenkins queue at the url %s", location) while try_count < self.max_try_before_job_appears: try: location_answer = jenkins_request_with_headers( - jenkins_server, Request(method='POST', url=location) + jenkins_server, Request(method="POST", url=location) ) # we don't want to fail the operator, this will continue to poll # until max_try_before_job_appears reached except (HTTPError, JenkinsException): - self.log.warning('polling failed, retrying', exc_info=True) + self.log.warning("polling failed, retrying", exc_info=True) try_count += 1 time.sleep(self.sleep_time) continue if location_answer is not None: - json_response = json.loads(location_answer['body']) + json_response = json.loads(location_answer["body"]) if ( - 'executable' in json_response - and json_response['executable'] is not None - and 'number' in json_response['executable'] + "executable" in json_response + and json_response["executable"] is not None + and "number" in json_response["executable"] ): - build_number = json_response['executable']['number'] - self.log.info('Job executed on Jenkins side with the build number %s', build_number) + build_number = json_response["executable"]["number"] + self.log.info("Job executed on Jenkins side with the build number %s", build_number) return build_number try_count += 1 time.sleep(self.sleep_time) @@ -186,9 +187,9 @@ def get_hook(self) -> JenkinsHook: """Instantiate jenkins hook""" return JenkinsHook(self.jenkins_connection_id) - def execute(self, context: Mapping[Any, Any]) -> Optional[str]: + def execute(self, context: Mapping[Any, Any]) -> str | None: self.log.info( - 'Triggering the job %s on the jenkins : %s with the parameters : %s', + "Triggering the job %s on the jenkins : %s with the parameters : %s", self.job_name, self.jenkins_connection_id, self.parameters, @@ -196,7 +197,7 @@ def execute(self, context: Mapping[Any, Any]) -> Optional[str]: jenkins_server = self.get_hook().get_jenkins_server() jenkins_response = self.build_job(jenkins_server, self.parameters) if jenkins_response: - build_number = self.poll_job_in_queue(jenkins_response['headers']['Location'], jenkins_server) + build_number = self.poll_job_in_queue(jenkins_response["headers"]["Location"], jenkins_server) time.sleep(self.sleep_time) keep_polling_job = True @@ -205,30 +206,30 @@ def execute(self, context: Mapping[Any, Any]) -> Optional[str]: while keep_polling_job: try: build_info = jenkins_server.get_build_info(name=self.job_name, number=build_number) - if build_info['result'] is not None: + if build_info["result"] is not None: keep_polling_job = False # Check if job ended with not allowed state. - if build_info['result'] not in self.allowed_jenkins_states: + if build_info["result"] not in self.allowed_jenkins_states: raise AirflowException( f"Jenkins job failed, final state : {build_info['result']}. " f"Find more information on job url : {build_info['url']}" ) else: - self.log.info('Waiting for job to complete : %s , build %s', self.job_name, build_number) + self.log.info("Waiting for job to complete : %s , build %s", self.job_name, build_number) time.sleep(self.sleep_time) except jenkins.NotFoundException as err: - raise AirflowException(f'Jenkins job status check failed. Final error was: {err.resp.status}') + raise AirflowException(f"Jenkins job status check failed. Final error was: {err.resp.status}") except jenkins.JenkinsException as err: raise AirflowException( - f'Jenkins call failed with error : {err}, if you have parameters ' - 'double check them, jenkins sends back ' - 'this exception for unknown parameters' - 'You can also check logs for more details on this exception ' - '(jenkins_url/log/rss)' + f"Jenkins call failed with error : {err}, if you have parameters " + "double check them, jenkins sends back " + "this exception for unknown parameters" + "You can also check logs for more details on this exception " + "(jenkins_url/log/rss)" ) if build_info: # If we can we return the url of the job # for later use (like retrieving an artifact) - return build_info['url'] + return build_info["url"] return None diff --git a/airflow/providers/jenkins/provider.yaml b/airflow/providers/jenkins/provider.yaml index fbbdefe27cecf..a188b1728045a 100644 --- a/airflow/providers/jenkins/provider.yaml +++ b/airflow/providers/jenkins/provider.yaml @@ -22,6 +22,8 @@ description: | `Jenkins `__ versions: + - 3.1.0 + - 3.0.0 - 2.1.0 - 2.0.7 - 2.0.6 @@ -35,8 +37,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - python-jenkins>=1.0.0 integrations: - integration-name: Jenkins @@ -58,9 +61,6 @@ sensors: python-modules: - 'airflow.providers.jenkins.sensors.jenkins' -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.jenkins.hooks.jenkins.JenkinsHook - connection-types: - hook-class-name: airflow.providers.jenkins.hooks.jenkins.JenkinsHook connection-type: jenkins diff --git a/airflow/providers/jenkins/sensors/jenkins.py b/airflow/providers/jenkins/sensors/jenkins.py index 8bb400c7294d2..dbfc9932d69e6 100644 --- a/airflow/providers/jenkins/sensors/jenkins.py +++ b/airflow/providers/jenkins/sensors/jenkins.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from airflow.providers.jenkins.hooks.jenkins import JenkinsHook from airflow.sensors.base import BaseSensorOperator @@ -40,7 +41,7 @@ def __init__( *, jenkins_connection_id: str, job_name: str, - build_number: Optional[int] = None, + build_number: int | None = None, **kwargs, ): super().__init__(**kwargs) @@ -48,7 +49,7 @@ def __init__( self.build_number = build_number self.jenkins_connection_id = jenkins_connection_id - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: self.log.info("Poking jenkins job %s", self.job_name) hook = JenkinsHook(self.jenkins_connection_id) is_building = hook.get_build_building_state(self.job_name, self.build_number) diff --git a/airflow/providers/jira/.latest-doc-only-change.txt b/airflow/providers/jira/.latest-doc-only-change.txt deleted file mode 100644 index 28124098645cf..0000000000000 --- a/airflow/providers/jira/.latest-doc-only-change.txt +++ /dev/null @@ -1 +0,0 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 diff --git a/airflow/providers/jira/CHANGELOG.rst b/airflow/providers/jira/CHANGELOG.rst deleted file mode 100644 index b608184efe1af..0000000000000 --- a/airflow/providers/jira/CHANGELOG.rst +++ /dev/null @@ -1,114 +0,0 @@ - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - - -Changelog ---------- - -2.0.5 -..... - - -Bug Fixes -~~~~~~~~~ - -* ``Use JiraHook instead of JiraOperator for JiraSensor`` - -2.0.4 -..... - -Bug Fixes -~~~~~~~~~ - -* ``Fix mistakenly added install_requires for all providers (#22382)`` - -2.0.3 -..... - -Misc -~~~~~ - -* ``Add Trove classifiers in PyPI (Framework :: Apache Airflow :: Provider)`` - -2.0.2 -..... - -Misc -~~~~ - -* ``Support for Python 3.10`` - -.. Below changes are excluded from the changelog. Move them to - appropriate section above if needed. Do not delete the lines(!): - * ``Add documentation for January 2021 providers release (#21257)`` - * ``Fixed changelog for January 2022 (delayed) provider's release (#21439)`` - * ``Fix K8S changelog to be PyPI-compatible (#20614)`` - * ``Fix template_fields type to have MyPy friendly Sequence type (#20571)`` - * ``Fix mypy providers (#20190)`` - * ``Remove ':type' lines now sphinx-autoapi supports typehints (#20951)`` - * ``Update documentation for provider December 2021 release (#20523)`` - * ``Use typed Context EVERYWHERE (#20565)`` - -2.0.1 -..... - -Misc -~~~~ - -* ``Optimise connection importing for Airflow 2.2.0`` - -.. Below changes are excluded from the changelog. Move them to - appropriate section above if needed. Do not delete the lines(!): - * ``Update description about the new ''connection-types'' provider meta-data (#17767)`` - * ``Import Hooks lazily individually in providers manager (#17682)`` - * ``Prepares docs for Rc2 release of July providers (#17116)`` - * ``Prepare documentation for July release of providers. (#17015)`` - * ``Removes pylint from our toolchain (#16682)`` - -2.0.0 -..... - -Breaking changes -~~~~~~~~~~~~~~~~ - -* ``Auto-apply apply_default decorator (#15667)`` - -.. warning:: Due to apply_default decorator removal, this version of the provider requires Airflow 2.1.0+. - If your Airflow version is < 2.1.0, and you want to install this provider version, first upgrade - Airflow to at least version 2.1.0. Otherwise your Airflow package version will be upgraded - automatically and you will have to manually run ``airflow upgrade db`` to complete the migration. - -.. Below changes are excluded from the changelog. Move them to - appropriate section above if needed. Do not delete the lines(!): - * ``Updated documentation for June 2021 provider release (#16294)`` - * ``More documentation update for June providers release (#16405)`` - * ``Synchronizes updated changelog after buggfix release (#16464)`` - -1.0.2 -..... - -* ``Fix 'logging.exception' redundancy (#14823)`` - -1.0.1 -..... - -Updated documentation and readme files. - -1.0.0 -..... - -Initial version of the provider. diff --git a/airflow/providers/jira/hooks/jira.py b/airflow/providers/jira/hooks/jira.py deleted file mode 100644 index 27ef4b0c3916d..0000000000000 --- a/airflow/providers/jira/hooks/jira.py +++ /dev/null @@ -1,88 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Hook for JIRA""" -from typing import Any, Optional - -from jira import JIRA -from jira.exceptions import JIRAError - -from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook - - -class JiraHook(BaseHook): - """ - Jira interaction hook, a Wrapper around JIRA Python SDK. - - :param jira_conn_id: reference to a pre-defined Jira Connection - """ - - default_conn_name = 'jira_default' - conn_type = "jira" - conn_name_attr = "jira_conn_id" - hook_name = "JIRA" - - def __init__(self, jira_conn_id: str = default_conn_name, proxies: Optional[Any] = None) -> None: - super().__init__() - self.jira_conn_id = jira_conn_id - self.proxies = proxies - self.client: Optional[JIRA] = None - self.get_conn() - - def get_conn(self) -> JIRA: - if not self.client: - self.log.debug('Creating Jira client for conn_id: %s', self.jira_conn_id) - - get_server_info = True - validate = True - extra_options = {} - if not self.jira_conn_id: - raise AirflowException('Failed to create jira client. no jira_conn_id provided') - - conn = self.get_connection(self.jira_conn_id) - if conn.extra is not None: - extra_options = conn.extra_dejson - # only required attributes are taken for now, - # more can be added ex: async, logging, max_retries - - # verify - if 'verify' in extra_options and extra_options['verify'].lower() == 'false': - extra_options['verify'] = False - - # validate - if 'validate' in extra_options and extra_options['validate'].lower() == 'false': - validate = False - - if 'get_server_info' in extra_options and extra_options['get_server_info'].lower() == 'false': - get_server_info = False - - try: - self.client = JIRA( - conn.host, - options=extra_options, - basic_auth=(conn.login, conn.password), - get_server_info=get_server_info, - validate=validate, - proxies=self.proxies, - ) - except JIRAError as jira_error: - raise AirflowException(f'Failed to create jira client, jira error: {str(jira_error)}') - except Exception as e: - raise AirflowException(f'Failed to create jira client, error: {str(e)}') - - return self.client diff --git a/airflow/providers/jira/operators/jira.py b/airflow/providers/jira/operators/jira.py deleted file mode 100644 index e9a45bdd45525..0000000000000 --- a/airflow/providers/jira/operators/jira.py +++ /dev/null @@ -1,90 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence - -from airflow.exceptions import AirflowException -from airflow.models import BaseOperator -from airflow.providers.jira.hooks.jira import JIRAError, JiraHook - -if TYPE_CHECKING: - from airflow.utils.context import Context - - -class JiraOperator(BaseOperator): - """ - JiraOperator to interact and perform action on Jira issue tracking system. - This operator is designed to use Jira Python SDK: http://jira.readthedocs.io - - :param jira_conn_id: reference to a pre-defined Jira Connection - :param jira_method: method name from Jira Python SDK to be called - :param jira_method_args: required method parameters for the jira_method. (templated) - :param result_processor: function to further process the response from Jira - :param get_jira_resource_method: function or operator to get jira resource - on which the provided jira_method will be executed - """ - - template_fields: Sequence[str] = ("jira_method_args",) - - def __init__( - self, - *, - jira_method: str, - jira_conn_id: str = 'jira_default', - jira_method_args: Optional[dict] = None, - result_processor: Optional[Callable] = None, - get_jira_resource_method: Optional[Callable] = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.jira_conn_id = jira_conn_id - self.method_name = jira_method - self.jira_method_args = jira_method_args - self.result_processor = result_processor - self.get_jira_resource_method = get_jira_resource_method - - def execute(self, context: 'Context') -> Any: - try: - if self.get_jira_resource_method is not None: - # if get_jira_resource_method is provided, jira_method will be executed on - # resource returned by executing the get_jira_resource_method. - # This makes all the provided methods of JIRA sdk accessible and usable - # directly at the JiraOperator without additional wrappers. - # ref: http://jira.readthedocs.io/en/latest/api.html - if isinstance(self.get_jira_resource_method, JiraOperator): - resource = self.get_jira_resource_method.execute(**context) - else: - resource = self.get_jira_resource_method(**context) - else: - # Default method execution is on the top level jira client resource - hook = JiraHook(jira_conn_id=self.jira_conn_id) - resource = hook.client - - # Current Jira-Python SDK (1.0.7) has issue with pickling the jira response. - # ex: self.xcom_push(context, key='operator_response', value=jira_response) - # This could potentially throw error if jira_result is not picklable - jira_result = getattr(resource, self.method_name)(**self.jira_method_args) - if self.result_processor: - return self.result_processor(context, jira_result) - - return jira_result - - except JIRAError as jira_error: - raise AirflowException(f"Failed to execute jiraOperator, error: {str(jira_error)}") - except Exception as e: - raise AirflowException(f"Jira operator error: {str(e)}") diff --git a/airflow/providers/jira/provider.yaml b/airflow/providers/jira/provider.yaml deleted file mode 100644 index c522822fd6422..0000000000000 --- a/airflow/providers/jira/provider.yaml +++ /dev/null @@ -1,64 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - ---- -package-name: apache-airflow-providers-jira -name: Jira -description: | - `Atlassian Jira `__ - -versions: - - 2.0.5 - - 2.0.4 - - 2.0.3 - - 2.0.2 - - 2.0.1 - - 2.0.0 - - 1.0.2 - - 1.0.1 - - 1.0.0 - -additional-dependencies: - - apache-airflow>=2.1.0 - -integrations: - - integration-name: Atlassian Jira - external-doc-url: https://www.atlassian.com/pl/software/jira - logo: /integration-logos/jira/Jira.png - tags: [software] - -operators: - - integration-name: Atlassian Jira - python-modules: - - airflow.providers.jira.operators.jira - -sensors: - - integration-name: Atlassian Jira - python-modules: - - airflow.providers.jira.sensors.jira - -hooks: - - integration-name: Atlassian Jira - python-modules: - - airflow.providers.jira.hooks.jira - -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.jira.hooks.jira.JiraHook - -connection-types: - - hook-class-name: airflow.providers.jira.hooks.jira.JiraHook - connection-type: jira diff --git a/airflow/providers/jira/sensors/jira.py b/airflow/providers/jira/sensors/jira.py deleted file mode 100644 index 9b16b6557bf8a..0000000000000 --- a/airflow/providers/jira/sensors/jira.py +++ /dev/null @@ -1,137 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence - -from jira.resources import Issue, Resource - -from airflow.providers.jira.hooks.jira import JiraHook -from airflow.providers.jira.operators.jira import JIRAError -from airflow.sensors.base import BaseSensorOperator - -if TYPE_CHECKING: - from airflow.utils.context import Context - - -class JiraSensor(BaseSensorOperator): - """ - Monitors a jira ticket for any change. - - :param jira_conn_id: reference to a pre-defined Jira Connection - :param method_name: method name from jira-python-sdk to be execute - :param method_params: parameters for the method method_name - :param result_processor: function that return boolean and act as a sensor response - """ - - def __init__( - self, - *, - method_name: str, - jira_conn_id: str = 'jira_default', - method_params: Optional[dict] = None, - result_processor: Optional[Callable] = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.jira_conn_id = jira_conn_id - self.result_processor = None - if result_processor is not None: - self.result_processor = result_processor - self.method_name = method_name - self.method_params = method_params - - def poke(self, context: 'Context') -> Any: - hook = JiraHook(jira_conn_id=self.jira_conn_id) - resource = hook.get_conn() - jira_result = getattr(resource, self.method_name)(**self.method_params) - if self.result_processor is None: - return jira_result - return self.result_processor(context, jira_result) - - -class JiraTicketSensor(JiraSensor): - """ - Monitors a jira ticket for given change in terms of function. - - :param jira_conn_id: reference to a pre-defined Jira Connection - :param ticket_id: id of the ticket to be monitored - :param field: field of the ticket to be monitored - :param expected_value: expected value of the field - :param result_processor: function that return boolean and act as a sensor response - """ - - template_fields: Sequence[str] = ("ticket_id",) - - def __init__( - self, - *, - jira_conn_id: str = 'jira_default', - ticket_id: Optional[str] = None, - field: Optional[str] = None, - expected_value: Optional[str] = None, - field_checker_func: Optional[Callable] = None, - **kwargs, - ) -> None: - - self.jira_conn_id = jira_conn_id - self.ticket_id = ticket_id - self.field = field - self.expected_value = expected_value - if field_checker_func is None: - field_checker_func = self.issue_field_checker - - super().__init__(jira_conn_id=jira_conn_id, result_processor=field_checker_func, **kwargs) - - def poke(self, context: 'Context') -> Any: - self.log.info('Jira Sensor checking for change in ticket: %s', self.ticket_id) - - self.method_name = "issue" - self.method_params = {'id': self.ticket_id, 'fields': self.field} - return JiraSensor.poke(self, context=context) - - def issue_field_checker(self, issue: Issue) -> Optional[bool]: - """Check issue using different conditions to prepare to evaluate sensor.""" - result = None - try: - if issue is not None and self.field is not None and self.expected_value is not None: - - field_val = getattr(issue.fields, self.field) - if field_val is not None: - if isinstance(field_val, list): - result = self.expected_value in field_val - elif isinstance(field_val, str): - result = self.expected_value.lower() == field_val.lower() - elif isinstance(field_val, Resource) and getattr(field_val, 'name'): - result = self.expected_value.lower() == field_val.name.lower() - else: - self.log.warning( - "Not implemented checker for issue field %s which " - "is neither string nor list nor Jira Resource", - self.field, - ) - - except JIRAError as jira_error: - self.log.error("Jira error while checking with expected value: %s", jira_error) - except Exception: - self.log.exception("Error while checking with expected value %s:", self.expected_value) - if result is True: - self.log.info( - "Issue field %s has expected value %s, returning success", self.field, self.expected_value - ) - else: - self.log.info("Issue field %s don't have expected value %s yet.", self.field, self.expected_value) - return result diff --git a/airflow/providers/microsoft/azure/CHANGELOG.rst b/airflow/providers/microsoft/azure/CHANGELOG.rst index aa8db5364cd52..5bfd784cc8a29 100644 --- a/airflow/providers/microsoft/azure/CHANGELOG.rst +++ b/airflow/providers/microsoft/azure/CHANGELOG.rst @@ -16,9 +16,143 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +5.0.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +* In AzureFileShareHook, if both ``extra__azure_fileshare__foo`` and ``foo`` existed in connection extra + dict, the prefixed version would be used; now, the non-prefixed version will be preferred. +* ``Remove deprecated classes (#27417)`` +* In Azure Batch ``vm_size`` and ``vm_node_agent_sku_id`` parameters are required. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Features +~~~~~~~~ + +* ``Add azure, google, authentication library limits to eaager upgrade (#27535)`` +* ``Allow and prefer non-prefixed extra fields for remaining azure (#27220)`` +* ``Allow and prefer non-prefixed extra fields for AzureFileShareHook (#27041)`` +* ``Allow and prefer non-prefixed extra fields for AzureDataExplorerHook (#27219)`` +* ``Allow and prefer non-prefixed extra fields for AzureDataFactoryHook (#27047)`` +* ``Update WasbHook to reflect preference for unprefixed extra (#27024)`` +* ``Look for 'extra__' instead of 'extra_' in 'get_field' (#27489)`` + +Bug Fixes +~~~~~~~~~ + +* ``Fix Azure Batch errors revealed by added typing to azure batch lib (#27601)`` +* ``Fix separator getting added to variables_prefix when empty (#26749)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Upgrade dependencies in order to avoid backtracking (#27531)`` + * ``Suppress any Exception in wasb task handler (#27495)`` + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update azure-storage-blob version (#25426)`` + + +4.3.0 +..... + +Features +~~~~~~~~ + +* ``Add DataFlow operations to Azure DataFactory hook (#26345)`` +* ``Add network_profile param in AzureContainerInstancesOperator (#26117)`` +* ``Add Azure synapse operator (#26038)`` +* ``Auto tail file logs in Web UI (#26169)`` +* ``Implement Azure Service Bus Topic Create, Delete Operators (#25436)`` + +Bug Fixes +~~~~~~~~~ + +* ``Fix AzureBatchOperator false negative task status (#25844)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +4.2.0 +..... + +Features +~~~~~~~~ + +* ``Add 'test_connection' method to AzureContainerInstanceHook (#25362)`` +* ``Add test_connection to Azure Batch hook (#25235)`` +* ``Bump typing-extensions and mypy for ParamSpec (#25088)`` +* ``Implement Azure Service Bus (Update and Receive) Subscription Operator (#25029)`` +* ``Set default wasb Azure http logging level to warning; fixes #16224 (#18896)`` + +4.1.0 +..... + +Features +~~~~~~~~ + +* ``Add 'test_connection' method to AzureCosmosDBHook (#25018)`` +* ``Add test_connection method to AzureFileShareHook (#24843)`` +* ``Add test_connection method to Azure WasbHook (#24771)`` +* ``Implement Azure service bus subscription Operators (#24625)`` +* ``Implement Azure Service Bus Queue Operators (#24038)`` + +Bug Fixes +~~~~~~~~~ + +* ``Update providers to use functools compat for ''cached_property'' (#24582)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +4.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Features +~~~~~~~~ + +* ``Pass connection extra parameters to wasb BlobServiceClient (#24154)`` + + +Misc +~~~~ + +* ``Apply per-run log templates to log handlers (#24153)`` +* ``Migrate Microsoft example DAGs to new design #22452 - azure (#24141)`` +* ``Add typing to Azure Cosmos Client Hook (#23941)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Clean up f-strings in logging calls (#23597)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 3.9.0 ..... diff --git a/airflow/providers/microsoft/azure/example_dags/example_adf_run_pipeline.py b/airflow/providers/microsoft/azure/example_dags/example_adf_run_pipeline.py deleted file mode 100644 index 9833dc3190205..0000000000000 --- a/airflow/providers/microsoft/azure/example_dags/example_adf_run_pipeline.py +++ /dev/null @@ -1,73 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from datetime import datetime, timedelta - -from airflow.models import DAG, BaseOperator - -try: - from airflow.operators.empty import EmptyOperator -except ModuleNotFoundError: - from airflow.operators.dummy import DummyOperator as EmptyOperator # type: ignore -from airflow.providers.microsoft.azure.operators.data_factory import AzureDataFactoryRunPipelineOperator -from airflow.providers.microsoft.azure.sensors.data_factory import AzureDataFactoryPipelineRunStatusSensor -from airflow.utils.edgemodifier import Label - -with DAG( - dag_id="example_adf_run_pipeline", - start_date=datetime(2021, 8, 13), - schedule_interval="@daily", - catchup=False, - default_args={ - "retries": 1, - "retry_delay": timedelta(minutes=3), - "azure_data_factory_conn_id": "azure_data_factory", - "factory_name": "my-data-factory", # This can also be specified in the ADF connection. - "resource_group_name": "my-resource-group", # This can also be specified in the ADF connection. - }, - default_view="graph", -) as dag: - begin = EmptyOperator(task_id="begin") - end = EmptyOperator(task_id="end") - - # [START howto_operator_adf_run_pipeline] - run_pipeline1: BaseOperator = AzureDataFactoryRunPipelineOperator( - task_id="run_pipeline1", - pipeline_name="pipeline1", - parameters={"myParam": "value"}, - ) - # [END howto_operator_adf_run_pipeline] - - # [START howto_operator_adf_run_pipeline_async] - run_pipeline2: BaseOperator = AzureDataFactoryRunPipelineOperator( - task_id="run_pipeline2", - pipeline_name="pipeline2", - wait_for_termination=False, - ) - - pipeline_run_sensor: BaseOperator = AzureDataFactoryPipelineRunStatusSensor( - task_id="pipeline_run_sensor", - run_id=run_pipeline2.output["run_id"], - ) - # [END howto_operator_adf_run_pipeline_async] - - begin >> Label("No async wait") >> run_pipeline1 - begin >> Label("Do async wait with sensor") >> run_pipeline2 - [run_pipeline1, pipeline_run_sensor] >> end - - # Task dependency created via `XComArgs`: - # run_pipeline2 >> pipeline_run_sensor diff --git a/airflow/providers/microsoft/azure/example_dags/example_adls_delete.py b/airflow/providers/microsoft/azure/example_dags/example_adls_delete.py deleted file mode 100644 index e007846f53292..0000000000000 --- a/airflow/providers/microsoft/azure/example_dags/example_adls_delete.py +++ /dev/null @@ -1,45 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -from datetime import datetime - -from airflow import models -from airflow.providers.microsoft.azure.operators.adls import ADLSDeleteOperator -from airflow.providers.microsoft.azure.transfers.local_to_adls import LocalFilesystemToADLSOperator - -LOCAL_FILE_PATH = os.environ.get("LOCAL_FILE_PATH", 'localfile.txt') -REMOTE_FILE_PATH = os.environ.get("REMOTE_LOCAL_PATH", 'remote.txt') - - -with models.DAG( - "example_adls_delete", - start_date=datetime(2021, 1, 1), - schedule_interval=None, - tags=['example'], -) as dag: - - upload_file = LocalFilesystemToADLSOperator( - task_id='upload_task', - local_path=LOCAL_FILE_PATH, - remote_path=REMOTE_FILE_PATH, - ) - # [START howto_operator_adls_delete] - remove_file = ADLSDeleteOperator(task_id="delete_task", path=REMOTE_FILE_PATH, recursive=True) - # [END howto_operator_adls_delete] - - upload_file >> remove_file diff --git a/airflow/providers/microsoft/azure/example_dags/example_azure_container_instances.py b/airflow/providers/microsoft/azure/example_dags/example_azure_container_instances.py deleted file mode 100644 index 42f258a54a415..0000000000000 --- a/airflow/providers/microsoft/azure/example_dags/example_azure_container_instances.py +++ /dev/null @@ -1,47 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This is an example dag for using the AzureContainerInstancesOperator. -""" -from datetime import datetime, timedelta - -from airflow import DAG -from airflow.providers.microsoft.azure.operators.container_instances import AzureContainerInstancesOperator - -with DAG( - dag_id='aci_example', - default_args={'retries': 1}, - schedule_interval=timedelta(days=1), - start_date=datetime(2018, 11, 1), - catchup=False, - tags=['example'], -) as dag: - - t1 = AzureContainerInstancesOperator( - ci_conn_id='azure_default', - registry_conn_id=None, - resource_group='resource-group', - name='aci-test-{{ ds }}', - image='hello-world', - region='WestUS2', - environment_variables={}, - volumes=[], - memory_in_gb=4.0, - cpu=1.0, - task_id='start_container', - ) diff --git a/airflow/providers/microsoft/azure/example_dags/example_azure_cosmosdb.py b/airflow/providers/microsoft/azure/example_dags/example_azure_cosmosdb.py deleted file mode 100644 index 626fc54644436..0000000000000 --- a/airflow/providers/microsoft/azure/example_dags/example_azure_cosmosdb.py +++ /dev/null @@ -1,59 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Ignore missing args provided by default_args -# type: ignore[call-arg] - -""" -This is only an example DAG to highlight usage of AzureCosmosDocumentSensor to detect -if a document now exists. - -You can trigger this manually with `airflow dags trigger example_cosmosdb_sensor`. - -*Note: Make sure that connection `azure_cosmos_default` is properly set before running -this example.* -""" - -from datetime import datetime - -from airflow import DAG -from airflow.providers.microsoft.azure.operators.cosmos import AzureCosmosInsertDocumentOperator -from airflow.providers.microsoft.azure.sensors.cosmos import AzureCosmosDocumentSensor - -with DAG( - dag_id='example_azure_cosmosdb_sensor', - default_args={'database_name': 'airflow_example_db'}, - start_date=datetime(2021, 1, 1), - catchup=False, - doc_md=__doc__, - tags=['example'], -) as dag: - - t1 = AzureCosmosDocumentSensor( - task_id='check_cosmos_file', - collection_name='airflow_example_coll', - document_id='airflow_checkid', - ) - - t2 = AzureCosmosInsertDocumentOperator( - task_id='insert_cosmos_file', - collection_name='new-collection', - document={"id": "someuniqueid", "param1": "value1", "param2": "value2"}, - ) - - t1 >> t2 diff --git a/airflow/providers/microsoft/azure/example_dags/example_fileshare.py b/airflow/providers/microsoft/azure/example_dags/example_fileshare.py deleted file mode 100644 index d50db3cb04027..0000000000000 --- a/airflow/providers/microsoft/azure/example_dags/example_fileshare.py +++ /dev/null @@ -1,52 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from datetime import datetime - -from airflow.decorators import task -from airflow.models import DAG -from airflow.providers.microsoft.azure.hooks.fileshare import AzureFileShareHook - -NAME = 'myfileshare' -DIRECTORY = "mydirectory" - - -@task -def create_fileshare(): - """Create a fileshare with directory""" - hook = AzureFileShareHook() - hook.create_share(NAME) - hook.create_directory(share_name=NAME, directory_name=DIRECTORY) - exists = hook.check_for_directory(share_name=NAME, directory_name=DIRECTORY) - if not exists: - raise Exception - - -@task -def delete_fileshare(): - """Delete a fileshare""" - hook = AzureFileShareHook() - hook.delete_share(NAME) - - -with DAG( - "example_fileshare", - schedule_interval="@once", - start_date=datetime(2021, 1, 1), - catchup=False, -) as dag: - create_fileshare() >> delete_fileshare() diff --git a/airflow/providers/microsoft/azure/example_dags/example_local_to_adls.py b/airflow/providers/microsoft/azure/example_dags/example_local_to_adls.py deleted file mode 100644 index 8dd0a682eb595..0000000000000 --- a/airflow/providers/microsoft/azure/example_dags/example_local_to_adls.py +++ /dev/null @@ -1,45 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -from datetime import datetime - -from airflow import models -from airflow.providers.microsoft.azure.operators.adls import ADLSDeleteOperator -from airflow.providers.microsoft.azure.transfers.local_to_adls import LocalFilesystemToADLSOperator - -LOCAL_FILE_PATH = os.environ.get("LOCAL_FILE_PATH", 'localfile.txt') -REMOTE_FILE_PATH = os.environ.get("REMOTE_LOCAL_PATH", 'remote.txt') - -with models.DAG( - "example_local_to_adls", - start_date=datetime(2021, 1, 1), - catchup=False, - schedule_interval=None, - tags=['example'], -) as dag: - # [START howto_operator_local_to_adls] - upload_file = LocalFilesystemToADLSOperator( - task_id='upload_task', - local_path=LOCAL_FILE_PATH, - remote_path=REMOTE_FILE_PATH, - ) - # [END howto_operator_local_to_adls] - - delete_file = ADLSDeleteOperator(task_id="remove_task", path=REMOTE_FILE_PATH, recursive=True) - - upload_file >> delete_file diff --git a/airflow/providers/microsoft/azure/example_dags/example_local_to_wasb.py b/airflow/providers/microsoft/azure/example_dags/example_local_to_wasb.py deleted file mode 100644 index 27372eee1fbbc..0000000000000 --- a/airflow/providers/microsoft/azure/example_dags/example_local_to_wasb.py +++ /dev/null @@ -1,40 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Ignore missing args provided by default_args -# type: ignore[call-arg] - -import os -from datetime import datetime - -from airflow.models import DAG -from airflow.providers.microsoft.azure.operators.wasb_delete_blob import WasbDeleteBlobOperator -from airflow.providers.microsoft.azure.transfers.local_to_wasb import LocalFilesystemToWasbOperator - -PATH_TO_UPLOAD_FILE = os.environ.get('AZURE_PATH_TO_UPLOAD_FILE', 'example-text.txt') - -with DAG( - "example_local_to_wasb", - schedule_interval="@once", - start_date=datetime(2021, 1, 1), - catchup=False, - default_args={"container_name": "mycontainer", "blob_name": "myblob"}, -) as dag: - upload = LocalFilesystemToWasbOperator(task_id="upload_file", file_path=PATH_TO_UPLOAD_FILE) - delete = WasbDeleteBlobOperator(task_id="delete_file") - - upload >> delete diff --git a/airflow/providers/microsoft/azure/hooks/adx.py b/airflow/providers/microsoft/azure/hooks/adx.py index 750fa051b2dc4..634c18c172faa 100644 --- a/airflow/providers/microsoft/azure/hooks/adx.py +++ b/airflow/providers/microsoft/azure/hooks/adx.py @@ -15,8 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# - """ This module contains Azure Data Explorer hook. @@ -25,7 +23,10 @@ KustoResponseDataSetV kusto """ -from typing import Any, Dict, Optional +from __future__ import annotations + +import warnings +from typing import Any from azure.kusto.data.exceptions import KustoServiceError from azure.kusto.data.request import ClientRequestProperties, KustoClient, KustoConnectionStringBuilder @@ -33,6 +34,7 @@ from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook +from airflow.providers.microsoft.azure.utils import _ensure_prefixes class AzureDataExplorerHook(BaseHook): @@ -69,82 +71,100 @@ class AzureDataExplorerHook(BaseHook): :ref:`Azure Data Explorer connection`. """ - conn_name_attr = 'azure_data_explorer_conn_id' - default_conn_name = 'azure_data_explorer_default' - conn_type = 'azure_data_explorer' - hook_name = 'Azure Data Explorer' + conn_name_attr = "azure_data_explorer_conn_id" + default_conn_name = "azure_data_explorer_default" + conn_type = "azure_data_explorer" + hook_name = "Azure Data Explorer" @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: + def get_connection_form_widgets() -> dict[str, Any]: """Returns connection widgets to add to connection form""" from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget from flask_babel import lazy_gettext from wtforms import PasswordField, StringField return { - "extra__azure_data_explorer__tenant": StringField( - lazy_gettext('Tenant ID'), widget=BS3TextFieldWidget() - ), - "extra__azure_data_explorer__auth_method": StringField( - lazy_gettext('Authentication Method'), widget=BS3TextFieldWidget() + "tenant": StringField(lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget()), + "auth_method": StringField(lazy_gettext("Authentication Method"), widget=BS3TextFieldWidget()), + "certificate": PasswordField( + lazy_gettext("Application PEM Certificate"), widget=BS3PasswordFieldWidget() ), - "extra__azure_data_explorer__certificate": PasswordField( - lazy_gettext('Application PEM Certificate'), widget=BS3PasswordFieldWidget() - ), - "extra__azure_data_explorer__thumbprint": PasswordField( - lazy_gettext('Application Certificate Thumbprint'), widget=BS3PasswordFieldWidget() + "thumbprint": PasswordField( + lazy_gettext("Application Certificate Thumbprint"), widget=BS3PasswordFieldWidget() ), } @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + @_ensure_prefixes(conn_type="azure_data_explorer") + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { - "hidden_fields": ['schema', 'port', 'extra'], + "hidden_fields": ["schema", "port", "extra"], "relabeling": { - 'login': 'Username', - 'host': 'Data Explorer Cluster URL', + "login": "Username", + "host": "Data Explorer Cluster URL", }, "placeholders": { - 'login': 'Varies with authentication method', - 'password': 'Varies with authentication method', - 'extra__azure_data_explorer__auth_method': 'AAD_APP/AAD_APP_CERT/AAD_CREDS/AAD_DEVICE', - 'extra__azure_data_explorer__tenant': 'Used with AAD_APP/AAD_APP_CERT/AAD_CREDS', - 'extra__azure_data_explorer__certificate': 'Used with AAD_APP_CERT', - 'extra__azure_data_explorer__thumbprint': 'Used with AAD_APP_CERT', + "login": "Varies with authentication method", + "password": "Varies with authentication method", + "auth_method": "AAD_APP/AAD_APP_CERT/AAD_CREDS/AAD_DEVICE", + "tenant": "Used with AAD_APP/AAD_APP_CERT/AAD_CREDS", + "certificate": "Used with AAD_APP_CERT", + "thumbprint": "Used with AAD_APP_CERT", }, } def __init__(self, azure_data_explorer_conn_id: str = default_conn_name) -> None: super().__init__() self.conn_id = azure_data_explorer_conn_id - self.connection = self.get_conn() + self.connection = self.get_conn() # todo: make this a property, or just delete def get_conn(self) -> KustoClient: """Return a KustoClient object.""" conn = self.get_connection(self.conn_id) + extras = conn.extra_dejson cluster = conn.host if not cluster: - raise AirflowException('Host connection option is required') + raise AirflowException("Host connection option is required") + + def warn_if_collison(key, backcompat_key): + if backcompat_key in extras: + warnings.warn( + f"Conflicting params `{key}` and `{backcompat_key}` found in extras for conn " + f"{self.conn_id}. Using value for `{key}`. Please ensure this is the correct value " + f"and remove the backcompat key `{backcompat_key}`." + ) def get_required_param(name: str) -> str: - """Extract required parameter value from connection, raise exception if not found""" - value = conn.extra_dejson.get(name) + """ + Extract required parameter value from connection, raise exception if not found. + + Warns if both ``foo`` and ``extra__azure_data_explorer__foo`` found in conn extra. + + Prefers unprefixed field. + """ + backcompat_prefix = "extra__azure_data_explorer__" + backcompat_key = f"{backcompat_prefix}{name}" + value = extras.get(name) + if value: + warn_if_collison(name, backcompat_key) + if not value: + value = extras.get(backcompat_key) if not value: - raise AirflowException(f'Required connection parameter is missing: `{name}`') + raise AirflowException(f"Required connection parameter is missing: `{name}`") return value - auth_method = get_required_param('extra__azure_data_explorer__auth_method') + auth_method = get_required_param("auth_method") - if auth_method == 'AAD_APP': - tenant = get_required_param('extra__azure_data_explorer__tenant') + if auth_method == "AAD_APP": + tenant = get_required_param("tenant") kcsb = KustoConnectionStringBuilder.with_aad_application_key_authentication( cluster, conn.login, conn.password, tenant ) - elif auth_method == 'AAD_APP_CERT': - certificate = get_required_param('extra__azure_data_explorer__certificate') - thumbprint = get_required_param('extra__azure_data_explorer__thumbprint') - tenant = get_required_param('extra__azure_data_explorer__tenant') + elif auth_method == "AAD_APP_CERT": + certificate = get_required_param("certificate") + thumbprint = get_required_param("thumbprint") + tenant = get_required_param("tenant") kcsb = KustoConnectionStringBuilder.with_aad_application_certificate_authentication( cluster, conn.login, @@ -152,19 +172,19 @@ def get_required_param(name: str) -> str: thumbprint, tenant, ) - elif auth_method == 'AAD_CREDS': - tenant = get_required_param('extra__azure_data_explorer__tenant') + elif auth_method == "AAD_CREDS": + tenant = get_required_param("tenant") kcsb = KustoConnectionStringBuilder.with_aad_user_password_authentication( cluster, conn.login, conn.password, tenant ) - elif auth_method == 'AAD_DEVICE': + elif auth_method == "AAD_DEVICE": kcsb = KustoConnectionStringBuilder.with_aad_device_authentication(cluster) else: - raise AirflowException(f'Unknown authentication method: {auth_method}') + raise AirflowException(f"Unknown authentication method: {auth_method}") return KustoClient(kcsb) - def run_query(self, query: str, database: str, options: Optional[Dict] = None) -> KustoResponseDataSetV2: + def run_query(self, query: str, database: str, options: dict | None = None) -> KustoResponseDataSetV2: """ Run KQL query using provided configuration, and return `azure.kusto.data.response.KustoResponseDataSet` instance. @@ -183,4 +203,4 @@ def run_query(self, query: str, database: str, options: Optional[Dict] = None) - try: return self.connection.execute(database, query, properties=properties) except KustoServiceError as error: - raise AirflowException(f'Error running Kusto query: {error}') + raise AirflowException(f"Error running Kusto query: {error}") diff --git a/airflow/providers/microsoft/azure/hooks/asb.py b/airflow/providers/microsoft/azure/hooks/asb.py new file mode 100644 index 0000000000000..3a25cfd56daf7 --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/asb.py @@ -0,0 +1,250 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any + +from azure.servicebus import ServiceBusClient, ServiceBusMessage, ServiceBusSender +from azure.servicebus.management import QueueProperties, ServiceBusAdministrationClient + +from airflow.hooks.base import BaseHook + + +class BaseAzureServiceBusHook(BaseHook): + """ + BaseAzureServiceBusHook class to create session and create connection using connection string + + :param azure_service_bus_conn_id: Reference to the + :ref:`Azure Service Bus connection`. + """ + + conn_name_attr = "azure_service_bus_conn_id" + default_conn_name = "azure_service_bus_default" + conn_type = "azure_service_bus" + hook_name = "Azure Service Bus" + + @staticmethod + def get_ui_field_behaviour() -> dict[str, Any]: + """Returns custom field behaviour""" + return { + "hidden_fields": ["port", "host", "extra", "login", "password"], + "relabeling": {"schema": "Connection String"}, + "placeholders": { + "schema": "Endpoint=sb://.servicebus.windows.net/;SharedAccessKeyName=;SharedAccessKey=", # noqa + }, + } + + def __init__(self, azure_service_bus_conn_id: str = default_conn_name) -> None: + super().__init__() + self.conn_id = azure_service_bus_conn_id + + def get_conn(self): + raise NotImplementedError + + +class AdminClientHook(BaseAzureServiceBusHook): + """ + Interacts with ServiceBusAdministrationClient client + to create, update, list, and delete resources of a + Service Bus namespace. This hook uses the same Azure Service Bus client connection inherited + from the base class + """ + + def get_conn(self) -> ServiceBusAdministrationClient: + """ + Create and returns ServiceBusAdministrationClient by using the connection + string in connection details + """ + conn = self.get_connection(self.conn_id) + + connection_string: str = str(conn.schema) + return ServiceBusAdministrationClient.from_connection_string(connection_string) + + def create_queue( + self, + queue_name: str, + max_delivery_count: int = 10, + dead_lettering_on_message_expiration: bool = True, + enable_batched_operations: bool = True, + ) -> QueueProperties: + """ + Create Queue by connecting to service Bus Admin client return the QueueProperties + + :param queue_name: The name of the queue or a QueueProperties with name. + :param max_delivery_count: The maximum delivery count. A message is automatically + dead lettered after this number of deliveries. Default value is 10.. + :param dead_lettering_on_message_expiration: A value that indicates whether this subscription has + dead letter support when a message expires. + :param enable_batched_operations: Value that indicates whether server-side batched + operations are enabled. + """ + if queue_name is None: + raise TypeError("Queue name cannot be None.") + + with self.get_conn() as service_mgmt_conn: + queue = service_mgmt_conn.create_queue( + queue_name, + max_delivery_count=max_delivery_count, + dead_lettering_on_message_expiration=dead_lettering_on_message_expiration, + enable_batched_operations=enable_batched_operations, + ) + return queue + + def delete_queue(self, queue_name: str) -> None: + """ + Delete the queue by queue_name in service bus namespace + + :param queue_name: The name of the queue or a QueueProperties with name. + """ + if queue_name is None: + raise TypeError("Queue name cannot be None.") + + with self.get_conn() as service_mgmt_conn: + service_mgmt_conn.delete_queue(queue_name) + + def delete_subscription(self, subscription_name: str, topic_name: str) -> None: + """ + Delete a topic subscription entities under a ServiceBus Namespace + + :param subscription_name: The subscription name that will own the rule in topic + :param topic_name: The topic that will own the subscription rule. + """ + if subscription_name is None: + raise TypeError("Subscription name cannot be None.") + if topic_name is None: + raise TypeError("Topic name cannot be None.") + + with self.get_conn() as service_mgmt_conn: + self.log.info("Deleting Subscription %s", subscription_name) + service_mgmt_conn.delete_subscription(topic_name, subscription_name) + + +class MessageHook(BaseAzureServiceBusHook): + """ + Interacts with ServiceBusClient and acts as a high level interface + for getting ServiceBusSender and ServiceBusReceiver. + """ + + def get_conn(self) -> ServiceBusClient: + """Create and returns ServiceBusClient by using the connection string in connection details""" + conn = self.get_connection(self.conn_id) + connection_string: str = str(conn.schema) + + self.log.info("Create and returns ServiceBusClient") + return ServiceBusClient.from_connection_string(conn_str=connection_string, logging_enable=True) + + def send_message(self, queue_name: str, messages: str | list[str], batch_message_flag: bool = False): + """ + By using ServiceBusClient Send message(s) to a Service Bus Queue. By using + batch_message_flag it enables and send message as batch message + + :param queue_name: The name of the queue or a QueueProperties with name. + :param messages: Message which needs to be sent to the queue. It can be string or list of string. + :param batch_message_flag: bool flag, can be set to True if message needs to be + sent as batch message. + """ + if queue_name is None: + raise TypeError("Queue name cannot be None.") + if not messages: + raise ValueError("Messages list cannot be empty.") + with self.get_conn() as service_bus_client, service_bus_client.get_queue_sender( + queue_name=queue_name + ) as sender: + with sender: + if isinstance(messages, str): + if not batch_message_flag: + msg = ServiceBusMessage(messages) + sender.send_messages(msg) + else: + self.send_batch_message(sender, [messages]) + else: + if not batch_message_flag: + self.send_list_messages(sender, messages) + else: + self.send_batch_message(sender, messages) + + @staticmethod + def send_list_messages(sender: ServiceBusSender, messages: list[str]): + list_messages = [ServiceBusMessage(message) for message in messages] + sender.send_messages(list_messages) # type: ignore[arg-type] + + @staticmethod + def send_batch_message(sender: ServiceBusSender, messages: list[str]): + batch_message = sender.create_message_batch() + for message in messages: + batch_message.add_message(ServiceBusMessage(message)) + sender.send_messages(batch_message) + + def receive_message( + self, queue_name, max_message_count: int | None = 1, max_wait_time: float | None = None + ): + """ + Receive a batch of messages at once in a specified Queue name + + :param queue_name: The name of the queue name or a QueueProperties with name. + :param max_message_count: Maximum number of messages in the batch. + :param max_wait_time: Maximum time to wait in seconds for the first message to arrive. + """ + if queue_name is None: + raise TypeError("Queue name cannot be None.") + + with self.get_conn() as service_bus_client, service_bus_client.get_queue_receiver( + queue_name=queue_name + ) as receiver: + with receiver: + received_msgs = receiver.receive_messages( + max_message_count=max_message_count, max_wait_time=max_wait_time + ) + for msg in received_msgs: + self.log.info(msg) + receiver.complete_message(msg) + + def receive_subscription_message( + self, + topic_name: str, + subscription_name: str, + max_message_count: int | None, + max_wait_time: float | None, + ): + """ + Receive a batch of subscription message at once. This approach is optimal if you wish + to process multiple messages simultaneously, or perform an ad-hoc receive as a single call. + + :param subscription_name: The subscription name that will own the rule in topic + :param topic_name: The topic that will own the subscription rule. + :param max_message_count: Maximum number of messages in the batch. + Actual number returned will depend on prefetch_count and incoming stream rate. + Setting to None will fully depend on the prefetch config. The default value is 1. + :param max_wait_time: Maximum time to wait in seconds for the first message to arrive. If no + messages arrive, and no timeout is specified, this call will not return until the + connection is closed. If specified, an no messages arrive within the timeout period, + an empty list will be returned. + """ + if subscription_name is None: + raise TypeError("Subscription name cannot be None.") + if topic_name is None: + raise TypeError("Topic name cannot be None.") + with self.get_conn() as service_bus_client, service_bus_client.get_subscription_receiver( + topic_name, subscription_name + ) as subscription_receiver: + with subscription_receiver: + received_msgs = subscription_receiver.receive_messages( + max_message_count=max_message_count, max_wait_time=max_wait_time + ) + for msg in received_msgs: + self.log.info(msg) + subscription_receiver.complete_message(msg) diff --git a/airflow/providers/microsoft/azure/hooks/azure_batch.py b/airflow/providers/microsoft/azure/hooks/azure_batch.py deleted file mode 100644 index 96e468c173460..0000000000000 --- a/airflow/providers/microsoft/azure/hooks/azure_batch.py +++ /dev/null @@ -1,27 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.batch`.""" - -import warnings - -from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.batch`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/microsoft/azure/hooks/azure_container_instance.py b/airflow/providers/microsoft/azure/hooks/azure_container_instance.py deleted file mode 100644 index 29ffe4e3ede91..0000000000000 --- a/airflow/providers/microsoft/azure/hooks/azure_container_instance.py +++ /dev/null @@ -1,27 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.container_instance`.""" - -import warnings - -from airflow.providers.microsoft.azure.hooks.container_instance import AzureContainerInstanceHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.container_instance`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/microsoft/azure/hooks/azure_container_registry.py b/airflow/providers/microsoft/azure/hooks/azure_container_registry.py deleted file mode 100644 index 50ef42b0bde54..0000000000000 --- a/airflow/providers/microsoft/azure/hooks/azure_container_registry.py +++ /dev/null @@ -1,27 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.container_registry`.""" - -import warnings - -from airflow.providers.microsoft.azure.hooks.container_registry import AzureContainerRegistryHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.container_registry`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/microsoft/azure/hooks/azure_container_volume.py b/airflow/providers/microsoft/azure/hooks/azure_container_volume.py deleted file mode 100644 index 83a69e8a41cc8..0000000000000 --- a/airflow/providers/microsoft/azure/hooks/azure_container_volume.py +++ /dev/null @@ -1,27 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.container_volume`.""" - -import warnings - -from airflow.providers.microsoft.azure.hooks.container_volume import AzureContainerVolumeHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.container_volume`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/microsoft/azure/hooks/azure_cosmos.py b/airflow/providers/microsoft/azure/hooks/azure_cosmos.py deleted file mode 100644 index 9f1da045e4b89..0000000000000 --- a/airflow/providers/microsoft/azure/hooks/azure_cosmos.py +++ /dev/null @@ -1,32 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.cosmos`.""" - -import warnings - -from airflow.providers.microsoft.azure.hooks.cosmos import ( # noqa - AzureCosmosDBHook, - get_collection_link, - get_database_link, - get_document_link, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.cosmos`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/microsoft/azure/hooks/azure_data_factory.py b/airflow/providers/microsoft/azure/hooks/azure_data_factory.py deleted file mode 100644 index 52faa0b91182e..0000000000000 --- a/airflow/providers/microsoft/azure/hooks/azure_data_factory.py +++ /dev/null @@ -1,27 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.data_factory`.""" - -import warnings - -from airflow.providers.microsoft.azure.hooks.data_factory import AzureDataFactoryHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.data_factory`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/microsoft/azure/hooks/azure_data_lake.py b/airflow/providers/microsoft/azure/hooks/azure_data_lake.py deleted file mode 100644 index aae7eec8db141..0000000000000 --- a/airflow/providers/microsoft/azure/hooks/azure_data_lake.py +++ /dev/null @@ -1,27 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.data_lake`.""" - -import warnings - -from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.data_lake`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/microsoft/azure/hooks/azure_fileshare.py b/airflow/providers/microsoft/azure/hooks/azure_fileshare.py deleted file mode 100644 index ec5da4b3d117f..0000000000000 --- a/airflow/providers/microsoft/azure/hooks/azure_fileshare.py +++ /dev/null @@ -1,27 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.fileshare`.""" - -import warnings - -from airflow.providers.microsoft.azure.hooks.fileshare import AzureFileShareHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.fileshare`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/microsoft/azure/hooks/base_azure.py b/airflow/providers/microsoft/azure/hooks/base_azure.py index 85e634b633721..41c5ac70b3bf6 100644 --- a/airflow/providers/microsoft/azure/hooks/base_azure.py +++ b/airflow/providers/microsoft/azure/hooks/base_azure.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import Any, Dict +from typing import Any from azure.common.client_factory import get_client_from_auth_file, get_client_from_json_dict from azure.common.credentials import ServicePrincipalCredentials @@ -34,13 +35,13 @@ class AzureBaseHook(BaseHook): which refers to the information to connect to the service. """ - conn_name_attr = 'azure_conn_id' - default_conn_name = 'azure_default' - conn_type = 'azure' - hook_name = 'Azure' + conn_name_attr = "azure_conn_id" + default_conn_name = "azure_default" + conn_type = "azure" + hook_name = "Azure" @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: + def get_connection_form_widgets() -> dict[str, Any]: """Returns connection widgets to add to connection form""" from flask_appbuilder.fieldwidgets import BS3TextFieldWidget from flask_babel import lazy_gettext @@ -48,40 +49,40 @@ def get_connection_form_widgets() -> Dict[str, Any]: return { "extra__azure__tenantId": StringField( - lazy_gettext('Azure Tenant ID'), widget=BS3TextFieldWidget() + lazy_gettext("Azure Tenant ID"), widget=BS3TextFieldWidget() ), "extra__azure__subscriptionId": StringField( - lazy_gettext('Azure Subscription ID'), widget=BS3TextFieldWidget() + lazy_gettext("Azure Subscription ID"), widget=BS3TextFieldWidget() ), } @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" import json return { - "hidden_fields": ['schema', 'port', 'host'], + "hidden_fields": ["schema", "port", "host"], "relabeling": { - 'login': 'Azure Client ID', - 'password': 'Azure Secret', + "login": "Azure Client ID", + "password": "Azure Secret", }, "placeholders": { - 'extra': json.dumps( + "extra": json.dumps( { "key_path": "path to json file for auth", "key_json": "specifies json dict for auth", }, indent=1, ), - 'login': 'client_id (token credentials auth)', - 'password': 'secret (token credentials auth)', - 'extra__azure__tenantId': 'tenantId (token credentials auth)', - 'extra__azure__subscriptionId': 'subscriptionId (token credentials auth)', + "login": "client_id (token credentials auth)", + "password": "secret (token credentials auth)", + "extra__azure__tenantId": "tenantId (token credentials auth)", + "extra__azure__subscriptionId": "subscriptionId (token credentials auth)", }, } - def __init__(self, sdk_client: Any, conn_id: str = 'azure_default'): + def __init__(self, sdk_client: Any, conn_id: str = "azure_default"): self.sdk_client = sdk_client self.conn_id = conn_id super().__init__() @@ -93,24 +94,24 @@ def get_conn(self) -> Any: :return: the authenticated client. """ conn = self.get_connection(self.conn_id) - tenant = conn.extra_dejson.get('extra__azure__tenantId') or conn.extra_dejson.get('tenantId') - subscription_id = conn.extra_dejson.get('extra__azure__subscriptionId') or conn.extra_dejson.get( - 'subscriptionId' + tenant = conn.extra_dejson.get("extra__azure__tenantId") or conn.extra_dejson.get("tenantId") + subscription_id = conn.extra_dejson.get("extra__azure__subscriptionId") or conn.extra_dejson.get( + "subscriptionId" ) - key_path = conn.extra_dejson.get('key_path') + key_path = conn.extra_dejson.get("key_path") if key_path: - if not key_path.endswith('.json'): - raise AirflowException('Unrecognised extension for key file.') - self.log.info('Getting connection using a JSON key file.') + if not key_path.endswith(".json"): + raise AirflowException("Unrecognised extension for key file.") + self.log.info("Getting connection using a JSON key file.") return get_client_from_auth_file(client_class=self.sdk_client, auth_path=key_path) - key_json = conn.extra_dejson.get('key_json') + key_json = conn.extra_dejson.get("key_json") if key_json: - self.log.info('Getting connection using a JSON config.') + self.log.info("Getting connection using a JSON config.") return get_client_from_json_dict(client_class=self.sdk_client, config_dict=key_json) - self.log.info('Getting connection using specific credentials and subscription_id.') + self.log.info("Getting connection using specific credentials and subscription_id.") return self.sdk_client( credentials=ServicePrincipalCredentials( client_id=conn.login, secret=conn.password, tenant=tenant diff --git a/airflow/providers/microsoft/azure/hooks/batch.py b/airflow/providers/microsoft/azure/hooks/batch.py index 7e1f0ac5bb749..e16e118fa9147 100644 --- a/airflow/providers/microsoft/azure/hooks/batch.py +++ b/airflow/providers/microsoft/azure/hooks/batch.py @@ -15,10 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations + import time from datetime import timedelta -from typing import Any, Dict, Optional, Set +from typing import Any from azure.batch import BatchServiceClient, batch_auth, models as batch_models from azure.batch.models import JobAddParameter, PoolAddParameter, TaskAddParameter @@ -26,6 +27,7 @@ from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook from airflow.models import Connection +from airflow.providers.microsoft.azure.utils import get_field from airflow.utils import timezone @@ -37,32 +39,38 @@ class AzureBatchHook(BaseHook): of a service principal which will be used to start the container instance. """ - conn_name_attr = 'azure_batch_conn_id' - default_conn_name = 'azure_batch_default' - conn_type = 'azure_batch' - hook_name = 'Azure Batch Service' + conn_name_attr = "azure_batch_conn_id" + default_conn_name = "azure_batch_default" + conn_type = "azure_batch" + hook_name = "Azure Batch Service" + + def _get_field(self, extras, name): + return get_field( + conn_id=self.conn_id, + conn_type=self.conn_type, + extras=extras, + field_name=name, + ) @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: + def get_connection_form_widgets() -> dict[str, Any]: """Returns connection widgets to add to connection form""" from flask_appbuilder.fieldwidgets import BS3TextFieldWidget from flask_babel import lazy_gettext from wtforms import StringField return { - "extra__azure_batch__account_url": StringField( - lazy_gettext('Batch Account URL'), widget=BS3TextFieldWidget() - ), + "account_url": StringField(lazy_gettext("Batch Account URL"), widget=BS3TextFieldWidget()), } @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { - "hidden_fields": ['schema', 'port', 'host', 'extra'], + "hidden_fields": ["schema", "port", "host", "extra"], "relabeling": { - 'login': 'Batch Account Name', - 'password': 'Batch Account Access Key', + "login": "Batch Account Name", + "password": "Batch Account Access Key", }, } @@ -84,9 +92,9 @@ def get_conn(self): """ conn = self._connection() - batch_account_url = conn.extra_dejson.get('extra__azure_batch__account_url') + batch_account_url = self._get_field(conn.extra_dejson, "account_url") if not batch_account_url: - raise AirflowException('Batch Account URL parameter is missing.') + raise AirflowException("Batch Account URL parameter is missing.") credentials = batch_auth.SharedKeyCredentials(conn.login, conn.password) batch_client = BatchServiceClient(credentials, batch_url=batch_account_url) @@ -95,17 +103,17 @@ def get_conn(self): def configure_pool( self, pool_id: str, - vm_size: Optional[str] = None, - vm_publisher: Optional[str] = None, - vm_offer: Optional[str] = None, - sku_starts_with: Optional[str] = None, - vm_sku: Optional[str] = None, - vm_version: Optional[str] = None, - vm_node_agent_sku_id: Optional[str] = None, - os_family: Optional[str] = None, - os_version: Optional[str] = None, - display_name: Optional[str] = None, - target_dedicated_nodes: Optional[int] = None, + vm_size: str, + vm_node_agent_sku_id: str, + vm_publisher: str | None = None, + vm_offer: str | None = None, + sku_starts_with: str | None = None, + vm_sku: str | None = None, + vm_version: str | None = None, + os_family: str | None = None, + os_version: str | None = None, + display_name: str | None = None, + target_dedicated_nodes: int | None = None, use_latest_image_and_sku: bool = False, **kwargs, ) -> PoolAddParameter: @@ -143,7 +151,7 @@ def configure_pool( """ if use_latest_image_and_sku: - self.log.info('Using latest verified virtual machine image with node agent sku') + self.log.info("Using latest verified virtual machine image with node agent sku") sku_to_use, image_ref_to_use = self._get_latest_verified_image_vm_and_sku( publisher=vm_publisher, offer=vm_offer, sku_starts_with=sku_starts_with ) @@ -160,7 +168,7 @@ def configure_pool( elif os_family: self.log.info( - 'Using cloud service configuration to create pool, virtual machine configuration ignored' + "Using cloud service configuration to create pool, virtual machine configuration ignored" ) pool = batch_models.PoolAddParameter( id=pool_id, @@ -174,7 +182,7 @@ def configure_pool( ) else: - self.log.info('Using virtual machine configuration to create a pool') + self.log.info("Using virtual machine configuration to create a pool") pool = batch_models.PoolAddParameter( id=pool_id, vm_size=vm_size, @@ -204,17 +212,17 @@ def create_pool(self, pool: PoolAddParameter) -> None: self.log.info("Attempting to create a pool: %s", pool.id) self.connection.pool.add(pool) self.log.info("Created pool: %s", pool.id) - except batch_models.BatchErrorException as e: - if e.error.code != "PoolExists": + except batch_models.BatchErrorException as err: + if not err.error or err.error.code != "PoolExists": raise else: self.log.info("Pool %s already exists", pool.id) def _get_latest_verified_image_vm_and_sku( self, - publisher: Optional[str] = None, - offer: Optional[str] = None, - sku_starts_with: Optional[str] = None, + publisher: str | None = None, + offer: str | None = None, + sku_starts_with: str | None = None, ) -> tuple: """ Get latest verified image vm and sku @@ -240,20 +248,20 @@ def _get_latest_verified_image_vm_and_sku( agent_sku_id, image_ref_to_use = skus_to_use[0] return agent_sku_id, image_ref_to_use - def wait_for_all_node_state(self, pool_id: str, node_state: Set) -> list: + def wait_for_all_node_state(self, pool_id: str, node_state: set) -> list: """ Wait for all nodes in a pool to reach given states :param pool_id: A string that identifies the pool :param node_state: A set of batch_models.ComputeNodeState """ - self.log.info('waiting for all nodes in pool %s to reach one of: %s', pool_id, node_state) + self.log.info("waiting for all nodes in pool %s to reach one of: %s", pool_id, node_state) while True: # refresh pool to ensure that there is no resize error pool = self.connection.pool.get(pool_id) if pool.resize_errors is not None: resize_errors = "\n".join(repr(e) for e in pool.resize_errors) - raise RuntimeError(f'resize error encountered for pool {pool.id}:\n{resize_errors}') + raise RuntimeError(f"resize error encountered for pool {pool.id}:\n{resize_errors}") nodes = list(self.connection.compute_node.list(pool.id)) if len(nodes) >= pool.target_dedicated_nodes and all(node.state in node_state for node in nodes): return nodes @@ -266,7 +274,7 @@ def configure_job( self, job_id: str, pool_id: str, - display_name: Optional[str] = None, + display_name: str | None = None, **kwargs, ) -> JobAddParameter: """ @@ -294,7 +302,7 @@ def create_job(self, job: JobAddParameter) -> None: self.connection.job.add(job) self.log.info("Job %s created", job.id) except batch_models.BatchErrorException as err: - if err.error.code != "JobExists": + if not err.error or err.error.code != "JobExists": raise else: self.log.info("Job %s already exists", job.id) @@ -303,7 +311,7 @@ def configure_task( self, task_id: str, command_line: str, - display_name: Optional[str] = None, + display_name: str | None = None, container_settings=None, **kwargs, ) -> TaskAddParameter: @@ -339,12 +347,12 @@ def add_single_task_to_job(self, job_id: str, task: TaskAddParameter) -> None: self.connection.task.add(job_id=job_id, task=task) except batch_models.BatchErrorException as err: - if err.error.code != "TaskExists": + if not err.error or err.error.code != "TaskExists": raise else: self.log.info("Task %s already exists", task.id) - def wait_for_job_tasks_to_complete(self, job_id: str, timeout: int) -> None: + def wait_for_job_tasks_to_complete(self, job_id: str, timeout: int) -> list[batch_models.CloudTask]: """ Wait for tasks in a particular job to complete @@ -357,8 +365,27 @@ def wait_for_job_tasks_to_complete(self, job_id: str, timeout: int) -> None: incomplete_tasks = [task for task in tasks if task.state != batch_models.TaskState.completed] if not incomplete_tasks: - return + # detect if any task in job has failed + fail_tasks = [ + task + for task in tasks + if task.executionInfo.result == batch_models.TaskExecutionResult.failure + ] + return fail_tasks for task in incomplete_tasks: self.log.info("Waiting for %s to complete, currently on %s state", task.id, task.state) time.sleep(15) raise TimeoutError("Timed out waiting for tasks to complete") + + def test_connection(self): + """Test a configured Azure Batch connection.""" + try: + # Attempt to list existing jobs under the configured Batch account and retrieve + # the first in the returned iterator. The Azure Batch API does allow for creation of a + # BatchServiceClient with incorrect values but then will fail properly once items are + # retrieved using the client. We need to _actually_ try to retrieve an object to properly + # test the connection. + next(self.get_conn().job.list(), None) + except Exception as e: + return False, str(e) + return True, "Successfully connected to Azure Batch." diff --git a/airflow/providers/microsoft/azure/hooks/container_instance.py b/airflow/providers/microsoft/azure/hooks/container_instance.py index 9b0cd5d17264f..41d9fadce5e2f 100644 --- a/airflow/providers/microsoft/azure/hooks/container_instance.py +++ b/airflow/providers/microsoft/azure/hooks/container_instance.py @@ -15,10 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations import warnings -from typing import Any from azure.mgmt.containerinstance import ContainerInstanceManagementClient from azure.mgmt.containerinstance.models import ContainerGroup @@ -36,17 +35,17 @@ class AzureContainerInstanceHook(AzureBaseHook): client_id (Application ID) as login, the generated password as password, and tenantId and subscriptionId in the extra's field as a json. - :param conn_id: :ref:`Azure connection id` of + :param azure_conn_id: :ref:`Azure connection id` of a service principal which will be used to start the container instance. """ - conn_name_attr = 'azure_conn_id' - default_conn_name = 'azure_default' - conn_type = 'azure_container_instance' - hook_name = 'Azure Container Instance' + conn_name_attr = "azure_conn_id" + default_conn_name = "azure_default" + conn_type = "azure_container_instance" + hook_name = "Azure Container Instance" - def __init__(self, conn_id: str = default_conn_name) -> None: - super().__init__(sdk_client=ContainerInstanceManagementClient, conn_id=conn_id) + def __init__(self, azure_conn_id: str = default_conn_name) -> None: + super().__init__(sdk_client=ContainerInstanceManagementClient, conn_id=azure_conn_id) self.connection = self.get_conn() def create_or_update(self, resource_group: str, name: str, container_group: ContainerGroup) -> None: @@ -67,7 +66,6 @@ def get_state_exitcode_details(self, resource_group: str, name: str) -> tuple: :param name: the name of the container group :return: A tuple with the state, exitcode, and details. If the exitcode is unknown 0 is returned. - :rtype: tuple(state,exitcode,details) """ warnings.warn( "get_state_exitcode_details() is deprecated. Related method is get_state()", @@ -85,7 +83,6 @@ def get_messages(self, resource_group: str, name: str) -> list: :param resource_group: the name of the resource group :param name: the name of the container group :return: A list of the event messages - :rtype: list[str] """ warnings.warn( "get_messages() is deprecated. Related method is get_state()", DeprecationWarning, stacklevel=2 @@ -94,14 +91,13 @@ def get_messages(self, resource_group: str, name: str) -> list: instance_view = cg_state.containers[0].instance_view return [event.message for event in instance_view.events] - def get_state(self, resource_group: str, name: str) -> Any: + def get_state(self, resource_group: str, name: str) -> ContainerGroup: """ Get the state of a container group :param resource_group: the name of the resource group :param name: the name of the container group :return: ContainerGroup - :rtype: ~azure.mgmt.containerinstance.models.ContainerGroup """ return self.connection.container_groups.get(resource_group, name, raw=False) @@ -113,7 +109,6 @@ def get_logs(self, resource_group: str, name: str, tail: int = 1000) -> list: :param name: the name of the container group :param tail: the size of the tail :return: A list of log messages - :rtype: list[str] """ logs = self.connection.container.list_logs(resource_group, name, name, tail=tail) return logs.content.splitlines(True) @@ -138,3 +133,15 @@ def exists(self, resource_group: str, name: str) -> bool: if container.name == name: return True return False + + def test_connection(self): + """Test a configured Azure Container Instance connection.""" + try: + # Attempt to list existing container groups under the configured subscription and retrieve the + # first in the returned iterator. We need to _actually_ try to retrieve an object to properly + # test the connection. + next(self.connection.container_groups.list(), None) + except Exception as e: + return False, str(e) + + return True, "Successfully connected to Azure Container Instance." diff --git a/airflow/providers/microsoft/azure/hooks/container_registry.py b/airflow/providers/microsoft/azure/hooks/container_registry.py index 6cc8e985178b3..785cf1a529c85 100644 --- a/airflow/providers/microsoft/azure/hooks/container_registry.py +++ b/airflow/providers/microsoft/azure/hooks/container_registry.py @@ -16,8 +16,9 @@ # specific language governing permissions and limitations # under the License. """Hook for Azure Container Registry""" +from __future__ import annotations -from typing import Any, Dict +from typing import Any from azure.mgmt.containerinstance.models import ImageRegistryCredential @@ -33,29 +34,29 @@ class AzureContainerRegistryHook(BaseHook): """ - conn_name_attr = 'azure_container_registry_conn_id' - default_conn_name = 'azure_container_registry_default' - conn_type = 'azure_container_registry' - hook_name = 'Azure Container Registry' + conn_name_attr = "azure_container_registry_conn_id" + default_conn_name = "azure_container_registry_default" + conn_type = "azure_container_registry" + hook_name = "Azure Container Registry" @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { - "hidden_fields": ['schema', 'port', 'extra'], + "hidden_fields": ["schema", "port", "extra"], "relabeling": { - 'login': 'Registry Username', - 'password': 'Registry Password', - 'host': 'Registry Server', + "login": "Registry Username", + "password": "Registry Password", + "host": "Registry Server", }, "placeholders": { - 'login': 'private registry username', - 'password': 'private registry password', - 'host': 'docker image registry server', + "login": "private registry username", + "password": "private registry password", + "host": "docker image registry server", }, } - def __init__(self, conn_id: str = 'azure_registry') -> None: + def __init__(self, conn_id: str = "azure_registry") -> None: super().__init__() self.conn_id = conn_id self.connection = self.get_conn() diff --git a/airflow/providers/microsoft/azure/hooks/container_volume.py b/airflow/providers/microsoft/azure/hooks/container_volume.py index fbd0e18723d6a..beaa6f9e2fc71 100644 --- a/airflow/providers/microsoft/azure/hooks/container_volume.py +++ b/airflow/providers/microsoft/azure/hooks/container_volume.py @@ -15,11 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +from __future__ import annotations + +from typing import Any from azure.mgmt.containerinstance.models import AzureFileVolume, Volume from airflow.hooks.base import BaseHook +from airflow.providers.microsoft.azure.utils import _ensure_prefixes, get_field class AzureContainerVolumeHook(BaseHook): @@ -32,50 +35,59 @@ class AzureContainerVolumeHook(BaseHook): """ conn_name_attr = "azure_container_volume_conn_id" - default_conn_name = 'azure_container_volume_default' - conn_type = 'azure_container_volume' - hook_name = 'Azure Container Volume' + default_conn_name = "azure_container_volume_default" + conn_type = "azure_container_volume" + hook_name = "Azure Container Volume" - def __init__(self, azure_container_volume_conn_id: str = 'azure_container_volume_default') -> None: + def __init__(self, azure_container_volume_conn_id: str = "azure_container_volume_default") -> None: super().__init__() self.conn_id = azure_container_volume_conn_id + def _get_field(self, extras, name): + return get_field( + conn_id=self.conn_id, + conn_type=self.conn_type, + extras=extras, + field_name=name, + ) + @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: + def get_connection_form_widgets() -> dict[str, Any]: """Returns connection widgets to add to connection form""" from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget from flask_babel import lazy_gettext from wtforms import PasswordField return { - "extra__azure_container_volume__connection_string": PasswordField( - lazy_gettext('Blob Storage Connection String (optional)'), widget=BS3PasswordFieldWidget() + "connection_string": PasswordField( + lazy_gettext("Blob Storage Connection String (optional)"), widget=BS3PasswordFieldWidget() ), } @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + @_ensure_prefixes(conn_type="azure_container_volume") + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { - "hidden_fields": ['schema', 'port', 'host', "extra"], + "hidden_fields": ["schema", "port", "host", "extra"], "relabeling": { - 'login': 'Azure Client ID', - 'password': 'Azure Secret', + "login": "Azure Client ID", + "password": "Azure Secret", }, "placeholders": { - 'login': 'client_id (token credentials auth)', - 'password': 'secret (token credentials auth)', - 'extra__azure_container_volume__connection_string': 'connection string auth', + "login": "client_id (token credentials auth)", + "password": "secret (token credentials auth)", + "connection_string": "connection string auth", }, } def get_storagekey(self) -> str: """Get Azure File Volume storage key""" conn = self.get_connection(self.conn_id) - service_options = conn.extra_dejson - - if 'extra__azure_container_volume__connection_string' in service_options: - for keyvalue in service_options['extra__azure_container_volume__connection_string'].split(";"): + extras = conn.extra_dejson + connection_string = self._get_field(extras, "connection_string") + if connection_string: + for keyvalue in connection_string.split(";"): key, value = keyvalue.split("=", 1) if key == "AccountKey": return value diff --git a/airflow/providers/microsoft/azure/hooks/cosmos.py b/airflow/providers/microsoft/azure/hooks/cosmos.py index ed475978b0345..b2d63a4a72ce4 100644 --- a/airflow/providers/microsoft/azure/hooks/cosmos.py +++ b/airflow/providers/microsoft/azure/hooks/cosmos.py @@ -23,14 +23,18 @@ login (=Endpoint uri), password (=secret key) and extra fields database_name and collection_name to specify the default database and collection to use (see connection `azure_cosmos_default` for an example). """ +from __future__ import annotations + +import json import uuid -from typing import Any, Dict, Optional +from typing import Any from azure.cosmos.cosmos_client import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError from airflow.exceptions import AirflowBadRequest from airflow.hooks.base import BaseHook +from airflow.providers.microsoft.azure.utils import _ensure_prefixes, get_field class AzureCosmosDBHook(BaseHook): @@ -45,52 +49,61 @@ class AzureCosmosDBHook(BaseHook): :ref:`Azure CosmosDB connection`. """ - conn_name_attr = 'azure_cosmos_conn_id' - default_conn_name = 'azure_cosmos_default' - conn_type = 'azure_cosmos' - hook_name = 'Azure CosmosDB' + conn_name_attr = "azure_cosmos_conn_id" + default_conn_name = "azure_cosmos_default" + conn_type = "azure_cosmos" + hook_name = "Azure CosmosDB" @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: + def get_connection_form_widgets() -> dict[str, Any]: """Returns connection widgets to add to connection form""" from flask_appbuilder.fieldwidgets import BS3TextFieldWidget from flask_babel import lazy_gettext from wtforms import StringField return { - "extra__azure_cosmos__database_name": StringField( - lazy_gettext('Cosmos Database Name (optional)'), widget=BS3TextFieldWidget() + "database_name": StringField( + lazy_gettext("Cosmos Database Name (optional)"), widget=BS3TextFieldWidget() ), - "extra__azure_cosmos__collection_name": StringField( - lazy_gettext('Cosmos Collection Name (optional)'), widget=BS3TextFieldWidget() + "collection_name": StringField( + lazy_gettext("Cosmos Collection Name (optional)"), widget=BS3TextFieldWidget() ), } @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + @_ensure_prefixes(conn_type="azure_cosmos") # todo: remove when min airflow version >= 2.5 + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { - "hidden_fields": ['schema', 'port', 'host', 'extra'], + "hidden_fields": ["schema", "port", "host", "extra"], "relabeling": { - 'login': 'Cosmos Endpoint URI', - 'password': 'Cosmos Master Key Token', + "login": "Cosmos Endpoint URI", + "password": "Cosmos Master Key Token", }, "placeholders": { - 'login': 'endpoint uri', - 'password': 'master key', - 'extra__azure_cosmos__database_name': 'database name', - 'extra__azure_cosmos__collection_name': 'collection name', + "login": "endpoint uri", + "password": "master key", + "database_name": "database name", + "collection_name": "collection name", }, } def __init__(self, azure_cosmos_conn_id: str = default_conn_name) -> None: super().__init__() self.conn_id = azure_cosmos_conn_id - self._conn: Optional[CosmosClient] = None + self._conn: CosmosClient | None = None self.default_database_name = None self.default_collection_name = None + def _get_field(self, extras, name): + return get_field( + conn_id=self.conn_id, + conn_type=self.conn_type, + extras=extras, + field_name=name, + ) + def get_conn(self) -> CosmosClient: """Return a cosmos db client.""" if not self._conn: @@ -99,18 +112,14 @@ def get_conn(self) -> CosmosClient: endpoint_uri = conn.login master_key = conn.password - self.default_database_name = extras.get('database_name') or extras.get( - 'extra__azure_cosmos__database_name' - ) - self.default_collection_name = extras.get('collection_name') or extras.get( - 'extra__azure_cosmos__collection_name' - ) + self.default_database_name = self._get_field(extras, "database_name") + self.default_collection_name = self._get_field(extras, "collection_name") # Initialize the Python Azure Cosmos DB client - self._conn = CosmosClient(endpoint_uri, {'masterKey': master_key}) + self._conn = CosmosClient(endpoint_uri, {"masterKey": master_key}) return self._conn - def __get_database_name(self, database_name: Optional[str] = None) -> str: + def __get_database_name(self, database_name: str | None = None) -> str: self.get_conn() db_name = database_name if db_name is None: @@ -121,7 +130,7 @@ def __get_database_name(self, database_name: Optional[str] = None) -> str: return db_name - def __get_collection_name(self, collection_name: Optional[str] = None) -> str: + def __get_collection_name(self, collection_name: str | None = None) -> str: self.get_conn() coll_name = collection_name if coll_name is None: @@ -140,14 +149,22 @@ def does_collection_exist(self, collection_name: str, database_name: str) -> boo existing_container = list( self.get_conn() .get_database_client(self.__get_database_name(database_name)) - .query_containers("SELECT * FROM r WHERE r.id=@id", [{"name": "@id", "value": collection_name}]) + .query_containers( + "SELECT * FROM r WHERE r.id=@id", + parameters=[json.dumps({"name": "@id", "value": collection_name})], + ) ) if len(existing_container) == 0: return False return True - def create_collection(self, collection_name: str, database_name: Optional[str] = None) -> None: + def create_collection( + self, + collection_name: str, + database_name: str | None = None, + partition_key: str | None = None, + ) -> None: """Creates a new collection in the CosmosDB database.""" if collection_name is None: raise AirflowBadRequest("Collection name cannot be None.") @@ -157,13 +174,16 @@ def create_collection(self, collection_name: str, database_name: Optional[str] = existing_container = list( self.get_conn() .get_database_client(self.__get_database_name(database_name)) - .query_containers("SELECT * FROM r WHERE r.id=@id", [{"name": "@id", "value": collection_name}]) + .query_containers( + "SELECT * FROM r WHERE r.id=@id", + parameters=[json.dumps({"name": "@id", "value": collection_name})], + ) ) # Only create if we did not find it already existing if len(existing_container) == 0: self.get_conn().get_database_client(self.__get_database_name(database_name)).create_container( - collection_name + collection_name, partition_key=partition_key ) def does_database_exist(self, database_name: str) -> bool: @@ -173,10 +193,8 @@ def does_database_exist(self, database_name: str) -> bool: existing_database = list( self.get_conn().query_databases( - { - "query": "SELECT * FROM r WHERE r.id=@id", - "parameters": [{"name": "@id", "value": database_name}], - } + "SELECT * FROM r WHERE r.id=@id", + parameters=[json.dumps({"name": "@id", "value": database_name})], ) ) if len(existing_database) == 0: @@ -193,10 +211,8 @@ def create_database(self, database_name: str) -> None: # to create it twice existing_database = list( self.get_conn().query_databases( - { - "query": "SELECT * FROM r WHERE r.id=@id", - "parameters": [{"name": "@id", "value": database_name}], - } + "SELECT * FROM r WHERE r.id=@id", + parameters=[json.dumps({"name": "@id", "value": database_name})], ) ) @@ -211,7 +227,7 @@ def delete_database(self, database_name: str) -> None: self.get_conn().delete_database(database_name) - def delete_collection(self, collection_name: str, database_name: Optional[str] = None) -> None: + def delete_collection(self, collection_name: str, database_name: str | None = None) -> None: """Deletes an existing collection in the CosmosDB database.""" if collection_name is None: raise AirflowBadRequest("Collection name cannot be None.") @@ -233,11 +249,11 @@ def upsert_document(self, document, database_name=None, collection_name=None, do raise AirflowBadRequest("You cannot insert a None document") # Add document id if isn't found - if 'id' in document: - if document['id'] is None: - document['id'] = document_id + if "id" in document: + if document["id"] is None: + document["id"] = document_id else: - document['id'] = document_id + document["id"] = document_id created_document = ( self.get_conn() @@ -249,7 +265,7 @@ def upsert_document(self, document, database_name=None, collection_name=None, do return created_document def insert_documents( - self, documents, database_name: Optional[str] = None, collection_name: Optional[str] = None + self, documents, database_name: str | None = None, collection_name: str | None = None ) -> list: """Insert a list of new documents into an existing collection in the CosmosDB database.""" if documents is None: @@ -267,18 +283,28 @@ def insert_documents( return created_documents def delete_document( - self, document_id: str, database_name: Optional[str] = None, collection_name: Optional[str] = None + self, + document_id: str, + database_name: str | None = None, + collection_name: str | None = None, + partition_key: str | None = None, ) -> None: """Delete an existing document out of a collection in the CosmosDB database.""" if document_id is None: raise AirflowBadRequest("Cannot delete a document without an id") - - self.get_conn().get_database_client(self.__get_database_name(database_name)).get_container_client( - self.__get_collection_name(collection_name) - ).delete_item(document_id) + ( + self.get_conn() + .get_database_client(self.__get_database_name(database_name)) + .get_container_client(self.__get_collection_name(collection_name)) + .delete_item(document_id, partition_key=partition_key) + ) def get_document( - self, document_id: str, database_name: Optional[str] = None, collection_name: Optional[str] = None + self, + document_id: str, + database_name: str | None = None, + collection_name: str | None = None, + partition_key: str | None = None, ): """Get a document from an existing collection in the CosmosDB database.""" if document_id is None: @@ -289,7 +315,7 @@ def get_document( self.get_conn() .get_database_client(self.__get_database_name(database_name)) .get_container_client(self.__get_collection_name(collection_name)) - .read_item(document_id) + .read_item(document_id, partition_key=partition_key) ) except CosmosHttpResponseError: return None @@ -297,29 +323,38 @@ def get_document( def get_documents( self, sql_string: str, - database_name: Optional[str] = None, - collection_name: Optional[str] = None, - partition_key: Optional[str] = None, - ) -> Optional[list]: + database_name: str | None = None, + collection_name: str | None = None, + partition_key: str | None = None, + ) -> list | None: """Get a list of documents from an existing collection in the CosmosDB database via SQL query.""" if sql_string is None: raise AirflowBadRequest("SQL query string cannot be None") - # Query them in SQL - query = {'query': sql_string} - try: result_iterable = ( self.get_conn() .get_database_client(self.__get_database_name(database_name)) .get_container_client(self.__get_collection_name(collection_name)) - .query_items(query, partition_key) + .query_items(sql_string, partition_key=partition_key) ) - return list(result_iterable) except CosmosHttpResponseError: return None + def test_connection(self): + """Test a configured Azure Cosmos connection.""" + try: + # Attempt to list existing databases under the configured subscription and retrieve the first in + # the returned iterator. The Azure Cosmos API does allow for creation of a + # CosmosClient with incorrect values but then will fail properly once items are + # retrieved using the client. We need to _actually_ try to retrieve an object to properly test the + # connection. + next(iter(self.get_conn().list_databases()), None) + except Exception as e: + return False, str(e) + return True, "Successfully connected to Azure Cosmos." + def get_database_link(database_id: str) -> str: """Get Azure CosmosDB database link""" diff --git a/airflow/providers/microsoft/azure/hooks/data_factory.py b/airflow/providers/microsoft/azure/hooks/data_factory.py index 03fb99272d911..a05ae87538867 100644 --- a/airflow/providers/microsoft/azure/hooks/data_factory.py +++ b/airflow/providers/microsoft/azure/hooks/data_factory.py @@ -25,18 +25,22 @@ PipelineRun TriggerResource datafactory + DataFlow mgmt """ +from __future__ import annotations + import inspect import time from functools import wraps -from typing import Any, Callable, Dict, Optional, Set, Tuple, Union +from typing import Any, Callable, Union from azure.core.polling import LROPoller from azure.identity import ClientSecretCredential, DefaultAzureCredential from azure.mgmt.datafactory import DataFactoryManagementClient from azure.mgmt.datafactory.models import ( CreateRunResponse, + DataFlow, DatasetResource, Factory, LinkedServiceResource, @@ -70,14 +74,17 @@ def bind_argument(arg, default_key): if arg not in bound_args.arguments or bound_args.arguments[arg] is None: self = args[0] conn = self.get_connection(self.conn_id) - default_value = conn.extra_dejson.get(default_key) + extras = conn.extra_dejson + default_value = extras.get(default_key) or extras.get( + f"extra__azure_data_factory__{default_key}" + ) if not default_value: raise AirflowException("Could not determine the targeted data factory.") - bound_args.arguments[arg] = conn.extra_dejson[default_key] + bound_args.arguments[arg] = default_value - bind_argument("resource_group_name", "extra__azure_data_factory__resource_group_name") - bind_argument("factory_name", "extra__azure_data_factory__factory_name") + bind_argument("resource_group_name", "resource_group_name") + bind_argument("factory_name", "factory_name") return func(*bound_args.args, **bound_args.kwargs) @@ -88,8 +95,8 @@ class PipelineRunInfo(TypedDict): """Type class for the pipeline run info dictionary.""" run_id: str - factory_name: Optional[str] - resource_group_name: Optional[str] + factory_name: str | None + resource_group_name: str | None class AzureDataFactoryPipelineRunStatus: @@ -109,6 +116,23 @@ class AzureDataFactoryPipelineRunException(AirflowException): """An exception that indicates a pipeline run failed to complete.""" +def get_field(extras: dict, field_name: str, strict: bool = False): + """Get field from extra, first checking short name, then for backcompat we check for prefixed name.""" + backcompat_prefix = "extra__azure_data_factory__" + if field_name.startswith("extra__"): + raise ValueError( + f"Got prefixed name {field_name}; please remove the '{backcompat_prefix}' prefix " + "when using this method." + ) + if field_name in extras: + return extras[field_name] or None + prefixed_name = f"{backcompat_prefix}{field_name}" + if prefixed_name in extras: + return extras[prefixed_name] or None + if strict: + raise KeyError(f"Field {field_name} not found in extras") + + class AzureDataFactoryHook(BaseHook): """ A hook to interact with Azure Data Factory. @@ -116,41 +140,35 @@ class AzureDataFactoryHook(BaseHook): :param azure_data_factory_conn_id: The :ref:`Azure Data Factory connection id`. """ - conn_type: str = 'azure_data_factory' - conn_name_attr: str = 'azure_data_factory_conn_id' - default_conn_name: str = 'azure_data_factory_default' - hook_name: str = 'Azure Data Factory' + conn_type: str = "azure_data_factory" + conn_name_attr: str = "azure_data_factory_conn_id" + default_conn_name: str = "azure_data_factory_default" + hook_name: str = "Azure Data Factory" @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: + def get_connection_form_widgets() -> dict[str, Any]: """Returns connection widgets to add to connection form""" from flask_appbuilder.fieldwidgets import BS3TextFieldWidget from flask_babel import lazy_gettext from wtforms import StringField return { - "extra__azure_data_factory__tenantId": StringField( - lazy_gettext('Tenant ID'), widget=BS3TextFieldWidget() - ), - "extra__azure_data_factory__subscriptionId": StringField( - lazy_gettext('Subscription ID'), widget=BS3TextFieldWidget() - ), - "extra__azure_data_factory__resource_group_name": StringField( - lazy_gettext('Resource Group Name'), widget=BS3TextFieldWidget() - ), - "extra__azure_data_factory__factory_name": StringField( - lazy_gettext('Factory Name'), widget=BS3TextFieldWidget() + "tenantId": StringField(lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget()), + "subscriptionId": StringField(lazy_gettext("Subscription ID"), widget=BS3TextFieldWidget()), + "resource_group_name": StringField( + lazy_gettext("Resource Group Name"), widget=BS3TextFieldWidget() ), + "factory_name": StringField(lazy_gettext("Factory Name"), widget=BS3TextFieldWidget()), } @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { - "hidden_fields": ['schema', 'port', 'host', 'extra'], + "hidden_fields": ["schema", "port", "host", "extra"], "relabeling": { - 'login': 'Client ID', - 'password': 'Secret', + "login": "Client ID", + "password": "Secret", }, } @@ -164,10 +182,11 @@ def get_conn(self) -> DataFactoryManagementClient: return self._conn conn = self.get_connection(self.conn_id) - tenant = conn.extra_dejson.get('extra__azure_data_factory__tenantId') + extras = conn.extra_dejson + tenant = get_field(extras, "tenantId") try: - subscription_id = conn.extra_dejson['extra__azure_data_factory__subscriptionId'] + subscription_id = get_field(extras, "subscriptionId", strict=True) except KeyError: raise ValueError("A Subscription ID is required to connect to Azure Data Factory.") @@ -187,7 +206,7 @@ def get_conn(self) -> DataFactoryManagementClient: @provide_targeted_factory def get_factory( - self, resource_group_name: Optional[str] = None, factory_name: Optional[str] = None, **config: Any + self, resource_group_name: str | None = None, factory_name: str | None = None, **config: Any ) -> Factory: """ Get the factory. @@ -218,8 +237,8 @@ def _create_client(credential: Credentials, subscription_id: str): def update_factory( self, factory: Factory, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> Factory: """ @@ -243,8 +262,8 @@ def update_factory( def create_factory( self, factory: Factory, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> Factory: """ @@ -266,7 +285,7 @@ def create_factory( @provide_targeted_factory def delete_factory( - self, resource_group_name: Optional[str] = None, factory_name: Optional[str] = None, **config: Any + self, resource_group_name: str | None = None, factory_name: str | None = None, **config: Any ) -> None: """ Delete the factory. @@ -281,8 +300,8 @@ def delete_factory( def get_linked_service( self, linked_service_name: str, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> LinkedServiceResource: """ @@ -314,8 +333,8 @@ def update_linked_service( self, linked_service_name: str, linked_service: LinkedServiceResource, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> LinkedServiceResource: """ @@ -341,8 +360,8 @@ def create_linked_service( self, linked_service_name: str, linked_service: LinkedServiceResource, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> LinkedServiceResource: """ @@ -367,8 +386,8 @@ def create_linked_service( def delete_linked_service( self, linked_service_name: str, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> None: """ @@ -387,8 +406,8 @@ def delete_linked_service( def get_dataset( self, dataset_name: str, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> DatasetResource: """ @@ -416,8 +435,8 @@ def update_dataset( self, dataset_name: str, dataset: DatasetResource, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> DatasetResource: """ @@ -443,8 +462,8 @@ def create_dataset( self, dataset_name: str, dataset: DatasetResource, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> DatasetResource: """ @@ -469,26 +488,135 @@ def create_dataset( def delete_dataset( self, dataset_name: str, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> None: """ Delete the dataset. :param dataset_name: The dataset name. - :param resource_group_name: The dataset name. + :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. """ self.get_conn().datasets.delete(resource_group_name, factory_name, dataset_name, **config) + @provide_targeted_factory + def get_dataflow( + self, + dataflow_name: str, + resource_group_name: str | None = None, + factory_name: str | None = None, + **config: Any, + ) -> DataFlow: + """ + Get the dataflow. + + :param dataflow_name: The dataflow name. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :return: The dataflow. + """ + return self.get_conn().data_flows.get(resource_group_name, factory_name, dataflow_name, **config) + + def _dataflow_exists( + self, + dataflow_name: str, + resource_group_name: str | None = None, + factory_name: str | None = None, + ) -> bool: + """Return whether the dataflow already exists.""" + dataflows = { + dataflow.name + for dataflow in self.get_conn().data_flows.list_by_factory(resource_group_name, factory_name) + } + + return dataflow_name in dataflows + + @provide_targeted_factory + def update_dataflow( + self, + dataflow_name: str, + dataflow: DataFlow, + resource_group_name: str | None = None, + factory_name: str | None = None, + **config: Any, + ) -> DataFlow: + """ + Update the dataflow. + + :param dataflow_name: The dataflow name. + :param dataflow: The dataflow resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the dataset does not exist. + :return: The dataflow. + """ + if not self._dataflow_exists( + dataflow_name, + resource_group_name, + factory_name, + ): + raise AirflowException(f"Dataflow {dataflow_name!r} does not exist.") + + return self.get_conn().data_flows.create_or_update( + resource_group_name, factory_name, dataflow_name, dataflow, **config + ) + + @provide_targeted_factory + def create_dataflow( + self, + dataflow_name: str, + dataflow: DataFlow, + resource_group_name: str | None = None, + factory_name: str | None = None, + **config: Any, + ) -> DataFlow: + """ + Create the dataflow. + + :param dataflow_name: The dataflow name. + :param dataflow: The dataflow resource definition. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + :raise AirflowException: If the dataset already exists. + :return: The dataset. + """ + if self._dataflow_exists(dataflow_name, resource_group_name, factory_name): + raise AirflowException(f"Dataflow {dataflow_name!r} already exists.") + + return self.get_conn().data_flows.create_or_update( + resource_group_name, factory_name, dataflow_name, dataflow, **config + ) + + @provide_targeted_factory + def delete_dataflow( + self, + dataflow_name: str, + resource_group_name: str | None = None, + factory_name: str | None = None, + **config: Any, + ) -> None: + """ + Delete the dataflow. + + :param dataflow_name: The dataflow name. + :param resource_group_name: The resource group name. + :param factory_name: The factory name. + :param config: Extra parameters for the ADF client. + """ + self.get_conn().data_flows.delete(resource_group_name, factory_name, dataflow_name, **config) + @provide_targeted_factory def get_pipeline( self, pipeline_name: str, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> PipelineResource: """ @@ -516,8 +644,8 @@ def update_pipeline( self, pipeline_name: str, pipeline: PipelineResource, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> PipelineResource: """ @@ -543,8 +671,8 @@ def create_pipeline( self, pipeline_name: str, pipeline: PipelineResource, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> PipelineResource: """ @@ -569,8 +697,8 @@ def create_pipeline( def delete_pipeline( self, pipeline_name: str, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> None: """ @@ -587,8 +715,8 @@ def delete_pipeline( def run_pipeline( self, pipeline_name: str, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> CreateRunResponse: """ @@ -608,8 +736,8 @@ def run_pipeline( def get_pipeline_run( self, run_id: str, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> PipelineRun: """ @@ -626,8 +754,8 @@ def get_pipeline_run( def get_pipeline_run_status( self, run_id: str, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, ) -> str: """ Get a pipeline run's current status. @@ -650,9 +778,9 @@ def get_pipeline_run_status( def wait_for_pipeline_run_status( self, run_id: str, - expected_statuses: Union[str, Set[str]], - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + expected_statuses: str | set[str], + resource_group_name: str | None = None, + factory_name: str | None = None, check_interval: int = 60, timeout: int = 60 * 60 * 24 * 7, ) -> bool: @@ -698,8 +826,8 @@ def wait_for_pipeline_run_status( def cancel_pipeline_run( self, run_id: str, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> None: """ @@ -716,8 +844,8 @@ def cancel_pipeline_run( def get_trigger( self, trigger_name: str, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> TriggerResource: """ @@ -745,8 +873,8 @@ def update_trigger( self, trigger_name: str, trigger: TriggerResource, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> TriggerResource: """ @@ -772,8 +900,8 @@ def create_trigger( self, trigger_name: str, trigger: TriggerResource, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> TriggerResource: """ @@ -798,8 +926,8 @@ def create_trigger( def delete_trigger( self, trigger_name: str, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> None: """ @@ -816,8 +944,8 @@ def delete_trigger( def start_trigger( self, trigger_name: str, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> LROPoller: """ @@ -835,8 +963,8 @@ def start_trigger( def stop_trigger( self, trigger_name: str, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> LROPoller: """ @@ -855,8 +983,8 @@ def rerun_trigger( self, trigger_name: str, run_id: str, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> None: """ @@ -877,8 +1005,8 @@ def cancel_trigger( self, trigger_name: str, run_id: str, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **config: Any, ) -> None: """ @@ -892,7 +1020,7 @@ def cancel_trigger( """ self.get_conn().trigger_runs.cancel(resource_group_name, factory_name, trigger_name, run_id, **config) - def test_connection(self) -> Tuple[bool, str]: + def test_connection(self) -> tuple[bool, str]: """Test a configured Azure Data Factory connection.""" success = (True, "Successfully connected to Azure Data Factory.") diff --git a/airflow/providers/microsoft/azure/hooks/data_lake.py b/airflow/providers/microsoft/azure/hooks/data_lake.py index f386947041a55..4a9bc98f925fb 100644 --- a/airflow/providers/microsoft/azure/hooks/data_lake.py +++ b/airflow/providers/microsoft/azure/hooks/data_lake.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """ This module contains integration with Azure Data Lake. @@ -24,12 +23,15 @@ login (=Client ID), password (=Client Secret) and extra fields tenant (Tenant) and account_name (Account Name) (see connection `azure_data_lake_default` for an example). """ -from typing import Any, Dict, Optional +from __future__ import annotations + +from typing import Any from azure.datalake.store import core, lib, multithread from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook +from airflow.providers.microsoft.azure.utils import _ensure_prefixes, get_field class AzureDataLakeHook(BaseHook): @@ -43,60 +45,64 @@ class AzureDataLakeHook(BaseHook): :param azure_data_lake_conn_id: Reference to the :ref:`Azure Data Lake connection`. """ - conn_name_attr = 'azure_data_lake_conn_id' - default_conn_name = 'azure_data_lake_default' - conn_type = 'azure_data_lake' - hook_name = 'Azure Data Lake' + conn_name_attr = "azure_data_lake_conn_id" + default_conn_name = "azure_data_lake_default" + conn_type = "azure_data_lake" + hook_name = "Azure Data Lake" @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: + def get_connection_form_widgets() -> dict[str, Any]: """Returns connection widgets to add to connection form""" from flask_appbuilder.fieldwidgets import BS3TextFieldWidget from flask_babel import lazy_gettext from wtforms import StringField return { - "extra__azure_data_lake__tenant": StringField( - lazy_gettext('Azure Tenant ID'), widget=BS3TextFieldWidget() - ), - "extra__azure_data_lake__account_name": StringField( - lazy_gettext('Azure DataLake Store Name'), widget=BS3TextFieldWidget() + "tenant": StringField(lazy_gettext("Azure Tenant ID"), widget=BS3TextFieldWidget()), + "account_name": StringField( + lazy_gettext("Azure DataLake Store Name"), widget=BS3TextFieldWidget() ), } @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + @_ensure_prefixes(conn_type="azure_data_lake") + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { - "hidden_fields": ['schema', 'port', 'host', 'extra'], + "hidden_fields": ["schema", "port", "host", "extra"], "relabeling": { - 'login': 'Azure Client ID', - 'password': 'Azure Client Secret', + "login": "Azure Client ID", + "password": "Azure Client Secret", }, "placeholders": { - 'login': 'client id', - 'password': 'secret', - 'extra__azure_data_lake__tenant': 'tenant id', - 'extra__azure_data_lake__account_name': 'datalake store', + "login": "client id", + "password": "secret", + "tenant": "tenant id", + "account_name": "datalake store", }, } def __init__(self, azure_data_lake_conn_id: str = default_conn_name) -> None: super().__init__() self.conn_id = azure_data_lake_conn_id - self._conn: Optional[core.AzureDLFileSystem] = None - self.account_name: Optional[str] = None + self._conn: core.AzureDLFileSystem | None = None + self.account_name: str | None = None + + def _get_field(self, extras, name): + return get_field( + conn_id=self.conn_id, + conn_type=self.conn_type, + extras=extras, + field_name=name, + ) def get_conn(self) -> core.AzureDLFileSystem: """Return a AzureDLFileSystem object.""" if not self._conn: conn = self.get_connection(self.conn_id) - service_options = conn.extra_dejson - self.account_name = service_options.get('account_name') or service_options.get( - 'extra__azure_data_lake__account_name' - ) - tenant = service_options.get('tenant') or service_options.get('extra__azure_data_lake__tenant') - + extras = conn.extra_dejson + self.account_name = self._get_field(extras, "account_name") + tenant = self._get_field(extras, "tenant") adl_creds = lib.auth(tenant_id=tenant, client_secret=conn.password, client_id=conn.login) self._conn = core.AzureDLFileSystem(adl_creds, store_name=self.account_name) self._conn.connect() @@ -108,7 +114,6 @@ def check_for_file(self, file_path: str) -> bool: :param file_path: Path and name of the file. :return: True if the file exists, False otherwise. - :rtype: bool """ try: files = self.get_conn().glob(file_path, details=False, invalidate_cache=True) diff --git a/airflow/providers/microsoft/azure/hooks/fileshare.py b/airflow/providers/microsoft/azure/hooks/fileshare.py index 1f78d6e4b2084..c3b9cd907dd0a 100644 --- a/airflow/providers/microsoft/azure/hooks/fileshare.py +++ b/airflow/providers/microsoft/azure/hooks/fileshare.py @@ -15,15 +15,45 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations + import warnings -from typing import IO, Any, Dict, List, Optional +from functools import wraps +from typing import IO, Any from azure.storage.file import File, FileService from airflow.hooks.base import BaseHook +def _ensure_prefixes(conn_type): + """ + Remove when provider min airflow version >= 2.5.0 since this is handled by + provider manager from that version. + """ + + def dec(func): + @wraps(func) + def inner(): + field_behaviors = func() + conn_attrs = {"host", "schema", "login", "password", "port", "extra"} + + def _ensure_prefix(field): + if field not in conn_attrs and not field.startswith("extra__"): + return f"extra__{conn_type}__{field}" + else: + return field + + if "placeholders" in field_behaviors: + placeholders = field_behaviors["placeholders"] + field_behaviors["placeholders"] = {_ensure_prefix(k): v for k, v in placeholders.items()} + return field_behaviors + + return inner + + return dec + + class AzureFileShareHook(BaseHook): """ Interacts with Azure FileShare Storage. @@ -35,81 +65,81 @@ class AzureFileShareHook(BaseHook): """ conn_name_attr = "azure_fileshare_conn_id" - default_conn_name = 'azure_fileshare_default' - conn_type = 'azure_fileshare' - hook_name = 'Azure FileShare' + default_conn_name = "azure_fileshare_default" + conn_type = "azure_fileshare" + hook_name = "Azure FileShare" - def __init__(self, azure_fileshare_conn_id: str = 'azure_fileshare_default') -> None: + def __init__(self, azure_fileshare_conn_id: str = "azure_fileshare_default") -> None: super().__init__() self.conn_id = azure_fileshare_conn_id self._conn = None @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: + def get_connection_form_widgets() -> dict[str, Any]: """Returns connection widgets to add to connection form""" from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget from flask_babel import lazy_gettext from wtforms import PasswordField, StringField return { - "extra__azure_fileshare__sas_token": PasswordField( - lazy_gettext('SAS Token (optional)'), widget=BS3PasswordFieldWidget() + "sas_token": PasswordField(lazy_gettext("SAS Token (optional)"), widget=BS3PasswordFieldWidget()), + "connection_string": StringField( + lazy_gettext("Connection String (optional)"), widget=BS3TextFieldWidget() ), - "extra__azure_fileshare__connection_string": StringField( - lazy_gettext('Connection String (optional)'), widget=BS3TextFieldWidget() - ), - "extra__azure_fileshare__protocol": StringField( - lazy_gettext('Account URL or token (optional)'), widget=BS3TextFieldWidget() + "protocol": StringField( + lazy_gettext("Account URL or token (optional)"), widget=BS3TextFieldWidget() ), } @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + @_ensure_prefixes(conn_type="azure_fileshare") + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { - "hidden_fields": ['schema', 'port', 'host', 'extra'], + "hidden_fields": ["schema", "port", "host", "extra"], "relabeling": { - 'login': 'Blob Storage Login (optional)', - 'password': 'Blob Storage Key (optional)', + "login": "Blob Storage Login (optional)", + "password": "Blob Storage Key (optional)", }, "placeholders": { - 'login': 'account name', - 'password': 'secret', - 'extra__azure_fileshare__sas_token': 'account url or token (optional)', - 'extra__azure_fileshare__connection_string': 'account url or token (optional)', - 'extra__azure_fileshare__protocol': 'account url or token (optional)', + "login": "account name", + "password": "secret", + "sas_token": "account url or token (optional)", + "connection_string": "account url or token (optional)", + "protocol": "account url or token (optional)", }, } def get_conn(self) -> FileService: """Return the FileService object.""" - prefix = "extra__azure_fileshare__" + + def check_for_conflict(key): + backcompat_key = f"{backcompat_prefix}{key}" + if backcompat_key in extras: + warnings.warn( + f"Conflicting params `{key}` and `{backcompat_key}` found in extras for conn " + f"{self.conn_id}. Using value for `{key}`. Please ensure this is the correct value " + f"and remove the backcompat key `{backcompat_key}`." + ) + + backcompat_prefix = "extra__azure_fileshare__" if self._conn: return self._conn conn = self.get_connection(self.conn_id) - service_options_with_prefix = conn.extra_dejson + extras = conn.extra_dejson service_options = {} - for key, value in service_options_with_prefix.items(): - # in case dedicated FileShareHook is used, the connection will use the extras from UI. - # in case deprecated wasb hook is used, the old extras will work as well - if key.startswith(prefix): - if value != '': - service_options[key[len(prefix) :]] = value - else: - # warn if the deprecated wasb_connection is used - warnings.warn( - "You are using deprecated connection for AzureFileShareHook." - " Please change it to `Azure FileShare`.", - DeprecationWarning, - ) - else: + for key, value in extras.items(): + if value == "": + continue + if not key.startswith("extra__"): service_options[key] = value - # warn if the old non-prefixed value is used - warnings.warn( - "You are using deprecated connection for AzureFileShareHook." - " Please change it to `Azure FileShare`.", - DeprecationWarning, - ) + check_for_conflict(key) + elif key.startswith(backcompat_prefix): + short_name = key[len(backcompat_prefix) :] + if short_name not in service_options: # prefer values provided with short name + service_options[short_name] = value + else: + warnings.warn(f"Extra param `{key}` not recognized; ignoring.") self._conn = FileService(account_name=conn.login, account_key=conn.password, **service_options) return self._conn @@ -122,7 +152,6 @@ def check_for_directory(self, share_name: str, directory_name: str, **kwargs) -> :param kwargs: Optional keyword arguments that `FileService.exists()` takes. :return: True if the file exists, False otherwise. - :rtype: bool """ return self.get_conn().exists(share_name, directory_name, **kwargs) @@ -136,12 +165,11 @@ def check_for_file(self, share_name: str, directory_name: str, file_name: str, * :param kwargs: Optional keyword arguments that `FileService.exists()` takes. :return: True if the file exists, False otherwise. - :rtype: bool """ return self.get_conn().exists(share_name, directory_name, file_name, **kwargs) def list_directories_and_files( - self, share_name: str, directory_name: Optional[str] = None, **kwargs + self, share_name: str, directory_name: str | None = None, **kwargs ) -> list: """ Return the list of directories and files stored on a Azure File Share. @@ -151,11 +179,10 @@ def list_directories_and_files( :param kwargs: Optional keyword arguments that `FileService.list_directories_and_files()` takes. :return: A list of files and directories - :rtype: list """ return self.get_conn().list_directories_and_files(share_name, directory_name, **kwargs) - def list_files(self, share_name: str, directory_name: Optional[str] = None, **kwargs) -> List[str]: + def list_files(self, share_name: str, directory_name: str | None = None, **kwargs) -> list[str]: """ Return the list of files stored on a Azure File Share. @@ -164,7 +191,6 @@ def list_files(self, share_name: str, directory_name: Optional[str] = None, **kw :param kwargs: Optional keyword arguments that `FileService.list_directories_and_files()` takes. :return: A list of files - :rtype: list """ return [ obj.name @@ -180,7 +206,6 @@ def create_share(self, share_name: str, **kwargs) -> bool: :param kwargs: Optional keyword arguments that `FileService.create_share()` takes. :return: True if share is created, False if share already exists. - :rtype: bool """ return self.get_conn().create_share(share_name, **kwargs) @@ -192,7 +217,6 @@ def delete_share(self, share_name: str, **kwargs) -> bool: :param kwargs: Optional keyword arguments that `FileService.delete_share()` takes. :return: True if share is deleted, False if share does not exist. - :rtype: bool """ return self.get_conn().delete_share(share_name, **kwargs) @@ -205,7 +229,6 @@ def create_directory(self, share_name: str, directory_name: str, **kwargs) -> li :param kwargs: Optional keyword arguments that `FileService.create_directory()` takes. :return: A list of files and directories - :rtype: list """ return self.get_conn().create_directory(share_name, directory_name, **kwargs) @@ -286,3 +309,19 @@ def load_stream( self.get_conn().create_file_from_stream( share_name, directory_name, file_name, stream, count, **kwargs ) + + def test_connection(self): + """Test Azure FileShare connection.""" + success = (True, "Successfully connected to Azure File Share.") + + try: + # Attempt to retrieve file share information + next(iter(self.get_conn().list_shares())) + return success + except StopIteration: + # If the iterator returned is empty it should still be considered a successful connection since + # it's possible to create a storage account without any file share and none could + # legitimately exist yet. + return success + except Exception as e: + return False, str(e) diff --git a/airflow/providers/microsoft/azure/hooks/synapse.py b/airflow/providers/microsoft/azure/hooks/synapse.py new file mode 100644 index 0000000000000..55c117cb0ff93 --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/synapse.py @@ -0,0 +1,200 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import time +from typing import Any, Union + +from azure.identity import ClientSecretCredential, DefaultAzureCredential +from azure.synapse.spark import SparkClient +from azure.synapse.spark.models import SparkBatchJobOptions + +from airflow.exceptions import AirflowTaskTimeout +from airflow.hooks.base import BaseHook +from airflow.providers.microsoft.azure.utils import get_field + +Credentials = Union[ClientSecretCredential, DefaultAzureCredential] + + +class AzureSynapseSparkBatchRunStatus: + """Azure Synapse Spark Job operation statuses.""" + + NOT_STARTED = "not_started" + STARTING = "starting" + RUNNING = "running" + IDLE = "idle" + BUSY = "busy" + SHUTTING_DOWN = "shutting_down" + ERROR = "error" + DEAD = "dead" + KILLED = "killed" + SUCCESS = "success" + + TERMINAL_STATUSES = {SUCCESS, DEAD, KILLED, ERROR} + + +class AzureSynapseHook(BaseHook): + """ + A hook to interact with Azure Synapse. + :param azure_synapse_conn_id: The :ref:`Azure Synapse connection id`. + :param spark_pool: The Apache Spark pool used to submit the job + """ + + conn_type: str = "azure_synapse" + conn_name_attr: str = "azure_synapse_conn_id" + default_conn_name: str = "azure_synapse_default" + hook_name: str = "Azure Synapse" + + @staticmethod + def get_connection_form_widgets() -> dict[str, Any]: + """Returns connection widgets to add to connection form""" + from flask_appbuilder.fieldwidgets import BS3TextFieldWidget + from flask_babel import lazy_gettext + from wtforms import StringField + + return { + "tenantId": StringField(lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget()), + "subscriptionId": StringField(lazy_gettext("Subscription ID"), widget=BS3TextFieldWidget()), + } + + @staticmethod + def get_ui_field_behaviour() -> dict[str, Any]: + """Returns custom field behaviour""" + return { + "hidden_fields": ["schema", "port", "extra"], + "relabeling": {"login": "Client ID", "password": "Secret", "host": "Synapse Workspace URL"}, + } + + def __init__(self, azure_synapse_conn_id: str = default_conn_name, spark_pool: str = ""): + self.job_id: int | None = None + self._conn: SparkClient | None = None + self.conn_id = azure_synapse_conn_id + self.spark_pool = spark_pool + super().__init__() + + def _get_field(self, extras, name): + return get_field( + conn_id=self.conn_id, + conn_type=self.conn_type, + extras=extras, + field_name=name, + ) + + def get_conn(self) -> SparkClient: + if self._conn is not None: + return self._conn + + conn = self.get_connection(self.conn_id) + extras = conn.extra_dejson + tenant = self._get_field(extras, "tenantId") + spark_pool = self.spark_pool + livy_api_version = "2022-02-22-preview" + + subscription_id = self._get_field(extras, "subscriptionId") + if not subscription_id: + raise ValueError("A Subscription ID is required to connect to Azure Synapse.") + + credential: Credentials + if conn.login is not None and conn.password is not None: + if not tenant: + raise ValueError("A Tenant ID is required when authenticating with Client ID and Secret.") + + credential = ClientSecretCredential( + client_id=conn.login, client_secret=conn.password, tenant_id=tenant + ) + else: + credential = DefaultAzureCredential() + + self._conn = self._create_client(credential, conn.host, spark_pool, livy_api_version, subscription_id) + + return self._conn + + @staticmethod + def _create_client(credential: Credentials, host, spark_pool, livy_api_version, subscription_id: str): + return SparkClient( + credential=credential, + endpoint=host, + spark_pool_name=spark_pool, + livy_api_version=livy_api_version, + subscription_id=subscription_id, + ) + + def run_spark_job( + self, + payload: SparkBatchJobOptions, + ): + """ + Run a job in an Apache Spark pool. + :param payload: Livy compatible payload which represents the spark job that a user wants to submit. + """ + job = self.get_conn().spark_batch.create_spark_batch_job(payload) + self.job_id = job.id + return job + + def get_job_run_status(self): + """Get the job run status.""" + job_run_status = self.get_conn().spark_batch.get_spark_batch_job(batch_id=self.job_id).state + return job_run_status + + def wait_for_job_run_status( + self, + job_id: int | None, + expected_statuses: str | set[str], + check_interval: int = 60, + timeout: int = 60 * 60 * 24 * 7, + ) -> bool: + """ + Waits for a job run to match an expected status. + + :param job_id: The job run identifier. + :param expected_statuses: The desired status(es) to check against a job run's current status. + :param check_interval: Time in seconds to check on a job run's status. + :param timeout: Time in seconds to wait for a job to reach a terminal status or the expected + status. + + """ + job_run_status = self.get_job_run_status() + start_time = time.monotonic() + + while ( + job_run_status not in AzureSynapseSparkBatchRunStatus.TERMINAL_STATUSES + and job_run_status not in expected_statuses + ): + # Check if the job-run duration has exceeded the ``timeout`` configured. + if start_time + timeout < time.monotonic(): + raise AirflowTaskTimeout( + f"Job {job_id} has not reached a terminal status after {timeout} seconds." + ) + + # Wait to check the status of the job run based on the ``check_interval`` configured. + self.log.info("Sleeping for %s seconds", str(check_interval)) + time.sleep(check_interval) + + job_run_status = self.get_job_run_status() + self.log.info("Current spark job run status is %s", job_run_status) + + return job_run_status in expected_statuses + + def cancel_job_run( + self, + job_id: int, + ) -> None: + """ + Cancel the spark job run. + :param job_id: The synapse spark job identifier. + """ + self.get_conn().spark_batch.cancel_spark_batch_job(job_id) diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py b/airflow/providers/microsoft/azure/hooks/wasb.py index 0bcd9952fcf3a..27680a5b69408 100644 --- a/airflow/providers/microsoft/azure/hooks/wasb.py +++ b/airflow/providers/microsoft/azure/hooks/wasb.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """ This module contains integration with Azure Blob Storage. @@ -23,10 +22,13 @@ Airflow connection of type `wasb` exists. Authorization can be done by supplying a login (=Storage account name) and password (=KEY), or login and SAS token in the extra field (see connection `wasb_default` for an example). - """ +from __future__ import annotations -from typing import Any, Dict, List, Optional +import logging +import os +from functools import wraps +from typing import Any from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceNotFoundError from azure.identity import ClientSecretCredential, DefaultAzureCredential @@ -36,6 +38,34 @@ from airflow.hooks.base import BaseHook +def _ensure_prefixes(conn_type): + """ + Remove when provider min airflow version >= 2.5.0 since this is handled by + provider manager from that version. + """ + + def dec(func): + @wraps(func) + def inner(): + field_behaviors = func() + conn_attrs = {"host", "schema", "login", "password", "port", "extra"} + + def _ensure_prefix(field): + if field not in conn_attrs and not field.startswith("extra__"): + return f"extra__{conn_type}__{field}" + else: + return field + + if "placeholders" in field_behaviors: + placeholders = field_behaviors["placeholders"] + field_behaviors["placeholders"] = {_ensure_prefix(k): v for k, v in placeholders.items()} + return field_behaviors + + return inner + + return dec + + class WasbHook(BaseHook): """ Interacts with Azure Blob Storage through the ``wasb://`` protocol. @@ -53,52 +83,51 @@ class WasbHook(BaseHook): :param public_read: Whether an anonymous public read access should be used. default is False """ - conn_name_attr = 'wasb_conn_id' - default_conn_name = 'wasb_default' - conn_type = 'wasb' - hook_name = 'Azure Blob Storage' + conn_name_attr = "wasb_conn_id" + default_conn_name = "wasb_default" + conn_type = "wasb" + hook_name = "Azure Blob Storage" @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: + def get_connection_form_widgets() -> dict[str, Any]: """Returns connection widgets to add to connection form""" from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget from flask_babel import lazy_gettext from wtforms import PasswordField, StringField return { - "extra__wasb__connection_string": PasswordField( - lazy_gettext('Blob Storage Connection String (optional)'), widget=BS3PasswordFieldWidget() + "connection_string": PasswordField( + lazy_gettext("Blob Storage Connection String (optional)"), widget=BS3PasswordFieldWidget() ), - "extra__wasb__shared_access_key": PasswordField( - lazy_gettext('Blob Storage Shared Access Key (optional)'), widget=BS3PasswordFieldWidget() + "shared_access_key": PasswordField( + lazy_gettext("Blob Storage Shared Access Key (optional)"), widget=BS3PasswordFieldWidget() ), - "extra__wasb__tenant_id": StringField( - lazy_gettext('Tenant Id (Active Directory Auth)'), widget=BS3TextFieldWidget() - ), - "extra__wasb__sas_token": PasswordField( - lazy_gettext('SAS Token (optional)'), widget=BS3PasswordFieldWidget() + "tenant_id": StringField( + lazy_gettext("Tenant Id (Active Directory Auth)"), widget=BS3TextFieldWidget() ), + "sas_token": PasswordField(lazy_gettext("SAS Token (optional)"), widget=BS3PasswordFieldWidget()), } @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + @_ensure_prefixes(conn_type="wasb") + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { - "hidden_fields": ['schema', 'port'], + "hidden_fields": ["schema", "port"], "relabeling": { - 'login': 'Blob Storage Login (optional)', - 'password': 'Blob Storage Key (optional)', - 'host': 'Account Name (Active Directory Auth)', + "login": "Blob Storage Login (optional)", + "password": "Blob Storage Key (optional)", + "host": "Account Name (Active Directory Auth)", }, "placeholders": { - 'extra': 'additional options for use with FileService and AzureFileVolume', - 'login': 'account name', - 'password': 'secret', - 'host': 'account url', - 'extra__wasb__connection_string': 'connection string auth', - 'extra__wasb__tenant_id': 'tenant', - 'extra__wasb__shared_access_key': 'shared access key', - 'extra__wasb__sas_token': 'account url or token', + "extra": "additional options for use with FileService and AzureFileVolume", + "login": "account name", + "password": "secret", + "host": "account url", + "connection_string": "connection string auth", + "tenant_id": "tenant", + "shared_access_key": "shared access key", + "sas_token": "account url or token", }, } @@ -112,6 +141,23 @@ def __init__( self.public_read = public_read self.blob_service_client = self.get_conn() + logger = logging.getLogger("azure.core.pipeline.policies.http_logging_policy") + try: + logger.setLevel(os.environ.get("AZURE_HTTP_LOGGING_LEVEL", logging.WARNING)) + except ValueError: + logger.setLevel(logging.WARNING) + + def _get_field(self, extra_dict, field_name): + prefix = "extra__wasb__" + if field_name.startswith("extra__"): + raise ValueError( + f"Got prefixed name {field_name}; please remove the '{prefix}' prefix " + f"when using this method." + ) + if field_name in extra_dict: + return extra_dict[field_name] or None + return extra_dict.get(f"{prefix}{field_name}") or None + def get_conn(self) -> BlobServiceClient: """Return the BlobServiceClient object.""" conn = self.get_connection(self.conn_id) @@ -121,31 +167,33 @@ def get_conn(self) -> BlobServiceClient: # Here we use anonymous public read # more info # https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources - return BlobServiceClient(account_url=conn.host) + return BlobServiceClient(account_url=conn.host, **extra) - if extra.get('connection_string') or extra.get('extra__wasb__connection_string'): + connection_string = self._get_field(extra, "connection_string") + if connection_string: # connection_string auth takes priority - connection_string = extra.get('connection_string') or extra.get('extra__wasb__connection_string') - return BlobServiceClient.from_connection_string(connection_string) - if extra.get('shared_access_key') or extra.get('extra__wasb__shared_access_key'): - shared_access_key = extra.get('shared_access_key') or extra.get('extra__wasb__shared_access_key') + return BlobServiceClient.from_connection_string(connection_string, **extra) + + shared_access_key = self._get_field(extra, "shared_access_key") + if shared_access_key: # using shared access key - return BlobServiceClient(account_url=conn.host, credential=shared_access_key) - if extra.get('tenant_id') or extra.get('extra__wasb__tenant_id'): + return BlobServiceClient(account_url=conn.host, credential=shared_access_key, **extra) + + tenant = self._get_field(extra, "tenant_id") + if tenant: # use Active Directory auth app_id = conn.login app_secret = conn.password - tenant = extra.get('tenant_id', extra.get('extra__wasb__tenant_id')) token_credential = ClientSecretCredential(tenant, app_id, app_secret) - return BlobServiceClient(account_url=conn.host, credential=token_credential) + return BlobServiceClient(account_url=conn.host, credential=token_credential, **extra) - sas_token = extra.get('sas_token') or extra.get('extra__wasb__sas_token') + sas_token = self._get_field(extra, "sas_token") if sas_token: - if sas_token.startswith('https'): - return BlobServiceClient(account_url=sas_token) + if sas_token.startswith("https"): + return BlobServiceClient(account_url=sas_token, **extra) else: return BlobServiceClient( - account_url=f'https://{conn.login}.blob.core.windows.net/{sas_token}' + account_url=f"https://{conn.login}.blob.core.windows.net/{sas_token}", **extra ) # Fall back to old auth (password) or use managed identity if not provided. @@ -185,7 +233,6 @@ def check_for_blob(self, container_name: str, blob_name: str, **kwargs) -> bool: :param blob_name: Name of the blob. :param kwargs: Optional keyword arguments for ``BlobClient.get_blob_properties`` takes. :return: True if the blob exists, False otherwise. - :rtype: bool """ try: self._get_blob_client(container_name, blob_name).get_blob_properties(**kwargs) @@ -193,7 +240,7 @@ def check_for_blob(self, container_name: str, blob_name: str, **kwargs) -> bool: return False return True - def check_for_prefix(self, container_name: str, prefix: str, **kwargs): + def check_for_prefix(self, container_name: str, prefix: str, **kwargs) -> bool: """ Check if a prefix exists on Azure Blob storage. @@ -201,7 +248,6 @@ def check_for_prefix(self, container_name: str, prefix: str, **kwargs): :param prefix: Prefix of the blob. :param kwargs: Optional keyword arguments that ``ContainerClient.walk_blobs`` takes :return: True if blobs matching the prefix exist, False otherwise. - :rtype: bool """ blobs = self.get_blobs_list(container_name=container_name, prefix=prefix, **kwargs) return len(blobs) > 0 @@ -209,11 +255,11 @@ def check_for_prefix(self, container_name: str, prefix: str, **kwargs): def get_blobs_list( self, container_name: str, - prefix: Optional[str] = None, - include: Optional[List[str]] = None, - delimiter: Optional[str] = '/', + prefix: str | None = None, + include: list[str] | None = None, + delimiter: str = "/", **kwargs, - ) -> List: + ) -> list: """ List blobs in a given container @@ -250,7 +296,7 @@ def load_file( useful if the target container may not exist yet. Defaults to False. :param kwargs: Optional keyword arguments that ``BlobClient.upload_blob()`` takes. """ - with open(file_path, 'rb') as data: + with open(file_path, "rb") as data: self.upload( container_name=container_name, blob_name=blob_name, @@ -314,11 +360,11 @@ def upload( container_name: str, blob_name: str, data: Any, - blob_type: str = 'BlockBlob', - length: Optional[int] = None, + blob_type: str = "BlockBlob", + length: int | None = None, create_container: bool = False, **kwargs, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Creates a new blob from a data source with automatic chunking. @@ -339,7 +385,7 @@ def upload( return blob_client.upload_blob(data, blob_type, length=length, **kwargs) def download( - self, container_name, blob_name, offset: Optional[int] = None, length: Optional[int] = None, **kwargs + self, container_name, blob_name, offset: int | None = None, length: int | None = None, **kwargs ) -> StorageStreamDownloader: """ Downloads a blob to the StorageStreamDownloader @@ -361,7 +407,7 @@ def create_container(self, container_name: str) -> None: """ container_client = self._get_container_client(container_name) try: - self.log.debug('Attempting to create container: %s', container_name) + self.log.debug("Attempting to create container: %s", container_name) container_client.create_container() self.log.info("Created container: %s", container_name) except ResourceExistsError: @@ -381,7 +427,7 @@ def create_container(self, container_name: str) -> None: self.conn_id, ) except Exception as e: - self.log.info('Error while attempting to create container %r: %s', container_name, e) + self.log.info("Error while attempting to create container %r: %s", container_name, e) raise def delete_container(self, container_name: str) -> None: @@ -391,13 +437,13 @@ def delete_container(self, container_name: str) -> None: :param container_name: The name of the container """ try: - self.log.debug('Attempting to delete container: %s', container_name) + self.log.debug("Attempting to delete container: %s", container_name) self._get_container_client(container_name).delete_container() - self.log.info('Deleted container: %s', container_name) + self.log.info("Deleted container: %s", container_name) except ResourceNotFoundError: - self.log.info('Unable to delete container %s (not found)', container_name) + self.log.info("Unable to delete container %s (not found)", container_name) except: # noqa: E722 - self.log.info('Error deleting container: %s', container_name) + self.log.info("Error deleting container: %s", container_name) raise def delete_blobs(self, container_name: str, *blobs, **kwargs) -> None: @@ -417,7 +463,7 @@ def delete_file( blob_name: str, is_prefix: bool = False, ignore_if_missing: bool = False, - delimiter: str = '', + delimiter: str = "", **kwargs, ) -> None: """ @@ -439,6 +485,17 @@ def delete_file( else: blobs_to_delete = [] if not ignore_if_missing and len(blobs_to_delete) == 0: - raise AirflowException(f'Blob(s) not found: {blob_name}') + raise AirflowException(f"Blob(s) not found: {blob_name}") self.delete_blobs(container_name, *blobs_to_delete, **kwargs) + + def test_connection(self): + """Test Azure Blob Storage connection.""" + success = (True, "Successfully connected to Azure Blob Storage.") + + try: + # Attempt to retrieve storage account information + self.get_conn().get_account_information() + return success + except Exception as e: + return False, str(e) diff --git a/airflow/providers/microsoft/azure/log/wasb_task_handler.py b/airflow/providers/microsoft/azure/log/wasb_task_handler.py index 9ec0cdf646fc4..8c0fe220830ef 100644 --- a/airflow/providers/microsoft/azure/log/wasb_task_handler.py +++ b/airflow/providers/microsoft/azure/log/wasb_task_handler.py @@ -15,18 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import os import shutil -import sys -from typing import Dict, Optional, Tuple - -from azure.common import AzureHttpError - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property +from typing import Any +from airflow.compat.functools import cached_property from airflow.configuration import conf from airflow.utils.log.file_task_handler import FileTaskHandler from airflow.utils.log.logging_mixin import LoggingMixin @@ -44,13 +39,14 @@ def __init__( base_log_folder: str, wasb_log_folder: str, wasb_container: str, - filename_template: str, delete_local_copy: str, + *, + filename_template: str | None = None, ) -> None: super().__init__(base_log_folder, filename_template) self.wasb_container = wasb_container self.remote_base = wasb_log_folder - self.log_relative_path = '' + self.log_relative_path = "" self._hook = None self.closed = False self.upload_on_close = True @@ -59,16 +55,17 @@ def __init__( @cached_property def hook(self): """Returns WasbHook.""" - remote_conn_id = conf.get('logging', 'REMOTE_LOG_CONN_ID') + remote_conn_id = conf.get("logging", "REMOTE_LOG_CONN_ID") try: from airflow.providers.microsoft.azure.hooks.wasb import WasbHook return WasbHook(remote_conn_id) - except AzureHttpError: + except Exception: self.log.exception( - 'Could not create an WasbHook with connection id "%s".' - ' Please make sure that apache-airflow[azure] is installed' - ' and the Wasb connection exists.', + "Could not create a WasbHook with connection id '%s'. " + "Do you have apache-airflow[azure] installed? " + "Does connection the connection exist, and is it " + "configured properly?", remote_conn_id, ) return None @@ -107,7 +104,9 @@ def close(self) -> None: # Mark closed so we don't double write if close is called twice self.closed = True - def _read(self, ti, try_number: int, metadata: Optional[str] = None) -> Tuple[str, Dict[str, bool]]: + def _read( + self, ti, try_number: int, metadata: dict[str, Any] | None = None + ) -> tuple[str, dict[str, bool]]: """ Read logs of given task instance and try_number from Wasb remote storage. If failed, read the log from task instance host machine. @@ -128,10 +127,10 @@ def _read(self, ti, try_number: int, metadata: Optional[str] = None) -> Tuple[st # local machine even if there are errors reading remote logs, as # returned remote_log will contain error messages. remote_log = self.wasb_read(remote_loc, return_error=True) - log = f'*** Reading remote log from {remote_loc}.\n{remote_log}\n' - return log, {'end_of_log': True} + log = f"*** Reading remote log from {remote_loc}.\n{remote_log}\n" + return log, {"end_of_log": True} else: - return super()._read(ti, try_number) + return super()._read(ti, try_number, metadata) def wasb_log_exists(self, remote_log_location: str) -> bool: """ @@ -158,13 +157,13 @@ def wasb_read(self, remote_log_location: str, return_error: bool = False): """ try: return self.hook.read_file(self.wasb_container, remote_log_location) - except AzureHttpError: - msg = f'Could not read logs from {remote_log_location}' + except Exception: + msg = f"Could not read logs from {remote_log_location}" self.log.exception(msg) # return error if needed if return_error: return msg - return '' + return "" def wasb_write(self, log: str, remote_log_location: str, append: bool = True) -> None: """ @@ -178,9 +177,9 @@ def wasb_write(self, log: str, remote_log_location: str, append: bool = True) -> """ if append and self.wasb_log_exists(remote_log_location): old_log = self.wasb_read(remote_log_location) - log = '\n'.join([old_log, log]) if old_log else log + log = "\n".join([old_log, log]) if old_log else log try: self.hook.load_string(log, self.wasb_container, remote_log_location, overwrite=True) - except AzureHttpError: - self.log.exception('Could not write logs to %s', remote_log_location) + except Exception: + self.log.exception("Could not write logs to %s", remote_log_location) diff --git a/airflow/providers/microsoft/azure/operators/adls.py b/airflow/providers/microsoft/azure/operators/adls.py index 9664a107e7a97..7d738ea53cb19 100644 --- a/airflow/providers/microsoft/azure/operators/adls.py +++ b/airflow/providers/microsoft/azure/operators/adls.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from typing import TYPE_CHECKING, Any, Sequence @@ -38,8 +39,8 @@ class ADLSDeleteOperator(BaseOperator): :param azure_data_lake_conn_id: Reference to the :ref:`Azure Data Lake connection`. """ - template_fields: Sequence[str] = ('path',) - ui_color = '#901dd2' + template_fields: Sequence[str] = ("path",) + ui_color = "#901dd2" def __init__( self, @@ -47,7 +48,7 @@ def __init__( path: str, recursive: bool = False, ignore_not_found: bool = True, - azure_data_lake_conn_id: str = 'azure_data_lake_default', + azure_data_lake_conn_id: str = "azure_data_lake_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -56,7 +57,7 @@ def __init__( self.ignore_not_found = ignore_not_found self.azure_data_lake_conn_id = azure_data_lake_conn_id - def execute(self, context: "Context") -> Any: + def execute(self, context: Context) -> Any: hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id) return hook.remove(path=self.path, recursive=self.recursive, ignore_not_found=self.ignore_not_found) @@ -83,17 +84,17 @@ class ADLSListOperator(BaseOperator): ) """ - template_fields: Sequence[str] = ('path',) - ui_color = '#901dd2' + template_fields: Sequence[str] = ("path",) + ui_color = "#901dd2" def __init__( - self, *, path: str, azure_data_lake_conn_id: str = 'azure_data_lake_default', **kwargs + self, *, path: str, azure_data_lake_conn_id: str = "azure_data_lake_default", **kwargs ) -> None: super().__init__(**kwargs) self.path = path self.azure_data_lake_conn_id = azure_data_lake_conn_id - def execute(self, context: "Context") -> list: + def execute(self, context: Context) -> list: hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id) - self.log.info('Getting list of ADLS files in path: %s', self.path) + self.log.info("Getting list of ADLS files in path: %s", self.path) return hook.list(path=self.path) diff --git a/airflow/providers/microsoft/azure/operators/adls_delete.py b/airflow/providers/microsoft/azure/operators/adls_delete.py deleted file mode 100644 index 3796005dacbe6..0000000000000 --- a/airflow/providers/microsoft/azure/operators/adls_delete.py +++ /dev/null @@ -1,38 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.operators.adls`.""" - -import warnings - -from airflow.providers.microsoft.azure.operators.adls import ADLSDeleteOperator - - -class AzureDataLakeStorageDeleteOperator(ADLSDeleteOperator): - """ - This class is deprecated. - Please use `airflow.providers.microsoft.azure.operators.adls.ADLSDeleteOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.microsoft.azure.operators.adls.ADLSDeleteOperator`.""", - DeprecationWarning, - stacklevel=3, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/microsoft/azure/operators/adls_list.py b/airflow/providers/microsoft/azure/operators/adls_list.py deleted file mode 100644 index 715e20aff0e4a..0000000000000 --- a/airflow/providers/microsoft/azure/operators/adls_list.py +++ /dev/null @@ -1,39 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.operators.adls`.""" - -import warnings - -from airflow.providers.microsoft.azure.operators.adls import ADLSListOperator - - -class AzureDataLakeStorageListOperator(ADLSListOperator): - """ - This class is deprecated. - Please use `airflow.providers.microsoft.azure.operators.adls.ADLSListOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - """This class is deprecated. - Please use - `airflow.providers.microsoft.azure.operators.adls.ADLSListOperator`.""", - DeprecationWarning, - stacklevel=3, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/microsoft/azure/operators/adx.py b/airflow/providers/microsoft/azure/operators/adx.py index 91b130ff0a85f..1578451fcbc68 100644 --- a/airflow/providers/microsoft/azure/operators/adx.py +++ b/airflow/providers/microsoft/azure/operators/adx.py @@ -15,10 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# - """This module contains Azure Data Explorer operators""" -from typing import TYPE_CHECKING, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from azure.kusto.data._models import KustoResultTable @@ -42,17 +42,17 @@ class AzureDataExplorerQueryOperator(BaseOperator): :ref:`Azure Data Explorer connection`. """ - ui_color = '#00a1f2' - template_fields: Sequence[str] = ('query', 'database') - template_ext: Sequence[str] = ('.kql',) + ui_color = "#00a1f2" + template_fields: Sequence[str] = ("query", "database") + template_ext: Sequence[str] = (".kql",) def __init__( self, *, query: str, database: str, - options: Optional[dict] = None, - azure_data_explorer_conn_id: str = 'azure_data_explorer_default', + options: dict | None = None, + azure_data_explorer_conn_id: str = "azure_data_explorer_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -65,7 +65,7 @@ def get_hook(self) -> AzureDataExplorerHook: """Returns new instance of AzureDataExplorerHook""" return AzureDataExplorerHook(self.azure_data_explorer_conn_id) - def execute(self, context: "Context") -> Union[KustoResultTable, str]: + def execute(self, context: Context) -> KustoResultTable | str: """ Run KQL Query on Azure Data Explorer (Kusto). Returns `PrimaryResult` of Query v2 HTTP response contents @@ -73,7 +73,7 @@ def execute(self, context: "Context") -> Union[KustoResultTable, str]: """ hook = self.get_hook() response = hook.run_query(self.query, self.database, self.options) - if conf.getboolean('core', 'enable_xcom_pickling'): + if conf.getboolean("core", "enable_xcom_pickling"): return response.primary_results[0] else: return str(response.primary_results[0]) diff --git a/airflow/providers/microsoft/azure/operators/asb.py b/airflow/providers/microsoft/azure/operators/asb.py new file mode 100644 index 0000000000000..ccb8678a0d285 --- /dev/null +++ b/airflow/providers/microsoft/azure/operators/asb.py @@ -0,0 +1,634 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime +from typing import TYPE_CHECKING, Any, Sequence + +from airflow.models import BaseOperator +from airflow.providers.microsoft.azure.hooks.asb import AdminClientHook, MessageHook + +if TYPE_CHECKING: + from azure.servicebus.management._models import AuthorizationRule + + from airflow.utils.context import Context + + +class AzureServiceBusCreateQueueOperator(BaseOperator): + """ + Creates a Azure Service Bus queue under a Service Bus Namespace by using ServiceBusAdministrationClient + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AzureServiceBusCreateQueueOperator` + + :param queue_name: The name of the queue. should be unique. + :param max_delivery_count: The maximum delivery count. A message is automatically + dead lettered after this number of deliveries. Default value is 10.. + :param dead_lettering_on_message_expiration: A value that indicates whether this subscription has + dead letter support when a message expires. + :param enable_batched_operations: Value that indicates whether server-side batched + operations are enabled. + :param azure_service_bus_conn_id: Reference to the + :ref:`Azure Service Bus connection`. + """ + + template_fields: Sequence[str] = ("queue_name",) + ui_color = "#e4f0e8" + + def __init__( + self, + *, + queue_name: str, + max_delivery_count: int = 10, + dead_lettering_on_message_expiration: bool = True, + enable_batched_operations: bool = True, + azure_service_bus_conn_id: str = "azure_service_bus_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.queue_name = queue_name + self.max_delivery_count = max_delivery_count + self.dead_lettering_on_message_expiration = dead_lettering_on_message_expiration + self.enable_batched_operations = enable_batched_operations + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: Context) -> None: + """Creates Queue in Azure Service Bus namespace, by connecting to Service Bus Admin client in hook""" + hook = AdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # create queue with name + queue = hook.create_queue( + self.queue_name, + self.max_delivery_count, + self.dead_lettering_on_message_expiration, + self.enable_batched_operations, + ) + self.log.info("Created Queue %s", queue.name) + + +class AzureServiceBusSendMessageOperator(BaseOperator): + """ + Send Message or batch message to the Service Bus queue + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AzureServiceBusSendMessageOperator` + + :param queue_name: The name of the queue. should be unique. + :param message: Message which needs to be sent to the queue. It can be string or list of string. + :param batch: Its boolean flag by default it is set to False, if the message needs to be sent + as batch message it can be set to True. + :param azure_service_bus_conn_id: Reference to the + :ref: `Azure Service Bus connection`. + """ + + template_fields: Sequence[str] = ("queue_name",) + ui_color = "#e4f0e8" + + def __init__( + self, + *, + queue_name: str, + message: str | list[str], + batch: bool = False, + azure_service_bus_conn_id: str = "azure_service_bus_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.queue_name = queue_name + self.batch = batch + self.message = message + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: Context) -> None: + """ + Sends Message to the specific queue in Service Bus namespace, by + connecting to Service Bus client + """ + # Create the hook + hook = MessageHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # send message + hook.send_message(self.queue_name, self.message, self.batch) + + +class AzureServiceBusReceiveMessageOperator(BaseOperator): + """ + Receive a batch of messages at once in a specified Queue name + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AzureServiceBusReceiveMessageOperator` + + :param queue_name: The name of the queue name or a QueueProperties with name. + :param max_message_count: Maximum number of messages in the batch. + :param max_wait_time: Maximum time to wait in seconds for the first message to arrive. + :param azure_service_bus_conn_id: Reference to the + :ref: `Azure Service Bus connection `. + """ + + template_fields: Sequence[str] = ("queue_name",) + ui_color = "#e4f0e8" + + def __init__( + self, + *, + queue_name: str, + azure_service_bus_conn_id: str = "azure_service_bus_default", + max_message_count: int = 10, + max_wait_time: float = 5, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.queue_name = queue_name + self.azure_service_bus_conn_id = azure_service_bus_conn_id + self.max_message_count = max_message_count + self.max_wait_time = max_wait_time + + def execute(self, context: Context) -> None: + """ + Receive Message in specific queue in Service Bus namespace, + by connecting to Service Bus client + """ + # Create the hook + hook = MessageHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # Receive message + hook.receive_message( + self.queue_name, max_message_count=self.max_message_count, max_wait_time=self.max_wait_time + ) + + +class AzureServiceBusDeleteQueueOperator(BaseOperator): + """ + Deletes the Queue in the Azure Service Bus namespace + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AzureServiceBusDeleteQueueOperator` + + :param queue_name: The name of the queue in Service Bus namespace. + :param azure_service_bus_conn_id: Reference to the + :ref: `Azure Service Bus connection `. + """ + + template_fields: Sequence[str] = ("queue_name",) + ui_color = "#e4f0e8" + + def __init__( + self, + *, + queue_name: str, + azure_service_bus_conn_id: str = "azure_service_bus_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.queue_name = queue_name + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: Context) -> None: + """Delete Queue in Service Bus namespace, by connecting to Service Bus Admin client""" + # Create the hook + hook = AdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # delete queue with name + hook.delete_queue(self.queue_name) + + +class AzureServiceBusTopicCreateOperator(BaseOperator): + """ + Create an Azure Service Bus Topic under a Service Bus Namespace by using ServiceBusAdministrationClient + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AzureServiceBusTopicCreateOperator` + + :param topic_name: Name of the topic. + :param default_message_time_to_live: ISO 8601 default message time span to live value. This is + the duration after which the message expires, starting from when the message is sent to Service + Bus. This is the default value used when TimeToLive is not set on a message itself. + Input value of either type ~datetime.timedelta or string in ISO 8601 duration format + like "PT300S" is accepted. + :param max_size_in_megabytes: The maximum size of the topic in megabytes, which is the size of + memory allocated for the topic. + :param requires_duplicate_detection: A value indicating if this topic requires duplicate + detection. + :param duplicate_detection_history_time_window: ISO 8601 time span structure that defines the + duration of the duplicate detection history. The default value is 10 minutes. + Input value of either type ~datetime.timedelta or string in ISO 8601 duration format + like "PT300S" is accepted. + :param enable_batched_operations: Value that indicates whether server-side batched operations + are enabled. + :param size_in_bytes: The size of the topic, in bytes. + :param filtering_messages_before_publishing: Filter messages before publishing. + :param authorization_rules: List of Authorization rules for resource. + :param support_ordering: A value that indicates whether the topic supports ordering. + :param auto_delete_on_idle: ISO 8601 time span idle interval after which the topic is + automatically deleted. The minimum duration is 5 minutes. + Input value of either type ~datetime.timedelta or string in ISO 8601 duration format + like "PT300S" is accepted. + :param enable_partitioning: A value that indicates whether the topic is to be partitioned + across multiple message brokers. + :param enable_express: A value that indicates whether Express Entities are enabled. An express + queue holds a message in memory temporarily before writing it to persistent storage. + :param user_metadata: Metadata associated with the topic. + :param max_message_size_in_kilobytes: The maximum size in kilobytes of message payload that + can be accepted by the queue. This feature is only available when using a Premium namespace + and Service Bus API version "2021-05" or higher. + The minimum allowed value is 1024 while the maximum allowed value is 102400. Default value is 1024. + """ + + template_fields: Sequence[str] = ("topic_name",) + ui_color = "#e4f0e8" + + def __init__( + self, + *, + topic_name: str, + azure_service_bus_conn_id: str = "azure_service_bus_default", + default_message_time_to_live: datetime.timedelta | str | None = None, + max_size_in_megabytes: int | None = None, + requires_duplicate_detection: bool | None = None, + duplicate_detection_history_time_window: datetime.timedelta | str | None = None, + enable_batched_operations: bool | None = None, + size_in_bytes: int | None = None, + filtering_messages_before_publishing: bool | None = None, + authorization_rules: list[AuthorizationRule] | None = None, + support_ordering: bool | None = None, + auto_delete_on_idle: datetime.timedelta | str | None = None, + enable_partitioning: bool | None = None, + enable_express: bool | None = None, + user_metadata: str | None = None, + max_message_size_in_kilobytes: int | None = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.topic_name = topic_name + self.azure_service_bus_conn_id = azure_service_bus_conn_id + self.default_message_time_to_live = default_message_time_to_live + self.max_size_in_megabytes = max_size_in_megabytes + self.requires_duplicate_detection = requires_duplicate_detection + self.duplicate_detection_history_time_window = duplicate_detection_history_time_window + self.enable_batched_operations = enable_batched_operations + self.size_in_bytes = size_in_bytes + self.filtering_messages_before_publishing = filtering_messages_before_publishing + self.authorization_rules = authorization_rules + self.support_ordering = support_ordering + self.auto_delete_on_idle = auto_delete_on_idle + self.enable_partitioning = enable_partitioning + self.enable_express = enable_express + self.user_metadata = user_metadata + self.max_message_size_in_kilobytes = max_message_size_in_kilobytes + + def execute(self, context: Context) -> str: + """Creates Topic in Service Bus namespace, by connecting to Service Bus Admin client""" + if self.topic_name is None: + raise TypeError("Topic name cannot be None.") + + # Create the hook + hook = AdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + with hook.get_conn() as service_mgmt_conn: + topic_properties = service_mgmt_conn.get_topic(self.topic_name) + if topic_properties and topic_properties.name == self.topic_name: + self.log.info("Topic name already exists") + return topic_properties.name + topic = service_mgmt_conn.create_topic( + topic_name=self.topic_name, + default_message_time_to_live=self.default_message_time_to_live, + max_size_in_megabytes=self.max_size_in_megabytes, + requires_duplicate_detection=self.requires_duplicate_detection, + duplicate_detection_history_time_window=self.duplicate_detection_history_time_window, + enable_batched_operations=self.enable_batched_operations, + size_in_bytes=self.size_in_bytes, + filtering_messages_before_publishing=self.filtering_messages_before_publishing, + authorization_rules=self.authorization_rules, + support_ordering=self.support_ordering, + auto_delete_on_idle=self.auto_delete_on_idle, + enable_partitioning=self.enable_partitioning, + enable_express=self.enable_express, + user_metadata=self.user_metadata, + max_message_size_in_kilobytes=self.max_message_size_in_kilobytes, + ) + self.log.info("Created Topic %s", topic.name) + return topic.name + + +class AzureServiceBusSubscriptionCreateOperator(BaseOperator): + """ + Create an Azure Service Bus Topic Subscription under a Service Bus Namespace + by using ServiceBusAdministrationClient + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AzureServiceBusSubscriptionCreateOperator` + + :param topic_name: The topic that will own the to-be-created subscription. + :param subscription_name: Name of the subscription that need to be created + :param lock_duration: ISO 8601 time span duration of a peek-lock; that is, the amount of time that + the message is locked for other receivers. The maximum value for LockDuration is 5 minutes; the + default value is 1 minute. Input value of either type ~datetime.timedelta or string in ISO 8601 + duration format like "PT300S" is accepted. + :param requires_session: A value that indicates whether the queue supports the concept of sessions. + :param default_message_time_to_live: ISO 8601 default message time span to live value. This is the + duration after which the message expires, starting from when the message is sent to + Service Bus. This is the default value used when TimeToLive is not set on a message itself. + Input value of either type ~datetime.timedelta or string in ISO 8601 duration + format like "PT300S" is accepted. + :param dead_lettering_on_message_expiration: A value that indicates whether this subscription has + dead letter support when a message expires. + :param dead_lettering_on_filter_evaluation_exceptions: A value that indicates whether this + subscription has dead letter support when a message expires. + :param max_delivery_count: The maximum delivery count. A message is automatically dead lettered + after this number of deliveries. Default value is 10. + :param enable_batched_operations: Value that indicates whether server-side batched + operations are enabled. + :param forward_to: The name of the recipient entity to which all the messages sent to the + subscription are forwarded to. + :param user_metadata: Metadata associated with the subscription. Maximum number of characters is 1024. + :param forward_dead_lettered_messages_to: The name of the recipient entity to which all the + messages sent to the subscription are forwarded to. + :param auto_delete_on_idle: ISO 8601 time Span idle interval after which the subscription is + automatically deleted. The minimum duration is 5 minutes. Input value of either + type ~datetime.timedelta or string in ISO 8601 duration format like "PT300S" is accepted. + :param azure_service_bus_conn_id: Reference to the + :ref:`Azure Service Bus connection`. + """ + + template_fields: Sequence[str] = ("topic_name", "subscription_name") + ui_color = "#e4f0e8" + + def __init__( + self, + *, + topic_name: str, + subscription_name: str, + azure_service_bus_conn_id: str = "azure_service_bus_default", + lock_duration: datetime.timedelta | str | None = None, + requires_session: bool | None = None, + default_message_time_to_live: datetime.timedelta | str | None = None, + dead_lettering_on_message_expiration: bool | None = True, + dead_lettering_on_filter_evaluation_exceptions: bool | None = None, + max_delivery_count: int | None = 10, + enable_batched_operations: bool | None = True, + forward_to: str | None = None, + user_metadata: str | None = None, + forward_dead_lettered_messages_to: str | None = None, + auto_delete_on_idle: datetime.timedelta | str | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.topic_name = topic_name + self.subscription_name = subscription_name + self.lock_duration = lock_duration + self.requires_session = requires_session + self.default_message_time_to_live = default_message_time_to_live + self.dl_on_message_expiration = dead_lettering_on_message_expiration + self.dl_on_filter_evaluation_exceptions = dead_lettering_on_filter_evaluation_exceptions + self.max_delivery_count = max_delivery_count + self.enable_batched_operations = enable_batched_operations + self.forward_to = forward_to + self.user_metadata = user_metadata + self.forward_dead_lettered_messages_to = forward_dead_lettered_messages_to + self.auto_delete_on_idle = auto_delete_on_idle + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: Context) -> None: + """Creates Subscription in Service Bus namespace, by connecting to Service Bus Admin client""" + if self.subscription_name is None: + raise TypeError("Subscription name cannot be None.") + if self.topic_name is None: + raise TypeError("Topic name cannot be None.") + # Create the hook + hook = AdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + with hook.get_conn() as service_mgmt_conn: + # create subscription with name + subscription = service_mgmt_conn.create_subscription( + topic_name=self.topic_name, + subscription_name=self.subscription_name, + lock_duration=self.lock_duration, + requires_session=self.requires_session, + default_message_time_to_live=self.default_message_time_to_live, + dead_lettering_on_message_expiration=self.dl_on_message_expiration, + dead_lettering_on_filter_evaluation_exceptions=self.dl_on_filter_evaluation_exceptions, + max_delivery_count=self.max_delivery_count, + enable_batched_operations=self.enable_batched_operations, + forward_to=self.forward_to, + user_metadata=self.user_metadata, + forward_dead_lettered_messages_to=self.forward_dead_lettered_messages_to, + auto_delete_on_idle=self.auto_delete_on_idle, + ) + self.log.info("Created subscription %s", subscription.name) + + +class AzureServiceBusUpdateSubscriptionOperator(BaseOperator): + """ + Update an Azure ServiceBus Topic Subscription under a ServiceBus Namespace + by using ServiceBusAdministrationClient + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AzureServiceBusUpdateSubscriptionOperator` + + :param topic_name: The topic that will own the to-be-created subscription. + :param subscription_name: Name of the subscription that need to be created. + :param max_delivery_count: The maximum delivery count. A message is automatically dead lettered + after this number of deliveries. Default value is 10. + :param dead_lettering_on_message_expiration: A value that indicates whether this subscription + has dead letter support when a message expires. + :param enable_batched_operations: Value that indicates whether server-side batched + operations are enabled. + :param azure_service_bus_conn_id: Reference to the + :ref:`Azure Service Bus connection`. + """ + + template_fields: Sequence[str] = ("topic_name", "subscription_name") + ui_color = "#e4f0e8" + + def __init__( + self, + *, + topic_name: str, + subscription_name: str, + max_delivery_count: int | None = None, + dead_lettering_on_message_expiration: bool | None = None, + enable_batched_operations: bool | None = None, + azure_service_bus_conn_id: str = "azure_service_bus_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.topic_name = topic_name + self.subscription_name = subscription_name + self.max_delivery_count = max_delivery_count + self.dl_on_message_expiration = dead_lettering_on_message_expiration + self.enable_batched_operations = enable_batched_operations + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: Context) -> None: + """Updates Subscription properties, by connecting to Service Bus Admin client""" + hook = AdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + with hook.get_conn() as service_mgmt_conn: + subscription_prop = service_mgmt_conn.get_subscription(self.topic_name, self.subscription_name) + if self.max_delivery_count: + subscription_prop.max_delivery_count = self.max_delivery_count + if self.dl_on_message_expiration is not None: + subscription_prop.dead_lettering_on_message_expiration = self.dl_on_message_expiration + if self.enable_batched_operations is not None: + subscription_prop.enable_batched_operations = self.enable_batched_operations + # update by updating the properties in the model + service_mgmt_conn.update_subscription(self.topic_name, subscription_prop) + updated_subscription = service_mgmt_conn.get_subscription(self.topic_name, self.subscription_name) + self.log.info("Subscription Updated successfully %s", updated_subscription) + + +class ASBReceiveSubscriptionMessageOperator(BaseOperator): + """ + Receive a Batch messages from a Service Bus Subscription under specific Topic. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:ASBReceiveSubscriptionMessageOperator` + + :param subscription_name: The subscription name that will own the rule in topic + :param topic_name: The topic that will own the subscription rule. + :param max_message_count: Maximum number of messages in the batch. + Actual number returned will depend on prefetch_count and incoming stream rate. + Setting to None will fully depend on the prefetch config. The default value is 1. + :param max_wait_time: Maximum time to wait in seconds for the first message to arrive. If no + messages arrive, and no timeout is specified, this call will not return until the + connection is closed. If specified, an no messages arrive within the timeout period, + an empty list will be returned. + :param azure_service_bus_conn_id: Reference to the + :ref:`Azure Service Bus connection `. + """ + + template_fields: Sequence[str] = ("topic_name", "subscription_name") + ui_color = "#e4f0e8" + + def __init__( + self, + *, + topic_name: str, + subscription_name: str, + max_message_count: int | None = 1, + max_wait_time: float | None = 5, + azure_service_bus_conn_id: str = "azure_service_bus_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.topic_name = topic_name + self.subscription_name = subscription_name + self.max_message_count = max_message_count + self.max_wait_time = max_wait_time + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: Context) -> None: + """ + Receive Message in specific queue in Service Bus namespace, + by connecting to Service Bus client + """ + # Create the hook + hook = MessageHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # Receive message + hook.receive_subscription_message( + self.topic_name, self.subscription_name, self.max_message_count, self.max_wait_time + ) + + +class AzureServiceBusSubscriptionDeleteOperator(BaseOperator): + """ + Deletes the topic subscription in the Azure ServiceBus namespace + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AzureServiceBusSubscriptionDeleteOperator` + + :param topic_name: The topic that will own the to-be-created subscription. + :param subscription_name: Name of the subscription that need to be created + :param azure_service_bus_conn_id: Reference to the + :ref:`Azure Service Bus connection `. + """ + + template_fields: Sequence[str] = ("topic_name", "subscription_name") + ui_color = "#e4f0e8" + + def __init__( + self, + *, + topic_name: str, + subscription_name: str, + azure_service_bus_conn_id: str = "azure_service_bus_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.topic_name = topic_name + self.subscription_name = subscription_name + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: Context) -> None: + """Delete topic subscription in Service Bus namespace, by connecting to Service Bus Admin client""" + # Create the hook + hook = AdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + # delete subscription with name + hook.delete_subscription(self.subscription_name, self.topic_name) + + +class AzureServiceBusTopicDeleteOperator(BaseOperator): + """ + Deletes the topic in the Azure Service Bus namespace + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AzureServiceBusTopicDeleteOperator` + + :param topic_name: Name of the topic to be deleted. + :param azure_service_bus_conn_id: Reference to the + :ref:`Azure Service Bus connection `. + """ + + template_fields: Sequence[str] = ("topic_name",) + ui_color = "#e4f0e8" + + def __init__( + self, + *, + topic_name: str, + azure_service_bus_conn_id: str = "azure_service_bus_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.topic_name = topic_name + self.azure_service_bus_conn_id = azure_service_bus_conn_id + + def execute(self, context: Context) -> None: + """Delete topic in Service Bus namespace, by connecting to Service Bus Admin client""" + if self.topic_name is None: + raise TypeError("Topic name cannot be None.") + hook = AdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) + + with hook.get_conn() as service_mgmt_conn: + topic_properties = service_mgmt_conn.get_topic(self.topic_name) + if topic_properties and topic_properties.name == self.topic_name: + service_mgmt_conn.delete_topic(self.topic_name) + self.log.info("Topic %s deleted.", self.topic_name) + else: + self.log.info("Topic %s does not exist.", self.topic_name) diff --git a/airflow/providers/microsoft/azure/operators/azure_batch.py b/airflow/providers/microsoft/azure/operators/azure_batch.py deleted file mode 100644 index baa931e6c76e3..0000000000000 --- a/airflow/providers/microsoft/azure/operators/azure_batch.py +++ /dev/null @@ -1,27 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.operators.batch`.""" - -import warnings - -from airflow.providers.microsoft.azure.operators.batch import AzureBatchOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.operators.batch`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/microsoft/azure/operators/azure_container_instances.py b/airflow/providers/microsoft/azure/operators/azure_container_instances.py deleted file mode 100644 index bb9bd05f24831..0000000000000 --- a/airflow/providers/microsoft/azure/operators/azure_container_instances.py +++ /dev/null @@ -1,33 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.microsoft.azure.operators.container_instances`. -""" - -import warnings - -from airflow.providers.microsoft.azure.operators.container_instances import ( # noqa - AzureContainerInstancesOperator, -) - -warnings.warn( - "This module is deprecated. " - "Please use `airflow.providers.microsoft.azure.operators.container_instances`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/microsoft/azure/operators/azure_cosmos.py b/airflow/providers/microsoft/azure/operators/azure_cosmos.py deleted file mode 100644 index 8ef095350ea7b..0000000000000 --- a/airflow/providers/microsoft/azure/operators/azure_cosmos.py +++ /dev/null @@ -1,27 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.operators.cosmos`.""" - -import warnings - -from airflow.providers.microsoft.azure.operators.cosmos import AzureCosmosInsertDocumentOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.operators.cosmos`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/microsoft/azure/operators/batch.py b/airflow/providers/microsoft/azure/operators/batch.py index b1e3ee5cd3c8c..0a14993c58758 100644 --- a/airflow/providers/microsoft/azure/operators/batch.py +++ b/airflow/providers/microsoft/azure/operators/batch.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -from typing import TYPE_CHECKING, Any, List, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence from azure.batch import models as batch_models @@ -82,7 +83,7 @@ class AzureBatchOperator(BaseOperator): use_latest_image_and_sku is set to True :param vm_sku: The name of the virtual machine sku to use :param vm_version: The version of the virtual machine - :param vm_version: Optional[str] + :param vm_version: str | None :param vm_node_agent_sku_id: The node agent sku id of the virtual machine :param os_family: The Azure Guest OS family to be installed on the virtual machines in the Pool. :param os_version: The OS family version @@ -92,13 +93,13 @@ class AzureBatchOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'batch_pool_id', - 'batch_pool_vm_size', - 'batch_job_id', - 'batch_task_id', - 'batch_task_command_line', + "batch_pool_id", + "batch_pool_vm_size", + "batch_job_id", + "batch_task_id", + "batch_task_command_line", ) - ui_color = '#f0f0e4' + ui_color = "#f0f0e4" def __init__( self, @@ -108,31 +109,31 @@ def __init__( batch_job_id: str, batch_task_command_line: str, batch_task_id: str, - vm_publisher: Optional[str] = None, - vm_offer: Optional[str] = None, - sku_starts_with: Optional[str] = None, - vm_sku: Optional[str] = None, - vm_version: Optional[str] = None, - vm_node_agent_sku_id: Optional[str] = None, - os_family: Optional[str] = None, - os_version: Optional[str] = None, - batch_pool_display_name: Optional[str] = None, - batch_job_display_name: Optional[str] = None, - batch_job_manager_task: Optional[batch_models.JobManagerTask] = None, - batch_job_preparation_task: Optional[batch_models.JobPreparationTask] = None, - batch_job_release_task: Optional[batch_models.JobReleaseTask] = None, - batch_task_display_name: Optional[str] = None, - batch_task_container_settings: Optional[batch_models.TaskContainerSettings] = None, - batch_start_task: Optional[batch_models.StartTask] = None, + vm_node_agent_sku_id: str, + vm_publisher: str | None = None, + vm_offer: str | None = None, + sku_starts_with: str | None = None, + vm_sku: str | None = None, + vm_version: str | None = None, + os_family: str | None = None, + os_version: str | None = None, + batch_pool_display_name: str | None = None, + batch_job_display_name: str | None = None, + batch_job_manager_task: batch_models.JobManagerTask | None = None, + batch_job_preparation_task: batch_models.JobPreparationTask | None = None, + batch_job_release_task: batch_models.JobReleaseTask | None = None, + batch_task_display_name: str | None = None, + batch_task_container_settings: batch_models.TaskContainerSettings | None = None, + batch_start_task: batch_models.StartTask | None = None, batch_max_retries: int = 3, - batch_task_resource_files: Optional[List[batch_models.ResourceFile]] = None, - batch_task_output_files: Optional[List[batch_models.OutputFile]] = None, - batch_task_user_identity: Optional[batch_models.UserIdentity] = None, - target_low_priority_nodes: Optional[int] = None, - target_dedicated_nodes: Optional[int] = None, + batch_task_resource_files: list[batch_models.ResourceFile] | None = None, + batch_task_output_files: list[batch_models.OutputFile] | None = None, + batch_task_user_identity: batch_models.UserIdentity | None = None, + target_low_priority_nodes: int | None = None, + target_dedicated_nodes: int | None = None, enable_auto_scale: bool = False, - auto_scale_formula: Optional[str] = None, - azure_batch_conn_id='azure_batch_default', + auto_scale_formula: str | None = None, + azure_batch_conn_id="azure_batch_default", use_latest_verified_vm_image_and_sku: bool = False, timeout: int = 25, should_delete_job: bool = False, @@ -236,7 +237,7 @@ def _check_inputs(self) -> Any: "Some required parameters are missing.Please you must set all the required parameters. " ) - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: self._check_inputs() self.hook.connection.config.retry_policy = self.batch_max_retries @@ -292,17 +293,20 @@ def execute(self, context: "Context") -> None: # Add task to job self.hook.add_single_task_to_job(job_id=self.batch_job_id, task=task) # Wait for tasks to complete - self.hook.wait_for_job_tasks_to_complete(job_id=self.batch_job_id, timeout=self.timeout) + fail_tasks = self.hook.wait_for_job_tasks_to_complete(job_id=self.batch_job_id, timeout=self.timeout) # Clean up if self.should_delete_job: # delete job first self.clean_up(job_id=self.batch_job_id) if self.should_delete_pool: self.clean_up(self.batch_pool_id) + # raise exception if any task fail + if fail_tasks: + raise AirflowException(f"Job fail. The failed task are: {fail_tasks}") def on_kill(self) -> None: response = self.hook.connection.job.terminate( - job_id=self.batch_job_id, terminate_reason='Job killed by user' + job_id=self.batch_job_id, terminate_reason="Job killed by user" ) self.log.info("Azure Batch job (%s) terminated: %s", self.batch_job_id, response) @@ -310,7 +314,7 @@ def get_hook(self) -> AzureBatchHook: """Create and return an AzureBatchHook.""" return AzureBatchHook(azure_batch_conn_id=self.azure_batch_conn_id) - def clean_up(self, pool_id: Optional[str] = None, job_id: Optional[str] = None) -> None: + def clean_up(self, pool_id: str | None = None, job_id: str | None = None) -> None: """ Delete the given pool and job in the batch account diff --git a/airflow/providers/microsoft/azure/operators/container_instances.py b/airflow/providers/microsoft/azure/operators/container_instances.py index 519ce8fe41937..68b1cc57baa0a 100644 --- a/airflow/providers/microsoft/azure/operators/container_instances.py +++ b/airflow/providers/microsoft/azure/operators/container_instances.py @@ -15,15 +15,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import re from collections import namedtuple from time import sleep -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Sequence from azure.mgmt.containerinstance.models import ( Container, ContainerGroup, + ContainerGroupNetworkProfile, ContainerPort, EnvironmentVariable, IpAddress, @@ -44,12 +46,12 @@ Volume = namedtuple( - 'Volume', - ['conn_id', 'account_name', 'share_name', 'mount_path', 'read_only'], + "Volume", + ["conn_id", "account_name", "share_name", "mount_path", "read_only"], ) -DEFAULT_ENVIRONMENT_VARIABLES: Dict[str, str] = {} +DEFAULT_ENVIRONMENT_VARIABLES: dict[str, str] = {} DEFAULT_SECURED_VARIABLES: Sequence[str] = [] DEFAULT_VOLUMES: Sequence[Volume] = [] DEFAULT_MEMORY_IN_GB = 2.0 @@ -88,6 +90,7 @@ class AzureContainerInstancesOperator(BaseOperator): :param restart_policy: Restart policy for all containers within the container group. Possible values include: 'Always', 'OnFailure', 'Never' :param ip_address: The IP address type of the container group. + :param network_profile: The network profile information for a container group. **Example**:: @@ -116,32 +119,33 @@ class AzureContainerInstancesOperator(BaseOperator): ) """ - template_fields: Sequence[str] = ('name', 'image', 'command', 'environment_variables') + template_fields: Sequence[str] = ("name", "image", "command", "environment_variables") template_fields_renderers = {"command": "bash", "environment_variables": "json"} def __init__( self, *, ci_conn_id: str, - registry_conn_id: Optional[str], + registry_conn_id: str | None, resource_group: str, name: str, image: str, region: str, - environment_variables: Optional[dict] = None, - secured_variables: Optional[str] = None, - volumes: Optional[list] = None, - memory_in_gb: Optional[Any] = None, - cpu: Optional[Any] = None, - gpu: Optional[Any] = None, - command: Optional[List[str]] = None, + environment_variables: dict | None = None, + secured_variables: str | None = None, + volumes: list | None = None, + memory_in_gb: Any | None = None, + cpu: Any | None = None, + gpu: Any | None = None, + command: list[str] | None = None, remove_on_error: bool = True, fail_if_exists: bool = True, - tags: Optional[Dict[str, str]] = None, - os_type: str = 'Linux', - restart_policy: str = 'Never', - ip_address: Optional[IpAddress] = None, - ports: Optional[List[ContainerPort]] = None, + tags: dict[str, str] | None = None, + os_type: str = "Linux", + restart_policy: str = "Never", + ip_address: IpAddress | None = None, + ports: list[ContainerPort] | None = None, + network_profile: ContainerGroupNetworkProfile | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -164,14 +168,14 @@ def __init__( self._ci_hook: Any = None self.tags = tags self.os_type = os_type - if self.os_type not in ['Linux', 'Windows']: + if self.os_type not in ["Linux", "Windows"]: raise AirflowException( "Invalid value for the os_type argument. " "Please set 'Linux' or 'Windows' as the os_type. " f"Found `{self.os_type}`." ) self.restart_policy = restart_policy - if self.restart_policy not in ['Always', 'OnFailure', 'Never']: + if self.restart_policy not in ["Always", "OnFailure", "Never"]: raise AirflowException( "Invalid value for the restart_policy argument. " "Please set one of 'Always', 'OnFailure','Never' as the restart_policy. " @@ -179,12 +183,13 @@ def __init__( ) self.ip_address = ip_address self.ports = ports + self.network_profile = network_profile - def execute(self, context: "Context") -> int: + def execute(self, context: Context) -> int: # Check name again in case it was templated. self._check_name(self.name) - self._ci_hook = AzureContainerInstanceHook(conn_id=self.ci_conn_id) + self._ci_hook = AzureContainerInstanceHook(azure_conn_id=self.ci_conn_id) if self.fail_if_exists: self.log.info("Testing if container group already exists") @@ -193,7 +198,7 @@ def execute(self, context: "Context") -> int: if self.registry_conn_id: registry_hook = AzureContainerRegistryHook(self.registry_conn_id) - image_registry_credentials: Optional[list] = [ + image_registry_credentials: list | None = [ registry_hook.connection, ] else: @@ -207,8 +212,8 @@ def execute(self, context: "Context") -> int: e = EnvironmentVariable(name=key, value=value) environment_variables.append(e) - volumes: List[Union[Volume, Volume]] = [] - volume_mounts: List[Union[VolumeMount, VolumeMount]] = [] + volumes: list[Volume | Volume] = [] + volume_mounts: list[VolumeMount | VolumeMount] = [] for conn_id, account_name, share_name, mount_path, read_only in self.volumes: hook = AzureContainerVolumeHook(conn_id) @@ -251,6 +256,7 @@ def execute(self, context: "Context") -> int: os_type=self.os_type, tags=self.tags, ip_address=self.ip_address, + network_profile=self.network_profile, ) self._ci_hook.create_or_update(self.resource_group, self.name, container_group) @@ -331,7 +337,7 @@ def _monitor_logging(self, resource_group: str, name: str) -> int: except AirflowTaskTimeout: raise except CloudError as err: - if 'ResourceNotFound' in str(err): + if "ResourceNotFound" in str(err): self.log.warning( "ResourceNotFound, container is probably removed " "by another process " @@ -345,7 +351,7 @@ def _monitor_logging(self, resource_group: str, name: str) -> int: sleep(1) - def _log_last(self, logs: Optional[list], last_line_logged: Any) -> Optional[Any]: + def _log_last(self, logs: list | None, last_line_logged: Any) -> Any | None: if logs: # determine the last line which was logged before last_line_index = 0 @@ -364,12 +370,12 @@ def _log_last(self, logs: Optional[list], last_line_logged: Any) -> Optional[Any @staticmethod def _check_name(name: str) -> str: - if '{{' in name: + if "{{" in name: # Let macros pass as they cannot be checked at construction time return name regex_check = re.match("[a-z0-9]([-a-z0-9]*[a-z0-9])?", name) if regex_check is None or regex_check.group() != name: raise AirflowException('ACI name must match regex [a-z0-9]([-a-z0-9]*[a-z0-9])? (like "my-name")') if len(name) > 63: - raise AirflowException('ACI name cannot be longer than 63 characters') + raise AirflowException("ACI name cannot be longer than 63 characters") return name diff --git a/airflow/providers/microsoft/azure/operators/cosmos.py b/airflow/providers/microsoft/azure/operators/cosmos.py index ef3638c6d7d85..674a8520531f2 100644 --- a/airflow/providers/microsoft/azure/operators/cosmos.py +++ b/airflow/providers/microsoft/azure/operators/cosmos.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator @@ -36,8 +38,8 @@ class AzureCosmosInsertDocumentOperator(BaseOperator): :ref:`Azure CosmosDB connection`. """ - template_fields: Sequence[str] = ('database_name', 'collection_name') - ui_color = '#e4f0e8' + template_fields: Sequence[str] = ("database_name", "collection_name") + ui_color = "#e4f0e8" def __init__( self, @@ -45,7 +47,7 @@ def __init__( database_name: str, collection_name: str, document: dict, - azure_cosmos_conn_id: str = 'azure_cosmos_default', + azure_cosmos_conn_id: str = "azure_cosmos_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -54,7 +56,7 @@ def __init__( self.document = document self.azure_cosmos_conn_id = azure_cosmos_conn_id - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: # Create the hook hook = AzureCosmosDBHook(azure_cosmos_conn_id=self.azure_cosmos_conn_id) diff --git a/airflow/providers/microsoft/azure/operators/data_factory.py b/airflow/providers/microsoft/azure/operators/data_factory.py index 488ccbced0702..fa29d38a06711 100644 --- a/airflow/providers/microsoft/azure/operators/data_factory.py +++ b/airflow/providers/microsoft/azure/operators/data_factory.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence +from typing import TYPE_CHECKING, Any, Sequence from airflow.hooks.base import BaseHook from airflow.models import BaseOperator, BaseOperatorLink, XCom @@ -23,46 +24,40 @@ AzureDataFactoryHook, AzureDataFactoryPipelineRunException, AzureDataFactoryPipelineRunStatus, + get_field, ) +from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstanceKey from airflow.utils.context import Context -class AzureDataFactoryPipelineRunLink(BaseOperatorLink): +class AzureDataFactoryPipelineRunLink(LoggingMixin, BaseOperatorLink): """Constructs a link to monitor a pipeline run in Azure Data Factory.""" name = "Monitor Pipeline Run" def get_link( self, - operator, - dttm=None, + operator: BaseOperator, *, - ti_key: Optional["TaskInstanceKey"] = None, + ti_key: TaskInstanceKey, ) -> str: - if ti_key is not None: - run_id = XCom.get_value(key="run_id", ti_key=ti_key) - else: - assert dttm - run_id = XCom.get_one( - key="run_id", - dag_id=operator.dag.dag_id, - task_id=operator.task_id, - execution_date=dttm, - ) - - conn = BaseHook.get_connection(operator.azure_data_factory_conn_id) - subscription_id = conn.extra_dejson["extra__azure_data_factory__subscriptionId"] + if not isinstance(operator, AzureDataFactoryRunPipelineOperator): + self.log.info("The %s is not %s class.", operator.__class__, AzureDataFactoryRunPipelineOperator) + return "" + run_id = XCom.get_value(key="run_id", ti_key=ti_key) + conn_id = operator.azure_data_factory_conn_id + conn = BaseHook.get_connection(conn_id) + extras = conn.extra_dejson + subscription_id = get_field(extras, "subscriptionId") + if not subscription_id: + raise KeyError(f"Param subscriptionId not found in conn_id '{conn_id}'") # Both Resource Group Name and Factory Name can either be declared in the Azure Data Factory # connection or passed directly to the operator. - resource_group_name = operator.resource_group_name or conn.extra_dejson.get( - "extra__azure_data_factory__resource_group_name" - ) - factory_name = operator.factory_name or conn.extra_dejson.get( - "extra__azure_data_factory__factory_name" - ) + resource_group_name = operator.resource_group_name or get_field(extras, "resource_group_name") + factory_name = operator.factory_name or get_field(extras, "factory_name") url = ( f"https://adf.azure.com/en-us/monitoring/pipelineruns/{run_id}" f"?factory=/subscriptions/{subscription_id}/" @@ -129,13 +124,13 @@ def __init__( pipeline_name: str, azure_data_factory_conn_id: str = AzureDataFactoryHook.default_conn_name, wait_for_termination: bool = True, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, - reference_pipeline_run_id: Optional[str] = None, - is_recovery: Optional[bool] = None, - start_activity_name: Optional[str] = None, - start_from_failure: Optional[bool] = None, - parameters: Optional[Dict[str, Any]] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, + reference_pipeline_run_id: str | None = None, + is_recovery: bool | None = None, + start_activity_name: str | None = None, + start_from_failure: bool | None = None, + parameters: dict[str, Any] | None = None, timeout: int = 60 * 60 * 24 * 7, check_interval: int = 60, **kwargs, @@ -154,7 +149,7 @@ def __init__( self.timeout = timeout self.check_interval = check_interval - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: self.hook = AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id) self.log.info("Executing the %s pipeline.", self.pipeline_name) response = self.hook.run_pipeline( diff --git a/airflow/providers/microsoft/azure/operators/synapse.py b/airflow/providers/microsoft/azure/operators/synapse.py new file mode 100644 index 0000000000000..b9d97704c57f6 --- /dev/null +++ b/airflow/providers/microsoft/azure/operators/synapse.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence + +from azure.synapse.spark.models import SparkBatchJobOptions + +from airflow.models import BaseOperator +from airflow.providers.microsoft.azure.hooks.synapse import AzureSynapseHook, AzureSynapseSparkBatchRunStatus + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class AzureSynapseRunSparkBatchOperator(BaseOperator): + """ + Executes a Spark job on Azure Synapse. + + .. see also:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AzureSynapseRunSparkBatchOperator` + + :param azure_synapse_conn_id: The connection identifier for connecting to Azure Synapse. + :param wait_for_termination: Flag to wait on a job run's termination. + :param spark_pool: The target synapse spark pool used to submit the job + :param payload: Livy compatible payload which represents the spark job that a user wants to submit + :param timeout: Time in seconds to wait for a job to reach a terminal status for non-asynchronous + waits. Used only if ``wait_for_termination`` is True. + :param check_interval: Time in seconds to check on a job run's status for non-asynchronous waits. + Used only if ``wait_for_termination`` is True. + """ + + template_fields: Sequence[str] = ( + "azure_synapse_conn_id", + "spark_pool", + ) + template_fields_renderers = {"parameters": "json"} + + ui_color = "#0678d4" + + def __init__( + self, + *, + azure_synapse_conn_id: str = AzureSynapseHook.default_conn_name, + wait_for_termination: bool = True, + spark_pool: str = "", + payload: SparkBatchJobOptions, + timeout: int = 60 * 60 * 24 * 7, + check_interval: int = 60, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.job_id = None + self.azure_synapse_conn_id = azure_synapse_conn_id + self.wait_for_termination = wait_for_termination + self.spark_pool = spark_pool + self.payload = payload + self.timeout = timeout + self.check_interval = check_interval + + def execute(self, context: Context) -> None: + self.hook = AzureSynapseHook( + azure_synapse_conn_id=self.azure_synapse_conn_id, spark_pool=self.spark_pool + ) + self.log.info("Executing the Synapse spark job.") + response = self.hook.run_spark_job(payload=self.payload) + self.log.info(response) + self.job_id = vars(response)["id"] + # Push the ``job_id`` value to XCom regardless of what happens during execution. This allows for + # retrieval the executed job's ``id`` for downstream tasks especially if performing an + # asynchronous wait. + context["ti"].xcom_push(key="job_id", value=self.job_id) + + if self.wait_for_termination: + self.log.info("Waiting for job run %s to terminate.", self.job_id) + + if self.hook.wait_for_job_run_status( + job_id=self.job_id, + expected_statuses=AzureSynapseSparkBatchRunStatus.SUCCESS, + check_interval=self.check_interval, + timeout=self.timeout, + ): + self.log.info("Job run %s has completed successfully.", self.job_id) + else: + raise Exception(f"Job run {self.job_id} has failed or has been cancelled.") + + def on_kill(self) -> None: + if self.job_id: + self.hook.cancel_job_run( + job_id=self.job_id, + ) + self.log.info("Job run %s has been cancelled successfully.", self.job_id) diff --git a/airflow/providers/microsoft/azure/operators/wasb_delete_blob.py b/airflow/providers/microsoft/azure/operators/wasb_delete_blob.py index 1242c59593d33..7015746a825cd 100644 --- a/airflow/providers/microsoft/azure/operators/wasb_delete_blob.py +++ b/airflow/providers/microsoft/azure/operators/wasb_delete_blob.py @@ -15,7 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations + from typing import TYPE_CHECKING, Any, Sequence from airflow.models import BaseOperator @@ -39,14 +40,14 @@ class WasbDeleteBlobOperator(BaseOperator): blob does not exist. """ - template_fields: Sequence[str] = ('container_name', 'blob_name') + template_fields: Sequence[str] = ("container_name", "blob_name") def __init__( self, *, container_name: str, blob_name: str, - wasb_conn_id: str = 'wasb_default', + wasb_conn_id: str = "wasb_default", check_options: Any = None, is_prefix: bool = False, ignore_if_missing: bool = False, @@ -62,8 +63,8 @@ def __init__( self.is_prefix = is_prefix self.ignore_if_missing = ignore_if_missing - def execute(self, context: "Context") -> None: - self.log.info('Deleting blob: %s\n in wasb://%s', self.blob_name, self.container_name) + def execute(self, context: Context) -> None: + self.log.info("Deleting blob: %s\n in wasb://%s", self.blob_name, self.container_name) hook = WasbHook(wasb_conn_id=self.wasb_conn_id) hook.delete_file( diff --git a/airflow/providers/microsoft/azure/provider.yaml b/airflow/providers/microsoft/azure/provider.yaml index 10cab462b4c53..7d8341f77f83a 100644 --- a/airflow/providers/microsoft/azure/provider.yaml +++ b/airflow/providers/microsoft/azure/provider.yaml @@ -21,6 +21,11 @@ name: Microsoft Azure description: | `Microsoft Azure `__ versions: + - 5.0.0 + - 4.3.0 + - 4.2.0 + - 4.1.0 + - 4.0.0 - 3.9.0 - 3.8.0 - 3.7.2 @@ -40,8 +45,27 @@ versions: - 1.1.0 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - azure-batch>=8.0.0 + - azure-cosmos>=4.0.0 + - azure-datalake-store>=0.0.45 + - azure-identity>=1.3.1 + - azure-keyvault-secrets>=4.1.0,<5.0 + - azure-kusto-data>=0.0.43,<0.1 + # Azure integration uses old libraries and the limits below reflect that + # TODO: upgrade to newer versions of all the below libraries + - azure-mgmt-containerinstance>=1.5.0,<2.0 + - azure-mgmt-datafactory>=1.0.0,<2.0 + - azure-mgmt-datalake-store>=0.5.0 + - azure-mgmt-resource>=2.2.0 + - azure-storage-blob>=12.14.0 + - azure-storage-common>=2.1.0 + - azure-storage-file>=2.1.0 + # Limited due to https://github.com/Azure/azure-uamqp-python/issues/191 + - azure-servicebus>=7.6.1; platform_machine != "aarch64" + - azure-synapse-spark + - adal>=1.2.7 integrations: - integration-name: Microsoft Azure Batch @@ -88,12 +112,21 @@ integrations: external-doc-url: https://azure.microsoft.com/ logo: /integration-logos/azure/Microsoft-Azure.png tags: [azure] + - integration-name: Microsoft Azure Service Bus + external-doc-url: https://azure.microsoft.com/en-us/services/service-bus/ + logo: /integration-logos/azure/Service-Bus.svg + how-to-guide: + - /docs/apache-airflow-providers-microsoft-azure/operators/asb.rst + tags: [azure] + - integration-name: Microsoft Azure Synapse + external-doc-url: https://azure.microsoft.com/en-us/services/synapse-analytics/ + how-to-guide: + - /docs/apache-airflow-providers-microsoft-azure/operators/azure_synapse.rst + tags: [azure] operators: - integration-name: Microsoft Azure Data Lake Storage python-modules: - - airflow.providers.microsoft.azure.operators.adls_list - - airflow.providers.microsoft.azure.operators.adls_delete - airflow.providers.microsoft.azure.operators.adls - integration-name: Microsoft Azure Data Explorer python-modules: @@ -101,27 +134,29 @@ operators: - integration-name: Microsoft Azure Batch python-modules: - airflow.providers.microsoft.azure.operators.batch - - airflow.providers.microsoft.azure.operators.azure_batch - integration-name: Microsoft Azure Container Instances python-modules: - airflow.providers.microsoft.azure.operators.container_instances - - airflow.providers.microsoft.azure.operators.azure_container_instances - integration-name: Microsoft Azure Cosmos DB python-modules: - airflow.providers.microsoft.azure.operators.cosmos - - airflow.providers.microsoft.azure.operators.azure_cosmos - integration-name: Microsoft Azure Blob Storage python-modules: - airflow.providers.microsoft.azure.operators.wasb_delete_blob - integration-name: Microsoft Azure Data Factory python-modules: - airflow.providers.microsoft.azure.operators.data_factory + - integration-name: Microsoft Azure Service Bus + python-modules: + - airflow.providers.microsoft.azure.operators.asb + - integration-name: Microsoft Azure Synapse + python-modules: + - airflow.providers.microsoft.azure.operators.synapse sensors: - integration-name: Microsoft Azure Cosmos DB python-modules: - airflow.providers.microsoft.azure.sensors.cosmos - - airflow.providers.microsoft.azure.sensors.azure_cosmos - integration-name: Microsoft Azure Blob Storage python-modules: - airflow.providers.microsoft.azure.sensors.wasb @@ -135,38 +170,36 @@ hooks: - airflow.providers.microsoft.azure.hooks.container_volume - airflow.providers.microsoft.azure.hooks.container_registry - airflow.providers.microsoft.azure.hooks.container_instance - - airflow.providers.microsoft.azure.hooks.azure_container_volume - - airflow.providers.microsoft.azure.hooks.azure_container_registry - - airflow.providers.microsoft.azure.hooks.azure_container_instance - integration-name: Microsoft Azure Data Explorer python-modules: - airflow.providers.microsoft.azure.hooks.adx - integration-name: Microsoft Azure FileShare python-modules: - airflow.providers.microsoft.azure.hooks.fileshare - - airflow.providers.microsoft.azure.hooks.azure_fileshare - integration-name: Microsoft Azure python-modules: - airflow.providers.microsoft.azure.hooks.base_azure - integration-name: Microsoft Azure Batch python-modules: - airflow.providers.microsoft.azure.hooks.batch - - airflow.providers.microsoft.azure.hooks.azure_batch - integration-name: Microsoft Azure Data Lake Storage python-modules: - airflow.providers.microsoft.azure.hooks.data_lake - - airflow.providers.microsoft.azure.hooks.azure_data_lake - integration-name: Microsoft Azure Cosmos DB python-modules: - airflow.providers.microsoft.azure.hooks.cosmos - - airflow.providers.microsoft.azure.hooks.azure_cosmos - integration-name: Microsoft Azure Blob Storage python-modules: - airflow.providers.microsoft.azure.hooks.wasb - integration-name: Microsoft Azure Data Factory python-modules: - airflow.providers.microsoft.azure.hooks.data_factory - - airflow.providers.microsoft.azure.hooks.azure_data_factory + - integration-name: Microsoft Azure Service Bus + python-modules: + - airflow.providers.microsoft.azure.hooks.asb + - integration-name: Microsoft Azure Synapse + python-modules: + - airflow.providers.microsoft.azure.hooks.synapse transfers: - source-integration-name: Local @@ -176,9 +209,6 @@ transfers: - source-integration-name: Oracle target-integration-name: Microsoft Azure Data Lake Storage python-module: airflow.providers.microsoft.azure.transfers.oracle_to_azure_data_lake - - source-integration-name: Local - target-integration-name: Microsoft Azure Blob Storage - python-module: airflow.providers.microsoft.azure.transfers.file_to_wasb - source-integration-name: Local target-integration-name: Microsoft Azure Blob Storage python-module: airflow.providers.microsoft.azure.transfers.local_to_wasb @@ -191,18 +221,6 @@ transfers: how-to-guide: /docs/apache-airflow-providers-microsoft-azure/operators/sftp_to_wasb.rst python-module: airflow.providers.microsoft.azure.transfers.sftp_to_wasb -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook - - airflow.providers.microsoft.azure.hooks.adx.AzureDataExplorerHook - - airflow.providers.microsoft.azure.hooks.batch.AzureBatchHook - - airflow.providers.microsoft.azure.hooks.cosmos.AzureCosmosDBHook - - airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeHook - - airflow.providers.microsoft.azure.hooks.fileshare.AzureFileShareHook - - airflow.providers.microsoft.azure.hooks.container_volume.AzureContainerVolumeHook - - airflow.providers.microsoft.azure.hooks.container_instance.AzureContainerInstanceHook - - airflow.providers.microsoft.azure.hooks.wasb.WasbHook - - airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook - - airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook connection-types: - hook-class-name: airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook @@ -229,10 +247,13 @@ connection-types: - hook-class-name: >- airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook connection-type: azure_container_registry + - hook-class-name: airflow.providers.microsoft.azure.hooks.asb.BaseAzureServiceBusHook + connection-type: azure_service_bus + - hook-class-name: airflow.providers.microsoft.azure.hooks.synapse.AzureSynapseHook + connection-type: azure_synapse secrets-backends: - airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend - - airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend logging: - airflow.providers.microsoft.azure.log.wasb_task_handler.WasbTaskHandler diff --git a/airflow/providers/microsoft/azure/secrets/azure_key_vault.py b/airflow/providers/microsoft/azure/secrets/azure_key_vault.py deleted file mode 100644 index f15ed17fde4c2..0000000000000 --- a/airflow/providers/microsoft/azure/secrets/azure_key_vault.py +++ /dev/null @@ -1,27 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.secrets.key_vault`.""" - -import warnings - -from airflow.providers.microsoft.azure.secrets.key_vault import AzureKeyVaultBackend # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.secrets.key_vault`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/microsoft/azure/secrets/key_vault.py b/airflow/providers/microsoft/azure/secrets/key_vault.py index 0dde708a8d123..fb9f2d690afc5 100644 --- a/airflow/providers/microsoft/azure/secrets/key_vault.py +++ b/airflow/providers/microsoft/azure/secrets/key_vault.py @@ -14,29 +14,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import re -import sys import warnings -from typing import Optional from azure.core.exceptions import ResourceNotFoundError from azure.identity import DefaultAzureCredential from azure.keyvault.secrets import SecretClient -from airflow.version import version as airflow_version - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - +from airflow.compat.functools import cached_property from airflow.secrets import BaseSecretsBackend from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.version import version as airflow_version def _parse_version(val): - val = re.sub(r'(\d+\.\d+\.\d+).*', lambda x: x.group(1), val) - return tuple(int(x) for x in val.split('.')) + val = re.sub(r"(\d+\.\d+\.\d+).*", lambda x: x.group(1), val) + return tuple(int(x) for x in val.split(".")) class AzureKeyVaultBackend(BaseSecretsBackend, LoggingMixin): @@ -78,11 +73,11 @@ class AzureKeyVaultBackend(BaseSecretsBackend, LoggingMixin): def __init__( self, - connections_prefix: str = 'airflow-connections', - variables_prefix: str = 'airflow-variables', - config_prefix: str = 'airflow-config', - vault_url: str = '', - sep: str = '-', + connections_prefix: str = "airflow-connections", + variables_prefix: str = "airflow-variables", + config_prefix: str = "airflow-config", + vault_url: str = "", + sep: str = "-", **kwargs, ) -> None: super().__init__() @@ -109,7 +104,7 @@ def client(self) -> SecretClient: client = SecretClient(vault_url=self.vault_url, credential=credential, **self.kwargs) return client - def get_conn_value(self, conn_id: str) -> Optional[str]: + def get_conn_value(self, conn_id: str) -> str | None: """ Get a serialized representation of Airflow Connection from an Azure Key Vault secret @@ -120,7 +115,7 @@ def get_conn_value(self, conn_id: str) -> Optional[str]: return self._get_secret(self.connections_prefix, conn_id) - def get_conn_uri(self, conn_id: str) -> Optional[str]: + def get_conn_uri(self, conn_id: str) -> str | None: """ Return URI representation of Connection conn_id. @@ -138,7 +133,7 @@ def get_conn_uri(self, conn_id: str) -> Optional[str]: ) return self.get_conn_value(conn_id) - def get_variable(self, key: str) -> Optional[str]: + def get_variable(self, key: str) -> str | None: """ Get an Airflow Variable from an Azure Key Vault secret. @@ -150,7 +145,7 @@ def get_variable(self, key: str) -> Optional[str]: return self._get_secret(self.variables_prefix, key) - def get_config(self, key: str) -> Optional[str]: + def get_config(self, key: str) -> str | None: """ Get Airflow Configuration @@ -163,7 +158,7 @@ def get_config(self, key: str) -> Optional[str]: return self._get_secret(self.config_prefix, key) @staticmethod - def build_path(path_prefix: str, secret_id: str, sep: str = '-') -> str: + def build_path(path_prefix: str, secret_id: str, sep: str = "-") -> str: """ Given a path_prefix and secret_id, build a valid secret name for the Azure Key Vault Backend. Also replaces underscore in the path with dashes to support easy switching between @@ -173,10 +168,14 @@ def build_path(path_prefix: str, secret_id: str, sep: str = '-') -> str: :param secret_id: Name of the secret :param sep: Separator used to concatenate path_prefix and secret_id """ - path = f'{path_prefix}{sep}{secret_id}' - return path.replace('_', sep) + # When an empty prefix is given, do not add a separator to the secret name + if path_prefix == "": + path = f"{secret_id}" + else: + path = f"{path_prefix}{sep}{secret_id}" + return path.replace("_", sep) - def _get_secret(self, path_prefix: str, secret_id: str) -> Optional[str]: + def _get_secret(self, path_prefix: str, secret_id: str) -> str | None: """ Get an Azure Key Vault secret value @@ -188,5 +187,5 @@ def _get_secret(self, path_prefix: str, secret_id: str) -> Optional[str]: secret = self.client.get_secret(name=name) return secret.value except ResourceNotFoundError as ex: - self.log.debug('Secret %s not found: %s', name, ex) + self.log.debug("Secret %s not found: %s", name, ex) return None diff --git a/airflow/providers/microsoft/azure/sensors/azure_cosmos.py b/airflow/providers/microsoft/azure/sensors/azure_cosmos.py deleted file mode 100644 index 0adeddac60a26..0000000000000 --- a/airflow/providers/microsoft/azure/sensors/azure_cosmos.py +++ /dev/null @@ -1,27 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.sensors.cosmos`.""" - -import warnings - -from airflow.providers.microsoft.azure.sensors.cosmos import AzureCosmosDocumentSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.sensors.cosmos`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/microsoft/azure/sensors/cosmos.py b/airflow/providers/microsoft/azure/sensors/cosmos.py index 295aa6d111a99..2692b9ed86f8a 100644 --- a/airflow/providers/microsoft/azure/sensors/cosmos.py +++ b/airflow/providers/microsoft/azure/sensors/cosmos.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from typing import TYPE_CHECKING, Sequence from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook @@ -46,7 +48,7 @@ class AzureCosmosDocumentSensor(BaseSensorOperator): :ref:`Azure CosmosDB connection`. """ - template_fields: Sequence[str] = ('database_name', 'collection_name', 'document_id') + template_fields: Sequence[str] = ("database_name", "collection_name", "document_id") def __init__( self, @@ -63,7 +65,7 @@ def __init__( self.collection_name = collection_name self.document_id = document_id - def poke(self, context: "Context") -> bool: + def poke(self, context: Context) -> bool: self.log.info("*** Entering poke") hook = AzureCosmosDBHook(self.azure_cosmos_conn_id) return hook.get_document(self.document_id, self.database_name, self.collection_name) is not None diff --git a/airflow/providers/microsoft/azure/sensors/data_factory.py b/airflow/providers/microsoft/azure/sensors/data_factory.py index ab328986777a8..9d550405a4d57 100644 --- a/airflow/providers/microsoft/azure/sensors/data_factory.py +++ b/airflow/providers/microsoft/azure/sensors/data_factory.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Sequence from airflow.providers.microsoft.azure.hooks.data_factory import ( AzureDataFactoryHook, @@ -52,8 +53,8 @@ def __init__( *, run_id: str, azure_data_factory_conn_id: str = AzureDataFactoryHook.default_conn_name, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str | None = None, + factory_name: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -62,7 +63,7 @@ def __init__( self.resource_group_name = resource_group_name self.factory_name = factory_name - def poke(self, context: "Context") -> bool: + def poke(self, context: Context) -> bool: self.hook = AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id) pipeline_run_status = self.hook.get_pipeline_run_status( run_id=self.run_id, diff --git a/airflow/providers/microsoft/azure/sensors/wasb.py b/airflow/providers/microsoft/azure/sensors/wasb.py index 5deda098d4e51..388a571f7d7a4 100644 --- a/airflow/providers/microsoft/azure/sensors/wasb.py +++ b/airflow/providers/microsoft/azure/sensors/wasb.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -from typing import TYPE_CHECKING, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.providers.microsoft.azure.hooks.wasb import WasbHook from airflow.sensors.base import BaseSensorOperator @@ -36,15 +37,15 @@ class WasbBlobSensor(BaseSensorOperator): `WasbHook.check_for_blob()` takes. """ - template_fields: Sequence[str] = ('container_name', 'blob_name') + template_fields: Sequence[str] = ("container_name", "blob_name") def __init__( self, *, container_name: str, blob_name: str, - wasb_conn_id: str = 'wasb_default', - check_options: Optional[dict] = None, + wasb_conn_id: str = "wasb_default", + check_options: dict | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -55,8 +56,8 @@ def __init__( self.blob_name = blob_name self.check_options = check_options - def poke(self, context: "Context"): - self.log.info('Poking for blob: %s\n in wasb://%s', self.blob_name, self.container_name) + def poke(self, context: Context): + self.log.info("Poking for blob: %s\n in wasb://%s", self.blob_name, self.container_name) hook = WasbHook(wasb_conn_id=self.wasb_conn_id) return hook.check_for_blob(self.container_name, self.blob_name, **self.check_options) @@ -72,15 +73,15 @@ class WasbPrefixSensor(BaseSensorOperator): `WasbHook.check_for_prefix()` takes. """ - template_fields: Sequence[str] = ('container_name', 'prefix') + template_fields: Sequence[str] = ("container_name", "prefix") def __init__( self, *, container_name: str, prefix: str, - wasb_conn_id: str = 'wasb_default', - check_options: Optional[dict] = None, + wasb_conn_id: str = "wasb_default", + check_options: dict | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -91,7 +92,7 @@ def __init__( self.prefix = prefix self.check_options = check_options - def poke(self, context: "Context") -> bool: - self.log.info('Poking for prefix: %s in wasb://%s', self.prefix, self.container_name) + def poke(self, context: Context) -> bool: + self.log.info("Poking for prefix: %s in wasb://%s", self.prefix, self.container_name) hook = WasbHook(wasb_conn_id=self.wasb_conn_id) return hook.check_for_prefix(self.container_name, self.prefix, **self.check_options) diff --git a/airflow/providers/microsoft/azure/transfers/azure_blob_to_gcs.py b/airflow/providers/microsoft/azure/transfers/azure_blob_to_gcs.py index 370bdfd146a3c..f22c17c059077 100644 --- a/airflow/providers/microsoft/azure/transfers/azure_blob_to_gcs.py +++ b/airflow/providers/microsoft/azure/transfers/azure_blob_to_gcs.py @@ -15,9 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations + import tempfile -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook @@ -60,7 +61,7 @@ class AzureBlobStorageToGCSOperator(BaseOperator): def __init__( self, *, - wasb_conn_id='wasb_default', + wasb_conn_id="wasb_default", gcp_conn_id: str = "google_cloud_default", blob_name: str, file_path: str, @@ -69,8 +70,8 @@ def __init__( object_name: str, filename: str, gzip: bool, - delegate_to: Optional[str], - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + delegate_to: str | None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -95,7 +96,7 @@ def __init__( "filename", ) - def execute(self, context: "Context") -> str: + def execute(self, context: Context) -> str: azure_hook = WasbHook(wasb_conn_id=self.wasb_conn_id) gcs_hook = GCSHook( gcp_conn_id=self.gcp_conn_id, diff --git a/airflow/providers/microsoft/azure/transfers/file_to_wasb.py b/airflow/providers/microsoft/azure/transfers/file_to_wasb.py deleted file mode 100644 index 3979ad4bcd9e1..0000000000000 --- a/airflow/providers/microsoft/azure/transfers/file_to_wasb.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.transfers.local_to_wasb`.""" - -import warnings - -from airflow.providers.microsoft.azure.transfers.local_to_wasb import LocalFilesystemToWasbOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.transfers.local_to_wasb`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/microsoft/azure/transfers/local_to_adls.py b/airflow/providers/microsoft/azure/transfers/local_to_adls.py index dd7e76e135606..072b17652843b 100644 --- a/airflow/providers/microsoft/azure/transfers/local_to_adls.py +++ b/airflow/providers/microsoft/azure/transfers/local_to_adls.py @@ -14,8 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import warnings -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence +from typing import TYPE_CHECKING, Any, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -54,7 +56,7 @@ class LocalFilesystemToADLSOperator(BaseOperator): """ template_fields: Sequence[str] = ("local_path", "remote_path") - ui_color = '#e4f0e8' + ui_color = "#e4f0e8" def __init__( self, @@ -65,8 +67,8 @@ def __init__( nthreads: int = 64, buffersize: int = 4194304, blocksize: int = 4194304, - extra_upload_options: Optional[Dict[str, Any]] = None, - azure_data_lake_conn_id: str = 'azure_data_lake_default', + extra_upload_options: dict[str, Any] | None = None, + azure_data_lake_conn_id: str = "azure_data_lake_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -79,13 +81,13 @@ def __init__( self.extra_upload_options = extra_upload_options self.azure_data_lake_conn_id = azure_data_lake_conn_id - def execute(self, context: "Context") -> None: - if '**' in self.local_path: + def execute(self, context: Context) -> None: + if "**" in self.local_path: raise AirflowException("Recursive glob patterns using `**` are not supported") if not self.extra_upload_options: self.extra_upload_options = {} hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id) - self.log.info('Uploading %s to %s', self.local_path, self.remote_path) + self.log.info("Uploading %s to %s", self.local_path, self.remote_path) return hook.upload_file( local_path=self.local_path, remote_path=self.remote_path, diff --git a/airflow/providers/microsoft/azure/transfers/local_to_wasb.py b/airflow/providers/microsoft/azure/transfers/local_to_wasb.py index 387e7ee6beda6..29dd11c1e72e4 100644 --- a/airflow/providers/microsoft/azure/transfers/local_to_wasb.py +++ b/airflow/providers/microsoft/azure/transfers/local_to_wasb.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -from typing import TYPE_CHECKING, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.microsoft.azure.hooks.wasb import WasbHook @@ -39,7 +40,7 @@ class LocalFilesystemToWasbOperator(BaseOperator): `WasbHook.load_file()` takes. """ - template_fields: Sequence[str] = ('file_path', 'container_name', 'blob_name') + template_fields: Sequence[str] = ("file_path", "container_name", "blob_name") def __init__( self, @@ -47,9 +48,9 @@ def __init__( file_path: str, container_name: str, blob_name: str, - wasb_conn_id: str = 'wasb_default', + wasb_conn_id: str = "wasb_default", create_container: bool = False, - load_options: Optional[dict] = None, + load_options: dict | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -62,11 +63,11 @@ def __init__( self.create_container = create_container self.load_options = load_options - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: """Upload a file to Azure Blob Storage.""" hook = WasbHook(wasb_conn_id=self.wasb_conn_id) self.log.info( - 'Uploading %s to wasb://%s as %s', + "Uploading %s to wasb://%s as %s", self.file_path, self.container_name, self.blob_name, diff --git a/airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py b/airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py index 150ac6c0f2bbb..3e632e512f634 100644 --- a/airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py +++ b/airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py @@ -15,10 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import os from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Sequence import unicodecsv as csv @@ -48,9 +49,9 @@ class OracleToAzureDataLakeOperator(BaseOperator): :param quoting: Quoting strategy. See unicodecsv quoting for more information. """ - template_fields: Sequence[str] = ('filename', 'sql', 'sql_params') + template_fields: Sequence[str] = ("filename", "sql", "sql_params") template_fields_renderers = {"sql_params": "py"} - ui_color = '#e08c8c' + ui_color = "#e08c8c" def __init__( self, @@ -60,7 +61,7 @@ def __init__( azure_data_lake_path: str, oracle_conn_id: str, sql: str, - sql_params: Optional[dict] = None, + sql_params: dict | None = None, delimiter: str = ",", encoding: str = "utf-8", quotechar: str = '"', @@ -81,8 +82,8 @@ def __init__( self.quotechar = quotechar self.quoting = quoting - def _write_temp_file(self, cursor: Any, path_to_save: Union[str, bytes, int]) -> None: - with open(path_to_save, 'wb') as csvfile: + def _write_temp_file(self, cursor: Any, path_to_save: str | bytes | int) -> None: + with open(path_to_save, "wb") as csvfile: csv_writer = csv.writer( csvfile, delimiter=self.delimiter, @@ -94,7 +95,7 @@ def _write_temp_file(self, cursor: Any, path_to_save: Union[str, bytes, int]) -> csv_writer.writerows(cursor) csvfile.flush() - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: oracle_hook = OracleHook(oracle_conn_id=self.oracle_conn_id) azure_data_lake_hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id) @@ -103,7 +104,7 @@ def execute(self, context: "Context") -> None: cursor = conn.cursor() # type: ignore[attr-defined] cursor.execute(self.sql, self.sql_params) - with TemporaryDirectory(prefix='airflow_oracle_to_azure_op_') as temp: + with TemporaryDirectory(prefix="airflow_oracle_to_azure_op_") as temp: self._write_temp_file(cursor, os.path.join(temp, self.filename)) self.log.info("Uploading local file to Azure Data Lake") azure_data_lake_hook.upload_file( diff --git a/airflow/providers/microsoft/azure/transfers/sftp_to_wasb.py b/airflow/providers/microsoft/azure/transfers/sftp_to_wasb.py index 1b50865cbdb7d..7f8aa7eb5bb00 100644 --- a/airflow/providers/microsoft/azure/transfers/sftp_to_wasb.py +++ b/airflow/providers/microsoft/azure/transfers/sftp_to_wasb.py @@ -16,28 +16,24 @@ # specific language governing permissions and limitations # under the License. """This module contains SFTP to Azure Blob Storage operator.""" +from __future__ import annotations + import os -import sys from collections import namedtuple from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Sequence if TYPE_CHECKING: from airflow.utils.context import Context - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.microsoft.azure.hooks.wasb import WasbHook from airflow.providers.sftp.hooks.sftp import SFTPHook WILDCARD = "*" -SftpFile = namedtuple('SftpFile', 'sftp_file_path, blob_name') +SftpFile = namedtuple("SftpFile", "sftp_file_path, blob_name") class SFTPToWasbOperator(BaseOperator): @@ -80,8 +76,8 @@ def __init__( container_name: str, blob_prefix: str = "", sftp_conn_id: str = "sftp_default", - wasb_conn_id: str = 'wasb_default', - load_options: Optional[Dict] = None, + wasb_conn_id: str = "wasb_default", + load_options: dict | None = None, move_object: bool = False, wasb_overwrite_object: bool = False, create_container: bool = False, @@ -101,10 +97,10 @@ def __init__( def dry_run(self) -> None: super().dry_run() - sftp_files: List[SftpFile] = self.get_sftp_files_map() + sftp_files: list[SftpFile] = self.get_sftp_files_map() for file in sftp_files: self.log.info( - 'Process will upload file from (SFTP) %s to wasb://%s as %s', + "Process will upload file from (SFTP) %s to wasb://%s as %s", file.sftp_file_path, self.container_name, file.blob_name, @@ -112,14 +108,14 @@ def dry_run(self) -> None: if self.move_object: self.log.info("Executing delete of %s", file) - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: """Upload a file from SFTP to Azure Blob Storage.""" - sftp_files: List[SftpFile] = self.get_sftp_files_map() + sftp_files: list[SftpFile] = self.get_sftp_files_map() uploaded_files = self.copy_files_to_wasb(sftp_files) if self.move_object: self.delete_files(uploaded_files) - def get_sftp_files_map(self) -> List[SftpFile]: + def get_sftp_files_map(self) -> list[SftpFile]: """Get SFTP files from the source path, it may use a WILDCARD to this end.""" sftp_files = [] @@ -137,7 +133,7 @@ def get_sftp_files_map(self) -> List[SftpFile]: return sftp_files - def get_tree_behavior(self) -> Tuple[str, Optional[str], Optional[str]]: + def get_tree_behavior(self) -> tuple[str, str | None, str | None]: """Extracts from source path the tree behavior to interact with the remote folder""" self.check_wildcards_limit() @@ -174,7 +170,7 @@ def get_full_path_blob(self, file: str) -> str: """Get a blob name based on the previous name and a blob_prefix variable""" return self.blob_prefix + os.path.basename(file) - def copy_files_to_wasb(self, sftp_files: List[SftpFile]) -> List[str]: + def copy_files_to_wasb(self, sftp_files: list[SftpFile]) -> list[str]: """Upload a list of files from sftp_files to Azure Blob Storage with a new Blob Name.""" uploaded_files = [] wasb_hook = WasbHook(wasb_conn_id=self.wasb_conn_id) @@ -182,7 +178,7 @@ def copy_files_to_wasb(self, sftp_files: List[SftpFile]) -> List[str]: with NamedTemporaryFile("w") as tmp: self.sftp_hook.retrieve_file(file.sftp_file_path, tmp.name) self.log.info( - 'Uploading %s to wasb://%s as %s', + "Uploading %s to wasb://%s as %s", file.sftp_file_path, self.container_name, file.blob_name, @@ -199,7 +195,7 @@ def copy_files_to_wasb(self, sftp_files: List[SftpFile]) -> List[str]: return uploaded_files - def delete_files(self, uploaded_files: List[str]) -> None: + def delete_files(self, uploaded_files: list[str]) -> None: """Delete files at SFTP which have been moved to Azure Blob Storage.""" for sftp_file_path in uploaded_files: self.log.info("Executing delete of %s", sftp_file_path) diff --git a/airflow/providers/microsoft/azure/utils.py b/airflow/providers/microsoft/azure/utils.py new file mode 100644 index 0000000000000..8c01100469055 --- /dev/null +++ b/airflow/providers/microsoft/azure/utils.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import warnings +from functools import wraps + + +def _ensure_prefixes(conn_type): + """ + Remove when provider min airflow version >= 2.5.0 since this is handled by + provider manager from that version. + """ + + def dec(func): + @wraps(func) + def inner(): + field_behaviors = func() + conn_attrs = {"host", "schema", "login", "password", "port", "extra"} + + def _ensure_prefix(field): + if field not in conn_attrs and not field.startswith("extra__"): + return f"extra__{conn_type}__{field}" + else: + return field + + if "placeholders" in field_behaviors: + placeholders = field_behaviors["placeholders"] + field_behaviors["placeholders"] = {_ensure_prefix(k): v for k, v in placeholders.items()} + return field_behaviors + + return inner + + return dec + + +def get_field(*, conn_id: str, conn_type: str, extras: dict, field_name: str): + """Get field from extra, first checking short name, then for backcompat we check for prefixed name.""" + backcompat_prefix = f"extra__{conn_type}__" + backcompat_key = f"{backcompat_prefix}{field_name}" + ret = None + if field_name.startswith("extra__"): + raise ValueError( + f"Got prefixed name {field_name}; please remove the '{backcompat_prefix}' prefix " + "when using this method." + ) + if field_name in extras: + if backcompat_key in extras: + warnings.warn( + f"Conflicting params `{field_name}` and `{backcompat_key}` found in extras for conn " + f"{conn_id}. Using value for `{field_name}`. Please ensure this is the correct " + f"value and remove the backcompat key `{backcompat_key}`." + ) + ret = extras[field_name] + elif backcompat_key in extras: + ret = extras.get(backcompat_key) + if ret == "": + return None + return ret diff --git a/airflow/providers/microsoft/mssql/CHANGELOG.rst b/airflow/providers/microsoft/mssql/CHANGELOG.rst index f1459156817fe..5fc41a0e4d72d 100644 --- a/airflow/providers/microsoft/mssql/CHANGELOG.rst +++ b/airflow/providers/microsoft/mssql/CHANGELOG.rst @@ -16,9 +16,90 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.3.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` +* ``Remove unnecessary newlines around single arg in signature (#27525)`` + +Features +~~~~~~~~ + +* ``Add SQLExecuteQueryOperator (#25717)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + +3.2.1 +..... + +Misc +~~~~ + +* ``Add common-sql lower bound for common-sql (#25789)`` + + +.. Review and move the new changes to one of the sections above: + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +3.2.0 +..... + +Features +~~~~~~~~ + +* ``Unify DbApiHook.run() method with the methods which override it (#23971)`` + +Bug Fixes +~~~~~~~~~ + +* ``Fix MsSqlHook.get_uri: pymssql driver to scheme (25092) (#25185)`` + + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Move all SQL classes to common-sql provider (#24836)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Migrate Microsoft example DAGs to new design #22452 - mssql (#24139)`` + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.1.3 ..... diff --git a/airflow/providers/microsoft/mssql/example_dags/example_mssql.py b/airflow/providers/microsoft/mssql/example_dags/example_mssql.py deleted file mode 100644 index 84a3b68a199f6..0000000000000 --- a/airflow/providers/microsoft/mssql/example_dags/example_mssql.py +++ /dev/null @@ -1,137 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example use of MsSql related operators. -""" -# [START mssql_operator_howto_guide] - -from datetime import datetime - -from airflow import DAG -from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook -from airflow.providers.microsoft.mssql.operators.mssql import MsSqlOperator - -dag = DAG( - 'example_mssql', - schedule_interval='@daily', - start_date=datetime(2021, 10, 1), - tags=['example'], - catchup=False, -) - -# [START howto_operator_mssql] - -# Example of creating a task to create a table in MsSql - -create_table_mssql_task = MsSqlOperator( - task_id='create_country_table', - mssql_conn_id='airflow_mssql', - sql=r""" - CREATE TABLE Country ( - country_id INT NOT NULL IDENTITY(1,1) PRIMARY KEY, - name TEXT, - continent TEXT - ); - """, - dag=dag, -) - -# [END howto_operator_mssql] - -# [START mssql_hook_howto_guide_insert_mssql_hook] - - -@dag.task(task_id="insert_mssql_task") -def insert_mssql_hook(): - mssql_hook = MsSqlHook(mssql_conn_id='airflow_mssql', schema='airflow') - - rows = [ - ('India', 'Asia'), - ('Germany', 'Europe'), - ('Argentina', 'South America'), - ('Ghana', 'Africa'), - ('Japan', 'Asia'), - ('Namibia', 'Africa'), - ] - target_fields = ['name', 'continent'] - mssql_hook.insert_rows(table='Country', rows=rows, target_fields=target_fields) - - -# [END mssql_hook_howto_guide_insert_mssql_hook] - -# [START mssql_operator_howto_guide_create_table_mssql_from_external_file] -# Example of creating a task that calls an sql command from an external file. -create_table_mssql_from_external_file = MsSqlOperator( - task_id='create_table_from_external_file', - mssql_conn_id='airflow_mssql', - sql='create_table.sql', - dag=dag, -) -# [END mssql_operator_howto_guide_create_table_mssql_from_external_file] - -# [START mssql_operator_howto_guide_populate_user_table] -populate_user_table = MsSqlOperator( - task_id='populate_user_table', - mssql_conn_id='airflow_mssql', - sql=r""" - INSERT INTO Users (username, description) - VALUES ( 'Danny', 'Musician'); - INSERT INTO Users (username, description) - VALUES ( 'Simone', 'Chef'); - INSERT INTO Users (username, description) - VALUES ( 'Lily', 'Florist'); - INSERT INTO Users (username, description) - VALUES ( 'Tim', 'Pet shop owner'); - """, -) -# [END mssql_operator_howto_guide_populate_user_table] - -# [START mssql_operator_howto_guide_get_all_countries] -get_all_countries = MsSqlOperator( - task_id="get_all_countries", - mssql_conn_id='airflow_mssql', - sql=r"""SELECT * FROM Country;""", -) -# [END mssql_operator_howto_guide_get_all_countries] - -# [START mssql_operator_howto_guide_get_all_description] -get_all_description = MsSqlOperator( - task_id="get_all_description", - mssql_conn_id='airflow_mssql', - sql=r"""SELECT description FROM Users;""", -) -# [END mssql_operator_howto_guide_get_all_description] - -# [START mssql_operator_howto_guide_params_passing_get_query] -get_countries_from_continent = MsSqlOperator( - task_id="get_countries_from_continent", - mssql_conn_id='airflow_mssql', - sql=r"""SELECT * FROM Country where {{ params.column }}='{{ params.value }}';""", - params={"column": "CONVERT(VARCHAR, continent)", "value": "Asia"}, -) -# [END mssql_operator_howto_guide_params_passing_get_query] -( - create_table_mssql_task - >> insert_mssql_hook() - >> create_table_mssql_from_external_file - >> populate_user_table - >> get_all_countries - >> get_all_description - >> get_countries_from_continent -) -# [END mssql_operator_howto_guide] diff --git a/airflow/providers/microsoft/mssql/hooks/mssql.py b/airflow/providers/microsoft/mssql/hooks/mssql.py index 75241f1f296d3..d6d6209526542 100644 --- a/airflow/providers/microsoft/mssql/hooks/mssql.py +++ b/airflow/providers/microsoft/mssql/hooks/mssql.py @@ -15,30 +15,82 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Microsoft SQLServer hook module""" +from __future__ import annotations + +from typing import Any import pymssql -from airflow.hooks.dbapi import DbApiHook +from airflow.providers.common.sql.hooks.sql import DbApiHook class MsSqlHook(DbApiHook): """Interact with Microsoft SQL Server.""" - conn_name_attr = 'mssql_conn_id' - default_conn_name = 'mssql_default' - conn_type = 'mssql' - hook_name = 'Microsoft SQL Server' + conn_name_attr = "mssql_conn_id" + default_conn_name = "mssql_default" + conn_type = "mssql" + hook_name = "Microsoft SQL Server" supports_autocommit = True + DEFAULT_SQLALCHEMY_SCHEME = "mssql+pymssql" - def __init__(self, *args, **kwargs) -> None: + def __init__( + self, + *args, + sqlalchemy_scheme: str | None = None, + **kwargs, + ) -> None: + """ + :param args: passed to DBApiHook + :param sqlalchemy_scheme: Scheme sqlalchemy connection. Default is ``mssql+pymssql`` Only used for + ``get_sqlalchemy_engine`` and ``get_sqlalchemy_connection`` methods. + :param kwargs: passed to DbApiHook + """ super().__init__(*args, **kwargs) self.schema = kwargs.pop("schema", None) + self._sqlalchemy_scheme = sqlalchemy_scheme - def get_conn( - self, - ) -> pymssql.connect: + @property + def connection_extra_lower(self) -> dict: + """ + ``connection.extra_dejson`` but where keys are converted to lower case. + This is used internally for case-insensitive access of mssql params. + """ + conn = self.get_connection(self.mssql_conn_id) # type: ignore[attr-defined] + return {k.lower(): v for k, v in conn.extra_dejson.items()} + + @property + def sqlalchemy_scheme(self) -> str: + """Sqlalchemy scheme either from constructor, connection extras or default.""" + return ( + self._sqlalchemy_scheme + or self.connection_extra_lower.get("sqlalchemy_scheme") + or self.DEFAULT_SQLALCHEMY_SCHEME + ) + + def get_uri(self) -> str: + from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit + + r = list(urlsplit(super().get_uri())) + # change pymssql driver: + r[0] = self.sqlalchemy_scheme + # remove query string 'sqlalchemy_scheme' like parameters: + qs = parse_qs(r[3], keep_blank_values=True) + for k in list(qs.keys()): + if k.lower() == "sqlalchemy_scheme": + qs.pop(k, None) + r[3] = urlencode(qs, doseq=True) + return urlunsplit(r) + + def get_sqlalchemy_connection( + self, connect_kwargs: dict | None = None, engine_kwargs: dict | None = None + ) -> Any: + """Sqlalchemy connection object""" + engine = self.get_sqlalchemy_engine(engine_kwargs=engine_kwargs) + return engine.connect(**(connect_kwargs or {})) + + def get_conn(self) -> pymssql.connect: """Returns a mssql connection object""" conn = self.get_connection(self.mssql_conn_id) # type: ignore[attr-defined] diff --git a/airflow/providers/microsoft/mssql/operators/mssql.py b/airflow/providers/microsoft/mssql/operators/mssql.py index 7082c7c52d90b..1b7c47886c05a 100644 --- a/airflow/providers/microsoft/mssql/operators/mssql.py +++ b/airflow/providers/microsoft/mssql/operators/mssql.py @@ -15,19 +15,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union +from __future__ import annotations -from airflow.exceptions import AirflowException -from airflow.models import BaseOperator -from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook -from airflow.www import utils as wwwutils +import warnings +from typing import Sequence -if TYPE_CHECKING: - from airflow.hooks.dbapi import DbApiHook - from airflow.utils.context import Context +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator -class MsSqlOperator(BaseOperator): +class MsSqlOperator(SQLExecuteQueryOperator): """ Executes sql code in a specific Microsoft SQL database @@ -48,49 +44,23 @@ class MsSqlOperator(BaseOperator): :param database: name of database which overwrite defined one in connection """ - template_fields: Sequence[str] = ('sql',) - template_ext: Sequence[str] = ('.sql',) - # TODO: Remove renderer check when the provider has an Airflow 2.3+ requirement. - template_fields_renderers = {'sql': 'tsql' if 'tsql' in wwwutils.get_attr_renderer() else 'sql'} - ui_color = '#ededed' + template_fields: Sequence[str] = ("sql",) + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "tsql"} + ui_color = "#ededed" def __init__( - self, - *, - sql: str, - mssql_conn_id: str = 'mssql_default', - parameters: Optional[Union[Mapping, Iterable]] = None, - autocommit: bool = False, - database: Optional[str] = None, - **kwargs, + self, *, mssql_conn_id: str = "mssql_default", database: str | None = None, **kwargs ) -> None: - super().__init__(**kwargs) - self.mssql_conn_id = mssql_conn_id - self.sql = sql - self.parameters = parameters - self.autocommit = autocommit - self.database = database - self._hook: Optional[Union[MsSqlHook, 'DbApiHook']] = None + if database is not None: + hook_params = kwargs.pop("hook_params", {}) + kwargs["hook_params"] = {"schema": database, **hook_params} - def get_hook(self) -> Optional[Union[MsSqlHook, 'DbApiHook']]: - """ - Will retrieve hook as determined by :meth:`~.Connection.get_hook` if one is defined, and - :class:`~.MsSqlHook` otherwise. - - For example, if the connection ``conn_type`` is ``'odbc'``, :class:`~.OdbcHook` will be used. - """ - if not self._hook: - conn = MsSqlHook.get_connection(conn_id=self.mssql_conn_id) - try: - self._hook = conn.get_hook() - self._hook.schema = self.database # type: ignore[union-attr] - except AirflowException: - self._hook = MsSqlHook(mssql_conn_id=self.mssql_conn_id, schema=self.database) - return self._hook - - def execute(self, context: 'Context') -> None: - self.log.info('Executing: %s', self.sql) - hook = self.get_hook() - hook.run( # type: ignore[union-attr] - sql=self.sql, autocommit=self.autocommit, parameters=self.parameters + super().__init__(conn_id=mssql_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`. + Also, you can provide `hook_params={'schema': }`.""", + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/providers/microsoft/mssql/provider.yaml b/airflow/providers/microsoft/mssql/provider.yaml index 5194fdfb866df..228f411cb26dc 100644 --- a/airflow/providers/microsoft/mssql/provider.yaml +++ b/airflow/providers/microsoft/mssql/provider.yaml @@ -22,6 +22,11 @@ description: | `Microsoft SQL Server (MSSQL) `__ versions: + - 3.3.0 + - 3.2.1 + - 3.2.0 + - 3.1.0 + - 3.0.0 - 2.1.3 - 2.1.2 - 2.1.1 @@ -32,8 +37,10 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-common-sql>=1.3.1 + - pymssql>=2.1.5 integrations: - integration-name: Microsoft SQL Server (MSSQL) @@ -53,9 +60,6 @@ hooks: python-modules: - airflow.providers.microsoft.mssql.hooks.mssql -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook - connection-types: - hook-class-name: airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook connection-type: mssql diff --git a/airflow/providers/microsoft/psrp/.latest-doc-only-change.txt b/airflow/providers/microsoft/psrp/.latest-doc-only-change.txt index 570fad6daee29..ff7136e07d744 100644 --- a/airflow/providers/microsoft/psrp/.latest-doc-only-change.txt +++ b/airflow/providers/microsoft/psrp/.latest-doc-only-change.txt @@ -1 +1 @@ -97496ba2b41063fa24393c58c5c648a0cdb5a7f8 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/microsoft/psrp/CHANGELOG.rst b/airflow/providers/microsoft/psrp/CHANGELOG.rst index 0bdf4d919944a..d3f18353e939e 100644 --- a/airflow/providers/microsoft/psrp/CHANGELOG.rst +++ b/airflow/providers/microsoft/psrp/CHANGELOG.rst @@ -16,9 +16,51 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +2.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Add documentation for July 2022 Provider's release (#25030)`` + +2.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Ensure @contextmanager decorates generator func (#23103)`` + * ``Introduce 'flake8-implicit-str-concat' plugin to static checks (#23873)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 1.1.4 ..... diff --git a/airflow/providers/microsoft/psrp/hooks/psrp.py b/airflow/providers/microsoft/psrp/hooks/psrp.py index 0aebe63d0319e..00cdae17a28a3 100644 --- a/airflow/providers/microsoft/psrp/hooks/psrp.py +++ b/airflow/providers/microsoft/psrp/hooks/psrp.py @@ -15,11 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from contextlib import contextmanager from copy import copy from logging import DEBUG, ERROR, INFO, WARNING -from typing import Any, Callable, Dict, Generator, Optional +from typing import Any, Callable, Generator from weakref import WeakKeyDictionary from pypsrp.host import PSHost @@ -75,18 +76,18 @@ class PsrpHook(BaseHook): _conn = None _configuration_name = None - _wsman_ref: "WeakKeyDictionary[RunspacePool, WSMan]" = WeakKeyDictionary() + _wsman_ref: WeakKeyDictionary[RunspacePool, WSMan] = WeakKeyDictionary() def __init__( self, psrp_conn_id: str, logging_level: int = DEBUG, - operation_timeout: Optional[int] = None, - runspace_options: Optional[Dict[str, Any]] = None, - wsman_options: Optional[Dict[str, Any]] = None, - on_output_callback: Optional[OutputCallback] = None, + operation_timeout: int | None = None, + runspace_options: dict[str, Any] | None = None, + wsman_options: dict[str, Any] | None = None, + on_output_callback: OutputCallback | None = None, exchange_keys: bool = True, - host: Optional[PSHost] = None, + host: PSHost | None = None, ): self.conn_id = psrp_conn_id self._logging_level = logging_level @@ -214,7 +215,7 @@ def invoke(self) -> Generator[PowerShell, None, None]: if local_context: self.__exit__(None, None, None) - def invoke_cmdlet(self, name: str, use_local_scope=None, **parameters: Dict[str, str]) -> PowerShell: + def invoke_cmdlet(self, name: str, use_local_scope=None, **parameters: dict[str, str]) -> PowerShell: """Invoke a PowerShell cmdlet and return session.""" with self.invoke() as ps: ps.add_cmdlet(name, use_local_scope=use_local_scope) @@ -232,7 +233,7 @@ def _log_record(self, log, record): if message_type == MessageType.ERROR_RECORD: log(INFO, "%s: %s", record.reason, record) if record.script_stacktrace: - for trace in record.script_stacktrace.split('\r\n'): + for trace in record.script_stacktrace.split("\r\n"): log(INFO, trace) level = INFORMATIONAL_RECORD_LEVEL_MAP.get(message_type) diff --git a/airflow/providers/microsoft/psrp/operators/psrp.py b/airflow/providers/microsoft/psrp/operators/psrp.py index ea07ee9115567..733b8cb29fd3a 100644 --- a/airflow/providers/microsoft/psrp/operators/psrp.py +++ b/airflow/providers/microsoft/psrp/operators/psrp.py @@ -15,9 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from logging import DEBUG -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Sequence from jinja2.nativetypes import NativeEnvironment from pypsrp.powershell import Command @@ -27,12 +28,7 @@ from airflow.models import BaseOperator from airflow.providers.microsoft.psrp.hooks.psrp import PsrpHook from airflow.settings import json - - -# TODO: Replace with airflow.utils.helpers.exactly_one in Airflow 2.3. -def exactly_one(*args): - return len(set(filter(None, args))) == 1 - +from airflow.utils.helpers import exactly_one if TYPE_CHECKING: from airflow.utils.context import Context @@ -93,14 +89,14 @@ def __init__( self, *, psrp_conn_id: str, - command: Optional[str] = None, - powershell: Optional[str] = None, - cmdlet: Optional[str] = None, - parameters: Optional[Dict[str, str]] = None, + command: str | None = None, + powershell: str | None = None, + cmdlet: str | None = None, + parameters: dict[str, str] | None = None, logging_level: int = DEBUG, - runspace_options: Optional[Dict[str, Any]] = None, - wsman_options: Optional[Dict[str, Any]] = None, - psrp_session_init: Optional[Command] = None, + runspace_options: dict[str, Any] | None = None, + wsman_options: dict[str, Any] | None = None, + psrp_session_init: Command | None = None, **kwargs, ) -> None: args = {command, powershell, cmdlet} @@ -109,7 +105,7 @@ def __init__( if parameters and not cmdlet: raise ValueError("Parameters only allowed with 'cmdlet'") if cmdlet: - kwargs.setdefault('task_id', cmdlet) + kwargs.setdefault("task_id", cmdlet) super().__init__(**kwargs) self.conn_id = psrp_conn_id self.command = command @@ -121,7 +117,7 @@ def __init__( self.wsman_options = wsman_options self.psrp_session_init = psrp_session_init - def execute(self, context: "Context") -> Optional[List[Any]]: + def execute(self, context: Context) -> list[Any] | None: with PsrpHook( self.conn_id, logging_level=self.logging_level, diff --git a/airflow/providers/microsoft/psrp/provider.yaml b/airflow/providers/microsoft/psrp/provider.yaml index 0741c53200694..ae39d29266597 100644 --- a/airflow/providers/microsoft/psrp/provider.yaml +++ b/airflow/providers/microsoft/psrp/provider.yaml @@ -24,6 +24,8 @@ description: | `__. versions: + - 2.1.0 + - 2.0.0 - 1.1.4 - 1.1.3 - 1.1.2 @@ -32,7 +34,7 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: +dependencies: - pypsrp>=0.8.0 integrations: diff --git a/airflow/providers/microsoft/winrm/.latest-doc-only-change.txt b/airflow/providers/microsoft/winrm/.latest-doc-only-change.txt index ab24993f57139..ff7136e07d744 100644 --- a/airflow/providers/microsoft/winrm/.latest-doc-only-change.txt +++ b/airflow/providers/microsoft/winrm/.latest-doc-only-change.txt @@ -1 +1 @@ -8b6b0848a3cacf9999477d6af4d2a87463f03026 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/microsoft/winrm/CHANGELOG.rst b/airflow/providers/microsoft/winrm/CHANGELOG.rst index 79f3700bb44f8..a357075534f90 100644 --- a/airflow/providers/microsoft/winrm/CHANGELOG.rst +++ b/airflow/providers/microsoft/winrm/CHANGELOG.rst @@ -16,9 +16,54 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` +* ``A few docs fixups (#26788)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Migrate Microsoft example DAGs to new design #22452 - winrm (#24140)`` + * ``Prepare provider documentation 2022.05.11 (#23631)`` + * ``Use new Breese for building, pulling and verifying the images. (#23104)`` + * ``Replace usage of 'DummyOperator' with 'EmptyOperator' (#22974)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.0.5 ..... diff --git a/airflow/providers/microsoft/winrm/example_dags/example_winrm.py b/airflow/providers/microsoft/winrm/example_dags/example_winrm.py deleted file mode 100644 index f95472989991f..0000000000000 --- a/airflow/providers/microsoft/winrm/example_dags/example_winrm.py +++ /dev/null @@ -1,63 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -------------------------------------------------------------------------------- -# Caveat: This Dag will not run because of missing scripts. -# The purpose of this is to give you a sample of a real world example DAG! -# -------------------------------------------------------------------------------- - -# -------------------------------------------------------------------------------- -# Load The Dependencies -# -------------------------------------------------------------------------------- -""" -This is an example dag for using the WinRMOperator. -""" -from datetime import datetime, timedelta - -from airflow import DAG - -try: - from airflow.operators.empty import EmptyOperator -except ModuleNotFoundError: - from airflow.operators.dummy import DummyOperator as EmptyOperator # type: ignore -from airflow.providers.microsoft.winrm.hooks.winrm import WinRMHook -from airflow.providers.microsoft.winrm.operators.winrm import WinRMOperator - -with DAG( - dag_id='POC_winrm_parallel', - schedule_interval='0 0 * * *', - start_date=datetime(2021, 1, 1), - dagrun_timeout=timedelta(minutes=60), - tags=['example'], - catchup=False, -) as dag: - - run_this_last = EmptyOperator(task_id='run_this_last') - - # [START create_hook] - winRMHook = WinRMHook(ssh_conn_id='ssh_POC1') - # [END create_hook] - - # [START run_operator] - t1 = WinRMOperator(task_id="wintask1", command='ls -altr', winrm_hook=winRMHook) - - t2 = WinRMOperator(task_id="wintask2", command='sleep 60', winrm_hook=winRMHook) - - t3 = WinRMOperator(task_id="wintask3", command='echo \'luke test\' ', winrm_hook=winRMHook) - # [END run_operator] - - [t1, t2, t3] >> run_this_last diff --git a/airflow/providers/microsoft/winrm/hooks/winrm.py b/airflow/providers/microsoft/winrm/hooks/winrm.py index ee11dbac9c9d2..22ec7666bc6cd 100644 --- a/airflow/providers/microsoft/winrm/hooks/winrm.py +++ b/airflow/providers/microsoft/winrm/hooks/winrm.py @@ -15,9 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """Hook for winrm remote execution.""" -from typing import Optional +from __future__ import annotations from winrm.protocol import Protocol @@ -40,8 +39,8 @@ class WinRMHook(BaseHook): :seealso: https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py :param ssh_conn_id: connection id from airflow Connections from where - all the required parameters can be fetched like username and password. - Thought the priority is given to the param passed during init + all the required parameters can be fetched like username and password, + though priority is given to the params passed during init. :param endpoint: When not set, endpoint will be constructed like this: 'http://{remote_host}:{remote_port}/wsman' :param remote_host: Remote host to connect to. Ignored if `endpoint` is set. @@ -76,24 +75,24 @@ class WinRMHook(BaseHook): def __init__( self, - ssh_conn_id: Optional[str] = None, - endpoint: Optional[str] = None, - remote_host: Optional[str] = None, + ssh_conn_id: str | None = None, + endpoint: str | None = None, + remote_host: str | None = None, remote_port: int = 5985, - transport: str = 'plaintext', - username: Optional[str] = None, - password: Optional[str] = None, - service: str = 'HTTP', - keytab: Optional[str] = None, - ca_trust_path: Optional[str] = None, - cert_pem: Optional[str] = None, - cert_key_pem: Optional[str] = None, - server_cert_validation: str = 'validate', + transport: str = "plaintext", + username: str | None = None, + password: str | None = None, + service: str = "HTTP", + keytab: str | None = None, + ca_trust_path: str | None = None, + cert_pem: str | None = None, + cert_key_pem: str | None = None, + server_cert_validation: str = "validate", kerberos_delegation: bool = False, read_timeout_sec: int = 30, operation_timeout_sec: int = 20, - kerberos_hostname_override: Optional[str] = None, - message_encryption: Optional[str] = 'auto', + kerberos_hostname_override: str | None = None, + message_encryption: str | None = "auto", credssp_disable_tlsv1_2: bool = False, send_cbt: bool = True, ) -> None: @@ -126,7 +125,7 @@ def get_conn(self): if self.client: return self.client - self.log.debug('Creating WinRM client for conn_id: %s', self.ssh_conn_id) + self.log.debug("Creating WinRM client for conn_id: %s", self.ssh_conn_id) if self.ssh_conn_id is not None: conn = self.get_connection(self.ssh_conn_id) @@ -159,7 +158,7 @@ def get_conn(self): if "server_cert_validation" in extra_options: self.server_cert_validation = str(extra_options["server_cert_validation"]) if "kerberos_delegation" in extra_options: - self.kerberos_delegation = str(extra_options["kerberos_delegation"]).lower() == 'true' + self.kerberos_delegation = str(extra_options["kerberos_delegation"]).lower() == "true" if "read_timeout_sec" in extra_options: self.read_timeout_sec = int(extra_options["read_timeout_sec"]) if "operation_timeout_sec" in extra_options: @@ -170,10 +169,10 @@ def get_conn(self): self.message_encryption = str(extra_options["message_encryption"]) if "credssp_disable_tlsv1_2" in extra_options: self.credssp_disable_tlsv1_2 = ( - str(extra_options["credssp_disable_tlsv1_2"]).lower() == 'true' + str(extra_options["credssp_disable_tlsv1_2"]).lower() == "true" ) if "send_cbt" in extra_options: - self.send_cbt = str(extra_options["send_cbt"]).lower() == 'true' + self.send_cbt = str(extra_options["send_cbt"]).lower() == "true" if not self.remote_host: raise AirflowException("Missing required param: remote_host") @@ -190,7 +189,7 @@ def get_conn(self): # If endpoint is not set, then build a standard wsman endpoint from host and port. if not self.endpoint: - self.endpoint = f'http://{self.remote_host}:{self.remote_port}/wsman' + self.endpoint = f"http://{self.remote_host}:{self.remote_port}/wsman" try: if self.password and self.password.strip(): diff --git a/airflow/providers/microsoft/winrm/operators/winrm.py b/airflow/providers/microsoft/winrm/operators/winrm.py index ea96c43a8cd71..7aef926173c08 100644 --- a/airflow/providers/microsoft/winrm/operators/winrm.py +++ b/airflow/providers/microsoft/winrm/operators/winrm.py @@ -15,10 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import logging from base64 import b64encode -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from winrm.exceptions import WinRMOperationTimeoutError @@ -34,7 +35,7 @@ # requests.packages.urllib3.exceptions.HeaderParsingError: [StartBoundaryNotFoundDefect(), # MultipartInvariantViolationDefect()], unparsed data: '' -logging.getLogger('urllib3.connectionpool').setLevel(logging.ERROR) +logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR) class WinRMOperator(BaseOperator): @@ -51,18 +52,18 @@ class WinRMOperator(BaseOperator): :param timeout: timeout for executing the command. """ - template_fields: Sequence[str] = ('command',) + template_fields: Sequence[str] = ("command",) template_fields_renderers = {"command": "powershell"} def __init__( self, *, - winrm_hook: Optional[WinRMHook] = None, - ssh_conn_id: Optional[str] = None, - remote_host: Optional[str] = None, - command: Optional[str] = None, - ps_path: Optional[str] = None, - output_encoding: str = 'utf-8', + winrm_hook: WinRMHook | None = None, + ssh_conn_id: str | None = None, + remote_host: str | None = None, + command: str | None = None, + ps_path: str | None = None, + output_encoding: str = "utf-8", timeout: int = 10, **kwargs, ) -> None: @@ -75,7 +76,7 @@ def __init__( self.output_encoding = output_encoding self.timeout = timeout - def execute(self, context: "Context") -> Union[list, str]: + def execute(self, context: Context) -> list | str: if self.ssh_conn_id and not self.winrm_hook: self.log.info("Hook not found, creating...") self.winrm_hook = WinRMHook(ssh_conn_id=self.ssh_conn_id) @@ -94,9 +95,9 @@ def execute(self, context: "Context") -> Union[list, str]: try: if self.ps_path is not None: self.log.info("Running command as powershell script: '%s'...", self.command) - encoded_ps = b64encode(self.command.encode('utf_16_le')).decode('ascii') + encoded_ps = b64encode(self.command.encode("utf_16_le")).decode("ascii") command_id = self.winrm_hook.winrm_protocol.run_command( # type: ignore[attr-defined] - winrm_client, f'{self.ps_path} -encodedcommand {encoded_ps}' + winrm_client, f"{self.ps_path} -encodedcommand {encoded_ps}" ) else: self.log.info("Running command: '%s'...", self.command) @@ -144,13 +145,13 @@ def execute(self, context: "Context") -> Union[list, str]: if return_code == 0: # returning output if do_xcom_push is set - enable_pickling = conf.getboolean('core', 'enable_xcom_pickling') + enable_pickling = conf.getboolean("core", "enable_xcom_pickling") if enable_pickling: return stdout_buffer else: - return b64encode(b''.join(stdout_buffer)).decode(self.output_encoding) + return b64encode(b"".join(stdout_buffer)).decode(self.output_encoding) else: - stderr_output = b''.join(stderr_buffer).decode(self.output_encoding) + stderr_output = b"".join(stderr_buffer).decode(self.output_encoding) error_msg = ( f"Error running cmd: {self.command}, return code: {return_code}, error: {stderr_output}" ) diff --git a/airflow/providers/microsoft/winrm/provider.yaml b/airflow/providers/microsoft/winrm/provider.yaml index e41011aa7d548..a26bddb5b57f4 100644 --- a/airflow/providers/microsoft/winrm/provider.yaml +++ b/airflow/providers/microsoft/winrm/provider.yaml @@ -22,6 +22,8 @@ description: | `Windows Remote Management (WinRM) `__ versions: + - 3.1.0 + - 3.0.0 - 2.0.5 - 2.0.4 - 2.0.3 @@ -33,8 +35,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - pywinrm>=0.4 integrations: - integration-name: Windows Remote Management (WinRM) diff --git a/airflow/providers/mongo/.latest-doc-only-change.txt b/airflow/providers/mongo/.latest-doc-only-change.txt index 029fd1fd22aec..ff7136e07d744 100644 --- a/airflow/providers/mongo/.latest-doc-only-change.txt +++ b/airflow/providers/mongo/.latest-doc-only-change.txt @@ -1 +1 @@ -2d109401b3566aef613501691d18cf7e4c776cd2 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/mongo/CHANGELOG.rst b/airflow/providers/mongo/CHANGELOG.rst index 7228008cae433..074c1d2d7c642 100644 --- a/airflow/providers/mongo/CHANGELOG.rst +++ b/airflow/providers/mongo/CHANGELOG.rst @@ -16,9 +16,53 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` +* ``Fix links to sources for examples (#24386)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare provider documentation 2022.05.11 (#23631)`` + * ``Bump pre-commit hook versions (#22887)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.3.3 ..... diff --git a/airflow/providers/mongo/hooks/mongo.py b/airflow/providers/mongo/hooks/mongo.py index 96a5ec800302a..4e9d99646fcc4 100644 --- a/airflow/providers/mongo/hooks/mongo.py +++ b/airflow/providers/mongo/hooks/mongo.py @@ -16,9 +16,10 @@ # specific language governing permissions and limitations # under the License. """Hook for Mongo DB""" +from __future__ import annotations + from ssl import CERT_NONE from types import TracebackType -from typing import List, Optional, Type, Union import pymongo from pymongo import MongoClient, ReplaceOne @@ -44,10 +45,10 @@ class MongoHook(BaseHook): when connecting to MongoDB. """ - conn_name_attr = 'conn_id' - default_conn_name = 'mongo_default' - conn_type = 'mongo' - hook_name = 'MongoDB' + conn_name_attr = "conn_id" + default_conn_name = "mongo_default" + conn_type = "mongo" + hook_name = "MongoDB" def __init__(self, conn_id: str = default_conn_name, *args, **kwargs) -> None: @@ -57,21 +58,21 @@ def __init__(self, conn_id: str = default_conn_name, *args, **kwargs) -> None: self.extras = self.connection.extra_dejson.copy() self.client = None - srv = self.extras.pop('srv', False) - scheme = 'mongodb+srv' if srv else 'mongodb' + srv = self.extras.pop("srv", False) + scheme = "mongodb+srv" if srv else "mongodb" - creds = f'{self.connection.login}:{self.connection.password}@' if self.connection.login else '' - port = '' if self.connection.port is None else f':{self.connection.port}' - self.uri = f'{scheme}://{creds}{self.connection.host}{port}/{self.connection.schema}' + creds = f"{self.connection.login}:{self.connection.password}@" if self.connection.login else "" + port = "" if self.connection.port is None else f":{self.connection.port}" + self.uri = f"{scheme}://{creds}{self.connection.host}{port}/{self.connection.schema}" def __enter__(self): return self def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: if self.client is not None: self.close_conn() @@ -85,8 +86,8 @@ def get_conn(self) -> MongoClient: options = self.extras # If we are using SSL disable requiring certs from specific hostname - if options.get('ssl', False): - options.update({'ssl_cert_reqs': CERT_NONE}) + if options.get("ssl", False): + options.update({"ssl_cert_reqs": CERT_NONE}) self.client = MongoClient(self.uri, **options) @@ -100,7 +101,7 @@ def close_conn(self) -> None: self.client = None def get_collection( - self, mongo_collection: str, mongo_db: Optional[str] = None + self, mongo_collection: str, mongo_db: str | None = None ) -> pymongo.collection.Collection: """ Fetches a mongo collection object for querying. @@ -113,7 +114,7 @@ def get_collection( return mongo_conn.get_database(mongo_db).get_collection(mongo_collection) def aggregate( - self, mongo_collection: str, aggregate_query: list, mongo_db: Optional[str] = None, **kwargs + self, mongo_collection: str, aggregate_query: list, mongo_db: str | None = None, **kwargs ) -> pymongo.command_cursor.CommandCursor: """ Runs an aggregation pipeline and returns the results @@ -129,8 +130,8 @@ def find( mongo_collection: str, query: dict, find_one: bool = False, - mongo_db: Optional[str] = None, - projection: Optional[Union[list, dict]] = None, + mongo_db: str | None = None, + projection: list | dict | None = None, **kwargs, ) -> pymongo.cursor.Cursor: """ @@ -145,7 +146,7 @@ def find( return collection.find(query, projection, **kwargs) def insert_one( - self, mongo_collection: str, doc: dict, mongo_db: Optional[str] = None, **kwargs + self, mongo_collection: str, doc: dict, mongo_db: str | None = None, **kwargs ) -> pymongo.results.InsertOneResult: """ Inserts a single document into a mongo collection @@ -156,7 +157,7 @@ def insert_one( return collection.insert_one(doc, **kwargs) def insert_many( - self, mongo_collection: str, docs: dict, mongo_db: Optional[str] = None, **kwargs + self, mongo_collection: str, docs: dict, mongo_db: str | None = None, **kwargs ) -> pymongo.results.InsertManyResult: """ Inserts many docs into a mongo collection. @@ -171,7 +172,7 @@ def update_one( mongo_collection: str, filter_doc: dict, update_doc: dict, - mongo_db: Optional[str] = None, + mongo_db: str | None = None, **kwargs, ) -> pymongo.results.UpdateResult: """ @@ -194,7 +195,7 @@ def update_many( mongo_collection: str, filter_doc: dict, update_doc: dict, - mongo_db: Optional[str] = None, + mongo_db: str | None = None, **kwargs, ) -> pymongo.results.UpdateResult: """ @@ -216,8 +217,8 @@ def replace_one( self, mongo_collection: str, doc: dict, - filter_doc: Optional[dict] = None, - mongo_db: Optional[str] = None, + filter_doc: dict | None = None, + mongo_db: str | None = None, **kwargs, ) -> pymongo.results.UpdateResult: """ @@ -238,18 +239,18 @@ def replace_one( collection = self.get_collection(mongo_collection, mongo_db=mongo_db) if not filter_doc: - filter_doc = {'_id': doc['_id']} + filter_doc = {"_id": doc["_id"]} return collection.replace_one(filter_doc, doc, **kwargs) def replace_many( self, mongo_collection: str, - docs: List[dict], - filter_docs: Optional[List[dict]] = None, - mongo_db: Optional[str] = None, + docs: list[dict], + filter_docs: list[dict] | None = None, + mongo_db: str | None = None, upsert: bool = False, - collation: Optional[pymongo.collation.Collation] = None, + collation: pymongo.collation.Collation | None = None, **kwargs, ) -> pymongo.results.BulkWriteResult: """ @@ -266,7 +267,7 @@ def replace_many( :param mongo_collection: The name of the collection to update. :param docs: The new documents. :param filter_docs: A list of queries that match the documents to replace. - Can be omitted; then the _id fields from docs will be used. + Can be omitted; then the _id fields from airflow.docs will be used. :param mongo_db: The name of the database to use. Can be omitted; then the database from the connection string is used. :param upsert: If ``True``, perform an insert if no documents @@ -279,7 +280,7 @@ def replace_many( collection = self.get_collection(mongo_collection, mongo_db=mongo_db) if not filter_docs: - filter_docs = [{'_id': doc['_id']} for doc in docs] + filter_docs = [{"_id": doc["_id"]} for doc in docs] requests = [ ReplaceOne(filter_docs[i], docs[i], upsert=upsert, collation=collation) for i in range(len(docs)) @@ -288,7 +289,7 @@ def replace_many( return collection.bulk_write(requests, **kwargs) def delete_one( - self, mongo_collection: str, filter_doc: dict, mongo_db: Optional[str] = None, **kwargs + self, mongo_collection: str, filter_doc: dict, mongo_db: str | None = None, **kwargs ) -> pymongo.results.DeleteResult: """ Deletes a single document in a mongo collection. @@ -305,7 +306,7 @@ def delete_one( return collection.delete_one(filter_doc, **kwargs) def delete_many( - self, mongo_collection: str, filter_doc: dict, mongo_db: Optional[str] = None, **kwargs + self, mongo_collection: str, filter_doc: dict, mongo_db: str | None = None, **kwargs ) -> pymongo.results.DeleteResult: """ Deletes one or more documents in a mongo collection. diff --git a/airflow/providers/mongo/provider.yaml b/airflow/providers/mongo/provider.yaml index efc5dd6d4d9b9..70e019d7fd6fc 100644 --- a/airflow/providers/mongo/provider.yaml +++ b/airflow/providers/mongo/provider.yaml @@ -22,6 +22,8 @@ description: | `MongoDB `__ versions: + - 3.1.0 + - 3.0.0 - 2.3.3 - 2.3.2 - 2.3.1 @@ -32,8 +34,12 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - dnspython>=1.13.0 + # pymongo 4.0.0 removes connection option `ssl_cert_reqs` which is used in providers-mongo/2.2.0 + # TODO: Upgrade to pymongo 4.0.0+ + - pymongo>=3.6.0,<4.0.0 integrations: - integration-name: MongoDB @@ -50,9 +56,6 @@ hooks: python-modules: - airflow.providers.mongo.hooks.mongo -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.mongo.hooks.mongo.MongoHook - connection-types: - hook-class-name: airflow.providers.mongo.hooks.mongo.MongoHook connection-type: mongo diff --git a/airflow/providers/mongo/sensors/mongo.py b/airflow/providers/mongo/sensors/mongo.py index 9d9a85268ec3c..83b9443f8e5c6 100644 --- a/airflow/providers/mongo/sensors/mongo.py +++ b/airflow/providers/mongo/sensors/mongo.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from typing import TYPE_CHECKING, Sequence from airflow.providers.mongo.hooks.mongo import MongoHook @@ -42,7 +44,7 @@ class MongoSensor(BaseSensorOperator): :param mongo_db: Target MongoDB name. """ - template_fields: Sequence[str] = ('collection', 'query') + template_fields: Sequence[str] = ("collection", "query") def __init__( self, *, collection: str, query: dict, mongo_conn_id: str = "mongo_default", mongo_db=None, **kwargs @@ -53,7 +55,7 @@ def __init__( self.query = query self.mongo_db = mongo_db - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: self.log.info( "Sensor check existence of the document that matches the following query: %s", self.query ) diff --git a/airflow/providers/mysql/.latest-doc-only-change.txt b/airflow/providers/mysql/.latest-doc-only-change.txt index 570fad6daee29..13020f96f4489 100644 --- a/airflow/providers/mysql/.latest-doc-only-change.txt +++ b/airflow/providers/mysql/.latest-doc-only-change.txt @@ -1 +1 @@ -97496ba2b41063fa24393c58c5c648a0cdb5a7f8 +df00436569bb6fb79ce8c0b7ca71dddf02b854ef diff --git a/airflow/providers/mysql/CHANGELOG.rst b/airflow/providers/mysql/CHANGELOG.rst index d713a34b05993..16c3bc1ab4279 100644 --- a/airflow/providers/mysql/CHANGELOG.rst +++ b/airflow/providers/mysql/CHANGELOG.rst @@ -19,9 +19,92 @@ The version of MySQL server has to be 5.6.4+. The exact version upper bound depe on the version of ``mysqlclient`` package. For example, ``mysqlclient`` 1.3.12 can only be used with MySQL server 5.6.4 through 5.7. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.3.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Features +~~~~~~~~ + +* ``Add SQLExecuteQueryOperator (#25717)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + +3.2.1 +..... + +Misc +~~~~ + +* ``Add common-sql lower bound for common-sql (#25789)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``D400 first line should end with period batch02 (#25268)`` + +3.2.0 +..... + +Features +~~~~~~~~ + +* ``Unify DbApiHook.run() method with the methods which override it (#23971)`` + + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Move all SQL classes to common-sql provider (#24836)`` + +Bug Fixes +~~~~~~~~~ + +* ``Close the MySQL connections once operations are done. (#24508)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Migrate MySQL example DAGs to new design #22453 (#24142)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.2.3 ..... diff --git a/airflow/providers/mysql/example_dags/example_mysql.py b/airflow/providers/mysql/example_dags/example_mysql.py deleted file mode 100644 index a41da05a24d89..0000000000000 --- a/airflow/providers/mysql/example_dags/example_mysql.py +++ /dev/null @@ -1,51 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example use of MySql related operators. -""" - -from datetime import datetime - -from airflow import DAG -from airflow.providers.mysql.operators.mysql import MySqlOperator - -dag = DAG( - 'example_mysql', - start_date=datetime(2021, 1, 1), - default_args={'mysql_conn_id': 'mysql_conn_id'}, - tags=['example'], - catchup=False, -) - -# [START howto_operator_mysql] - -drop_table_mysql_task = MySqlOperator(task_id='drop_table_mysql', sql=r"""DROP TABLE table_name;""", dag=dag) - -# [END howto_operator_mysql] - -# [START howto_operator_mysql_external_file] - -mysql_task = MySqlOperator( - task_id='drop_table_mysql_external_file', - sql='/scripts/drop_table.sql', - dag=dag, -) - -# [END howto_operator_mysql_external_file] - -drop_table_mysql_task >> mysql_task diff --git a/airflow/providers/mysql/hooks/mysql.py b/airflow/providers/mysql/hooks/mysql.py index 1f8513b9b7bbc..843ac0b8aadc8 100644 --- a/airflow/providers/mysql/hooks/mysql.py +++ b/airflow/providers/mysql/hooks/mysql.py @@ -15,19 +15,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module allows to connect to a MySQL database.""" +from __future__ import annotations + import json -from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Union -from airflow.hooks.dbapi import DbApiHook from airflow.models import Connection +from airflow.providers.common.sql.hooks.sql import DbApiHook if TYPE_CHECKING: from mysql.connector.abstracts import MySQLConnectionAbstract from MySQLdb.connections import Connection as MySQLdbConnection -MySQLConnectionTypes = Union['MySQLdbConnection', 'MySQLConnectionAbstract'] +MySQLConnectionTypes = Union["MySQLdbConnection", "MySQLConnectionAbstract"] class MySqlHook(DbApiHook): @@ -48,10 +49,10 @@ class MySqlHook(DbApiHook): :param connection: The :ref:`MySQL connection id ` used for MySQL credentials. """ - conn_name_attr = 'mysql_conn_id' - default_conn_name = 'mysql_default' - conn_type = 'mysql' - hook_name = 'MySQL' + conn_name_attr = "mysql_conn_id" + default_conn_name = "mysql_default" + conn_type = "mysql" + hook_name = "MySQL" supports_autocommit = True def __init__(self, *args, **kwargs) -> None: @@ -61,91 +62,101 @@ def __init__(self, *args, **kwargs) -> None: def set_autocommit(self, conn: MySQLConnectionTypes, autocommit: bool) -> None: """ - The MySQLdb (mysqlclient) client uses an `autocommit` method rather - than an `autocommit` property to set the autocommit setting + Set *autocommit*. + + *mysqlclient* uses an *autocommit* method rather than an *autocommit* + property, so we need to override this to support it. :param conn: connection to set autocommit setting :param autocommit: autocommit setting - :rtype: None """ - if hasattr(conn.__class__, 'autocommit') and isinstance(conn.__class__.autocommit, property): + if hasattr(conn.__class__, "autocommit") and isinstance(conn.__class__.autocommit, property): conn.autocommit = autocommit else: conn.autocommit(autocommit) def get_autocommit(self, conn: MySQLConnectionTypes) -> bool: """ - The MySQLdb (mysqlclient) client uses a `get_autocommit` method - rather than an `autocommit` property to get the autocommit setting + Whether *autocommit* is active. + + *mysqlclient* uses an *get_autocommit* method rather than an *autocommit* + property, so we need to override this to support it. :param conn: connection to get autocommit setting from. :return: connection autocommit setting - :rtype: bool """ - if hasattr(conn.__class__, 'autocommit') and isinstance(conn.__class__.autocommit, property): + if hasattr(conn.__class__, "autocommit") and isinstance(conn.__class__.autocommit, property): return conn.autocommit else: return conn.get_autocommit() - def _get_conn_config_mysql_client(self, conn: Connection) -> Dict: + def _get_conn_config_mysql_client(self, conn: Connection) -> dict: conn_config = { "user": conn.login, - "passwd": conn.password or '', - "host": conn.host or 'localhost', - "db": self.schema or conn.schema or '', + "passwd": conn.password or "", + "host": conn.host or "localhost", + "db": self.schema or conn.schema or "", } # check for authentication via AWS IAM - if conn.extra_dejson.get('iam', False): - conn_config['passwd'], conn.port = self.get_iam_token(conn) - conn_config["read_default_group"] = 'enable-cleartext-plugin' + if conn.extra_dejson.get("iam", False): + conn_config["passwd"], conn.port = self.get_iam_token(conn) + conn_config["read_default_group"] = "enable-cleartext-plugin" conn_config["port"] = int(conn.port) if conn.port else 3306 - if conn.extra_dejson.get('charset', False): + if conn.extra_dejson.get("charset", False): conn_config["charset"] = conn.extra_dejson["charset"] - if conn_config["charset"].lower() in ('utf8', 'utf-8'): + if conn_config["charset"].lower() in ("utf8", "utf-8"): conn_config["use_unicode"] = True - if conn.extra_dejson.get('cursor', False): + if conn.extra_dejson.get("cursor", False): import MySQLdb.cursors - if (conn.extra_dejson["cursor"]).lower() == 'sscursor': + if (conn.extra_dejson["cursor"]).lower() == "sscursor": conn_config["cursorclass"] = MySQLdb.cursors.SSCursor - elif (conn.extra_dejson["cursor"]).lower() == 'dictcursor': + elif (conn.extra_dejson["cursor"]).lower() == "dictcursor": conn_config["cursorclass"] = MySQLdb.cursors.DictCursor - elif (conn.extra_dejson["cursor"]).lower() == 'ssdictcursor': + elif (conn.extra_dejson["cursor"]).lower() == "ssdictcursor": conn_config["cursorclass"] = MySQLdb.cursors.SSDictCursor - local_infile = conn.extra_dejson.get('local_infile', False) - if conn.extra_dejson.get('ssl', False): + local_infile = conn.extra_dejson.get("local_infile", False) + if conn.extra_dejson.get("ssl", False): # SSL parameter for MySQL has to be a dictionary and in case # of extra/dejson we can get string if extra is passed via # URL parameters - dejson_ssl = conn.extra_dejson['ssl'] + dejson_ssl = conn.extra_dejson["ssl"] if isinstance(dejson_ssl, str): dejson_ssl = json.loads(dejson_ssl) - conn_config['ssl'] = dejson_ssl - if conn.extra_dejson.get('unix_socket'): - conn_config['unix_socket'] = conn.extra_dejson['unix_socket'] + conn_config["ssl"] = dejson_ssl + if conn.extra_dejson.get("ssl_mode", False): + conn_config["ssl_mode"] = conn.extra_dejson["ssl_mode"] + if conn.extra_dejson.get("unix_socket"): + conn_config["unix_socket"] = conn.extra_dejson["unix_socket"] if local_infile: conn_config["local_infile"] = 1 return conn_config - def _get_conn_config_mysql_connector_python(self, conn: Connection) -> Dict: + def _get_conn_config_mysql_connector_python(self, conn: Connection) -> dict: conn_config = { - 'user': conn.login, - 'password': conn.password or '', - 'host': conn.host or 'localhost', - 'database': self.schema or conn.schema or '', - 'port': int(conn.port) if conn.port else 3306, + "user": conn.login, + "password": conn.password or "", + "host": conn.host or "localhost", + "database": self.schema or conn.schema or "", + "port": int(conn.port) if conn.port else 3306, } - if conn.extra_dejson.get('allow_local_infile', False): + if conn.extra_dejson.get("allow_local_infile", False): conn_config["allow_local_infile"] = True + # Ref: https://dev.mysql.com/doc/connector-python/en/connector-python-connectargs.html + for key, value in conn.extra_dejson.items(): + if key.startswith("ssl_"): + conn_config[key] = value return conn_config def get_conn(self) -> MySQLConnectionTypes: """ + Connection to a MySQL database. + Establishes a connection to a mysql database by extracting the connection configuration from the Airflow connection. @@ -157,24 +168,24 @@ def get_conn(self) -> MySQLConnectionTypes: """ conn = self.connection or self.get_connection(getattr(self, self.conn_name_attr)) - client_name = conn.extra_dejson.get('client', 'mysqlclient') + client_name = conn.extra_dejson.get("client", "mysqlclient") - if client_name == 'mysqlclient': + if client_name == "mysqlclient": import MySQLdb conn_config = self._get_conn_config_mysql_client(conn) return MySQLdb.connect(**conn_config) - if client_name == 'mysql-connector-python': + if client_name == "mysql-connector-python": import mysql.connector conn_config = self._get_conn_config_mysql_connector_python(conn) return mysql.connector.connect(**conn_config) - raise ValueError('Unknown MySQL client name provided!') + raise ValueError("Unknown MySQL client name provided!") def bulk_load(self, table: str, tmp_file: str) -> None: - """Loads a tab-delimited file into a database table""" + """Load a tab-delimited file into a database table.""" conn = self.get_conn() cur = conn.cursor() cur.execute( @@ -184,9 +195,10 @@ def bulk_load(self, table: str, tmp_file: str) -> None: """ ) conn.commit() + conn.close() def bulk_dump(self, table: str, tmp_file: str) -> None: - """Dumps a database table into a tab-delimited file""" + """Dump a database table into a tab-delimited file.""" conn = self.get_conn() cur = conn.cursor() cur.execute( @@ -196,29 +208,33 @@ def bulk_dump(self, table: str, tmp_file: str) -> None: """ ) conn.commit() + conn.close() @staticmethod - def _serialize_cell(cell: object, conn: Optional[Connection] = None) -> object: + def _serialize_cell(cell: object, conn: Connection | None = None) -> Any: """ + Convert argument to a literal. + The package MySQLdb converts an argument to a literal when passing those separately to execute. Hence, this method does nothing. :param cell: The cell to insert into the table :param conn: The database connection :return: The same cell - :rtype: object """ return cell - def get_iam_token(self, conn: Connection) -> Tuple[str, int]: + def get_iam_token(self, conn: Connection) -> tuple[str, int]: """ + Retrieve a temporary password to connect to MySQL. + Uses AWSHook to retrieve a temporary password to connect to MySQL Port is required. If none is provided, default 3306 is used """ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook - aws_conn_id = conn.extra_dejson.get('aws_conn_id', 'aws_default') - aws_hook = AwsBaseHook(aws_conn_id, client_type='rds') + aws_conn_id = conn.extra_dejson.get("aws_conn_id", "aws_default") + aws_hook = AwsBaseHook(aws_conn_id, client_type="rds") if conn.port is None: port = 3306 else: @@ -228,7 +244,7 @@ def get_iam_token(self, conn: Connection) -> Tuple[str, int]: return token, port def bulk_load_custom( - self, table: str, tmp_file: str, duplicate_key_handling: str = 'IGNORE', extra_options: str = '' + self, table: str, tmp_file: str, duplicate_key_handling: str = "IGNORE", extra_options: str = "" ) -> None: """ A more configurable way to load local data from a file into the database. @@ -263,3 +279,4 @@ def bulk_load_custom( cursor.close() conn.commit() + conn.close() diff --git a/airflow/providers/mysql/operators/mysql.py b/airflow/providers/mysql/operators/mysql.py index d51a97e6fe38a..1609d09411a0a 100644 --- a/airflow/providers/mysql/operators/mysql.py +++ b/airflow/providers/mysql/operators/mysql.py @@ -15,18 +15,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import ast -from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Sequence, Union +from __future__ import annotations -from airflow.models import BaseOperator -from airflow.providers.mysql.hooks.mysql import MySqlHook -from airflow.www import utils as wwwutils +import warnings +from typing import Sequence -if TYPE_CHECKING: - from airflow.utils.context import Context +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator -class MySqlOperator(BaseOperator): +class MySqlOperator(SQLExecuteQueryOperator): """ Executes sql code in a specific MySQL database @@ -47,38 +44,26 @@ class MySqlOperator(BaseOperator): :param database: name of database which overwrite defined one in connection """ - template_fields: Sequence[str] = ('sql', 'parameters') - # TODO: Remove renderer check when the provider has an Airflow 2.3+ requirement. + template_fields: Sequence[str] = ("sql", "parameters") template_fields_renderers = { - 'sql': 'mysql' if 'mysql' in wwwutils.get_attr_renderer() else 'sql', - 'parameters': 'json', + "sql": "mysql", + "parameters": "json", } - template_ext: Sequence[str] = ('.sql', '.json') - ui_color = '#ededed' + template_ext: Sequence[str] = (".sql", ".json") + ui_color = "#ededed" def __init__( - self, - *, - sql: Union[str, List[str]], - mysql_conn_id: str = 'mysql_default', - parameters: Optional[Union[Mapping, Iterable]] = None, - autocommit: bool = False, - database: Optional[str] = None, - **kwargs, + self, *, mysql_conn_id: str = "mysql_default", database: str | None = None, **kwargs ) -> None: - super().__init__(**kwargs) - self.mysql_conn_id = mysql_conn_id - self.sql = sql - self.autocommit = autocommit - self.parameters = parameters - self.database = database + if database is not None: + hook_params = kwargs.pop("hook_params", {}) + kwargs["hook_params"] = {"schema": database, **hook_params} - def prepare_template(self) -> None: - """Parse template file for attribute parameters.""" - if isinstance(self.parameters, str): - self.parameters = ast.literal_eval(self.parameters) - - def execute(self, context: 'Context') -> None: - self.log.info('Executing: %s', self.sql) - hook = MySqlHook(mysql_conn_id=self.mysql_conn_id, schema=self.database) - hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) + super().__init__(conn_id=mysql_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`. + Also, you can provide `hook_params={'schema': }`.""", + DeprecationWarning, + stacklevel=2, + ) diff --git a/airflow/providers/mysql/provider.yaml b/airflow/providers/mysql/provider.yaml index 6368244d349e2..3c5030e9ddeb5 100644 --- a/airflow/providers/mysql/provider.yaml +++ b/airflow/providers/mysql/provider.yaml @@ -22,6 +22,11 @@ description: | `MySQL `__ versions: + - 3.3.0 + - 3.2.1 + - 3.2.0 + - 3.1.0 + - 3.0.0 - 2.2.3 - 2.2.2 - 2.2.1 @@ -34,8 +39,11 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-common-sql>=1.3.1 + - mysql-connector-python>=8.0.11; platform_machine != "aarch64" + - mysqlclient>=1.3.6; platform_machine != "aarch64" integrations: - integration-name: MySQL @@ -70,8 +78,6 @@ transfers: target-integration-name: MySQL python-module: airflow.providers.mysql.transfers.trino_to_mysql -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.mysql.hooks.mysql.MySqlHook connection-types: - hook-class-name: airflow.providers.mysql.hooks.mysql.MySqlHook diff --git a/airflow/providers/mysql/transfers/presto_to_mysql.py b/airflow/providers/mysql/transfers/presto_to_mysql.py index 45529237729e0..b38e6b8654d68 100644 --- a/airflow/providers/mysql/transfers/presto_to_mysql.py +++ b/airflow/providers/mysql/transfers/presto_to_mysql.py @@ -15,12 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.mysql.hooks.mysql import MySqlHook from airflow.providers.presto.hooks.presto import PrestoHook -from airflow.www import utils as wwwutils if TYPE_CHECKING: from airflow.utils.context import Context @@ -43,23 +44,22 @@ class PrestoToMySqlOperator(BaseOperator): the task twice won't double load data). (templated) """ - template_fields: Sequence[str] = ('sql', 'mysql_table', 'mysql_preoperator') - template_ext: Sequence[str] = ('.sql',) - # TODO: Remove renderer check when the provider has an Airflow 2.3+ requirement. + template_fields: Sequence[str] = ("sql", "mysql_table", "mysql_preoperator") + template_ext: Sequence[str] = (".sql",) template_fields_renderers = { "sql": "sql", - "mysql_preoperator": "mysql" if "mysql" in wwwutils.get_attr_renderer() else "sql", + "mysql_preoperator": "mysql", } - ui_color = '#a0e08c' + ui_color = "#a0e08c" def __init__( self, *, sql: str, mysql_table: str, - presto_conn_id: str = 'presto_default', - mysql_conn_id: str = 'mysql_default', - mysql_preoperator: Optional[str] = None, + presto_conn_id: str = "presto_default", + mysql_conn_id: str = "mysql_default", + mysql_preoperator: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -69,7 +69,7 @@ def __init__( self.mysql_preoperator = mysql_preoperator self.presto_conn_id = presto_conn_id - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: presto = PrestoHook(presto_conn_id=self.presto_conn_id) self.log.info("Extracting data from Presto: %s", self.sql) results = presto.get_records(self.sql) diff --git a/airflow/providers/mysql/transfers/s3_to_mysql.py b/airflow/providers/mysql/transfers/s3_to_mysql.py index 51ad95fcfb9d5..d2e0285ccac63 100644 --- a/airflow/providers/mysql/transfers/s3_to_mysql.py +++ b/airflow/providers/mysql/transfers/s3_to_mysql.py @@ -14,9 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import os -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -43,38 +44,38 @@ class S3ToMySqlOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 's3_source_key', - 'mysql_table', + "s3_source_key", + "mysql_table", ) template_ext: Sequence[str] = () - ui_color = '#f4a460' + ui_color = "#f4a460" def __init__( self, *, s3_source_key: str, mysql_table: str, - mysql_duplicate_key_handling: str = 'IGNORE', - mysql_extra_options: Optional[str] = None, - aws_conn_id: str = 'aws_default', - mysql_conn_id: str = 'mysql_default', + mysql_duplicate_key_handling: str = "IGNORE", + mysql_extra_options: str | None = None, + aws_conn_id: str = "aws_default", + mysql_conn_id: str = "mysql_default", **kwargs, ) -> None: super().__init__(**kwargs) self.s3_source_key = s3_source_key self.mysql_table = mysql_table self.mysql_duplicate_key_handling = mysql_duplicate_key_handling - self.mysql_extra_options = mysql_extra_options or '' + self.mysql_extra_options = mysql_extra_options or "" self.aws_conn_id = aws_conn_id self.mysql_conn_id = mysql_conn_id - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: """ Executes the transfer operation from S3 to MySQL. :param context: The context that is being provided when executing. """ - self.log.info('Loading %s to MySql table %s...', self.s3_source_key, self.mysql_table) + self.log.info("Loading %s to MySql table %s...", self.s3_source_key, self.mysql_table) s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) file = s3_hook.download_file(key=self.s3_source_key) diff --git a/airflow/providers/mysql/transfers/trino_to_mysql.py b/airflow/providers/mysql/transfers/trino_to_mysql.py index 3c013a7eaed2a..8ff5ed0446b5e 100644 --- a/airflow/providers/mysql/transfers/trino_to_mysql.py +++ b/airflow/providers/mysql/transfers/trino_to_mysql.py @@ -15,12 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.mysql.hooks.mysql import MySqlHook from airflow.providers.trino.hooks.trino import TrinoHook -from airflow.www import utils as wwwutils if TYPE_CHECKING: from airflow.utils.context import Context @@ -43,23 +44,22 @@ class TrinoToMySqlOperator(BaseOperator): the task twice won't double load data). (templated) """ - template_fields: Sequence[str] = ('sql', 'mysql_table', 'mysql_preoperator') - template_ext: Sequence[str] = ('.sql',) - # TODO: Remove renderer check when the provider has an Airflow 2.3+ requirement. + template_fields: Sequence[str] = ("sql", "mysql_table", "mysql_preoperator") + template_ext: Sequence[str] = (".sql",) template_fields_renderers = { "sql": "sql", - "mysql_preoperator": "mysql" if "mysql" in wwwutils.get_attr_renderer() else "sql", + "mysql_preoperator": "mysql", } - ui_color = '#a0e08c' + ui_color = "#a0e08c" def __init__( self, *, sql: str, mysql_table: str, - trino_conn_id: str = 'trino_default', - mysql_conn_id: str = 'mysql_default', - mysql_preoperator: Optional[str] = None, + trino_conn_id: str = "trino_default", + mysql_conn_id: str = "mysql_default", + mysql_preoperator: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -69,7 +69,7 @@ def __init__( self.mysql_preoperator = mysql_preoperator self.trino_conn_id = trino_conn_id - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: trino = TrinoHook(trino_conn_id=self.trino_conn_id) self.log.info("Extracting data from Trino: %s", self.sql) results = trino.get_records(self.sql) diff --git a/airflow/providers/mysql/transfers/vertica_to_mysql.py b/airflow/providers/mysql/transfers/vertica_to_mysql.py index 595b2cb01b3dc..a7df1f029dd24 100644 --- a/airflow/providers/mysql/transfers/vertica_to_mysql.py +++ b/airflow/providers/mysql/transfers/vertica_to_mysql.py @@ -15,10 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from contextlib import closing from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Sequence import MySQLdb import unicodecsv as csv @@ -26,14 +27,10 @@ from airflow.models import BaseOperator from airflow.providers.mysql.hooks.mysql import MySqlHook from airflow.providers.vertica.hooks.vertica import VerticaHook -from airflow.www import utils as wwwutils if TYPE_CHECKING: from airflow.utils.context import Context -# TODO: Remove renderer check when the provider has an Airflow 2.3+ requirement. -MYSQL_RENDERER = 'mysql' if 'mysql' in wwwutils.get_attr_renderer() else 'sql' - class VerticaToMySqlOperator(BaseOperator): """ @@ -57,23 +54,23 @@ class VerticaToMySqlOperator(BaseOperator): destination MySQL connection: {'local_infile': true}. """ - template_fields: Sequence[str] = ('sql', 'mysql_table', 'mysql_preoperator', 'mysql_postoperator') - template_ext: Sequence[str] = ('.sql',) + template_fields: Sequence[str] = ("sql", "mysql_table", "mysql_preoperator", "mysql_postoperator") + template_ext: Sequence[str] = (".sql",) template_fields_renderers = { "sql": "sql", - "mysql_preoperator": MYSQL_RENDERER, - "mysql_postoperator": MYSQL_RENDERER, + "mysql_preoperator": "mysql", + "mysql_postoperator": "mysql", } - ui_color = '#a0e08c' + ui_color = "#a0e08c" def __init__( self, sql: str, mysql_table: str, - vertica_conn_id: str = 'vertica_default', - mysql_conn_id: str = 'mysql_default', - mysql_preoperator: Optional[str] = None, - mysql_postoperator: Optional[str] = None, + vertica_conn_id: str = "vertica_default", + mysql_conn_id: str = "mysql_default", + mysql_preoperator: str | None = None, + mysql_postoperator: str | None = None, bulk_load: bool = False, *args, **kwargs, @@ -87,7 +84,7 @@ def __init__( self.vertica_conn_id = vertica_conn_id self.bulk_load = bulk_load - def execute(self, context: 'Context'): + def execute(self, context: Context): vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) @@ -133,7 +130,7 @@ def _bulk_load_transfer(self, mysql, vertica): self.log.info("Selecting rows from Vertica to local file %s...", tmpfile.name) self.log.info(self.sql) - csv_writer = csv.writer(tmpfile, delimiter='\t', encoding='utf-8') + csv_writer = csv.writer(tmpfile, delimiter="\t", encoding="utf-8") for row in cursor.iterate(): csv_writer.writerow(row) count += 1 diff --git a/airflow/providers/neo4j/.latest-doc-only-change.txt b/airflow/providers/neo4j/.latest-doc-only-change.txt index 029fd1fd22aec..ff7136e07d744 100644 --- a/airflow/providers/neo4j/.latest-doc-only-change.txt +++ b/airflow/providers/neo4j/.latest-doc-only-change.txt @@ -1 +1 @@ -2d109401b3566aef613501691d18cf7e4c776cd2 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/neo4j/CHANGELOG.rst b/airflow/providers/neo4j/CHANGELOG.rst index e5f3dad9b8cc3..25a416cf9995c 100644 --- a/airflow/providers/neo4j/CHANGELOG.rst +++ b/airflow/providers/neo4j/CHANGELOG.rst @@ -17,9 +17,63 @@ specific language governing permissions and limitations under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.2.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Add documentation for July 2022 Provider's release (#25030)`` +* ``Unify DbApiHook.run() method with the methods which override it (#23971)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Migrate Neo4j example DAGs to new design #22454 (#24143)`` + * ``Prepare provider documentation 2022.05.11 (#23631)`` + * ``Bump pre-commit hook versions (#22887)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.1.3 ..... diff --git a/airflow/providers/neo4j/example_dags/example_neo4j.py b/airflow/providers/neo4j/example_dags/example_neo4j.py deleted file mode 100644 index f1ebbb0011aa5..0000000000000 --- a/airflow/providers/neo4j/example_dags/example_neo4j.py +++ /dev/null @@ -1,43 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example use of Neo4j related operators. -""" - -from datetime import datetime - -from airflow import DAG -from airflow.providers.neo4j.operators.neo4j import Neo4jOperator - -dag = DAG( - 'example_neo4j', - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) - -# [START run_query_neo4j_operator] - -neo4j_task = Neo4jOperator( - task_id='run_neo4j_query', - neo4j_conn_id='neo4j_conn_id', - sql='MATCH (tom {name: "Tom Hanks", date: "{{ds}}"}) RETURN tom', - dag=dag, -) - -# [END run_query_neo4j_operator] diff --git a/airflow/providers/neo4j/hooks/neo4j.py b/airflow/providers/neo4j/hooks/neo4j.py index ad9ce4ac3cb88..8f193dec38d87 100644 --- a/airflow/providers/neo4j/hooks/neo4j.py +++ b/airflow/providers/neo4j/hooks/neo4j.py @@ -15,10 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module allows to connect to a Neo4j database.""" +from __future__ import annotations + +from typing import Any -from neo4j import GraphDatabase, Neo4jDriver, Result +from neo4j import Driver, GraphDatabase from airflow.hooks.base import BaseHook from airflow.models import Connection @@ -33,18 +35,18 @@ class Neo4jHook(BaseHook): :param neo4j_conn_id: Reference to :ref:`Neo4j connection id `. """ - conn_name_attr = 'neo4j_conn_id' - default_conn_name = 'neo4j_default' - conn_type = 'neo4j' - hook_name = 'Neo4j' + conn_name_attr = "neo4j_conn_id" + default_conn_name = "neo4j_default" + conn_type = "neo4j" + hook_name = "Neo4j" def __init__(self, conn_id: str = default_conn_name, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.neo4j_conn_id = conn_id self.connection = kwargs.pop("connection", None) - self.client = None + self.client: Driver | None = None - def get_conn(self) -> Neo4jDriver: + def get_conn(self) -> Driver: """ Function that initiates a new Neo4j connection with username, password and database schema. @@ -55,9 +57,9 @@ def get_conn(self) -> Neo4jDriver: self.connection = self.get_connection(self.neo4j_conn_id) uri = self.get_uri(self.connection) - self.log.info('URI: %s', uri) + self.log.info("URI: %s", uri) - is_encrypted = self.connection.extra_dejson.get('encrypted', False) + is_encrypted = self.connection.extra_dejson.get("encrypted", False) self.client = GraphDatabase.driver( uri, auth=(self.connection.login, self.connection.password), encrypted=is_encrypted @@ -76,24 +78,24 @@ def get_uri(self, conn: Connection) -> str: :param conn: connection object. :return: uri """ - use_neo4j_scheme = conn.extra_dejson.get('neo4j_scheme', False) - scheme = 'neo4j' if use_neo4j_scheme else 'bolt' + use_neo4j_scheme = conn.extra_dejson.get("neo4j_scheme", False) + scheme = "neo4j" if use_neo4j_scheme else "bolt" # Self signed certificates - ssc = conn.extra_dejson.get('certs_self_signed', False) + ssc = conn.extra_dejson.get("certs_self_signed", False) # Only certificates signed by CA. - trusted_ca = conn.extra_dejson.get('certs_trusted_ca', False) - encryption_scheme = '' + trusted_ca = conn.extra_dejson.get("certs_trusted_ca", False) + encryption_scheme = "" if ssc: - encryption_scheme = '+ssc' + encryption_scheme = "+ssc" elif trusted_ca: - encryption_scheme = '+s' + encryption_scheme = "+s" return f"{scheme}{encryption_scheme}://{conn.host}:{7687 if conn.port is None else conn.port}" - def run(self, query) -> Result: + def run(self, query) -> list[Any]: """ Function to create a neo4j session and execute the query in the session. diff --git a/airflow/providers/neo4j/operators/neo4j.py b/airflow/providers/neo4j/operators/neo4j.py index b61f0734f0841..fc4d043daad7e 100644 --- a/airflow/providers/neo4j/operators/neo4j.py +++ b/airflow/providers/neo4j/operators/neo4j.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable, Mapping, Sequence from airflow.models import BaseOperator from airflow.providers.neo4j.hooks.neo4j import Neo4jHook @@ -37,14 +39,14 @@ class Neo4jOperator(BaseOperator): :param neo4j_conn_id: Reference to :ref:`Neo4j connection id `. """ - template_fields: Sequence[str] = ('sql',) + template_fields: Sequence[str] = ("sql",) def __init__( self, *, sql: str, - neo4j_conn_id: str = 'neo4j_default', - parameters: Optional[Union[Mapping, Iterable]] = None, + neo4j_conn_id: str = "neo4j_default", + parameters: Iterable | Mapping | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -52,7 +54,7 @@ def __init__( self.sql = sql self.parameters = parameters - def execute(self, context: 'Context') -> None: - self.log.info('Executing: %s', self.sql) + def execute(self, context: Context) -> None: + self.log.info("Executing: %s", self.sql) hook = Neo4jHook(conn_id=self.neo4j_conn_id) hook.run(self.sql) diff --git a/airflow/providers/neo4j/provider.yaml b/airflow/providers/neo4j/provider.yaml index e5fa643cdc844..7a48d9371ee61 100644 --- a/airflow/providers/neo4j/provider.yaml +++ b/airflow/providers/neo4j/provider.yaml @@ -22,6 +22,9 @@ description: | `Neo4j `__ versions: + - 3.2.0 + - 3.1.0 + - 3.0.0 - 2.1.3 - 2.1.2 - 2.1.1 @@ -32,8 +35,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - neo4j>=4.2.1 integrations: - integration-name: Neo4j @@ -52,8 +56,6 @@ hooks: python-modules: - airflow.providers.neo4j.hooks.neo4j -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.neo4j.hooks.neo4j.Neo4jHook connection-types: - hook-class-name: airflow.providers.neo4j.hooks.neo4j.Neo4jHook diff --git a/airflow/providers/odbc/CHANGELOG.rst b/airflow/providers/odbc/CHANGELOG.rst index 02f8409b0a235..0dd5d95948d9f 100644 --- a/airflow/providers/odbc/CHANGELOG.rst +++ b/airflow/providers/odbc/CHANGELOG.rst @@ -16,9 +16,77 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.2.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + +3.1.2 +..... + +Misc +~~~~ + +* ``Add common-sql lower bound for common-sql (#25789)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +3.1.1 +..... + +Bug Fixes +~~~~~~~~~ + +* ``Fix odbc hook sqlalchemy_scheme docstring (#25421)`` + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Move all SQL classes to common-sql provider (#24836)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.0.4 ..... diff --git a/airflow/providers/odbc/hooks/odbc.py b/airflow/providers/odbc/hooks/odbc.py index 9ce32fa337118..20e8e8864e5ab 100644 --- a/airflow/providers/odbc/hooks/odbc.py +++ b/airflow/providers/odbc/hooks/odbc.py @@ -15,12 +15,14 @@ # specific language governing permissions and limitations # under the License. """This module contains ODBC hook.""" -from typing import Any, Optional +from __future__ import annotations + +from typing import Any from urllib.parse import quote_plus import pyodbc -from airflow.hooks.dbapi import DbApiHook +from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.utils.helpers import merge_dicts @@ -31,21 +33,21 @@ class OdbcHook(DbApiHook): See :doc:`/connections/odbc` for full documentation. """ - DEFAULT_SQLALCHEMY_SCHEME = 'mssql+pyodbc' - conn_name_attr = 'odbc_conn_id' - default_conn_name = 'odbc_default' - conn_type = 'odbc' - hook_name = 'ODBC' + DEFAULT_SQLALCHEMY_SCHEME = "mssql+pyodbc" + conn_name_attr = "odbc_conn_id" + default_conn_name = "odbc_default" + conn_type = "odbc" + hook_name = "ODBC" supports_autocommit = True def __init__( self, *args, - database: Optional[str] = None, - driver: Optional[str] = None, - dsn: Optional[str] = None, - connect_kwargs: Optional[dict] = None, - sqlalchemy_scheme: Optional[str] = None, + database: str | None = None, + driver: str | None = None, + dsn: str | None = None, + connect_kwargs: dict | None = None, + sqlalchemy_scheme: str | None = None, **kwargs, ) -> None: """ @@ -75,16 +77,16 @@ def connection(self): return self._connection @property - def database(self) -> Optional[str]: + def database(self) -> str | None: """Database provided in init if exists; otherwise, ``schema`` from ``Connection`` object.""" return self._database or self.connection.schema @property - def sqlalchemy_scheme(self) -> Optional[str]: - """Database provided in init if exists; otherwise, ``schema`` from ``Connection`` object.""" + def sqlalchemy_scheme(self) -> str: + """Sqlalchemy scheme either from constructor, connection extras or default.""" return ( self._sqlalchemy_scheme - or self.connection_extra_lower.get('sqlalchemy_scheme') + or self.connection_extra_lower.get("sqlalchemy_scheme") or self.DEFAULT_SQLALCHEMY_SCHEME ) @@ -98,19 +100,19 @@ def connection_extra_lower(self) -> dict: return {k.lower(): v for k, v in self.connection.extra_dejson.items()} @property - def driver(self) -> Optional[str]: + def driver(self) -> str | None: """Driver from init param if given; else try to find one in connection extra.""" if not self._driver: - driver = self.connection_extra_lower.get('driver') + driver = self.connection_extra_lower.get("driver") if driver: self._driver = driver - return self._driver and self._driver.strip().lstrip('{').rstrip('}').strip() + return self._driver and self._driver.strip().lstrip("{").rstrip("}").strip() @property - def dsn(self) -> Optional[str]: + def dsn(self) -> str | None: """DSN from init param if given; else try to find one in connection extra.""" if not self._dsn: - dsn = self.connection_extra_lower.get('dsn') + dsn = self.connection_extra_lower.get("dsn") if dsn: self._dsn = dsn.strip() return self._dsn @@ -124,7 +126,7 @@ def odbc_connection_string(self): ``Connection.extra`` will be added to the connection string. """ if not self._conn_str: - conn_str = '' + conn_str = "" if self.driver: conn_str += f"DRIVER={{{self.driver}}};" if self.dsn: @@ -141,7 +143,7 @@ def odbc_connection_string(self): if self.connection.port: conn_str += f"PORT={self.connection.port};" - extra_exclude = {'driver', 'dsn', 'connect_kwargs', 'sqlalchemy_scheme'} + extra_exclude = {"driver", "dsn", "connect_kwargs", "sqlalchemy_scheme"} extra_params = { k: v for k, v in self.connection.extra_dejson.items() if not k.lower() in extra_exclude } @@ -161,13 +163,13 @@ def connect_kwargs(self) -> dict: If ``attrs_before`` provided, keys and values are converted to int, as required by pyodbc. """ - conn_connect_kwargs = self.connection_extra_lower.get('connect_kwargs', {}) + conn_connect_kwargs = self.connection_extra_lower.get("connect_kwargs", {}) hook_connect_kwargs = self._connect_kwargs or {} merged_connect_kwargs = merge_dicts(conn_connect_kwargs, hook_connect_kwargs) - if 'attrs_before' in merged_connect_kwargs: - merged_connect_kwargs['attrs_before'] = { - int(k): int(v) for k, v in merged_connect_kwargs['attrs_before'].items() + if "attrs_before" in merged_connect_kwargs: + merged_connect_kwargs["attrs_before"] = { + int(k): int(v) for k, v in merged_connect_kwargs["attrs_before"].items() } return merged_connect_kwargs @@ -178,13 +180,16 @@ def get_conn(self) -> pyodbc.Connection: return conn def get_uri(self) -> str: - """URI invoked in :py:meth:`~airflow.hooks.dbapi.DbApiHook.get_sqlalchemy_engine` method""" + """ + URI invoked in :py:meth:`~airflow.providers.common.sql.hooks.sql.DbApiHook.get_sqlalchemy_engine` + method. + """ quoted_conn_str = quote_plus(self.odbc_connection_string) uri = f"{self.sqlalchemy_scheme}:///?odbc_connect={quoted_conn_str}" return uri def get_sqlalchemy_connection( - self, connect_kwargs: Optional[dict] = None, engine_kwargs: Optional[dict] = None + self, connect_kwargs: dict | None = None, engine_kwargs: dict | None = None ) -> Any: """Sqlalchemy connection object""" engine = self.get_sqlalchemy_engine(engine_kwargs=engine_kwargs) diff --git a/airflow/providers/odbc/provider.yaml b/airflow/providers/odbc/provider.yaml index 35ae833ee43f1..c1199437005f3 100644 --- a/airflow/providers/odbc/provider.yaml +++ b/airflow/providers/odbc/provider.yaml @@ -22,6 +22,11 @@ description: | `ODBC `__ versions: + - 3.2.0 + - 3.1.2 + - 3.1.1 + - 3.1.0 + - 3.0.0 - 2.0.4 - 2.0.3 - 2.0.2 @@ -30,8 +35,10 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-common-sql>=1.3.1 + - pyodbc integrations: - integration-name: ODBC @@ -44,8 +51,6 @@ hooks: python-modules: - airflow.providers.odbc.hooks.odbc -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.odbc.hooks.odbc.OdbcHook connection-types: - hook-class-name: airflow.providers.odbc.hooks.odbc.OdbcHook diff --git a/airflow/providers/openfaas/.latest-doc-only-change.txt b/airflow/providers/openfaas/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/openfaas/.latest-doc-only-change.txt +++ b/airflow/providers/openfaas/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/openfaas/CHANGELOG.rst b/airflow/providers/openfaas/CHANGELOG.rst index eabc696fbe474..780dac3d22eed 100644 --- a/airflow/providers/openfaas/CHANGELOG.rst +++ b/airflow/providers/openfaas/CHANGELOG.rst @@ -16,9 +16,50 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.0.3 ..... diff --git a/airflow/providers/openfaas/hooks/openfaas.py b/airflow/providers/openfaas/hooks/openfaas.py index a5d40b600004b..a9202ad9d5362 100644 --- a/airflow/providers/openfaas/hooks/openfaas.py +++ b/airflow/providers/openfaas/hooks/openfaas.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import Any, Dict +from typing import Any import requests @@ -41,7 +42,7 @@ class OpenFaasHook(BaseHook): DEPLOY_FUNCTION = "/system/functions" UPDATE_FUNCTION = "/system/functions" - def __init__(self, function_name=None, conn_id: str = 'open_faas_default', *args, **kwargs) -> None: + def __init__(self, function_name=None, conn_id: str = "open_faas_default", *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.function_name = function_name self.conn_id = conn_id @@ -50,7 +51,7 @@ def get_conn(self): conn = self.get_connection(self.conn_id) return conn - def deploy_function(self, overwrite_function_if_exist: bool, body: Dict[str, Any]) -> None: + def deploy_function(self, overwrite_function_if_exist: bool, body: dict[str, Any]) -> None: """Deploy OpenFaaS function""" if overwrite_function_if_exist: self.log.info("Function already exist %s going to update", self.function_name) @@ -62,11 +63,11 @@ def deploy_function(self, overwrite_function_if_exist: bool, body: Dict[str, Any if response.status_code != OK_STATUS_CODE: self.log.error("Response status %d", response.status_code) self.log.error("Failed to deploy") - raise AirflowException('failed to deploy') + raise AirflowException("failed to deploy") else: self.log.info("Function deployed %s", self.function_name) - def invoke_async_function(self, body: Dict[str, Any]) -> None: + def invoke_async_function(self, body: dict[str, Any]) -> None: """Invoking function asynchronously""" url = self.get_conn().host + self.INVOKE_ASYNC_FUNCTION + self.function_name self.log.info("Invoking function asynchronously %s", url) @@ -75,9 +76,9 @@ def invoke_async_function(self, body: Dict[str, Any]) -> None: self.log.info("Invoked %s", self.function_name) else: self.log.error("Response status %d", response.status_code) - raise AirflowException('failed to invoke function') + raise AirflowException("failed to invoke function") - def invoke_function(self, body: Dict[str, Any]) -> None: + def invoke_function(self, body: dict[str, Any]) -> None: """Invoking function synchronously, will block until function completes and returns""" url = self.get_conn().host + self.INVOKE_FUNCTION + self.function_name self.log.info("Invoking function synchronously %s", url) @@ -88,9 +89,9 @@ def invoke_function(self, body: Dict[str, Any]) -> None: self.log.info("Response %s", response.text) else: self.log.error("Response status %d", response.status_code) - raise AirflowException('failed to invoke function') + raise AirflowException("failed to invoke function") - def update_function(self, body: Dict[str, Any]) -> None: + def update_function(self, body: dict[str, Any]) -> None: """Update OpenFaaS function""" url = self.get_conn().host + self.UPDATE_FUNCTION self.log.info("Updating function %s", url) @@ -98,7 +99,7 @@ def update_function(self, body: Dict[str, Any]) -> None: if response.status_code != OK_STATUS_CODE: self.log.error("Response status %d", response.status_code) self.log.error("Failed to update response %s", response.content.decode("utf-8")) - raise AirflowException('failed to update ' + self.function_name) + raise AirflowException("failed to update " + self.function_name) else: self.log.info("Function was updated") diff --git a/airflow/providers/openfaas/provider.yaml b/airflow/providers/openfaas/provider.yaml index 38cf110447962..90fed6940ffd3 100644 --- a/airflow/providers/openfaas/provider.yaml +++ b/airflow/providers/openfaas/provider.yaml @@ -22,6 +22,8 @@ description: | `OpenFaaS `__ versions: + - 3.1.0 + - 3.0.0 - 2.0.3 - 2.0.2 - 2.0.1 @@ -30,8 +32,8 @@ versions: - 1.1.0 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 integrations: - integration-name: OpenFaaS diff --git a/airflow/providers/opsgenie/.latest-doc-only-change.txt b/airflow/providers/opsgenie/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/opsgenie/.latest-doc-only-change.txt +++ b/airflow/providers/opsgenie/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/opsgenie/CHANGELOG.rst b/airflow/providers/opsgenie/CHANGELOG.rst index bb6d463df3eac..f2764429329b4 100644 --- a/airflow/providers/opsgenie/CHANGELOG.rst +++ b/airflow/providers/opsgenie/CHANGELOG.rst @@ -16,9 +16,60 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +5.0.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +Remove 'OpsgenieAlertOperator' also removed hooks.opsgenie_alert path + +* ``Remove deprecated code from Opsgenie provider (#27252)`` + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + + +4.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Migrate Opsgenie example DAGs to new design #22455 (#24144)`` + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 3.1.0 ..... diff --git a/airflow/providers/opsgenie/hooks/opsgenie.py b/airflow/providers/opsgenie/hooks/opsgenie.py index 602ca4e7a5cae..a65bcd9649d15 100644 --- a/airflow/providers/opsgenie/hooks/opsgenie.py +++ b/airflow/providers/opsgenie/hooks/opsgenie.py @@ -15,9 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# - -from typing import Optional +from __future__ import annotations from opsgenie_sdk import ( AlertApi, @@ -46,18 +44,18 @@ class OpsgenieAlertHook(BaseHook): """ - conn_name_attr = 'opsgenie_conn_id' - default_conn_name = 'opsgenie_default' - conn_type = 'opsgenie' - hook_name = 'Opsgenie' + conn_name_attr = "opsgenie_conn_id" + default_conn_name = "opsgenie_default" + conn_type = "opsgenie" + hook_name = "Opsgenie" - def __init__(self, opsgenie_conn_id: str = 'opsgenie_default') -> None: + def __init__(self, opsgenie_conn_id: str = "opsgenie_default") -> None: super().__init__() # type: ignore[misc] self.conn_id = opsgenie_conn_id configuration = Configuration() conn = self.get_connection(self.conn_id) - configuration.api_key['Authorization'] = conn.password - configuration.host = conn.host or 'https://api.opsgenie.com' + configuration.api_key["Authorization"] = conn.password + configuration.host = conn.host or "https://api.opsgenie.com" self.alert_api_instance = AlertApi(ApiClient(configuration)) def _get_api_key(self) -> str: @@ -65,7 +63,6 @@ def _get_api_key(self) -> str: Get the API key from the connection :return: API key - :rtype: str """ conn = self.get_connection(self.conn_id) return conn.password @@ -75,18 +72,16 @@ def get_conn(self) -> AlertApi: Get the underlying AlertApi client :return: AlertApi client - :rtype: opsgenie_sdk.AlertApi """ return self.alert_api_instance - def create_alert(self, payload: Optional[dict] = None) -> SuccessResponse: + def create_alert(self, payload: dict | None = None) -> SuccessResponse: """ Create an alert on Opsgenie :param payload: Opsgenie API Create Alert payload values See https://docs.opsgenie.com/docs/alert-api#section-create-alert :return: api response - :rtype: opsgenie_sdk.SuccessResponse """ payload = payload or {} @@ -95,15 +90,15 @@ def create_alert(self, payload: Optional[dict] = None) -> SuccessResponse: api_response = self.alert_api_instance.create_alert(create_alert_payload) return api_response except OpenApiException as e: - self.log.exception('Exception when sending alert to opsgenie with payload: %s', payload) + self.log.exception("Exception when sending alert to opsgenie with payload: %s", payload) raise e def close_alert( self, identifier: str, - identifier_type: Optional[str] = 'id', - payload: Optional[dict] = None, - **kwargs: Optional[dict], + identifier_type: str | None = "id", + payload: dict | None = None, + **kwargs: dict | None, ) -> SuccessResponse: """ Close an alert in Opsgenie @@ -117,7 +112,6 @@ def close_alert( :return: SuccessResponse If the method is called asynchronously, returns the request thread. - :rtype: opsgenie_sdk.SuccessResponse """ payload = payload or {} try: @@ -130,15 +124,15 @@ def close_alert( ) return api_response except OpenApiException as e: - self.log.exception('Exception when closing alert in opsgenie with payload: %s', payload) + self.log.exception("Exception when closing alert in opsgenie with payload: %s", payload) raise e def delete_alert( self, identifier: str, - identifier_type: Optional[str] = None, - user: Optional[str] = None, - source: Optional[str] = None, + identifier_type: str | None = None, + user: str | None = None, + source: str | None = None, ) -> SuccessResponse: """ Delete an alert in Opsgenie @@ -149,7 +143,6 @@ def delete_alert( :param user: Display name of the request owner. :param source: Display name of the request source :return: SuccessResponse - :rtype: opsgenie_sdk.SuccessResponse """ try: api_response = self.alert_api_instance.delete_alert( @@ -160,5 +153,5 @@ def delete_alert( ) return api_response except OpenApiException as e: - self.log.exception('Exception when calling AlertApi->delete_alert: %s\n', e) + self.log.exception("Exception when calling AlertApi->delete_alert: %s\n", e) raise e diff --git a/airflow/providers/opsgenie/hooks/opsgenie_alert.py b/airflow/providers/opsgenie/hooks/opsgenie_alert.py deleted file mode 100644 index abd9c89e09308..0000000000000 --- a/airflow/providers/opsgenie/hooks/opsgenie_alert.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.opsgenie.hooks.opsgenie`.""" - -import warnings - -from airflow.providers.opsgenie.hooks.opsgenie import OpsgenieAlertHook # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.opsgenie.hooks.opsgenie`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/opsgenie/operators/opsgenie.py b/airflow/providers/opsgenie/operators/opsgenie.py index 1e2ff05b60319..eb3df82f6f923 100644 --- a/airflow/providers/opsgenie/operators/opsgenie.py +++ b/airflow/providers/opsgenie/operators/opsgenie.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence from airflow.models import BaseOperator from airflow.providers.opsgenie.hooks.opsgenie import OpsgenieAlertHook @@ -59,25 +60,25 @@ class OpsgenieCreateAlertOperator(BaseOperator): :param note: Additional note that will be added while creating the alert. (templated) """ - template_fields: Sequence[str] = ('message', 'alias', 'description', 'entity', 'priority', 'note') + template_fields: Sequence[str] = ("message", "alias", "description", "entity", "priority", "note") def __init__( self, *, message: str, - opsgenie_conn_id: str = 'opsgenie_default', - alias: Optional[str] = None, - description: Optional[str] = None, - responders: Optional[List[dict]] = None, - visible_to: Optional[List[dict]] = None, - actions: Optional[List[str]] = None, - tags: Optional[List[str]] = None, - details: Optional[dict] = None, - entity: Optional[str] = None, - source: Optional[str] = None, - priority: Optional[str] = None, - user: Optional[str] = None, - note: Optional[str] = None, + opsgenie_conn_id: str = "opsgenie_default", + alias: str | None = None, + description: str | None = None, + responders: list[dict] | None = None, + visible_to: list[dict] | None = None, + actions: list[str] | None = None, + tags: list[str] | None = None, + details: dict | None = None, + entity: str | None = None, + source: str | None = None, + priority: str | None = None, + user: str | None = None, + note: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -96,9 +97,9 @@ def __init__( self.priority = priority self.user = user self.note = note - self.hook: Optional[OpsgenieAlertHook] = None + self.hook: OpsgenieAlertHook | None = None - def _build_opsgenie_payload(self) -> Dict[str, Any]: + def _build_opsgenie_payload(self) -> dict[str, Any]: """ Construct the Opsgenie JSON payload. All relevant parameters are combined here to a valid Opsgenie JSON payload. @@ -127,7 +128,7 @@ def _build_opsgenie_payload(self) -> Dict[str, Any]: payload[key] = val return payload - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: """Call the OpsgenieAlertHook to post message""" self.hook = OpsgenieAlertHook(self.opsgenie_conn_id) self.hook.create_alert(self._build_opsgenie_payload()) @@ -161,12 +162,12 @@ def __init__( self, *, identifier: str, - opsgenie_conn_id: str = 'opsgenie_default', - identifier_type: Optional[str] = None, - user: Optional[str] = None, - note: Optional[str] = None, - source: Optional[str] = None, - close_alert_kwargs: Optional[dict] = None, + opsgenie_conn_id: str = "opsgenie_default", + identifier_type: str | None = None, + user: str | None = None, + note: str | None = None, + source: str | None = None, + close_alert_kwargs: dict | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -178,9 +179,9 @@ def __init__( self.note = note self.source = source self.close_alert_kwargs = close_alert_kwargs - self.hook: Optional[OpsgenieAlertHook] = None + self.hook: OpsgenieAlertHook | None = None - def _build_opsgenie_close_alert_payload(self) -> Dict[str, Any]: + def _build_opsgenie_close_alert_payload(self) -> dict[str, Any]: """ Construct the Opsgenie JSON payload. All relevant parameters are combined here to a valid Opsgenie JSON payload. @@ -199,7 +200,7 @@ def _build_opsgenie_close_alert_payload(self) -> Dict[str, Any]: payload[key] = val return payload - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: """Call the OpsgenieAlertHook to close alert""" self.hook = OpsgenieAlertHook(self.opsgenie_conn_id) self.hook.close_alert( @@ -232,16 +233,16 @@ class OpsgenieDeleteAlertOperator(BaseOperator): :param source: Display name of the request source """ - template_fields: Sequence[str] = ('identifier',) + template_fields: Sequence[str] = ("identifier",) def __init__( self, *, identifier: str, - opsgenie_conn_id: str = 'opsgenie_default', - identifier_type: Optional[str] = None, - user: Optional[str] = None, - source: Optional[str] = None, + opsgenie_conn_id: str = "opsgenie_default", + identifier_type: str | None = None, + user: str | None = None, + source: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -252,7 +253,7 @@ def __init__( self.user = user self.source = source - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: """Call the OpsgenieAlertHook to delete alert""" hook = OpsgenieAlertHook(self.opsgenie_conn_id) hook.delete_alert( diff --git a/airflow/providers/opsgenie/operators/opsgenie_alert.py b/airflow/providers/opsgenie/operators/opsgenie_alert.py deleted file mode 100644 index ea133f5fa1b7d..0000000000000 --- a/airflow/providers/opsgenie/operators/opsgenie_alert.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.opsgenie.operators.opsgenie`.""" - -import warnings - -from airflow.providers.opsgenie.operators.opsgenie import OpsgenieCreateAlertOperator - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.opsgenie.operators.opsgenie`.", - DeprecationWarning, - stacklevel=2, -) - - -class OpsgenieAlertOperator(OpsgenieCreateAlertOperator): - """ - This operator is deprecated. - Please use :class:`airflow.providers.opsgenie.operators.opsgenie.OpsgenieCreateAlertOperator`. - """ - - def __init__(self, *args, **kwargs): - warnings.warn( - "This operator is deprecated. " - "Please use :class:`airflow.providers.opsgenie.operators.opsgenie.OpsgenieCreateAlertOperator`.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/airflow/providers/opsgenie/provider.yaml b/airflow/providers/opsgenie/provider.yaml index fbf88cd28ebc4..58aba33976f60 100644 --- a/airflow/providers/opsgenie/provider.yaml +++ b/airflow/providers/opsgenie/provider.yaml @@ -22,6 +22,8 @@ description: | `Opsgenie `__ versions: + - 5.0.0 + - 4.0.0 - 3.1.0 - 3.0.3 - 3.0.2 @@ -33,8 +35,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - opsgenie-sdk>=2.1.5 integrations: - integration-name: Opsgenie @@ -47,18 +50,13 @@ integrations: operators: - integration-name: Opsgenie python-modules: - - airflow.providers.opsgenie.operators.opsgenie_alert - airflow.providers.opsgenie.operators.opsgenie hooks: - integration-name: Opsgenie python-modules: - - airflow.providers.opsgenie.hooks.opsgenie_alert - airflow.providers.opsgenie.hooks.opsgenie -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.opsgenie.hooks.opsgenie.OpsgenieAlertHook - connection-types: - hook-class-name: airflow.providers.opsgenie.hooks.opsgenie.OpsgenieAlertHook connection-type: opsgenie diff --git a/airflow/providers/oracle/CHANGELOG.rst b/airflow/providers/oracle/CHANGELOG.rst index 6d98b4e783548..e6baa4437e7b0 100644 --- a/airflow/providers/oracle/CHANGELOG.rst +++ b/airflow/providers/oracle/CHANGELOG.rst @@ -16,9 +16,111 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.5.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Features +~~~~~~~~ + +* ``Add SQLExecuteQueryOperator (#25717)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + +3.4.0 +..... + +Features +~~~~~~~~ + +* ``Add oracledb thick mode support for oracle provider (#26576)`` + +Misc +~~~~ + +* ``Add common-sql lower bound for common-sql (#25789)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + + +3.3.0 +..... + +Features +~~~~~~~~ + +* ``Unify DbApiHook.run() method with the methods which override it (#23971)`` + + +3.2.0 +..... + +Features +~~~~~~~~ + +* ``Move all SQL classes to common-sql provider (#24836)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Align Black and blacken-docs configs (#24785)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.1.0 +..... + +Features +~~~~~~~~~ + +* ``Update Oracle library to latest version (#24311)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Features +~~~~~~~~ + +* ``Add 'parameters' to templated fields in 'OracleOperator' (#22857)`` + +Misc +~~~~ + +* ``Make numpy effectively an optional dependency for Oracle provider (#24272)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare provider documentation 2022.05.11 (#23631)`` + * ``Use new Breese for building, pulling and verifying the images. (#23104)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.2.3 ..... diff --git a/airflow/providers/oracle/hooks/oracle.py b/airflow/providers/oracle/hooks/oracle.py index 84c0f5d6a1e33..0ba7425e166f0 100644 --- a/airflow/providers/oracle/hooks/oracle.py +++ b/airflow/providers/oracle/hooks/oracle.py @@ -15,15 +15,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations +import math import warnings from datetime import datetime -from typing import Dict, List, Optional, Union -import cx_Oracle -import numpy +import oracledb -from airflow.hooks.dbapi import DbApiHook +try: + import numpy +except ImportError: + numpy = None # type: ignore + +from airflow.providers.common.sql.hooks.sql import DbApiHook PARAM_TYPES = {bool, float, int, str} @@ -37,22 +42,85 @@ def _map_param(value): return value +def _get_bool(val): + if isinstance(val, bool): + return val + if isinstance(val, str): + val = val.lower().strip() + if val == "true": + return True + if val == "false": + return False + return None + + +def _get_first_bool(*vals): + for val in vals: + converted = _get_bool(val) + if isinstance(converted, bool): + return converted + return None + + class OracleHook(DbApiHook): """ Interact with Oracle SQL. :param oracle_conn_id: The :ref:`Oracle connection id ` used for Oracle credentials. + :param thick_mode: Specify whether to use python-oracledb in thick mode. Defaults to False. + If set to True, you must have the Oracle Client libraries installed. + See `oracledb docs` + for more info. + :param thick_mode_lib_dir: Path to use to find the Oracle Client libraries when using thick mode. + If not specified, defaults to the standard way of locating the Oracle Client library on the OS. + See `oracledb docs + ` + for more info. + :param thick_mode_config_dir: Path to use to find the Oracle Client library + configuration files when using thick mode. + If not specified, defaults to the standard way of locating the Oracle Client + library configuration files on the OS. + See `oracledb docs + ` + for more info. + :param fetch_decimals: Specify whether numbers should be fetched as ``decimal.Decimal`` values. + See `defaults.fetch_decimals + ` + for more info. + :param fetch_lobs: Specify whether to fetch strings/bytes for CLOBs or BLOBs instead of locators. + See `defaults.fetch_lobs + ` + for more info. """ - conn_name_attr = 'oracle_conn_id' - default_conn_name = 'oracle_default' - conn_type = 'oracle' - hook_name = 'Oracle' + conn_name_attr = "oracle_conn_id" + default_conn_name = "oracle_default" + conn_type = "oracle" + hook_name = "Oracle" + _test_connection_sql = "select 1 from dual" supports_autocommit = True - def get_conn(self) -> 'OracleHook': + def __init__( + self, + *args, + thick_mode: bool | None = None, + thick_mode_lib_dir: str | None = None, + thick_mode_config_dir: str | None = None, + fetch_decimals: bool | None = None, + fetch_lobs: bool | None = None, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + + self.thick_mode = thick_mode + self.thick_mode_lib_dir = thick_mode_lib_dir + self.thick_mode_config_dir = thick_mode_config_dir + self.fetch_decimals = fetch_decimals + self.fetch_lobs = fetch_lobs + + def get_conn(self) -> oracledb.Connection: """ Returns a oracle connection object Optional parameters for using a custom DSN connection @@ -72,32 +140,61 @@ def get_conn(self) -> 'OracleHook': .. code-block:: python - { - "dsn": ( - "(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)" - "(HOST=host)(PORT=1521))(CONNECT_DATA=(SID=sid)))" - ) - } + {"dsn": ("(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST=host)(PORT=1521))(CONNECT_DATA=(SID=sid)))")} - see more param detail in - `cx_Oracle.connect `_ + see more param detail in `oracledb.connect + `_ """ conn = self.get_connection(self.oracle_conn_id) # type: ignore[attr-defined] - conn_config = {'user': conn.login, 'password': conn.password} - sid = conn.extra_dejson.get('sid') - mod = conn.extra_dejson.get('module') + conn_config = {"user": conn.login, "password": conn.password} + sid = conn.extra_dejson.get("sid") + mod = conn.extra_dejson.get("module") schema = conn.schema - service_name = conn.extra_dejson.get('service_name') + # Enable oracledb thick mode if thick_mode is set to True + # Parameters take precedence over connection config extra + # Defaults to use thin mode if not provided in params or connection config extra + thick_mode = _get_first_bool(self.thick_mode, conn.extra_dejson.get("thick_mode")) + if thick_mode is True: + if self.thick_mode_lib_dir is None: + self.thick_mode_lib_dir = conn.extra_dejson.get("thick_mode_lib_dir") + if not isinstance(self.thick_mode_lib_dir, (str, type(None))): + raise TypeError( + f"thick_mode_lib_dir expected str or None, " + f"got {type(self.thick_mode_lib_dir).__name__}" + ) + if self.thick_mode_config_dir is None: + self.thick_mode_config_dir = conn.extra_dejson.get("thick_mode_config_dir") + if not isinstance(self.thick_mode_config_dir, (str, type(None))): + raise TypeError( + f"thick_mode_config_dir expected str or None, " + f"got {type(self.thick_mode_config_dir).__name__}" + ) + oracledb.init_oracle_client( + lib_dir=self.thick_mode_lib_dir, config_dir=self.thick_mode_config_dir + ) + + # Set oracledb Defaults Attributes if provided + # (https://python-oracledb.readthedocs.io/en/latest/api_manual/defaults.html) + fetch_decimals = _get_first_bool(self.fetch_decimals, conn.extra_dejson.get("fetch_decimals")) + if isinstance(fetch_decimals, bool): + oracledb.defaults.fetch_decimals = fetch_decimals + + fetch_lobs = _get_first_bool(self.fetch_lobs, conn.extra_dejson.get("fetch_lobs")) + if isinstance(fetch_lobs, bool): + oracledb.defaults.fetch_lobs = fetch_lobs + + # Set up DSN + service_name = conn.extra_dejson.get("service_name") port = conn.port if conn.port else 1521 if conn.host and sid and not service_name: - conn_config['dsn'] = cx_Oracle.makedsn(conn.host, port, sid) + conn_config["dsn"] = oracledb.makedsn(conn.host, port, sid) elif conn.host and service_name and not sid: - conn_config['dsn'] = cx_Oracle.makedsn(conn.host, port, service_name=service_name) + conn_config["dsn"] = oracledb.makedsn(conn.host, port, service_name=service_name) else: - dsn = conn.extra_dejson.get('dsn') + dsn = conn.extra_dejson.get("dsn") if dsn is None: dsn = conn.host if conn.port is not None: @@ -112,53 +209,42 @@ def get_conn(self) -> 'OracleHook': stacklevel=2, ) dsn += "/" + conn.schema - conn_config['dsn'] = dsn - - if 'encoding' in conn.extra_dejson: - conn_config['encoding'] = conn.extra_dejson.get('encoding') - # if `encoding` is specific but `nencoding` is not - # `nencoding` should use same values as `encoding` to set encoding, inspired by - # https://github.com/oracle/python-cx_Oracle/issues/157#issuecomment-371877993 - if 'nencoding' not in conn.extra_dejson: - conn_config['nencoding'] = conn.extra_dejson.get('encoding') - if 'nencoding' in conn.extra_dejson: - conn_config['nencoding'] = conn.extra_dejson.get('nencoding') - if 'threaded' in conn.extra_dejson: - conn_config['threaded'] = conn.extra_dejson.get('threaded') - if 'events' in conn.extra_dejson: - conn_config['events'] = conn.extra_dejson.get('events') - - mode = conn.extra_dejson.get('mode', '').lower() - if mode == 'sysdba': - conn_config['mode'] = cx_Oracle.SYSDBA - elif mode == 'sysasm': - conn_config['mode'] = cx_Oracle.SYSASM - elif mode == 'sysoper': - conn_config['mode'] = cx_Oracle.SYSOPER - elif mode == 'sysbkp': - conn_config['mode'] = cx_Oracle.SYSBKP - elif mode == 'sysdgd': - conn_config['mode'] = cx_Oracle.SYSDGD - elif mode == 'syskmt': - conn_config['mode'] = cx_Oracle.SYSKMT - elif mode == 'sysrac': - conn_config['mode'] = cx_Oracle.SYSRAC - - purity = conn.extra_dejson.get('purity', '').lower() - if purity == 'new': - conn_config['purity'] = cx_Oracle.ATTR_PURITY_NEW - elif purity == 'self': - conn_config['purity'] = cx_Oracle.ATTR_PURITY_SELF - elif purity == 'default': - conn_config['purity'] = cx_Oracle.ATTR_PURITY_DEFAULT - - conn = cx_Oracle.connect(**conn_config) + conn_config["dsn"] = dsn + + if "events" in conn.extra_dejson: + conn_config["events"] = conn.extra_dejson.get("events") + + mode = conn.extra_dejson.get("mode", "").lower() + if mode == "sysdba": + conn_config["mode"] = oracledb.AUTH_MODE_SYSDBA + elif mode == "sysasm": + conn_config["mode"] = oracledb.AUTH_MODE_SYSASM + elif mode == "sysoper": + conn_config["mode"] = oracledb.AUTH_MODE_SYSOPER + elif mode == "sysbkp": + conn_config["mode"] = oracledb.AUTH_MODE_SYSBKP + elif mode == "sysdgd": + conn_config["mode"] = oracledb.AUTH_MODE_SYSDGD + elif mode == "syskmt": + conn_config["mode"] = oracledb.AUTH_MODE_SYSKMT + elif mode == "sysrac": + conn_config["mode"] = oracledb.AUTH_MODE_SYSRAC + + purity = conn.extra_dejson.get("purity", "").lower() + if purity == "new": + conn_config["purity"] = oracledb.PURITY_NEW + elif purity == "self": + conn_config["purity"] = oracledb.PURITY_SELF + elif purity == "default": + conn_config["purity"] = oracledb.PURITY_DEFAULT + + conn = oracledb.connect(**conn_config) if mod is not None: conn.module = mod # if Connection.schema is defined, set schema after connecting successfully # cannot be part of conn_config - # https://cx-oracle.readthedocs.io/en/latest/api_manual/connection.html?highlight=schema#Connection.current_schema + # https://python-oracledb.readthedocs.io/en/latest/api_manual/connection.html?highlight=schema#Connection.current_schema # Only set schema when not using conn.schema as Service Name if schema and service_name: conn.current_schema = schema @@ -168,10 +254,10 @@ def get_conn(self) -> 'OracleHook': def insert_rows( self, table: str, - rows: List[tuple], + rows: list[tuple], target_fields=None, commit_every: int = 1000, - replace: Optional[bool] = False, + replace: bool | None = False, **kwargs, ) -> None: """ @@ -179,7 +265,7 @@ def insert_rows( the whole set of inserts is treated as one transaction Changes from standard DbApiHook implementation: - - Oracle SQL queries in cx_Oracle can not be terminated with a semicolon (`;`) + - Oracle SQL queries in oracledb can not be terminated with a semicolon (`;`) - Replace NaN values with NULL using `numpy.nan_to_num` (not using `is_nan()` because of input types error for strings) - Coerce datetime cells to Oracle DATETIME format during insert @@ -194,10 +280,10 @@ def insert_rows( :param replace: Whether to replace instead of insert """ if target_fields: - target_fields = ', '.join(target_fields) - target_fields = f'({target_fields})' + target_fields = ", ".join(target_fields) + target_fields = f"({target_fields})" else: - target_fields = '' + target_fields = "" conn = self.get_conn() if self.supports_autocommit: self.set_autocommit(conn, False) @@ -210,14 +296,14 @@ def insert_rows( if isinstance(cell, str): lst.append("'" + str(cell).replace("'", "''") + "'") elif cell is None: - lst.append('NULL') - elif isinstance(cell, float) and numpy.isnan(cell): # coerce numpy NaN to NULL - lst.append('NULL') - elif isinstance(cell, numpy.datetime64): + lst.append("NULL") + elif isinstance(cell, float) and math.isnan(cell): # coerce numpy NaN to NULL + lst.append("NULL") + elif numpy and isinstance(cell, numpy.datetime64): lst.append("'" + str(cell) + "'") elif isinstance(cell, datetime): lst.append( - "to_date('" + cell.strftime('%Y-%m-%d %H:%M:%S') + "','YYYY-MM-DD HH24:MI:SS')" + "to_date('" + cell.strftime("%Y-%m-%d %H:%M:%S") + "','YYYY-MM-DD HH24:MI:SS')" ) else: lst.append(str(cell)) @@ -226,21 +312,21 @@ def insert_rows( cur.execute(sql) if i % commit_every == 0: conn.commit() # type: ignore[attr-defined] - self.log.info('Loaded %s into %s rows so far', i, table) + self.log.info("Loaded %s into %s rows so far", i, table) conn.commit() # type: ignore[attr-defined] cur.close() conn.close() # type: ignore[attr-defined] - self.log.info('Done loading. Loaded a total of %s rows', i) + self.log.info("Done loading. Loaded a total of %s rows", i) def bulk_insert_rows( self, table: str, - rows: List[tuple], - target_fields: Optional[List[str]] = None, + rows: list[tuple], + target_fields: list[str] | None = None, commit_every: int = 5000, ): """ - A performant bulk insert for cx_Oracle + A performant bulk insert for oracledb that uses prepared statements via `executemany()`. For best performance, pass in `rows` as an iterator. @@ -259,10 +345,10 @@ def bulk_insert_rows( self.set_autocommit(conn, False) cursor = conn.cursor() # type: ignore[attr-defined] values_base = target_fields if target_fields else rows[0] - prepared_stm = 'insert into {tablename} {columns} values ({values})'.format( + prepared_stm = "insert into {tablename} {columns} values ({values})".format( tablename=table, - columns='({})'.format(', '.join(target_fields)) if target_fields else '', - values=', '.join(':%s' % i for i in range(1, len(values_base) + 1)), + columns="({})".format(", ".join(target_fields)) if target_fields else "", + values=", ".join(":%s" % i for i in range(1, len(values_base) + 1)), ) row_count = 0 # Chunk the rows @@ -274,14 +360,14 @@ def bulk_insert_rows( cursor.prepare(prepared_stm) cursor.executemany(None, row_chunk) conn.commit() # type: ignore[attr-defined] - self.log.info('[%s] inserted %s rows', table, row_count) + self.log.info("[%s] inserted %s rows", table, row_count) # Empty chunk row_chunk = [] # Commit the leftover chunk cursor.prepare(prepared_stm) cursor.executemany(None, row_chunk) conn.commit() # type: ignore[attr-defined] - self.log.info('[%s] inserted %s rows', table, row_count) + self.log.info("[%s] inserted %s rows", table, row_count) cursor.close() conn.close() # type: ignore[attr-defined] @@ -289,8 +375,8 @@ def callproc( self, identifier: str, autocommit: bool = False, - parameters: Optional[Union[List, Dict]] = None, - ) -> Optional[Union[List, Dict]]: + parameters: list | dict | None = None, + ) -> list | dict | None: """ Call the stored procedure identified by the provided string. @@ -302,7 +388,7 @@ def callproc( provided `parameters` argument. See - https://cx-oracle.readthedocs.io/en/latest/api_manual/cursor.html#Cursor.var + https://python-oracledb.readthedocs.io/en/latest/api_manual/cursor.html#Cursor.var for further reference. """ if parameters is None: @@ -339,18 +425,3 @@ def handler(cursor): ) return result - - # TODO: Merge this implementation back to DbApiHook when dropping - # support for Airflow 2.2. - def test_connection(self): - """Tests the connection by executing a select 1 from dual query""" - status, message = False, '' - try: - if self.get_first("select 1 from dual"): - status = True - message = 'Connection successfully tested' - except Exception as e: - status = False - message = str(e) - - return status, message diff --git a/airflow/providers/oracle/operators/oracle.py b/airflow/providers/oracle/operators/oracle.py index b60d4b6e89100..fac193b4f8324 100644 --- a/airflow/providers/oracle/operators/oracle.py +++ b/airflow/providers/oracle/operators/oracle.py @@ -15,16 +15,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Sequence, Union +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.providers.oracle.hooks.oracle import OracleHook if TYPE_CHECKING: from airflow.utils.context import Context -class OracleOperator(BaseOperator): +class OracleOperator(SQLExecuteQueryOperator): """ Executes sql code in a specific Oracle database. @@ -40,33 +44,21 @@ class OracleOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'parameters', - 'sql', + "parameters", + "sql", ) - template_ext: Sequence[str] = ('.sql',) - template_fields_renderers = {'sql': 'sql'} - ui_color = '#ededed' + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "sql"} + ui_color = "#ededed" - def __init__( - self, - *, - sql: Union[str, List[str]], - oracle_conn_id: str = 'oracle_default', - parameters: Optional[Union[Mapping, Iterable]] = None, - autocommit: bool = False, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.oracle_conn_id = oracle_conn_id - self.sql = sql - self.autocommit = autocommit - self.parameters = parameters - - def execute(self, context: 'Context') -> None: - self.log.info('Executing: %s', self.sql) - hook = OracleHook(oracle_conn_id=self.oracle_conn_id) - if self.sql: - hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) + def __init__(self, *, oracle_conn_id: str = "oracle_default", **kwargs) -> None: + super().__init__(conn_id=oracle_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""", + DeprecationWarning, + stacklevel=2, + ) class OracleStoredProcedureOperator(BaseOperator): @@ -80,17 +72,17 @@ class OracleStoredProcedureOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'parameters', - 'procedure', + "parameters", + "procedure", ) - ui_color = '#ededed' + ui_color = "#ededed" def __init__( self, *, procedure: str, - oracle_conn_id: str = 'oracle_default', - parameters: Optional[Union[Dict, List]] = None, + oracle_conn_id: str = "oracle_default", + parameters: dict | list | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -98,7 +90,7 @@ def __init__( self.procedure = procedure self.parameters = parameters - def execute(self, context: 'Context') -> Optional[Union[List, Dict]]: - self.log.info('Executing: %s', self.procedure) + def execute(self, context: Context): + self.log.info("Executing: %s", self.procedure) hook = OracleHook(oracle_conn_id=self.oracle_conn_id) return hook.callproc(self.procedure, autocommit=True, parameters=self.parameters) diff --git a/airflow/providers/oracle/provider.yaml b/airflow/providers/oracle/provider.yaml index 254c577fecb5b..b80aad3261860 100644 --- a/airflow/providers/oracle/provider.yaml +++ b/airflow/providers/oracle/provider.yaml @@ -22,6 +22,12 @@ description: | `Oracle `__ versions: + - 3.5.0 + - 3.4.0 + - 3.3.0 + - 3.2.0 + - 3.1.0 + - 3.0.0 - 2.2.3 - 2.2.2 - 2.2.1 @@ -33,8 +39,10 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-common-sql>=1.3.1 + - oracledb>=1.0.0 integrations: - integration-name: Oracle @@ -42,6 +50,11 @@ integrations: logo: /integration-logos/oracle/Oracle.png tags: [software] +additional-extras: + - name: numpy + dependencies: + - numpy + operators: - integration-name: Oracle python-modules: @@ -57,9 +70,6 @@ transfers: target-integration-name: Oracle python-module: airflow.providers.oracle.transfers.oracle_to_oracle -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.oracle.hooks.oracle.OracleHook - connection-types: - hook-class-name: airflow.providers.oracle.hooks.oracle.OracleHook connection-type: oracle diff --git a/airflow/providers/oracle/transfers/oracle_to_oracle.py b/airflow/providers/oracle/transfers/oracle_to_oracle.py index 9d16fa85f2061..bc603bcd17272 100644 --- a/airflow/providers/oracle/transfers/oracle_to_oracle.py +++ b/airflow/providers/oracle/transfers/oracle_to_oracle.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.oracle.hooks.oracle import OracleHook @@ -38,9 +40,9 @@ class OracleToOracleOperator(BaseOperator): :param rows_chunk: number of rows per chunk to commit. """ - template_fields: Sequence[str] = ('source_sql', 'source_sql_params') + template_fields: Sequence[str] = ("source_sql", "source_sql_params") template_fields_renderers = {"source_sql": "sql", "source_sql_params": "py"} - ui_color = '#e08c8c' + ui_color = "#e08c8c" def __init__( self, @@ -49,7 +51,7 @@ def __init__( destination_table: str, oracle_source_conn_id: str, source_sql: str, - source_sql_params: Optional[dict] = None, + source_sql_params: dict | None = None, rows_chunk: int = 5000, **kwargs, ) -> None: @@ -83,7 +85,7 @@ def _execute(self, src_hook, dest_hook, context) -> None: self.log.info("Finished data transfer.") cursor.close() - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: src_hook = OracleHook(oracle_conn_id=self.oracle_source_conn_id) dest_hook = OracleHook(oracle_conn_id=self.oracle_destination_conn_id) self._execute(src_hook, dest_hook, context) diff --git a/airflow/providers/pagerduty/.latest-doc-only-change.txt b/airflow/providers/pagerduty/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/pagerduty/.latest-doc-only-change.txt +++ b/airflow/providers/pagerduty/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/pagerduty/CHANGELOG.rst b/airflow/providers/pagerduty/CHANGELOG.rst index b3ebe76ab6c46..4f28966f95c93 100644 --- a/airflow/providers/pagerduty/CHANGELOG.rst +++ b/airflow/providers/pagerduty/CHANGELOG.rst @@ -16,9 +16,51 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Remove "bad characters" from our codebase (#24841)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.1.3 ..... diff --git a/airflow/providers/pagerduty/hooks/pagerduty.py b/airflow/providers/pagerduty/hooks/pagerduty.py index 98be80965aaab..1cd393c54393c 100644 --- a/airflow/providers/pagerduty/hooks/pagerduty.py +++ b/airflow/providers/pagerduty/hooks/pagerduty.py @@ -16,8 +16,10 @@ # specific language governing permissions and limitations # under the License. """Hook for sending or receiving data from PagerDuty as well as creating PagerDuty incidents.""" +from __future__ import annotations + import warnings -from typing import Any, Dict, List, Optional +from typing import Any import pdpyras @@ -49,16 +51,16 @@ class PagerdutyHook(BaseHook): hook_name = "Pagerduty" @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { - "hidden_fields": ['port', 'login', 'schema', 'host'], + "hidden_fields": ["port", "login", "schema", "host"], "relabeling": { - 'password': 'Pagerduty API token', + "password": "Pagerduty API token", }, } - def __init__(self, token: Optional[str] = None, pagerduty_conn_id: Optional[str] = None) -> None: + def __init__(self, token: str | None = None, pagerduty_conn_id: str | None = None) -> None: super().__init__() self.routing_key = None self._session = None @@ -75,7 +77,7 @@ def __init__(self, token: Optional[str] = None, pagerduty_conn_id: Optional[str] self.token = token if self.token is None: - raise AirflowException('Cannot get token: No valid api token nor pagerduty_conn_id supplied.') + raise AirflowException("Cannot get token: No valid api token nor pagerduty_conn_id supplied.") def get_session(self) -> pdpyras.APISession: """ @@ -96,15 +98,15 @@ def create_event( severity: str, source: str = "airflow", action: str = "trigger", - routing_key: Optional[str] = None, - dedup_key: Optional[str] = None, - custom_details: Optional[Any] = None, - group: Optional[str] = None, - component: Optional[str] = None, - class_type: Optional[str] = None, - images: Optional[List[Any]] = None, - links: Optional[List[Any]] = None, - ) -> Dict: + routing_key: str | None = None, + dedup_key: str | None = None, + custom_details: Any | None = None, + group: str | None = None, + component: str | None = None, + class_type: str | None = None, + images: list[Any] | None = None, + links: list[Any] | None = None, + ) -> dict: """ Create event for service integration. @@ -121,7 +123,7 @@ def create_event( :param custom_details: Free-form details from the event. Can be a dictionary or a string. If a dictionary is passed it will show up in PagerDuty as a table. :param group: A cluster or grouping of sources. For example, sources - “prod-datapipe-02” and “prod-datapipe-03” might both be part of “prod-datapipe” + "prod-datapipe-02" and "prod-datapipe-03" might both be part of "prod-datapipe" :param component: The part or component of the affected system that is broken. :param class_type: The class/type of the event. :param images: List of images to include. Each dictionary in the list accepts the following keys: @@ -134,7 +136,6 @@ def create_event( `text`: [Optional] Plain text that describes the purpose of the link, and can be used as the link's text. :return: PagerDuty Events API v2 response. - :rtype: dict """ warnings.warn( "This method will be deprecated. Please use the " diff --git a/airflow/providers/pagerduty/hooks/pagerduty_events.py b/airflow/providers/pagerduty/hooks/pagerduty_events.py index c5eaffe105284..4fc2e4bf34fb6 100644 --- a/airflow/providers/pagerduty/hooks/pagerduty_events.py +++ b/airflow/providers/pagerduty/hooks/pagerduty_events.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """Hook for sending or receiving data from PagerDuty as well as creating PagerDuty incidents.""" -from typing import Any, Dict, List, Optional +from __future__ import annotations + +from typing import Any import pdpyras @@ -41,17 +43,17 @@ class PagerdutyEventsHook(BaseHook): hook_name = "Pagerduty Events" @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { - "hidden_fields": ['port', 'login', 'schema', 'host', 'extra'], + "hidden_fields": ["port", "login", "schema", "host", "extra"], "relabeling": { - 'password': 'Pagerduty Integration key', + "password": "Pagerduty Integration key", }, } def __init__( - self, integration_key: Optional[str] = None, pagerduty_events_conn_id: Optional[str] = None + self, integration_key: str | None = None, pagerduty_events_conn_id: str | None = None ) -> None: super().__init__() self.integration_key = None @@ -66,23 +68,23 @@ def __init__( if self.integration_key is None: raise AirflowException( - 'Cannot get token: No valid integration key nor pagerduty_events_conn_id supplied.' + "Cannot get token: No valid integration key nor pagerduty_events_conn_id supplied." ) def create_event( self, summary: str, severity: str, - source: str = 'airflow', - action: str = 'trigger', - dedup_key: Optional[str] = None, - custom_details: Optional[Any] = None, - group: Optional[str] = None, - component: Optional[str] = None, - class_type: Optional[str] = None, - images: Optional[List[Any]] = None, - links: Optional[List[Any]] = None, - ) -> Dict: + source: str = "airflow", + action: str = "trigger", + dedup_key: str | None = None, + custom_details: Any | None = None, + group: str | None = None, + component: str | None = None, + class_type: str | None = None, + images: list[Any] | None = None, + links: list[Any] | None = None, + ) -> dict: """ Create event for service integration. @@ -97,7 +99,7 @@ def create_event( :param custom_details: Free-form details from the event. Can be a dictionary or a string. If a dictionary is passed it will show up in PagerDuty as a table. :param group: A cluster or grouping of sources. For example, sources - “prod-datapipe-02” and “prod-datapipe-03” might both be part of “prod-datapipe” + "prod-datapipe-02" and "prod-datapipe-03" might both be part of "prod-datapipe" :param component: The part or component of the affected system that is broken. :param class_type: The class/type of the event. :param images: List of images to include. Each dictionary in the list accepts the following keys: @@ -110,7 +112,6 @@ def create_event( `text`: [Optional] Plain text that describes the purpose of the link, and can be used as the link's text. :return: PagerDuty Events API v2 response. - :rtype: dict """ payload = { "summary": summary, @@ -126,7 +127,7 @@ def create_event( if class_type: payload["class"] = class_type - actions = ('trigger', 'acknowledge', 'resolve') + actions = ("trigger", "acknowledge", "resolve") if action not in actions: raise ValueError(f"Event action must be one of: {', '.join(actions)}") data = { @@ -135,7 +136,7 @@ def create_event( } if dedup_key: data["dedup_key"] = dedup_key - elif action != 'trigger': + elif action != "trigger": raise ValueError( f"The dedup_key property is required for event_action={action} events, " f"and it must be a string." diff --git a/airflow/providers/pagerduty/provider.yaml b/airflow/providers/pagerduty/provider.yaml index de038c026e263..80965bdb69885 100644 --- a/airflow/providers/pagerduty/provider.yaml +++ b/airflow/providers/pagerduty/provider.yaml @@ -22,6 +22,8 @@ description: | `Pagerduty `__ versions: + - 3.1.0 + - 3.0.0 - 2.1.3 - 2.1.2 - 2.1.1 @@ -31,8 +33,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - pdpyras>=4.1.2 integrations: - integration-name: Pagerduty diff --git a/airflow/providers/papermill/.latest-doc-only-change.txt b/airflow/providers/papermill/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/papermill/.latest-doc-only-change.txt +++ b/airflow/providers/papermill/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/papermill/CHANGELOG.rst b/airflow/providers/papermill/CHANGELOG.rst index 99581324117b4..56f70310055a6 100644 --- a/airflow/providers/papermill/CHANGELOG.rst +++ b/airflow/providers/papermill/CHANGELOG.rst @@ -16,9 +16,56 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Features +~~~~~~~~ + +* ``Add support to specify language name in PapermillOperator (#23916)`` +* ``Fix langauge override in papermill operator (#24301)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Migrate Papermill example DAGs to new design #22456 (#24146)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.2.3 ..... diff --git a/airflow/providers/papermill/example_dags/example_papermill.py b/airflow/providers/papermill/example_dags/example_papermill.py deleted file mode 100644 index c49b7715794ad..0000000000000 --- a/airflow/providers/papermill/example_dags/example_papermill.py +++ /dev/null @@ -1,85 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This DAG will use Papermill to run the notebook "hello_world", based on the execution date -it will create an output notebook "out-". All fields, including the keys in the parameters, are -templated. -""" -import os -from datetime import datetime, timedelta - -import scrapbook as sb - -from airflow import DAG -from airflow.decorators import task -from airflow.lineage import AUTO -from airflow.providers.papermill.operators.papermill import PapermillOperator - -START_DATE = datetime(2021, 1, 1) -SCHEDULE_INTERVAL = '0 0 * * *' -DAGRUN_TIMEOUT = timedelta(minutes=60) - -with DAG( - dag_id='example_papermill_operator', - schedule_interval=SCHEDULE_INTERVAL, - start_date=START_DATE, - dagrun_timeout=DAGRUN_TIMEOUT, - tags=['example'], - catchup=False, -) as dag_1: - # [START howto_operator_papermill] - run_this = PapermillOperator( - task_id="run_example_notebook", - input_nb="/tmp/hello_world.ipynb", - output_nb="/tmp/out-{{ execution_date }}.ipynb", - parameters={"msgs": "Ran from Airflow at {{ execution_date }}!"}, - ) - # [END howto_operator_papermill] - - -@task -def check_notebook(inlets, execution_date): - """ - Verify the message in the notebook - """ - notebook = sb.read_notebook(inlets[0].url) - message = notebook.scraps['message'] - print(f"Message in notebook {message} for {execution_date}") - - if message.data != f"Ran from Airflow at {execution_date}!": - return False - - return True - - -with DAG( - dag_id='example_papermill_operator_2', - schedule_interval=SCHEDULE_INTERVAL, - start_date=START_DATE, - dagrun_timeout=DAGRUN_TIMEOUT, - catchup=False, -) as dag_2: - - run_this = PapermillOperator( - task_id="run_example_notebook", - input_nb=os.path.join(os.path.dirname(os.path.realpath(__file__)), "input_notebook.ipynb"), - output_nb="/tmp/out-{{ execution_date }}.ipynb", - parameters={"msgs": "Ran from Airflow at {{ execution_date }}!"}, - ) - - run_this >> check_notebook(inlets=AUTO, execution_date="{{ execution_date }}") diff --git a/airflow/providers/papermill/operators/papermill.py b/airflow/providers/papermill/operators/papermill.py index b79ac226c120f..531304d441c7f 100644 --- a/airflow/providers/papermill/operators/papermill.py +++ b/airflow/providers/papermill/operators/papermill.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Dict, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence import attr import papermill as pm @@ -31,10 +33,10 @@ class NoteBook(File): """Jupyter notebook""" - type_hint: Optional[str] = "jupyter_notebook" - parameters: Optional[Dict] = {} + type_hint: str | None = "jupyter_notebook" + parameters: dict | None = {} - meta_schema: str = __name__ + '.NoteBook' + meta_schema: str = __name__ + ".NoteBook" class PapermillOperator(BaseOperator): @@ -50,16 +52,16 @@ class PapermillOperator(BaseOperator): supports_lineage = True - template_fields: Sequence[str] = ('input_nb', 'output_nb', 'parameters', 'kernel_name', 'language_name') + template_fields: Sequence[str] = ("input_nb", "output_nb", "parameters", "kernel_name", "language_name") def __init__( self, *, - input_nb: Optional[str] = None, - output_nb: Optional[str] = None, - parameters: Optional[Dict] = None, - kernel_name: Optional[str] = None, - language_name: Optional[str] = None, + input_nb: str | None = None, + output_nb: str | None = None, + parameters: dict | None = None, + kernel_name: str | None = None, + language_name: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -74,7 +76,7 @@ def __init__( if output_nb: self.outlets.append(NoteBook(url=output_nb)) - def execute(self, context: 'Context'): + def execute(self, context: Context): if not self.inlets or not self.outlets: raise ValueError("Input notebook or output notebook is not specified") @@ -86,5 +88,5 @@ def execute(self, context: 'Context'): progress_bar=False, report_mode=True, kernel_name=self.kernel_name, - language_name=self.language_name, + language=self.language_name, ) diff --git a/airflow/providers/papermill/provider.yaml b/airflow/providers/papermill/provider.yaml index 42d052c42b692..966a171b50463 100644 --- a/airflow/providers/papermill/provider.yaml +++ b/airflow/providers/papermill/provider.yaml @@ -22,6 +22,8 @@ description: | `Papermill `__ versions: + - 3.1.0 + - 3.0.0 - 2.2.3 - 2.2.2 - 2.2.1 @@ -33,8 +35,10 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - papermill[all]>=1.2.1 + - scrapbook[all] integrations: - integration-name: Papermill diff --git a/airflow/providers/plexus/.latest-doc-only-change.txt b/airflow/providers/plexus/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/plexus/.latest-doc-only-change.txt +++ b/airflow/providers/plexus/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/plexus/CHANGELOG.rst b/airflow/providers/plexus/CHANGELOG.rst index 18298d2f4f61f..e917e7be01fb2 100644 --- a/airflow/providers/plexus/CHANGELOG.rst +++ b/airflow/providers/plexus/CHANGELOG.rst @@ -16,9 +16,50 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Migrate Plexus example DAGs to new design #22457 (#24147)`` + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.0.4 ..... diff --git a/airflow/providers/plexus/example_dags/example_plexus.py b/airflow/providers/plexus/example_dags/example_plexus.py deleted file mode 100644 index 68ddcb7d030d2..0000000000000 --- a/airflow/providers/plexus/example_dags/example_plexus.py +++ /dev/null @@ -1,47 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from datetime import datetime - -from airflow import DAG -from airflow.providers.plexus.operators.job import PlexusJobOperator - -HOME = '/home/acc' -T3_PRERUN_SCRIPT = 'cp {home}/imdb/run_scripts/mlflow.sh {home}/ && chmod +x mlflow.sh'.format(home=HOME) - - -dag = DAG( - 'test', - default_args={'owner': 'core scientific', 'retries': 1}, - description='testing plexus operator', - start_date=datetime(2021, 1, 1), - schedule_interval='@once', - catchup=False, -) - -t1 = PlexusJobOperator( - task_id='test', - job_params={ - 'name': 'test', - 'app': 'MLFlow Pipeline 01', - 'queue': 'DGX-2 (gpu:Tesla V100-SXM3-32GB)', - 'num_nodes': 1, - 'num_cores': 1, - 'prerun_script': T3_PRERUN_SCRIPT, - }, - dag=dag, -) diff --git a/airflow/providers/plexus/hooks/plexus.py b/airflow/providers/plexus/hooks/plexus.py index 6c1d7bc7ebacb..21991066dc2b8 100644 --- a/airflow/providers/plexus/hooks/plexus.py +++ b/airflow/providers/plexus/hooks/plexus.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from typing import Any import arrow diff --git a/airflow/providers/plexus/operators/job.py b/airflow/providers/plexus/operators/job.py index d252ba8cab285..963924bbbbb91 100644 --- a/airflow/providers/plexus/operators/job.py +++ b/airflow/providers/plexus/operators/job.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import logging import time -from typing import Any, Dict, Optional +from typing import Any import requests @@ -43,7 +44,7 @@ class PlexusJobOperator(BaseOperator): """ - def __init__(self, job_params: Dict, **kwargs) -> None: + def __init__(self, job_params: dict, **kwargs) -> None: super().__init__(**kwargs) self.job_params = job_params @@ -121,14 +122,14 @@ def _api_lookup(self, param: str, hook): for dct in results: if dct[mapping[0]] == mapping[1]: v = dct[key] - if param == 'app': - self.is_service = dct['is_service'] + if param == "app": + self.is_service = dct["is_service"] if v is None: raise AirflowException(f"Could not locate value for param:{key} at endpoint: {endpoint}") return v - def construct_job_params(self, hook: Any) -> Dict[Any, Optional[Any]]: + def construct_job_params(self, hook: Any) -> dict[Any, Any | None]: """ Creates job_params dict for api call to launch a Plexus job. diff --git a/airflow/providers/plexus/provider.yaml b/airflow/providers/plexus/provider.yaml index b5ead4b2c5775..0e908c502caec 100644 --- a/airflow/providers/plexus/provider.yaml +++ b/airflow/providers/plexus/provider.yaml @@ -22,6 +22,8 @@ description: | `Plexus `__ versions: + - 3.1.0 + - 3.0.0 - 2.0.4 - 2.0.3 - 2.0.2 @@ -30,8 +32,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - arrow>=0.16.0 integrations: - integration-name: Plexus diff --git a/airflow/providers/postgres/CHANGELOG.rst b/airflow/providers/postgres/CHANGELOG.rst index 6922091b32fe3..b841dadb8d6e1 100644 --- a/airflow/providers/postgres/CHANGELOG.rst +++ b/airflow/providers/postgres/CHANGELOG.rst @@ -16,9 +16,101 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +5.3.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Features +~~~~~~~~ + +* ``PostgresHook: Added ON CONFLICT DO NOTHING statement when all target fields are primary keys (#26661)`` +* ``Add SQLExecuteQueryOperator (#25717)`` +* ``Rename schema to database in PostgresHook (#26744)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + +5.2.2 +..... + +Misc +~~~~ + +* ``Add common-sql lower bound for common-sql (#25789)`` + +.. Review and move the new changes to one of the sections above: + * ``Rename schema to database in 'PostgresHook' (#26436)`` + * ``Revert "Rename schema to database in 'PostgresHook' (#26436)" (#26734)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +5.2.1 +..... + +Bug Fixes +~~~~~~~~~ + +* ``Bump dep on common-sql to fix issue with SQLTableCheckOperator (#26143)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``postgres provider: use non-binary psycopg2 (#25710)`` + +5.2.0 +..... + +Features +~~~~~~~~ + +* ``Use only public AwsHook's methods during IAM authorization (#25424)`` +* ``Unify DbApiHook.run() method with the methods which override it (#23971)`` + + +5.1.0 +..... + +Features +~~~~~~~~ + +* ``Move all SQL classes to common-sql provider (#24836)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +5.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Migrate Postgres example DAGs to new design #22458 (#24148)`` + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 4.1.0 ..... diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index 884a1c92927e9..3040b1df81ebb 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -15,10 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import os +import warnings from contextlib import closing from copy import deepcopy -from typing import Iterable, List, Optional, Tuple, Union +from typing import Any, Iterable, Union import psycopg2 import psycopg2.extensions @@ -26,8 +29,8 @@ from psycopg2.extensions import connection from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor -from airflow.hooks.dbapi import DbApiHook from airflow.models.connection import Connection +from airflow.providers.common.sql.hooks.sql import DbApiHook CursorType = Union[DictCursor, RealDictCursor, NamedTupleCursor] @@ -58,27 +61,55 @@ class PostgresHook(DbApiHook): reference to a specific postgres database. """ - conn_name_attr = 'postgres_conn_id' - default_conn_name = 'postgres_default' - conn_type = 'postgres' - hook_name = 'Postgres' + conn_name_attr = "postgres_conn_id" + default_conn_name = "postgres_default" + conn_type = "postgres" + hook_name = "Postgres" supports_autocommit = True def __init__(self, *args, **kwargs) -> None: + if "schema" in kwargs: + warnings.warn( + 'The "schema" arg has been renamed to "database" as it contained the database name.' + 'Please use "database" to set the database name.', + DeprecationWarning, + stacklevel=2, + ) + kwargs["database"] = kwargs["schema"] super().__init__(*args, **kwargs) - self.connection: Optional[Connection] = kwargs.pop("connection", None) + self.connection: Connection | None = kwargs.pop("connection", None) self.conn: connection = None - self.schema: Optional[str] = kwargs.pop("schema", None) + self.database: str | None = kwargs.pop("database", None) + + @property + def schema(self): + warnings.warn( + 'The "schema" variable has been renamed to "database" as it contained the database name.' + 'Please use "database" to get the database name.', + DeprecationWarning, + stacklevel=2, + ) + return self.database + + @schema.setter + def schema(self, value): + warnings.warn( + 'The "schema" variable has been renamed to "database" as it contained the database name.' + 'Please use "database" to set the database name.', + DeprecationWarning, + stacklevel=2, + ) + self.database = value def _get_cursor(self, raw_cursor: str) -> CursorType: _cursor = raw_cursor.lower() - if _cursor == 'dictcursor': + if _cursor == "dictcursor": return psycopg2.extras.DictCursor - if _cursor == 'realdictcursor': + if _cursor == "realdictcursor": return psycopg2.extras.RealDictCursor - if _cursor == 'namedtuplecursor': + if _cursor == "namedtuplecursor": return psycopg2.extras.NamedTupleCursor - raise ValueError(f'Invalid cursor passed {_cursor}') + raise ValueError(f"Invalid cursor passed {_cursor}") def get_conn(self) -> connection: """Establishes a connection to a postgres database.""" @@ -86,27 +117,27 @@ def get_conn(self) -> connection: conn = deepcopy(self.connection or self.get_connection(conn_id)) # check for authentication via AWS IAM - if conn.extra_dejson.get('iam', False): + if conn.extra_dejson.get("iam", False): conn.login, conn.password, conn.port = self.get_iam_token(conn) conn_args = dict( host=conn.host, user=conn.login, password=conn.password, - dbname=self.schema or conn.schema, + dbname=self.database or conn.schema, port=conn.port, ) - raw_cursor = conn.extra_dejson.get('cursor', False) + raw_cursor = conn.extra_dejson.get("cursor", False) if raw_cursor: - conn_args['cursor_factory'] = self._get_cursor(raw_cursor) + conn_args["cursor_factory"] = self._get_cursor(raw_cursor) for arg_name, arg_val in conn.extra_dejson.items(): if arg_name not in [ - 'iam', - 'redshift', - 'cursor', - 'cluster-identifier', - 'aws_conn_id', + "iam", + "redshift", + "cursor", + "cluster-identifier", + "aws_conn_id", ]: conn_args[arg_name] = arg_val @@ -126,10 +157,10 @@ def copy_expert(self, sql: str, filename: str) -> None: """ self.log.info("Running copy expert: %s, filename: %s", sql, filename) if not os.path.isfile(filename): - with open(filename, 'w'): + with open(filename, "w"): pass - with open(filename, 'r+') as file: + with open(filename, "r+") as file: with closing(self.get_conn()) as conn: with closing(conn.cursor()) as cur: cur.copy_expert(sql, file) @@ -141,7 +172,9 @@ def get_uri(self) -> str: Extract the URI from the connection. :return: the extracted uri. """ - uri = super().get_uri().replace("postgres://", "postgresql://") + conn = self.get_connection(getattr(self, self.conn_name_attr)) + conn.schema = self.database or conn.schema + uri = conn.get_uri().replace("postgres://", "postgresql://") return uri def bulk_load(self, table: str, tmp_file: str) -> None: @@ -153,7 +186,7 @@ def bulk_dump(self, table: str, tmp_file: str) -> None: self.copy_expert(f"COPY {table} TO STDOUT", tmp_file) @staticmethod - def _serialize_cell(cell: object, conn: Optional[connection] = None) -> object: + def _serialize_cell(cell: object, conn: connection | None = None) -> Any: """ Postgresql will adapt all arguments to the execute() method internally, hence we return cell without any conversion. @@ -164,57 +197,56 @@ def _serialize_cell(cell: object, conn: Optional[connection] = None) -> object: :param cell: The cell to insert into the table :param conn: The database connection :return: The cell - :rtype: object """ return cell - def get_iam_token(self, conn: Connection) -> Tuple[str, str, int]: + def get_iam_token(self, conn: Connection) -> tuple[str, str, int]: """ Uses AWSHook to retrieve a temporary password to connect to Postgres or Redshift. Port is required. If none is provided, default is used for each service """ - from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + try: + from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + except ImportError: + from airflow.exceptions import AirflowException + + raise AirflowException( + "apache-airflow-providers-amazon not installed, run: " + "pip install 'apache-airflow-providers-postgres[amazon]'." + ) - redshift = conn.extra_dejson.get('redshift', False) - aws_conn_id = conn.extra_dejson.get('aws_conn_id', 'aws_default') - aws_hook = AwsBaseHook(aws_conn_id, client_type='rds') + aws_conn_id = conn.extra_dejson.get("aws_conn_id", "aws_default") login = conn.login - if conn.port is None: - port = 5439 if redshift else 5432 - else: - port = conn.port - if redshift: + if conn.extra_dejson.get("redshift", False): + port = conn.port or 5439 # Pull the custer-identifier from the beginning of the Redshift URL # ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns my-cluster - cluster_identifier = conn.extra_dejson.get('cluster-identifier', conn.host.split('.')[0]) - session, endpoint_url = aws_hook._get_credentials(region_name=None) - client = session.client( - "redshift", - endpoint_url=endpoint_url, - config=aws_hook.config, - verify=aws_hook.verify, - ) - cluster_creds = client.get_cluster_credentials( - DbUser=conn.login, - DbName=self.schema or conn.schema, + cluster_identifier = conn.extra_dejson.get("cluster-identifier", conn.host.split(".")[0]) + redshift_client = AwsBaseHook(aws_conn_id=aws_conn_id, client_type="redshift").conn + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials + cluster_creds = redshift_client.get_cluster_credentials( + DbUser=login, + DbName=self.database or conn.schema, ClusterIdentifier=cluster_identifier, AutoCreate=False, ) - token = cluster_creds['DbPassword'] - login = cluster_creds['DbUser'] + token = cluster_creds["DbPassword"] + login = cluster_creds["DbUser"] else: - token = aws_hook.conn.generate_db_auth_token(conn.host, port, conn.login) + port = conn.port or 5432 + rds_client = AwsBaseHook(aws_conn_id=aws_conn_id, client_type="rds").conn + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.generate_db_auth_token + token = rds_client.generate_db_auth_token(conn.host, port, conn.login) return login, token, port - def get_table_primary_key(self, table: str, schema: Optional[str] = "public") -> Optional[List[str]]: + def get_table_primary_key(self, table: str, schema: str | None = "public") -> list[str] | None: """ Helper method that returns the table primary key :param table: Name of the target table :param schema: Name of the target schema, public by default :return: Primary key columns list - :rtype: List[str] """ sql = """ select kcu.column_name @@ -230,9 +262,9 @@ def get_table_primary_key(self, table: str, schema: Optional[str] = "public") -> pk_columns = [row[0] for row in self.get_records(sql, (schema, table))] return pk_columns or None - @staticmethod + @classmethod def _generate_insert_sql( - table: str, values: Tuple[str, ...], target_fields: Iterable[str], replace: bool, **kwargs + cls, table: str, values: tuple[str, ...], target_fields: Iterable[str], replace: bool, **kwargs ) -> str: """ Static helper method that generates the INSERT SQL statement. @@ -245,10 +277,9 @@ def _generate_insert_sql( :param replace_index: the column or list of column names to act as index for the ON CONFLICT clause :return: The generated INSERT or REPLACE SQL statement - :rtype: str """ placeholders = [ - "%s", + cls.placeholder, ] * len(values) replace_index = kwargs.get("replace_index") @@ -256,21 +287,25 @@ def _generate_insert_sql( target_fields_fragment = ", ".join(target_fields) target_fields_fragment = f"({target_fields_fragment})" else: - target_fields_fragment = '' + target_fields_fragment = "" sql = f"INSERT INTO {table} {target_fields_fragment} VALUES ({','.join(placeholders)})" if replace: - if target_fields is None: + if not target_fields: raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires column names") - if replace_index is None: + if not replace_index: raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires an unique index") if isinstance(replace_index, str): replace_index = [replace_index] - replace_index_set = set(replace_index) - replace_target = [ - "{0} = excluded.{0}".format(col) for col in target_fields if col not in replace_index_set - ] - sql += f" ON CONFLICT ({', '.join(replace_index)}) DO UPDATE SET {', '.join(replace_target)}" + on_conflict_str = f" ON CONFLICT ({', '.join(replace_index)})" + replace_target = [f for f in target_fields if f not in replace_index] + + if replace_target: + replace_target_str = ", ".join(f"{col} = excluded.{col}" for col in replace_target) + sql += f"{on_conflict_str} DO UPDATE SET {replace_target_str}" + else: + sql += f"{on_conflict_str} DO NOTHING" + return sql diff --git a/airflow/providers/postgres/operators/postgres.py b/airflow/providers/postgres/operators/postgres.py index e0238aa88204b..a9489b666369a 100644 --- a/airflow/providers/postgres/operators/postgres.py +++ b/airflow/providers/postgres/operators/postgres.py @@ -15,19 +15,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Sequence, Union +from __future__ import annotations -from psycopg2.sql import SQL, Identifier +import warnings +from typing import Mapping, Sequence -from airflow.models import BaseOperator -from airflow.providers.postgres.hooks.postgres import PostgresHook -from airflow.www import utils as wwwutils +from psycopg2.sql import SQL, Identifier -if TYPE_CHECKING: - from airflow.utils.context import Context +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator -class PostgresOperator(BaseOperator): +class PostgresOperator(SQLExecuteQueryOperator): """ Executes sql code in a specific Postgres database @@ -40,55 +38,54 @@ class PostgresOperator(BaseOperator): (default value: False) :param parameters: (optional) the parameters to render the SQL query with. :param database: name of database which overwrite defined one in connection + :param runtime_parameters: a mapping of runtime params added to the final sql being executed. + For example, you could set the schema via `{"search_path": "CUSTOM_SCHEMA"}`. """ - template_fields: Sequence[str] = ('sql',) - # TODO: Remove renderer check when the provider has an Airflow 2.3+ requirement. - template_fields_renderers = { - 'sql': 'postgresql' if 'postgresql' in wwwutils.get_attr_renderer() else 'sql' - } - template_ext: Sequence[str] = ('.sql',) - ui_color = '#ededed' + template_fields: Sequence[str] = ("sql",) + template_fields_renderers = {"sql": "postgresql"} + template_ext: Sequence[str] = (".sql",) + ui_color = "#ededed" def __init__( self, *, - sql: Union[str, List[str]], - postgres_conn_id: str = 'postgres_default', - autocommit: bool = False, - parameters: Optional[Union[Mapping, Iterable]] = None, - database: Optional[str] = None, - runtime_parameters: Optional[Mapping] = None, + postgres_conn_id: str = "postgres_default", + database: str | None = None, + runtime_parameters: Mapping | None = None, **kwargs, ) -> None: - super().__init__(**kwargs) - self.sql = sql - self.postgres_conn_id = postgres_conn_id - self.autocommit = autocommit - self.parameters = parameters - self.database = database - self.runtime_parameters = runtime_parameters - self.hook: Optional[PostgresHook] = None + if database is not None: + hook_params = kwargs.pop("hook_params", {}) + kwargs["hook_params"] = {"schema": database, **hook_params} + + if runtime_parameters: + sql = kwargs.pop("sql") + parameters = kwargs.pop("parameters", {}) - def execute(self, context: 'Context'): - self.hook = PostgresHook(postgres_conn_id=self.postgres_conn_id, schema=self.database) - if self.runtime_parameters: final_sql = [] sql_param = {} - for param in self.runtime_parameters: + for param in runtime_parameters: set_param_sql = f"SET {{}} TO %({param})s;" dynamic_sql = SQL(set_param_sql).format(Identifier(f"{param}")) final_sql.append(dynamic_sql) - for param, val in self.runtime_parameters.items(): + for param, val in runtime_parameters.items(): sql_param.update({f"{param}": f"{val}"}) - if self.parameters: - sql_param.update(self.parameters) - if isinstance(self.sql, str): - final_sql.append(SQL(self.sql)) + if parameters: + sql_param.update(parameters) + if isinstance(sql, str): + final_sql.append(SQL(sql)) else: - final_sql.extend(list(map(SQL, self.sql))) - self.hook.run(final_sql, self.autocommit, parameters=sql_param) - else: - self.hook.run(self.sql, self.autocommit, parameters=self.parameters) - for output in self.hook.conn.notices: - self.log.info(output) + final_sql.extend(list(map(SQL, sql))) + + kwargs["sql"] = final_sql + kwargs["parameters"] = sql_param + + super().__init__(conn_id=postgres_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`. + Also, you can provide `hook_params={'schema': }`.""", + DeprecationWarning, + stacklevel=2, + ) diff --git a/airflow/providers/postgres/provider.yaml b/airflow/providers/postgres/provider.yaml index c55aad1cca63a..46493ef959b15 100644 --- a/airflow/providers/postgres/provider.yaml +++ b/airflow/providers/postgres/provider.yaml @@ -22,6 +22,12 @@ description: | `PostgreSQL `__ versions: + - 5.3.0 + - 5.2.2 + - 5.2.1 + - 5.2.0 + - 5.1.0 + - 5.0.0 - 4.1.0 - 4.0.1 - 4.0.0 @@ -35,8 +41,10 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-common-sql>=1.3.1 + - psycopg2>=2.8.0 integrations: - integration-name: PostgreSQL @@ -56,9 +64,11 @@ hooks: python-modules: - airflow.providers.postgres.hooks.postgres -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.postgres.hooks.postgres.PostgresHook - connection-types: - hook-class-name: airflow.providers.postgres.hooks.postgres.PostgresHook connection-type: postgres + +additional-extras: + - name: amazon + dependencies: + - apache-airflow-providers-amazon>=2.6.0 diff --git a/airflow/providers/presto/.latest-doc-only-change.txt b/airflow/providers/presto/.latest-doc-only-change.txt index b4e77ccb86f61..ff7136e07d744 100644 --- a/airflow/providers/presto/.latest-doc-only-change.txt +++ b/airflow/providers/presto/.latest-doc-only-change.txt @@ -1 +1 @@ -5164cdbe98ad63754d969b4b300a7a0167565e33 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/presto/CHANGELOG.rst b/airflow/providers/presto/CHANGELOG.rst index 901e3966d0d35..0a724bd20c86e 100644 --- a/airflow/providers/presto/CHANGELOG.rst +++ b/airflow/providers/presto/CHANGELOG.rst @@ -16,9 +16,109 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +4.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Use DbApiHook.run for DbApiHook.get_records and DbApiHook.get_first (#26944)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +4.0.1 +..... + +Features +~~~~~~~~ + +* ``Add common-sql lower bound for common-sql (#25789)`` + +Bug Fixes +~~~~~~~~~ + +* ``Fix placeholders in 'TrinoHook', 'PrestoHook', 'SqliteHook' (#25939)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + +4.0.0 +..... + + +Breaking changes +~~~~~~~~~~~~~~~~ + +Deprecated ``hql`` parameter has been removed in ``get_records``, ``get_first``, ``get_pandas_df`` and ``run`` +methods of the ``PrestoHook``. + +Remove ``PrestoToSlackOperator`` in favor of Slack provider ``SqlToSlackOperator``. + +* ``Remove 'PrestoToSlackOperator' (#25425)`` + +Breaking changes +~~~~~~~~~~~~~~~~ + +* ``Deprecate hql parameters and synchronize DBApiHook method APIs (#25299)`` + +Features +~~~~~~~~~ + +* ``Unify DbApiHook.run() method with the methods which override it (#23971)`` + + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Adding generic 'SqlToSlackOperator' (#24663)`` +* ``Move all SQL classes to common-sql provider (#24836)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Features +~~~~~~~~ + +* ``Add 'PrestoToSlackOperator' (#23979)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Migrate Presto example DAGs to new design #22459 (#24145)`` + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.2.1 ..... diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py index 95ecf86b52dda..825b28ad60a54 100644 --- a/airflow/providers/presto/hooks/presto.py +++ b/airflow/providers/presto/hooks/presto.py @@ -15,10 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import json import os -import warnings -from typing import Any, Callable, Iterable, Optional, overload +from typing import Any, Callable, Iterable, Mapping import prestodb from prestodb.exceptions import DatabaseError @@ -26,8 +27,8 @@ from airflow import AirflowException from airflow.configuration import conf -from airflow.hooks.dbapi import DbApiHook from airflow.models import Connection +from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING try: @@ -36,27 +37,24 @@ # This is from airflow.utils.operator_helpers, # For the sake of provider backward compatibility, this is hardcoded if import fails # https://github.com/apache/airflow/pull/22416#issuecomment-1075531290 - DEFAULT_FORMAT_PREFIX = 'airflow.ctx.' + DEFAULT_FORMAT_PREFIX = "airflow.ctx." def generate_presto_client_info() -> str: """Return json string with dag_id, task_id, execution_date and try_number""" context_var = { - format_map['default'].replace(DEFAULT_FORMAT_PREFIX, ''): os.environ.get( - format_map['env_var_format'], '' + format_map["default"].replace(DEFAULT_FORMAT_PREFIX, ""): os.environ.get( + format_map["env_var_format"], "" ) for format_map in AIRFLOW_VAR_NAME_FORMAT_MAPPING.values() } - # try_number isn't available in context for airflow < 2.2.5 - # https://github.com/apache/airflow/issues/23059 - try_number = context_var.get('try_number', '') task_info = { - 'dag_id': context_var['dag_id'], - 'task_id': context_var['task_id'], - 'execution_date': context_var['execution_date'], - 'try_number': try_number, - 'dag_run_id': context_var['dag_run_id'], - 'dag_owner': context_var['dag_owner'], + "dag_id": context_var["dag_id"], + "task_id": context_var["task_id"], + "execution_date": context_var["execution_date"], + "try_number": context_var["try_number"], + "dag_run_id": context_var["dag_run_id"], + "dag_owner": context_var["dag_owner"], } return json.dumps(task_info, sort_keys=True) @@ -69,9 +67,9 @@ def _boolify(value): if isinstance(value, bool): return value if isinstance(value, str): - if value.lower() == 'false': + if value.lower() == "false": return False - elif value.lower() == 'true': + elif value.lower() == "true": return True return value @@ -86,33 +84,34 @@ class PrestoHook(DbApiHook): [[340698]] """ - conn_name_attr = 'presto_conn_id' - default_conn_name = 'presto_default' - conn_type = 'presto' - hook_name = 'Presto' + conn_name_attr = "presto_conn_id" + default_conn_name = "presto_default" + conn_type = "presto" + hook_name = "Presto" + placeholder = "?" def get_conn(self) -> Connection: """Returns a connection object""" db = self.get_connection(self.presto_conn_id) # type: ignore[attr-defined] extra = db.extra_dejson auth = None - if db.password and extra.get('auth') == 'kerberos': + if db.password and extra.get("auth") == "kerberos": raise AirflowException("Kerberos authorization doesn't support password.") elif db.password: auth = prestodb.auth.BasicAuthentication(db.login, db.password) - elif extra.get('auth') == 'kerberos': + elif extra.get("auth") == "kerberos": auth = prestodb.auth.KerberosAuthentication( - config=extra.get('kerberos__config', os.environ.get('KRB5_CONFIG')), - service_name=extra.get('kerberos__service_name'), - mutual_authentication=_boolify(extra.get('kerberos__mutual_authentication', False)), - force_preemptive=_boolify(extra.get('kerberos__force_preemptive', False)), - hostname_override=extra.get('kerberos__hostname_override'), + config=extra.get("kerberos__config", os.environ.get("KRB5_CONFIG")), + service_name=extra.get("kerberos__service_name"), + mutual_authentication=_boolify(extra.get("kerberos__mutual_authentication", False)), + force_preemptive=_boolify(extra.get("kerberos__force_preemptive", False)), + hostname_override=extra.get("kerberos__hostname_override"), sanitize_mutual_error_response=_boolify( - extra.get('kerberos__sanitize_mutual_error_response', True) + extra.get("kerberos__sanitize_mutual_error_response", True) ), - principal=extra.get('kerberos__principal', conf.get('kerberos', 'principal')), - delegate=_boolify(extra.get('kerberos__delegate', False)), - ca_bundle=extra.get('kerberos__ca_bundle'), + principal=extra.get("kerberos__principal", conf.get("kerberos", "principal")), + delegate=_boolify(extra.get("kerberos__delegate", False)), + ca_bundle=extra.get("kerberos__ca_bundle"), ) http_headers = {"X-Presto-Client-Info": generate_presto_client_info()} @@ -120,113 +119,54 @@ def get_conn(self) -> Connection: host=db.host, port=db.port, user=db.login, - source=db.extra_dejson.get('source', 'airflow'), + source=db.extra_dejson.get("source", "airflow"), http_headers=http_headers, - http_scheme=db.extra_dejson.get('protocol', 'http'), - catalog=db.extra_dejson.get('catalog', 'hive'), + http_scheme=db.extra_dejson.get("protocol", "http"), + catalog=db.extra_dejson.get("catalog", "hive"), schema=db.schema, auth=auth, isolation_level=self.get_isolation_level(), # type: ignore[func-returns-value] ) - if extra.get('verify') is not None: + if extra.get("verify") is not None: # Unfortunately verify parameter is available via public API. # The PR is merged in the presto library, but has not been released. # See: https://github.com/prestosql/presto-python-client/pull/31 - presto_conn._http_session.verify = _boolify(extra['verify']) + presto_conn._http_session.verify = _boolify(extra["verify"]) return presto_conn def get_isolation_level(self) -> Any: """Returns an isolation level""" db = self.get_connection(self.presto_conn_id) # type: ignore[attr-defined] - isolation_level = db.extra_dejson.get('isolation_level', 'AUTOCOMMIT').upper() + isolation_level = db.extra_dejson.get("isolation_level", "AUTOCOMMIT").upper() return getattr(IsolationLevel, isolation_level, IsolationLevel.AUTOCOMMIT) - @staticmethod - def _strip_sql(sql: str) -> str: - return sql.strip().rstrip(';') - - @overload - def get_records(self, sql: str = "", parameters: Optional[dict] = None): - """Get a set of records from Presto - - :param sql: SQL statement to be executed. - :param parameters: The parameters to render the SQL query with. - """ - - @overload - def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""): - """:sphinx-autoapi-skip:""" - - def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""): - """:sphinx-autoapi-skip:""" - if hql: - warnings.warn( - "The hql parameter has been deprecated. You should pass the sql parameter.", - DeprecationWarning, - stacklevel=2, - ) - sql = hql - + def get_records( + self, + sql: str | list[str] = "", + parameters: Iterable | Mapping | None = None, + ) -> Any: + if not isinstance(sql, str): + raise ValueError(f"The sql in Presto Hook must be a string and is {sql}!") try: - return super().get_records(self._strip_sql(sql), parameters) + return super().get_records(self.strip_sql_string(sql), parameters) except DatabaseError as e: raise PrestoException(e) - @overload - def get_first(self, sql: str = "", parameters: Optional[dict] = None) -> Any: - """Returns only the first row, regardless of how many rows the query returns. - - :param sql: SQL statement to be executed. - :param parameters: The parameters to render the SQL query with. - """ - - @overload - def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any: - """:sphinx-autoapi-skip:""" - - def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any: - """:sphinx-autoapi-skip:""" - if hql: - warnings.warn( - "The hql parameter has been deprecated. You should pass the sql parameter.", - DeprecationWarning, - stacklevel=2, - ) - sql = hql - + def get_first(self, sql: str | list[str] = "", parameters: Iterable | Mapping | None = None) -> Any: + if not isinstance(sql, str): + raise ValueError(f"The sql in Presto Hook must be a string and is {sql}!") try: - return super().get_first(self._strip_sql(sql), parameters) + return super().get_first(self.strip_sql_string(sql), parameters) except DatabaseError as e: raise PrestoException(e) - @overload def get_pandas_df(self, sql: str = "", parameters=None, **kwargs): - """Get a pandas dataframe from a sql query. - - :param sql: SQL statement to be executed. - :param parameters: The parameters to render the SQL query with. - """ - - @overload - def get_pandas_df(self, sql: str = "", parameters=None, hql: str = "", **kwargs): - """:sphinx-autoapi-skip:""" - - def get_pandas_df(self, sql: str = "", parameters=None, hql: str = "", **kwargs): - """:sphinx-autoapi-skip:""" - if hql: - warnings.warn( - "The hql parameter has been deprecated. You should pass the sql parameter.", - DeprecationWarning, - stacklevel=2, - ) - sql = hql - import pandas cursor = self.get_cursor() try: - cursor.execute(self._strip_sql(sql), parameters) + cursor.execute(self.strip_sql_string(sql), parameters) data = cursor.fetchall() except DatabaseError as e: raise PrestoException(e) @@ -238,51 +178,29 @@ def get_pandas_df(self, sql: str = "", parameters=None, hql: str = "", **kwargs) df = pandas.DataFrame(**kwargs) return df - @overload - def run( - self, - sql: str = "", - autocommit: bool = False, - parameters: Optional[dict] = None, - handler: Optional[Callable] = None, - ) -> None: - """Execute the statement against Presto. Can be used to create views.""" - - @overload - def run( - self, - sql: str = "", - autocommit: bool = False, - parameters: Optional[dict] = None, - handler: Optional[Callable] = None, - hql: str = "", - ) -> None: - """:sphinx-autoapi-skip:""" - def run( self, - sql: str = "", + sql: str | Iterable[str], autocommit: bool = False, - parameters: Optional[dict] = None, - handler: Optional[Callable] = None, - hql: str = "", - ) -> None: - """:sphinx-autoapi-skip:""" - if hql: - warnings.warn( - "The hql parameter has been deprecated. You should pass the sql parameter.", - DeprecationWarning, - stacklevel=2, - ) - sql = hql - - return super().run(sql=self._strip_sql(sql), parameters=parameters, handler=handler) + parameters: Iterable | Mapping | None = None, + handler: Callable | None = None, + split_statements: bool = False, + return_last: bool = True, + ) -> Any | list[Any] | None: + return super().run( + sql=sql, + autocommit=autocommit, + parameters=parameters, + handler=handler, + split_statements=split_statements, + return_last=return_last, + ) def insert_rows( self, table: str, rows: Iterable[tuple], - target_fields: Optional[Iterable[str]] = None, + target_fields: Iterable[str] | None = None, commit_every: int = 0, replace: bool = False, **kwargs, @@ -299,10 +217,22 @@ def insert_rows( """ if self.get_isolation_level() == IsolationLevel.AUTOCOMMIT: self.log.info( - 'Transactions are not enable in presto connection. ' - 'Please use the isolation_level property to enable it. ' - 'Falling back to insert all rows in one transaction.' + "Transactions are not enable in presto connection. " + "Please use the isolation_level property to enable it. " + "Falling back to insert all rows in one transaction." ) commit_every = 0 super().insert_rows(table, rows, target_fields, commit_every) + + @staticmethod + def _serialize_cell(cell: Any, conn: Connection | None = None) -> Any: + """ + Presto will adapt all arguments to the execute() method internally, + hence we return cell without any conversion. + + :param cell: The cell to insert into the table + :param conn: The database connection + :return: The cell + """ + return cell diff --git a/airflow/providers/presto/provider.yaml b/airflow/providers/presto/provider.yaml index 66c493e3f6f7c..01c2bf4020934 100644 --- a/airflow/providers/presto/provider.yaml +++ b/airflow/providers/presto/provider.yaml @@ -22,6 +22,11 @@ description: | `Presto `__ versions: + - 4.1.0 + - 4.0.1 + - 4.0.0 + - 3.1.0 + - 3.0.0 - 2.2.1 - 2.2.0 - 2.1.2 @@ -33,8 +38,11 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-common-sql>=1.3.1 + - presto-python-client>=0.8.2 + - pandas>=0.17.1 integrations: - integration-name: Presto @@ -53,8 +61,6 @@ transfers: how-to-guide: /docs/apache-airflow-providers-presto/operators/transfer/gcs_to_presto.rst python-module: airflow.providers.presto.transfers.gcs_to_presto -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.presto.hooks.presto.PrestoHook connection-types: - hook-class-name: airflow.providers.presto.hooks.presto.PrestoHook diff --git a/airflow/providers/presto/transfers/gcs_to_presto.py b/airflow/providers/presto/transfers/gcs_to_presto.py index a37498c92e093..466772856bd60 100644 --- a/airflow/providers/presto/transfers/gcs_to_presto.py +++ b/airflow/providers/presto/transfers/gcs_to_presto.py @@ -16,11 +16,12 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Cloud Storage to Presto operator.""" +from __future__ import annotations import csv import json from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union +from typing import TYPE_CHECKING, Iterable, Sequence from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook @@ -58,9 +59,9 @@ class GCSToPrestoOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'source_bucket', - 'source_object', - 'presto_table', + "source_bucket", + "source_object", + "presto_table", ) def __init__( @@ -71,10 +72,10 @@ def __init__( presto_table: str, presto_conn_id: str = "presto_default", gcp_conn_id: str = "google_cloud_default", - schema_fields: Optional[Iterable[str]] = None, - schema_object: Optional[str] = None, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + schema_fields: Iterable[str] | None = None, + schema_object: str | None = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -88,7 +89,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: gcs_hook = GCSHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/qubole/CHANGELOG.rst b/airflow/providers/qubole/CHANGELOG.rst index b044df058a39c..78330fd8c2d05 100644 --- a/airflow/providers/qubole/CHANGELOG.rst +++ b/airflow/providers/qubole/CHANGELOG.rst @@ -16,9 +16,98 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.3.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + +3.2.1 +..... + +Misc +~~~~ + +* ``Add common-sql lower bound for common-sql (#25789)`` + +.. Review and move the new changes to one of the sections above: + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``D400 first line should end with period batch02 (#25268)`` + +3.2.0 +..... + +Features +~~~~~~~~ + +* ``Make extra link work in UI (#25500)`` +* ``Move all "old" SQL operators to common.sql providers (#25350)`` +* ``Improve taskflow type hints with ParamSpec (#25173)`` + +Bug Fixes +~~~~~~~~~ + +* ``Correctly render 'results_parser_callable' param in Qubole docs (#25514)`` + + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Move all SQL classes to common-sql provider (#24836)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Remove "bad characters" from our codebase (#24841)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Bug Fixes +~~~~~~~~~ + +* ``Add typing for airflow/configuration.py (#23716)`` +* ``Fix backwards-compatibility introduced by fixing mypy problems (#24230)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Migrate Qubole example DAGs to new design #22460 (#24149)`` + * ``Prepare provider documentation 2022.05.11 (#23631)`` + * ``Use new Breese for building, pulling and verifying the images. (#23104)`` + * ``Replace usage of 'DummyOperator' with 'EmptyOperator' (#22974)`` + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.1.3 ..... diff --git a/airflow/providers/qubole/example_dags/__init__.py b/airflow/providers/qubole/example_dags/__init__.py deleted file mode 100644 index 217e5db960782..0000000000000 --- a/airflow/providers/qubole/example_dags/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/airflow/providers/qubole/example_dags/example_qubole.py b/airflow/providers/qubole/example_dags/example_qubole.py deleted file mode 100644 index ea9bdfd205ae0..0000000000000 --- a/airflow/providers/qubole/example_dags/example_qubole.py +++ /dev/null @@ -1,270 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import filecmp -import random -import textwrap -from datetime import datetime - -from airflow import DAG -from airflow.decorators import task - -try: - from airflow.operators.empty import EmptyOperator -except ModuleNotFoundError: - from airflow.operators.dummy import DummyOperator as EmptyOperator # type: ignore -from airflow.operators.python import BranchPythonOperator -from airflow.providers.qubole.operators.qubole import QuboleOperator -from airflow.providers.qubole.sensors.qubole import QuboleFileSensor, QubolePartitionSensor -from airflow.utils.trigger_rule import TriggerRule - -START_DATE = datetime(2021, 1, 1) - -with DAG( - dag_id='example_qubole_operator', - schedule_interval=None, - start_date=START_DATE, - tags=['example'], -) as dag: - dag.doc_md = textwrap.dedent( - """ - This is only an example DAG to highlight usage of QuboleOperator in various scenarios, - some of these tasks may or may not work based on your Qubole account setup. - - Run a shell command from Qubole Analyze against your Airflow cluster with following to - trigger it manually `airflow dags trigger example_qubole_operator`. - - *Note: Make sure that connection `qubole_default` is properly set before running this - example. Also be aware that it might spin up clusters to run these examples.* - """ - ) - - @task(trigger_rule=TriggerRule.ALL_DONE) - def compare_result(hive_show_table, hive_s3_location, ti=None): - """ - Compares the results of two QuboleOperator tasks. - - :param hive_show_table: The "hive_show_table" task. - :param hive_s3_location: The "hive_s3_location" task. - :param ti: The TaskInstance object. - :return: True if the files are the same, False otherwise. - :rtype: bool - """ - qubole_result_1 = hive_show_table.get_results(ti) - qubole_result_2 = hive_s3_location.get_results(ti) - return filecmp.cmp(qubole_result_1, qubole_result_2) - - # [START howto_operator_qubole_run_hive_query] - hive_show_table = QuboleOperator( - task_id='hive_show_table', - command_type='hivecmd', - query='show tables', - cluster_label='{{ params.cluster_label }}', - fetch_logs=True, - # If `fetch_logs`=true, will fetch qubole command logs and concatenate - # them into corresponding airflow task logs - tags='airflow_example_run', - # To attach tags to qubole command, auto attach 3 tags - dag_id, task_id, run_id - params={ - 'cluster_label': 'default', - }, - ) - # [END howto_operator_qubole_run_hive_query] - - # [START howto_operator_qubole_run_hive_script] - hive_s3_location = QuboleOperator( - task_id='hive_s3_location', - command_type="hivecmd", - script_location="s3n://public-qubole/qbol-library/scripts/show_table.hql", - notify=True, - tags=['tag1', 'tag2'], - # If the script at s3 location has any qubole specific macros to be replaced - # macros='[{"date": "{{ ds }}"}, {"name" : "abc"}]', - ) - # [END howto_operator_qubole_run_hive_script] - - options = ['hadoop_jar_cmd', 'presto_cmd', 'db_query', 'spark_cmd'] - - branching = BranchPythonOperator(task_id='branching', python_callable=lambda: random.choice(options)) - - [hive_show_table, hive_s3_location] >> compare_result(hive_s3_location, hive_show_table) >> branching - - join = EmptyOperator(task_id='join', trigger_rule=TriggerRule.ONE_SUCCESS) - - # [START howto_operator_qubole_run_hadoop_jar] - hadoop_jar_cmd = QuboleOperator( - task_id='hadoop_jar_cmd', - command_type='hadoopcmd', - sub_command='jar s3://paid-qubole/HadoopAPIExamples/' - 'jars/hadoop-0.20.1-dev-streaming.jar ' - '-mapper wc ' - '-numReduceTasks 0 -input s3://paid-qubole/HadoopAPITests/' - 'data/3.tsv -output ' - 's3://paid-qubole/HadoopAPITests/data/3_wc', - cluster_label='{{ params.cluster_label }}', - fetch_logs=True, - params={ - 'cluster_label': 'default', - }, - ) - # [END howto_operator_qubole_run_hadoop_jar] - - # [START howto_operator_qubole_run_pig_script] - pig_cmd = QuboleOperator( - task_id='pig_cmd', - command_type="pigcmd", - script_location="s3://public-qubole/qbol-library/scripts/script1-hadoop-s3-small.pig", - parameters="key1=value1 key2=value2", - ) - # [END howto_operator_qubole_run_pig_script] - - branching >> hadoop_jar_cmd >> pig_cmd >> join - - # [START howto_operator_qubole_run_presto_query] - presto_cmd = QuboleOperator(task_id='presto_cmd', command_type='prestocmd', query='show tables') - # [END howto_operator_qubole_run_presto_query] - - # [START howto_operator_qubole_run_shell_script] - shell_cmd = QuboleOperator( - task_id='shell_cmd', - command_type="shellcmd", - script_location="s3://public-qubole/qbol-library/scripts/shellx.sh", - parameters="param1 param2", - ) - # [END howto_operator_qubole_run_shell_script] - - branching >> presto_cmd >> shell_cmd >> join - - # [START howto_operator_qubole_run_db_tap_query] - db_query = QuboleOperator( - task_id='db_query', command_type='dbtapquerycmd', query='show tables', db_tap_id=2064 - ) - # [END howto_operator_qubole_run_db_tap_query] - - # [START howto_operator_qubole_run_db_export] - db_export = QuboleOperator( - task_id='db_export', - command_type='dbexportcmd', - mode=1, - hive_table='default_qubole_airline_origin_destination', - db_table='exported_airline_origin_destination', - partition_spec='dt=20110104-02', - dbtap_id=2064, - ) - # [END howto_operator_qubole_run_db_export] - - branching >> db_query >> db_export >> join - - # [START howto_operator_qubole_run_db_import] - db_import = QuboleOperator( - task_id='db_import', - command_type='dbimportcmd', - mode=1, - hive_table='default_qubole_airline_origin_destination', - db_table='exported_airline_origin_destination', - where_clause='id < 10', - parallelism=2, - dbtap_id=2064, - ) - # [END howto_operator_qubole_run_db_import] - - # [START howto_operator_qubole_run_spark_scala] - prog = ''' - import scala.math.random - import org.apache.spark._ - - /** Computes an approximation to pi */ - object SparkPi { - def main(args: Array[String]) { - val conf = new SparkConf().setAppName("Spark Pi") - val spark = new SparkContext(conf) - val slices = if (args.length > 0) args(0).toInt else 2 - val n = math.min(100000L * slices, Int.MaxValue).toInt // avoid overflow - val count = spark.parallelize(1 until n, slices).map { i => - val x = random * 2 - 1 - val y = random * 2 - 1 - if (x*x + y*y < 1) 1 else 0 - }.reduce(_ + _) - println("Pi is roughly " + 4.0 * count / n) - spark.stop() - } - } - ''' - - spark_cmd = QuboleOperator( - task_id='spark_cmd', - command_type="sparkcmd", - program=prog, - language='scala', - arguments='--class SparkPi', - tags='airflow_example_run', - ) - # [END howto_operator_qubole_run_spark_scala] - - branching >> db_import >> spark_cmd >> join - -with DAG( - dag_id='example_qubole_sensor', - schedule_interval=None, - start_date=START_DATE, - tags=['example'], -) as dag2: - dag2.doc_md = textwrap.dedent( - """ - This is only an example DAG to highlight usage of QuboleSensor in various scenarios, - some of these tasks may or may not work based on your QDS account setup. - - Run a shell command from Qubole Analyze against your Airflow cluster with following to - trigger it manually `airflow dags trigger example_qubole_sensor`. - - *Note: Make sure that connection `qubole_default` is properly set before running - this example.* - """ - ) - - # [START howto_sensor_qubole_run_file_sensor] - check_s3_file = QuboleFileSensor( - task_id='check_s3_file', - poke_interval=60, - timeout=600, - data={ - "files": [ - "s3://paid-qubole/HadoopAPIExamples/jars/hadoop-0.20.1-dev-streaming.jar", - "s3://paid-qubole/HadoopAPITests/data/{{ ds.split('-')[2] }}.tsv", - ] # will check for availability of all the files in array - }, - ) - # [END howto_sensor_qubole_run_file_sensor] - - # [START howto_sensor_qubole_run_partition_sensor] - check_hive_partition = QubolePartitionSensor( - task_id='check_hive_partition', - poke_interval=10, - timeout=60, - data={ - "schema": "default", - "table": "my_partitioned_table", - "columns": [ - {"column": "month", "values": ["{{ ds.split('-')[1] }}"]}, - {"column": "day", "values": ["{{ ds.split('-')[2] }}", "{{ yesterday_ds.split('-')[2] }}"]}, - ], # will check for partitions like [month=12/day=12,month=12/day=13] - }, - ) - # [END howto_sensor_qubole_run_partition_sensor] - - check_s3_file >> check_hive_partition diff --git a/airflow/providers/qubole/hooks/qubole.py b/airflow/providers/qubole/hooks/qubole.py index 7896fbd352807..f011bd5b5ce88 100644 --- a/airflow/providers/qubole/hooks/qubole.py +++ b/airflow/providers/qubole/hooks/qubole.py @@ -15,14 +15,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """Qubole hook""" +from __future__ import annotations + import datetime import logging import os import pathlib import time -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any from qds_sdk.commands import ( Command, @@ -46,6 +47,7 @@ from airflow.utils.state import State if TYPE_CHECKING: + from airflow.models.taskinstance import TaskInstance from airflow.utils.context import Context @@ -65,7 +67,7 @@ "jupytercmd": JupyterNotebookCommand, } -POSITIONAL_ARGS = {'hadoopcmd': ['sub_command'], 'shellcmd': ['parameters'], 'pigcmd': ['parameters']} +POSITIONAL_ARGS = {"hadoopcmd": ["sub_command"], "shellcmd": ["parameters"], "pigcmd": ["parameters"]} def flatten_list(list_of_lists) -> list: @@ -85,7 +87,7 @@ def get_options_list(command_class) -> list: return filter_options(options_list) -def build_command_args() -> Tuple[Dict[str, list], list]: +def build_command_args() -> tuple[dict[str, list], list]: """Build Command argument from command and options""" command_args, hyphen_args = {}, set() for cmd in COMMAND_CLASSES: @@ -113,56 +115,56 @@ def build_command_args() -> Tuple[Dict[str, list], list]: class QuboleHook(BaseHook): """Hook for Qubole communication""" - conn_name_attr = 'qubole_conn_id' - default_conn_name = 'qubole_default' - conn_type = 'qubole' - hook_name = 'Qubole' + conn_name_attr: str = "qubole_conn_id" + default_conn_name = "qubole_default" + conn_type = "qubole" + hook_name = "Qubole" @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { - "hidden_fields": ['login', 'schema', 'port', 'extra'], + "hidden_fields": ["login", "schema", "port", "extra"], "relabeling": { - 'host': 'API Endpoint', - 'password': 'Auth Token', + "host": "API Endpoint", + "password": "Auth Token", }, - "placeholders": {'host': 'https://.qubole.com/api'}, + "placeholders": {"host": "https://.qubole.com/api"}, } def __init__(self, *args, **kwargs) -> None: super().__init__() - conn = self.get_connection(kwargs.get('qubole_conn_id', self.default_conn_name)) + conn = self.get_connection(kwargs.get("qubole_conn_id", self.default_conn_name)) Qubole.configure(api_token=conn.password, api_url=conn.host) - self.task_id = kwargs['task_id'] - self.dag_id = kwargs['dag'].dag_id + self.task_id = kwargs["task_id"] + self.dag_id = kwargs["dag"].dag_id self.kwargs = kwargs - self.cls = COMMAND_CLASSES[self.kwargs['command_type']] - self.cmd: Optional[Command] = None - self.task_instance = None + self.cls = COMMAND_CLASSES[self.kwargs["command_type"]] + self.cmd: Command | None = None + self.task_instance: TaskInstance | None = None @staticmethod def handle_failure_retry(context) -> None: """Handle retries in case of failures""" - ti = context['ti'] - cmd_id = ti.xcom_pull(key='qbol_cmd_id', task_ids=ti.task_id) + ti = context["ti"] + cmd_id = ti.xcom_pull(key="qbol_cmd_id", task_ids=ti.task_id) if cmd_id is not None: cmd = Command.find(cmd_id) if cmd is not None: - if cmd.status == 'done': - log.info('Command ID: %s has been succeeded, hence marking this TI as Success.', cmd_id) + if cmd.status == "done": + log.info("Command ID: %s has been succeeded, hence marking this TI as Success.", cmd_id) ti.state = State.SUCCESS - elif cmd.status == 'running': - log.info('Cancelling the Qubole Command Id: %s', cmd_id) + elif cmd.status == "running": + log.info("Cancelling the Qubole Command Id: %s", cmd_id) cmd.cancel() - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: """Execute call""" args = self.cls.parse(self.create_cmd_args(context)) self.cmd = self.cls.create(**args) - self.task_instance = context['task_instance'] - context['task_instance'].xcom_push(key='qbol_cmd_id', value=self.cmd.id) # type: ignore[attr-defined] + self.task_instance = context["task_instance"] + context["task_instance"].xcom_push(key="qbol_cmd_id", value=self.cmd.id) # type: ignore[attr-defined] self.log.info( "Qubole command created with Id: %s and Status: %s", self.cmd.id, # type: ignore[attr-defined] @@ -176,14 +178,14 @@ def execute(self, context: 'Context') -> None: "Command Id: %s and Status: %s", self.cmd.id, self.cmd.status # type: ignore[attr-defined] ) - if 'fetch_logs' in self.kwargs and self.kwargs['fetch_logs'] is True: + if "fetch_logs" in self.kwargs and self.kwargs["fetch_logs"] is True: self.log.info( "Logs for Command Id: %s \n%s", self.cmd.id, self.cmd.get_log() # type: ignore[attr-defined] ) - if self.cmd.status != 'done': # type: ignore[attr-defined] + if self.cmd.status != "done": # type: ignore[attr-defined] raise AirflowException( - 'Command Id: {} failed with Status: {}'.format( + "Command Id: {} failed with Status: {}".format( self.cmd.id, self.cmd.status # type: ignore[attr-defined] ) ) @@ -203,7 +205,7 @@ def kill(self, ti): cmd_id = ti.xcom_pull(key="qbol_cmd_id", task_ids=ti.task_id) self.cmd = self.cls.find(cmd_id) if self.cls and self.cmd: - self.log.info('Sending KILL signal to Qubole Command Id: %s', self.cmd.id) + self.log.info("Sending KILL signal to Qubole Command Id: %s", self.cmd.id) self.cmd.cancel() def get_results( @@ -227,16 +229,17 @@ def get_results( """ if fp is None: iso = datetime.datetime.utcnow().isoformat() - logpath = os.path.expanduser(conf.get_mandatory_value('logging', 'BASE_LOG_FOLDER')) - resultpath = logpath + '/' + self.dag_id + '/' + self.task_id + '/results' + base_log_folder = conf.get_mandatory_value("logging", "BASE_LOG_FOLDER") + logpath = os.path.expanduser(base_log_folder) + resultpath = logpath + "/" + self.dag_id + "/" + self.task_id + "/results" pathlib.Path(resultpath).mkdir(parents=True, exist_ok=True) - fp = open(resultpath + '/' + iso, 'wb') + fp = open(resultpath + "/" + iso, "wb") if self.cmd is None: cmd_id = ti.xcom_pull(key="qbol_cmd_id", task_ids=self.task_id) self.cmd = self.cls.find(cmd_id) - include_headers_str = 'true' if include_headers else 'false' + include_headers_str = "true" if include_headers else "false" self.cmd.get_results( fp, inline, delim, fetch, arguments=[include_headers_str] ) # type: ignore[attr-defined] @@ -266,12 +269,12 @@ def get_jobs_id(self, ti) -> None: cmd_id = ti.xcom_pull(key="qbol_cmd_id", task_ids=self.task_id) Command.get_jobs_id(cmd_id) - def create_cmd_args(self, context) -> List[str]: + def create_cmd_args(self, context) -> list[str]: """Creates command arguments""" args = [] - cmd_type = self.kwargs['command_type'] + cmd_type = self.kwargs["command_type"] inplace_args = None - tags = {self.dag_id, self.task_id, context['run_id']} + tags = {self.dag_id, self.task_id, context["run_id"]} positional_args_list = flatten_list(POSITIONAL_ARGS.values()) for key, value in self.kwargs.items(): @@ -280,9 +283,9 @@ def create_cmd_args(self, context) -> List[str]: args.append(f"--{key.replace('_', '-')}={value}") elif key in positional_args_list: inplace_args = value - elif key == 'tags': + elif key == "tags": self._add_tags(tags, value) - elif key == 'notify': + elif key == "notify": if value is True: args.append("--notify") else: @@ -291,7 +294,7 @@ def create_cmd_args(self, context) -> List[str]: args.append(f"--tags={','.join(filter(None, tags))}") if inplace_args is not None: - args += inplace_args.split(' ') + args += inplace_args.split(" ") return args diff --git a/airflow/providers/qubole/hooks/qubole_check.py b/airflow/providers/qubole/hooks/qubole_check.py index 5dba31c5cbe97..03d9d2d83733b 100644 --- a/airflow/providers/qubole/hooks/qubole_check.py +++ b/airflow/providers/qubole/hooks/qubole_check.py @@ -15,21 +15,21 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations + import logging from io import StringIO -from typing import List, Optional, Union from qds_sdk.commands import Command from airflow.exceptions import AirflowException -from airflow.hooks.dbapi import DbApiHook +from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.qubole.hooks.qubole import QuboleHook log = logging.getLogger(__name__) -COL_DELIM = '\t' -ROW_DELIM = '\r\n' +COL_DELIM = "\t" +ROW_DELIM = "\r\n" def isint(value) -> bool: @@ -58,7 +58,7 @@ def isbool(value) -> bool: return False -def parse_first_row(row_list) -> List[Union[bool, float, int, str]]: +def parse_first_row(row_list) -> list[bool | float | int | str]: """Parse Qubole first record list""" record_list = [] first_row = row_list[0] if row_list else "" @@ -81,22 +81,22 @@ class QuboleCheckHook(QuboleHook, DbApiHook): def __init__(self, context, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.results_parser_callable = parse_first_row - if 'results_parser_callable' in kwargs and kwargs['results_parser_callable'] is not None: - if not callable(kwargs['results_parser_callable']): - raise AirflowException('`results_parser_callable` param must be callable') - self.results_parser_callable = kwargs['results_parser_callable'] + if "results_parser_callable" in kwargs and kwargs["results_parser_callable"] is not None: + if not callable(kwargs["results_parser_callable"]): + raise AirflowException("`results_parser_callable` param must be callable") + self.results_parser_callable = kwargs["results_parser_callable"] self.context = context @staticmethod def handle_failure_retry(context) -> None: - ti = context['ti'] - cmd_id = ti.xcom_pull(key='qbol_cmd_id', task_ids=ti.task_id) + ti = context["ti"] + cmd_id = ti.xcom_pull(key="qbol_cmd_id", task_ids=ti.task_id) if cmd_id is not None: cmd = Command.find(cmd_id) if cmd is not None: - if cmd.status == 'running': - log.info('Cancelling the Qubole Command Id: %s', cmd_id) + if cmd.status == "running": + log.info("Cancelling the Qubole Command Id: %s", cmd_id) cmd.cancel() def get_first(self, sql): @@ -107,13 +107,13 @@ def get_first(self, sql): record_list = self.results_parser_callable(row_list) return record_list - def get_query_results(self) -> Optional[str]: + def get_query_results(self) -> str | None: """Get Qubole query result""" if self.cmd is not None: cmd_id = self.cmd.id self.log.info("command id: %d", cmd_id) query_result_buffer = StringIO() - self.cmd.get_results(fp=query_result_buffer, inline=True, delim=COL_DELIM, arguments=['true']) + self.cmd.get_results(fp=query_result_buffer, inline=True, delim=COL_DELIM, arguments=["true"]) query_result = query_result_buffer.getvalue() query_result_buffer.close() return query_result diff --git a/airflow/providers/qubole/operators/qubole.py b/airflow/providers/qubole/operators/qubole.py index 15a39c61bfaa1..710387663ffa2 100644 --- a/airflow/providers/qubole/operators/qubole.py +++ b/airflow/providers/qubole/operators/qubole.py @@ -16,9 +16,10 @@ # specific language governing permissions and limitations # under the License. """Qubole operator""" +from __future__ import annotations + import re -from datetime import datetime -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Sequence from airflow.hooks.base import BaseHook from airflow.models import BaseOperator, BaseOperatorLink, XCom @@ -31,7 +32,7 @@ ) if TYPE_CHECKING: - from airflow.models.abstractoperator import AbstractOperator + from airflow.models.taskinstance import TaskInstanceKey from airflow.utils.context import Context @@ -39,38 +40,30 @@ class QDSLink(BaseOperatorLink): """Link to QDS""" - name = 'Go to QDS' + name = "Go to QDS" def get_link( self, - operator: "AbstractOperator", - dttm: Optional[datetime] = None, + operator: BaseOperator, *, - ti_key: Optional["TaskInstanceKey"] = None, + ti_key: TaskInstanceKey, ) -> str: """ Get link to qubole command result page. :param operator: operator - :param dttm: datetime :return: url link """ conn = BaseHook.get_connection( getattr(operator, "qubole_conn_id", None) - or operator.kwargs['qubole_conn_id'] # type: ignore[attr-defined] + or operator.kwargs["qubole_conn_id"] # type: ignore[attr-defined] ) if conn and conn.host: - host = re.sub(r'api$', 'v2/analyze?command_id=', conn.host) - else: - host = 'https://api.qubole.com/v2/analyze?command_id=' - if ti_key is not None: - qds_command_id = XCom.get_value(key='qbol_cmd_id', ti_key=ti_key) + host = re.sub(r"api$", "v2/analyze?command_id=", conn.host) else: - assert dttm - qds_command_id = XCom.get_one( - key='qbol_cmd_id', dag_id=operator.dag_id, task_id=operator.task_id, execution_date=dttm - ) - url = host + str(qds_command_id) if qds_command_id else '' + host = "https://api.qubole.com/v2/analyze?command_id=" + qds_command_id = XCom.get_value(key="qbol_cmd_id", ti_key=ti_key) + url = host + str(qds_command_id) if qds_command_id else "" return url @@ -176,8 +169,8 @@ class QuboleOperator(BaseOperator): jupytercmd: :path: Path including name of the Jupyter notebook to be run with extension. :arguments: Valid JSON to be sent to the notebook. Specify the parameters in notebooks and pass - the parameter value using the JSON format. key is the parameter’s name and value is - the parameter’s value. Supported types in parameters are string, integer, float and boolean. + the parameter value using the JSON format. key is the parameter's name and value is + the parameter's value. Supported types in parameters are string, integer, float and boolean. .. note: @@ -198,47 +191,47 @@ class QuboleOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'query', - 'script_location', - 'sub_command', - 'script', - 'files', - 'archives', - 'program', - 'cmdline', - 'sql', - 'where_clause', - 'tags', - 'extract_query', - 'boundary_query', - 'macros', - 'name', - 'parameters', - 'dbtap_id', - 'hive_table', - 'db_table', - 'split_column', - 'note_id', - 'db_update_keys', - 'export_dir', - 'partition_spec', - 'qubole_conn_id', - 'arguments', - 'user_program_arguments', - 'cluster_label', + "query", + "script_location", + "sub_command", + "script", + "files", + "archives", + "program", + "cmdline", + "sql", + "where_clause", + "tags", + "extract_query", + "boundary_query", + "macros", + "name", + "parameters", + "dbtap_id", + "hive_table", + "db_table", + "split_column", + "note_id", + "db_update_keys", + "export_dir", + "partition_spec", + "qubole_conn_id", + "arguments", + "user_program_arguments", + "cluster_label", ) - template_ext: Sequence[str] = ('.txt',) - ui_color = '#3064A1' - ui_fgcolor = '#fff' - qubole_hook_allowed_args_list = ['command_type', 'qubole_conn_id', 'fetch_logs'] + template_ext: Sequence[str] = (".txt",) + ui_color = "#3064A1" + ui_fgcolor = "#fff" + qubole_hook_allowed_args_list = ["command_type", "qubole_conn_id", "fetch_logs"] operator_extra_links = (QDSLink(),) def __init__(self, *, qubole_conn_id: str = "qubole_default", **kwargs) -> None: self.kwargs = kwargs - self.kwargs['qubole_conn_id'] = qubole_conn_id - self.hook: Optional[QuboleHook] = None + self.kwargs["qubole_conn_id"] = qubole_conn_id + self.hook: QuboleHook | None = None filtered_base_kwargs = self._get_filtered_args(kwargs) super().__init__(**filtered_base_kwargs) @@ -257,7 +250,7 @@ def _get_filtered_args(self, all_kwargs) -> dict: ) return {key: value for key, value in all_kwargs.items() if key not in qubole_args} - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: return self.get_hook().execute(context) def on_kill(self, ti=None) -> None: @@ -295,7 +288,7 @@ def __getattribute__(self, name: str) -> str: if name in self.kwargs: return self.kwargs[name] else: - return '' + return "" else: return object.__getattribute__(self, name) @@ -307,6 +300,6 @@ def __setattr__(self, name: str, value: str) -> None: def _get_template_fields(obj: BaseOperator) -> dict: - class_ = object.__getattribute__(obj, '__class__') - template_fields = object.__getattribute__(class_, 'template_fields') + class_ = object.__getattribute__(obj, "__class__") + template_fields = object.__getattribute__(class_, "template_fields") return template_fields diff --git a/airflow/providers/qubole/operators/qubole_check.py b/airflow/providers/qubole/operators/qubole_check.py index e63ff308b3d66..d95b283e18e9b 100644 --- a/airflow/providers/qubole/operators/qubole_check.py +++ b/airflow/providers/qubole/operators/qubole_check.py @@ -15,11 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -from typing import Callable, Optional, Sequence, Union +from __future__ import annotations + +from typing import Callable, Sequence from airflow.exceptions import AirflowException -from airflow.operators.sql import SQLCheckOperator, SQLValueCheckOperator +from airflow.providers.common.sql.operators.sql import SQLCheckOperator, SQLValueCheckOperator from airflow.providers.qubole.hooks.qubole_check import QuboleCheckHook from airflow.providers.qubole.operators.qubole import QuboleOperator @@ -28,7 +29,7 @@ class _QuboleCheckOperatorMixin: """This is a Mixin for Qubole related check operators""" kwargs: dict - results_parser_callable: Optional[Callable] + results_parser_callable: Callable | None def execute(self, context=None) -> None: """Execute a check operation against Qubole""" @@ -86,19 +87,15 @@ class QuboleCheckOperator(_QuboleCheckOperatorMixin, SQLCheckOperator, QuboleOpe :ref:`howto/operator:QuboleCheckOperator` :param qubole_conn_id: Connection id which consists of qds auth_token + :param results_parser_callable: This is an optional parameter to extend the flexibility of parsing the + results of Qubole command to the users. This is a Python callable which can hold the logic to parse + list of rows returned by Qubole command. By default, only the values on first row are used for + performing checks. This callable should return a list of records on which the checks have to be + performed. kwargs: - Arguments specific to Qubole command can be referred from QuboleOperator docs. - :results_parser_callable: This is an optional parameter to - extend the flexibility of parsing the results of Qubole - command to the users. This is a python callable which - can hold the logic to parse list of rows returned by Qubole command. - By default, only the values on first row are used for performing checks. - This callable should return a list of records on - which the checks have to be performed. - .. note:: All fields in common with template fields of QuboleOperator and SQLCheckOperator are template-supported. @@ -108,17 +105,17 @@ class QuboleCheckOperator(_QuboleCheckOperatorMixin, SQLCheckOperator, QuboleOpe set(QuboleOperator.template_fields) | set(SQLCheckOperator.template_fields) ) template_ext = QuboleOperator.template_ext - ui_fgcolor = '#000' + ui_fgcolor = "#000" def __init__( self, *, qubole_conn_id: str = "qubole_default", - results_parser_callable: Optional[Callable] = None, + results_parser_callable: Callable | None = None, **kwargs, ) -> None: sql = get_sql_from_qbol_cmd(kwargs) - kwargs.pop('sql', None) + kwargs.pop("sql", None) super().__init__(qubole_conn_id=qubole_conn_id, sql=sql, **kwargs) self.results_parser_callable = results_parser_callable self.on_failure_callback = QuboleCheckHook.handle_failure_retry @@ -138,47 +135,38 @@ class QuboleValueCheckOperator(_QuboleCheckOperatorMixin, SQLValueCheckOperator, is not within the permissible limit of expected value. :param qubole_conn_id: Connection id which consists of qds auth_token - :param pass_value: Expected value of the query results. - - :param tolerance: Defines the permissible pass_value range, for example if - tolerance is 2, the Qubole command output can be anything between - -2*pass_value and 2*pass_value, without the operator erring out. - - + :param tolerance: Defines the permissible pass_value range, for example if tolerance is 2, the Qubole + command output can be anything between -2*pass_value and 2*pass_value, without the operator erring + out. + :param results_parser_callable: This is an optional parameter to extend the flexibility of parsing the + results of Qubole command to the users. This is a Python callable which can hold the logic to parse + list of rows returned by Qubole command. By default, only the values on first row are used for + performing checks. This callable should return a list of records on which the checks have to be + performed. kwargs: - Arguments specific to Qubole command can be referred from QuboleOperator docs. - :results_parser_callable: This is an optional parameter to - extend the flexibility of parsing the results of Qubole - command to the users. This is a python callable which - can hold the logic to parse list of rows returned by Qubole command. - By default, only the values on first row are used for performing checks. - This callable should return a list of records on - which the checks have to be performed. - - .. note:: All fields in common with template fields of QuboleOperator and SQLValueCheckOperator are template-supported. """ template_fields = tuple(set(QuboleOperator.template_fields) | set(SQLValueCheckOperator.template_fields)) template_ext = QuboleOperator.template_ext - ui_fgcolor = '#000' + ui_fgcolor = "#000" def __init__( self, *, - pass_value: Union[str, int, float], - tolerance: Optional[Union[int, float]] = None, - results_parser_callable: Optional[Callable] = None, + pass_value: str | int | float, + tolerance: int | float | None = None, + results_parser_callable: Callable | None = None, qubole_conn_id: str = "qubole_default", **kwargs, ) -> None: sql = get_sql_from_qbol_cmd(kwargs) - kwargs.pop('sql', None) + kwargs.pop("sql", None) super().__init__( qubole_conn_id=qubole_conn_id, sql=sql, pass_value=pass_value, tolerance=tolerance, **kwargs ) @@ -190,11 +178,11 @@ def __init__( def get_sql_from_qbol_cmd(params) -> str: """Get Qubole sql from Qubole command""" - sql = '' - if 'query' in params: - sql = params['query'] - elif 'sql' in params: - sql = params['sql'] + sql = "" + if "query" in params: + sql = params["query"] + elif "sql" in params: + sql = params["sql"] return sql @@ -206,7 +194,7 @@ def handle_airflow_exception(airflow_exception, hook: QuboleCheckHook): qubole_command_results = hook.get_query_results() qubole_command_id = cmd.id exception_message = ( - f'\nQubole Command Id: {qubole_command_id}\nQubole Command Results:\n{qubole_command_results}' + f"\nQubole Command Id: {qubole_command_id}\nQubole Command Results:\n{qubole_command_results}" ) raise AirflowException(str(airflow_exception) + exception_message) raise AirflowException(str(airflow_exception)) diff --git a/airflow/providers/qubole/provider.yaml b/airflow/providers/qubole/provider.yaml index 59e5a82ea7ddf..2ada483c9a786 100644 --- a/airflow/providers/qubole/provider.yaml +++ b/airflow/providers/qubole/provider.yaml @@ -22,6 +22,11 @@ description: | `Qubole `__ versions: + - 3.3.0 + - 3.2.1 + - 3.2.0 + - 3.1.0 + - 3.0.0 - 2.1.3 - 2.1.2 - 2.1.1 @@ -32,8 +37,10 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-common-sql>=1.3.1 + - qds-sdk>=1.10.4 integrations: - integration-name: Qubole @@ -61,8 +68,6 @@ hooks: - airflow.providers.qubole.hooks.qubole - airflow.providers.qubole.hooks.qubole_check -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.qubole.hooks.qubole.QuboleHook connection-types: - hook-class-name: airflow.providers.qubole.hooks.qubole.QuboleHook diff --git a/airflow/providers/qubole/sensors/qubole.py b/airflow/providers/qubole/sensors/qubole.py index 1d1fdb62eda41..9a80d311a4875 100644 --- a/airflow/providers/qubole/sensors/qubole.py +++ b/airflow/providers/qubole/sensors/qubole.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from typing import TYPE_CHECKING, Sequence from qds_sdk.qubole import Qubole @@ -29,17 +31,17 @@ class QuboleSensor(BaseSensorOperator): - """Base class for all Qubole Sensors""" + """Base class for all Qubole Sensors.""" - template_fields: Sequence[str] = ('data', 'qubole_conn_id') + template_fields: Sequence[str] = ("data", "qubole_conn_id") - template_ext: Sequence[str] = ('.txt',) + template_ext: Sequence[str] = (".txt",) def __init__(self, *, data, qubole_conn_id: str = "qubole_default", **kwargs) -> None: self.data = data self.qubole_conn_id = qubole_conn_id - if 'poke_interval' in kwargs and kwargs['poke_interval'] < 5: + if "poke_interval" in kwargs and kwargs["poke_interval"] < 5: raise AirflowException( f"Sorry, poke_interval can't be less than 5 sec for task '{kwargs['task_id']}' " f"in dag '{kwargs['dag'].dag_id}'." @@ -47,12 +49,12 @@ def __init__(self, *, data, qubole_conn_id: str = "qubole_default", **kwargs) -> super().__init__(**kwargs) - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: conn = BaseHook.get_connection(self.qubole_conn_id) Qubole.configure(api_token=conn.password, api_url=conn.host) - self.log.info('Poking: %s', self.data) + self.log.info("Poking: %s", self.data) status = False try: @@ -61,15 +63,16 @@ def poke(self, context: 'Context') -> bool: self.log.exception(e) status = False - self.log.info('Status of this Poke: %s', status) + self.log.info("Status of this Poke: %s", status) return status class QuboleFileSensor(QuboleSensor): """ - Wait for a file or folder to be present in cloud storage - and check for its presence via QDS APIs + Wait for a file or folder to be present in cloud storage. + + Check for file or folder presence via QDS APIs. .. seealso:: For more information on how to use this sensor, take a look at the guide: @@ -92,8 +95,9 @@ def __init__(self, **kwargs) -> None: class QubolePartitionSensor(QuboleSensor): """ - Wait for a Hive partition to show up in QHS (Qubole Hive Service) - and check for its presence via QDS APIs + Wait for a Hive partition to show up in QHS (Qubole Hive Service). + + Check for Hive partition presence via QDS APIs. .. seealso:: For more information on how to use this sensor, take a look at the guide: diff --git a/airflow/providers/redis/.latest-doc-only-change.txt b/airflow/providers/redis/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/redis/.latest-doc-only-change.txt +++ b/airflow/providers/redis/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/redis/CHANGELOG.rst b/airflow/providers/redis/CHANGELOG.rst index 15a44833c2286..3d1c7b5767485 100644 --- a/airflow/providers/redis/CHANGELOG.rst +++ b/airflow/providers/redis/CHANGELOG.rst @@ -16,9 +16,50 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.0.4 ..... diff --git a/airflow/providers/redis/hooks/redis.py b/airflow/providers/redis/hooks/redis.py index 909cc120a6ccc..f6b697bd3b9b5 100644 --- a/airflow/providers/redis/hooks/redis.py +++ b/airflow/providers/redis/hooks/redis.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """RedisHook module""" +from __future__ import annotations + from redis import Redis from airflow.hooks.base import BaseHook @@ -31,10 +32,10 @@ class RedisHook(BaseHook): ``{"ssl": true, "ssl_cert_reqs": "require", "ssl_cert_file": "/path/to/cert.pem", etc}``. """ - conn_name_attr = 'redis_conn_id' - default_conn_name = 'redis_default' - conn_type = 'redis' - hook_name = 'Redis' + conn_name_attr = "redis_conn_id" + default_conn_name = "redis_default" + conn_type = "redis" + hook_name = "Redis" def __init__(self, redis_conn_id: str = default_conn_name) -> None: """ @@ -56,8 +57,8 @@ def get_conn(self): conn = self.get_connection(self.redis_conn_id) self.host = conn.host self.port = conn.port - self.password = None if str(conn.password).lower() in ['none', 'false', ''] else conn.password - self.db = conn.extra_dejson.get('db') + self.password = None if str(conn.password).lower() in ["none", "false", ""] else conn.password + self.db = conn.extra_dejson.get("db") # check for ssl parameters in conn.extra ssl_arg_names = [ diff --git a/airflow/providers/redis/operators/redis_publish.py b/airflow/providers/redis/operators/redis_publish.py index 67315acfd8fb8..98b1b9d681b65 100644 --- a/airflow/providers/redis/operators/redis_publish.py +++ b/airflow/providers/redis/operators/redis_publish.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator @@ -33,16 +35,16 @@ class RedisPublishOperator(BaseOperator): :param redis_conn_id: redis connection to use """ - template_fields: Sequence[str] = ('channel', 'message') + template_fields: Sequence[str] = ("channel", "message") - def __init__(self, *, channel: str, message: str, redis_conn_id: str = 'redis_default', **kwargs) -> None: + def __init__(self, *, channel: str, message: str, redis_conn_id: str = "redis_default", **kwargs) -> None: super().__init__(**kwargs) self.redis_conn_id = redis_conn_id self.channel = channel self.message = message - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: """ Publish the message to Redis channel @@ -50,8 +52,8 @@ def execute(self, context: 'Context') -> None: """ redis_hook = RedisHook(redis_conn_id=self.redis_conn_id) - self.log.info('Sending message %s to Redis on channel %s', self.message, self.channel) + self.log.info("Sending message %s to Redis on channel %s", self.message, self.channel) result = redis_hook.get_conn().publish(channel=self.channel, message=self.message) - self.log.info('Result of publishing %s', result) + self.log.info("Result of publishing %s", result) diff --git a/airflow/providers/redis/provider.yaml b/airflow/providers/redis/provider.yaml index 22cd508458ddb..0cbcab3a88e66 100644 --- a/airflow/providers/redis/provider.yaml +++ b/airflow/providers/redis/provider.yaml @@ -22,6 +22,8 @@ description: | `Redis `__ versions: + - 3.1.0 + - 3.0.0 - 2.0.4 - 2.0.3 - 2.0.2 @@ -30,8 +32,13 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + # Redis 4 introduced a number of changes that likely need testing including mixins in redis commands + # as well as unquoting URLS with `urllib.parse.unquote`: + # https://github.com/redis/redis-py/blob/master/CHANGES + # TODO: upgrade to support redis package >=4 + - redis~=3.2 integrations: - integration-name: Redis @@ -55,9 +62,6 @@ hooks: python-modules: - airflow.providers.redis.hooks.redis -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.redis.hooks.redis.RedisHook - connection-types: - hook-class-name: airflow.providers.redis.hooks.redis.RedisHook connection-type: redis diff --git a/airflow/providers/redis/sensors/redis_key.py b/airflow/providers/redis/sensors/redis_key.py index 064459ab0ffb3..855a39ab82ff0 100644 --- a/airflow/providers/redis/sensors/redis_key.py +++ b/airflow/providers/redis/sensors/redis_key.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from typing import TYPE_CHECKING, Sequence from airflow.providers.redis.hooks.redis import RedisHook @@ -27,14 +29,14 @@ class RedisKeySensor(BaseSensorOperator): """Checks for the existence of a key in a Redis""" - template_fields: Sequence[str] = ('key',) - ui_color = '#f0eee4' + template_fields: Sequence[str] = ("key",) + ui_color = "#f0eee4" def __init__(self, *, key: str, redis_conn_id: str, **kwargs) -> None: super().__init__(**kwargs) self.redis_conn_id = redis_conn_id self.key = key - def poke(self, context: 'Context') -> bool: - self.log.info('Sensor checks for existence of key: %s', self.key) + def poke(self, context: Context) -> bool: + self.log.info("Sensor checks for existence of key: %s", self.key) return RedisHook(self.redis_conn_id).get_conn().exists(self.key) diff --git a/airflow/providers/redis/sensors/redis_pub_sub.py b/airflow/providers/redis/sensors/redis_pub_sub.py index dedfde72c8530..31c5e5af9ee2c 100644 --- a/airflow/providers/redis/sensors/redis_pub_sub.py +++ b/airflow/providers/redis/sensors/redis_pub_sub.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, List, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.providers.redis.hooks.redis import RedisHook from airflow.sensors.base import BaseSensorOperator @@ -33,17 +34,17 @@ class RedisPubSubSensor(BaseSensorOperator): :param redis_conn_id: the redis connection id """ - template_fields: Sequence[str] = ('channels',) - ui_color = '#f0eee4' + template_fields: Sequence[str] = ("channels",) + ui_color = "#f0eee4" - def __init__(self, *, channels: Union[List[str], str], redis_conn_id: str, **kwargs) -> None: + def __init__(self, *, channels: list[str] | str, redis_conn_id: str, **kwargs) -> None: super().__init__(**kwargs) self.channels = channels self.redis_conn_id = redis_conn_id self.pubsub = RedisHook(redis_conn_id=self.redis_conn_id).get_conn().pubsub() self.pubsub.subscribe(self.channels) - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: """ Check for message on subscribed channels and write to xcom the message with key ``message`` @@ -52,15 +53,15 @@ def poke(self, context: 'Context') -> bool: :param context: the context object :return: ``True`` if message (with type 'message') is available or ``False`` if not """ - self.log.info('RedisPubSubSensor checking for message on channels: %s', self.channels) + self.log.info("RedisPubSubSensor checking for message on channels: %s", self.channels) message = self.pubsub.get_message() - self.log.info('Message %s from channel %s', message, self.channels) + self.log.info("Message %s from channel %s", message, self.channels) # Process only message types - if message and message['type'] == 'message': + if message and message["type"] == "message": - context['ti'].xcom_push(key='message', value=message) + context["ti"].xcom_push(key="message", value=message) self.pubsub.unsubscribe(self.channels) return True diff --git a/airflow/providers/salesforce/.latest-doc-only-change.txt b/airflow/providers/salesforce/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/salesforce/.latest-doc-only-change.txt +++ b/airflow/providers/salesforce/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/salesforce/CHANGELOG.rst b/airflow/providers/salesforce/CHANGELOG.rst index 226212a912ed1..da323857adffa 100644 --- a/airflow/providers/salesforce/CHANGELOG.rst +++ b/airflow/providers/salesforce/CHANGELOG.rst @@ -16,9 +16,86 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +5.2.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` +* ``Allow and prefer non-prefixed extra fields for SalesforceHook (#27075)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +5.1.0 +..... + +Features +~~~~~~~~ + +* ``Improve taskflow type hints with ParamSpec (#25173)`` + + +5.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* ``Remove Tableau from Salesforce provider (#23747)`` + +.. warning:: Due to tableau extra removal, ``pip install apache-airflow-providers-salesforce[tableau]`` + will not work. You can install Tableau provider directly via ``pip install apache-airflow-providers-tableau``. + +Features +~~~~~~~~ + +* ``Add support for Salesforce bulk api (#24473)`` + +Bug Fixes +~~~~~~~~~ + +* ``Update providers to use functools compat for ''cached_property'' (#24582)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Correct parameter typing in 'SalesforceBulkOperator' (#24927)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + + +4.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Migrate Salesforce example DAGs to new design #22463 (#24127)`` + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 3.4.4 ..... diff --git a/airflow/providers/salesforce/hooks/salesforce.py b/airflow/providers/salesforce/hooks/salesforce.py index cdb2c918e935c..8087c42b07187 100644 --- a/airflow/providers/salesforce/hooks/salesforce.py +++ b/airflow/providers/salesforce/hooks/salesforce.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """ This module contains a Salesforce Hook which allows you to connect to your Salesforce instance, retrieve data from it, and write that data to a file for other uses. @@ -23,20 +22,17 @@ .. note:: this hook also relies on the simple_salesforce package: https://github.com/simple-salesforce/simple-salesforce """ +from __future__ import annotations + import logging -import sys import time -from typing import Any, Dict, Iterable, List, Optional - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property +from typing import Any, Iterable import pandas as pd from requests import Session from simple_salesforce import Salesforce, api +from airflow.compat.functools import cached_property from airflow.hooks.base import BaseHook log = logging.getLogger(__name__) @@ -73,53 +69,52 @@ class SalesforceHook(BaseHook): def __init__( self, salesforce_conn_id: str = default_conn_name, - session_id: Optional[str] = None, - session: Optional[Session] = None, + session_id: str | None = None, + session: Session | None = None, ) -> None: super().__init__() self.conn_id = salesforce_conn_id self.session_id = session_id self.session = session + def _get_field(self, extras: dict, field_name: str): + """Get field from extra, first checking short name, then for backcompat we check for prefixed name.""" + backcompat_prefix = "extra__salesforce__" + if field_name.startswith("extra__"): + raise ValueError( + f"Got prefixed name {field_name}; please remove the '{backcompat_prefix}' prefix " + "when using this method." + ) + if field_name in extras: + return extras[field_name] or None + prefixed_name = f"{backcompat_prefix}{field_name}" + return extras.get(prefixed_name) or None + @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: + def get_connection_form_widgets() -> dict[str, Any]: """Returns connection widgets to add to connection form""" from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget from flask_babel import lazy_gettext from wtforms import PasswordField, StringField return { - "extra__salesforce__security_token": PasswordField( - lazy_gettext("Security Token"), widget=BS3PasswordFieldWidget() - ), - "extra__salesforce__domain": StringField(lazy_gettext("Domain"), widget=BS3TextFieldWidget()), - "extra__salesforce__consumer_key": StringField( - lazy_gettext("Consumer Key"), widget=BS3TextFieldWidget() - ), - "extra__salesforce__private_key_file_path": PasswordField( + "security_token": PasswordField(lazy_gettext("Security Token"), widget=BS3PasswordFieldWidget()), + "domain": StringField(lazy_gettext("Domain"), widget=BS3TextFieldWidget()), + "consumer_key": StringField(lazy_gettext("Consumer Key"), widget=BS3TextFieldWidget()), + "private_key_file_path": PasswordField( lazy_gettext("Private Key File Path"), widget=BS3PasswordFieldWidget() ), - "extra__salesforce__private_key": PasswordField( - lazy_gettext("Private Key"), widget=BS3PasswordFieldWidget() - ), - "extra__salesforce__organization_id": StringField( - lazy_gettext("Organization ID"), widget=BS3TextFieldWidget() - ), - "extra__salesforce__instance": StringField(lazy_gettext("Instance"), widget=BS3TextFieldWidget()), - "extra__salesforce__instance_url": StringField( - lazy_gettext("Instance URL"), widget=BS3TextFieldWidget() - ), - "extra__salesforce__proxies": StringField(lazy_gettext("Proxies"), widget=BS3TextFieldWidget()), - "extra__salesforce__version": StringField( - lazy_gettext("API Version"), widget=BS3TextFieldWidget() - ), - "extra__salesforce__client_id": StringField( - lazy_gettext("Client ID"), widget=BS3TextFieldWidget() - ), + "private_key": PasswordField(lazy_gettext("Private Key"), widget=BS3PasswordFieldWidget()), + "organization_id": StringField(lazy_gettext("Organization ID"), widget=BS3TextFieldWidget()), + "instance": StringField(lazy_gettext("Instance"), widget=BS3TextFieldWidget()), + "instance_url": StringField(lazy_gettext("Instance URL"), widget=BS3TextFieldWidget()), + "proxies": StringField(lazy_gettext("Proxies"), widget=BS3TextFieldWidget()), + "version": StringField(lazy_gettext("API Version"), widget=BS3TextFieldWidget()), + "client_id": StringField(lazy_gettext("Client ID"), widget=BS3TextFieldWidget()), } @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { "hidden_fields": ["schema", "port", "extra", "host"], @@ -141,19 +136,19 @@ def conn(self) -> api.Salesforce: conn = Salesforce( username=connection.login, password=connection.password, - security_token=extras.get('extra__salesforce__security_token') or None, - domain=extras.get('extra__salesforce__domain') or None, + security_token=self._get_field(extras, "security_token") or None, + domain=self._get_field(extras, "domain") or None, session_id=self.session_id, - instance=extras.get('extra__salesforce__instance') or None, - instance_url=extras.get('extra__salesforce__instance_url') or None, - organizationId=extras.get('extra__salesforce__organization_id') or None, - version=extras.get('extra__salesforce__version') or api.DEFAULT_API_VERSION, - proxies=extras.get('extra__salesforce__proxies') or None, + instance=self._get_field(extras, "instance") or None, + instance_url=self._get_field(extras, "instance_url") or None, + organizationId=self._get_field(extras, "organization_id") or None, + version=self._get_field(extras, "version") or api.DEFAULT_API_VERSION, + proxies=self._get_field(extras, "proxies") or None, session=self.session, - client_id=extras.get('extra__salesforce__client_id') or None, - consumer_key=extras.get('extra__salesforce__consumer_key') or None, - privatekey_file=extras.get('extra__salesforce__private_key_file_path') or None, - privatekey=extras.get('extra__salesforce__private_key') or None, + client_id=self._get_field(extras, "client_id") or None, + consumer_key=self._get_field(extras, "consumer_key") or None, + privatekey_file=self._get_field(extras, "private_key_file_path") or None, + privatekey=self._get_field(extras, "private_key") or None, ) return conn @@ -161,9 +156,7 @@ def get_conn(self) -> api.Salesforce: """Returns a Salesforce instance. (cached)""" return self.conn - def make_query( - self, query: str, include_deleted: bool = False, query_params: Optional[dict] = None - ) -> dict: + def make_query(self, query: str, include_deleted: bool = False, query_params: dict | None = None) -> dict: """ Make a query to Salesforce. @@ -171,7 +164,6 @@ def make_query( :param include_deleted: True if the query should include deleted records. :param query_params: Additional optional arguments :return: The query result. - :rtype: dict """ conn = self.get_conn() @@ -180,7 +172,7 @@ def make_query( query_results = conn.query_all(query, include_deleted=include_deleted, **query_params) self.log.info( - "Received results: Total size: %s; Done: %s", query_results['totalSize'], query_results['done'] + "Received results: Total size: %s; Done: %s", query_results["totalSize"], query_results["done"] ) return query_results @@ -193,23 +185,21 @@ def describe_object(self, obj: str) -> dict: :param obj: The name of the Salesforce object that we are getting a description of. :return: the description of the Salesforce object. - :rtype: dict """ conn = self.get_conn() return conn.__getattr__(obj).describe() - def get_available_fields(self, obj: str) -> List[str]: + def get_available_fields(self, obj: str) -> list[str]: """ Get a list of all available fields for an object. :param obj: The name of the Salesforce object that we are getting a description of. :return: the names of the fields. - :rtype: list(str) """ obj_description = self.describe_object(obj) - return [field['name'] for field in obj_description['fields']] + return [field["name"] for field in obj_description["fields"]] def get_object_from_salesforce(self, obj: str, fields: Iterable[str]) -> dict: """ @@ -222,7 +212,6 @@ def get_object_from_salesforce(self, obj: str, fields: Iterable[str]) -> dict: :param obj: The object name to get from Salesforce. :param fields: The fields to get from the object. :return: all instances of the object from Salesforce. - :rtype: dict """ query = f"SELECT {','.join(fields)} FROM {obj}" @@ -240,7 +229,6 @@ def _to_timestamp(cls, column: pd.Series) -> pd.Series: :param column: A Series object representing a column of a dataframe. :return: a new series that maintains the same index as the original - :rtype: pandas.Series """ # try and convert the column to datetimes # the column MUST have a four digit year somewhere in the string @@ -272,7 +260,7 @@ def _to_timestamp(cls, column: pd.Series) -> pd.Series: def write_object_to_file( self, - query_results: List[dict], + query_results: list[dict], filename: str, fmt: str = "csv", coerce_to_timestamp: bool = False, @@ -308,10 +296,9 @@ def write_object_to_file( :param record_time_added: True if you want to add a Unix timestamp field to the resulting data that marks when the data was fetched from Salesforce. Default: False :return: the dataframe that gets written to the file. - :rtype: pandas.Dataframe """ fmt = fmt.lower() - if fmt not in ['csv', 'json', 'ndjson']: + if fmt not in ["csv", "json", "ndjson"]: raise ValueError(f"Format value is not recognized: {fmt}") df = self.object_to_df( @@ -349,7 +336,7 @@ def write_object_to_file( return df def object_to_df( - self, query_results: List[dict], coerce_to_timestamp: bool = False, record_time_added: bool = False + self, query_results: list[dict], coerce_to_timestamp: bool = False, record_time_added: bool = False ) -> pd.DataFrame: """ Export query results to dataframe. @@ -366,7 +353,6 @@ def object_to_df( :param record_time_added: True if you want to add a Unix timestamp field to the resulting data that marks when the data was fetched from Salesforce. Default: False :return: the dataframe. - :rtype: pandas.Dataframe """ # this line right here will convert all integers to floats # if there are any None/np.nan values in the column @@ -384,7 +370,7 @@ def object_to_df( # get the object name out of the query results # it's stored in the "attributes" dictionary # for each returned record - object_name = query_results[0]['attributes']['type'] + object_name = query_results[0]["attributes"]["type"] self.log.info("Coercing timestamps for: %s", object_name) @@ -394,9 +380,9 @@ def object_to_df( # are the ones that are either date or datetime types # strings are too general and we risk unintentional conversion possible_timestamp_cols = [ - field['name'].lower() - for field in schema['fields'] - if field['type'] in ["date", "datetime"] and field['name'].lower() in df.columns + field["name"].lower() + for field in schema["fields"] + if field["type"] in ["date", "datetime"] and field["name"].lower() in df.columns ] df[possible_timestamp_cols] = df[possible_timestamp_cols].apply(self._to_timestamp) diff --git a/airflow/providers/salesforce/hooks/tableau.py b/airflow/providers/salesforce/hooks/tableau.py deleted file mode 100644 index 8dcfe82a8e6b4..0000000000000 --- a/airflow/providers/salesforce/hooks/tableau.py +++ /dev/null @@ -1,26 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import warnings - -from airflow.providers.tableau.hooks.tableau import TableauHook, TableauJobFinishCode # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.tableau.hooks.tableau`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/salesforce/operators/bulk.py b/airflow/providers/salesforce/operators/bulk.py new file mode 100644 index 0000000000000..2b22b1ff4a156 --- /dev/null +++ b/airflow/providers/salesforce/operators/bulk.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.models import BaseOperator +from airflow.providers.salesforce.hooks.salesforce import SalesforceHook +from airflow.typing_compat import Literal + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class SalesforceBulkOperator(BaseOperator): + """ + Execute a Salesforce Bulk API and pushes results to xcom. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SalesforceBulkOperator` + + :param operation: Bulk operation to be performed + Available operations are in ['insert', 'update', 'upsert', 'delete', 'hard_delete'] + :param object_name: The name of the Salesforce object + :param payload: list of dict to be passed as a batch + :param external_id_field: unique identifier field for upsert operations + :param batch_size: number of records to assign for each batch in the job + :param use_serial: Process batches in serial mode + :param salesforce_conn_id: The :ref:`Salesforce Connection id `. + """ + + available_operations = ("insert", "update", "upsert", "delete", "hard_delete") + + def __init__( + self, + *, + operation: Literal["insert", "update", "upsert", "delete", "hard_delete"], + object_name: str, + payload: list, + external_id_field: str = "Id", + batch_size: int = 10000, + use_serial: bool = False, + salesforce_conn_id: str = "salesforce_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.operation = operation + self.object_name = object_name + self.payload = payload + self.external_id_field = external_id_field + self.batch_size = batch_size + self.use_serial = use_serial + self.salesforce_conn_id = salesforce_conn_id + self._validate_inputs() + + def _validate_inputs(self) -> None: + if not self.object_name: + raise ValueError("The required parameter 'object_name' cannot have an empty value.") + + if self.operation not in self.available_operations: + raise ValueError( + f"Operation {self.operation!r} not found! " + f"Available operations are {self.available_operations}." + ) + + def execute(self, context: Context): + """ + Makes an HTTP request to Salesforce Bulk API. + + :param context: The task context during execution. + :return: API response if do_xcom_push is True + """ + sf_hook = SalesforceHook(salesforce_conn_id=self.salesforce_conn_id) + conn = sf_hook.get_conn() + + result = [] + if self.operation == "insert": + result = conn.bulk.__getattr__(self.object_name).insert( + data=self.payload, batch_size=self.batch_size, use_serial=self.use_serial + ) + elif self.operation == "update": + result = conn.bulk.__getattr__(self.object_name).update( + data=self.payload, batch_size=self.batch_size, use_serial=self.use_serial + ) + elif self.operation == "upsert": + result = conn.bulk.__getattr__(self.object_name).upsert( + data=self.payload, + external_id_field=self.external_id_field, + batch_size=self.batch_size, + use_serial=self.use_serial, + ) + elif self.operation == "delete": + result = conn.bulk.__getattr__(self.object_name).delete( + data=self.payload, batch_size=self.batch_size, use_serial=self.use_serial + ) + elif self.operation == "hard_delete": + result = conn.bulk.__getattr__(self.object_name).hard_delete( + data=self.payload, batch_size=self.batch_size, use_serial=self.use_serial + ) + + if self.do_xcom_push and result: + return result + + return None diff --git a/airflow/providers/salesforce/operators/salesforce_apex_rest.py b/airflow/providers/salesforce/operators/salesforce_apex_rest.py index 703f5dfaeca3e..46b10a58670c9 100644 --- a/airflow/providers/salesforce/operators/salesforce_apex_rest.py +++ b/airflow/providers/salesforce/operators/salesforce_apex_rest.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from typing import TYPE_CHECKING from airflow.models import BaseOperator @@ -41,9 +43,9 @@ def __init__( self, *, endpoint: str, - method: str = 'GET', + method: str = "GET", payload: dict, - salesforce_conn_id: str = 'salesforce_default', + salesforce_conn_id: str = "salesforce_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -52,12 +54,11 @@ def __init__( self.payload = payload self.salesforce_conn_id = salesforce_conn_id - def execute(self, context: 'Context') -> dict: + def execute(self, context: Context) -> dict: """ Makes an HTTP request to an APEX REST endpoint and pushes results to xcom. :param context: The task context during execution. :return: Apex response - :rtype: dict """ result: dict = {} sf_hook = SalesforceHook(salesforce_conn_id=self.salesforce_conn_id) diff --git a/airflow/providers/salesforce/operators/tableau_refresh_workbook.py b/airflow/providers/salesforce/operators/tableau_refresh_workbook.py deleted file mode 100644 index 007575caad142..0000000000000 --- a/airflow/providers/salesforce/operators/tableau_refresh_workbook.py +++ /dev/null @@ -1,28 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import warnings - -from airflow.providers.tableau.operators.tableau_refresh_workbook import ( # noqa - TableauRefreshWorkbookOperator, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.tableau.operators.tableau_refresh_workbook`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/salesforce/provider.yaml b/airflow/providers/salesforce/provider.yaml index 77e62c44e84fa..01910156392ca 100644 --- a/airflow/providers/salesforce/provider.yaml +++ b/airflow/providers/salesforce/provider.yaml @@ -22,6 +22,10 @@ description: | `Salesforce `__ versions: + - 5.2.0 + - 5.1.0 + - 5.0.0 + - 4.0.0 - 3.4.4 - 3.4.3 - 3.4.2 @@ -35,14 +39,17 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - simple-salesforce>=1.0.0 + - pandas>=0.17.1 integrations: - integration-name: Salesforce external-doc-url: https://www.salesforce.com/ how-to-guide: - /docs/apache-airflow-providers-salesforce/operators/salesforce_apex_rest.rst + - /docs/apache-airflow-providers-salesforce/operators/bulk.rst logo: /integration-logos/salesforce/Salesforce.png tags: [service] @@ -50,24 +57,13 @@ operators: - integration-name: Salesforce python-modules: - airflow.providers.salesforce.operators.salesforce_apex_rest - - airflow.providers.salesforce.operators.tableau_refresh_workbook - -sensors: - - integration-name: Salesforce - python-modules: - - airflow.providers.salesforce.sensors.tableau_job_status + - airflow.providers.salesforce.operators.bulk hooks: - - integration-name: Tableau - python-modules: - - airflow.providers.salesforce.hooks.tableau - integration-name: Salesforce python-modules: - airflow.providers.salesforce.hooks.salesforce -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.salesforce.hooks.salesforce.SalesforceHook - connection-types: - hook-class-name: airflow.providers.salesforce.hooks.salesforce.SalesforceHook connection-type: salesforce diff --git a/airflow/providers/salesforce/sensors/tableau_job_status.py b/airflow/providers/salesforce/sensors/tableau_job_status.py deleted file mode 100644 index 09e2a373eaccb..0000000000000 --- a/airflow/providers/salesforce/sensors/tableau_job_status.py +++ /dev/null @@ -1,29 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import warnings - -from airflow.providers.tableau.sensors.tableau import ( # noqa - TableauJobFailedException, - TableauJobStatusSensor, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.tableau.sensors.tableau_job_status`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/samba/.latest-doc-only-change.txt b/airflow/providers/samba/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/samba/.latest-doc-only-change.txt +++ b/airflow/providers/samba/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/samba/CHANGELOG.rst b/airflow/providers/samba/CHANGELOG.rst index 430319327b44f..601862979d644 100644 --- a/airflow/providers/samba/CHANGELOG.rst +++ b/airflow/providers/samba/CHANGELOG.rst @@ -16,9 +16,50 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +4.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +4.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 3.0.4 ..... diff --git a/airflow/providers/samba/hooks/samba.py b/airflow/providers/samba/hooks/samba.py index 383dcf0c076d5..1ca56c351d548 100644 --- a/airflow/providers/samba/hooks/samba.py +++ b/airflow/providers/samba/hooks/samba.py @@ -15,11 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import posixpath from functools import wraps from shutil import copyfileobj -from typing import Dict, Optional import smbclient import smbprotocol.connection @@ -39,12 +39,12 @@ class SambaHook(BaseHook): the connection is used in its place. """ - conn_name_attr = 'samba_conn_id' - default_conn_name = 'samba_default' - conn_type = 'samba' - hook_name = 'Samba' + conn_name_attr = "samba_conn_id" + default_conn_name = "samba_default" + conn_type = "samba" + hook_name = "Samba" - def __init__(self, samba_conn_id: str = default_conn_name, share: Optional[str] = None) -> None: + def __init__(self, samba_conn_id: str = default_conn_name, share: str | None = None) -> None: super().__init__() conn = self.get_connection(samba_conn_id) @@ -54,7 +54,7 @@ def __init__(self, samba_conn_id: str = default_conn_name, share: Optional[str] if not conn.password: self.log.info("Password not provided") - connection_cache: Dict[str, smbprotocol.connection.Connection] = {} + connection_cache: dict[str, smbprotocol.connection.Connection] = {} self._host = conn.host self._share = share or conn.schema diff --git a/airflow/providers/samba/provider.yaml b/airflow/providers/samba/provider.yaml index 26d2e582eced5..9b210018515a6 100644 --- a/airflow/providers/samba/provider.yaml +++ b/airflow/providers/samba/provider.yaml @@ -22,6 +22,8 @@ description: | `Samba `__ versions: + - 4.1.0 + - 4.0.0 - 3.0.4 - 3.0.3 - 3.0.2 @@ -31,8 +33,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - smbprotocol>=1.5.0 integrations: - integration-name: Samba @@ -45,8 +48,6 @@ hooks: python-modules: - airflow.providers.samba.hooks.samba -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.samba.hooks.samba.SambaHook connection-types: - hook-class-name: airflow.providers.samba.hooks.samba.SambaHook diff --git a/airflow/providers/segment/.latest-doc-only-change.txt b/airflow/providers/segment/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/segment/.latest-doc-only-change.txt +++ b/airflow/providers/segment/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/segment/CHANGELOG.rst b/airflow/providers/segment/CHANGELOG.rst index c3246811ab957..789475d81cbe8 100644 --- a/airflow/providers/segment/CHANGELOG.rst +++ b/airflow/providers/segment/CHANGELOG.rst @@ -16,9 +16,50 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.0.4 ..... diff --git a/airflow/providers/segment/hooks/segment.py b/airflow/providers/segment/hooks/segment.py index 053c9e037d1ee..684fbf78d03c5 100644 --- a/airflow/providers/segment/hooks/segment.py +++ b/airflow/providers/segment/hooks/segment.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# """ This module contains a Segment Hook which allows you to connect to your Segment account, @@ -24,6 +23,8 @@ NOTE: this hook also relies on the Segment analytics package: https://github.com/segmentio/analytics-python """ +from __future__ import annotations + import analytics from airflow.exceptions import AirflowException @@ -51,13 +52,13 @@ class SegmentHook(BaseHook): `{"write_key":"YOUR_SECURITY_TOKEN"}` """ - conn_name_attr = 'segment_conn_id' - default_conn_name = 'segment_default' - conn_type = 'segment' - hook_name = 'Segment' + conn_name_attr = "segment_conn_id" + default_conn_name = "segment_default" + conn_type = "segment" + hook_name = "Segment" def __init__( - self, segment_conn_id: str = 'segment_default', segment_debug_mode: bool = False, *args, **kwargs + self, segment_conn_id: str = "segment_default", segment_debug_mode: bool = False, *args, **kwargs ) -> None: super().__init__() self.segment_conn_id = segment_conn_id @@ -68,20 +69,20 @@ def __init__( # get the connection parameters self.connection = self.get_connection(self.segment_conn_id) self.extras = self.connection.extra_dejson - self.write_key = self.extras.get('write_key') + self.write_key = self.extras.get("write_key") if self.write_key is None: - raise AirflowException('No Segment write key provided') + raise AirflowException("No Segment write key provided") def get_conn(self) -> analytics: - self.log.info('Setting write key for Segment analytics connection') + self.log.info("Setting write key for Segment analytics connection") analytics.debug = self.segment_debug_mode if self.segment_debug_mode: - self.log.info('Setting Segment analytics connection to debug mode') + self.log.info("Setting Segment analytics connection to debug mode") analytics.on_error = self.on_error analytics.write_key = self.write_key return analytics def on_error(self, error: str, items: str) -> None: """Handles error callbacks when using Segment with segment_debug_mode set to True""" - self.log.error('Encountered Segment error: %s with items: %s', error, items) - raise AirflowException(f'Segment error: {error}') + self.log.error("Encountered Segment error: %s with items: %s", error, items) + raise AirflowException(f"Segment error: {error}") diff --git a/airflow/providers/segment/operators/segment_track_event.py b/airflow/providers/segment/operators/segment_track_event.py index 19f15df3f677a..8a0d2cd879d75 100644 --- a/airflow/providers/segment/operators/segment_track_event.py +++ b/airflow/providers/segment/operators/segment_track_event.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.segment.hooks.segment import SegmentHook @@ -36,16 +38,16 @@ class SegmentTrackEventOperator(BaseOperator): Defaults to False """ - template_fields: Sequence[str] = ('user_id', 'event', 'properties') - ui_color = '#ffd700' + template_fields: Sequence[str] = ("user_id", "event", "properties") + ui_color = "#ffd700" def __init__( self, *, user_id: str, event: str, - properties: Optional[dict] = None, - segment_conn_id: str = 'segment_default', + properties: dict | None = None, + segment_conn_id: str = "segment_default", segment_debug_mode: bool = False, **kwargs, ) -> None: @@ -57,11 +59,11 @@ def __init__( self.segment_debug_mode = segment_debug_mode self.segment_conn_id = segment_conn_id - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: hook = SegmentHook(segment_conn_id=self.segment_conn_id, segment_debug_mode=self.segment_debug_mode) self.log.info( - 'Sending track event (%s) for user id: %s with properties: %s', + "Sending track event (%s) for user id: %s with properties: %s", self.event, self.user_id, self.properties, diff --git a/airflow/providers/segment/provider.yaml b/airflow/providers/segment/provider.yaml index c695634c08f49..b46b96eecab56 100644 --- a/airflow/providers/segment/provider.yaml +++ b/airflow/providers/segment/provider.yaml @@ -22,6 +22,8 @@ description: | `Segment `__ versions: + - 3.1.0 + - 3.0.0 - 2.0.4 - 2.0.3 - 2.0.2 @@ -30,8 +32,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - analytics-python>=1.2.9 integrations: - integration-name: Segment @@ -49,9 +52,6 @@ hooks: python-modules: - airflow.providers.segment.hooks.segment -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.segment.hooks.segment.SegmentHook - connection-types: - hook-class-name: airflow.providers.segment.hooks.segment.SegmentHook connection-type: segment diff --git a/airflow/providers/sendgrid/.latest-doc-only-change.txt b/airflow/providers/sendgrid/.latest-doc-only-change.txt index e7e8156d80b9e..ff7136e07d744 100644 --- a/airflow/providers/sendgrid/.latest-doc-only-change.txt +++ b/airflow/providers/sendgrid/.latest-doc-only-change.txt @@ -1 +1 @@ -b916b7507921129dc48d6add1bdc4b923b60c9b9 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/sendgrid/CHANGELOG.rst b/airflow/providers/sendgrid/CHANGELOG.rst index 8460ae4c30fbc..88121eff6b67d 100644 --- a/airflow/providers/sendgrid/CHANGELOG.rst +++ b/airflow/providers/sendgrid/CHANGELOG.rst @@ -16,9 +16,54 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Features +~~~~~~~~ + +* ``Add Airflow specific warning classes (#25799)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.0.4 ..... diff --git a/airflow/providers/sendgrid/provider.yaml b/airflow/providers/sendgrid/provider.yaml index d53eeaab8f2a4..210af20925f31 100644 --- a/airflow/providers/sendgrid/provider.yaml +++ b/airflow/providers/sendgrid/provider.yaml @@ -21,7 +21,13 @@ name: Sendgrid description: | `Sendgrid `__ +dependencies: + - apache-airflow>=2.3.0 + - sendgrid>=6.0.0 + versions: + - 3.1.0 + - 3.0.0 - 2.0.4 - 2.0.3 - 2.0.2 diff --git a/airflow/providers/sendgrid/utils/emailer.py b/airflow/providers/sendgrid/utils/emailer.py index 58a1968180914..4d872c542b2ee 100644 --- a/airflow/providers/sendgrid/utils/emailer.py +++ b/airflow/providers/sendgrid/utils/emailer.py @@ -16,13 +16,14 @@ # specific language governing permissions and limitations # under the License. """Airflow module for email backend using sendgrid""" +from __future__ import annotations import base64 import logging import mimetypes import os import warnings -from typing import Dict, Iterable, Optional, Union +from typing import Iterable, Union import sendgrid from sendgrid.helpers.mail import ( @@ -50,9 +51,9 @@ def send_email( to: AddressesType, subject: str, html_content: str, - files: Optional[AddressesType] = None, - cc: Optional[AddressesType] = None, - bcc: Optional[AddressesType] = None, + files: AddressesType | None = None, + cc: AddressesType | None = None, + bcc: AddressesType | None = None, sandbox_mode: bool = False, conn_id: str = "sendgrid_default", **kwargs, @@ -67,8 +68,8 @@ def send_email( files = [] mail = Mail() - from_email = kwargs.get('from_email') or os.environ.get('SENDGRID_MAIL_FROM') - from_name = kwargs.get('from_name') or os.environ.get('SENDGRID_MAIL_SENDER') + from_email = kwargs.get("from_email") or os.environ.get("SENDGRID_MAIL_FROM") + from_name = kwargs.get("from_name") or os.environ.get("SENDGRID_MAIL_SENDER") mail.from_email = Email(from_email, from_name) mail.subject = subject mail.mail_settings = MailSettings() @@ -91,15 +92,15 @@ def send_email( personalization.add_bcc(Email(bcc_address)) # Add custom_args to personalization if present - pers_custom_args = kwargs.get('personalization_custom_args') + pers_custom_args = kwargs.get("personalization_custom_args") if isinstance(pers_custom_args, dict): for key in pers_custom_args.keys(): personalization.add_custom_arg(CustomArg(key, pers_custom_args[key])) mail.add_personalization(personalization) - mail.add_content(Content('text/html', html_content)) + mail.add_content(Content("text/html", html_content)) - categories = kwargs.get('categories', []) + categories = kwargs.get("categories", []) for cat in categories: mail.add_category(Category(cat)) @@ -108,7 +109,7 @@ def send_email( basename = os.path.basename(fname) with open(fname, "rb") as file: - content = base64.b64encode(file.read()).decode('utf-8') + content = base64.b64encode(file.read()).decode("utf-8") attachment = Attachment( file_content=content, @@ -122,7 +123,7 @@ def send_email( _post_sendgrid_mail(mail.get(), conn_id) -def _post_sendgrid_mail(mail_data: Dict, conn_id: str = "sendgrid_default") -> None: +def _post_sendgrid_mail(mail_data: dict, conn_id: str = "sendgrid_default") -> None: api_key = None try: conn = BaseHook.get_connection(conn_id) @@ -133,22 +134,22 @@ def _post_sendgrid_mail(mail_data: Dict, conn_id: str = "sendgrid_default") -> N warnings.warn( "Fetching Sendgrid credentials from environment variables will be deprecated in a future " "release. Please set credentials using a connection instead.", - PendingDeprecationWarning, + DeprecationWarning, stacklevel=2, ) - api_key = os.environ.get('SENDGRID_API_KEY') + api_key = os.environ.get("SENDGRID_API_KEY") sendgrid_client = sendgrid.SendGridAPIClient(api_key=api_key) response = sendgrid_client.client.mail.send.post(request_body=mail_data) # 2xx status code. if 200 <= response.status_code < 300: log.info( - 'Email with subject %s is successfully sent to recipients: %s', - mail_data['subject'], - mail_data['personalizations'], + "Email with subject %s is successfully sent to recipients: %s", + mail_data["subject"], + mail_data["personalizations"], ) else: log.error( - 'Failed to send out email with subject %s, status code: %s', - mail_data['subject'], + "Failed to send out email with subject %s, status code: %s", + mail_data["subject"], response.status_code, ) diff --git a/airflow/providers/sftp/CHANGELOG.rst b/airflow/providers/sftp/CHANGELOG.rst index 3a4348a09a793..7c769e71276ea 100644 --- a/airflow/providers/sftp/CHANGELOG.rst +++ b/airflow/providers/sftp/CHANGELOG.rst @@ -16,9 +16,87 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +4.2.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Bug Fixes +~~~~~~~~~ + +* ``SFTP Provider: Fix default folder permissions (#26593)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + +4.1.0 +..... + +Features +~~~~~~~~ + +* ``SFTPOperator - add support for list of file paths (#26666)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +4.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* ``Convert sftp hook to use paramiko instead of pysftp (#24512)`` + +Features +~~~~~~~~ + +* ``Update 'actual_file_to_check' with rendered 'path' (#24451)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Automatically detect if non-lazy logging interpolation is used (#24910)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + * ``Add documentation for July 2022 Provider's release (#25030)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Features +~~~~~~~~ + +* ``Adding fnmatch type regex to SFTPSensor (#24084)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.6.0 ..... diff --git a/airflow/providers/sftp/hooks/sftp.py b/airflow/providers/sftp/hooks/sftp.py index 58c820b838386..66ae50665b14f 100644 --- a/airflow/providers/sftp/hooks/sftp.py +++ b/airflow/providers/sftp/hooks/sftp.py @@ -16,15 +16,18 @@ # specific language governing permissions and limitations # under the License. """This module contains SFTP hook.""" +from __future__ import annotations + import datetime +import os import stat import warnings -from typing import Any, Dict, List, Optional, Tuple +from fnmatch import fnmatch +from typing import Any, Callable -import pysftp -import tenacity -from paramiko import SSHException +import paramiko +from airflow.exceptions import AirflowException from airflow.providers.ssh.hooks.ssh import SSHHook @@ -48,135 +51,82 @@ class SFTPHook(SSHHook): Errors that may occur throughout but should be handled downstream. For consistency reasons with SSHHook, the preferred parameter is "ssh_conn_id". - Please note that it is still possible to use the parameter "ftp_conn_id" - to initialize the hook, but it will be removed in future Airflow versions. :param ssh_conn_id: The :ref:`sftp connection id` - :param ftp_conn_id (Outdated): The :ref:`sftp connection id` + :param ssh_hook: Optional SSH hook (included to support passing of an SSH hook to the SFTP operator) """ - conn_name_attr = 'ssh_conn_id' - default_conn_name = 'sftp_default' - conn_type = 'sftp' - hook_name = 'SFTP' + conn_name_attr = "ssh_conn_id" + default_conn_name = "sftp_default" + conn_type = "sftp" + hook_name = "SFTP" @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + def get_ui_field_behaviour() -> dict[str, Any]: return { - "hidden_fields": ['schema'], + "hidden_fields": ["schema"], "relabeling": { - 'login': 'Username', + "login": "Username", }, } def __init__( self, - ssh_conn_id: Optional[str] = 'sftp_default', + ssh_conn_id: str | None = "sftp_default", + ssh_hook: SSHHook | None = None, *args, **kwargs, ) -> None: - ftp_conn_id = kwargs.pop('ftp_conn_id', None) + self.conn: paramiko.SFTPClient | None = None + + # TODO: remove support for ssh_hook when it is removed from SFTPOperator + self.ssh_hook = ssh_hook + + if self.ssh_hook is not None: + warnings.warn( + "Parameter `ssh_hook` is deprecated and will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + if not isinstance(self.ssh_hook, SSHHook): + raise AirflowException( + f"ssh_hook must be an instance of SSHHook, but got {type(self.ssh_hook)}" + ) + self.log.info("ssh_hook is provided. It will be used to generate SFTP connection.") + self.ssh_conn_id = self.ssh_hook.ssh_conn_id + return + + ftp_conn_id = kwargs.pop("ftp_conn_id", None) if ftp_conn_id: warnings.warn( - 'Parameter `ftp_conn_id` is deprecated. Please use `ssh_conn_id` instead.', + "Parameter `ftp_conn_id` is deprecated. Please use `ssh_conn_id` instead.", DeprecationWarning, stacklevel=2, ) ssh_conn_id = ftp_conn_id - kwargs['ssh_conn_id'] = ssh_conn_id - super().__init__(*args, **kwargs) - self.conn = None - self.private_key_pass = None - self.ciphers = None - - # Fail for unverified hosts, unless this is explicitly allowed - self.no_host_key_check = False - - if self.ssh_conn_id is not None: - conn = self.get_connection(self.ssh_conn_id) - if conn.extra is not None: - extra_options = conn.extra_dejson - - # For backward compatibility - # TODO: remove in the next major provider release. - - if 'private_key_pass' in extra_options: - warnings.warn( - 'Extra option `private_key_pass` is deprecated.' - 'Please use `private_key_passphrase` instead.' - '`private_key_passphrase` will precede if both options are specified.' - 'The old option `private_key_pass` will be removed in a future release.', - DeprecationWarning, - stacklevel=2, - ) - self.private_key_pass = extra_options.get( - 'private_key_passphrase', extra_options.get('private_key_pass') - ) + kwargs["ssh_conn_id"] = ssh_conn_id + self.ssh_conn_id = ssh_conn_id + + super().__init__(*args, **kwargs) - if 'ignore_hostkey_verification' in extra_options: - warnings.warn( - 'Extra option `ignore_hostkey_verification` is deprecated.' - 'Please use `no_host_key_check` instead.' - 'This option will be removed in a future release.', - DeprecationWarning, - stacklevel=2, - ) - self.no_host_key_check = ( - str(extra_options['ignore_hostkey_verification']).lower() == 'true' - ) - - if 'no_host_key_check' in extra_options: - self.no_host_key_check = str(extra_options['no_host_key_check']).lower() == 'true' - - if 'ciphers' in extra_options: - self.ciphers = extra_options['ciphers'] - - @tenacity.retry( - stop=tenacity.stop_after_delay(10), - wait=tenacity.wait_exponential(multiplier=1, max=10), - retry=tenacity.retry_if_exception_type(SSHException), - reraise=True, - ) - def get_conn(self) -> pysftp.Connection: - """Returns an SFTP connection object""" + def get_conn(self) -> paramiko.SFTPClient: # type: ignore[override] + """Opens an SFTP connection to the remote host""" if self.conn is None: - cnopts = pysftp.CnOpts() - if self.no_host_key_check: - cnopts.hostkeys = None + # TODO: remove support for ssh_hook when it is removed from SFTPOperator + if self.ssh_hook is not None: + self.conn = self.ssh_hook.get_conn().open_sftp() else: - if self.host_key is not None: - cnopts.hostkeys.add(self.remote_host, self.host_key.get_name(), self.host_key) - else: - pass # will fallback to system host keys if none explicitly specified in conn extra - - cnopts.compression = self.compress - cnopts.ciphers = self.ciphers - conn_params = { - 'host': self.remote_host, - 'port': self.port, - 'username': self.username, - 'cnopts': cnopts, - } - if self.password and self.password.strip(): - conn_params['password'] = self.password - if self.pkey: - conn_params['private_key'] = self.pkey - elif self.key_file: - conn_params['private_key'] = self.key_file - if self.private_key_pass: - conn_params['private_key_pass'] = self.private_key_pass - - self.conn = pysftp.Connection(**conn_params) + self.conn = super().get_conn().open_sftp() return self.conn def close_conn(self) -> None: - """Closes the connection""" + """Closes the SFTP connection""" if self.conn is not None: self.conn.close() self.conn = None - def describe_directory(self, path: str) -> Dict[str, Dict[str, str]]: + def describe_directory(self, path: str) -> dict[str, dict[str, str | int | None]]: """ Returns a dictionary of {filename: {attributes}} for all files on the remote system (where the MLSD command is supported). @@ -184,36 +134,85 @@ def describe_directory(self, path: str) -> Dict[str, Dict[str, str]]: :param path: full path to the remote directory """ conn = self.get_conn() - flist = conn.listdir_attr(path) + flist = sorted(conn.listdir_attr(path), key=lambda x: x.filename) files = {} for f in flist: - modify = datetime.datetime.fromtimestamp(f.st_mtime).strftime('%Y%m%d%H%M%S') + modify = datetime.datetime.fromtimestamp(f.st_mtime).strftime("%Y%m%d%H%M%S") # type: ignore files[f.filename] = { - 'size': f.st_size, - 'type': 'dir' if stat.S_ISDIR(f.st_mode) else 'file', - 'modify': modify, + "size": f.st_size, + "type": "dir" if stat.S_ISDIR(f.st_mode) else "file", # type: ignore + "modify": modify, } return files - def list_directory(self, path: str) -> List[str]: + def list_directory(self, path: str) -> list[str]: """ Returns a list of files on the remote system. :param path: full path to the remote directory to list """ conn = self.get_conn() - files = conn.listdir(path) + files = sorted(conn.listdir(path)) return files - def create_directory(self, path: str, mode: int = 777) -> None: + def mkdir(self, path: str, mode: int = 0o777) -> None: + """ + Creates a directory on the remote system. + The default mode is 0777, but on some systems, the current umask value is first masked out. + + :param path: full path to the remote directory to create + :param mode: int permissions of octal mode for directory + """ + conn = self.get_conn() + conn.mkdir(path, mode=mode) + + def isdir(self, path: str) -> bool: + """ + Checks if the path provided is a directory or not. + + :param path: full path to the remote directory to check + """ + conn = self.get_conn() + try: + result = stat.S_ISDIR(conn.stat(path).st_mode) # type: ignore + except OSError: + result = False + return result + + def isfile(self, path: str) -> bool: + """ + Checks if the path provided is a file or not. + + :param path: full path to the remote file to check + """ + conn = self.get_conn() + try: + result = stat.S_ISREG(conn.stat(path).st_mode) # type: ignore + except OSError: + result = False + return result + + def create_directory(self, path: str, mode: int = 0o777) -> None: """ Creates a directory on the remote system. + The default mode is 0777, but on some systems, the current umask value is first masked out. :param path: full path to the remote directory to create - :param mode: int representation of octal mode for directory + :param mode: int permissions of octal mode for directory """ conn = self.get_conn() - conn.makedirs(path, mode) + if self.isdir(path): + self.log.info("%s already exists", path) + return + elif self.isfile(path): + raise AirflowException(f"{path} already exists and is a file") + else: + dirname, basename = os.path.split(path) + if dirname and not self.isdir(dirname): + self.create_directory(dirname, mode) + if basename: + self.log.info("Creating %s", path) + conn.mkdir(path, mode=mode) def delete_directory(self, path: str) -> None: """ @@ -236,7 +235,7 @@ def retrieve_file(self, remote_full_path: str, local_full_path: str) -> None: conn = self.get_conn() conn.get(remote_full_path, local_full_path) - def store_file(self, remote_full_path: str, local_full_path: str) -> None: + def store_file(self, remote_full_path: str, local_full_path: str, confirm: bool = True) -> None: """ Transfers a local file to the remote location. If local_full_path_or_buffer is a string path, the file will be read @@ -246,7 +245,7 @@ def store_file(self, remote_full_path: str, local_full_path: str) -> None: :param local_full_path: full path to the local file """ conn = self.get_conn() - conn.put(local_full_path, remote_full_path) + conn.put(local_full_path, remote_full_path, confirm=confirm) def delete_file(self, path: str) -> None: """ @@ -265,7 +264,7 @@ def get_mod_time(self, path: str) -> str: """ conn = self.get_conn() ftp_mdtm = conn.stat(path).st_mtime - return datetime.datetime.fromtimestamp(ftp_mdtm).strftime('%Y%m%d%H%M%S') + return datetime.datetime.fromtimestamp(ftp_mdtm).strftime("%Y%m%d%H%M%S") # type: ignore def path_exists(self, path: str) -> bool: """ @@ -274,10 +273,14 @@ def path_exists(self, path: str) -> bool: :param path: full path to the remote file or directory """ conn = self.get_conn() - return conn.exists(path) + try: + conn.stat(path) + except OSError: + return False + return True @staticmethod - def _is_path_match(path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None) -> bool: + def _is_path_match(path: str, prefix: str | None = None, delimiter: str | None = None) -> bool: """ Return True if given path starts with prefix (if set) and ends with delimiter (if set). @@ -292,9 +295,54 @@ def _is_path_match(path: str, prefix: Optional[str] = None, delimiter: Optional[ return False return True + def walktree( + self, + path: str, + fcallback: Callable[[str], Any | None], + dcallback: Callable[[str], Any | None], + ucallback: Callable[[str], Any | None], + recurse: bool = True, + ) -> None: + """ + Recursively descend, depth first, the directory tree rooted at + path, calling discreet callback functions for each regular file, + directory and unknown file type. + + :param str path: + root of remote directory to descend, use '.' to start at + :attr:`.pwd` + :param callable fcallback: + callback function to invoke for a regular file. + (form: ``func(str)``) + :param callable dcallback: + callback function to invoke for a directory. (form: ``func(str)``) + :param callable ucallback: + callback function to invoke for an unknown file type. + (form: ``func(str)``) + :param bool recurse: *Default: True* - should it recurse + + :returns: None + """ + conn = self.get_conn() + for entry in self.list_directory(path): + pathname = os.path.join(path, entry) + mode = conn.stat(pathname).st_mode + if stat.S_ISDIR(mode): # type: ignore + # It's a directory, call the dcallback function + dcallback(pathname) + if recurse: + # now, recurse into it + self.walktree(pathname, fcallback, dcallback, ucallback) + elif stat.S_ISREG(mode): # type: ignore + # It's a file, call the fcallback function + fcallback(pathname) + else: + # Unknown file type + ucallback(pathname) + def get_tree_map( - self, path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None - ) -> Tuple[List[str], List[str], List[str]]: + self, path: str, prefix: str | None = None, delimiter: str | None = None + ) -> tuple[list[str], list[str], list[str]]: """ Return tuple with recursive lists of files, directories and unknown paths from given path. It is possible to filter results by giving prefix and/or delimiter parameters. @@ -303,16 +351,16 @@ def get_tree_map( :param prefix: if set paths will be added if start with prefix :param delimiter: if set paths will be added if end with delimiter :return: tuple with list of files, dirs and unknown items - :rtype: Tuple[List[str], List[str], List[str]] """ - conn = self.get_conn() - files, dirs, unknowns = [], [], [] # type: List[str], List[str], List[str] + files: list[str] = [] + dirs: list[str] = [] + unknowns: list[str] = [] - def append_matching_path_callback(list_): + def append_matching_path_callback(list_: list[str]) -> Callable: return lambda item: list_.append(item) if self._is_path_match(item, prefix, delimiter) else None - conn.walktree( - remotepath=path, + self.walktree( + path=path, fcallback=append_matching_path_callback(files), dcallback=append_matching_path_callback(dirs), ucallback=append_matching_path_callback(unknowns), @@ -321,11 +369,29 @@ def append_matching_path_callback(list_): return files, dirs, unknowns - def test_connection(self) -> Tuple[bool, str]: + def test_connection(self) -> tuple[bool, str]: """Test the SFTP connection by calling path with directory""" try: conn = self.get_conn() - conn.pwd + conn.normalize(".") return True, "Connection successfully tested" except Exception as e: return False, str(e) + + def get_file_by_pattern(self, path, fnmatch_pattern) -> str: + """ + Returning the first matching file based on the given fnmatch type pattern + + :param path: path to be checked + :param fnmatch_pattern: The pattern that will be matched with `fnmatch` + :return: string containing the first found file, or an empty string if none matched + """ + files_list = self.list_directory(path) + + for file in files_list: + if not fnmatch(file, fnmatch_pattern): + pass + else: + return file + + return "" diff --git a/airflow/providers/sftp/operators/sftp.py b/airflow/providers/sftp/operators/sftp.py index c78c7f4c04d64..8884818ffb493 100644 --- a/airflow/providers/sftp/operators/sftp.py +++ b/airflow/providers/sftp/operators/sftp.py @@ -16,38 +16,44 @@ # specific language governing permissions and limitations # under the License. """This module contains SFTP operator.""" +from __future__ import annotations + import os +import warnings from pathlib import Path from typing import Any, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator +from airflow.providers.sftp.hooks.sftp import SFTPHook from airflow.providers.ssh.hooks.ssh import SSHHook class SFTPOperation: - """Operation that can be used with SFTP/""" + """Operation that can be used with SFTP""" - PUT = 'put' - GET = 'get' + PUT = "put" + GET = "get" class SFTPOperator(BaseOperator): """ SFTPOperator for transferring files from remote host to local or vice a versa. - This operator uses ssh_hook to open sftp transport channel that serve as basis + This operator uses sftp_hook to open sftp transport channel that serve as basis for file transfer. - :param ssh_hook: predefined ssh_hook to use for remote execution. - Either `ssh_hook` or `ssh_conn_id` needs to be provided. :param ssh_conn_id: :ref:`ssh connection id` from airflow Connections. `ssh_conn_id` will be ignored if `ssh_hook` - is provided. + or `sftp_hook` is provided. + :param sftp_hook: predefined SFTPHook to use + Either `sftp_hook` or `ssh_conn_id` needs to be provided. + :param ssh_hook: Deprecated - predefined SSHHook to use for remote execution + Use `sftp_hook` instead. :param remote_host: remote host to connect (templated) Nullable. If provided, it will replace the `remote_host` which was - defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`. - :param local_filepath: local file path to get or put. (templated) - :param remote_filepath: remote file path to get or put. (templated) + defined in `sftp_hook`/`ssh_hook` or predefined in the connection of `ssh_conn_id`. + :param local_filepath: local file path or list of local file paths to get or put. (templated) + :param remote_filepath: remote file path or list of remote file paths to get or put. (templated) :param operation: specify operation 'get' or 'put', defaults to put :param confirm: specify if the SFTP operation should be confirmed, defaults to True :param create_intermediate_dirs: create missing intermediate directories when @@ -70,103 +76,115 @@ class SFTPOperator(BaseOperator): """ - template_fields: Sequence[str] = ('local_filepath', 'remote_filepath', 'remote_host') + template_fields: Sequence[str] = ("local_filepath", "remote_filepath", "remote_host") def __init__( self, *, - ssh_hook=None, - ssh_conn_id=None, - remote_host=None, - local_filepath=None, - remote_filepath=None, - operation=SFTPOperation.PUT, - confirm=True, - create_intermediate_dirs=False, + ssh_hook: SSHHook | None = None, + sftp_hook: SFTPHook | None = None, + ssh_conn_id: str | None = None, + remote_host: str | None = None, + local_filepath: str | list[str], + remote_filepath: str | list[str], + operation: str = SFTPOperation.PUT, + confirm: bool = True, + create_intermediate_dirs: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) self.ssh_hook = ssh_hook + self.sftp_hook = sftp_hook self.ssh_conn_id = ssh_conn_id self.remote_host = remote_host - self.local_filepath = local_filepath - self.remote_filepath = remote_filepath self.operation = operation self.confirm = confirm self.create_intermediate_dirs = create_intermediate_dirs + + self.local_filepath_was_str = False + if isinstance(local_filepath, str): + self.local_filepath = [local_filepath] + self.local_filepath_was_str = True + else: + self.local_filepath = local_filepath + + if isinstance(remote_filepath, str): + self.remote_filepath = [remote_filepath] + else: + self.remote_filepath = remote_filepath + + if len(self.local_filepath) != len(self.remote_filepath): + raise ValueError( + f"{len(self.local_filepath)} paths in local_filepath " + f"!= {len(self.remote_filepath)} paths in remote_filepath" + ) + if not (self.operation.lower() == SFTPOperation.GET or self.operation.lower() == SFTPOperation.PUT): raise TypeError( f"Unsupported operation value {self.operation}, " f"expected {SFTPOperation.GET} or {SFTPOperation.PUT}." ) - def execute(self, context: Any) -> str: + # TODO: remove support for ssh_hook in next major provider version in hook and operator + if self.ssh_hook is not None and self.sftp_hook is not None: + raise AirflowException( + "Both `ssh_hook` and `sftp_hook` are defined. Please use only one of them." + ) + + if self.ssh_hook is not None: + if not isinstance(self.ssh_hook, SSHHook): + self.log.info("ssh_hook is invalid. Trying ssh_conn_id to create SFTPHook.") + self.sftp_hook = SFTPHook(ssh_conn_id=self.ssh_conn_id) + if self.sftp_hook is None: + warnings.warn( + "Parameter `ssh_hook` is deprecated" + "Please use `sftp_hook` instead." + "The old parameter `ssh_hook` will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + self.sftp_hook = SFTPHook(ssh_hook=self.ssh_hook) + + def execute(self, context: Any) -> str | list[str] | None: file_msg = None try: if self.ssh_conn_id: - if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): - self.log.info("ssh_conn_id is ignored when ssh_hook is provided.") + if self.sftp_hook and isinstance(self.sftp_hook, SFTPHook): + self.log.info("ssh_conn_id is ignored when sftp_hook/ssh_hook is provided.") else: self.log.info( - "ssh_hook is not provided or invalid. Trying ssh_conn_id to create SSHHook." + "sftp_hook/ssh_hook not provided or invalid. Trying ssh_conn_id to create SFTPHook." ) - self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) + self.sftp_hook = SFTPHook(ssh_conn_id=self.ssh_conn_id) - if not self.ssh_hook: - raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.") + if not self.sftp_hook: + raise AirflowException("Cannot operate without sftp_hook or ssh_conn_id.") if self.remote_host is not None: self.log.info( "remote_host is provided explicitly. " "It will replace the remote_host which was defined " - "in ssh_hook or predefined in connection of ssh_conn_id." + "in sftp_hook or predefined in connection of ssh_conn_id." ) - self.ssh_hook.remote_host = self.remote_host + self.sftp_hook.remote_host = self.remote_host - with self.ssh_hook.get_conn() as ssh_client: - sftp_client = ssh_client.open_sftp() + for local_filepath, remote_filepath in zip(self.local_filepath, self.remote_filepath): if self.operation.lower() == SFTPOperation.GET: - local_folder = os.path.dirname(self.local_filepath) + local_folder = os.path.dirname(local_filepath) if self.create_intermediate_dirs: Path(local_folder).mkdir(parents=True, exist_ok=True) - file_msg = f"from {self.remote_filepath} to {self.local_filepath}" + file_msg = f"from {remote_filepath} to {local_filepath}" self.log.info("Starting to transfer %s", file_msg) - sftp_client.get(self.remote_filepath, self.local_filepath) + self.sftp_hook.retrieve_file(remote_filepath, local_filepath) else: - remote_folder = os.path.dirname(self.remote_filepath) + remote_folder = os.path.dirname(remote_filepath) if self.create_intermediate_dirs: - _make_intermediate_dirs( - sftp_client=sftp_client, - remote_directory=remote_folder, - ) - file_msg = f"from {self.local_filepath} to {self.remote_filepath}" + self.sftp_hook.create_directory(remote_folder) + file_msg = f"from {local_filepath} to {remote_filepath}" self.log.info("Starting to transfer file %s", file_msg) - sftp_client.put(self.local_filepath, self.remote_filepath, confirm=self.confirm) + self.sftp_hook.store_file(remote_filepath, local_filepath, confirm=self.confirm) except Exception as e: raise AirflowException(f"Error while transferring {file_msg}, error: {str(e)}") - return self.local_filepath - - -def _make_intermediate_dirs(sftp_client, remote_directory) -> None: - """ - Create all the intermediate directories in a remote host - - :param sftp_client: A Paramiko SFTP client. - :param remote_directory: Absolute Path of the directory containing the file - :return: - """ - if remote_directory == '/': - sftp_client.chdir('/') - return - if remote_directory == '': - return - try: - sftp_client.chdir(remote_directory) - except OSError: - dirname, basename = os.path.split(remote_directory.rstrip('/')) - _make_intermediate_dirs(sftp_client, dirname) - sftp_client.mkdir(basename) - sftp_client.chdir(basename) - return + return self.local_filepath[0] if self.local_filepath_was_str else self.local_filepath diff --git a/airflow/providers/sftp/provider.yaml b/airflow/providers/sftp/provider.yaml index 859d40bda1c75..f09a4b1972f3a 100644 --- a/airflow/providers/sftp/provider.yaml +++ b/airflow/providers/sftp/provider.yaml @@ -22,6 +22,10 @@ description: | `SSH File Transfer Protocol (SFTP) `__ versions: + - 4.2.0 + - 4.1.0 + - 4.0.0 + - 3.0.0 - 2.6.0 - 2.5.2 - 2.5.1 @@ -38,8 +42,8 @@ versions: - 1.1.0 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 - apache-airflow-providers-ssh>=2.1.0 integrations: @@ -63,8 +67,6 @@ hooks: python-modules: - airflow.providers.sftp.hooks.sftp -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.sftp.hooks.sftp.SFTPHook connection-types: - hook-class-name: airflow.providers.sftp.hooks.sftp.SFTPHook diff --git a/airflow/providers/sftp/sensors/sftp.py b/airflow/providers/sftp/sensors/sftp.py index 904321e9b80c0..39dcf4b113caf 100644 --- a/airflow/providers/sftp/sensors/sftp.py +++ b/airflow/providers/sftp/sensors/sftp.py @@ -16,8 +16,10 @@ # specific language governing permissions and limitations # under the License. """This module contains SFTP sensor.""" +from __future__ import annotations + from datetime import datetime -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Sequence from paramiko.sftp import SFTP_NO_SUCH_FILE @@ -34,42 +36,55 @@ class SFTPSensor(BaseSensorOperator): Waits for a file or directory to be present on SFTP. :param path: Remote file or directory path + :param file_pattern: The pattern that will be used to match the file (fnmatch format) :param sftp_conn_id: The connection to run the sensor against :param newer_than: DateTime for which the file or file path should be newer than, comparison is inclusive """ template_fields: Sequence[str] = ( - 'path', - 'newer_than', + "path", + "newer_than", ) def __init__( self, *, path: str, - newer_than: Optional[datetime] = None, - sftp_conn_id: str = 'sftp_default', + file_pattern: str = "", + newer_than: datetime | None = None, + sftp_conn_id: str = "sftp_default", **kwargs, ) -> None: super().__init__(**kwargs) self.path = path - self.hook: Optional[SFTPHook] = None + self.file_pattern = file_pattern + self.hook: SFTPHook | None = None self.sftp_conn_id = sftp_conn_id - self.newer_than: Optional[datetime] = newer_than + self.newer_than: datetime | None = newer_than - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: self.hook = SFTPHook(self.sftp_conn_id) - self.log.info('Poking for %s', self.path) + self.log.info("Poking for %s, with pattern %s", self.path, self.file_pattern) + + if self.file_pattern: + file_from_pattern = self.hook.get_file_by_pattern(self.path, self.file_pattern) + if file_from_pattern: + actual_file_to_check = file_from_pattern + else: + return False + else: + actual_file_to_check = self.path + try: - mod_time = self.hook.get_mod_time(self.path) - self.log.info('Found File %s last modified: %s', str(self.path), str(mod_time)) + mod_time = self.hook.get_mod_time(actual_file_to_check) + self.log.info("Found File %s last modified: %s", str(actual_file_to_check), str(mod_time)) except OSError as e: if e.errno != SFTP_NO_SUCH_FILE: raise e return False self.hook.close_conn() if self.newer_than: - _mod_time = convert_to_utc(datetime.strptime(mod_time, '%Y%m%d%H%M%S')) + _mod_time = convert_to_utc(datetime.strptime(mod_time, "%Y%m%d%H%M%S")) _newer_than = convert_to_utc(self.newer_than) return _newer_than <= _mod_time else: diff --git a/airflow/providers/singularity/.latest-doc-only-change.txt b/airflow/providers/singularity/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/singularity/.latest-doc-only-change.txt +++ b/airflow/providers/singularity/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/singularity/CHANGELOG.rst b/airflow/providers/singularity/CHANGELOG.rst index a4df9ec084224..ba69168f24ad2 100644 --- a/airflow/providers/singularity/CHANGELOG.rst +++ b/airflow/providers/singularity/CHANGELOG.rst @@ -16,9 +16,50 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Migrate Singularity example DAGs to new design #22464 (#24128)`` + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.0.4 ..... diff --git a/airflow/providers/singularity/example_dags/__init__.py b/airflow/providers/singularity/example_dags/__init__.py deleted file mode 100644 index 217e5db960782..0000000000000 --- a/airflow/providers/singularity/example_dags/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/airflow/providers/singularity/example_dags/example_singularity.py b/airflow/providers/singularity/example_dags/example_singularity.py deleted file mode 100644 index cf54a28084de0..0000000000000 --- a/airflow/providers/singularity/example_dags/example_singularity.py +++ /dev/null @@ -1,46 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from datetime import datetime, timedelta - -from airflow import DAG -from airflow.operators.bash import BashOperator -from airflow.providers.singularity.operators.singularity import SingularityOperator - -with DAG( - 'singularity_sample', - default_args={'retries': 1}, - schedule_interval=timedelta(minutes=10), - start_date=datetime(2021, 1, 1), - catchup=False, -) as dag: - - t1 = BashOperator(task_id='print_date', bash_command='date') - - t2 = BashOperator(task_id='sleep', bash_command='sleep 5', retries=3) - - t3 = SingularityOperator( - command='/bin/sleep 30', - image='docker://busybox:1.30.1', - task_id='singularity_op_tester', - ) - - t4 = BashOperator(task_id='print_hello', bash_command='echo "hello world!!!"') - - t1 >> [t2, t3] - t3 >> t4 diff --git a/airflow/providers/singularity/operators/singularity.py b/airflow/providers/singularity/operators/singularity.py index 9a5587f0540f6..121566ca90696 100644 --- a/airflow/providers/singularity/operators/singularity.py +++ b/airflow/providers/singularity/operators/singularity.py @@ -15,11 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import ast import os import shutil -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Sequence from spython.main import Client @@ -55,12 +56,12 @@ class SingularityOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'command', - 'environment', + "command", + "environment", ) template_ext: Sequence[str] = ( - '.sh', - '.bash', + ".sh", + ".bash", ) template_fields_renderers = {"command": "bash", "environment": "json"} @@ -68,15 +69,15 @@ def __init__( self, *, image: str, - command: Union[str, ast.AST], - start_command: Optional[Union[str, List[str]]] = None, - environment: Optional[Dict[str, Any]] = None, - pull_folder: Optional[str] = None, - working_dir: Optional[str] = None, - force_pull: Optional[bool] = False, - volumes: Optional[List[str]] = None, - options: Optional[List[str]] = None, - auto_remove: Optional[bool] = False, + command: str | ast.AST, + start_command: str | list[str] | None = None, + environment: dict[str, Any] | None = None, + pull_folder: str | None = None, + working_dir: str | None = None, + force_pull: bool | None = False, + volumes: list[str] | None = None, + options: list[str] | None = None, + auto_remove: bool | None = False, **kwargs, ) -> None: @@ -95,17 +96,17 @@ def __init__( self.cli = None self.container = None - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: - self.log.info('Preparing Singularity container %s', self.image) + self.log.info("Preparing Singularity container %s", self.image) self.cli = Client if not self.command: - raise AirflowException('You must define a command.') + raise AirflowException("You must define a command.") # Pull the container if asked, and ensure not a binary file if self.force_pull and not os.path.exists(self.image): - self.log.info('Pulling container %s', self.image) + self.log.info("Pulling container %s", self.image) image = self.cli.pull( # type: ignore[attr-defined] self.image, stream=True, pull_folder=self.pull_folder ) @@ -122,36 +123,36 @@ def execute(self, context: 'Context') -> None: # Prepare list of binds for bind in self.volumes: - self.options += ['--bind', bind] + self.options += ["--bind", bind] # Does the user want a custom working directory? if self.working_dir is not None: - self.options += ['--workdir', self.working_dir] + self.options += ["--workdir", self.working_dir] # Export environment before instance is run for enkey, envar in self.environment.items(): - self.log.debug('Exporting %s=%s', envar, enkey) + self.log.debug("Exporting %s=%s", envar, enkey) os.putenv(enkey, envar) os.environ[enkey] = envar # Create a container instance - self.log.debug('Options include: %s', self.options) + self.log.debug("Options include: %s", self.options) self.instance = self.cli.instance( # type: ignore[attr-defined] self.image, options=self.options, args=self.start_command, start=False ) self.instance.start() # type: ignore[attr-defined] self.log.info(self.instance.cmd) # type: ignore[attr-defined] - self.log.info('Created instance %s from %s', self.instance, self.image) + self.log.info("Created instance %s from %s", self.instance, self.image) - self.log.info('Running command %s', self._get_command()) + self.log.info("Running command %s", self._get_command()) self.cli.quiet = True # type: ignore[attr-defined] result = self.cli.execute( # type: ignore[attr-defined] self.instance, self._get_command(), return_result=True ) # Stop the instance - self.log.info('Stopping instance %s', self.instance) + self.log.info("Stopping instance %s", self.instance) self.instance.stop() # type: ignore[attr-defined] if self.auto_remove is True: @@ -159,14 +160,14 @@ def execute(self, context: 'Context') -> None: shutil.rmtree(self.image) # If the container failed, raise the exception - if result['return_code'] != 0: - message = result['message'] - raise AirflowException(f'Singularity failed: {message}') + if result["return_code"] != 0: + message = result["message"] + raise AirflowException(f"Singularity failed: {message}") - self.log.info('Output from command %s', result['message']) + self.log.info("Output from command %s", result["message"]) - def _get_command(self) -> Optional[Any]: - if self.command is not None and self.command.strip().find('[') == 0: # type: ignore + def _get_command(self) -> Any | None: + if self.command is not None and self.command.strip().find("[") == 0: # type: ignore commands = ast.literal_eval(self.command) else: commands = self.command @@ -174,7 +175,7 @@ def _get_command(self) -> Optional[Any]: def on_kill(self) -> None: if self.instance is not None: - self.log.info('Stopping Singularity instance') + self.log.info("Stopping Singularity instance") self.instance.stop() # If an image exists, clean it up diff --git a/airflow/providers/singularity/provider.yaml b/airflow/providers/singularity/provider.yaml index ff9058e6d5d48..ff0ab0945d249 100644 --- a/airflow/providers/singularity/provider.yaml +++ b/airflow/providers/singularity/provider.yaml @@ -22,6 +22,8 @@ description: | `Singularity `__ versions: + - 3.1.0 + - 3.0.0 - 2.0.4 - 2.0.3 - 2.0.2 @@ -31,8 +33,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - spython>=0.0.56 integrations: - integration-name: Singularity diff --git a/airflow/providers/slack/.latest-doc-only-change.txt b/airflow/providers/slack/.latest-doc-only-change.txt index 570fad6daee29..c6282d7cb8d14 100644 --- a/airflow/providers/slack/.latest-doc-only-change.txt +++ b/airflow/providers/slack/.latest-doc-only-change.txt @@ -1 +1 @@ -97496ba2b41063fa24393c58c5c648a0cdb5a7f8 +808035e00aaf59a8012c50903a09d3f50bd92ca4 diff --git a/airflow/providers/slack/CHANGELOG.rst b/airflow/providers/slack/CHANGELOG.rst index f27f18e4897ae..98f00e5b60f6b 100644 --- a/airflow/providers/slack/CHANGELOG.rst +++ b/airflow/providers/slack/CHANGELOG.rst @@ -16,9 +16,104 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +7.0.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.3+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers +* In SlackHook and SlackWebhookHook, if both ``extra____foo`` and ``foo`` existed in connection extra + dict, the prefixed version would be used; now, the non-prefixed version will be preferred. You'll see a warning + if there is such a collision. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` +* ``Allow and prefer non-prefixed extra fields for slack hooks (#27070)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + +.. Review and move the new changes to one of the sections above: + * ``Replace urlparse with urlsplit (#27389)`` + +6.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* The hook class ``SlackWebhookHook`` does not inherit from ``HttpHook`` anymore. In practice the + only impact on user-defined classes based on **SlackWebhookHook** and you use attributes + from **HttpHook**. +* Drop support deprecated ``webhook_token`` parameter in ``slack-incoming-webhook`` extra. + +* ``Refactor 'SlackWebhookOperator': Get rid of mandatory http-provider dependency (#26648)`` +* ``Refactor SlackWebhookHook in order to use 'slack_sdk' instead of HttpHook methods (#26452)`` + +Features +~~~~~~~~ + +* ``Move send_file method into SlackHook (#26118)`` +* ``Refactor Slack API Hook and add Connection (#25852)`` +* ``Remove unsafe imports in Slack API Connection (#26459)`` +* ``Add common-sql lower bound for common-sql (#25789)`` +* ``Fix Slack Connections created in the UI (#26845)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Replace SQL with Common SQL in pre commit (#26058)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``AIP-47 - Migrate Slack DAG to new design (#25137)`` + * ``Fix errors in CHANGELOGS for slack and amazon (#26746)`` + * ``Update docs for September Provider's release (#26731)`` + +5.1.0 +..... + +Features +~~~~~~~~ + +* ``Move all SQL classes to common-sql provider (#24836)`` +* ``Adding generic 'SqlToSlackOperator' (#24663)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Update docstring in 'SqlToSlackOperator' (#24759)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +5.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 4.2.3 ..... diff --git a/airflow/providers/slack/hooks/slack.py b/airflow/providers/slack/hooks/slack.py index c00f0106e1e4d..2bf9bd2acef90 100644 --- a/airflow/providers/slack/hooks/slack.py +++ b/airflow/providers/slack/hooks/slack.py @@ -15,27 +15,79 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Hook for Slack""" -from typing import Any, Optional +from __future__ import annotations + +import json +import warnings +from functools import wraps +from pathlib import Path +from typing import TYPE_CHECKING, Any, Sequence from slack_sdk import WebClient -from slack_sdk.web.slack_response import SlackResponse +from slack_sdk.errors import SlackApiError -from airflow.exceptions import AirflowException +from airflow.compat.functools import cached_property +from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.hooks.base import BaseHook +from airflow.providers.slack.utils import ConnectionExtraConfig +from airflow.utils.log.secrets_masker import mask_secret + +if TYPE_CHECKING: + from slack_sdk.http_retry import RetryHandler + from slack_sdk.web.slack_response import SlackResponse + + +def _ensure_prefixes(conn_type): + """ + Remove when provider min airflow version >= 2.5.0 since this is handled by + provider manager from that version. + """ + + def dec(func): + @wraps(func) + def inner(cls): + field_behaviors = func(cls) + conn_attrs = {"host", "schema", "login", "password", "port", "extra"} + + def _ensure_prefix(field): + if field not in conn_attrs and not field.startswith("extra__"): + return f"extra__{conn_type}__{field}" + else: + return field + + if "placeholders" in field_behaviors: + placeholders = field_behaviors["placeholders"] + field_behaviors["placeholders"] = {_ensure_prefix(k): v for k, v in placeholders.items()} + return field_behaviors + + return inner + + return dec class SlackHook(BaseHook): """ - Creates a Slack connection to be used for calls. + Creates a Slack API Connection to be used for calls. + + This class provide a thin wrapper around the ``slack_sdk.WebClient``. + + .. seealso:: + - :ref:`Slack API connection ` + - https://api.slack.com/messaging + - https://slack.dev/python-slack-sdk/web/index.html + + .. warning:: + This hook intend to use `Slack API` connection + and might not work correctly with `Slack Incoming Webhook` and `HTTP` connections. Takes both Slack API token directly and connection that has Slack API token. If both are supplied, Slack API token will be used. Also exposes the rest of slack.WebClient args. + Examples: - .. code-block:: python + .. code-block:: python # Create hook - slack_hook = SlackHook(token="xxx") # or slack_hook = SlackHook(slack_conn_id="slack") + slack_hook = SlackHook(slack_conn_id="slack_api_default") # Call generic API with parameters (errors are handled by hook) # For more details check https://api.slack.com/methods/chat.postMessage @@ -45,39 +97,132 @@ class SlackHook(BaseHook): # For more details check https://slack.dev/python-slack-sdk/web/index.html#messaging slack_hook.client.chat_postMessage(channel="#random", text="Hello world!") - :param token: Slack API token :param slack_conn_id: :ref:`Slack connection id ` that has Slack API token in the password field. - :param use_session: A boolean specifying if the client should take advantage of - connection pooling. Default is True. - :param base_url: A string representing the Slack API base URL. Default is - ``https://www.slack.com/api/`` - :param timeout: The maximum number of seconds the client will wait - to connect and receive a response from Slack. Default is 30 seconds. + :param timeout: The maximum number of seconds the client will wait to connect + and receive a response from Slack. If not set than default WebClient value will use. + :param base_url: A string representing the Slack API base URL. + If not set than default WebClient BASE_URL will use (``https://www.slack.com/api/``). + :param proxy: Proxy to make the Slack API call. + :param retry_handlers: List of handlers to customize retry logic in ``slack_sdk.WebClient``. + :param token: (deprecated) Slack API Token. """ + conn_name_attr = "slack_conn_id" + default_conn_name = "slack_api_default" + conn_type = "slack" + hook_name = "Slack API" + def __init__( self, - token: Optional[str] = None, - slack_conn_id: Optional[str] = None, - **client_args: Any, + token: str | None = None, + slack_conn_id: str | None = None, + base_url: str | None = None, + timeout: int | None = None, + proxy: str | None = None, + retry_handlers: list[RetryHandler] | None = None, + **extra_client_args: Any, ) -> None: + if not token and not slack_conn_id: + raise AirflowException("Either `slack_conn_id` or `token` should be provided.") + if token: + mask_secret(token) + warnings.warn( + "Provide token as hook argument deprecated by security reason and will be removed " + "in a future releases. Please specify token in `Slack API` connection.", + DeprecationWarning, + stacklevel=2, + ) + if not slack_conn_id: + warnings.warn( + "You have not set parameter `slack_conn_id`. Currently `Slack API` connection id optional " + "but in a future release it will mandatory.", + FutureWarning, + stacklevel=2, + ) + super().__init__() - self.token = self.__get_token(token, slack_conn_id) - self.client = WebClient(self.token, **client_args) + self._token = token + self.slack_conn_id = slack_conn_id + self.base_url = base_url + self.timeout = timeout + self.proxy = proxy + self.retry_handlers = retry_handlers + self.extra_client_args = extra_client_args + if self.extra_client_args.pop("use_session", None) is not None: + warnings.warn("`use_session` has no affect in slack_sdk.WebClient.", UserWarning, stacklevel=2) + + @cached_property + def client(self) -> WebClient: + """Get the underlying slack_sdk.WebClient (cached).""" + return WebClient(**self._get_conn_params()) + + def get_conn(self) -> WebClient: + """Get the underlying slack_sdk.WebClient (cached).""" + return self.client + + def _get_conn_params(self) -> dict[str, Any]: + """Fetch connection params as a dict and merge it with hook parameters.""" + conn = self.get_connection(self.slack_conn_id) if self.slack_conn_id else None + conn_params: dict[str, Any] = {"retry_handlers": self.retry_handlers} + + if self._token: + conn_params["token"] = self._token + elif conn: + if not conn.password: + raise AirflowNotFoundException( + f"Connection ID {self.slack_conn_id!r} does not contain password (Slack API Token)." + ) + conn_params["token"] = conn.password + + extra_config = ConnectionExtraConfig( + conn_type=self.conn_type, + conn_id=conn.conn_id if conn else None, + extra=conn.extra_dejson if conn else {}, + ) + + # Merge Hook parameters with Connection config + conn_params.update( + { + "timeout": self.timeout or extra_config.getint("timeout", default=None), + "base_url": self.base_url or extra_config.get("base_url", default=None), + "proxy": self.proxy or extra_config.get("proxy", default=None), + } + ) + + # Add additional client args + conn_params.update(self.extra_client_args) + if "logger" not in conn_params: + conn_params["logger"] = self.log + + return {k: v for k, v in conn_params.items() if v is not None} + + @cached_property + def token(self) -> str: + warnings.warn( + "`SlackHook.token` property deprecated and will be removed in a future releases.", + DeprecationWarning, + stacklevel=2, + ) + return self._get_conn_params()["token"] def __get_token(self, token: Any, slack_conn_id: Any) -> str: + warnings.warn( + "`SlackHook.__get_token` method deprecated and will be removed in a future releases.", + DeprecationWarning, + stacklevel=2, + ) if token is not None: return token if slack_conn_id is not None: conn = self.get_connection(slack_conn_id) - if not getattr(conn, 'password', None): - raise AirflowException('Missing token(password) in Slack connection') + if not getattr(conn, "password", None): + raise AirflowException("Missing token(password) in Slack connection") return conn.password - raise AirflowException('Cannot get token: No valid Slack token nor slack_conn_id supplied.') + raise AirflowException("Cannot get token: No valid Slack token nor slack_conn_id supplied.") def call(self, api_method: str, **kwargs) -> SlackResponse: """ @@ -95,3 +240,123 @@ def call(self, api_method: str, **kwargs) -> SlackResponse: iterated on to execute subsequent requests. """ return self.client.api_call(api_method, **kwargs) + + def send_file( + self, + *, + channels: str | Sequence[str] | None = None, + file: str | Path | None = None, + content: str | None = None, + filename: str | None = None, + filetype: str | None = None, + initial_comment: str | None = None, + title: str | None = None, + ) -> SlackResponse: + """ + Create or upload an existing file. + + :param channels: Comma-separated list of channel names or IDs where the file will be shared. + If omitting this parameter, then file will send to workspace. + :param file: Path to file which need to be sent. + :param content: File contents. If omitting this parameter, you must provide a file. + :param filename: Displayed filename. + :param filetype: A file type identifier. + :param initial_comment: The message text introducing the file in specified ``channels``. + :param title: Title of file. + + .. seealso:: + - `Slack API files.upload method `_ + - `File types `_ + """ + if not ((not file) ^ (not content)): + raise ValueError("Either `file` or `content` must be provided, not both.") + elif file: + file = Path(file) + with open(file, "rb") as fp: + if not filename: + filename = file.name + return self.client.files_upload( + file=fp, + filename=filename, + filetype=filetype, + initial_comment=initial_comment, + title=title, + channels=channels, + ) + + return self.client.files_upload( + content=content, + filename=filename, + filetype=filetype, + initial_comment=initial_comment, + title=title, + channels=channels, + ) + + def test_connection(self): + """Tests the Slack API connection. + + .. seealso:: + https://api.slack.com/methods/auth.test + """ + try: + response = self.call("auth.test") + response.validate() + except SlackApiError as e: + return False, str(e) + except Exception as e: + return False, f"Unknown error occurred while testing connection: {e}" + + if isinstance(response.data, bytes): + # If response data binary then return simple message + return True, f"Connection successfully tested (url: {response.api_url})." + + try: + return True, json.dumps(response.data) + except TypeError: + return True, str(response) + + @classmethod + def get_connection_form_widgets(cls) -> dict[str, Any]: + """Returns dictionary of widgets to be added for the hook to handle extra values.""" + from flask_appbuilder.fieldwidgets import BS3TextFieldWidget + from flask_babel import lazy_gettext + from wtforms import IntegerField, StringField + from wtforms.validators import NumberRange, Optional + + return { + "timeout": IntegerField( + lazy_gettext("Timeout"), + widget=BS3TextFieldWidget(), + validators=[Optional(strip_whitespace=True), NumberRange(min=1)], + description="Optional. The maximum number of seconds the client will wait to connect " + "and receive a response from Slack API.", + ), + "base_url": StringField( + lazy_gettext("Base URL"), + widget=BS3TextFieldWidget(), + description="Optional. A string representing the Slack API base URL.", + ), + "proxy": StringField( + lazy_gettext("Proxy"), + widget=BS3TextFieldWidget(), + description="Optional. Proxy to make the Slack API call.", + ), + } + + @classmethod + @_ensure_prefixes(conn_type="slack") + def get_ui_field_behaviour(cls) -> dict[str, Any]: + """Returns custom field behaviour.""" + return { + "hidden_fields": ["login", "port", "host", "schema", "extra"], + "relabeling": { + "password": "Slack API Token", + }, + "placeholders": { + "password": "xoxb-1234567890123-09876543210987-AbCdEfGhIjKlMnOpQrStUvWx", + "timeout": "30", + "base_url": "https://www.slack.com/api/", + "proxy": "http://localhost:9000", + }, + } diff --git a/airflow/providers/slack/hooks/slack_webhook.py b/airflow/providers/slack/hooks/slack_webhook.py index 27d45c0741090..d44a1766731f8 100644 --- a/airflow/providers/slack/hooks/slack_webhook.py +++ b/airflow/providers/slack/hooks/slack_webhook.py @@ -15,144 +15,489 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations + import json import warnings -from typing import Optional +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable +from urllib.parse import urlsplit + +from slack_sdk import WebhookClient +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException -from airflow.providers.http.hooks.http import HttpHook +from airflow.hooks.base import BaseHook +from airflow.models import Connection +from airflow.providers.slack.utils import ConnectionExtraConfig +from airflow.utils.log.secrets_masker import mask_secret + +if TYPE_CHECKING: + from slack_sdk.http_retry import RetryHandler + +DEFAULT_SLACK_WEBHOOK_ENDPOINT = "https://hooks.slack.com/services" +LEGACY_INTEGRATION_PARAMS = ("channel", "username", "icon_emoji", "icon_url") -class SlackWebhookHook(HttpHook): +def check_webhook_response(func: Callable) -> Callable: + """Function decorator that check WebhookResponse and raise an error if status code != 200.""" + + @wraps(func) + def wrapper(*args, **kwargs) -> Callable: + resp = func(*args, **kwargs) + if resp.status_code != 200: + raise AirflowException( + f"Response body: {resp.body!r}, Status Code: {resp.status_code}. " + "See: https://api.slack.com/messaging/webhooks#handling_errors" + ) + return resp + + return wrapper + + +def _ensure_prefixes(conn_type): + """ + Remove when provider min airflow version >= 2.5.0 since this is handled by + provider manager from that version. """ - This hook allows you to post messages to Slack using incoming webhooks. - Takes both Slack webhook token directly and connection that has Slack webhook token. - If both supplied, http_conn_id will be used as base_url, - and webhook_token will be taken as endpoint, the relative path of the url. - - Each Slack webhook token can be pre-configured to use a specific channel, username and - icon. You can override these defaults in this hook. - - :param http_conn_id: connection that has Slack webhook token in the password field - :param webhook_token: Slack webhook token - :param message: The message you want to send on Slack - :param attachments: The attachments to send on Slack. Should be a list of - dictionaries representing Slack attachments. - :param blocks: The blocks to send on Slack. Should be a list of - dictionaries representing Slack blocks. - :param channel: The channel the message should be posted to - :param username: The username to post to slack with - :param icon_emoji: The emoji to use as icon for the user posting to Slack - :param icon_url: The icon image URL string to use in place of the default icon. - :param link_names: Whether or not to find and link channel and usernames in your - message - :param proxy: Proxy to use to make the Slack webhook call + + def dec(func): + @wraps(func) + def inner(cls): + field_behaviors = func(cls) + conn_attrs = {"host", "schema", "login", "password", "port", "extra"} + + def _ensure_prefix(field): + if field not in conn_attrs and not field.startswith("extra__"): + return f"extra__{conn_type}__{field}" + else: + return field + + if "placeholders" in field_behaviors: + placeholders = field_behaviors["placeholders"] + field_behaviors["placeholders"] = {_ensure_prefix(k): v for k, v in placeholders.items()} + return field_behaviors + + return inner + + return dec + + +class SlackWebhookHook(BaseHook): + """ + This class provide a thin wrapper around the ``slack_sdk.WebhookClient``. + This hook allows you to post messages to Slack by using Incoming Webhooks. + + .. seealso:: + - :ref:`Slack Incoming Webhook connection ` + - https://api.slack.com/messaging/webhooks + - https://slack.dev/python-slack-sdk/webhook/index.html + + .. note:: + You cannot override the default channel (chosen by the user who installed your app), + username, or icon when you're using Incoming Webhooks to post messages. + Instead, these values will always inherit from the associated Slack App configuration + (`link `_). + It is possible to change this values only in `Legacy Slack Integration Incoming Webhook + `_. + + .. warning:: + This hook intend to use `Slack Incoming Webhook` connection + and might not work correctly with `Slack API` connection. + + Examples: + .. code-block:: python + + # Create hook + hook = SlackWebhookHook(slack_webhook_conn_id="slack_default") + + # Post message in Slack channel by JSON formatted message + # See: https://api.slack.com/messaging/webhooks#posting_with_webhooks + hook.send_dict({"text": "Hello world!"}) + + # Post simple message in Slack channel + hook.send_text("Hello world!") + + # Use ``slack_sdk.WebhookClient`` + hook.client.send(text="Hello world!") + + :param slack_webhook_conn_id: Slack Incoming Webhook connection id + that has Incoming Webhook token in the password field. + :param timeout: The maximum number of seconds the client will wait to connect + and receive a response from Slack. If not set than default WebhookClient value will use. + :param proxy: Proxy to make the Slack Incoming Webhook call. + :param retry_handlers: List of handlers to customize retry logic in ``slack_sdk.WebhookClient``. + :param webhook_token: (deprecated) Slack Incoming Webhook token. + Use instead Slack Incoming Webhook connection password field. """ - conn_name_attr = 'http_conn_id' - default_conn_name = 'slack_default' - conn_type = 'slackwebhook' - hook_name = 'Slack Webhook' + conn_name_attr = "slack_webhook_conn_id" + default_conn_name = "slack_default" + conn_type = "slackwebhook" + hook_name = "Slack Incoming Webhook" def __init__( self, - http_conn_id=None, - webhook_token=None, - message="", - attachments=None, - blocks=None, - channel=None, - username=None, - icon_emoji=None, - icon_url=None, - link_names=False, - proxy=None, - *args, + slack_webhook_conn_id: str | None = None, + webhook_token: str | None = None, + timeout: int | None = None, + proxy: str | None = None, + retry_handlers: list[RetryHandler] | None = None, **kwargs, ): - super().__init__(http_conn_id=http_conn_id, *args, **kwargs) - self.webhook_token = self._get_token(webhook_token, http_conn_id) - self.message = message - self.attachments = attachments - self.blocks = blocks - self.channel = channel - self.username = username - self.icon_emoji = icon_emoji - self.icon_url = icon_url - self.link_names = link_names - self.proxy = proxy + super().__init__() - def _get_token(self, token: str, http_conn_id: Optional[str]) -> str: - """ - Given either a manually set token or a conn_id, return the webhook_token to use. + http_conn_id = kwargs.pop("http_conn_id", None) + if http_conn_id: + warnings.warn( + "Parameter `http_conn_id` is deprecated. Please use `slack_webhook_conn_id` instead.", + DeprecationWarning, + stacklevel=2, + ) + if slack_webhook_conn_id: + raise AirflowException("You cannot provide both `slack_webhook_conn_id` and `http_conn_id`.") + slack_webhook_conn_id = http_conn_id - :param token: The manually provided token - :param http_conn_id: The conn_id provided - :return: webhook_token to use - :rtype: str - """ - if token: - return token - elif http_conn_id: - conn = self.get_connection(http_conn_id) + if not slack_webhook_conn_id and not webhook_token: + raise AirflowException("Either `slack_webhook_conn_id` or `webhook_token` should be provided.") + if webhook_token: + mask_secret(webhook_token) + warnings.warn( + "Provide `webhook_token` as hook argument deprecated by security reason and will be removed " + "in a future releases. Please specify it in `Slack Webhook` connection.", + DeprecationWarning, + stacklevel=2, + ) + if not slack_webhook_conn_id: + warnings.warn( + "You have not set parameter `slack_webhook_conn_id`. Currently `Slack Incoming Webhook` " + "connection id optional but in a future release it will mandatory.", + FutureWarning, + stacklevel=2, + ) - if getattr(conn, 'password', None): - return conn.password - else: - extra = conn.extra_dejson - web_token = extra.get('webhook_token', '') + self.slack_webhook_conn_id = slack_webhook_conn_id + self.timeout = timeout + self.proxy = proxy + self.retry_handlers = retry_handlers + self._webhook_token = webhook_token - if web_token: + # Compatibility with previous version of SlackWebhookHook + deprecated_class_attrs = [] + for deprecated_attr in ( + "message", + "attachments", + "blocks", + "channel", + "username", + "icon_emoji", + "icon_url", + "link_names", + ): + if deprecated_attr in kwargs: + deprecated_class_attrs.append(deprecated_attr) + setattr(self, deprecated_attr, kwargs.pop(deprecated_attr)) + if deprecated_attr == "message": + # Slack WebHook Post Request not expected `message` as field, + # so we also set "text" attribute which will check by SlackWebhookHook._resolve_argument + self.text = getattr(self, deprecated_attr) + elif deprecated_attr == "link_names": warnings.warn( - "'webhook_token' in 'extra' is deprecated. Please use 'password' field", - DeprecationWarning, + "`link_names` has no affect, if you want to mention user see: " + "https://api.slack.com/reference/surfaces/formatting#mentioning-users", + UserWarning, stacklevel=2, ) - return web_token + if deprecated_class_attrs: + warnings.warn( + f"Provide {','.join(repr(a) for a in deprecated_class_attrs)} as hook argument(s) " + f"is deprecated and will be removed in a future releases. " + f"Please specify attributes in `{self.__class__.__name__}.send` method instead.", + DeprecationWarning, + stacklevel=2, + ) + + self.extra_client_args = kwargs + + @cached_property + def client(self) -> WebhookClient: + """Get the underlying slack_sdk.webhook.WebhookClient (cached).""" + return WebhookClient(**self._get_conn_params()) + + def get_conn(self) -> WebhookClient: + """Get the underlying slack_sdk.webhook.WebhookClient (cached).""" + return self.client + + @cached_property + def webhook_token(self) -> str: + """Return Slack Webhook Token URL.""" + warnings.warn( + "`SlackHook.webhook_token` property deprecated and will be removed in a future releases.", + DeprecationWarning, + stacklevel=2, + ) + return self._get_conn_params()["url"] + + def _get_conn_params(self) -> dict[str, Any]: + """Fetch connection params as a dict and merge it with hook parameters.""" + default_schema, _, default_host = DEFAULT_SLACK_WEBHOOK_ENDPOINT.partition("://") + if self.slack_webhook_conn_id: + conn = self.get_connection(self.slack_webhook_conn_id) else: - raise AirflowException('Cannot get token: No valid Slack webhook token nor conn_id supplied') + # If slack_webhook_conn_id not specified, then use connection with default schema and host + conn = Connection( + conn_id=None, conn_type=self.conn_type, host=default_schema, password=default_host + ) + extra_config = ConnectionExtraConfig( + conn_type=self.conn_type, + conn_id=conn.conn_id, + extra=conn.extra_dejson, + ) + conn_params: dict[str, Any] = {"retry_handlers": self.retry_handlers} + + webhook_token = None + if self._webhook_token: + self.log.debug("Retrieving Slack Webhook Token from hook attribute.") + webhook_token = self._webhook_token + elif conn.conn_id: + if conn.password: + self.log.debug( + "Retrieving Slack Webhook Token from Connection ID %r password.", + self.slack_webhook_conn_id, + ) + webhook_token = conn.password + + webhook_token = webhook_token or "" + if not webhook_token and not conn.host: + raise AirflowException("Cannot get token: No valid Slack token nor valid Connection ID supplied.") + elif webhook_token and "://" in webhook_token: + self.log.debug("Retrieving Slack Webhook Token URL from webhook token.") + url = webhook_token + else: + self.log.debug("Constructing Slack Webhook Token URL.") + if conn.host and "://" in conn.host: + base_url = conn.host + else: + schema = conn.schema if conn.schema else default_schema + host = conn.host if conn.host else default_host + base_url = f"{schema}://{host}" + + base_url = base_url.rstrip("/") + if not webhook_token: + parsed_token = (urlsplit(base_url).path or "").strip("/") + if base_url == DEFAULT_SLACK_WEBHOOK_ENDPOINT or not parsed_token: + # Raise an error in case of password not specified and + # 1. Result of constructing base_url equal https://hooks.slack.com/services + # 2. Empty url path, e.g. if base_url = https://hooks.slack.com + raise AirflowException( + "Cannot get token: No valid Slack token nor valid Connection ID supplied." + ) + mask_secret(parsed_token) + warnings.warn( + f"Found Slack Webhook Token URL in Connection {conn.conn_id!r} `host` " + "and `password` field is empty. This behaviour deprecated " + "and could expose you token in the UI and will be removed in a future releases.", + DeprecationWarning, + stacklevel=2, + ) + url = (base_url.rstrip("/") + "/" + webhook_token.lstrip("/")).rstrip("/") + + conn_params["url"] = url + # Merge Hook parameters with Connection config + conn_params.update( + { + "timeout": self.timeout or extra_config.getint("timeout", default=None), + "proxy": self.proxy or extra_config.get("proxy", default=None), + } + ) + # Add additional client args + conn_params.update(self.extra_client_args) + if "logger" not in conn_params: + conn_params["logger"] = self.log + + return {k: v for k, v in conn_params.items() if v is not None} - def _build_slack_message(self) -> str: + def _resolve_argument(self, name: str, value): """ - Construct the Slack message. All relevant parameters are combined here to a valid - Slack json message. + Resolve message parameters. - :return: Slack message to send - :rtype: str + .. note:: + This method exist for compatibility and merge instance class attributes with + method attributes and not be required when assign class attributes to message + would completely remove. """ - cmd = {} - - if self.channel: - cmd['channel'] = self.channel - if self.username: - cmd['username'] = self.username - if self.icon_emoji: - cmd['icon_emoji'] = self.icon_emoji - if self.icon_url: - cmd['icon_url'] = self.icon_url - if self.link_names: - cmd['link_names'] = 1 - if self.attachments: - cmd['attachments'] = self.attachments - if self.blocks: - cmd['blocks'] = self.blocks - - cmd['text'] = self.message - return json.dumps(cmd) + if value is None and name in ( + "text", + "attachments", + "blocks", + "channel", + "username", + "icon_emoji", + "icon_url", + "link_names", + ): + return getattr(self, name, None) + + return value + + @check_webhook_response + def send_dict(self, body: dict[str, Any] | str, *, headers: dict[str, str] | None = None): + """ + Performs a Slack Incoming Webhook request with given JSON data block. + + :param body: JSON data structure, expected dict or JSON-string. + :param headers: Request headers for this request. + """ + if isinstance(body, str): + try: + body = json.loads(body) + except json.JSONDecodeError as err: + raise AirflowException( + f"Body expected valid JSON string, got {body!r}. Original error:\n * {err}" + ) from None + + if not isinstance(body, dict): + raise TypeError(f"Body expected dictionary, got {type(body).__name__}.") + + if any(legacy_attr in body for legacy_attr in ("channel", "username", "icon_emoji", "icon_url")): + warnings.warn( + "You cannot override the default channel (chosen by the user who installed your app), " + "username, or icon when you're using Incoming Webhooks to post messages. " + "Instead, these values will always inherit from the associated Slack app configuration. " + "See: https://api.slack.com/messaging/webhooks#advanced_message_formatting. " + "It is possible to change this values only in Legacy Slack Integration Incoming Webhook: " + "https://api.slack.com/legacy/custom-integrations/messaging/webhooks#legacy-customizations", + UserWarning, + stacklevel=2, + ) + + return self.client.send_dict(body, headers=headers) + + def send( + self, + *, + text: str | None = None, + attachments: list[dict[str, Any]] | None = None, + blocks: list[dict[str, Any]] | None = None, + response_type: str | None = None, + replace_original: bool | None = None, + delete_original: bool | None = None, + unfurl_links: bool | None = None, + unfurl_media: bool | None = None, + headers: dict[str, str] | None = None, + **kwargs, + ): + """ + Performs a Slack Incoming Webhook request with given arguments. + + :param text: The text message + (even when having blocks, setting this as well is recommended as it works as fallback). + :param attachments: A collection of attachments. + :param blocks: A collection of Block Kit UI components. + :param response_type: The type of message (either 'in_channel' or 'ephemeral'). + :param replace_original: True if you use this option for response_url requests. + :param delete_original: True if you use this option for response_url requests. + :param unfurl_links: Option to indicate whether text url should unfurl. + :param unfurl_media: Option to indicate whether media url should unfurl. + :param headers: Request headers for this request. + """ + body = { + "text": self._resolve_argument("text", text), + "attachments": self._resolve_argument("attachments", attachments), + "blocks": self._resolve_argument("blocks", blocks), + "response_type": response_type, + "replace_original": replace_original, + "delete_original": delete_original, + "unfurl_links": unfurl_links, + "unfurl_media": unfurl_media, + # Legacy Integration Parameters + **{lip: self._resolve_argument(lip, kwargs.pop(lip, None)) for lip in LEGACY_INTEGRATION_PARAMS}, + } + if kwargs: + warnings.warn( + f"Found unexpected keyword-argument(s) {', '.join(repr(k) for k in kwargs)} " + "in `send` method. This argument(s) have no effect.", + UserWarning, + stacklevel=2, + ) + body = {k: v for k, v in body.items() if v is not None} + return self.send_dict(body=body, headers=headers) + + def send_text( + self, + text: str, + *, + unfurl_links: bool | None = None, + unfurl_media: bool | None = None, + headers: dict[str, str] | None = None, + ): + """ + Performs a Slack Incoming Webhook request with given text. + + :param text: The text message. + :param unfurl_links: Option to indicate whether text url should unfurl. + :param unfurl_media: Option to indicate whether media url should unfurl. + :param headers: Request headers for this request. + """ + return self.send(text=text, unfurl_links=unfurl_links, unfurl_media=unfurl_media, headers=headers) + + @classmethod + def get_connection_form_widgets(cls) -> dict[str, Any]: + """Returns dictionary of widgets to be added for the hook to handle extra values.""" + from flask_appbuilder.fieldwidgets import BS3TextFieldWidget + from flask_babel import lazy_gettext + from wtforms import IntegerField, StringField + from wtforms.validators import NumberRange, Optional + + return { + "timeout": IntegerField( + lazy_gettext("Timeout"), + widget=BS3TextFieldWidget(), + validators=[Optional(), NumberRange(min=1)], + description="Optional. The maximum number of seconds the client will wait to connect " + "and receive a response from Slack Incoming Webhook.", + ), + "proxy": StringField( + lazy_gettext("Proxy"), + widget=BS3TextFieldWidget(), + description="Optional. Proxy to make the Slack Incoming Webhook call.", + ), + } + + @classmethod + @_ensure_prefixes(conn_type="slackwebhook") + def get_ui_field_behaviour(cls) -> dict[str, Any]: + """Returns custom field behaviour.""" + return { + "hidden_fields": ["login", "port", "extra"], + "relabeling": { + "host": "Slack Webhook Endpoint", + "password": "Webhook Token", + }, + "placeholders": { + "schema": "https", + "host": "hooks.slack.com/services", + "password": "T00000000/B00000000/XXXXXXXXXXXXXXXXXXXXXXXX", + "timeout": "30", + "proxy": "http://localhost:9000", + }, + } def execute(self) -> None: - """Remote Popen (actually execute the slack webhook call)""" - proxies = {} - if self.proxy: - # we only need https proxy for Slack, as the endpoint is https - proxies = {'https': self.proxy} - - slack_message = self._build_slack_message() - self.run( - endpoint=self.webhook_token, - data=slack_message, - headers={'Content-type': 'application/json'}, - extra_options={'proxies': proxies, 'check_response': True}, + """ + Remote Popen (actually execute the slack webhook call). + + .. note:: + This method exist for compatibility with previous version of operator + and expected that Slack Incoming Webhook message constructing from class attributes rather than + pass as method arguments. + """ + warnings.warn( + "`SlackWebhookHook.execute` method deprecated and will be removed in a future releases. " + "Please use `SlackWebhookHook.send` or `SlackWebhookHook.send_dict` or " + "`SlackWebhookHook.send_text` methods instead.", + DeprecationWarning, + stacklevel=2, ) + self.send() diff --git a/airflow/providers/slack/operators/slack.py b/airflow/providers/slack/operators/slack.py index 1aa5edc22be4a..4d586772472f2 100644 --- a/airflow/providers/slack/operators/slack.py +++ b/airflow/providers/slack/operators/slack.py @@ -15,11 +15,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import json -from typing import Any, Dict, List, Optional, Sequence +import warnings +from typing import Any, Sequence +from airflow.compat.functools import cached_property from airflow.models import BaseOperator from airflow.providers.slack.hooks.slack import SlackHook +from airflow.utils.log.secrets_masker import mask_secret class SlackAPIOperator(BaseOperator): @@ -29,7 +34,7 @@ class SlackAPIOperator(BaseOperator): In the future additional Slack API Operators will be derived from this class as well. Only one of `slack_conn_id` and `token` is required. - :param slack_conn_id: :ref:`Slack connection id ` + :param slack_conn_id: :ref:`Slack API Connection ` which its password is Slack API token. Optional :param token: Slack API token (https://api.slack.com/web). Optional :param method: The Slack API Method to Call (https://api.slack.com/methods). Optional @@ -40,20 +45,26 @@ class SlackAPIOperator(BaseOperator): def __init__( self, *, - slack_conn_id: Optional[str] = None, - token: Optional[str] = None, - method: Optional[str] = None, - api_params: Optional[Dict] = None, + slack_conn_id: str | None = None, + token: str | None = None, + method: str | None = None, + api_params: dict | None = None, **kwargs, ) -> None: super().__init__(**kwargs) - - self.token = token # type: Optional[str] - self.slack_conn_id = slack_conn_id # type: Optional[str] + if token: + mask_secret(token) + self.token = token + self.slack_conn_id = slack_conn_id self.method = method self.api_params = api_params + @cached_property + def hook(self) -> SlackHook: + """Slack Hook.""" + return SlackHook(token=self.token, slack_conn_id=self.slack_conn_id) + def construct_api_call_params(self) -> Any: """ Used by the execute function. Allows templating on the source fields @@ -70,14 +81,9 @@ def construct_api_call_params(self) -> Any: ) def execute(self, **kwargs): - """ - The SlackAPIOperator calls will not fail even if the call is not unsuccessful. - It should not prevent a DAG from completing in success - """ if not self.api_params: self.construct_api_call_params() - slack = SlackHook(token=self.token, slack_conn_id=self.slack_conn_id) - slack.call(self.method, json=self.api_params) + self.hook.call(self.method, json=self.api_params) class SlackAPIPostOperator(SlackAPIOperator): @@ -106,23 +112,23 @@ class SlackAPIPostOperator(SlackAPIOperator): - see https://api.slack.com/reference/block-kit/blocks. """ - template_fields: Sequence[str] = ('username', 'text', 'attachments', 'blocks', 'channel') - ui_color = '#FFBA40' + template_fields: Sequence[str] = ("username", "text", "attachments", "blocks", "channel") + ui_color = "#FFBA40" def __init__( self, - channel: str = '#general', - username: str = 'Airflow', - text: str = 'No message has been set.\n' - 'Here is a cat video instead\n' - 'https://www.youtube.com/watch?v=J---aiyznGQ', - icon_url: str = 'https://raw.githubusercontent.com/apache/' - 'airflow/main/airflow/www/static/pin_100.png', - attachments: Optional[List] = None, - blocks: Optional[List] = None, + channel: str = "#general", + username: str = "Airflow", + text: str = "No message has been set.\n" + "Here is a cat video instead\n" + "https://www.youtube.com/watch?v=J---aiyznGQ", + icon_url: str = "https://raw.githubusercontent.com/apache/" + "airflow/main/airflow/www/static/pin_100.png", + attachments: list | None = None, + blocks: list | None = None, **kwargs, ) -> None: - self.method = 'chat.postMessage' + self.method = "chat.postMessage" self.channel = channel self.username = username self.text = text @@ -133,18 +139,18 @@ def __init__( def construct_api_call_params(self) -> Any: self.api_params = { - 'channel': self.channel, - 'username': self.username, - 'text': self.text, - 'icon_url': self.icon_url, - 'attachments': json.dumps(self.attachments), - 'blocks': json.dumps(self.blocks), + "channel": self.channel, + "username": self.username, + "text": self.text, + "icon_url": self.icon_url, + "attachments": json.dumps(self.attachments), + "blocks": json.dumps(self.blocks), } class SlackAPIFileOperator(SlackAPIOperator): """ - Send a file to a slack channel + Send a file to a slack channels Examples: .. code-block:: python @@ -154,7 +160,7 @@ class SlackAPIFileOperator(SlackAPIOperator): task_id="slack_file_upload_1", dag=dag, slack_conn_id="slack", - channel="#general", + channels="#general,#random", initial_comment="Hello World!", filename="/files/dags/test.txt", filetype="txt", @@ -165,63 +171,67 @@ class SlackAPIFileOperator(SlackAPIOperator): task_id="slack_file_upload_2", dag=dag, slack_conn_id="slack", - channel="#general", + channels="#general", initial_comment="Hello World!", content="file content in txt", ) - :param channel: channel in which to sent file on slack name (templated) + :param channels: Comma-separated list of channel names or IDs where the file will be shared. + If set this argument to None, then file will send to associated workspace. (templated) :param initial_comment: message to send to slack. (templated) :param filename: name of the file (templated) - :param filetype: slack filetype. (templated) - - see https://api.slack.com/types/file + :param filetype: slack filetype. (templated) See: https://api.slack.com/types/file#file_types :param content: file content. (templated) + :param title: title of file. (templated) + :param channel: (deprecated) channel in which to sent file on slack name """ - template_fields: Sequence[str] = ('channel', 'initial_comment', 'filename', 'filetype', 'content') - ui_color = '#44BEDF' + template_fields: Sequence[str] = ( + "channels", + "initial_comment", + "filename", + "filetype", + "content", + "title", + ) + ui_color = "#44BEDF" def __init__( self, - channel: str = '#general', - initial_comment: str = 'No message has been set!', - filename: Optional[str] = None, - filetype: Optional[str] = None, - content: Optional[str] = None, + channels: str | Sequence[str] | None = None, + initial_comment: str | None = None, + filename: str | None = None, + filetype: str | None = None, + content: str | None = None, + title: str | None = None, + channel: str | None = None, **kwargs, ) -> None: - self.method = 'files.upload' - self.channel = channel + if channel: + warnings.warn( + "Argument `channel` is deprecated and will removed in a future releases. " + "Please use `channels` instead.", + DeprecationWarning, + stacklevel=2, + ) + if channels: + raise ValueError(f"Cannot set both arguments: channel={channel!r} and channels={channels!r}.") + channels = channel + + self.channels = channels self.initial_comment = initial_comment self.filename = filename self.filetype = filetype self.content = content - self.file_params: Dict = {} - super().__init__(method=self.method, **kwargs) + self.title = title + super().__init__(method="files.upload", **kwargs) def execute(self, **kwargs): - """ - The SlackAPIOperator calls will not fail even if the call is not unsuccessful. - It should not prevent a DAG from completing in success - """ - slack = SlackHook(token=self.token, slack_conn_id=self.slack_conn_id) - - # If file content is passed. - if self.content is not None: - self.api_params = { - 'channels': self.channel, - 'content': self.content, - 'initial_comment': self.initial_comment, - } - slack.call(self.method, data=self.api_params) - # If file name is passed. - elif self.filename is not None: - self.api_params = { - 'channels': self.channel, - 'filename': self.filename, - 'filetype': self.filetype, - 'initial_comment': self.initial_comment, - } - with open(self.filename, "rb") as file_handle: - slack.call(self.method, data=self.api_params, files={'file': file_handle}) - file_handle.close() + self.hook.send_file( + channels=self.channels, + # For historical reason SlackAPIFileOperator use filename as reference to file + file=self.filename, + content=self.content, + initial_comment=self.initial_comment, + title=self.title, + ) diff --git a/airflow/providers/slack/operators/slack_webhook.py b/airflow/providers/slack/operators/slack_webhook.py index c9a4c78fb0b2b..d4cf8ebba036c 100644 --- a/airflow/providers/slack/operators/slack_webhook.py +++ b/airflow/providers/slack/operators/slack_webhook.py @@ -15,29 +15,42 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -from typing import TYPE_CHECKING, Optional, Sequence +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Sequence -from airflow.providers.http.operators.http import SimpleHttpOperator +from airflow.compat.functools import cached_property +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook if TYPE_CHECKING: from airflow.utils.context import Context -class SlackWebhookOperator(SimpleHttpOperator): +class SlackWebhookOperator(BaseOperator): """ - This operator allows you to post messages to Slack using incoming webhooks. - Takes both Slack webhook token directly and connection that has Slack webhook token. - If both supplied, http_conn_id will be used as base_url, - and webhook_token will be taken as endpoint, the relative path of the url. + This operator allows you to post messages to Slack using Incoming Webhooks. + + .. note:: + You cannot override the default channel (chosen by the user who installed your app), + username, or icon when you're using Incoming Webhooks to post messages. + Instead, these values will always inherit from the associated Slack App configuration + (`link `_). + It is possible to change this values only in `Legacy Slack Integration Incoming Webhook + `_. - Each Slack webhook token can be pre-configured to use a specific channel, username and - icon. You can override these defaults in this hook. + .. warning:: + This operator could take Slack Webhook Token from ``webhook_token`` + as well as from :ref:`Slack Incoming Webhook connection `. + However, provide ``webhook_token`` it is not secure and this attribute + will be removed in the future version of provider. - :param http_conn_id: connection that has Slack webhook token in the extra field - :param webhook_token: Slack webhook token - :param message: The message you want to send on Slack + :param slack_webhook_conn_id: :ref:`Slack Incoming Webhook ` + connection id that has Incoming Webhook token in the password field. + :param message: The formatted text of the message to be published. + If ``blocks`` are included, this will become the fallback text used in notifications. :param attachments: The attachments to send on Slack. Should be a list of dictionaries representing Slack attachments. :param blocks: The blocks to send on Slack. Should be a list of @@ -49,37 +62,82 @@ class SlackWebhookOperator(SimpleHttpOperator): :param link_names: Whether or not to find and link channel and usernames in your message :param proxy: Proxy to use to make the Slack webhook call + :param webhook_token: (deprecated) Slack Incoming Webhook token. + Please use ``slack_webhook_conn_id`` instead. """ template_fields: Sequence[str] = ( - 'webhook_token', - 'message', - 'attachments', - 'blocks', - 'channel', - 'username', - 'proxy', + "webhook_token", + "message", + "attachments", + "blocks", + "channel", + "username", + "proxy", ) def __init__( self, *, - http_conn_id: str, - webhook_token: Optional[str] = None, + slack_webhook_conn_id: str | None = None, + webhook_token: str | None = None, message: str = "", - attachments: Optional[list] = None, - blocks: Optional[list] = None, - channel: Optional[str] = None, - username: Optional[str] = None, - icon_emoji: Optional[str] = None, - icon_url: Optional[str] = None, + attachments: list | None = None, + blocks: list | None = None, + channel: str | None = None, + username: str | None = None, + icon_emoji: str | None = None, + icon_url: str | None = None, link_names: bool = False, - proxy: Optional[str] = None, + proxy: str | None = None, **kwargs, ) -> None: - super().__init__(endpoint=webhook_token, **kwargs) - self.http_conn_id = http_conn_id + http_conn_id = kwargs.pop("http_conn_id", None) + if http_conn_id: + warnings.warn( + "Parameter `http_conn_id` is deprecated. Please use `slack_webhook_conn_id` instead.", + DeprecationWarning, + stacklevel=2, + ) + if slack_webhook_conn_id: + raise AirflowException("You cannot provide both `slack_webhook_conn_id` and `http_conn_id`.") + slack_webhook_conn_id = http_conn_id + + # Compatibility with previous version of operator which based on SimpleHttpOperator. + # Users might pass these arguments previously, however its never pass to SlackWebhookHook. + # We remove this arguments if found in ``kwargs`` and notify users if found any. + deprecated_class_attrs = [] + for deprecated_attr in ( + "endpoint", + "method", + "data", + "headers", + "response_check", + "response_filter", + "extra_options", + "log_response", + "auth_type", + "tcp_keep_alive", + "tcp_keep_alive_idle", + "tcp_keep_alive_count", + "tcp_keep_alive_interval", + ): + if deprecated_attr in kwargs: + deprecated_class_attrs.append(deprecated_attr) + kwargs.pop(deprecated_attr) + if deprecated_class_attrs: + warnings.warn( + f"Provide {','.join(repr(a) for a in deprecated_class_attrs)} is deprecated " + f"and as has no affect, please remove it from {self.__class__.__name__} " + "constructor attributes otherwise in future version of provider it might cause an issue.", + DeprecationWarning, + stacklevel=2, + ) + + super().__init__(**kwargs) + self.slack_webhook_conn_id = slack_webhook_conn_id self.webhook_token = webhook_token + self.proxy = proxy self.message = message self.attachments = attachments self.blocks = blocks @@ -88,22 +146,29 @@ def __init__( self.icon_emoji = icon_emoji self.icon_url = icon_url self.link_names = link_names - self.proxy = proxy - self.hook: Optional[SlackWebhookHook] = None - def execute(self, context: 'Context') -> None: + @cached_property + def hook(self) -> SlackWebhookHook: + """Create and return an SlackWebhookHook (cached).""" + return SlackWebhookHook( + slack_webhook_conn_id=self.slack_webhook_conn_id, + proxy=self.proxy, + # Deprecated. SlackWebhookHook will notify user if user provide non-empty ``webhook_token``. + webhook_token=self.webhook_token, + ) + + def execute(self, context: Context) -> None: """Call the SlackWebhookHook to post the provided Slack message""" - self.hook = SlackWebhookHook( - self.http_conn_id, - self.webhook_token, - self.message, - self.attachments, - self.blocks, - self.channel, - self.username, - self.icon_emoji, - self.icon_url, - self.link_names, - self.proxy, + self.hook.send( + text=self.message, + attachments=self.attachments, + blocks=self.blocks, + # Parameters below use for compatibility with previous version of Operator and warn user if it set + # Legacy Integration Parameters + channel=self.channel, + username=self.username, + icon_emoji=self.icon_emoji, + icon_url=self.icon_url, + # Unused Parameters, if not None than warn user + link_names=self.link_names, ) - self.hook.execute() diff --git a/airflow/providers/slack/provider.yaml b/airflow/providers/slack/provider.yaml index 503dbb26d503e..5082a9dec4e10 100644 --- a/airflow/providers/slack/provider.yaml +++ b/airflow/providers/slack/provider.yaml @@ -22,6 +22,10 @@ description: | `Slack `__ versions: + - 7.0.0 + - 6.0.0 + - 5.1.0 + - 5.0.0 - 4.2.3 - 4.2.2 - 4.2.1 @@ -33,8 +37,10 @@ versions: - 2.0.0 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-common-sql>=1.3.1 + - slack_sdk>=3.0.0 integrations: - integration-name: Slack @@ -56,9 +62,14 @@ hooks: - airflow.providers.slack.hooks.slack - airflow.providers.slack.hooks.slack_webhook -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.slack.hooks.slack_webhook.SlackWebhookHook +transfers: + - source-integration-name: Common SQL + target-integration-name: Slack + python-module: airflow.providers.slack.transfers.sql_to_slack + how-to-guide: /docs/apache-airflow-providers-slack/operators/sql_to_slack.rst connection-types: + - hook-class-name: airflow.providers.slack.hooks.slack.SlackHook + connection-type: slack - hook-class-name: airflow.providers.slack.hooks.slack_webhook.SlackWebhookHook connection-type: slackwebhook diff --git a/airflow/providers/jira/__init__.py b/airflow/providers/slack/transfers/__init__.py similarity index 100% rename from airflow/providers/jira/__init__.py rename to airflow/providers/slack/transfers/__init__.py diff --git a/airflow/providers/slack/transfers/sql_to_slack.py b/airflow/providers/slack/transfers/sql_to_slack.py new file mode 100644 index 0000000000000..cf5c01b22c9cf --- /dev/null +++ b/airflow/providers/slack/transfers/sql_to_slack.py @@ -0,0 +1,299 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from tempfile import NamedTemporaryFile +from typing import TYPE_CHECKING, Iterable, Mapping, Sequence + +from pandas import DataFrame +from tabulate import tabulate + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.models import BaseOperator +from airflow.providers.common.sql.hooks.sql import DbApiHook +from airflow.providers.slack.hooks.slack import SlackHook +from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook +from airflow.providers.slack.utils import parse_filename + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class BaseSqlToSlackOperator(BaseOperator): + """ + Operator implements base sql methods for SQL to Slack Transfer operators. + + :param sql: The SQL query to be executed + :param sql_conn_id: reference to a specific DB-API Connection. + :param sql_hook_params: Extra config params to be passed to the underlying hook. + Should match the desired hook constructor params. + :param parameters: The parameters to pass to the SQL query. + """ + + def __init__( + self, + *, + sql: str, + sql_conn_id: str, + sql_hook_params: dict | None = None, + parameters: Iterable | Mapping | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.sql_conn_id = sql_conn_id + self.sql_hook_params = sql_hook_params + self.sql = sql + self.parameters = parameters + + def _get_hook(self) -> DbApiHook: + self.log.debug("Get connection for %s", self.sql_conn_id) + conn = BaseHook.get_connection(self.sql_conn_id) + hook = conn.get_hook(hook_params=self.sql_hook_params) + if not callable(getattr(hook, "get_pandas_df", None)): + raise AirflowException( + "This hook is not supported. The hook class must have get_pandas_df method." + ) + return hook + + def _get_query_results(self) -> DataFrame: + sql_hook = self._get_hook() + + self.log.info("Running SQL query: %s", self.sql) + df = sql_hook.get_pandas_df(self.sql, parameters=self.parameters) + return df + + +class SqlToSlackOperator(BaseSqlToSlackOperator): + """ + Executes an SQL statement in a given SQL connection and sends the results to Slack. The results of the + query are rendered into the 'slack_message' parameter as a Pandas dataframe using a JINJA variable called + '{{ results_df }}'. The 'results_df' variable name can be changed by specifying a different + 'results_df_name' parameter. The Tabulate library is added to the JINJA environment as a filter to + allow the dataframe to be rendered nicely. For example, set 'slack_message' to {{ results_df | + tabulate(tablefmt="pretty", headers="keys") }} to send the results to Slack as an ascii rendered table. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SqlToSlackOperator` + + :param sql: The SQL query to be executed (templated) + :param slack_message: The templated Slack message to send with the data returned from the SQL connection. + You can use the default JINJA variable {{ results_df }} to access the pandas dataframe containing the + SQL results + :param sql_conn_id: reference to a specific database. + :param sql_hook_params: Extra config params to be passed to the underlying hook. + Should match the desired hook constructor params. + :param slack_conn_id: The connection id for Slack. + :param slack_webhook_token: The token to use to authenticate to Slack. If this is not provided, the + 'slack_conn_id' attribute needs to be specified in the 'password' field. + :param slack_channel: The channel to send message. Override default from Slack connection. + :param results_df_name: The name of the JINJA template's dataframe variable, default is 'results_df' + :param parameters: The parameters to pass to the SQL query + """ + + template_fields: Sequence[str] = ("sql", "slack_message") + template_ext: Sequence[str] = (".sql", ".jinja", ".j2") + template_fields_renderers = {"sql": "sql", "slack_message": "jinja"} + times_rendered = 0 + + def __init__( + self, + *, + sql: str, + sql_conn_id: str, + sql_hook_params: dict | None = None, + slack_conn_id: str | None = None, + slack_webhook_token: str | None = None, + slack_channel: str | None = None, + slack_message: str, + results_df_name: str = "results_df", + parameters: Iterable | Mapping | None = None, + **kwargs, + ) -> None: + + super().__init__( + sql=sql, sql_conn_id=sql_conn_id, sql_hook_params=sql_hook_params, parameters=parameters, **kwargs + ) + + self.slack_conn_id = slack_conn_id + self.slack_webhook_token = slack_webhook_token + self.slack_channel = slack_channel + self.slack_message = slack_message + self.results_df_name = results_df_name + self.kwargs = kwargs + + if not self.slack_conn_id and not self.slack_webhook_token: + raise AirflowException( + "SqlToSlackOperator requires either a `slack_conn_id` or a `slack_webhook_token` argument" + ) + + def _render_and_send_slack_message(self, context, df) -> None: + # Put the dataframe into the context and render the JINJA template fields + context[self.results_df_name] = df + self.render_template_fields(context) + + slack_hook = self._get_slack_hook() + self.log.info("Sending slack message: %s", self.slack_message) + slack_hook.send(text=self.slack_message, channel=self.slack_channel) + + def _get_slack_hook(self) -> SlackWebhookHook: + return SlackWebhookHook( + slack_webhook_conn_id=self.slack_conn_id, webhook_token=self.slack_webhook_token + ) + + def render_template_fields(self, context, jinja_env=None) -> None: + # If this is the first render of the template fields, exclude slack_message from rendering since + # the SQL results haven't been retrieved yet. + if self.times_rendered == 0: + fields_to_render: Iterable[str] = filter(lambda x: x != "slack_message", self.template_fields) + else: + fields_to_render = self.template_fields + + if not jinja_env: + jinja_env = self.get_template_env() + + # Add the tabulate library into the JINJA environment + jinja_env.filters["tabulate"] = tabulate + + self._do_render_template_fields(self, fields_to_render, context, jinja_env, set()) + self.times_rendered += 1 + + def execute(self, context: Context) -> None: + if not isinstance(self.sql, str): + raise AirflowException("Expected 'sql' parameter should be a string.") + if self.sql is None or self.sql.strip() == "": + raise AirflowException("Expected 'sql' parameter is missing.") + if self.slack_message is None or self.slack_message.strip() == "": + raise AirflowException("Expected 'slack_message' parameter is missing.") + + df = self._get_query_results() + self._render_and_send_slack_message(context, df) + + self.log.debug("Finished sending SQL data to Slack") + + +class SqlToSlackApiFileOperator(BaseSqlToSlackOperator): + """ + Executes an SQL statement in a given SQL connection and sends the results to Slack API as file. + + :param sql: The SQL query to be executed + :param sql_conn_id: reference to a specific DB-API Connection. + :param slack_conn_id: :ref:`Slack API Connection `. + :param slack_filename: Filename for display in slack. + Should contain supported extension which referenced to ``SUPPORTED_FILE_FORMATS``. + It is also possible to set compression in extension: + ``filename.csv.gzip``, ``filename.json.zip``, etc. + :param sql_hook_params: Extra config params to be passed to the underlying hook. + Should match the desired hook constructor params. + :param parameters: The parameters to pass to the SQL query. + :param slack_channels: Comma-separated list of channel names or IDs where the file will be shared. + If omitting this parameter, then file will send to workspace. + :param slack_initial_comment: The message text introducing the file in specified ``slack_channels``. + :param slack_title: Title of file. + :param df_kwargs: Keyword arguments forwarded to ``pandas.DataFrame.to_{format}()`` method. + + Example: + .. code-block:: python + + SqlToSlackApiFileOperator( + task_id="sql_to_slack", + sql="SELECT 1 a, 2 b, 3 c", + sql_conn_id="sql-connection", + slack_conn_id="slack-api-connection", + slack_filename="awesome.json.gz", + slack_channels="#random,#general", + slack_initial_comment="Awesome load to compressed multiline JSON.", + df_kwargs={ + "orient": "records", + "lines": True, + }, + ) + """ + + template_fields: Sequence[str] = ( + "sql", + "slack_channels", + "slack_filename", + "slack_initial_comment", + "slack_title", + ) + template_ext: Sequence[str] = (".sql", ".jinja", ".j2") + template_fields_renderers = {"sql": "sql", "slack_message": "jinja"} + + SUPPORTED_FILE_FORMATS: Sequence[str] = ("csv", "json", "html") + + def __init__( + self, + *, + sql: str, + sql_conn_id: str, + sql_hook_params: dict | None = None, + parameters: Iterable | Mapping | None = None, + slack_conn_id: str, + slack_filename: str, + slack_channels: str | Sequence[str] | None = None, + slack_initial_comment: str | None = None, + slack_title: str | None = None, + df_kwargs: dict | None = None, + **kwargs, + ): + super().__init__( + sql=sql, sql_conn_id=sql_conn_id, sql_hook_params=sql_hook_params, parameters=parameters, **kwargs + ) + self.slack_conn_id = slack_conn_id + self.slack_filename = slack_filename + self.slack_channels = slack_channels + self.slack_initial_comment = slack_initial_comment + self.slack_title = slack_title + self.df_kwargs = df_kwargs or {} + + def execute(self, context: Context) -> None: + # Parse file format from filename + output_file_format, _ = parse_filename( + filename=self.slack_filename, + supported_file_formats=self.SUPPORTED_FILE_FORMATS, + ) + + slack_hook = SlackHook(slack_conn_id=self.slack_conn_id) + with NamedTemporaryFile(mode="w+", suffix=f"_{self.slack_filename}") as fp: + # tempfile.NamedTemporaryFile used only for create and remove temporary file, + # pandas will open file in correct mode itself depend on file type. + # So we close file descriptor here for avoid incidentally write anything. + fp.close() + + output_file_name = fp.name + output_file_format = output_file_format.upper() + df_result = self._get_query_results() + if output_file_format == "CSV": + df_result.to_csv(output_file_name, **self.df_kwargs) + elif output_file_format == "JSON": + df_result.to_json(output_file_name, **self.df_kwargs) + elif output_file_format == "HTML": + df_result.to_html(output_file_name, **self.df_kwargs) + else: + # Not expected that this error happen. This only possible + # if SUPPORTED_FILE_FORMATS extended and no actual implementation for specific format. + raise AirflowException(f"Unexpected output file format: {output_file_format}") + + slack_hook.send_file( + channels=self.slack_channels, + file=output_file_name, + filename=self.slack_filename, + initial_comment=self.slack_initial_comment, + title=self.slack_title, + ) diff --git a/airflow/providers/slack/utils/__init__.py b/airflow/providers/slack/utils/__init__.py new file mode 100644 index 0000000000000..1071de6299c27 --- /dev/null +++ b/airflow/providers/slack/utils/__init__.py @@ -0,0 +1,117 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import warnings +from typing import Any, Sequence + +from airflow.utils.types import NOTSET + + +class ConnectionExtraConfig: + """Helper class for rom Connection Extra. + + :param conn_type: Hook connection type. + :param conn_id: Connection ID uses for appropriate error messages. + :param extra: Connection extra dictionary. + """ + + def __init__(self, conn_type: str, conn_id: str | None = None, extra: dict[str, Any] | None = None): + super().__init__() + self.conn_type = conn_type + self.conn_id = conn_id + self.extra = extra or {} + + def get(self, field, default: Any = NOTSET): + """Get specified field from Connection Extra. + + :param field: Connection extra field name. + :param default: If specified then use as default value if field not present in Connection Extra. + """ + backcompat_key = f"extra__{self.conn_type}__{field}" + if self.extra.get(field) not in (None, ""): + if self.extra.get(backcompat_key) not in (None, ""): + warnings.warn( + f"Conflicting params `{field}` and `{backcompat_key}` found in extras for conn " + f"{self.conn_id}. Using value for `{field}`. Please ensure this is the correct value " + f"and remove the backcompat key `{backcompat_key}`." + ) + return self.extra[field] + elif backcompat_key in self.extra and self.extra[backcompat_key] not in (None, ""): + # Addition validation with non-empty required for connection which created in the UI + # in Airflow 2.2. In these connections always present key-value pair for all prefixed extras + # even if user do not fill this fields. + # In additional fields from `wtforms.IntegerField` might contain None value. + # E.g.: `{'extra__slackwebhook__proxy': '', 'extra__slackwebhook__timeout': None}` + # From Airflow 2.3, using the prefix is no longer required. + return self.extra[backcompat_key] + else: + if default is NOTSET: + raise KeyError( + f"Couldn't find {backcompat_key!r} or {field!r} " + f"in Connection ({self.conn_id!r}) Extra and no default value specified." + ) + return default + + def getint(self, field, default: Any = NOTSET) -> Any: + """Get specified field from Connection Extra and evaluate as integer. + + :param field: Connection extra field name. + :param default: If specified then use as default value if field not present in Connection Extra. + """ + value = self.get(field=field, default=default) + if value != default: + value = int(value) + return value + + +def parse_filename( + filename: str, supported_file_formats: Sequence[str], fallback: str | None = None +) -> tuple[str, str | None]: + """ + Parse filetype and compression from given filename. + :param filename: filename to parse. + :param supported_file_formats: list of supported file extensions. + :param fallback: fallback to given file format. + :returns: filetype and compression (if specified) + """ + if not filename: + raise ValueError("Expected 'filename' parameter is missing.") + if fallback and fallback not in supported_file_formats: + raise ValueError(f"Invalid fallback value {fallback!r}, expected one of {supported_file_formats}.") + + parts = filename.rsplit(".", 2) + try: + if len(parts) == 1: + raise ValueError(f"No file extension specified in filename {filename!r}.") + if parts[-1] in supported_file_formats: + return parts[-1], None + elif len(parts) == 2: + raise ValueError( + f"Unsupported file format {parts[-1]!r}, expected one of {supported_file_formats}." + ) + else: + if parts[-2] not in supported_file_formats: + raise ValueError( + f"Unsupported file format '{parts[-2]}.{parts[-1]}', " + f"expected one of {supported_file_formats} with compression extension." + ) + return parts[-2], parts[-1] + except ValueError as ex: + if fallback: + return fallback, None + raise ex from None diff --git a/airflow/providers/snowflake/CHANGELOG.rst b/airflow/providers/snowflake/CHANGELOG.rst index ce11f8f1803ca..eba5b2a87885b 100644 --- a/airflow/providers/snowflake/CHANGELOG.rst +++ b/airflow/providers/snowflake/CHANGELOG.rst @@ -16,9 +16,116 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +4.0.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +In SnowflakeHook, if both ``extra__snowflake__foo`` and ``foo`` existed in connection extra +dict, the prefixed version would be used; now, the non-prefixed version will be preferred. + +* ``Update snowflake hook to not use extra prefix (#26764)`` + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Features +~~~~~~~~ + +* ``Add SQLExecuteQueryOperator (#25717)`` + +Bug fixes +~~~~~~~~~ + +* ``Use unused SQLCheckOperator.parameters in SQLCheckOperator.execute. (#27599)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + +3.3.0 +..... + +Features +~~~~~~~~ + +* ``Add custom handler param in SnowflakeOperator (#25983)`` + +Bug Fixes +~~~~~~~~~ + +* ``Fix wrong deprecation warning for 'S3ToSnowflakeOperator' (#26047)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``copy into snowflake from external stage (#25541)`` + +3.2.0 +..... + +Features +~~~~~~~~ + +* ``Move all "old" SQL operators to common.sql providers (#25350)`` +* ``Unify DbApiHook.run() method with the methods which override it (#23971)`` + + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Adding generic 'SqlToSlackOperator' (#24663)`` +* ``Move all SQL classes to common-sql provider (#24836)`` +* ``Pattern parameter in S3ToSnowflakeOperator (#24571)`` + +Bug Fixes +~~~~~~~~~ + +* ``S3ToSnowflakeOperator: escape single quote in s3_keys (#24607)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Bug Fixes +~~~~~~~~~ + +* ``Fix error when SnowflakeHook take empty list in 'sql' param (#23767)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Migrate Snowflake system tests to new design #22434 (#24151)`` + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.7.0 ..... diff --git a/airflow/providers/snowflake/example_dags/example_snowflake.py b/airflow/providers/snowflake/example_dags/example_snowflake.py deleted file mode 100644 index 12144884059ec..0000000000000 --- a/airflow/providers/snowflake/example_dags/example_snowflake.py +++ /dev/null @@ -1,135 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example use of Snowflake related operators. -""" -from datetime import datetime - -from airflow import DAG -from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator -from airflow.providers.snowflake.transfers.s3_to_snowflake import S3ToSnowflakeOperator -from airflow.providers.snowflake.transfers.snowflake_to_slack import SnowflakeToSlackOperator - -SNOWFLAKE_CONN_ID = 'my_snowflake_conn' -SLACK_CONN_ID = 'my_slack_conn' -# TODO: should be able to rely on connection's schema, but currently param required by S3ToSnowflakeTransfer -SNOWFLAKE_SCHEMA = 'schema_name' -SNOWFLAKE_STAGE = 'stage_name' -SNOWFLAKE_WAREHOUSE = 'warehouse_name' -SNOWFLAKE_DATABASE = 'database_name' -SNOWFLAKE_ROLE = 'role_name' -SNOWFLAKE_SAMPLE_TABLE = 'sample_table' -S3_FILE_PATH = '> [ - snowflake_op_with_params, - snowflake_op_sql_list, - snowflake_op_template_file, - copy_into_table, - snowflake_op_sql_multiple_stmts, - ] - >> slack_report -) diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index 29a4b63156c6d..e525efe763439 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -15,22 +15,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import os from contextlib import closing +from functools import wraps from io import StringIO from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, Iterable, Mapping from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from snowflake import connector -from snowflake.connector import DictCursor, SnowflakeConnection -from snowflake.connector.util_text import split_statements +from snowflake.connector import DictCursor, SnowflakeConnection, util_text from snowflake.sqlalchemy import URL from sqlalchemy import create_engine from airflow import AirflowException -from airflow.hooks.dbapi import DbApiHook +from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.utils.strings import to_boolean @@ -40,6 +42,34 @@ def _try_to_boolean(value: Any): return value +def _ensure_prefixes(conn_type): + """ + Remove when provider min airflow version >= 2.5.0 since this is handled by + provider manager from that version. + """ + + def dec(func): + @wraps(func) + def inner(): + field_behaviors = func() + conn_attrs = {"host", "schema", "login", "password", "port", "extra"} + + def _ensure_prefix(field): + if field not in conn_attrs and not field.startswith("extra__"): + return f"extra__{conn_type}__{field}" + else: + return field + + if "placeholders" in field_behaviors: + placeholders = field_behaviors["placeholders"] + field_behaviors["placeholders"] = {_ensure_prefix(k): v for k, v in placeholders.items()} + return field_behaviors + + return inner + + return dec + + class SnowflakeHook(DbApiHook): """ A client to interact with Snowflake. @@ -77,48 +107,46 @@ class SnowflakeHook(DbApiHook): :ref:`howto/operator:SnowflakeOperator` """ - conn_name_attr = 'snowflake_conn_id' - default_conn_name = 'snowflake_default' - conn_type = 'snowflake' - hook_name = 'Snowflake' + conn_name_attr = "snowflake_conn_id" + default_conn_name = "snowflake_default" + conn_type = "snowflake" + hook_name = "Snowflake" supports_autocommit = True + _test_connection_sql = "select 1" @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: + def get_connection_form_widgets() -> dict[str, Any]: """Returns connection widgets to add to connection form""" from flask_appbuilder.fieldwidgets import BS3TextAreaFieldWidget, BS3TextFieldWidget from flask_babel import lazy_gettext from wtforms import BooleanField, StringField return { - "extra__snowflake__account": StringField(lazy_gettext('Account'), widget=BS3TextFieldWidget()), - "extra__snowflake__warehouse": StringField( - lazy_gettext('Warehouse'), widget=BS3TextFieldWidget() - ), - "extra__snowflake__database": StringField(lazy_gettext('Database'), widget=BS3TextFieldWidget()), - "extra__snowflake__region": StringField(lazy_gettext('Region'), widget=BS3TextFieldWidget()), - "extra__snowflake__role": StringField(lazy_gettext('Role'), widget=BS3TextFieldWidget()), - "extra__snowflake__private_key_file": StringField( - lazy_gettext('Private key (Path)'), widget=BS3TextFieldWidget() - ), - "extra__snowflake__private_key_content": StringField( - lazy_gettext('Private key (Text)'), widget=BS3TextAreaFieldWidget() + "account": StringField(lazy_gettext("Account"), widget=BS3TextFieldWidget()), + "warehouse": StringField(lazy_gettext("Warehouse"), widget=BS3TextFieldWidget()), + "database": StringField(lazy_gettext("Database"), widget=BS3TextFieldWidget()), + "region": StringField(lazy_gettext("Region"), widget=BS3TextFieldWidget()), + "role": StringField(lazy_gettext("Role"), widget=BS3TextFieldWidget()), + "private_key_file": StringField(lazy_gettext("Private key (Path)"), widget=BS3TextFieldWidget()), + "private_key_content": StringField( + lazy_gettext("Private key (Text)"), widget=BS3TextAreaFieldWidget() ), - "extra__snowflake__insecure_mode": BooleanField( - label=lazy_gettext('Insecure mode'), description="Turns off OCSP certificate checks" + "insecure_mode": BooleanField( + label=lazy_gettext("Insecure mode"), description="Turns off OCSP certificate checks" ), } @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + @_ensure_prefixes(conn_type="snowflake") + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" import json return { - "hidden_fields": ['port'], + "hidden_fields": ["port"], "relabeling": {}, "placeholders": { - 'extra': json.dumps( + "extra": json.dumps( { "authenticator": "snowflake oauth", "private_key_file": "private key", @@ -126,17 +154,17 @@ def get_ui_field_behaviour() -> Dict[str, Any]: }, indent=1, ), - 'schema': 'snowflake schema', - 'login': 'snowflake username', - 'password': 'snowflake password', - 'extra__snowflake__account': 'snowflake account name', - 'extra__snowflake__warehouse': 'snowflake warehouse name', - 'extra__snowflake__database': 'snowflake db name', - 'extra__snowflake__region': 'snowflake hosted region', - 'extra__snowflake__role': 'snowflake role', - 'extra__snowflake__private_key_file': 'Path of snowflake private key (PEM Format)', - 'extra__snowflake__private_key_content': 'Content to snowflake private key (PEM format)', - 'extra__snowflake__insecure_mode': 'insecure mode', + "schema": "snowflake schema", + "login": "snowflake username", + "password": "snowflake password", + "account": "snowflake account name", + "warehouse": "snowflake warehouse name", + "database": "snowflake db name", + "region": "snowflake hosted region", + "role": "snowflake role", + "private_key_file": "Path of snowflake private key (PEM Format)", + "private_key_content": "Content to snowflake private key (PEM format)", + "insecure_mode": "insecure mode", }, } @@ -150,37 +178,50 @@ def __init__(self, *args, **kwargs) -> None: self.schema = kwargs.pop("schema", None) self.authenticator = kwargs.pop("authenticator", None) self.session_parameters = kwargs.pop("session_parameters", None) - self.query_ids: List[str] = [] - - def _get_conn_params(self) -> Dict[str, Optional[str]]: + self.query_ids: list[str] = [] + + def _get_field(self, extra_dict, field_name): + backcompat_prefix = "extra__snowflake__" + backcompat_key = f"{backcompat_prefix}{field_name}" + if field_name.startswith("extra__"): + raise ValueError( + f"Got prefixed name {field_name}; please remove the '{backcompat_prefix}' prefix " + f"when using this method." + ) + if field_name in extra_dict: + import warnings + + if backcompat_key in extra_dict: + warnings.warn( + f"Conflicting params `{field_name}` and `{backcompat_key}` found in extras. " + f"Using value for `{field_name}`. Please ensure this is the correct " + f"value and remove the backcompat key `{backcompat_key}`." + ) + return extra_dict[field_name] or None + return extra_dict.get(backcompat_key) or None + + def _get_conn_params(self) -> dict[str, str | None]: """ One method to fetch connection params as a dict used in get_uri() and get_connection() """ conn = self.get_connection(self.snowflake_conn_id) # type: ignore[attr-defined] - account = conn.extra_dejson.get('extra__snowflake__account', '') or conn.extra_dejson.get( - 'account', '' - ) - warehouse = conn.extra_dejson.get('extra__snowflake__warehouse', '') or conn.extra_dejson.get( - 'warehouse', '' - ) - database = conn.extra_dejson.get('extra__snowflake__database', '') or conn.extra_dejson.get( - 'database', '' - ) - region = conn.extra_dejson.get('extra__snowflake__region', '') or conn.extra_dejson.get('region', '') - role = conn.extra_dejson.get('extra__snowflake__role', '') or conn.extra_dejson.get('role', '') - schema = conn.schema or '' - authenticator = conn.extra_dejson.get('authenticator', 'snowflake') - session_parameters = conn.extra_dejson.get('session_parameters') - insecure_mode = _try_to_boolean( - conn.extra_dejson.get( - 'extra__snowflake__insecure_mode', conn.extra_dejson.get('insecure_mode', None) - ) - ) + extra_dict = conn.extra_dejson + account = self._get_field(extra_dict, "account") or "" + warehouse = self._get_field(extra_dict, "warehouse") or "" + database = self._get_field(extra_dict, "database") or "" + region = self._get_field(extra_dict, "region") or "" + role = self._get_field(extra_dict, "role") or "" + insecure_mode = _try_to_boolean(self._get_field(extra_dict, "insecure_mode")) + schema = conn.schema or "" + + # authenticator and session_parameters never supported long name so we don't use _get_field + authenticator = extra_dict.get("authenticator", "snowflake") + session_parameters = extra_dict.get("session_parameters") conn_config = { "user": conn.login, - "password": conn.password or '', + "password": conn.password or "", "schema": self.schema or schema, "database": self.database or database, "account": self.account or account, @@ -193,7 +234,7 @@ def _get_conn_params(self) -> Dict[str, Optional[str]]: "application": os.environ.get("AIRFLOW_SNOWFLAKE_PARTNER", "AIRFLOW"), } if insecure_mode: - conn_config['insecure_mode'] = insecure_mode + conn_config["insecure_mode"] = insecure_mode # If private_key_file is specified in the extra json, load the contents of the file as a private key. # If private_key_content is specified in the extra json, use it as a private key. @@ -201,12 +242,8 @@ def _get_conn_params(self) -> Dict[str, Optional[str]]: # The connection password then becomes the passphrase for the private key. # If your private key is not encrypted (not recommended), then leave the password empty. - private_key_file = conn.extra_dejson.get( - 'extra__snowflake__private_key_file' - ) or conn.extra_dejson.get('private_key_file') - private_key_content = conn.extra_dejson.get( - 'extra__snowflake__private_key_content' - ) or conn.extra_dejson.get('private_key_content') + private_key_file = self._get_field(extra_dict, "private_key_file") + private_key_content = self._get_field(extra_dict, "private_key_content") private_key_pem = None if private_key_content and private_key_file: @@ -234,8 +271,8 @@ def _get_conn_params(self) -> Dict[str, Optional[str]]: encryption_algorithm=serialization.NoEncryption(), ) - conn_config['private_key'] = pkb - conn_config.pop('password', None) + conn_config["private_key"] = pkb + conn_config.pop("password", None) return conn_config @@ -244,12 +281,12 @@ def get_uri(self) -> str: conn_params = self._get_conn_params() return self._conn_params_to_sqlalchemy_uri(conn_params) - def _conn_params_to_sqlalchemy_uri(self, conn_params: Dict) -> str: + def _conn_params_to_sqlalchemy_uri(self, conn_params: dict) -> str: return URL( **{ k: v for k, v in conn_params.items() - if v and k not in ['session_parameters', 'insecure_mode', 'private_key'] + if v and k not in ["session_parameters", "insecure_mode", "private_key"] } ) @@ -268,13 +305,13 @@ def get_sqlalchemy_engine(self, engine_kwargs=None): """ engine_kwargs = engine_kwargs or {} conn_params = self._get_conn_params() - if 'insecure_mode' in conn_params: - engine_kwargs.setdefault('connect_args', dict()) - engine_kwargs['connect_args']['insecure_mode'] = True - for key in ['session_parameters', 'private_key']: + if "insecure_mode" in conn_params: + engine_kwargs.setdefault("connect_args", dict()) + engine_kwargs["connect_args"]["insecure_mode"] = True + for key in ["session_parameters", "private_key"]: if conn_params.get(key): - engine_kwargs.setdefault('connect_args', dict()) - engine_kwargs['connect_args'][key] = conn_params[key] + engine_kwargs.setdefault("connect_args", dict()) + engine_kwargs["connect_args"][key] = conn_params[key] return create_engine(self._conn_params_to_sqlalchemy_uri(conn_params), **engine_kwargs) def set_autocommit(self, conn, autocommit: Any) -> None: @@ -282,15 +319,17 @@ def set_autocommit(self, conn, autocommit: Any) -> None: conn.autocommit_mode = autocommit def get_autocommit(self, conn): - return getattr(conn, 'autocommit_mode', False) + return getattr(conn, "autocommit_mode", False) def run( self, - sql: Union[str, list], + sql: str | Iterable[str], autocommit: bool = False, - parameters: Optional[Union[Sequence[Any], Dict[Any, Any]]] = None, - handler: Optional[Callable] = None, - ): + parameters: Iterable | Mapping | None = None, + handler: Callable | None = None, + split_statements: bool = True, + return_last: bool = True, + ) -> Any | list[Any] | None: """ Runs a command or a list of commands. Pass a list of sql statements to the sql parameter to get them to execute @@ -305,15 +344,22 @@ def run( before executing the query. :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. + :param split_statements: Whether to split a single SQL string into statements and run separately + :param return_last: Whether to return result for only last statement or for all after split + :return: return only result of the LAST SQL expression if handler was provided. """ self.query_ids = [] + self.scalar_return_last = isinstance(sql, str) and return_last if isinstance(sql, str): - split_statements_tuple = split_statements(StringIO(sql)) - sql = [sql_string for sql_string, _ in split_statements_tuple if sql_string] + if split_statements: + split_statements_tuple = util_text.split_statements(StringIO(sql)) + sql = [sql_string for sql_string, _ in split_statements_tuple if sql_string] + else: + sql = [self.strip_sql_string(sql)] if sql: - self.log.debug("Executing %d statements against Snowflake DB", len(sql)) + self.log.debug("Executing following statements against Snowflake DB: %s", list(sql)) else: raise ValueError("List of SQL statements is empty") @@ -322,38 +368,26 @@ def run( # SnowflakeCursor does not extend ContextManager, so we have to ignore mypy error here with closing(conn.cursor(DictCursor)) as cur: # type: ignore[type-var] - + results = [] for sql_statement in sql: + self._run_command(cur, sql_statement, parameters) - self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters) - if parameters: - cur.execute(sql_statement, parameters) - else: - cur.execute(sql_statement) - - execution_info = [] if handler is not None: - cur = handler(cur) - for row in cur: - self.log.info("Statement execution info - %s", row) - execution_info.append(row) + result = handler(cur) + results.append(result) query_id = cur.sfqid self.log.info("Rows affected: %s", cur.rowcount) self.log.info("Snowflake query id: %s", query_id) self.query_ids.append(query_id) - # If autocommit was set to False for db that supports autocommit, - # or if db does not supports autocommit, we do a manual commit. + # If autocommit was set to False or db does not support autocommit, we do a manual commit. if not self.get_autocommit(conn): conn.commit() - return execution_info - - def test_connection(self): - """Test the Snowflake connection by running a simple query.""" - try: - self.run(sql="select 1") - except Exception as e: - return False, str(e) - return True, "Connection successfully tested" + if handler is None: + return None + elif self.scalar_return_last: + return results[-1] + else: + return results diff --git a/airflow/providers/snowflake/operators/snowflake.py b/airflow/providers/snowflake/operators/snowflake.py index 086c1d6fd5bee..cf7835ef65518 100644 --- a/airflow/providers/snowflake/operators/snowflake.py +++ b/airflow/providers/snowflake/operators/snowflake.py @@ -15,32 +15,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, List, Optional, Sequence, SupportsAbs +from __future__ import annotations -from airflow.models import BaseOperator -from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator -from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook +import warnings +from typing import Any, Iterable, Mapping, Sequence, SupportsAbs - -def get_db_hook(self) -> SnowflakeHook: - """ - Create and return SnowflakeHook. - - :return: a SnowflakeHook instance. - :rtype: SnowflakeHook - """ - return SnowflakeHook( - snowflake_conn_id=self.snowflake_conn_id, - warehouse=self.warehouse, - database=self.database, - role=self.role, - schema=self.schema, - authenticator=self.authenticator, - session_parameters=self.session_parameters, - ) +from airflow.providers.common.sql.operators.sql import ( + SQLCheckOperator, + SQLExecuteQueryOperator, + SQLIntervalCheckOperator, + SQLValueCheckOperator, +) -class SnowflakeOperator(BaseOperator): +class SnowflakeOperator(SQLExecuteQueryOperator): """ Executes SQL code in a Snowflake database @@ -75,53 +63,45 @@ class SnowflakeOperator(BaseOperator): the time you connect to Snowflake """ - template_fields: Sequence[str] = ('sql',) - template_ext: Sequence[str] = ('.sql',) - template_fields_renderers = {'sql': 'sql'} - ui_color = '#ededed' + template_fields: Sequence[str] = ("sql",) + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "sql"} + ui_color = "#ededed" def __init__( self, *, - sql: Any, - snowflake_conn_id: str = 'snowflake_default', - parameters: Optional[dict] = None, - autocommit: bool = True, - do_xcom_push: bool = True, - warehouse: Optional[str] = None, - database: Optional[str] = None, - role: Optional[str] = None, - schema: Optional[str] = None, - authenticator: Optional[str] = None, - session_parameters: Optional[dict] = None, + snowflake_conn_id: str = "snowflake_default", + warehouse: str | None = None, + database: str | None = None, + role: str | None = None, + schema: str | None = None, + authenticator: str | None = None, + session_parameters: dict | None = None, **kwargs, ) -> None: - super().__init__(**kwargs) - self.snowflake_conn_id = snowflake_conn_id - self.sql = sql - self.autocommit = autocommit - self.do_xcom_push = do_xcom_push - self.parameters = parameters - self.warehouse = warehouse - self.database = database - self.role = role - self.schema = schema - self.authenticator = authenticator - self.session_parameters = session_parameters - self.query_ids: List[str] = [] - - def get_db_hook(self) -> SnowflakeHook: - return get_db_hook(self) - - def execute(self, context: Any) -> None: - """Run query on snowflake""" - self.log.info('Executing: %s', self.sql) - hook = self.get_db_hook() - execution_info = hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) - self.query_ids = hook.query_ids - - if self.do_xcom_push: - return execution_info + if any([warehouse, database, role, schema, authenticator, session_parameters]): + hook_params = kwargs.pop("hook_params", {}) + kwargs["hook_params"] = { + "warehouse": warehouse, + "database": database, + "role": role, + "schema": schema, + "authenticator": authenticator, + "session_parameters": session_parameters, + **hook_params, + } + + super().__init__(conn_id=snowflake_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`. + Also, you can provide `hook_params={'warehouse': , 'database': , + 'role': , 'schema': , 'authenticator': , + 'session_parameters': }`.""", + DeprecationWarning, + stacklevel=2, + ) class SnowflakeCheckOperator(SQLCheckOperator): @@ -179,27 +159,27 @@ class SnowflakeCheckOperator(SQLCheckOperator): the time you connect to Snowflake """ - template_fields: Sequence[str] = ('sql',) - template_ext: Sequence[str] = ('.sql',) - ui_color = '#ededed' + template_fields: Sequence[str] = ("sql",) + template_ext: Sequence[str] = (".sql",) + ui_color = "#ededed" def __init__( self, *, - sql: Any, - snowflake_conn_id: str = 'snowflake_default', - parameters: Optional[dict] = None, + sql: str, + snowflake_conn_id: str = "snowflake_default", + parameters: Iterable | Mapping | None = None, autocommit: bool = True, do_xcom_push: bool = True, - warehouse: Optional[str] = None, - database: Optional[str] = None, - role: Optional[str] = None, - schema: Optional[str] = None, - authenticator: Optional[str] = None, - session_parameters: Optional[dict] = None, + warehouse: str | None = None, + database: str | None = None, + role: str | None = None, + schema: str | None = None, + authenticator: str | None = None, + session_parameters: dict | None = None, **kwargs, ) -> None: - super().__init__(sql=sql, **kwargs) + super().__init__(sql=sql, parameters=parameters, **kwargs) self.snowflake_conn_id = snowflake_conn_id self.sql = sql self.autocommit = autocommit @@ -211,10 +191,7 @@ def __init__( self.schema = schema self.authenticator = authenticator self.session_parameters = session_parameters - self.query_ids: List[str] = [] - - def get_db_hook(self) -> SnowflakeHook: - return get_db_hook(self) + self.query_ids: list[str] = [] class SnowflakeValueCheckOperator(SQLValueCheckOperator): @@ -256,16 +233,16 @@ def __init__( sql: str, pass_value: Any, tolerance: Any = None, - snowflake_conn_id: str = 'snowflake_default', - parameters: Optional[dict] = None, + snowflake_conn_id: str = "snowflake_default", + parameters: Iterable | Mapping | None = None, autocommit: bool = True, do_xcom_push: bool = True, - warehouse: Optional[str] = None, - database: Optional[str] = None, - role: Optional[str] = None, - schema: Optional[str] = None, - authenticator: Optional[str] = None, - session_parameters: Optional[dict] = None, + warehouse: str | None = None, + database: str | None = None, + role: str | None = None, + schema: str | None = None, + authenticator: str | None = None, + session_parameters: dict | None = None, **kwargs, ) -> None: super().__init__(sql=sql, pass_value=pass_value, tolerance=tolerance, **kwargs) @@ -280,10 +257,7 @@ def __init__( self.schema = schema self.authenticator = authenticator self.session_parameters = session_parameters - self.query_ids: List[str] = [] - - def get_db_hook(self) -> SnowflakeHook: - return get_db_hook(self) + self.query_ids: list[str] = [] class SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator): @@ -331,18 +305,18 @@ def __init__( *, table: str, metrics_thresholds: dict, - date_filter_column: str = 'ds', + date_filter_column: str = "ds", days_back: SupportsAbs[int] = -7, - snowflake_conn_id: str = 'snowflake_default', - parameters: Optional[dict] = None, + snowflake_conn_id: str = "snowflake_default", + parameters: Iterable | Mapping | None = None, autocommit: bool = True, do_xcom_push: bool = True, - warehouse: Optional[str] = None, - database: Optional[str] = None, - role: Optional[str] = None, - schema: Optional[str] = None, - authenticator: Optional[str] = None, - session_parameters: Optional[dict] = None, + warehouse: str | None = None, + database: str | None = None, + role: str | None = None, + schema: str | None = None, + authenticator: str | None = None, + session_parameters: dict | None = None, **kwargs, ) -> None: super().__init__( @@ -362,7 +336,4 @@ def __init__( self.schema = schema self.authenticator = authenticator self.session_parameters = session_parameters - self.query_ids: List[str] = [] - - def get_db_hook(self) -> SnowflakeHook: - return get_db_hook(self) + self.query_ids: list[str] = [] diff --git a/airflow/providers/snowflake/provider.yaml b/airflow/providers/snowflake/provider.yaml index 735083c6e6f2b..82e2015e0574f 100644 --- a/airflow/providers/snowflake/provider.yaml +++ b/airflow/providers/snowflake/provider.yaml @@ -22,6 +22,11 @@ description: | `Snowflake `__ versions: + - 4.0.0 + - 3.3.0 + - 3.2.0 + - 3.1.0 + - 3.0.0 - 2.7.0 - 2.6.0 - 2.5.2 @@ -40,8 +45,11 @@ versions: - 1.1.0 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-common-sql>=1.3.1 + - snowflake-connector-python>=2.4.1 + - snowflake-sqlalchemy>=1.1.0 integrations: - integration-name: Snowflake @@ -70,9 +78,15 @@ transfers: target-integration-name: Slack python-module: airflow.providers.snowflake.transfers.snowflake_to_slack how-to-guide: /docs/apache-airflow-providers-snowflake/operators/snowflake_to_slack.rst - -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.snowflake.hooks.snowflake.SnowflakeHook + - source-integration-name: Amazon Simple Storage Service (S3) + target-integration-name: Snowflake + python-module: airflow.providers.snowflake.transfers.copy_into_snowflake + - source-integration-name: Google Cloud Storage (GCS) + target-integration-name: Snowflake + python-module: airflow.providers.snowflake.transfers.copy_into_snowflake + - source-integration-name: Microsoft Azure Blob Storage + target-integration-name: Snowflake + python-module: airflow.providers.snowflake.transfers.copy_into_snowflake connection-types: - hook-class-name: airflow.providers.snowflake.hooks.snowflake.SnowflakeHook diff --git a/airflow/providers/snowflake/transfers/copy_into_snowflake.py b/airflow/providers/snowflake/transfers/copy_into_snowflake.py new file mode 100644 index 0000000000000..736b53e1c5116 --- /dev/null +++ b/airflow/providers/snowflake/transfers/copy_into_snowflake.py @@ -0,0 +1,143 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains abstract operator that child classes implements +COPY INTO
SQL in Snowflake +""" +from __future__ import annotations + +from typing import Any, Sequence + +from airflow.models import BaseOperator +from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook +from airflow.providers.snowflake.utils.common import enclose_param + + +class CopyFromExternalStageToSnowflakeOperator(BaseOperator): + """ + Executes a COPY INTO command to load files from an external stage from clouds to Snowflake + + This operator requires the snowflake_conn_id connection. The snowflake host, login, + and, password field must be setup in the connection. Other inputs can be defined + in the connection or hook instantiation. + + :param namespace: snowflake namespace + :param table: snowflake table + :param file_format: file format name i.e. CSV, AVRO, etc + :param stage: reference to a specific snowflake stage. If the stage's schema is not the same as the + table one, it must be specified + :param prefix: cloud storage location specified to limit the set of files to load + :param files: files to load into table + :param pattern: pattern to load files from external location to table + :param copy_into_postifx: optional sql postfix for INSERT INTO query + such as `formatTypeOptions` and `copyOptions` + :param snowflake_conn_id: Reference to :ref:`Snowflake connection id` + :param account: snowflake account name + :param warehouse: name of snowflake warehouse + :param database: name of snowflake database + :param region: name of snowflake region + :param role: name of snowflake role + :param schema: name of snowflake schema + :param authenticator: authenticator for Snowflake. + 'snowflake' (default) to use the internal Snowflake authenticator + 'externalbrowser' to authenticate using your web browser and + Okta, ADFS or any other SAML 2.0-compliant identify provider + (IdP) that has been defined for your account + ``https://.okta.com`` to authenticate + through native Okta. + :param session_parameters: You can set session-level parameters at + the time you connect to Snowflake + :param copy_options: snowflake COPY INTO syntax copy options + :param validation_mode: snowflake COPY INTO syntax validation mode + + """ + + template_fields: Sequence[str] = ("files",) + template_fields_renderers = {"files": "json"} + + def __init__( + self, + *, + files: list | None = None, + table: str, + stage: str, + prefix: str | None = None, + file_format: str, + schema: str | None = None, + columns_array: list | None = None, + pattern: str | None = None, + warehouse: str | None = None, + database: str | None = None, + autocommit: bool = True, + snowflake_conn_id: str = "snowflake_default", + role: str | None = None, + authenticator: str | None = None, + session_parameters: dict | None = None, + copy_options: str | None = None, + validation_mode: str | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.files = files + self.table = table + self.stage = stage + self.prefix = prefix + self.file_format = file_format + self.schema = schema + self.columns_array = columns_array + self.pattern = pattern + self.warehouse = warehouse + self.database = database + self.autocommit = autocommit + self.snowflake_conn_id = snowflake_conn_id + self.role = role + self.authenticator = authenticator + self.session_parameters = session_parameters + self.copy_options = copy_options + self.validation_mode = validation_mode + + def execute(self, context: Any) -> None: + snowflake_hook = SnowflakeHook( + snowflake_conn_id=self.snowflake_conn_id, + warehouse=self.warehouse, + database=self.database, + role=self.role, + schema=self.schema, + authenticator=self.authenticator, + session_parameters=self.session_parameters, + ) + + if self.schema: + into = f"{self.schema}.{self.table}" + else: + into = self.table + + if self.columns_array: + into = f"{into}({', '.join(self.columns_array)})" + + sql = f""" + COPY INTO {into} + FROM @{self.stage}/{self.prefix or ""} + {"FILES=" + ",".join(map(enclose_param ,self.files)) if self.files else ""} + {"PATTERN=" + enclose_param(self.pattern) if self.pattern else ""} + FILE_FORMAT={self.file_format} + {self.copy_options or ""} + {self.validation_mode or ""} + """ + self.log.info("Executing COPY command...") + snowflake_hook.run(sql=sql, autocommit=self.autocommit) + self.log.info("COPY command completed") diff --git a/airflow/providers/snowflake/transfers/s3_to_snowflake.py b/airflow/providers/snowflake/transfers/s3_to_snowflake.py index 9b8eecb4dd292..58fbcbffd335a 100644 --- a/airflow/providers/snowflake/transfers/s3_to_snowflake.py +++ b/airflow/providers/snowflake/transfers/s3_to_snowflake.py @@ -15,12 +15,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module contains AWS S3 to Snowflake operator.""" -from typing import Any, Optional, Sequence +from __future__ import annotations + +import warnings +from typing import Any, Sequence from airflow.models import BaseOperator from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook +from airflow.providers.snowflake.utils.common import enclose_param class S3ToSnowflakeOperator(BaseOperator): @@ -43,6 +46,9 @@ class S3ToSnowflakeOperator(BaseOperator): defined in the connection's extra JSON) :param database: reference to a specific database in Snowflake connection :param columns_array: reference to a specific columns array in snowflake database + :param pattern: regular expression pattern string specifying the file names and/or paths to match. + Note: regular expression will be automatically enclose in single quotes + and all single quotes in expression will replace by two single quotes. :param snowflake_conn_id: Reference to :ref:`Snowflake connection id` :param role: name of role (will overwrite any role defined in @@ -64,22 +70,32 @@ class S3ToSnowflakeOperator(BaseOperator): def __init__( self, *, - s3_keys: Optional[list] = None, + s3_keys: list | None = None, table: str, stage: str, - prefix: Optional[str] = None, + prefix: str | None = None, file_format: str, - schema: Optional[str] = None, - columns_array: Optional[list] = None, - warehouse: Optional[str] = None, - database: Optional[str] = None, + schema: str | None = None, + columns_array: list | None = None, + pattern: str | None = None, + warehouse: str | None = None, + database: str | None = None, autocommit: bool = True, - snowflake_conn_id: str = 'snowflake_default', - role: Optional[str] = None, - authenticator: Optional[str] = None, - session_parameters: Optional[dict] = None, + snowflake_conn_id: str = "snowflake_default", + role: str | None = None, + authenticator: str | None = None, + session_parameters: dict | None = None, **kwargs, ) -> None: + warnings.warn( + """ + S3ToSnowflakeOperator is deprecated. + Please use + `airflow.providers.snowflake.transfers.copy_into_snowflake.CopyFromExternalStageToSnowflakeOperator`. + """, + DeprecationWarning, + stacklevel=2, + ) super().__init__(**kwargs) self.s3_keys = s3_keys self.table = table @@ -90,6 +106,7 @@ def __init__( self.file_format = file_format self.schema = schema self.columns_array = columns_array + self.pattern = pattern self.autocommit = autocommit self.snowflake_conn_id = snowflake_conn_id self.role = role @@ -119,12 +136,13 @@ def execute(self, context: Any) -> None: f"FROM @{self.stage}/{self.prefix or ''}", ] if self.s3_keys: - files = ", ".join(f"'{key}'" for key in self.s3_keys) + files = ", ".join(map(enclose_param, self.s3_keys)) sql_parts.append(f"files=({files})") sql_parts.append(f"file_format={self.file_format}") - + if self.pattern: + sql_parts.append(f"pattern={enclose_param(self.pattern)}") copy_query = "\n".join(sql_parts) - self.log.info('Executing COPY command...') + self.log.info("Executing COPY command...") snowflake_hook.run(copy_query, self.autocommit) self.log.info("COPY command completed") diff --git a/airflow/providers/snowflake/transfers/snowflake_to_slack.py b/airflow/providers/snowflake/transfers/snowflake_to_slack.py index 2c6138e58edc9..199999fe0b54f 100644 --- a/airflow/providers/snowflake/transfers/snowflake_to_slack.py +++ b/airflow/providers/snowflake/transfers/snowflake_to_slack.py @@ -14,22 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union +import warnings +from typing import Iterable, Mapping, Sequence -from pandas import DataFrame -from tabulate import tabulate +from airflow.providers.slack.transfers.sql_to_slack import SqlToSlackOperator -from airflow.exceptions import AirflowException -from airflow.models import BaseOperator -from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook -from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook -if TYPE_CHECKING: - from airflow.utils.context import Context - - -class SnowflakeToSlackOperator(BaseOperator): +class SnowflakeToSlackOperator(SqlToSlackOperator): """ Executes an SQL statement in Snowflake and sends the results to Slack. The results of the query are rendered into the 'slack_message' parameter as a Pandas dataframe using a JINJA variable called '{{ @@ -48,7 +41,7 @@ class SnowflakeToSlackOperator(BaseOperator): SQL results :param snowflake_conn_id: Reference to :ref:`Snowflake connection id` - :param slack_conn_id: The connection id for Slack + :param slack_conn_id: The connection id for Slack. :param results_df_name: The name of the JINJA template's dataframe variable, default is 'results_df' :param parameters: The parameters to pass to the SQL query :param warehouse: The Snowflake virtual warehouse to use to run the SQL query @@ -56,11 +49,11 @@ class SnowflakeToSlackOperator(BaseOperator): :param schema: The schema to run the SQL against in Snowflake :param role: The role to use when connecting to Snowflake :param slack_token: The token to use to authenticate to Slack. If this is not provided, the - 'webhook_token' attribute needs to be specified in the 'Extra' JSON field against the slack_conn_id + 'webhook_token' attribute needs to be specified in the 'Extra' JSON field against the slack_conn_id. """ - template_fields: Sequence[str] = ('sql', 'slack_message') - template_ext: Sequence[str] = ('.sql', '.jinja', '.j2') + template_fields: Sequence[str] = ("sql", "slack_message") + template_ext: Sequence[str] = (".sql", ".jinja", ".j2") template_fields_renderers = {"sql": "sql", "slack_message": "jinja"} times_rendered = 0 @@ -69,19 +62,17 @@ def __init__( *, sql: str, slack_message: str, - snowflake_conn_id: str = 'snowflake_default', - slack_conn_id: str = 'slack_default', - results_df_name: str = 'results_df', - parameters: Optional[Union[Iterable, Mapping]] = None, - warehouse: Optional[str] = None, - database: Optional[str] = None, - schema: Optional[str] = None, - role: Optional[str] = None, - slack_token: Optional[str] = None, + snowflake_conn_id: str = "snowflake_default", + slack_conn_id: str = "slack_default", + results_df_name: str = "results_df", + parameters: Iterable | Mapping | None = None, + warehouse: str | None = None, + database: str | None = None, + schema: str | None = None, + role: str | None = None, + slack_token: str | None = None, **kwargs, ) -> None: - super().__init__(**kwargs) - self.snowflake_conn_id = snowflake_conn_id self.sql = sql self.parameters = parameters @@ -94,62 +85,31 @@ def __init__( self.slack_message = slack_message self.results_df_name = results_df_name - def _get_query_results(self) -> DataFrame: - snowflake_hook = self._get_snowflake_hook() - - self.log.info('Running SQL query: %s', self.sql) - df = snowflake_hook.get_pandas_df(self.sql, parameters=self.parameters) - return df - - def _render_and_send_slack_message(self, context, df) -> None: - # Put the dataframe into the context and render the JINJA template fields - context[self.results_df_name] = df - self.render_template_fields(context) - - slack_hook = self._get_slack_hook() - self.log.info('Sending slack message: %s', self.slack_message) - slack_hook.execute() - - def _get_snowflake_hook(self) -> SnowflakeHook: - return SnowflakeHook( - snowflake_conn_id=self.snowflake_conn_id, - warehouse=self.warehouse, - database=self.database, - role=self.role, - schema=self.schema, + warnings.warn( + """ + SnowflakeToSlackOperator is deprecated. + Please use `airflow.providers.slack.transfers.sql_to_slack.SqlToSlackOperator`. + """, + DeprecationWarning, + stacklevel=2, ) - def _get_slack_hook(self) -> SlackWebhookHook: - return SlackWebhookHook( - http_conn_id=self.slack_conn_id, message=self.slack_message, webhook_token=self.slack_token + hook_params = { + "schema": self.schema, + "role": self.role, + "database": self.database, + "warehouse": self.warehouse, + } + cleaned_hook_params = {k: v for k, v in hook_params.items() if v is not None} + + super().__init__( + sql=self.sql, + sql_conn_id=self.snowflake_conn_id, + slack_conn_id=self.slack_conn_id, + slack_webhook_token=self.slack_token, + slack_message=self.slack_message, + results_df_name=self.results_df_name, + parameters=self.parameters, + sql_hook_params=cleaned_hook_params, + **kwargs, ) - - def render_template_fields(self, context, jinja_env=None) -> None: - # If this is the first render of the template fields, exclude slack_message from rendering since - # the snowflake results haven't been retrieved yet. - if self.times_rendered == 0: - fields_to_render: Iterable[str] = filter(lambda x: x != 'slack_message', self.template_fields) - else: - fields_to_render = self.template_fields - - if not jinja_env: - jinja_env = self.get_template_env() - - # Add the tabulate library into the JINJA environment - jinja_env.filters['tabulate'] = tabulate - - self._do_render_template_fields(self, fields_to_render, context, jinja_env, set()) - self.times_rendered += 1 - - def execute(self, context: 'Context') -> None: - if not isinstance(self.sql, str): - raise AirflowException("Expected 'sql' parameter should be a string.") - if self.sql is None or self.sql.strip() == "": - raise AirflowException("Expected 'sql' parameter is missing.") - if self.slack_message is None or self.slack_message.strip() == "": - raise AirflowException("Expected 'slack_message' parameter is missing.") - - df = self._get_query_results() - self._render_and_send_slack_message(context, df) - - self.log.debug('Finished sending Snowflake data to Slack') diff --git a/airflow/providers/jira/hooks/__init__.py b/airflow/providers/snowflake/utils/__init__.py similarity index 100% rename from airflow/providers/jira/hooks/__init__.py rename to airflow/providers/snowflake/utils/__init__.py diff --git a/airflow/providers/snowflake/utils/common.py b/airflow/providers/snowflake/utils/common.py new file mode 100644 index 0000000000000..3fe735ad6982e --- /dev/null +++ b/airflow/providers/snowflake/utils/common.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + + +def enclose_param(param: str) -> str: + """ + Replace all single quotes in parameter by two single quotes and enclose param in single quote. + + .. seealso:: + https://docs.snowflake.com/en/sql-reference/data-types-text.html#single-quoted-string-constants + + Examples: + .. code-block:: python + + enclose_param("without quotes") # Returns: 'without quotes' + enclose_param("'with quotes'") # Returns: '''with quotes''' + enclose_param("Today's sales projections") # Returns: 'Today''s sales projections' + enclose_param("sample/john's.csv") # Returns: 'sample/john''s.csv' + enclose_param(".*'awesome'.*[.]csv") # Returns: '.*''awesome''.*[.]csv' + + :param param: parameter which required single quotes enclosure. + """ + return f"""'{param.replace("'", "''")}'""" diff --git a/airflow/providers/sqlite/.latest-doc-only-change.txt b/airflow/providers/sqlite/.latest-doc-only-change.txt index e7c3c940c9c77..ff7136e07d744 100644 --- a/airflow/providers/sqlite/.latest-doc-only-change.txt +++ b/airflow/providers/sqlite/.latest-doc-only-change.txt @@ -1 +1 @@ -602abe8394fafe7de54df7e73af56de848cdf617 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/sqlite/CHANGELOG.rst b/airflow/providers/sqlite/CHANGELOG.rst index f012a52ba1cb8..c303e4bfb44ad 100644 --- a/airflow/providers/sqlite/CHANGELOG.rst +++ b/airflow/providers/sqlite/CHANGELOG.rst @@ -16,9 +16,94 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.3.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Features +~~~~~~~~ + +* ``Add SQLExecuteQueryOperator (#25717)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +3.2.1 +..... + +Features +~~~~~~~~ + +* ``Add common-sql lower bound for common-sql (#25789)`` + +Bug Fixes +~~~~~~~~~ + +* ``Fix placeholders in 'TrinoHook', 'PrestoHook', 'SqliteHook' (#25939)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + +3.2.0 +..... + +Features +~~~~~~~~ + +* ``Unify DbApiHook.run() method with the methods which override it (#23971)`` + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Move all SQL classes to common-sql provider (#24836)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Bug Fixes +~~~~~~~~~ + +* ``Fix ''SqliteHook'' compatibility with SQLAlchemy engine (#23790)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Migrate SQLite example DAGs to new design #22461 (#24150)`` + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.1.3 ..... diff --git a/airflow/providers/sqlite/example_dags/__init__.py b/airflow/providers/sqlite/example_dags/__init__.py deleted file mode 100644 index 217e5db960782..0000000000000 --- a/airflow/providers/sqlite/example_dags/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/airflow/providers/sqlite/example_dags/example_sqlite.py b/airflow/providers/sqlite/example_dags/example_sqlite.py deleted file mode 100644 index b1755996e281c..0000000000000 --- a/airflow/providers/sqlite/example_dags/example_sqlite.py +++ /dev/null @@ -1,86 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This is an example DAG for the use of the SqliteOperator. -In this example, we create two tasks that execute in sequence. -The first task calls an sql command, defined in the SQLite operator, -which when triggered, is performed on the connected sqlite database. -The second task is similar but instead calls the SQL command from an external file. -""" - -from datetime import datetime - -from airflow import DAG -from airflow.providers.sqlite.hooks.sqlite import SqliteHook -from airflow.providers.sqlite.operators.sqlite import SqliteOperator - -dag = DAG( - dag_id='example_sqlite', - schedule_interval='@daily', - start_date=datetime(2021, 1, 1), - tags=['example'], - catchup=False, -) - -# [START howto_operator_sqlite] - -# Example of creating a task that calls a common CREATE TABLE sql command. -create_table_sqlite_task = SqliteOperator( - task_id='create_table_sqlite', - sql=r""" - CREATE TABLE Customers ( - customer_id INT PRIMARY KEY, - first_name TEXT, - last_name TEXT - ); - """, - dag=dag, -) - -# [END howto_operator_sqlite] - - -@dag.task(task_id="insert_sqlite_task") -def insert_sqlite_hook(): - sqlite_hook = SqliteHook() - - rows = [('James', '11'), ('James', '22'), ('James', '33')] - target_fields = ['first_name', 'last_name'] - sqlite_hook.insert_rows(table='Customers', rows=rows, target_fields=target_fields) - - -@dag.task(task_id="replace_sqlite_task") -def replace_sqlite_hook(): - sqlite_hook = SqliteHook() - - rows = [('James', '11'), ('James', '22'), ('James', '33')] - target_fields = ['first_name', 'last_name'] - sqlite_hook.insert_rows(table='Customers', rows=rows, target_fields=target_fields, replace=True) - - -# [START howto_operator_sqlite_external_file] - -# Example of creating a task that calls an sql command from an external file. -external_create_table_sqlite_task = SqliteOperator( - task_id='create_table_sqlite_external_file', - sql='create_table.sql', -) - -# [END howto_operator_sqlite_external_file] - -create_table_sqlite_task >> external_create_table_sqlite_task >> insert_sqlite_hook() >> replace_sqlite_hook() diff --git a/airflow/providers/sqlite/hooks/sqlite.py b/airflow/providers/sqlite/hooks/sqlite.py index 6fe055c149d6f..feeab10791fc8 100644 --- a/airflow/providers/sqlite/hooks/sqlite.py +++ b/airflow/providers/sqlite/hooks/sqlite.py @@ -15,19 +15,21 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import sqlite3 -from airflow.hooks.dbapi import DbApiHook +from airflow.providers.common.sql.hooks.sql import DbApiHook class SqliteHook(DbApiHook): """Interact with SQLite.""" - conn_name_attr = 'sqlite_conn_id' - default_conn_name = 'sqlite_default' - conn_type = 'sqlite' - hook_name = 'Sqlite' + conn_name_attr = "sqlite_conn_id" + default_conn_name = "sqlite_default" + conn_type = "sqlite" + hook_name = "Sqlite" + placeholder = "?" def get_conn(self) -> sqlite3.dbapi2.Connection: """Returns a sqlite connection object""" @@ -41,33 +43,3 @@ def get_uri(self) -> str: conn_id = getattr(self, self.conn_name_attr) airflow_conn = self.get_connection(conn_id) return f"sqlite:///{airflow_conn.host}" - - @staticmethod - def _generate_insert_sql(table, values, target_fields, replace, **kwargs): - """ - Static helper method that generates the INSERT SQL statement. - The REPLACE variant is specific to MySQL syntax. - - :param table: Name of the target table - :param values: The row to insert into the table - :param target_fields: The names of the columns to fill in the table - :param replace: Whether to replace instead of insert - :return: The generated INSERT or REPLACE SQL statement - :rtype: str - """ - placeholders = [ - "?", - ] * len(values) - - if target_fields: - target_fields = ", ".join(target_fields) - target_fields = f"({target_fields})" - else: - target_fields = '' - - if not replace: - sql = "INSERT INTO " - else: - sql = "REPLACE INTO " - sql += f"{table} {target_fields} VALUES ({','.join(placeholders)})" - return sql diff --git a/airflow/providers/sqlite/operators/sqlite.py b/airflow/providers/sqlite/operators/sqlite.py index 7ef97ca2963f4..0bd7bb514b940 100644 --- a/airflow/providers/sqlite/operators/sqlite.py +++ b/airflow/providers/sqlite/operators/sqlite.py @@ -15,13 +15,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Iterable, List, Mapping, Optional, Sequence, Union +from __future__ import annotations -from airflow.models import BaseOperator -from airflow.providers.sqlite.hooks.sqlite import SqliteHook +import warnings +from typing import Sequence +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator -class SqliteOperator(BaseOperator): + +class SqliteOperator(SQLExecuteQueryOperator): """ Executes sql code in a specific Sqlite database @@ -37,25 +39,16 @@ class SqliteOperator(BaseOperator): :param parameters: (optional) the parameters to render the SQL query with. """ - template_fields: Sequence[str] = ('sql',) - template_ext: Sequence[str] = ('.sql',) - template_fields_renderers = {'sql': 'sql'} - ui_color = '#cdaaed' - - def __init__( - self, - *, - sql: Union[str, List[str]], - sqlite_conn_id: str = 'sqlite_default', - parameters: Optional[Union[Mapping, Iterable]] = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.sqlite_conn_id = sqlite_conn_id - self.sql = sql - self.parameters = parameters or [] - - def execute(self, context: Mapping[Any, Any]) -> None: - self.log.info('Executing: %s', self.sql) - hook = SqliteHook(sqlite_conn_id=self.sqlite_conn_id) - hook.run(self.sql, parameters=self.parameters) + template_fields: Sequence[str] = ("sql",) + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "sql"} + ui_color = "#cdaaed" + + def __init__(self, *, sqlite_conn_id: str = "sqlite_default", **kwargs) -> None: + super().__init__(conn_id=sqlite_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""", + DeprecationWarning, + stacklevel=2, + ) diff --git a/airflow/providers/sqlite/provider.yaml b/airflow/providers/sqlite/provider.yaml index 822cd6c4261aa..f0a2e454530be 100644 --- a/airflow/providers/sqlite/provider.yaml +++ b/airflow/providers/sqlite/provider.yaml @@ -22,6 +22,11 @@ description: | `SQLite `__ versions: + - 3.3.0 + - 3.2.1 + - 3.2.0 + - 3.1.0 + - 3.0.0 - 2.1.3 - 2.1.2 - 2.1.1 @@ -32,6 +37,9 @@ versions: - 1.0.1 - 1.0.0 +dependencies: + - apache-airflow-providers-common-sql>=1.3.1 + integrations: - integration-name: SQLite external-doc-url: https://www.sqlite.org/index.html @@ -51,9 +59,6 @@ hooks: python-modules: - airflow.providers.sqlite.hooks.sqlite -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.sqlite.hooks.sqlite.SqliteHook - connection-types: - hook-class-name: airflow.providers.sqlite.hooks.sqlite.SqliteHook connection-type: sqlite diff --git a/airflow/providers/ssh/CHANGELOG.rst b/airflow/providers/ssh/CHANGELOG.rst index 3bf59346ed98a..de43932a860c6 100644 --- a/airflow/providers/ssh/CHANGELOG.rst +++ b/airflow/providers/ssh/CHANGELOG.rst @@ -16,9 +16,99 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.3.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Features +~~~~~~~~ + +* ``Added docs regarding templated field (#27301)`` +* ``Added environment to templated SSHOperator fields (#26824)`` +* ``Apply log formatter on every output line in SSHOperator (#27442)`` + +Bug Fixes +~~~~~~~~~ + +* ``A few docs fixups (#26788)`` +* ``SSHOperator ignores cmd_timeout (#27182) (#27184)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + +3.2.0 +..... + +Features +~~~~~~~~ + +* ``feat: load host keys to save new host key (#25979)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Less verbose logging in ssh operator (#24915)`` +* ``Convert sftp hook to use paramiko instead of pysftp (#24512)`` + +Bug Fixes +~~~~~~~~~ + +* ``Update providers to use functools compat for ''cached_property'' (#24582)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Features +~~~~~~~~ + +* ``Add disabled_algorithms as an extra parameter for SSH connections (#24090)`` + +Bug Fixes +~~~~~~~~~ + +* ``fixing SSHHook bug when using allow_host_key_change param (#24116)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.4.4 ..... diff --git a/airflow/providers/ssh/hooks/ssh.py b/airflow/providers/ssh/hooks/ssh.py index f3f7d11a57ea2..760b22f79660b 100644 --- a/airflow/providers/ssh/hooks/ssh.py +++ b/airflow/providers/ssh/hooks/ssh.py @@ -16,24 +16,21 @@ # specific language governing permissions and limitations # under the License. """Hook for SSH connections.""" +from __future__ import annotations + import os -import sys import warnings from base64 import decodebytes from io import StringIO from select import select -from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union +from typing import Any, Sequence import paramiko from paramiko.config import SSH_PORT from sshtunnel import SSHTunnelForwarder from tenacity import Retrying, stop_after_attempt, wait_fixed, wait_random -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook @@ -53,8 +50,8 @@ class SSHHook(BaseHook): :param ssh_conn_id: :ref:`ssh connection id` from airflow Connections from where all the required parameters can be fetched like - username, password or key_file. Thought the priority is given to the - param passed during init + username, password or key_file, though priority is given to the + params passed during init. :param remote_host: remote host to connect :param username: username to connect to the remote_host :param password: password of the username to connect to the remote_host @@ -68,10 +65,14 @@ class SSHHook(BaseHook): :param keepalive_interval: send a keepalive packet to remote host every keepalive_interval seconds :param banner_timeout: timeout to wait for banner from the server in seconds + :param disabled_algorithms: dictionary mapping algorithm type to an + iterable of algorithm identifiers, which will be disabled for the + lifetime of the transport + :param ciphers: list of ciphers to use in order of preference """ # List of classes to try loading private keys as, ordered (roughly) by most common to least common - _pkey_loaders: Sequence[Type[paramiko.PKey]] = ( + _pkey_loaders: Sequence[type[paramiko.PKey]] = ( paramiko.RSAKey, paramiko.ECDSAKey, paramiko.Ed25519Key, @@ -79,39 +80,41 @@ class SSHHook(BaseHook): ) _host_key_mappings = { - 'rsa': paramiko.RSAKey, - 'dss': paramiko.DSSKey, - 'ecdsa': paramiko.ECDSAKey, - 'ed25519': paramiko.Ed25519Key, + "rsa": paramiko.RSAKey, + "dss": paramiko.DSSKey, + "ecdsa": paramiko.ECDSAKey, + "ed25519": paramiko.Ed25519Key, } - conn_name_attr = 'ssh_conn_id' - default_conn_name = 'ssh_default' - conn_type = 'ssh' - hook_name = 'SSH' + conn_name_attr = "ssh_conn_id" + default_conn_name = "ssh_default" + conn_type = "ssh" + hook_name = "SSH" @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { - "hidden_fields": ['schema'], + "hidden_fields": ["schema"], "relabeling": { - 'login': 'Username', + "login": "Username", }, } def __init__( self, - ssh_conn_id: Optional[str] = None, - remote_host: str = '', - username: Optional[str] = None, - password: Optional[str] = None, - key_file: Optional[str] = None, - port: Optional[int] = None, - timeout: Optional[int] = None, - conn_timeout: Optional[int] = None, + ssh_conn_id: str | None = None, + remote_host: str = "", + username: str | None = None, + password: str | None = None, + key_file: str | None = None, + port: int | None = None, + timeout: int | None = None, + conn_timeout: int | None = None, keepalive_interval: int = 30, banner_timeout: float = 30.0, + disabled_algorithms: dict | None = None, + ciphers: list[str] | None = None, ) -> None: super().__init__() self.ssh_conn_id = ssh_conn_id @@ -125,6 +128,8 @@ def __init__( self.conn_timeout = conn_timeout self.keepalive_interval = keepalive_interval self.banner_timeout = banner_timeout + self.disabled_algorithms = disabled_algorithms + self.ciphers = ciphers self.host_proxy_cmd = None # Default values, overridable from Connection @@ -135,7 +140,7 @@ def __init__( self.look_for_keys = True # Placeholder for deprecated __enter__ - self.client: Optional[paramiko.SSHClient] = None + self.client: paramiko.SSHClient | None = None # Use connection to override defaults if self.ssh_conn_id is not None: @@ -154,25 +159,25 @@ def __init__( if "key_file" in extra_options and self.key_file is None: self.key_file = extra_options.get("key_file") - private_key = extra_options.get('private_key') - private_key_passphrase = extra_options.get('private_key_passphrase') + private_key = extra_options.get("private_key") + private_key_passphrase = extra_options.get("private_key_passphrase") if private_key: self.pkey = self._pkey_from_private_key(private_key, passphrase=private_key_passphrase) if "timeout" in extra_options: warnings.warn( - 'Extra option `timeout` is deprecated.' - 'Please use `conn_timeout` instead.' - 'The old option `timeout` will be removed in a future version.', + "Extra option `timeout` is deprecated." + "Please use `conn_timeout` instead." + "The old option `timeout` will be removed in a future version.", DeprecationWarning, stacklevel=2, ) - self.timeout = int(extra_options['timeout']) + self.timeout = int(extra_options["timeout"]) if "conn_timeout" in extra_options and self.conn_timeout is None: - self.conn_timeout = int(extra_options['conn_timeout']) + self.conn_timeout = int(extra_options["conn_timeout"]) - if "compress" in extra_options and str(extra_options["compress"]).lower() == 'false': + if "compress" in extra_options and str(extra_options["compress"]).lower() == "false": self.compress = False host_key = extra_options.get("host_key") @@ -187,31 +192,37 @@ def __init__( if ( "allow_host_key_change" in extra_options - and str(extra_options["allow_host_key_change"]).lower() == 'true' + and str(extra_options["allow_host_key_change"]).lower() == "true" ): self.allow_host_key_change = True if ( "look_for_keys" in extra_options - and str(extra_options["look_for_keys"]).lower() == 'false' + and str(extra_options["look_for_keys"]).lower() == "false" ): self.look_for_keys = False + if "disabled_algorithms" in extra_options: + self.disabled_algorithms = extra_options.get("disabled_algorithms") + + if "ciphers" in extra_options: + self.ciphers = extra_options.get("ciphers") + if host_key is not None: if host_key.startswith("ssh-"): key_type, host_key = host_key.split(None)[:2] key_constructor = self._host_key_mappings[key_type[4:]] else: key_constructor = paramiko.RSAKey - decoded_host_key = decodebytes(host_key.encode('utf-8')) + decoded_host_key = decodebytes(host_key.encode("utf-8")) self.host_key = key_constructor(data=decoded_host_key) self.no_host_key_check = False if self.timeout: warnings.warn( - 'Parameter `timeout` is deprecated.' - 'Please use `conn_timeout` instead.' - 'The old option `timeout` will be removed in a future version.', + "Parameter `timeout` is deprecated." + "Please use `conn_timeout` instead." + "The old option `timeout` will be removed in a future version.", DeprecationWarning, stacklevel=1, ) @@ -237,33 +248,29 @@ def __init__( ) self.username = getuser() - user_ssh_config_filename = os.path.expanduser('~/.ssh/config') + user_ssh_config_filename = os.path.expanduser("~/.ssh/config") if os.path.isfile(user_ssh_config_filename): ssh_conf = paramiko.SSHConfig() with open(user_ssh_config_filename) as config_fd: ssh_conf.parse(config_fd) host_info = ssh_conf.lookup(self.remote_host) - if host_info and host_info.get('proxycommand'): - self.host_proxy_cmd = host_info['proxycommand'] + if host_info and host_info.get("proxycommand"): + self.host_proxy_cmd = host_info["proxycommand"] if not (self.password or self.key_file): - if host_info and host_info.get('identityfile'): - self.key_file = host_info['identityfile'][0] + if host_info and host_info.get("identityfile"): + self.key_file = host_info["identityfile"][0] self.port = self.port or SSH_PORT @cached_property - def host_proxy(self) -> Optional[paramiko.ProxyCommand]: + def host_proxy(self) -> paramiko.ProxyCommand | None: cmd = self.host_proxy_cmd return paramiko.ProxyCommand(cmd) if cmd else None def get_conn(self) -> paramiko.SSHClient: - """ - Opens a ssh connection to the remote host. - - :rtype: paramiko.client.SSHClient - """ - self.log.debug('Creating SSH client for conn_id: %s', self.ssh_conn_id) + """Opens a ssh connection to the remote host.""" + self.log.debug("Creating SSH client for conn_id: %s", self.ssh_conn_id) client = paramiko.SSHClient() if self.allow_host_key_change: @@ -271,11 +278,18 @@ def get_conn(self) -> paramiko.SSHClient: "Remote Identification Change is not verified. " "This won't protect against Man-In-The-Middle attacks" ) + # to avoid BadHostKeyException, skip loading host keys + client.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy) else: client.load_system_host_keys() if self.no_host_key_check: self.log.warning("No Host Key Verification. This won't protect against Man-In-The-Middle attacks") + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + # to avoid BadHostKeyException, skip loading and saving host keys + known_hosts = os.path.expanduser("~/.ssh/known_hosts") + if not self.allow_host_key_change and os.path.isfile(known_hosts): + client.load_host_keys(known_hosts) else: if self.host_key is not None: client_host_keys = client.get_host_keys() @@ -288,11 +302,7 @@ def get_conn(self) -> paramiko.SSHClient: else: pass # will fallback to system host keys if none explicitly specified in conn extra - if self.no_host_key_check or self.allow_host_key_change: - # Default is RejectPolicy - client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - - connect_kwargs: Dict[str, Any] = dict( + connect_kwargs: dict[str, Any] = dict( hostname=self.remote_host, username=self.username, timeout=self.conn_timeout, @@ -313,6 +323,9 @@ def get_conn(self) -> paramiko.SSHClient: if self.key_file: connect_kwargs.update(key_filename=self.key_file) + if self.disabled_algorithms: + connect_kwargs.update(disabled_algorithms=self.disabled_algorithms) + log_before_sleep = lambda retry_state: self.log.info( "Failed to connect. Sleeping before retry attempt %d", retry_state.attempt_number ) @@ -328,17 +341,22 @@ def get_conn(self) -> paramiko.SSHClient: if self.keepalive_interval: # MyPy check ignored because "paramiko" isn't well-typed. The `client.get_transport()` returns - # type "Optional[Transport]" and item "None" has no attribute "set_keepalive". + # type "Transport | None" and item "None" has no attribute "set_keepalive". client.get_transport().set_keepalive(self.keepalive_interval) # type: ignore[union-attr] + if self.ciphers: + # MyPy check ignored because "paramiko" isn't well-typed. The `client.get_transport()` returns + # type "Transport | None" and item "None" has no method `get_security_options`". + client.get_transport().get_security_options().ciphers = self.ciphers # type: ignore[union-attr] + self.client = client return client - def __enter__(self) -> 'SSHHook': + def __enter__(self) -> SSHHook: warnings.warn( - 'The contextmanager of SSHHook is deprecated.' - 'Please use get_conn() as a contextmanager instead.' - 'This method will be removed in Airflow 2.0', + "The contextmanager of SSHHook is deprecated." + "Please use get_conn() as a contextmanager instead." + "This method will be removed in Airflow 2.0", category=DeprecationWarning, ) return self @@ -349,7 +367,7 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.client = None def get_tunnel( - self, remote_port: int, remote_host: str = "localhost", local_port: Optional[int] = None + self, remote_port: int, remote_host: str = "localhost", local_port: int | None = None ) -> SSHTunnelForwarder: """ Creates a tunnel between two hosts. Like ssh -L :host:. @@ -361,9 +379,9 @@ def get_tunnel( :return: sshtunnel.SSHTunnelForwarder object """ if local_port: - local_bind_address: Union[Tuple[str, int], Tuple[str]] = ('localhost', local_port) + local_bind_address: tuple[str, int] | tuple[str] = ("localhost", local_port) else: - local_bind_address = ('localhost',) + local_bind_address = ("localhost",) tunnel_kwargs = dict( ssh_port=self.port, @@ -401,16 +419,16 @@ def create_tunnel( :return: """ warnings.warn( - 'SSHHook.create_tunnel is deprecated, Please' - 'use get_tunnel() instead. But please note that the' - 'order of the parameters have changed' - 'This method will be removed in Airflow 2.0', + "SSHHook.create_tunnel is deprecated, Please" + "use get_tunnel() instead. But please note that the" + "order of the parameters have changed" + "This method will be removed in Airflow 2.0", category=DeprecationWarning, ) return self.get_tunnel(remote_port, remote_host, local_port) - def _pkey_from_private_key(self, private_key: str, passphrase: Optional[str] = None) -> paramiko.PKey: + def _pkey_from_private_key(self, private_key: str, passphrase: str | None = None) -> paramiko.PKey: """ Creates appropriate paramiko key for given private key @@ -419,21 +437,21 @@ def _pkey_from_private_key(self, private_key: str, passphrase: Optional[str] = N :raises AirflowException: if key cannot be read """ if len(private_key.split("\n", 2)) < 2: - raise AirflowException('Key must have BEGIN and END header/footer on separate lines.') + raise AirflowException("Key must have BEGIN and END header/footer on separate lines.") for pkey_class in self._pkey_loaders: try: key = pkey_class.from_private_key(StringIO(private_key), password=passphrase) # Test it actually works. If Paramiko loads an openssh generated key, sometimes it will # happily load it as the wrong type, only to fail when actually used. - key.sign_ssh_data(b'') + key.sign_ssh_data(b"") return key except (paramiko.ssh_exception.SSHException, ValueError): continue raise AirflowException( - 'Private key provided cannot be read by paramiko.' - 'Ensure key provided is valid for one of the following' - 'key formats: RSA, DSS, ECDSA, or Ed25519' + "Private key provided cannot be read by paramiko." + "Ensure key provided is valid for one of the following" + "key formats: RSA, DSS, ECDSA, or Ed25519" ) def exec_ssh_client_command( @@ -441,9 +459,9 @@ def exec_ssh_client_command( ssh_client: paramiko.SSHClient, command: str, get_pty: bool, - environment: Optional[dict], - timeout: Optional[int], - ) -> Tuple[int, bytes, bytes]: + environment: dict | None, + timeout: int | None, + ) -> tuple[int, bytes, bytes]: self.log.info("Running command: %s", command) # set timeout taken as params @@ -460,8 +478,8 @@ def exec_ssh_client_command( stdin.close() channel.shutdown_write() - agg_stdout = b'' - agg_stderr = b'' + agg_stdout = b"" + agg_stderr = b"" # capture any initial output in case channel is closed already stdout_buffer_length = len(stdout.channel.in_buffer) @@ -469,23 +487,28 @@ def exec_ssh_client_command( if stdout_buffer_length > 0: agg_stdout += stdout.channel.recv(stdout_buffer_length) + timedout = False + # read from both stdout and stderr while not channel.closed or channel.recv_ready() or channel.recv_stderr_ready(): readq, _, _ = select([channel], [], [], timeout) + timedout = len(readq) == 0 for recv in readq: if recv.recv_ready(): - line = stdout.channel.recv(len(recv.in_buffer)) - agg_stdout += line - self.log.info(line.decode('utf-8', 'replace').strip('\n')) + output = stdout.channel.recv(len(recv.in_buffer)) + agg_stdout += output + for line in output.decode("utf-8", "replace").strip("\n").splitlines(): + self.log.info(line) if recv.recv_stderr_ready(): - line = stderr.channel.recv_stderr(len(recv.in_stderr_buffer)) - agg_stderr += line - self.log.warning(line.decode('utf-8', 'replace').strip('\n')) + output = stderr.channel.recv_stderr(len(recv.in_stderr_buffer)) + agg_stderr += output + for line in output.decode("utf-8", "replace").strip("\n").splitlines(): + self.log.warning(line) if ( stdout.channel.exit_status_ready() and not stderr.channel.recv_stderr_ready() and not stdout.channel.recv_ready() - ): + ) or timedout: stdout.channel.shutdown_read() try: stdout.channel.close() @@ -499,6 +522,9 @@ def exec_ssh_client_command( stdout.close() stderr.close() + if timedout: + raise AirflowException("SSH command timed out") + exit_status = stdout.channel.recv_exit_status() return exit_status, agg_stdout, agg_stderr diff --git a/airflow/providers/ssh/operators/ssh.py b/airflow/providers/ssh/operators/ssh.py index 1d60aae4f85f5..806f72832cbb7 100644 --- a/airflow/providers/ssh/operators/ssh.py +++ b/airflow/providers/ssh/operators/ssh.py @@ -15,10 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import warnings from base64 import b64encode -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.configuration import conf from airflow.exceptions import AirflowException @@ -49,10 +50,8 @@ class SSHOperator(BaseOperator): Nullable. If provided, it will replace the `conn_timeout` which was predefined in the connection of `ssh_conn_id`. :param cmd_timeout: timeout (in seconds) for executing the command. The default is 10 seconds. - :param timeout: (deprecated) timeout (in seconds) for executing the command. The default is 10 seconds. - Use conn_timeout and cmd_timeout parameters instead. :param environment: a dict of shell environment variables. Note that the - server will reject them silently if `AcceptEnv` is not set in SSH config. + server will reject them silently if `AcceptEnv` is not set in SSH config. (templated) :param get_pty: request a pseudo-terminal from the server. Set to ``True`` to have the remote process killed upon task timeout. The default is ``False`` but note that `get_pty` is forced to ``True`` @@ -60,21 +59,23 @@ class SSHOperator(BaseOperator): :param banner_timeout: timeout to wait for banner from the server in seconds """ - template_fields: Sequence[str] = ('command', 'remote_host') - template_ext: Sequence[str] = ('.sh',) - template_fields_renderers = {"command": "bash"} + template_fields: Sequence[str] = ("command", "environment", "remote_host") + template_ext: Sequence[str] = (".sh",) + template_fields_renderers = { + "command": "bash", + "environment": "python", + } def __init__( self, *, - ssh_hook: Optional["SSHHook"] = None, - ssh_conn_id: Optional[str] = None, - remote_host: Optional[str] = None, - command: Optional[str] = None, - timeout: Optional[int] = None, - conn_timeout: Optional[int] = None, - cmd_timeout: Optional[int] = None, - environment: Optional[dict] = None, + ssh_hook: SSHHook | None = None, + ssh_conn_id: str | None = None, + remote_host: str | None = None, + command: str | None = None, + conn_timeout: int | None = None, + cmd_timeout: int | None = None, + environment: dict | None = None, get_pty: bool = False, banner_timeout: float = 30.0, **kwargs, @@ -84,27 +85,13 @@ def __init__( self.ssh_conn_id = ssh_conn_id self.remote_host = remote_host self.command = command - self.timeout = timeout self.conn_timeout = conn_timeout - self.cmd_timeout = cmd_timeout - if self.conn_timeout is None and self.timeout: - self.conn_timeout = self.timeout - if self.cmd_timeout is None: - self.cmd_timeout = self.timeout if self.timeout else CMD_TIMEOUT + self.cmd_timeout = cmd_timeout if cmd_timeout else CMD_TIMEOUT self.environment = environment self.get_pty = get_pty self.banner_timeout = banner_timeout - if self.timeout: - warnings.warn( - 'Parameter `timeout` is deprecated.' - 'Please use `conn_timeout` and `cmd_timeout` instead.' - 'The old option `timeout` will be removed in a future version.', - DeprecationWarning, - stacklevel=2, - ) - - def get_hook(self) -> "SSHHook": + def get_hook(self) -> SSHHook: from airflow.providers.ssh.hooks.ssh import SSHHook if self.ssh_conn_id: @@ -131,51 +118,47 @@ def get_hook(self) -> "SSHHook": return self.ssh_hook - def get_ssh_client(self) -> "SSHClient": + def get_ssh_client(self) -> SSHClient: # Remember to use context manager or call .close() on this when done - self.log.info('Creating ssh_client') + self.log.info("Creating ssh_client") return self.get_hook().get_conn() - def exec_ssh_client_command(self, ssh_client: "SSHClient", command: str): + def exec_ssh_client_command(self, ssh_client: SSHClient, command: str): warnings.warn( - 'exec_ssh_client_command method on SSHOperator is deprecated, call ' - '`ssh_hook.exec_ssh_client_command` instead', + "exec_ssh_client_command method on SSHOperator is deprecated, call " + "`ssh_hook.exec_ssh_client_command` instead", DeprecationWarning, ) assert self.ssh_hook return self.ssh_hook.exec_ssh_client_command( - ssh_client, command, timeout=self.timeout, environment=self.environment, get_pty=self.get_pty + ssh_client, command, timeout=self.cmd_timeout, environment=self.environment, get_pty=self.get_pty ) def raise_for_status(self, exit_status: int, stderr: bytes) -> None: if exit_status != 0: - error_msg = stderr.decode('utf-8') - raise AirflowException(f"error running cmd: {self.command}, error: {error_msg}") + raise AirflowException(f"SSH operator error: exit status = {exit_status}") - def run_ssh_client_command(self, ssh_client: "SSHClient", command: str) -> bytes: + def run_ssh_client_command(self, ssh_client: SSHClient, command: str) -> bytes: assert self.ssh_hook exit_status, agg_stdout, agg_stderr = self.ssh_hook.exec_ssh_client_command( - ssh_client, command, timeout=self.timeout, environment=self.environment, get_pty=self.get_pty + ssh_client, command, timeout=self.cmd_timeout, environment=self.environment, get_pty=self.get_pty ) self.raise_for_status(exit_status, agg_stderr) return agg_stdout - def execute(self, context=None) -> Union[bytes, str]: - result: Union[bytes, str] + def execute(self, context=None) -> bytes | str: + result: bytes | str if self.command is None: raise AirflowException("SSH operator error: SSH command not specified. Aborting.") # Forcing get_pty to True if the command begins with "sudo". - self.get_pty = self.command.startswith('sudo') or self.get_pty - - try: - with self.get_ssh_client() as ssh_client: - result = self.run_ssh_client_command(ssh_client, self.command) - except Exception as e: - raise AirflowException(f"SSH operator error: {str(e)}") - enable_pickling = conf.getboolean('core', 'enable_xcom_pickling') + self.get_pty = self.command.startswith("sudo") or self.get_pty + + with self.get_ssh_client() as ssh_client: + result = self.run_ssh_client_command(ssh_client, self.command) + enable_pickling = conf.getboolean("core", "enable_xcom_pickling") if not enable_pickling: - result = b64encode(result).decode('utf-8') + result = b64encode(result).decode("utf-8") return result def tunnel(self) -> None: diff --git a/airflow/providers/ssh/provider.yaml b/airflow/providers/ssh/provider.yaml index 3916ad5334b1b..bc753213298ee 100644 --- a/airflow/providers/ssh/provider.yaml +++ b/airflow/providers/ssh/provider.yaml @@ -22,6 +22,10 @@ description: | `Secure Shell (SSH) `__ versions: + - 3.3.0 + - 3.2.0 + - 3.1.0 + - 3.0.0 - 2.4.4 - 2.4.3 - 2.4.2 @@ -37,8 +41,10 @@ versions: - 1.1.0 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - paramiko>=2.6.0 + - sshtunnel>=0.3.2 integrations: - integration-name: Secure Shell (SSH) @@ -56,9 +62,6 @@ hooks: python-modules: - airflow.providers.ssh.hooks.ssh -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.ssh.hooks.ssh.SSHHook - connection-types: - hook-class-name: airflow.providers.ssh.hooks.ssh.SSHHook connection-type: ssh diff --git a/airflow/providers/tableau/.latest-doc-only-change.txt b/airflow/providers/tableau/.latest-doc-only-change.txt index d34f7b39802b2..ff7136e07d744 100644 --- a/airflow/providers/tableau/.latest-doc-only-change.txt +++ b/airflow/providers/tableau/.latest-doc-only-change.txt @@ -1 +1 @@ -ef037e702182e4370cb00c853c4fb0e246a0479c +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/tableau/CHANGELOG.rst b/airflow/providers/tableau/CHANGELOG.rst index 16cc93e0cd931..a72289080a543 100644 --- a/airflow/providers/tableau/CHANGELOG.rst +++ b/airflow/providers/tableau/CHANGELOG.rst @@ -16,9 +16,68 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +4.0.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +* ``Removed deprecated classes path tableau_job_status and tableau_refresh_workbook (#27288).`` +* ``Remove deprecated Tableau classes (#27288)`` + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +3.0.1 +..... + +Bug fixes +~~~~~~~~~ + +* ``Remove Tableau from Salesforce provider (#23747)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Remove "bad characters" from our codebase (#24841)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``AIP-47 - Migrate Tableau DAGs to new design (#24125)`` + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.1.8 ..... diff --git a/airflow/providers/tableau/example_dags/example_tableau.py b/airflow/providers/tableau/example_dags/example_tableau.py deleted file mode 100644 index 53aba4c074839..0000000000000 --- a/airflow/providers/tableau/example_dags/example_tableau.py +++ /dev/null @@ -1,64 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This is an example dag that performs two refresh operations on a Tableau Workbook aka Extract. The first one -waits until it succeeds. The second does not wait since this is an asynchronous operation and we don't know -when the operation actually finishes. That's why we have another task that checks only that. -""" -from datetime import datetime, timedelta - -from airflow import DAG -from airflow.providers.tableau.operators.tableau import TableauOperator -from airflow.providers.tableau.sensors.tableau import TableauJobStatusSensor - -with DAG( - dag_id='example_tableau', - default_args={'site_id': 'my_site'}, - dagrun_timeout=timedelta(hours=2), - schedule_interval=None, - start_date=datetime(2021, 1, 1), - tags=['example'], -) as dag: - # Refreshes a workbook and waits until it succeeds. - # [START howto_operator_tableau] - task_refresh_workbook_blocking = TableauOperator( - resource='workbooks', - method='refresh', - find='MyWorkbook', - match_with='name', - blocking_refresh=True, - task_id='refresh_tableau_workbook_blocking', - ) - # [END howto_operator_tableau] - # Refreshes a workbook and does not wait until it succeeds. - task_refresh_workbook_non_blocking = TableauOperator( - resource='workbooks', - method='refresh', - find='MyWorkbook', - match_with='name', - blocking_refresh=False, - task_id='refresh_tableau_workbook_non_blocking', - ) - # The following task queries the status of the workbook refresh job until it succeeds. - task_check_job_status = TableauJobStatusSensor( - job_id=task_refresh_workbook_non_blocking.output, - task_id='check_tableau_job_status', - ) - - # Task dependency created via XComArgs: - # task_refresh_workbook_non_blocking >> task_check_job_status diff --git a/airflow/providers/tableau/example_dags/example_tableau_refresh_workbook.py b/airflow/providers/tableau/example_dags/example_tableau_refresh_workbook.py deleted file mode 100644 index 31579003b80aa..0000000000000 --- a/airflow/providers/tableau/example_dags/example_tableau_refresh_workbook.py +++ /dev/null @@ -1,62 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This is an example dag that performs two refresh operations on a Tableau Workbook aka Extract. The first one -waits until it succeeds. The second does not wait since this is an asynchronous operation and we don't know -when the operation actually finishes. That's why we have another task that checks only that. -""" -from datetime import datetime, timedelta - -from airflow import DAG -from airflow.providers.tableau.operators.tableau import TableauOperator -from airflow.providers.tableau.sensors.tableau import TableauJobStatusSensor - -with DAG( - dag_id='example_tableau_refresh_workbook', - dagrun_timeout=timedelta(hours=2), - schedule_interval=None, - start_date=datetime(2021, 1, 1), - default_args={'site_id': 'my_site'}, - tags=['example'], -) as dag: - # Refreshes a workbook and waits until it succeeds. - task_refresh_workbook_blocking = TableauOperator( - resource='workbooks', - method='refresh', - find='MyWorkbook', - match_with='name', - blocking_refresh=True, - task_id='refresh_tableau_workbook_blocking', - ) - # Refreshes a workbook and does not wait until it succeeds. - task_refresh_workbook_non_blocking = TableauOperator( - resource='workbooks', - method='refresh', - find='MyWorkbook', - match_with='name', - blocking_refresh=False, - task_id='refresh_tableau_workbook_non_blocking', - ) - # The following task queries the status of the workbook refresh job until it succeeds. - task_check_job_status = TableauJobStatusSensor( - job_id=task_refresh_workbook_non_blocking.output, - task_id='check_tableau_job_status', - ) - - # Task dependency created via XComArgs: - # task_refresh_workbook_non_blocking >> task_check_job_status diff --git a/airflow/providers/tableau/hooks/tableau.py b/airflow/providers/tableau/hooks/tableau.py index e0d890b605bfd..fe1650980347d 100644 --- a/airflow/providers/tableau/hooks/tableau.py +++ b/airflow/providers/tableau/hooks/tableau.py @@ -14,10 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import time import warnings from enum import Enum -from typing import Any, Optional, Union +from typing import Any from tableauserverclient import Pager, PersonalAccessTokenAuth, Server, TableauAuth from tableauserverclient.server import Auth @@ -26,15 +28,15 @@ from airflow.hooks.base import BaseHook -def parse_boolean(val: str) -> Union[str, bool]: +def parse_boolean(val: str) -> str | bool: """Try to parse a string into boolean. The string is returned as-is if it does not look like a boolean value. """ val = val.lower() - if val in ('y', 'yes', 't', 'true', 'on', '1'): + if val in ("y", "yes", "t", "true", "on", "1"): return True - if val in ('n', 'no', 'f', 'false', 'off', '0'): + if val in ("n", "no", "f", "false", "off", "0"): return False return val @@ -70,22 +72,22 @@ class TableauHook(BaseHook): containing the credentials to authenticate to the Tableau Server. """ - conn_name_attr = 'tableau_conn_id' - default_conn_name = 'tableau_default' - conn_type = 'tableau' - hook_name = 'Tableau' + conn_name_attr = "tableau_conn_id" + default_conn_name = "tableau_default" + conn_type = "tableau" + hook_name = "Tableau" - def __init__(self, site_id: Optional[str] = None, tableau_conn_id: str = default_conn_name) -> None: + def __init__(self, site_id: str | None = None, tableau_conn_id: str = default_conn_name) -> None: super().__init__() self.tableau_conn_id = tableau_conn_id self.conn = self.get_connection(self.tableau_conn_id) - self.site_id = site_id or self.conn.extra_dejson.get('site_id', '') + self.site_id = site_id or self.conn.extra_dejson.get("site_id", "") self.server = Server(self.conn.host) - verify: Any = self.conn.extra_dejson.get('verify', True) + verify: Any = self.conn.extra_dejson.get("verify", True) if isinstance(verify, str): verify = parse_boolean(verify) self.server.add_http_options( - options_dict={'verify': verify, 'cert': self.conn.extra_dejson.get('cert', None)} + options_dict={"verify": verify, "cert": self.conn.extra_dejson.get("cert", None)} ) self.server.use_server_version() self.tableau_conn = None @@ -103,13 +105,12 @@ def get_conn(self) -> Auth.contextmgr: Sign in to the Tableau Server. :return: an authorized Tableau Server Context Manager object. - :rtype: tableauserverclient.server.Auth.contextmgr """ if self.conn.login and self.conn.password: return self._auth_via_password() - if 'token_name' in self.conn.extra_dejson and 'personal_access_token' in self.conn.extra_dejson: + if "token_name" in self.conn.extra_dejson and "personal_access_token" in self.conn.extra_dejson: return self._auth_via_token() - raise NotImplementedError('No Authentication method found for given Credentials!') + raise NotImplementedError("No Authentication method found for given Credentials!") def _auth_via_password(self) -> Auth.contextmgr: tableau_auth = TableauAuth( @@ -125,8 +126,8 @@ def _auth_via_token(self) -> Auth.contextmgr: DeprecationWarning, ) tableau_auth = PersonalAccessTokenAuth( - token_name=self.conn.extra_dejson['token_name'], - personal_access_token=self.conn.extra_dejson['personal_access_token'], + token_name=self.conn.extra_dejson["token_name"], + personal_access_token=self.conn.extra_dejson["personal_access_token"], site_id=self.site_id, ) return self.server.auth.sign_in_with_personal_access_token(tableau_auth) @@ -139,7 +140,6 @@ def get_all(self, resource_name: str) -> Pager: :param resource_name: The name of the resource to paginate. For example: jobs or workbooks. :return: all items by returning a Pager. - :rtype: tableauserverclient.Pager """ try: resource = getattr(self.server, resource_name) @@ -153,8 +153,7 @@ def get_job_status(self, job_id: str) -> TableauJobFinishCode: .. see also:: https://tableau.github.io/server-client-python/docs/api-ref#jobs :param job_id: The id of the job to check. - :return: An Enum that describe the Tableau job’s return code - :rtype: TableauJobFinishCode + :return: An Enum that describe the Tableau job's return code """ return TableauJobFinishCode(int(self.server.jobs.get_by_id(job_id).finish_code)) @@ -164,11 +163,10 @@ def wait_for_state(self, job_id: str, target_state: TableauJobFinishCode, check_ to target_state or different from PENDING. :param job_id: The id of the job to check. - :param target_state: Enum that describe the Tableau job’s target state + :param target_state: Enum that describe the Tableau job's target state :param check_interval: time in seconds that the job should wait in between each instance state checks until operation is completed :return: return True if the job is equal to the target_status, False otherwise. - :rtype: bool """ finish_code = self.get_job_status(job_id=job_id) while finish_code == TableauJobFinishCode.PENDING and finish_code != target_state: diff --git a/airflow/providers/tableau/operators/tableau.py b/airflow/providers/tableau/operators/tableau.py index 7f78d598ad47f..9181e7fea6064 100644 --- a/airflow/providers/tableau/operators/tableau.py +++ b/airflow/providers/tableau/operators/tableau.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -29,15 +31,15 @@ RESOURCES_METHODS = { - 'datasources': ['delete', 'refresh'], - 'groups': ['delete'], - 'projects': ['delete'], - 'schedule': ['delete'], - 'sites': ['delete'], - 'subscriptions': ['delete'], - 'tasks': ['delete', 'run'], - 'users': ['remove'], - 'workbooks': ['delete', 'refresh'], + "datasources": ["delete", "refresh"], + "groups": ["delete"], + "projects": ["delete"], + "schedule": ["delete"], + "sites": ["delete"], + "subscriptions": ["delete"], + "tasks": ["delete", "run"], + "users": ["remove"], + "workbooks": ["delete", "refresh"], } @@ -68,11 +70,11 @@ def __init__( resource: str, method: str, find: str, - match_with: str = 'id', - site_id: Optional[str] = None, + match_with: str = "id", + site_id: str | None = None, blocking_refresh: bool = True, check_interval: float = 20, - tableau_conn_id: str = 'tableau_default', + tableau_conn_id: str = "tableau_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -85,21 +87,20 @@ def __init__( self.blocking_refresh = blocking_refresh self.tableau_conn_id = tableau_conn_id - def execute(self, context: 'Context') -> str: + def execute(self, context: Context) -> str: """ Executes the Tableau API resource and pushes the job id or downloaded file URI to xcom. :param context: The task context during execution. :return: the id of the job that executes the extract refresh or downloaded file URI. - :rtype: str """ available_resources = RESOURCES_METHODS.keys() if self.resource not in available_resources: - error_message = f'Resource not found! Available Resources: {available_resources}' + error_message = f"Resource not found! Available Resources: {available_resources}" raise AirflowException(error_message) available_methods = RESOURCES_METHODS[self.resource] if self.method not in available_methods: - error_message = f'Method not found! Available methods for {self.resource}: {available_methods}' + error_message = f"Method not found! Available methods for {self.resource}: {available_methods}" raise AirflowException(error_message) with TableauHook(self.site_id, self.tableau_conn_id) as tableau_hook: @@ -113,26 +114,26 @@ def execute(self, context: 'Context') -> str: job_id = response.id - if self.method == 'refresh': + if self.method == "refresh": if self.blocking_refresh: if not tableau_hook.wait_for_state( job_id=job_id, check_interval=self.check_interval, target_state=TableauJobFinishCode.SUCCESS, ): - raise TableauJobFailedException(f'The Tableau Refresh {self.resource} Job failed!') + raise TableauJobFailedException(f"The Tableau Refresh {self.resource} Job failed!") return job_id def _get_resource_id(self, tableau_hook: TableauHook) -> str: - if self.match_with == 'id': + if self.match_with == "id": return self.find for resource in tableau_hook.get_all(resource_name=self.resource): if getattr(resource, self.match_with) == self.find: resource_id = resource.id - self.log.info('Found matching with id %s', resource_id) + self.log.info("Found matching with id %s", resource_id) return resource_id - raise AirflowException(f'{self.resource} with {self.match_with} {self.find} not found!') + raise AirflowException(f"{self.resource} with {self.match_with} {self.find} not found!") diff --git a/airflow/providers/tableau/operators/tableau_refresh_workbook.py b/airflow/providers/tableau/operators/tableau_refresh_workbook.py deleted file mode 100644 index 306c3ded629bf..0000000000000 --- a/airflow/providers/tableau/operators/tableau_refresh_workbook.py +++ /dev/null @@ -1,91 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import warnings -from typing import TYPE_CHECKING, Optional - -from airflow.models import BaseOperator -from airflow.providers.tableau.operators.tableau import TableauOperator - -if TYPE_CHECKING: - from airflow.utils.context import Context - - -warnings.warn( - """This operator is deprecated. Please use `airflow.providers.tableau.operators.tableau`.""", - DeprecationWarning, - stacklevel=2, -) - - -class TableauRefreshWorkbookOperator(BaseOperator): - """ - This operator is deprecated. Please use `airflow.providers.tableau.operators.tableau`. - - Refreshes a Tableau Workbook/Extract - - .. seealso:: https://tableau.github.io/server-client-python/docs/api-ref#workbooks - - :param workbook_name: The name of the workbook to refresh. - :param site_id: The id of the site where the workbook belongs to. - :param blocking: Defines if the job waits until the refresh has finished. - Default: True. - :param tableau_conn_id: The :ref:`Tableau Connection id ` - containing the credentials to authenticate to the Tableau Server. Default: - 'tableau_default'. - :param check_interval: time in seconds that the job should wait in - between each instance state checks until operation is completed - """ - - def __init__( - self, - *, - workbook_name: str, - site_id: Optional[str] = None, - blocking: bool = True, - tableau_conn_id: str = 'tableau_default', - check_interval: float = 20, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.workbook_name = workbook_name - self.site_id = site_id - self.blocking = blocking - self.tableau_conn_id = tableau_conn_id - self.check_interval = check_interval - - def execute(self, context: 'Context') -> str: - """ - Executes the Tableau Extract Refresh and pushes the job id to xcom. - - :param context: The task context during execution. - :return: the id of the job that executes the extract refresh - :rtype: str - """ - job_id = TableauOperator( - resource='workbooks', - method='refresh', - find=self.workbook_name, - match_with='name', - site_id=self.site_id, - tableau_conn_id=self.tableau_conn_id, - blocking_refresh=self.blocking, - check_interval=self.check_interval, - task_id='refresh_workbook', - dag=None, - ).execute(context=context) - - return job_id diff --git a/airflow/providers/tableau/provider.yaml b/airflow/providers/tableau/provider.yaml index 5ced7549deb31..9a11fc78823f6 100644 --- a/airflow/providers/tableau/provider.yaml +++ b/airflow/providers/tableau/provider.yaml @@ -22,6 +22,9 @@ description: | `Tableau `__ versions: + - 4.0.0 + - 3.0.1 + - 3.0.0 - 2.1.8 - 2.1.7 - 2.1.6 @@ -34,8 +37,9 @@ versions: - 2.0.0 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - tableauserverclient integrations: - integration-name: Tableau @@ -49,12 +53,10 @@ operators: - integration-name: Tableau python-modules: - airflow.providers.tableau.operators.tableau - - airflow.providers.tableau.operators.tableau_refresh_workbook sensors: - integration-name: Tableau python-modules: - - airflow.providers.tableau.sensors.tableau_job_status - airflow.providers.tableau.sensors.tableau hooks: @@ -62,9 +64,6 @@ hooks: python-modules: - airflow.providers.tableau.hooks.tableau -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.tableau.hooks.tableau.TableauHook - connection-types: - hook-class-name: airflow.providers.tableau.hooks.tableau.TableauHook connection-type: tableau diff --git a/airflow/providers/tableau/sensors/tableau.py b/airflow/providers/tableau/sensors/tableau.py index 7602d28076051..936e8e9f262e5 100644 --- a/airflow/providers/tableau/sensors/tableau.py +++ b/airflow/providers/tableau/sensors/tableau.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.providers.tableau.hooks.tableau import ( TableauHook, @@ -39,14 +41,14 @@ class TableauJobStatusSensor(BaseSensorOperator): containing the credentials to authenticate to the Tableau Server. """ - template_fields: Sequence[str] = ('job_id',) + template_fields: Sequence[str] = ("job_id",) def __init__( self, *, job_id: str, - site_id: Optional[str] = None, - tableau_conn_id: str = 'tableau_default', + site_id: str | None = None, + tableau_conn_id: str = "tableau_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -54,19 +56,18 @@ def __init__( self.job_id = job_id self.site_id = site_id - def poke(self, context: 'Context') -> bool: + def poke(self, context: Context) -> bool: """ Pokes until the job has successfully finished. :param context: The task context during execution. :return: True if it succeeded and False if not. - :rtype: bool """ with TableauHook(self.site_id, self.tableau_conn_id) as tableau_hook: finish_code = tableau_hook.get_job_status(job_id=self.job_id) - self.log.info('Current finishCode is %s (%s)', finish_code.name, finish_code.value) + self.log.info("Current finishCode is %s (%s)", finish_code.name, finish_code.value) if finish_code in (TableauJobFinishCode.ERROR, TableauJobFinishCode.CANCELED): - raise TableauJobFailedException('The Tableau Refresh Workbook Job failed!') + raise TableauJobFailedException("The Tableau Refresh Workbook Job failed!") return finish_code == TableauJobFinishCode.SUCCESS diff --git a/airflow/providers/tableau/sensors/tableau_job_status.py b/airflow/providers/tableau/sensors/tableau_job_status.py deleted file mode 100644 index d2c39c3c196b2..0000000000000 --- a/airflow/providers/tableau/sensors/tableau_job_status.py +++ /dev/null @@ -1,28 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.tableau.sensors.tableau`.""" - -import warnings - -from airflow.providers.tableau.sensors.tableau import TableauJobStatusSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.tableau.sensors.tableau`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/providers/tabular/.latest-doc-only-change.txt b/airflow/providers/tabular/.latest-doc-only-change.txt new file mode 100644 index 0000000000000..ff7136e07d744 --- /dev/null +++ b/airflow/providers/tabular/.latest-doc-only-change.txt @@ -0,0 +1 @@ +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/tabular/CHANGELOG.rst b/airflow/providers/tabular/CHANGELOG.rst new file mode 100644 index 0000000000000..6e102b63012b4 --- /dev/null +++ b/airflow/providers/tabular/CHANGELOG.rst @@ -0,0 +1,49 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Changelog +--------- + +1.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +1.0.1 +..... + +Bug Fixes +~~~~~~~~~ + +* ``Set grant type of the Tabular hook (#25099)`` + +1.0.0 +..... + +Initial version of the provider. diff --git a/airflow/mypy/plugin/__init__.py b/airflow/providers/tabular/__init__.py similarity index 100% rename from airflow/mypy/plugin/__init__.py rename to airflow/providers/tabular/__init__.py diff --git a/airflow/providers/alibaba/cloud/example_dags/__init__.py b/airflow/providers/tabular/hooks/__init__.py similarity index 100% rename from airflow/providers/alibaba/cloud/example_dags/__init__.py rename to airflow/providers/tabular/hooks/__init__.py diff --git a/airflow/providers/tabular/hooks/tabular.py b/airflow/providers/tabular/hooks/tabular.py new file mode 100644 index 0000000000000..9390a05a040cd --- /dev/null +++ b/airflow/providers/tabular/hooks/tabular.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any + +import requests +from requests import HTTPError + +from airflow.hooks.base import BaseHook + +DEFAULT_TABULAR_URL = "https://api.tabulardata.io/ws/v1" + +TOKENS_ENDPOINT = "oauth/tokens" + + +class TabularHook(BaseHook): + """ + This hook acts as a base hook for tabular services. It offers the ability to generate temporary, + short-lived session tokens to use within Airflow submitted jobs. + + :param tabular_conn_id: The :ref:`Tabular connection id` + which refers to the information to connect to the Tabular OAuth service. + """ + + conn_name_attr = "tabular_conn_id" + default_conn_name = "tabular_default" + conn_type = "tabular" + hook_name = "Tabular" + + @staticmethod + def get_ui_field_behaviour() -> dict[str, Any]: + """Returns custom field behaviour""" + return { + "hidden_fields": ["schema", "port"], + "relabeling": { + "host": "Base URL", + "login": "Client ID", + "password": "Client Secret", + }, + "placeholders": { + "host": DEFAULT_TABULAR_URL, + "login": "client_id (token credentials auth)", + "password": "secret (token credentials auth)", + }, + } + + def __init__(self, tabular_conn_id: str = default_conn_name) -> None: + super().__init__() + self.conn_id = tabular_conn_id + + def test_connection(self) -> tuple[bool, str]: + """Test the Tabular connection.""" + try: + self.get_conn() + return True, "Successfully fetched token from Tabular" + except HTTPError as e: + return False, f"HTTP Error: {e}: {e.response.text}" + except Exception as e: + return False, str(e) + + def get_conn(self) -> str: + """Obtain a short-lived access token via a client_id and client_secret.""" + conn = self.get_connection(self.conn_id) + base_url = conn.host if conn.host else DEFAULT_TABULAR_URL + base_url = base_url.rstrip("/") + client_id = conn.login + client_secret = conn.password + data = {"client_id": client_id, "client_secret": client_secret, "grant_type": "client_credentials"} + + response = requests.post(f"{base_url}/{TOKENS_ENDPOINT}", data=data) + response.raise_for_status() + + return response.json()["access_token"] + + def get_token_macro(self): + return f"{{{{ conn.{self.conn_id}.get_hook().get_conn() }}}}" diff --git a/airflow/providers/tabular/provider.yaml b/airflow/providers/tabular/provider.yaml new file mode 100644 index 0000000000000..28c265fc98158 --- /dev/null +++ b/airflow/providers/tabular/provider.yaml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-tabular +name: Tabular +description: | + `Tabular `__ + +versions: + - 1.1.0 + - 1.0.1 + - 1.0.0 + +dependencies: + - apache-airflow>=2.3.0 + +integrations: + - integration-name: Tabular + external-doc-url: https://tabular.io/docs/ + logo: /integration-logos/tabular/tabular.jpeg + tags: [software] + +hooks: + - integration-name: Tabular + python-modules: + - airflow.providers.tabular.hooks.tabular + +connection-types: + - hook-class-name: airflow.providers.tabular.hooks.tabular.TabularHook + connection-type: tabular diff --git a/airflow/providers/telegram/.latest-doc-only-change.txt b/airflow/providers/telegram/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/telegram/.latest-doc-only-change.txt +++ b/airflow/providers/telegram/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/telegram/CHANGELOG.rst b/airflow/providers/telegram/CHANGELOG.rst index 3ceb2435aa41b..a20efd2b006fd 100644 --- a/airflow/providers/telegram/CHANGELOG.rst +++ b/airflow/providers/telegram/CHANGELOG.rst @@ -16,9 +16,51 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Update old style typing (#26872)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Migrate Telegram example DAGs to new design #22468 (#24126)`` + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.0.4 ..... diff --git a/airflow/providers/telegram/example_dags/__init__.py b/airflow/providers/telegram/example_dags/__init__.py deleted file mode 100644 index 217e5db960782..0000000000000 --- a/airflow/providers/telegram/example_dags/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/airflow/providers/telegram/example_dags/example_telegram.py b/airflow/providers/telegram/example_dags/example_telegram.py deleted file mode 100644 index 76e71800ecb2d..0000000000000 --- a/airflow/providers/telegram/example_dags/example_telegram.py +++ /dev/null @@ -1,39 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example use of Telegram operator. -""" - -from datetime import datetime - -from airflow import DAG -from airflow.providers.telegram.operators.telegram import TelegramOperator - -dag = DAG('example_telegram', start_date=datetime(2021, 1, 1), tags=['example']) - -# [START howto_operator_telegram] - -send_message_telegram_task = TelegramOperator( - task_id='send_message_telegram', - telegram_conn_id='telegram_conn_id', - chat_id='-3222103937', - text='Hello from Airflow!', - dag=dag, -) - -# [END howto_operator_telegram] diff --git a/airflow/providers/telegram/hooks/telegram.py b/airflow/providers/telegram/hooks/telegram.py index 3de4f90a96ba5..bf3d18fe6551e 100644 --- a/airflow/providers/telegram/hooks/telegram.py +++ b/airflow/providers/telegram/hooks/telegram.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. """Hook for Telegram""" -from typing import Optional +from __future__ import annotations import telegram import tenacity @@ -58,9 +58,9 @@ class TelegramHook(BaseHook): def __init__( self, - telegram_conn_id: Optional[str] = None, - token: Optional[str] = None, - chat_id: Optional[str] = None, + telegram_conn_id: str | None = None, + token: str | None = None, + chat_id: str | None = None, ) -> None: super().__init__() self.token = self.__get_token(token, telegram_conn_id) @@ -72,18 +72,16 @@ def get_conn(self) -> telegram.bot.Bot: Returns the telegram bot client :return: telegram bot client - :rtype: telegram.bot.Bot """ return telegram.bot.Bot(token=self.token) - def __get_token(self, token: Optional[str], telegram_conn_id: Optional[str]) -> str: + def __get_token(self, token: str | None, telegram_conn_id: str | None) -> str: """ Returns the telegram API token :param token: telegram API token :param telegram_conn_id: telegram connection name :return: telegram API token - :rtype: str """ if token is not None: return token @@ -98,14 +96,13 @@ def __get_token(self, token: Optional[str], telegram_conn_id: Optional[str]) -> raise AirflowException("Cannot get token: No valid Telegram connection supplied.") - def __get_chat_id(self, chat_id: Optional[str], telegram_conn_id: Optional[str]) -> Optional[str]: + def __get_chat_id(self, chat_id: str | None, telegram_conn_id: str | None) -> str | None: """ Returns the telegram chat ID for a chat/channel/group :param chat_id: optional chat ID :param telegram_conn_id: telegram connection name :return: telegram chat ID - :rtype: str """ if chat_id is not None: return chat_id @@ -134,10 +131,10 @@ def send_message(self, api_params: dict) -> None: } kwargs.update(api_params) - if 'text' not in kwargs or kwargs['text'] is None: + if "text" not in kwargs or kwargs["text"] is None: raise AirflowException("'text' must be provided for telegram message") - if kwargs['chat_id'] is None: + if kwargs["chat_id"] is None: raise AirflowException("'chat_id' must be provided for telegram message") response = self.connection.send_message(**kwargs) diff --git a/airflow/providers/telegram/operators/telegram.py b/airflow/providers/telegram/operators/telegram.py index 59a92f6546746..e0e3ffc3b8273 100644 --- a/airflow/providers/telegram/operators/telegram.py +++ b/airflow/providers/telegram/operators/telegram.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """Operator for Telegram""" -from typing import TYPE_CHECKING, Optional, Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -43,17 +45,17 @@ class TelegramOperator(BaseOperator): :param telegram_kwargs: Extra args to be passed to telegram client """ - template_fields: Sequence[str] = ('text', 'chat_id') - ui_color = '#FFBA40' + template_fields: Sequence[str] = ("text", "chat_id") + ui_color = "#FFBA40" def __init__( self, *, telegram_conn_id: str = "telegram_default", - token: Optional[str] = None, - chat_id: Optional[str] = None, + token: str | None = None, + chat_id: str | None = None, text: str = "No message has been set.", - telegram_kwargs: Optional[dict] = None, + telegram_kwargs: dict | None = None, **kwargs, ): self.chat_id = chat_id @@ -68,10 +70,10 @@ def __init__( super().__init__(**kwargs) - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: """Calls the TelegramHook to post the provided Telegram message""" if self.text: - self.telegram_kwargs['text'] = self.text + self.telegram_kwargs["text"] = self.text telegram_hook = TelegramHook( telegram_conn_id=self.telegram_conn_id, diff --git a/airflow/providers/telegram/provider.yaml b/airflow/providers/telegram/provider.yaml index daa3e1e0db33a..fb48632e8c766 100644 --- a/airflow/providers/telegram/provider.yaml +++ b/airflow/providers/telegram/provider.yaml @@ -22,6 +22,8 @@ description: | `Telegram `__ versions: + - 3.1.0 + - 3.0.0 - 2.0.4 - 2.0.3 - 2.0.2 @@ -31,8 +33,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - python-telegram-bot>=13.0 integrations: - integration-name: Telegram diff --git a/airflow/providers/trino/CHANGELOG.rst b/airflow/providers/trino/CHANGELOG.rst index 32f091423e44a..0f0ffb335f158 100644 --- a/airflow/providers/trino/CHANGELOG.rst +++ b/airflow/providers/trino/CHANGELOG.rst @@ -16,9 +16,114 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +4.2.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` +* ``Bump Trino version to fix non-working DML queries (#27168)`` + +Features +~~~~~~~~ + +* ``Add SQLExecuteQueryOperator (#25717)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Allow setting client tags for trino connection (#27213)`` + * ``Use DbApiHook.run for DbApiHook.get_records and DbApiHook.get_first (#26944)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Allow session properties for trino connection (#27095)`` + +4.1.0 +..... + +Features +~~~~~~~~ + +* ``trino: Support CertificateAuthentication in the trino hook (#26246)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +4.0.1 +..... + +Features +~~~~~~~~ + +* ``Add common-sql lower bound for common-sql (#25789)`` + +Bug Fixes +~~~~~~~~~ + +* ``Fix placeholders in 'TrinoHook', 'PrestoHook', 'SqliteHook' (#25939)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + +4.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +Deprecated ``hql`` parameter has been removed in ``get_records``, ``get_first``, ``get_pandas_df`` and ``run`` +methods of the ``TrinoHook``. + +* ``Deprecate hql parameters and synchronize DBApiHook method APIs (#25299)`` + +Features +~~~~~~~~ + +* ``Unify DbApiHook.run() method with the methods which override it (#23971)`` + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Move all SQL classes to common-sql provider (#24836)`` +* ``Add test_connection method to Trino hook (#24583)`` +* ``Add 'on_kill()' to kill Trino query if the task is killed (#24559)`` +* ``Add TrinoOperator (#24415)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``AIP-47 | Migrate Trino example DAGs to new design (#24118)`` + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.3.0 ..... diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py index a6295f34ea42d..63c75446ead19 100644 --- a/airflow/providers/trino/hooks/trino.py +++ b/airflow/providers/trino/hooks/trino.py @@ -15,10 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import json import os -import warnings -from typing import Any, Callable, Iterable, Optional, overload +from typing import Any, Callable, Iterable, Mapping import trino from trino.exceptions import DatabaseError @@ -26,8 +27,8 @@ from airflow import AirflowException from airflow.configuration import conf -from airflow.hooks.dbapi import DbApiHook from airflow.models import Connection +from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING try: @@ -36,27 +37,24 @@ # This is from airflow.utils.operator_helpers, # For the sake of provider backward compatibility, this is hardcoded if import fails # https://github.com/apache/airflow/pull/22416#issuecomment-1075531290 - DEFAULT_FORMAT_PREFIX = 'airflow.ctx.' + DEFAULT_FORMAT_PREFIX = "airflow.ctx." def generate_trino_client_info() -> str: """Return json string with dag_id, task_id, execution_date and try_number""" context_var = { - format_map['default'].replace(DEFAULT_FORMAT_PREFIX, ''): os.environ.get( - format_map['env_var_format'], '' + format_map["default"].replace(DEFAULT_FORMAT_PREFIX, ""): os.environ.get( + format_map["env_var_format"], "" ) for format_map in AIRFLOW_VAR_NAME_FORMAT_MAPPING.values() } - # try_number isn't available in context for airflow < 2.2.5 - # https://github.com/apache/airflow/issues/23059 - try_number = context_var.get('try_number', '') task_info = { - 'dag_id': context_var['dag_id'], - 'task_id': context_var['task_id'], - 'execution_date': context_var['execution_date'], - 'try_number': try_number, - 'dag_run_id': context_var['dag_run_id'], - 'dag_owner': context_var['dag_owner'], + "dag_id": context_var["dag_id"], + "task_id": context_var["task_id"], + "execution_date": context_var["execution_date"], + "try_number": context_var["try_number"], + "dag_run_id": context_var["dag_run_id"], + "dag_owner": context_var["dag_owner"], } return json.dumps(task_info, sort_keys=True) @@ -69,9 +67,9 @@ def _boolify(value): if isinstance(value, bool): return value if isinstance(value, str): - if value.lower() == 'false': + if value.lower() == "false": return False - elif value.lower() == 'true': + elif value.lower() == "true": return True return value @@ -86,10 +84,13 @@ class TrinoHook(DbApiHook): [[340698]] """ - conn_name_attr = 'trino_conn_id' - default_conn_name = 'trino_default' - conn_type = 'trino' - hook_name = 'Trino' + conn_name_attr = "trino_conn_id" + default_conn_name = "trino_default" + conn_type = "trino" + hook_name = "Trino" + query_id = "" + placeholder = "?" + _test_connection_sql = "select 1" def get_conn(self) -> Connection: """Returns a connection object""" @@ -97,29 +98,34 @@ def get_conn(self) -> Connection: extra = db.extra_dejson auth = None user = db.login - if db.password and extra.get('auth') == 'kerberos': - raise AirflowException("Kerberos authorization doesn't support password.") + if db.password and extra.get("auth") in ("kerberos", "certs"): + raise AirflowException(f"The {extra.get('auth')!r} authorization type doesn't support password.") elif db.password: auth = trino.auth.BasicAuthentication(db.login, db.password) # type: ignore[attr-defined] - elif extra.get('auth') == 'jwt': - auth = trino.auth.JWTAuthentication(token=extra.get('jwt__token')) - elif extra.get('auth') == 'kerberos': + elif extra.get("auth") == "jwt": + auth = trino.auth.JWTAuthentication(token=extra.get("jwt__token")) + elif extra.get("auth") == "certs": + auth = trino.auth.CertificateAuthentication( + extra.get("certs__client_cert_path"), + extra.get("certs__client_key_path"), + ) + elif extra.get("auth") == "kerberos": auth = trino.auth.KerberosAuthentication( # type: ignore[attr-defined] - config=extra.get('kerberos__config', os.environ.get('KRB5_CONFIG')), - service_name=extra.get('kerberos__service_name'), - mutual_authentication=_boolify(extra.get('kerberos__mutual_authentication', False)), - force_preemptive=_boolify(extra.get('kerberos__force_preemptive', False)), - hostname_override=extra.get('kerberos__hostname_override'), + config=extra.get("kerberos__config", os.environ.get("KRB5_CONFIG")), + service_name=extra.get("kerberos__service_name"), + mutual_authentication=_boolify(extra.get("kerberos__mutual_authentication", False)), + force_preemptive=_boolify(extra.get("kerberos__force_preemptive", False)), + hostname_override=extra.get("kerberos__hostname_override"), sanitize_mutual_error_response=_boolify( - extra.get('kerberos__sanitize_mutual_error_response', True) + extra.get("kerberos__sanitize_mutual_error_response", True) ), - principal=extra.get('kerberos__principal', conf.get('kerberos', 'principal')), - delegate=_boolify(extra.get('kerberos__delegate', False)), - ca_bundle=extra.get('kerberos__ca_bundle'), + principal=extra.get("kerberos__principal", conf.get("kerberos", "principal")), + delegate=_boolify(extra.get("kerberos__delegate", False)), + ca_bundle=extra.get("kerberos__ca_bundle"), ) - if _boolify(extra.get('impersonate_as_owner', False)): - user = os.getenv('AIRFLOW_CTX_DAG_OWNER', None) + if _boolify(extra.get("impersonate_as_owner", False)): + user = os.getenv("AIRFLOW_CTX_DAG_OWNER", None) if user is None: user = db.login http_headers = {"X-Trino-Client-Info": generate_trino_client_info()} @@ -127,15 +133,17 @@ def get_conn(self) -> Connection: host=db.host, port=db.port, user=user, - source=extra.get('source', 'airflow'), - http_scheme=extra.get('protocol', 'http'), + source=extra.get("source", "airflow"), + http_scheme=extra.get("protocol", "http"), http_headers=http_headers, - catalog=extra.get('catalog', 'hive'), + catalog=extra.get("catalog", "hive"), schema=db.schema, auth=auth, # type: ignore[func-returns-value] isolation_level=self.get_isolation_level(), - verify=_boolify(extra.get('verify', True)), + verify=_boolify(extra.get("verify", True)), + session_properties=extra.get("session_properties") or None, + client_tags=extra.get("client_tags") or None, ) return trino_conn @@ -143,100 +151,37 @@ def get_conn(self) -> Connection: def get_isolation_level(self) -> Any: """Returns an isolation level""" db = self.get_connection(self.trino_conn_id) # type: ignore[attr-defined] - isolation_level = db.extra_dejson.get('isolation_level', 'AUTOCOMMIT').upper() + isolation_level = db.extra_dejson.get("isolation_level", "AUTOCOMMIT").upper() return getattr(IsolationLevel, isolation_level, IsolationLevel.AUTOCOMMIT) - @staticmethod - def _strip_sql(sql: str) -> str: - return sql.strip().rstrip(';') - - @overload - def get_records(self, sql: str = "", parameters: Optional[dict] = None): - """Get a set of records from Trino - - :param sql: SQL statement to be executed. - :param parameters: The parameters to render the SQL query with. - """ - - @overload - def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""): - """:sphinx-autoapi-skip:""" - - def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""): - """:sphinx-autoapi-skip:""" - if hql: - warnings.warn( - "The hql parameter has been deprecated. You should pass the sql parameter.", - DeprecationWarning, - stacklevel=2, - ) - sql = hql - + def get_records( + self, + sql: str | list[str] = "", + parameters: Iterable | Mapping | None = None, + ) -> Any: + if not isinstance(sql, str): + raise ValueError(f"The sql in Trino Hook must be a string and is {sql}!") try: - return super().get_records(self._strip_sql(sql), parameters) + return super().get_records(self.strip_sql_string(sql), parameters) except DatabaseError as e: raise TrinoException(e) - @overload - def get_first(self, sql: str = "", parameters: Optional[dict] = None) -> Any: - """Returns only the first row, regardless of how many rows the query returns. - - :param sql: SQL statement to be executed. - :param parameters: The parameters to render the SQL query with. - """ - - @overload - def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any: - """:sphinx-autoapi-skip:""" - - def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any: - """:sphinx-autoapi-skip:""" - if hql: - warnings.warn( - "The hql parameter has been deprecated. You should pass the sql parameter.", - DeprecationWarning, - stacklevel=2, - ) - sql = hql - + def get_first(self, sql: str | list[str] = "", parameters: Iterable | Mapping | None = None) -> Any: + if not isinstance(sql, str): + raise ValueError(f"The sql in Trino Hook must be a string and is {sql}!") try: - return super().get_first(self._strip_sql(sql), parameters) + return super().get_first(self.strip_sql_string(sql), parameters) except DatabaseError as e: raise TrinoException(e) - @overload def get_pandas_df( - self, sql: str = "", parameters: Optional[dict] = None, **kwargs + self, sql: str = "", parameters: Iterable | Mapping | None = None, **kwargs ): # type: ignore[override] - """Get a pandas dataframe from a sql query. - - :param sql: SQL statement to be executed. - :param parameters: The parameters to render the SQL query with. - """ - - @overload - def get_pandas_df( - self, sql: str = "", parameters: Optional[dict] = None, hql: str = "", **kwargs - ): # type: ignore[override] - """:sphinx-autoapi-skip:""" - - def get_pandas_df( - self, sql: str = "", parameters: Optional[dict] = None, hql: str = "", **kwargs - ): # type: ignore[override] - """:sphinx-autoapi-skip:""" - if hql: - warnings.warn( - "The hql parameter has been deprecated. You should pass the sql parameter.", - DeprecationWarning, - stacklevel=2, - ) - sql = hql - import pandas cursor = self.get_cursor() try: - cursor.execute(self._strip_sql(sql), parameters) + cursor.execute(self.strip_sql_string(sql), parameters) data = cursor.fetchall() except DatabaseError as e: raise TrinoException(e) @@ -248,53 +193,29 @@ def get_pandas_df( df = pandas.DataFrame(**kwargs) return df - @overload - def run( - self, - sql: str = "", - autocommit: bool = False, - parameters: Optional[dict] = None, - handler: Optional[Callable] = None, - ) -> None: - """Execute the statement against Trino. Can be used to create views.""" - - @overload - def run( - self, - sql: str = "", - autocommit: bool = False, - parameters: Optional[dict] = None, - handler: Optional[Callable] = None, - hql: str = "", - ) -> None: - """:sphinx-autoapi-skip:""" - def run( self, - sql: str = "", + sql: str | Iterable[str], autocommit: bool = False, - parameters: Optional[dict] = None, - handler: Optional[Callable] = None, - hql: str = "", - ) -> None: - """:sphinx-autoapi-skip:""" - if hql: - warnings.warn( - "The hql parameter has been deprecated. You should pass the sql parameter.", - DeprecationWarning, - stacklevel=2, - ) - sql = hql - + parameters: Iterable | Mapping | None = None, + handler: Callable | None = None, + split_statements: bool = False, + return_last: bool = True, + ) -> Any | list[Any] | None: return super().run( - sql=self._strip_sql(sql), autocommit=autocommit, parameters=parameters, handler=handler + sql=sql, + autocommit=autocommit, + parameters=parameters, + handler=handler, + split_statements=split_statements, + return_last=return_last, ) def insert_rows( self, table: str, rows: Iterable[tuple], - target_fields: Optional[Iterable[str]] = None, + target_fields: Iterable[str] | None = None, commit_every: int = 0, replace: bool = False, **kwargs, @@ -311,10 +232,22 @@ def insert_rows( """ if self.get_isolation_level() == IsolationLevel.AUTOCOMMIT: self.log.info( - 'Transactions are not enable in trino connection. ' - 'Please use the isolation_level property to enable it. ' - 'Falling back to insert all rows in one transaction.' + "Transactions are not enable in trino connection. " + "Please use the isolation_level property to enable it. " + "Falling back to insert all rows in one transaction." ) commit_every = 0 super().insert_rows(table, rows, target_fields, commit_every, replace) + + @staticmethod + def _serialize_cell(cell: Any, conn: Connection | None = None) -> Any: + """ + Trino will adapt all arguments to the execute() method internally, + hence we return cell without any conversion. + + :param cell: The cell to insert into the table + :param conn: The database connection + :return: The cell + """ + return cell diff --git a/airflow/providers/jira/operators/__init__.py b/airflow/providers/trino/operators/__init__.py similarity index 100% rename from airflow/providers/jira/operators/__init__.py rename to airflow/providers/trino/operators/__init__.py diff --git a/airflow/providers/trino/operators/trino.py b/airflow/providers/trino/operators/trino.py new file mode 100644 index 0000000000000..fc2f496ff8118 --- /dev/null +++ b/airflow/providers/trino/operators/trino.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains the Trino operator.""" +from __future__ import annotations + +import warnings +from typing import Any, Sequence + +from trino.exceptions import TrinoQueryError + +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator +from airflow.providers.trino.hooks.trino import TrinoHook + + +class TrinoOperator(SQLExecuteQueryOperator): + """ + Executes sql code using a specific Trino query Engine. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TrinoOperator` + + :param sql: the SQL code to be executed as a single string, or + a list of str (sql statements), or a reference to a template file. + :param trino_conn_id: id of the connection config for the target Trino + environment + :param autocommit: What to set the connection's autocommit setting to + before executing the query + :param handler: The result handler which is called with the result of each statement. + :param parameters: (optional) the parameters to render the SQL query with. + """ + + template_fields: Sequence[str] = ("sql",) + template_fields_renderers = {"sql": "sql"} + template_ext: Sequence[str] = (".sql",) + ui_color = "#ededed" + + def __init__(self, *, trino_conn_id: str = "trino_default", **kwargs: Any) -> None: + super().__init__(conn_id=trino_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""", + DeprecationWarning, + stacklevel=2, + ) + + def on_kill(self) -> None: + if self._hook is not None and isinstance(self._hook, TrinoHook): + query_id = "'" + self._hook.query_id + "'" + try: + self.log.info("Stopping query run with queryId - %s", self._hook.query_id) + self._hook.run( + sql=f"CALL system.runtime.kill_query(query_id => {query_id},message => 'Job " + f"killed by " + f"user');", + handler=list, + ) + except TrinoQueryError as e: + self.log.info(str(e)) + self.log.info("Trino query (%s) terminated", query_id) diff --git a/airflow/providers/trino/provider.yaml b/airflow/providers/trino/provider.yaml index 828035fc63b4f..088afabfb8187 100644 --- a/airflow/providers/trino/provider.yaml +++ b/airflow/providers/trino/provider.yaml @@ -22,6 +22,12 @@ description: | `Trino `__ versions: + - 4.2.0 + - 4.1.0 + - 4.0.1 + - 4.0.0 + - 3.1.0 + - 3.0.0 - 2.3.0 - 2.2.0 - 2.1.2 @@ -32,15 +38,25 @@ versions: - 2.0.0 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-common-sql>=1.3.1 + - pandas>=0.17.1 + - trino>=0.318.0 integrations: - integration-name: Trino external-doc-url: https://trino.io/docs/ logo: /integration-logos/trino/trino-og.png + how-to-guide: + - /docs/apache-airflow-providers-trino/operators/trino.rst tags: [software] +operators: + - integration-name: Trino + python-modules: + - airflow.providers.trino.operators.trino + hooks: - integration-name: Trino python-modules: @@ -52,9 +68,6 @@ transfers: how-to-guide: /docs/apache-airflow-providers-trino/operators/transfer/gcs_to_trino.rst python-module: airflow.providers.trino.transfers.gcs_to_trino -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.trino.hooks.trino.TrinoHook - connection-types: - hook-class-name: airflow.providers.trino.hooks.trino.TrinoHook connection-type: trino diff --git a/airflow/providers/trino/transfers/gcs_to_trino.py b/airflow/providers/trino/transfers/gcs_to_trino.py index 8f00d09fc81b9..721c0b6547cd7 100644 --- a/airflow/providers/trino/transfers/gcs_to_trino.py +++ b/airflow/providers/trino/transfers/gcs_to_trino.py @@ -16,11 +16,12 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Cloud Storage to Trino operator.""" +from __future__ import annotations import csv import json from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union +from typing import TYPE_CHECKING, Iterable, Sequence from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook @@ -61,9 +62,9 @@ class GCSToTrinoOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'source_bucket', - 'source_object', - 'trino_table', + "source_bucket", + "source_object", + "trino_table", ) def __init__( @@ -74,10 +75,10 @@ def __init__( trino_table: str, trino_conn_id: str = "trino_default", gcp_conn_id: str = "google_cloud_default", - schema_fields: Optional[Iterable[str]] = None, - schema_object: Optional[str] = None, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + schema_fields: Iterable[str] | None = None, + schema_object: str | None = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -91,7 +92,7 @@ def __init__( self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> None: gcs_hook = GCSHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, diff --git a/airflow/providers/vertica/CHANGELOG.rst b/airflow/providers/vertica/CHANGELOG.rst index 0155755a2ba05..e2aebdb5ddef5 100644 --- a/airflow/providers/vertica/CHANGELOG.rst +++ b/airflow/providers/vertica/CHANGELOG.rst @@ -16,9 +16,84 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.3.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Features +~~~~~~~~ + +* ``Add SQLExecuteQueryOperator (#25717)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + +3.2.1 +..... + +Misc +~~~~ + +* ``Add common-sql lower bound for common-sql (#25789)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +3.2.0 +..... + +Features +~~~~~~~~ + +* ``Optimize log when using VerticaOperator (#25566)`` +* ``Unify DbApiHook.run() method with the methods which override it (#23971)`` + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``Move all SQL classes to common-sql provider (#24836)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.1.3 ..... diff --git a/airflow/providers/vertica/hooks/vertica.py b/airflow/providers/vertica/hooks/vertica.py index 9530c5965c5a8..434d0d1767c8f 100644 --- a/airflow/providers/vertica/hooks/vertica.py +++ b/airflow/providers/vertica/hooks/vertica.py @@ -15,20 +15,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations from vertica_python import connect -from airflow.hooks.dbapi import DbApiHook +from airflow.providers.common.sql.hooks.sql import DbApiHook class VerticaHook(DbApiHook): """Interact with Vertica.""" - conn_name_attr = 'vertica_conn_id' - default_conn_name = 'vertica_default' - conn_type = 'vertica' - hook_name = 'Vertica' + conn_name_attr = "vertica_conn_id" + default_conn_name = "vertica_default" + conn_type = "vertica" + hook_name = "Vertica" supports_autocommit = True def get_conn(self) -> connect: @@ -36,9 +36,9 @@ def get_conn(self) -> connect: conn = self.get_connection(self.vertica_conn_id) # type: ignore conn_config = { "user": conn.login, - "password": conn.password or '', + "password": conn.password or "", "database": conn.schema, - "host": conn.host or 'localhost', + "host": conn.host or "localhost", } if not conn.port: diff --git a/airflow/providers/vertica/operators/vertica.py b/airflow/providers/vertica/operators/vertica.py index 3a30e0ee2723e..28e4a9f15da20 100644 --- a/airflow/providers/vertica/operators/vertica.py +++ b/airflow/providers/vertica/operators/vertica.py @@ -15,16 +15,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Any, List, Sequence, Union +from __future__ import annotations -from airflow.models import BaseOperator -from airflow.providers.vertica.hooks.vertica import VerticaHook +import warnings +from typing import Any, Sequence -if TYPE_CHECKING: - from airflow.utils.context import Context +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator -class VerticaOperator(BaseOperator): +class VerticaOperator(SQLExecuteQueryOperator): """ Executes sql code in a specific Vertica database. @@ -34,19 +33,16 @@ class VerticaOperator(BaseOperator): Template references are recognized by str ending in '.sql' """ - template_fields: Sequence[str] = ('sql',) - template_ext: Sequence[str] = ('.sql',) - template_fields_renderers = {'sql': 'sql'} - ui_color = '#b4e0ff' - - def __init__( - self, *, sql: Union[str, List[str]], vertica_conn_id: str = 'vertica_default', **kwargs: Any - ) -> None: - super().__init__(**kwargs) - self.vertica_conn_id = vertica_conn_id - self.sql = sql - - def execute(self, context: 'Context') -> None: - self.log.info('Executing: %s', self.sql) - hook = VerticaHook(vertica_conn_id=self.vertica_conn_id) - hook.run(sql=self.sql) + template_fields: Sequence[str] = ("sql",) + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "sql"} + ui_color = "#b4e0ff" + + def __init__(self, *, vertica_conn_id: str = "vertica_default", **kwargs: Any) -> None: + super().__init__(conn_id=vertica_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""", + DeprecationWarning, + stacklevel=2, + ) diff --git a/airflow/providers/vertica/provider.yaml b/airflow/providers/vertica/provider.yaml index 81e4127560432..023a7880096fd 100644 --- a/airflow/providers/vertica/provider.yaml +++ b/airflow/providers/vertica/provider.yaml @@ -22,6 +22,11 @@ description: | `Vertica `__ versions: + - 3.3.0 + - 3.2.1 + - 3.2.0 + - 3.1.0 + - 3.0.0 - 2.1.3 - 2.1.2 - 2.1.1 @@ -31,8 +36,10 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - apache-airflow-providers-common-sql>=1.3.1 + - vertica-python>=0.5.1 integrations: - integration-name: Vertica @@ -50,9 +57,6 @@ hooks: python-modules: - airflow.providers.vertica.hooks.vertica -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.vertica.hooks.vertica.VerticaHook - connection-types: - hook-class-name: airflow.providers.vertica.hooks.vertica.VerticaHook connection-type: vertica diff --git a/airflow/providers/yandex/.latest-doc-only-change.txt b/airflow/providers/yandex/.latest-doc-only-change.txt index 28124098645cf..ff7136e07d744 100644 --- a/airflow/providers/yandex/.latest-doc-only-change.txt +++ b/airflow/providers/yandex/.latest-doc-only-change.txt @@ -1 +1 @@ -6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/yandex/CHANGELOG.rst b/airflow/providers/yandex/CHANGELOG.rst index ebf4898a731a1..5d1f508b3debb 100644 --- a/airflow/providers/yandex/CHANGELOG.rst +++ b/airflow/providers/yandex/CHANGELOG.rst @@ -16,9 +16,81 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +3.2.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Features +~~~~~~~~ + +* ``Allow no extra prefix in yandex hook (#27040)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + +4.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.3+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +Misc +~~~~ + +* In YandexCloudBaseHook, non-prefixed extra fields are supported and are preferred. E.g. ``folder_id`` will + be preferred if ``extra__yandexcloud__folder_id`` is also present. + +3.1.0 +..... + +Features +~~~~~~~~ + +* ``YandexCloud provider: Support new Yandex SDK features for DataProc (#25158)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + * ``Remove 'hook-class-names' from provider.yaml (#24702)`` + +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Migrate Yandex example DAGs to new design AIP-47 (#24082)`` + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 2.2.3 ..... diff --git a/airflow/providers/yandex/hooks/yandex.py b/airflow/providers/yandex/hooks/yandex.py index a337954496241..79feb86882492 100644 --- a/airflow/providers/yandex/hooks/yandex.py +++ b/airflow/providers/yandex/hooks/yandex.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import json import warnings -from typing import Any, Dict, Optional, Union +from typing import Any import yandexcloud @@ -32,55 +33,55 @@ class YandexCloudBaseHook(BaseHook): :param yandex_conn_id: The connection ID to use when fetching connection info. """ - conn_name_attr = 'yandex_conn_id' - default_conn_name = 'yandexcloud_default' - conn_type = 'yandexcloud' - hook_name = 'Yandex Cloud' + conn_name_attr = "yandex_conn_id" + default_conn_name = "yandexcloud_default" + conn_type = "yandexcloud" + hook_name = "Yandex Cloud" @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: + def get_connection_form_widgets() -> dict[str, Any]: """Returns connection widgets to add to connection form""" from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget from flask_babel import lazy_gettext from wtforms import PasswordField, StringField return { - "extra__yandexcloud__service_account_json": PasswordField( - lazy_gettext('Service account auth JSON'), + "service_account_json": PasswordField( + lazy_gettext("Service account auth JSON"), widget=BS3PasswordFieldWidget(), - description='Service account auth JSON. Looks like ' + description="Service account auth JSON. Looks like " '{"id", "...", "service_account_id": "...", "private_key": "..."}. ' - 'Will be used instead of OAuth token and SA JSON file path field if specified.', + "Will be used instead of OAuth token and SA JSON file path field if specified.", ), - "extra__yandexcloud__service_account_json_path": StringField( - lazy_gettext('Service account auth JSON file path'), + "service_account_json_path": StringField( + lazy_gettext("Service account auth JSON file path"), widget=BS3TextFieldWidget(), - description='Service account auth JSON file path. File content looks like ' + description="Service account auth JSON file path. File content looks like " '{"id", "...", "service_account_id": "...", "private_key": "..."}. ' - 'Will be used instead of OAuth token if specified.', + "Will be used instead of OAuth token if specified.", ), - "extra__yandexcloud__oauth": PasswordField( - lazy_gettext('OAuth Token'), + "oauth": PasswordField( + lazy_gettext("OAuth Token"), widget=BS3PasswordFieldWidget(), - description='User account OAuth token. ' - 'Either this or service account JSON must be specified.', + description="User account OAuth token. " + "Either this or service account JSON must be specified.", ), - "extra__yandexcloud__folder_id": StringField( - lazy_gettext('Default folder ID'), + "folder_id": StringField( + lazy_gettext("Default folder ID"), widget=BS3TextFieldWidget(), - description='Optional. This folder will be used ' - 'to create all new clusters and nodes by default', + description="Optional. This folder will be used " + "to create all new clusters and nodes by default", ), - "extra__yandexcloud__public_ssh_key": StringField( - lazy_gettext('Public SSH key'), + "public_ssh_key": StringField( + lazy_gettext("Public SSH key"), widget=BS3TextFieldWidget(), - description='Optional. This key will be placed to all created Compute nodes' - 'to let you have a root shell there', + description="Optional. This key will be placed to all created Compute nodes" + "to let you have a root shell there", ), } @classmethod - def provider_user_agent(cls) -> Optional[str]: + def provider_user_agent(cls) -> str | None: """Construct User-Agent from Airflow core & provider package versions""" import airflow from airflow.providers_manager import ProvidersManager @@ -89,26 +90,26 @@ def provider_user_agent(cls) -> Optional[str]: manager = ProvidersManager() provider_name = manager.hooks[cls.conn_type].package_name # type: ignore[union-attr] provider = manager.providers[provider_name] - return f'apache-airflow/{airflow.__version__} {provider_name}/{provider.version}' + return f"apache-airflow/{airflow.__version__} {provider_name}/{provider.version}" except KeyError: warnings.warn(f"Hook '{cls.hook_name}' info is not initialized in airflow.ProviderManager") return None @staticmethod - def get_ui_field_behaviour() -> Dict[str, Any]: + def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { - "hidden_fields": ['host', 'schema', 'login', 'password', 'port', 'extra'], + "hidden_fields": ["host", "schema", "login", "password", "port", "extra"], "relabeling": {}, } def __init__( self, # Connection id is deprecated. Use yandex_conn_id instead - connection_id: Optional[str] = None, - yandex_conn_id: Optional[str] = None, - default_folder_id: Union[dict, bool, None] = None, - default_public_ssh_key: Optional[str] = None, + connection_id: str | None = None, + yandex_conn_id: str | None = None, + default_folder_id: str | None = None, + default_public_ssh_key: str | None = None, ) -> None: super().__init__() if connection_id: @@ -122,32 +123,41 @@ def __init__( self.extras = self.connection.extra_dejson credentials = self._get_credentials() self.sdk = yandexcloud.SDK(user_agent=self.provider_user_agent(), **credentials) - self.default_folder_id = default_folder_id or self._get_field('folder_id', False) - self.default_public_ssh_key = default_public_ssh_key or self._get_field('public_ssh_key', False) + self.default_folder_id = default_folder_id or self._get_field("folder_id", False) + self.default_public_ssh_key = default_public_ssh_key or self._get_field("public_ssh_key", False) self.client = self.sdk.client - def _get_credentials(self) -> Dict[str, Any]: - service_account_json_path = self._get_field('service_account_json_path', False) - service_account_json = self._get_field('service_account_json', False) - oauth_token = self._get_field('oauth', False) + def _get_credentials(self) -> dict[str, Any]: + service_account_json_path = self._get_field("service_account_json_path", False) + service_account_json = self._get_field("service_account_json", False) + oauth_token = self._get_field("oauth", False) if not (service_account_json or oauth_token or service_account_json_path): raise AirflowException( - 'No credentials are found in connection. Specify either service account ' - 'authentication JSON or user OAuth token in Yandex.Cloud connection' + "No credentials are found in connection. Specify either service account " + "authentication JSON or user OAuth token in Yandex.Cloud connection" ) if service_account_json_path: with open(service_account_json_path) as infile: service_account_json = infile.read() if service_account_json: service_account_key = json.loads(service_account_json) - return {'service_account_key': service_account_key} + return {"service_account_key": service_account_key} else: - return {'token': oauth_token} + return {"token": oauth_token} def _get_field(self, field_name: str, default: Any = None) -> Any: - """Fetches a field from extras, and returns it.""" - long_f = f'extra__yandexcloud__{field_name}' - if hasattr(self, 'extras') and long_f in self.extras: - return self.extras[long_f] - else: + """Get field from extra, first checking short name, then for backcompat we check for prefixed name.""" + if not hasattr(self, "extras"): return default + backcompat_prefix = "extra__yandexcloud__" + if field_name.startswith("extra__"): + raise ValueError( + f"Got prefixed name {field_name}; please remove the '{backcompat_prefix}' prefix " + "when using this method." + ) + if field_name in self.extras: + return self.extras[field_name] + prefixed_name = f"{backcompat_prefix}{field_name}" + if prefixed_name in self.extras: + return self.extras[prefixed_name] + return default diff --git a/airflow/providers/yandex/hooks/yandexcloud_dataproc.py b/airflow/providers/yandex/hooks/yandexcloud_dataproc.py index 6597dba5ee0ef..9b1862205e466 100644 --- a/airflow/providers/yandex/hooks/yandexcloud_dataproc.py +++ b/airflow/providers/yandex/hooks/yandexcloud_dataproc.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook diff --git a/airflow/providers/yandex/operators/yandexcloud_dataproc.py b/airflow/providers/yandex/operators/yandexcloud_dataproc.py index 1a9dd1acf05bc..625827d109963 100644 --- a/airflow/providers/yandex/operators/yandexcloud_dataproc.py +++ b/airflow/providers/yandex/operators/yandexcloud_dataproc.py @@ -14,8 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, Dict, Iterable, Optional, Sequence, Union +import warnings +from dataclasses import dataclass +from typing import TYPE_CHECKING, Iterable, Sequence from airflow.models import BaseOperator from airflow.providers.yandex.hooks.yandexcloud_dataproc import DataprocHook @@ -24,6 +27,15 @@ from airflow.utils.context import Context +@dataclass +class InitializationAction: + """Data for initialization action to be run at start of DataProc cluster.""" + + uri: str # Uri of the executable file + args: Sequence[str] # Arguments to the initialization action + timeout: int # Execution timeout + + class DataprocCreateClusterOperator(BaseOperator): """Creates Yandex.Cloud Data Proc cluster. @@ -69,44 +81,60 @@ class DataprocCreateClusterOperator(BaseOperator): in percents. 10-100. By default is not set and default autoscaling strategy is used. :param computenode_decommission_timeout: Timeout to gracefully decommission nodes during downscaling. - In seconds. + In seconds + :param properties: Properties passed to main node software. + Docs: https://cloud.yandex.com/docs/data-proc/concepts/settings-list + :param enable_ui_proxy: Enable UI Proxy feature for forwarding Hadoop components web interfaces + Docs: https://cloud.yandex.com/docs/data-proc/concepts/ui-proxy + :param host_group_ids: Dedicated host groups to place VMs of cluster on. + Docs: https://cloud.yandex.com/docs/compute/concepts/dedicated-host + :param security_group_ids: User security groups. + Docs: https://cloud.yandex.com/docs/data-proc/concepts/network#security-groups :param log_group_id: Id of log group to write logs. By default logs will be sent to default log group. To disable cloud log sending set cluster property dataproc:disable_cloud_logging = true + Docs: https://cloud.yandex.com/docs/data-proc/concepts/logs + :param initialization_actions: Set of init-actions to run when cluster starts. + Docs: https://cloud.yandex.com/docs/data-proc/concepts/init-action """ def __init__( self, *, - folder_id: Optional[str] = None, - cluster_name: Optional[str] = None, - cluster_description: Optional[str] = '', - cluster_image_version: Optional[str] = None, - ssh_public_keys: Optional[Union[str, Iterable[str]]] = None, - subnet_id: Optional[str] = None, - services: Iterable[str] = ('HDFS', 'YARN', 'MAPREDUCE', 'HIVE', 'SPARK'), - s3_bucket: Optional[str] = None, - zone: str = 'ru-central1-b', - service_account_id: Optional[str] = None, - masternode_resource_preset: Optional[str] = None, - masternode_disk_size: Optional[int] = None, - masternode_disk_type: Optional[str] = None, - datanode_resource_preset: Optional[str] = None, - datanode_disk_size: Optional[int] = None, - datanode_disk_type: Optional[str] = None, + folder_id: str | None = None, + cluster_name: str | None = None, + cluster_description: str | None = "", + cluster_image_version: str | None = None, + ssh_public_keys: str | Iterable[str] | None = None, + subnet_id: str | None = None, + services: Iterable[str] = ("HDFS", "YARN", "MAPREDUCE", "HIVE", "SPARK"), + s3_bucket: str | None = None, + zone: str = "ru-central1-b", + service_account_id: str | None = None, + masternode_resource_preset: str | None = None, + masternode_disk_size: int | None = None, + masternode_disk_type: str | None = None, + datanode_resource_preset: str | None = None, + datanode_disk_size: int | None = None, + datanode_disk_type: str | None = None, datanode_count: int = 1, - computenode_resource_preset: Optional[str] = None, - computenode_disk_size: Optional[int] = None, - computenode_disk_type: Optional[str] = None, + computenode_resource_preset: str | None = None, + computenode_disk_size: int | None = None, + computenode_disk_type: str | None = None, computenode_count: int = 0, - computenode_max_hosts_count: Optional[int] = None, - computenode_measurement_duration: Optional[int] = None, - computenode_warmup_duration: Optional[int] = None, - computenode_stabilization_duration: Optional[int] = None, + computenode_max_hosts_count: int | None = None, + computenode_measurement_duration: int | None = None, + computenode_warmup_duration: int | None = None, + computenode_stabilization_duration: int | None = None, computenode_preemptible: bool = False, - computenode_cpu_utilization_target: Optional[int] = None, - computenode_decommission_timeout: Optional[int] = None, - connection_id: Optional[str] = None, - log_group_id: Optional[str] = None, + computenode_cpu_utilization_target: int | None = None, + computenode_decommission_timeout: int | None = None, + connection_id: str | None = None, + properties: dict[str, str] | None = None, + enable_ui_proxy: bool = False, + host_group_ids: Iterable[str] | None = None, + security_group_ids: Iterable[str] | None = None, + log_group_id: str | None = None, + initialization_actions: Iterable[InitializationAction] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -139,11 +167,16 @@ def __init__( self.computenode_preemptible = computenode_preemptible self.computenode_cpu_utilization_target = computenode_cpu_utilization_target self.computenode_decommission_timeout = computenode_decommission_timeout + self.properties = properties + self.enable_ui_proxy = enable_ui_proxy + self.host_group_ids = host_group_ids + self.security_group_ids = security_group_ids self.log_group_id = log_group_id + self.initialization_actions = initialization_actions - self.hook: Optional[DataprocHook] = None + self.hook: DataprocHook | None = None - def execute(self, context: 'Context') -> None: + def execute(self, context: Context) -> dict: self.hook = DataprocHook( yandex_conn_id=self.yandex_conn_id, ) @@ -176,41 +209,78 @@ def execute(self, context: 'Context') -> None: computenode_preemptible=self.computenode_preemptible, computenode_cpu_utilization_target=self.computenode_cpu_utilization_target, computenode_decommission_timeout=self.computenode_decommission_timeout, + properties=self.properties, + enable_ui_proxy=self.enable_ui_proxy, + host_group_ids=self.host_group_ids, + security_group_ids=self.security_group_ids, log_group_id=self.log_group_id, + initialization_actions=self.initialization_actions + and [ + self.hook.sdk.wrappers.InitializationAction( + uri=init_action.uri, + args=init_action.args, + timeout=init_action.timeout, + ) + for init_action in self.initialization_actions + ], ) - context['task_instance'].xcom_push(key='cluster_id', value=operation_result.response.id) - context['task_instance'].xcom_push(key='yandexcloud_connection_id', value=self.yandex_conn_id) + cluster_id = operation_result.response.id + context["task_instance"].xcom_push(key="cluster_id", value=cluster_id) + # Deprecated + context["task_instance"].xcom_push(key="yandexcloud_connection_id", value=self.yandex_conn_id) + return cluster_id + + @property + def cluster_id(self): + return self.output -class DataprocDeleteClusterOperator(BaseOperator): - """Deletes Yandex.Cloud Data Proc cluster. + +class DataprocBaseOperator(BaseOperator): + """Base class for DataProc operators working with given cluster. :param connection_id: ID of the Yandex.Cloud Airflow connection. :param cluster_id: ID of the cluster to remove. (templated) """ - template_fields: Sequence[str] = ('cluster_id',) + template_fields: Sequence[str] = ("cluster_id",) - def __init__( - self, *, connection_id: Optional[str] = None, cluster_id: Optional[str] = None, **kwargs - ) -> None: + def __init__(self, *, yandex_conn_id: str | None = None, cluster_id: str | None = None, **kwargs) -> None: super().__init__(**kwargs) - self.yandex_conn_id = connection_id self.cluster_id = cluster_id - self.hook: Optional[DataprocHook] = None + self.yandex_conn_id = yandex_conn_id - def execute(self, context: 'Context') -> None: - cluster_id = self.cluster_id or context['task_instance'].xcom_pull(key='cluster_id') - yandex_conn_id = self.yandex_conn_id or context['task_instance'].xcom_pull( - key='yandexcloud_connection_id' - ) - self.hook = DataprocHook( - yandex_conn_id=yandex_conn_id, - ) - self.hook.client.delete_cluster(cluster_id) + def _setup(self, context: Context) -> DataprocHook: + if self.cluster_id is None: + self.cluster_id = context["task_instance"].xcom_pull(key="cluster_id") + if self.yandex_conn_id is None: + xcom_yandex_conn_id = context["task_instance"].xcom_pull(key="yandexcloud_connection_id") + if xcom_yandex_conn_id: + warnings.warn("Implicit pass of `yandex_conn_id` is deprecated, please pass it explicitly") + self.yandex_conn_id = xcom_yandex_conn_id + + return DataprocHook(yandex_conn_id=self.yandex_conn_id) + + def execute(self, context: Context): + raise NotImplementedError() + + +class DataprocDeleteClusterOperator(DataprocBaseOperator): + """Deletes Yandex.Cloud Data Proc cluster. + + :param connection_id: ID of the Yandex.Cloud Airflow connection. + :param cluster_id: ID of the cluster to remove. (templated) + """ + + def __init__(self, *, connection_id: str | None = None, cluster_id: str | None = None, **kwargs) -> None: + super().__init__(yandex_conn_id=connection_id, cluster_id=cluster_id, **kwargs) + + def execute(self, context: Context) -> None: + hook = self._setup(context) + hook.client.delete_cluster(self.cluster_id) -class DataprocCreateHiveJobOperator(BaseOperator): +class DataprocCreateHiveJobOperator(DataprocBaseOperator): """Runs Hive job in Data Proc cluster. :param query: Hive query. @@ -224,52 +294,41 @@ class DataprocCreateHiveJobOperator(BaseOperator): :param connection_id: ID of the Yandex.Cloud Airflow connection. """ - template_fields: Sequence[str] = ('cluster_id',) - def __init__( self, *, - query: Optional[str] = None, - query_file_uri: Optional[str] = None, - script_variables: Optional[Dict[str, str]] = None, + query: str | None = None, + query_file_uri: str | None = None, + script_variables: dict[str, str] | None = None, continue_on_failure: bool = False, - properties: Optional[Dict[str, str]] = None, - name: str = 'Hive job', - cluster_id: Optional[str] = None, - connection_id: Optional[str] = None, + properties: dict[str, str] | None = None, + name: str = "Hive job", + cluster_id: str | None = None, + connection_id: str | None = None, **kwargs, ) -> None: - super().__init__(**kwargs) + super().__init__(yandex_conn_id=connection_id, cluster_id=cluster_id, **kwargs) self.query = query self.query_file_uri = query_file_uri self.script_variables = script_variables self.continue_on_failure = continue_on_failure self.properties = properties self.name = name - self.cluster_id = cluster_id - self.connection_id = connection_id - self.hook: Optional[DataprocHook] = None - def execute(self, context: 'Context') -> None: - cluster_id = self.cluster_id or context['task_instance'].xcom_pull(key='cluster_id') - yandex_conn_id = self.connection_id or context['task_instance'].xcom_pull( - key='yandexcloud_connection_id' - ) - self.hook = DataprocHook( - yandex_conn_id=yandex_conn_id, - ) - self.hook.client.create_hive_job( + def execute(self, context: Context) -> None: + hook = self._setup(context) + hook.client.create_hive_job( query=self.query, query_file_uri=self.query_file_uri, script_variables=self.script_variables, continue_on_failure=self.continue_on_failure, properties=self.properties, name=self.name, - cluster_id=cluster_id, + cluster_id=self.cluster_id, ) -class DataprocCreateMapReduceJobOperator(BaseOperator): +class DataprocCreateMapReduceJobOperator(DataprocBaseOperator): """Runs Mapreduce job in Data Proc cluster. :param main_jar_file_uri: URI of jar file with job. @@ -286,24 +345,22 @@ class DataprocCreateMapReduceJobOperator(BaseOperator): :param connection_id: ID of the Yandex.Cloud Airflow connection. """ - template_fields: Sequence[str] = ('cluster_id',) - def __init__( self, *, - main_class: Optional[str] = None, - main_jar_file_uri: Optional[str] = None, - jar_file_uris: Optional[Iterable[str]] = None, - archive_uris: Optional[Iterable[str]] = None, - file_uris: Optional[Iterable[str]] = None, - args: Optional[Iterable[str]] = None, - properties: Optional[Dict[str, str]] = None, - name: str = 'Mapreduce job', - cluster_id: Optional[str] = None, - connection_id: Optional[str] = None, + main_class: str | None = None, + main_jar_file_uri: str | None = None, + jar_file_uris: Iterable[str] | None = None, + archive_uris: Iterable[str] | None = None, + file_uris: Iterable[str] | None = None, + args: Iterable[str] | None = None, + properties: dict[str, str] | None = None, + name: str = "Mapreduce job", + cluster_id: str | None = None, + connection_id: str | None = None, **kwargs, ) -> None: - super().__init__(**kwargs) + super().__init__(yandex_conn_id=connection_id, cluster_id=cluster_id, **kwargs) self.main_class = main_class self.main_jar_file_uri = main_jar_file_uri self.jar_file_uris = jar_file_uris @@ -312,19 +369,10 @@ def __init__( self.args = args self.properties = properties self.name = name - self.cluster_id = cluster_id - self.connection_id = connection_id - self.hook: Optional[DataprocHook] = None - def execute(self, context: 'Context') -> None: - cluster_id = self.cluster_id or context['task_instance'].xcom_pull(key='cluster_id') - yandex_conn_id = self.connection_id or context['task_instance'].xcom_pull( - key='yandexcloud_connection_id' - ) - self.hook = DataprocHook( - yandex_conn_id=yandex_conn_id, - ) - self.hook.client.create_mapreduce_job( + def execute(self, context: Context) -> None: + hook = self._setup(context) + hook.client.create_mapreduce_job( main_class=self.main_class, main_jar_file_uri=self.main_jar_file_uri, jar_file_uris=self.jar_file_uris, @@ -333,11 +381,11 @@ def execute(self, context: 'Context') -> None: args=self.args, properties=self.properties, name=self.name, - cluster_id=cluster_id, + cluster_id=self.cluster_id, ) -class DataprocCreateSparkJobOperator(BaseOperator): +class DataprocCreateSparkJobOperator(DataprocBaseOperator): """Runs Spark job in Data Proc cluster. :param main_jar_file_uri: URI of jar file with job. Can be placed in HDFS or S3. @@ -358,27 +406,25 @@ class DataprocCreateSparkJobOperator(BaseOperator): provided in --packages to avoid dependency conflicts. """ - template_fields: Sequence[str] = ('cluster_id',) - def __init__( self, *, - main_class: Optional[str] = None, - main_jar_file_uri: Optional[str] = None, - jar_file_uris: Optional[Iterable[str]] = None, - archive_uris: Optional[Iterable[str]] = None, - file_uris: Optional[Iterable[str]] = None, - args: Optional[Iterable[str]] = None, - properties: Optional[Dict[str, str]] = None, - name: str = 'Spark job', - cluster_id: Optional[str] = None, - connection_id: Optional[str] = None, - packages: Optional[Iterable[str]] = None, - repositories: Optional[Iterable[str]] = None, - exclude_packages: Optional[Iterable[str]] = None, + main_class: str | None = None, + main_jar_file_uri: str | None = None, + jar_file_uris: Iterable[str] | None = None, + archive_uris: Iterable[str] | None = None, + file_uris: Iterable[str] | None = None, + args: Iterable[str] | None = None, + properties: dict[str, str] | None = None, + name: str = "Spark job", + cluster_id: str | None = None, + connection_id: str | None = None, + packages: Iterable[str] | None = None, + repositories: Iterable[str] | None = None, + exclude_packages: Iterable[str] | None = None, **kwargs, ) -> None: - super().__init__(**kwargs) + super().__init__(yandex_conn_id=connection_id, cluster_id=cluster_id, **kwargs) self.main_class = main_class self.main_jar_file_uri = main_jar_file_uri self.jar_file_uris = jar_file_uris @@ -387,22 +433,13 @@ def __init__( self.args = args self.properties = properties self.name = name - self.cluster_id = cluster_id - self.connection_id = connection_id self.packages = packages self.repositories = repositories self.exclude_packages = exclude_packages - self.hook: Optional[DataprocHook] = None - def execute(self, context: 'Context') -> None: - cluster_id = self.cluster_id or context['task_instance'].xcom_pull(key='cluster_id') - yandex_conn_id = self.connection_id or context['task_instance'].xcom_pull( - key='yandexcloud_connection_id' - ) - self.hook = DataprocHook( - yandex_conn_id=yandex_conn_id, - ) - self.hook.client.create_spark_job( + def execute(self, context: Context) -> None: + hook = self._setup(context) + hook.client.create_spark_job( main_class=self.main_class, main_jar_file_uri=self.main_jar_file_uri, jar_file_uris=self.jar_file_uris, @@ -414,11 +451,11 @@ def execute(self, context: 'Context') -> None: repositories=self.repositories, exclude_packages=self.exclude_packages, name=self.name, - cluster_id=cluster_id, + cluster_id=self.cluster_id, ) -class DataprocCreatePysparkJobOperator(BaseOperator): +class DataprocCreatePysparkJobOperator(DataprocBaseOperator): """Runs Pyspark job in Data Proc cluster. :param main_python_file_uri: URI of python file with job. Can be placed in HDFS or S3. @@ -439,27 +476,25 @@ class DataprocCreatePysparkJobOperator(BaseOperator): provided in --packages to avoid dependency conflicts. """ - template_fields: Sequence[str] = ('cluster_id',) - def __init__( self, *, - main_python_file_uri: Optional[str] = None, - python_file_uris: Optional[Iterable[str]] = None, - jar_file_uris: Optional[Iterable[str]] = None, - archive_uris: Optional[Iterable[str]] = None, - file_uris: Optional[Iterable[str]] = None, - args: Optional[Iterable[str]] = None, - properties: Optional[Dict[str, str]] = None, - name: str = 'Pyspark job', - cluster_id: Optional[str] = None, - connection_id: Optional[str] = None, - packages: Optional[Iterable[str]] = None, - repositories: Optional[Iterable[str]] = None, - exclude_packages: Optional[Iterable[str]] = None, + main_python_file_uri: str | None = None, + python_file_uris: Iterable[str] | None = None, + jar_file_uris: Iterable[str] | None = None, + archive_uris: Iterable[str] | None = None, + file_uris: Iterable[str] | None = None, + args: Iterable[str] | None = None, + properties: dict[str, str] | None = None, + name: str = "Pyspark job", + cluster_id: str | None = None, + connection_id: str | None = None, + packages: Iterable[str] | None = None, + repositories: Iterable[str] | None = None, + exclude_packages: Iterable[str] | None = None, **kwargs, ) -> None: - super().__init__(**kwargs) + super().__init__(yandex_conn_id=connection_id, cluster_id=cluster_id, **kwargs) self.main_python_file_uri = main_python_file_uri self.python_file_uris = python_file_uris self.jar_file_uris = jar_file_uris @@ -468,22 +503,13 @@ def __init__( self.args = args self.properties = properties self.name = name - self.cluster_id = cluster_id - self.connection_id = connection_id self.packages = packages self.repositories = repositories self.exclude_packages = exclude_packages - self.hook: Optional[DataprocHook] = None - def execute(self, context: 'Context') -> None: - cluster_id = self.cluster_id or context['task_instance'].xcom_pull(key='cluster_id') - yandex_conn_id = self.connection_id or context['task_instance'].xcom_pull( - key='yandexcloud_connection_id' - ) - self.hook = DataprocHook( - yandex_conn_id=yandex_conn_id, - ) - self.hook.client.create_pyspark_job( + def execute(self, context: Context) -> None: + hook = self._setup(context) + hook.client.create_pyspark_job( main_python_file_uri=self.main_python_file_uri, python_file_uris=self.python_file_uris, jar_file_uris=self.jar_file_uris, @@ -495,5 +521,5 @@ def execute(self, context: 'Context') -> None: repositories=self.repositories, exclude_packages=self.exclude_packages, name=self.name, - cluster_id=cluster_id, + cluster_id=self.cluster_id, ) diff --git a/airflow/providers/yandex/provider.yaml b/airflow/providers/yandex/provider.yaml index eb81021cc0815..689b59d42ba7e 100644 --- a/airflow/providers/yandex/provider.yaml +++ b/airflow/providers/yandex/provider.yaml @@ -22,6 +22,9 @@ description: | Yandex including `Yandex.Cloud `__ versions: + - 3.2.0 + - 3.1.0 + - 3.0.0 - 2.2.3 - 2.2.2 - 2.2.1 @@ -31,8 +34,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - yandexcloud>=0.173.0 integrations: - integration-name: Yandex.Cloud @@ -60,9 +64,6 @@ hooks: python-modules: - airflow.providers.yandex.hooks.yandexcloud_dataproc -hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - - airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook - connection-types: - hook-class-name: airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook connection-type: yandexcloud diff --git a/airflow/providers/zendesk/.latest-doc-only-change.txt b/airflow/providers/zendesk/.latest-doc-only-change.txt index e7c3c940c9c77..ff7136e07d744 100644 --- a/airflow/providers/zendesk/.latest-doc-only-change.txt +++ b/airflow/providers/zendesk/.latest-doc-only-change.txt @@ -1 +1 @@ -602abe8394fafe7de54df7e73af56de848cdf617 +06acf40a4337759797f666d5bb27a5a393b74fed diff --git a/airflow/providers/zendesk/CHANGELOG.rst b/airflow/providers/zendesk/CHANGELOG.rst index b35c73f027cbc..8176819605a07 100644 --- a/airflow/providers/zendesk/CHANGELOG.rst +++ b/airflow/providers/zendesk/CHANGELOG.rst @@ -16,9 +16,55 @@ under the License. +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + Changelog --------- +4.1.0 +..... + +This release of provider is only available for Airflow 2.3+ as explained in the +`Apache Airflow providers support policy `_. + +Misc +~~~~ + +* ``Move min airflow version to 2.3.0 for all providers (#27196)`` + +Bug Fixes +~~~~~~~~~ + +* ``fix zendesk change log (#27363)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add documentation for July 2022 Provider's release (#25030)`` + * ``Enable string normalization in python formatting - providers (#27205)`` + * ``Update docs for September Provider's release (#26731)`` + * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)`` + * ``Prepare docs for new providers release (August 2022) (#25618)`` + * ``Move provider dependencies to inside provider folders (#24672)`` + +4.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow + providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add explanatory note for contributors about updating Changelog (#24229)`` + * ``Migrate Zendesk example DAGs to new design #22471 (#24129)`` + * ``Prepare docs for May 2022 provider's release (#24231)`` + * ``Update package description to remove double min-airflow specification (#24292)`` + 3.0.3 ..... @@ -51,7 +97,7 @@ Misc ..... Misc -~~~ +~~~~ ``ZendeskHook`` moved from using ``zdesk`` to ``zenpy`` package. Breaking changes diff --git a/airflow/providers/zendesk/hooks/zendesk.py b/airflow/providers/zendesk/hooks/zendesk.py index 2bff0ce34186d..5a8ba2f91061e 100644 --- a/airflow/providers/zendesk/hooks/zendesk.py +++ b/airflow/providers/zendesk/hooks/zendesk.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List, Optional, Tuple, Union +from __future__ import annotations from zenpy import Zenpy from zenpy.lib.api import BaseApi @@ -32,21 +32,21 @@ class ZendeskHook(BaseHook): :param zendesk_conn_id: The Airflow connection used for Zendesk credentials. """ - conn_name_attr = 'zendesk_conn_id' - default_conn_name = 'zendesk_default' - conn_type = 'zendesk' - hook_name = 'Zendesk' + conn_name_attr = "zendesk_conn_id" + default_conn_name = "zendesk_default" + conn_type = "zendesk" + hook_name = "Zendesk" def __init__(self, zendesk_conn_id: str = default_conn_name) -> None: super().__init__() self.zendesk_conn_id = zendesk_conn_id - self.base_api: Optional[BaseApi] = None + self.base_api: BaseApi | None = None zenpy_client, url = self._init_conn() self.zenpy_client = zenpy_client self.__url = url self.get = self.zenpy_client.users._get - def _init_conn(self) -> Tuple[Zenpy, str]: + def _init_conn(self) -> tuple[Zenpy, str]: """ Create the Zenpy Client for our Zendesk connection. @@ -55,7 +55,7 @@ def _init_conn(self) -> Tuple[Zenpy, str]: conn = self.get_connection(self.zendesk_conn_id) url = "https://" + conn.host domain = conn.host - subdomain: Optional[str] = None + subdomain: str | None = None if conn.host.count(".") >= 2: dot_splitted_string = conn.host.rsplit(".", 2) subdomain = dot_splitted_string[0] @@ -85,9 +85,9 @@ def search_tickets(self, **kwargs) -> SearchResultGenerator: :param kwargs: (optional) Search fields given to the zenpy search method. :return: SearchResultGenerator of Ticket objects. """ - return self.zenpy_client.search(type='ticket', **kwargs) + return self.zenpy_client.search(type="ticket", **kwargs) - def create_tickets(self, tickets: Union[Ticket, List[Ticket]], **kwargs) -> Union[TicketAudit, JobStatus]: + def create_tickets(self, tickets: Ticket | list[Ticket], **kwargs) -> TicketAudit | JobStatus: """ Create tickets. @@ -98,7 +98,7 @@ def create_tickets(self, tickets: Union[Ticket, List[Ticket]], **kwargs) -> Unio """ return self.zenpy_client.tickets.create(tickets, **kwargs) - def update_tickets(self, tickets: Union[Ticket, List[Ticket]], **kwargs) -> Union[TicketAudit, JobStatus]: + def update_tickets(self, tickets: Ticket | list[Ticket], **kwargs) -> TicketAudit | JobStatus: """ Update tickets. @@ -109,7 +109,7 @@ def update_tickets(self, tickets: Union[Ticket, List[Ticket]], **kwargs) -> Unio """ return self.zenpy_client.tickets.update(tickets, **kwargs) - def delete_tickets(self, tickets: Union[Ticket, List[Ticket]], **kwargs) -> None: + def delete_tickets(self, tickets: Ticket | list[Ticket], **kwargs) -> None: """ Delete tickets, returns nothing on success and raises APIException on failure. diff --git a/airflow/providers/zendesk/provider.yaml b/airflow/providers/zendesk/provider.yaml index bc870d8dfe890..69a24ead5daab 100644 --- a/airflow/providers/zendesk/provider.yaml +++ b/airflow/providers/zendesk/provider.yaml @@ -22,6 +22,8 @@ description: | `Zendesk `__ versions: + - 4.1.0 + - 4.0.0 - 3.0.3 - 3.0.2 - 3.0.1 @@ -31,8 +33,9 @@ versions: - 1.0.1 - 1.0.0 -additional-dependencies: - - apache-airflow>=2.1.0 +dependencies: + - apache-airflow>=2.3.0 + - zenpy>=2.0.24 integrations: - integration-name: Zendesk diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index b5d0297e90abf..6088e3b37347e 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -16,6 +16,8 @@ # specific language governing permissions and limitations # under the License. """Manages all providers.""" +from __future__ import annotations + import fnmatch import functools import json @@ -27,26 +29,9 @@ from dataclasses import dataclass from functools import wraps from time import perf_counter -from typing import ( - Any, - Callable, - Dict, - List, - MutableMapping, - NamedTuple, - Optional, - Set, - Type, - TypeVar, - Union, - cast, -) - -import jsonschema -from packaging import version as packaging_version +from typing import TYPE_CHECKING, Any, Callable, MutableMapping, NamedTuple, TypeVar, cast from airflow.exceptions import AirflowOptionalProviderFeatureException -from airflow.hooks.base import BaseHook from airflow.typing_compat import Literal from airflow.utils import yaml from airflow.utils.entry_points import entry_points_with_dist @@ -65,13 +50,44 @@ } +def _ensure_prefix_for_placeholders(field_behaviors: dict[str, Any], conn_type: str): + """ + If the given field_behaviors dict contains a placeholders node, and there + are placeholders for extra fields (i.e. anything other than the built-in conn + attrs), and if those extra fields are unprefixed, then add the prefix. + + The reason we need to do this is, all custom conn fields live in the same dictionary, + so we need to namespace them with a prefix internally. But for user convenience, + and consistency between the `get_ui_field_behaviour` method and the extra dict itself, + we allow users to supply the unprefixed name. + """ + conn_attrs = {"host", "schema", "login", "password", "port", "extra"} + + def ensure_prefix(field): + if field not in conn_attrs and not field.startswith("extra__"): + return f"extra__{conn_type}__{field}" + else: + return field + + if "placeholders" in field_behaviors: + placeholders = field_behaviors["placeholders"] + field_behaviors["placeholders"] = {ensure_prefix(k): v for k, v in placeholders.items()} + + return field_behaviors + + +if TYPE_CHECKING: + from airflow.decorators.base import TaskDecorator + from airflow.hooks.base import BaseHook + + class LazyDictWithCache(MutableMapping): """ Dictionary, which in case you set callable, executes the passed callable with `key` attribute at first use - and returns and caches the result. """ - __slots__ = ['_resolved', '_raw_dict'] + __slots__ = ["_resolved", "_raw_dict"] def __init__(self, *args, **kw): self._resolved = set() @@ -110,6 +126,8 @@ def __contains__(self, key): def _create_provider_info_schema_validator(): """Creates JSON schema validator from the provider_info.schema.json""" + import jsonschema + with resource_files("airflow").joinpath("provider_info.schema.json").open("rb") as f: schema = json.load(f) cls = jsonschema.validators.validator_for(schema) @@ -119,6 +137,8 @@ def _create_provider_info_schema_validator(): def _create_customized_form_field_behaviours_schema_validator(): """Creates JSON schema validator from the customized_form_field_behaviours.schema.json""" + import jsonschema + with resource_files("airflow").joinpath("customized_form_field_behaviours.schema.json").open("rb") as f: schema = json.load(f) cls = jsonschema.validators.validator_for(schema) @@ -152,16 +172,16 @@ class ProviderInfo: """ version: str - data: Dict - package_or_source: Union[Literal['source'], Literal['package']] + data: dict + package_or_source: Literal["source"] | Literal["package"] def __post_init__(self): - if self.package_or_source not in ('source', 'package'): + if self.package_or_source not in ("source", "package"): raise ValueError( f"Received {self.package_or_source!r} for `package_or_source`. " "Must be either 'package' or 'source'." ) - self.is_source = self.package_or_source == 'source' + self.is_source = self.package_or_source == "source" class HookClassProvider(NamedTuple): @@ -237,7 +257,7 @@ def log_import_warning(class_name, e, provider_package): def _sanity_check( provider_package: str, class_name: str, provider_info: ProviderInfo -) -> Optional[Type[BaseHook]]: +) -> type[BaseHook] | None: """ Performs coherence check on provider classes. For apache-airflow providers - it checks if it starts with appropriate package. For all providers @@ -335,25 +355,25 @@ def __new__(cls): def __init__(self): """Initializes the manager.""" super().__init__() - self._initialized_cache: Dict[str, bool] = {} + self._initialized_cache: dict[str, bool] = {} # Keeps dict of providers keyed by module name - self._provider_dict: Dict[str, ProviderInfo] = {} + self._provider_dict: dict[str, ProviderInfo] = {} # Keeps dict of hooks keyed by connection type - self._hooks_dict: Dict[str, HookInfo] = {} + self._hooks_dict: dict[str, HookInfo] = {} - self._taskflow_decorators: Dict[str, Callable] = LazyDictWithCache() + self._taskflow_decorators: dict[str, Callable] = LazyDictWithCache() # keeps mapping between connection_types and hook class, package they come from - self._hook_provider_dict: Dict[str, HookClassProvider] = {} + self._hook_provider_dict: dict[str, HookClassProvider] = {} # Keeps dict of hooks keyed by connection type. They are lazy evaluated at access time - self._hooks_lazy_dict: LazyDictWithCache[str, Union[HookInfo, Callable]] = LazyDictWithCache() + self._hooks_lazy_dict: LazyDictWithCache[str, HookInfo | Callable] = LazyDictWithCache() # Keeps methods that should be used to add custom widgets tuple of keyed by name of the extra field - self._connection_form_widgets: Dict[str, ConnectionFormWidgetInfo] = {} + self._connection_form_widgets: dict[str, ConnectionFormWidgetInfo] = {} # Customizations for javascript fields are kept here - self._field_behaviours: Dict[str, Dict] = {} - self._extra_link_class_name_set: Set[str] = set() - self._logging_class_name_set: Set[str] = set() - self._secrets_backend_class_name_set: Set[str] = set() - self._api_auth_backend_module_names: Set[str] = set() + self._field_behaviours: dict[str, dict] = {} + self._extra_link_class_name_set: set[str] = set() + self._logging_class_name_set: set[str] = set() + self._secrets_backend_class_name_set: set[str] = set() + self._api_auth_backend_module_names: set[str] = set() self._provider_schema_validator = _create_provider_info_schema_validator() self._customized_form_fields_schema_validator = ( _create_customized_form_field_behaviours_schema_validator() @@ -372,6 +392,8 @@ def initialize_providers_list(self): self._provider_dict = OrderedDict(sorted(self._provider_dict.items())) def _verify_all_providers_all_compatible(self): + from packaging import version as packaging_version + for provider_id, info in self._provider_dict.items(): min_version = MIN_PROVIDER_VERSIONS.get(provider_id) if min_version: @@ -431,22 +453,22 @@ def _discover_all_providers_from_packages(self) -> None: together with the code. The runtime version is more relaxed (allows for additional properties) and verifies only the subset of fields that are needed at runtime. """ - for entry_point, dist in entry_points_with_dist('apache_airflow_provider'): - package_name = dist.metadata['name'] + for entry_point, dist in entry_points_with_dist("apache_airflow_provider"): + package_name = dist.metadata["name"] if self._provider_dict.get(package_name) is not None: continue log.debug("Loading %s from package %s", entry_point, package_name) version = dist.version provider_info = entry_point.load()() self._provider_schema_validator.validate(provider_info) - provider_info_package_name = provider_info['package-name'] + provider_info_package_name = provider_info["package-name"] if package_name != provider_info_package_name: raise Exception( f"The package '{package_name}' from setuptools and " f"{provider_info_package_name} do not match. Please make sure they are aligned" ) if package_name not in self._provider_dict: - self._provider_dict[package_name] = ProviderInfo(version, provider_info, 'package') + self._provider_dict[package_name] = ProviderInfo(version, provider_info, "package") else: log.warning( "The provider for package '%s' could not be registered from because providers for that " @@ -469,7 +491,14 @@ def _discover_all_airflow_builtin_providers_from_local_sources(self) -> None: log.info("You have no providers installed.") return try: + seen = set() for path in airflow.providers.__path__: # type: ignore[attr-defined] + # The same path can appear in the __path__ twice, under non-normalized paths (ie. + # /path/to/repo/airflow/providers and /path/to/repo/./airflow/providers) + path = os.path.realpath(path) + if path in seen: + continue + seen.add(path) self._add_provider_info_from_local_source_files_on_path(path) except Exception as e: log.warning("Error when loading 'provider.yaml' files from airflow sources: %s", e) @@ -500,9 +529,9 @@ def _add_provider_info_from_local_source_file(self, path, package_name) -> None: provider_info = yaml.safe_load(provider_yaml_file) self._provider_schema_validator.validate(provider_info) - version = provider_info['versions'][0] + version = provider_info["versions"][0] if package_name not in self._provider_dict: - self._provider_dict[package_name] = ProviderInfo(version, provider_info, 'source') + self._provider_dict[package_name] = ProviderInfo(version, provider_info, "source") else: log.warning( "The providers for package '%s' could not be registered because providers for that " @@ -514,8 +543,8 @@ def _add_provider_info_from_local_source_file(self, path, package_name) -> None: def _discover_hooks_from_connection_types( self, - hook_class_names_registered: Set[str], - already_registered_warning_connection_types: Set[str], + hook_class_names_registered: set[str], + already_registered_warning_connection_types: set[str], package_name: str, provider: ProviderInfo, ): @@ -535,13 +564,22 @@ def _discover_hooks_from_connection_types( connection_types = provider.data.get("connection-types") if connection_types: for connection_type_dict in connection_types: - connection_type = connection_type_dict['connection-type'] - hook_class_name = connection_type_dict['hook-class-name'] + connection_type = connection_type_dict["connection-type"] + hook_class_name = connection_type_dict["hook-class-name"] hook_class_names_registered.add(hook_class_name) already_registered = self._hook_provider_dict.get(connection_type) if already_registered: if already_registered.package_name != package_name: already_registered_warning_connection_types.add(connection_type) + else: + log.warning( + "The connection type '%s' is already registered in the" + " package '%s' with different class names: '%s' and '%s'. ", + connection_type, + package_name, + already_registered.hook_class_name, + hook_class_name, + ) else: self._hook_provider_dict[connection_type] = HookClassProvider( hook_class_name=hook_class_name, package_name=package_name @@ -557,8 +595,8 @@ def _discover_hooks_from_connection_types( def _discover_hooks_from_hook_class_names( self, - hook_class_names_registered: Set[str], - already_registered_warning_connection_types: Set[str], + hook_class_names_registered: set[str], + already_registered_warning_connection_types: set[str], package_name: str, provider: ProviderInfo, provider_uses_connection_types: bool, @@ -632,8 +670,8 @@ def _discover_hooks_from_hook_class_names( def _discover_hooks(self) -> None: """Retrieves all connections defined in the providers via Hooks""" for package_name, provider in self._provider_dict.items(): - duplicated_connection_types: Set[str] = set() - hook_class_names_registered: Set[str] = set() + duplicated_connection_types: set[str] = set() + hook_class_names_registered: set[str] = set() provider_uses_connection_types = self._discover_hooks_from_connection_types( hook_class_names_registered, duplicated_connection_types, package_name, provider ) @@ -651,9 +689,15 @@ def _import_info_from_all_hooks(self): """Force-import all hooks and initialize the connections/fields""" # Retrieve all hooks to make sure that all of them are imported _ = list(self._hooks_lazy_dict.values()) - self._connection_form_widgets = OrderedDict(sorted(self._connection_form_widgets.items())) self._field_behaviours = OrderedDict(sorted(self._field_behaviours.items())) + # Widgets for connection forms are currently used in two places: + # 1. In the UI Connections, expected same order that it defined in Hook. + # 2. cli command - `airflow providers widgets` and expected that it in alphabetical order. + # It is not possible to recover original ordering after sorting, + # that the main reason why original sorting moved to cli part: + # self._connection_form_widgets = OrderedDict(sorted(self._connection_form_widgets.items())) + def _discover_taskflow_decorators(self) -> None: for name, info in self._provider_dict.items(): for taskflow_decorator in info.data.get("task-decorators", []): @@ -668,7 +712,7 @@ def _add_taskflow_decorator(self, name, decorator_class_name: str, provider_pack if name in self._taskflow_decorators: try: existing = self._taskflow_decorators[name] - other_name = f'{existing.__module__}.{existing.__name__}' + other_name = f"{existing.__module__}.{existing.__name__}" except Exception: # If problem importing, then get the value from the functools.partial other_name = self._taskflow_decorators._raw_dict[name].args[0] # type: ignore[attr-defined] @@ -692,11 +736,11 @@ def _get_attr(obj: Any, attr_name: str): def _import_hook( self, - connection_type: Optional[str], + connection_type: str | None, provider_info: ProviderInfo, - hook_class_name: Optional[str] = None, - package_name: Optional[str] = None, - ) -> Optional[HookInfo]: + hook_class_name: str | None = None, + package_name: str | None = None, + ) -> HookInfo | None: """ Imports hook and retrieves hook information. Either connection_type (for lazy loading) or hook_class_name must be set - but not both). Only needs package_name if hook_class_name is @@ -733,12 +777,13 @@ def _import_hook( if hook_class is None: return None try: - module, class_name = hook_class_name.rsplit('.', maxsplit=1) + module, class_name = hook_class_name.rsplit(".", maxsplit=1) # Do not use attr here. We want to check only direct class fields not those # inherited from parent hook. This way we add form fields only once for the whole # hierarchy and we add it only from the parent hook that provides those! - if 'get_connection_form_widgets' in hook_class.__dict__: + if "get_connection_form_widgets" in hook_class.__dict__: widgets = hook_class.get_connection_form_widgets() + if widgets: for widget in widgets.values(): if widget.field_class not in allowed_field_classes: @@ -751,7 +796,7 @@ def _import_hook( ) return None self._add_widgets(package_name, hook_class, widgets) - if 'get_ui_field_behaviour' in hook_class.__dict__: + if "get_ui_field_behaviour" in hook_class.__dict__: field_behaviours = hook_class.get_ui_field_behaviour() if field_behaviours: self._add_customized_fields(package_name, hook_class, field_behaviours) @@ -763,7 +808,7 @@ def _import_hook( e, ) return None - hook_connection_type = self._get_attr(hook_class, 'conn_type') + hook_connection_type = self._get_attr(hook_class, "conn_type") if connection_type: if hook_connection_type != connection_type: log.warning( @@ -776,8 +821,8 @@ def _import_hook( connection_type, ) connection_type = hook_connection_type - connection_id_attribute_name: str = self._get_attr(hook_class, 'conn_name_attr') - hook_name: str = self._get_attr(hook_class, 'hook_name') + connection_id_attribute_name: str = self._get_attr(hook_class, "conn_name_attr") + hook_name: str = self._get_attr(hook_class, "hook_name") if not connection_type or not connection_id_attribute_name or not hook_name: log.warning( @@ -795,13 +840,13 @@ def _import_hook( package_name=package_name, hook_name=hook_name, connection_type=connection_type, - connection_testable=hasattr(hook_class, 'test_connection'), + connection_testable=hasattr(hook_class, "test_connection"), ) - def _add_widgets(self, package_name: str, hook_class: type, widgets: Dict[str, Any]): + def _add_widgets(self, package_name: str, hook_class: type, widgets: dict[str, Any]): conn_type = hook_class.conn_type # type: ignore for field_identifier, field in widgets.items(): - if field_identifier.startswith('extra__'): + if field_identifier.startswith("extra__"): prefixed_field_name = field_identifier else: prefixed_field_name = f"extra__{conn_type}__{field_identifier}" @@ -817,10 +862,15 @@ def _add_widgets(self, package_name: str, hook_class: type, widgets: Dict[str, A hook_class.__name__, package_name, field, field_identifier ) - def _add_customized_fields(self, package_name: str, hook_class: type, customized_fields: Dict): + def _add_customized_fields(self, package_name: str, hook_class: type, customized_fields: dict): try: connection_type = getattr(hook_class, "conn_type") + self._customized_form_fields_schema_validator.validate(customized_fields) + + if connection_type: + customized_fields = _ensure_prefix_for_placeholders(customized_fields, connection_type) + if connection_type in self._field_behaviours: log.warning( "The connection_type %s from package %s and class %s has already been added " @@ -872,13 +922,13 @@ def _discover_auth_backends(self) -> None: self._api_auth_backend_module_names.add(auth_backend_module_name) @property - def providers(self) -> Dict[str, ProviderInfo]: + def providers(self) -> dict[str, ProviderInfo]: """Returns information about available providers.""" self.initialize_providers_list() return self._provider_dict @property - def hooks(self) -> MutableMapping[str, Optional[HookInfo]]: + def hooks(self) -> MutableMapping[str, HookInfo | None]: """ Returns dictionary of connection_type-to-hook mapping. Note that the dict can contain None values if a hook discovered cannot be imported! @@ -888,44 +938,47 @@ def hooks(self) -> MutableMapping[str, Optional[HookInfo]]: return self._hooks_lazy_dict @property - def taskflow_decorators(self) -> Dict[str, Callable]: + def taskflow_decorators(self) -> dict[str, TaskDecorator]: self.initialize_providers_taskflow_decorator() return self._taskflow_decorators @property - def extra_links_class_names(self) -> List[str]: + def extra_links_class_names(self) -> list[str]: """Returns set of extra link class names.""" self.initialize_providers_extra_links() return sorted(self._extra_link_class_name_set) @property - def connection_form_widgets(self) -> Dict[str, ConnectionFormWidgetInfo]: - """Returns widgets for connection forms.""" + def connection_form_widgets(self) -> dict[str, ConnectionFormWidgetInfo]: + """ + Returns widgets for connection forms. + Dictionary keys in the same order that it defined in Hook. + """ self.initialize_providers_hooks() self._import_info_from_all_hooks() return self._connection_form_widgets @property - def field_behaviours(self) -> Dict[str, Dict]: + def field_behaviours(self) -> dict[str, dict]: """Returns dictionary with field behaviours for connection types.""" self.initialize_providers_hooks() self._import_info_from_all_hooks() return self._field_behaviours @property - def logging_class_names(self) -> List[str]: + def logging_class_names(self) -> list[str]: """Returns set of log task handlers class names.""" self.initialize_providers_logging() return sorted(self._logging_class_name_set) @property - def secrets_backend_class_names(self) -> List[str]: + def secrets_backend_class_names(self) -> list[str]: """Returns set of secret backend class names.""" self.initialize_providers_secrets_backends() return sorted(self._secrets_backend_class_name_set) @property - def auth_backend_module_names(self) -> List[str]: + def auth_backend_module_names(self) -> list[str]: """Returns set of API auth backend class names.""" self.initialize_providers_auth_backends() return sorted(self._api_auth_backend_module_names) diff --git a/airflow/secrets/__init__.py b/airflow/secrets/__init__.py index 5b12c8a300137..57f90acef9f93 100644 --- a/airflow/secrets/__init__.py +++ b/airflow/secrets/__init__.py @@ -22,7 +22,9 @@ * Metastore database * AWS SSM Parameter store """ -__all__ = ['BaseSecretsBackend', 'DEFAULT_SECRETS_SEARCH_PATH'] +from __future__ import annotations + +__all__ = ["BaseSecretsBackend", "DEFAULT_SECRETS_SEARCH_PATH"] from airflow.secrets.base_secrets import BaseSecretsBackend diff --git a/airflow/secrets/base_secrets.py b/airflow/secrets/base_secrets.py index a9942e9586385..56336f10be22a 100644 --- a/airflow/secrets/base_secrets.py +++ b/airflow/secrets/base_secrets.py @@ -14,9 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import warnings from abc import ABC -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING + +from airflow.exceptions import RemovedInAirflow3Warning if TYPE_CHECKING: from airflow.models.connection import Connection @@ -36,7 +40,7 @@ def build_path(path_prefix: str, secret_id: str, sep: str = "/") -> str: """ return f"{path_prefix}{sep}{secret_id}" - def get_conn_value(self, conn_id: str) -> Optional[str]: + def get_conn_value(self, conn_id: str) -> str | None: """ Retrieve from Secrets Backend a string value representing the Connection object. @@ -47,7 +51,7 @@ def get_conn_value(self, conn_id: str) -> Optional[str]: """ raise NotImplementedError - def deserialize_connection(self, conn_id: str, value: str) -> 'Connection': + def deserialize_connection(self, conn_id: str, value: str) -> Connection: """ Given a serialized representation of the airflow Connection, return an instance. Looks at first character to determine how to deserialize. @@ -59,12 +63,12 @@ def deserialize_connection(self, conn_id: str, value: str) -> 'Connection': from airflow.models.connection import Connection value = value.strip() - if value[0] == '{': + if value[0] == "{": return Connection.from_json(conn_id=conn_id, value=value) else: return Connection(conn_id=conn_id, uri=value) - def get_conn_uri(self, conn_id: str) -> Optional[str]: + def get_conn_uri(self, conn_id: str) -> str | None: """ Get conn_uri from Secrets Backend @@ -75,7 +79,7 @@ def get_conn_uri(self, conn_id: str) -> Optional[str]: """ raise NotImplementedError() - def get_connection(self, conn_id: str) -> Optional['Connection']: + def get_connection(self, conn_id: str) -> Connection | None: """ Return connection object with a given ``conn_id``. @@ -93,7 +97,7 @@ def get_connection(self, conn_id: str) -> Optional['Connection']: not_implemented_get_conn_value = True warnings.warn( "Method `get_conn_uri` is deprecated. Please use `get_conn_value`.", - PendingDeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) @@ -112,7 +116,7 @@ def get_connection(self, conn_id: str) -> Optional['Connection']: else: return None - def get_connections(self, conn_id: str) -> List['Connection']: + def get_connections(self, conn_id: str) -> list[Connection]: """ Return connection object with a given ``conn_id``. @@ -121,7 +125,7 @@ def get_connections(self, conn_id: str) -> List['Connection']: warnings.warn( "This method is deprecated. Please use " "`airflow.secrets.base_secrets.BaseSecretsBackend.get_connection`.", - PendingDeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) conn = self.get_connection(conn_id=conn_id) @@ -129,7 +133,7 @@ def get_connections(self, conn_id: str) -> List['Connection']: return [conn] return [] - def get_variable(self, key: str) -> Optional[str]: + def get_variable(self, key: str) -> str | None: """ Return value for Airflow Variable @@ -138,7 +142,7 @@ def get_variable(self, key: str) -> Optional[str]: """ raise NotImplementedError() - def get_config(self, key: str) -> Optional[str]: + def get_config(self, key: str) -> str | None: """ Return value for Airflow Config Key diff --git a/airflow/secrets/environment_variables.py b/airflow/secrets/environment_variables.py index 41883ba33d128..1571737b33601 100644 --- a/airflow/secrets/environment_variables.py +++ b/airflow/secrets/environment_variables.py @@ -16,11 +16,12 @@ # specific language governing permissions and limitations # under the License. """Objects relating to sourcing connections from environment variables""" +from __future__ import annotations import os import warnings -from typing import Optional +from airflow.exceptions import RemovedInAirflow3Warning from airflow.secrets import BaseSecretsBackend CONN_ENV_PREFIX = "AIRFLOW_CONN_" @@ -30,7 +31,7 @@ class EnvironmentVariablesBackend(BaseSecretsBackend): """Retrieves Connection object and Variable from environment variable.""" - def get_conn_uri(self, conn_id: str) -> Optional[str]: + def get_conn_uri(self, conn_id: str) -> str | None: """ Return URI representation of Connection conn_id :param conn_id: the connection id @@ -39,15 +40,15 @@ def get_conn_uri(self, conn_id: str) -> Optional[str]: warnings.warn( "This method is deprecated. Please use " "`airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_conn_value`.", - PendingDeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) return self.get_conn_value(conn_id) - def get_conn_value(self, conn_id: str) -> Optional[str]: + def get_conn_value(self, conn_id: str) -> str | None: return os.environ.get(CONN_ENV_PREFIX + conn_id.upper()) - def get_variable(self, key: str) -> Optional[str]: + def get_variable(self, key: str) -> str | None: """ Get Airflow Variable from Environment Variable diff --git a/airflow/secrets/local_filesystem.py b/airflow/secrets/local_filesystem.py index 25c4eed3db966..ac019a0b25fbf 100644 --- a/airflow/secrets/local_filesystem.py +++ b/airflow/secrets/local_filesystem.py @@ -16,6 +16,8 @@ # specific language governing permissions and limitations # under the License. """Objects relating to retrieving connections and variables from local file""" +from __future__ import annotations + import json import logging import os @@ -23,13 +25,14 @@ from collections import defaultdict from inspect import signature from json import JSONDecodeError -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any from airflow.exceptions import ( AirflowException, AirflowFileParseException, ConnectionNotUnique, FileSyntaxError, + RemovedInAirflow3Warning, ) from airflow.secrets.base_secrets import BaseSecretsBackend from airflow.utils import yaml @@ -42,14 +45,14 @@ from airflow.models.connection import Connection -def get_connection_parameter_names() -> Set[str]: +def get_connection_parameter_names() -> set[str]: """Returns :class:`airflow.models.connection.Connection` constructor parameters.""" from airflow.models.connection import Connection return {k for k in signature(Connection.__init__).parameters.keys() if k != "self"} -def _parse_env_file(file_path: str) -> Tuple[Dict[str, List[str]], List[FileSyntaxError]]: +def _parse_env_file(file_path: str) -> tuple[dict[str, list[str]], list[FileSyntaxError]]: """ Parse a file in the ``.env`` format. @@ -63,8 +66,8 @@ def _parse_env_file(file_path: str) -> Tuple[Dict[str, List[str]], List[FileSynt with open(file_path) as f: content = f.read() - secrets: Dict[str, List[str]] = defaultdict(list) - errors: List[FileSyntaxError] = [] + secrets: dict[str, list[str]] = defaultdict(list) + errors: list[FileSyntaxError] = [] for line_no, line in enumerate(content.splitlines(), 1): if not line: # Ignore empty line @@ -95,7 +98,7 @@ def _parse_env_file(file_path: str) -> Tuple[Dict[str, List[str]], List[FileSynt return secrets, errors -def _parse_yaml_file(file_path: str) -> Tuple[Dict[str, List[str]], List[FileSyntaxError]]: +def _parse_yaml_file(file_path: str) -> tuple[dict[str, list[str]], list[FileSyntaxError]]: """ Parse a file in the YAML format. @@ -118,7 +121,7 @@ def _parse_yaml_file(file_path: str) -> Tuple[Dict[str, List[str]], List[FileSyn return secrets, [] -def _parse_json_file(file_path: str) -> Tuple[Dict[str, Any], List[FileSyntaxError]]: +def _parse_json_file(file_path: str) -> tuple[dict[str, Any], list[FileSyntaxError]]: """ Parse a file in the JSON format. @@ -148,7 +151,7 @@ def _parse_json_file(file_path: str) -> Tuple[Dict[str, Any], List[FileSyntaxErr } -def _parse_secret_file(file_path: str) -> Dict[str, Any]: +def _parse_secret_file(file_path: str) -> dict[str, Any]: """ Based on the file extension format, selects a parser, and parses the file. @@ -220,14 +223,13 @@ def _create_connection(conn_id: str, value: Any): ) -def load_variables(file_path: str) -> Dict[str, str]: +def load_variables(file_path: str) -> dict[str, str]: """ Load variables from a text file. ``JSON``, `YAML` and ``.env`` files are supported. :param file_path: The location of the file that will be processed. - :rtype: Dict[str, List[str]] """ log.debug("Loading variables from a text file") @@ -240,28 +242,27 @@ def load_variables(file_path: str) -> Dict[str, str]: return variables -def load_connections(file_path) -> Dict[str, List[Any]]: +def load_connections(file_path) -> dict[str, list[Any]]: """This function is deprecated. Please use `airflow.secrets.local_filesystem.load_connections_dict`.",""" warnings.warn( "This function is deprecated. Please use `airflow.secrets.local_filesystem.load_connections_dict`.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) return {k: [v] for k, v in load_connections_dict(file_path).values()} -def load_connections_dict(file_path: str) -> Dict[str, Any]: +def load_connections_dict(file_path: str) -> dict[str, Any]: """ Load connection from text file. ``JSON``, `YAML` and ``.env`` files are supported. :return: A dictionary where the key contains a connection ID and the value contains the connection. - :rtype: Dict[str, airflow.models.connection.Connection] """ log.debug("Loading connection") - secrets: Dict[str, Any] = _parse_secret_file(file_path) + secrets: dict[str, Any] = _parse_secret_file(file_path) connection_by_conn_id = {} for key, secret_values in list(secrets.items()): if isinstance(secret_values, list): @@ -289,15 +290,13 @@ class LocalFilesystemBackend(BaseSecretsBackend, LoggingMixin): :param connections_file_path: File location with connection data. """ - def __init__( - self, variables_file_path: Optional[str] = None, connections_file_path: Optional[str] = None - ): + def __init__(self, variables_file_path: str | None = None, connections_file_path: str | None = None): super().__init__() self.variables_file = variables_file_path self.connections_file = connections_file_path @property - def _local_variables(self) -> Dict[str, str]: + def _local_variables(self) -> dict[str, str]: if not self.variables_file: self.log.debug("The file for variables is not specified. Skipping") # The user may not specify any file. @@ -306,23 +305,23 @@ def _local_variables(self) -> Dict[str, str]: return secrets @property - def _local_connections(self) -> Dict[str, 'Connection']: + def _local_connections(self) -> dict[str, Connection]: if not self.connections_file: self.log.debug("The file for connection is not specified. Skipping") # The user may not specify any file. return {} return load_connections_dict(self.connections_file) - def get_connection(self, conn_id: str) -> Optional['Connection']: + def get_connection(self, conn_id: str) -> Connection | None: if conn_id in self._local_connections: return self._local_connections[conn_id] return None - def get_connections(self, conn_id: str) -> List[Any]: + def get_connections(self, conn_id: str) -> list[Any]: warnings.warn( "This method is deprecated. Please use " "`airflow.secrets.local_filesystem.LocalFilesystemBackend.get_connection`.", - PendingDeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) conn = self.get_connection(conn_id=conn_id) @@ -330,5 +329,5 @@ def get_connections(self, conn_id: str) -> List[Any]: return [conn] return [] - def get_variable(self, key: str) -> Optional[str]: + def get_variable(self, key: str) -> str | None: return self._local_variables.get(key) diff --git a/airflow/secrets/metastore.py b/airflow/secrets/metastore.py index 100a35d8fd2c2..ce322d0b5c73c 100644 --- a/airflow/secrets/metastore.py +++ b/airflow/secrets/metastore.py @@ -16,9 +16,12 @@ # specific language governing permissions and limitations # under the License. """Objects relating to sourcing connections from metastore database""" +from __future__ import annotations + import warnings -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING +from airflow.exceptions import RemovedInAirflow3Warning from airflow.secrets import BaseSecretsBackend from airflow.utils.session import provide_session @@ -30,7 +33,7 @@ class MetastoreBackend(BaseSecretsBackend): """Retrieves Connection object and Variable from airflow metastore database.""" @provide_session - def get_connection(self, conn_id, session=None) -> Optional['Connection']: + def get_connection(self, conn_id, session=None) -> Connection | None: from airflow.models.connection import Connection conn = session.query(Connection).filter(Connection.conn_id == conn_id).first() @@ -38,11 +41,11 @@ def get_connection(self, conn_id, session=None) -> Optional['Connection']: return conn @provide_session - def get_connections(self, conn_id, session=None) -> List['Connection']: + def get_connections(self, conn_id, session=None) -> list[Connection]: warnings.warn( "This method is deprecated. Please use " "`airflow.secrets.metastore.MetastoreBackend.get_connection`.", - PendingDeprecationWarning, + RemovedInAirflow3Warning, stacklevel=3, ) conn = self.get_connection(conn_id=conn_id, session=session) diff --git a/airflow/security/kerberos.py b/airflow/security/kerberos.py index e8fc86af7259c..bc36f5e98c539 100644 --- a/airflow/security/kerberos.py +++ b/airflow/security/kerberos.py @@ -15,7 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations + # Licensed to Cloudera, Inc. under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -31,25 +32,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Kerberos security provider""" +"""Kerberos security provider.""" import logging import shlex -import socket import subprocess import sys import time -from typing import List, Optional from airflow.configuration import conf +from airflow.utils.net import get_hostname -NEED_KRB181_WORKAROUND = None # type: Optional[bool] +NEED_KRB181_WORKAROUND: bool | None = None log = logging.getLogger(__name__) -def renew_from_kt(principal: Optional[str], keytab: str, exit_on_fail: bool = True): +def renew_from_kt(principal: str | None, keytab: str, exit_on_fail: bool = True): """ - Renew kerberos token from keytab + Renew kerberos token from keytab. :param principal: principal :param keytab: keytab file @@ -59,22 +59,22 @@ def renew_from_kt(principal: Optional[str], keytab: str, exit_on_fail: bool = Tr # minutes to give ourselves a large renewal buffer. renewal_lifetime = f"{conf.getint('kerberos', 'reinit_frequency')}m" - cmd_principal = principal or conf.get_mandatory_value('kerberos', 'principal').replace( - "_HOST", socket.getfqdn() + cmd_principal = principal or conf.get_mandatory_value("kerberos", "principal").replace( + "_HOST", get_hostname() ) - if conf.getboolean('kerberos', 'forwardable'): - forwardable = '-f' + if conf.getboolean("kerberos", "forwardable"): + forwardable = "-f" else: - forwardable = '-F' + forwardable = "-F" - if conf.getboolean('kerberos', 'include_ip'): - include_ip = '-a' + if conf.getboolean("kerberos", "include_ip"): + include_ip = "-a" else: - include_ip = '-A' + include_ip = "-A" - cmdv: List[str] = [ - conf.get_mandatory_value('kerberos', 'kinit_path'), + cmdv: list[str] = [ + conf.get_mandatory_value("kerberos", "kinit_path"), forwardable, include_ip, "-r", @@ -83,7 +83,7 @@ def renew_from_kt(principal: Optional[str], keytab: str, exit_on_fail: bool = Tr "-t", keytab, # specify keytab "-c", - conf.get_mandatory_value('kerberos', 'ccache'), # specify credentials cache + conf.get_mandatory_value("kerberos", "ccache"), # specify credentials cache cmd_principal, ] log.info("Re-initialising kerberos from keytab: %s", " ".join(shlex.quote(f) for f in cmdv)) @@ -131,10 +131,10 @@ def perform_krb181_workaround(principal: str): :param principal: principal name :return: None """ - cmdv: List[str] = [ - conf.get_mandatory_value('kerberos', 'kinit_path'), + cmdv: list[str] = [ + conf.get_mandatory_value("kerberos", "kinit_path"), "-c", - conf.get_mandatory_value('kerberos', 'ccache'), + conf.get_mandatory_value("kerberos", "ccache"), "-R", ] # Renew ticket_cache @@ -143,8 +143,8 @@ def perform_krb181_workaround(principal: str): ret = subprocess.call(cmdv, close_fds=True) if ret != 0: - principal = f"{principal or conf.get('kerberos', 'principal')}/{socket.getfqdn()}" - ccache = conf.get('kerberos', 'ccache') + principal = f"{principal or conf.get('kerberos', 'principal')}/{get_hostname()}" + ccache = conf.get("kerberos", "ccache") log.error( "Couldn't renew kerberos ticket in order to work around Kerberos 1.8.1 issue. Please check that " "the ticket for '%s' is still renewable:\n $ kinit -f -c %s\nIf the 'renew until' date is the " @@ -159,19 +159,22 @@ def perform_krb181_workaround(principal: str): def detect_conf_var() -> bool: - """Return true if the ticket cache contains "conf" information as is found + """ + Autodetect the Kerberos ticket configuration. + + Return true if the ticket cache contains "conf" information as is found in ticket caches of Kerberos 1.8.1 or later. This is incompatible with the Sun Java Krb5LoginModule in Java6, so we need to take an action to work around it. """ - ticket_cache = conf.get_mandatory_value('kerberos', 'ccache') + ticket_cache = conf.get_mandatory_value("kerberos", "ccache") - with open(ticket_cache, 'rb') as file: + with open(ticket_cache, "rb") as file: # Note: this file is binary, so we check against a bytearray. - return b'X-CACHECONF:' in file.read() + return b"X-CACHECONF:" in file.read() -def run(principal: Optional[str], keytab: str): +def run(principal: str | None, keytab: str): """ Run the kerbros renewer. @@ -185,4 +188,4 @@ def run(principal: Optional[str], keytab: str): while True: renew_from_kt(principal, keytab) - time.sleep(conf.getint('kerberos', 'reinit_frequency')) + time.sleep(conf.getint("kerberos", "reinit_frequency")) diff --git a/airflow/security/permissions.py b/airflow/security/permissions.py index 2d5c0b939976e..adf7af964e571 100644 --- a/airflow/security/permissions.py +++ b/airflow/security/permissions.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations # Resource Constants RESOURCE_ACTION = "Permissions" @@ -32,6 +33,7 @@ RESOURCE_DAG_CODE = "DAG Code" RESOURCE_DAG_RUN = "DAG Runs" RESOURCE_IMPORT_ERROR = "ImportError" +RESOURCE_DAG_WARNING = "DAG Warnings" RESOURCE_JOB = "Jobs" RESOURCE_MY_PASSWORD = "My Password" RESOURCE_MY_PROFILE = "My Profile" @@ -52,6 +54,7 @@ RESOURCE_VARIABLE = "Variables" RESOURCE_WEBSITE = "Website" RESOURCE_XCOM = "XComs" +RESOURCE_DATASET = "Datasets" # Action Constants @@ -66,14 +69,15 @@ DAG_ACTIONS = {ACTION_CAN_READ, ACTION_CAN_EDIT, ACTION_CAN_DELETE} -def resource_name_for_dag(dag_id): - """Returns the resource name for a DAG id.""" - if dag_id == RESOURCE_DAG: - return dag_id +def resource_name_for_dag(root_dag_id: str) -> str: + """Returns the resource name for a DAG id. - if dag_id.startswith(RESOURCE_DAG_PREFIX): - return dag_id - - # To account for SubDags - root_dag_id = dag_id.split(".")[0] + Note that since a sub-DAG should follow the permission of its + parent DAG, you should pass ``DagModel.root_dag_id`` to this function, + for a subdag. A normal dag should pass the ``DagModel.dag_id``. + """ + if root_dag_id == RESOURCE_DAG: + return root_dag_id + if root_dag_id.startswith(RESOURCE_DAG_PREFIX): + return root_dag_id return f"{RESOURCE_DAG_PREFIX}{root_dag_id}" diff --git a/airflow/security/utils.py b/airflow/security/utils.py index c4e2af91d7ab9..6ce61e36a74d5 100644 --- a/airflow/security/utils.py +++ b/airflow/security/utils.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations # Licensed to Cloudera, Inc. under one # or more contributor license agreements. See the NOTICE file @@ -35,12 +36,11 @@ """Various security-related utils.""" import re import socket -from typing import List, Optional from airflow.utils.net import get_hostname -def get_components(principal) -> Optional[List[str]]: +def get_components(principal) -> list[str] | None: """ Returns components retrieved from the kerberos principal. -> (short name, instance (FQDN), realm) @@ -49,15 +49,15 @@ def get_components(principal) -> Optional[List[str]]: """ if not principal: return None - return re.split(r'[/@]', str(principal)) + return re.split(r"[/@]", str(principal)) def replace_hostname_pattern(components, host=None): """Replaces hostname with the right pattern including lowercase of the name.""" fqdn = host - if not fqdn or fqdn == '0.0.0.0': + if not fqdn or fqdn == "0.0.0.0": fqdn = get_hostname() - return f'{components[0]}/{fqdn.lower()}@{components[2]}' + return f"{components[0]}/{fqdn.lower()}@{components[2]}" def get_fqdn(hostname_or_ip=None): @@ -65,7 +65,7 @@ def get_fqdn(hostname_or_ip=None): try: if hostname_or_ip: fqdn = socket.gethostbyaddr(hostname_or_ip)[0] - if fqdn == 'localhost': + if fqdn == "localhost": fqdn = get_hostname() else: fqdn = get_hostname() @@ -77,7 +77,7 @@ def get_fqdn(hostname_or_ip=None): def principal_from_username(username, realm): """Retrieves principal from the user name and realm.""" - if ('@' not in username) and realm: + if ("@" not in username) and realm: username = f"{username}@{realm}" return username diff --git a/airflow/sensors/__init__.py b/airflow/sensors/__init__.py index d3a60b4f79c97..4de283d1107d3 100644 --- a/airflow/sensors/__init__.py +++ b/airflow/sensors/__init__.py @@ -15,5 +15,58 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +# fmt: off """Sensors.""" +from __future__ import annotations + +from airflow.utils.deprecation_tools import add_deprecated_classes + +__deprecated_classes = { + 'base_sensor_operator': { + 'BaseSensorOperator': 'airflow.sensors.base.BaseSensorOperator', + }, + 'date_time_sensor': { + 'DateTimeSensor': 'airflow.sensors.date_time.DateTimeSensor', + }, + 'external_task_sensor': { + 'ExternalTaskMarker': 'airflow.sensors.external_task.ExternalTaskMarker', + 'ExternalTaskSensor': 'airflow.sensors.external_task.ExternalTaskSensor', + 'ExternalTaskSensorLink': 'airflow.sensors.external_task.ExternalTaskSensorLink', + }, + 'hdfs_sensor': { + 'HdfsSensor': 'airflow.providers.apache.hdfs.sensors.hdfs.HdfsSensor', + }, + 'hive_partition_sensor': { + 'HivePartitionSensor': 'airflow.providers.apache.hive.sensors.hive_partition.HivePartitionSensor', + }, + 'http_sensor': { + 'HttpSensor': 'airflow.providers.http.sensors.http.HttpSensor', + }, + 'metastore_partition_sensor': { + 'MetastorePartitionSensor': ( + 'airflow.providers.apache.hive.sensors.metastore_partition.MetastorePartitionSensor' + ), + }, + 'named_hive_partition_sensor': { + 'NamedHivePartitionSensor': ( + 'airflow.providers.apache.hive.sensors.named_hive_partition.NamedHivePartitionSensor' + ), + }, + 's3_key_sensor': { + 'S3KeySensor': 'airflow.providers.amazon.aws.sensors.s3.S3KeySensor', + }, + 'sql': { + 'SqlSensor': 'airflow.providers.common.sql.sensors.sql.SqlSensor', + }, + 'sql_sensor': { + 'SqlSensor': 'airflow.providers.common.sql.sensors.sql.SqlSensor', + }, + 'time_delta_sensor': { + 'TimeDeltaSensor': 'airflow.sensors.time_delta.TimeDeltaSensor', + }, + 'web_hdfs_sensor': { + 'WebHdfsSensor': 'airflow.providers.apache.hdfs.sensors.web_hdfs.WebHdfsSensor', + }, +} + +add_deprecated_classes(__deprecated_classes, __name__) diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py index f00b3a6761d5b..f20c7d7771dca 100644 --- a/airflow/sensors/base.py +++ b/airflow/sensors/base.py @@ -15,14 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import datetime import functools import hashlib import time -import warnings from datetime import timedelta -from typing import Any, Callable, Iterable, Optional, Union +from typing import Any, Callable, Iterable from airflow import settings from airflow.configuration import conf @@ -33,7 +33,6 @@ AirflowSkipException, ) from airflow.models.baseoperator import BaseOperator -from airflow.models.sensorinstance import SensorInstance from airflow.models.skipmixin import SkipMixin from airflow.models.taskreschedule import TaskReschedule from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep @@ -44,7 +43,6 @@ # Google Provider before 3.0.0 imported apply_defaults from here. # See https://github.com/apache/airflow/issues/16035 from airflow.utils.decorators import apply_defaults # noqa: F401 -from airflow.utils.docs import get_docs_url # As documented in https://dev.mysql.com/doc/refman/5.7/en/datetime.html. _MYSQL_TIMESTAMP_MAX = datetime.datetime(2038, 1, 19, 3, 14, 7, tzinfo=timezone.utc) @@ -59,6 +57,8 @@ def _is_metadatabase_mysql() -> bool: class PokeReturnValue: """ + Optional return value for poke methods. + Sensors can optionally return an instance of the PokeReturnValue class in the poke method. If an XCom value is supplied when the sensor is done, then the XCom value will be pushed through the operator return value. @@ -66,7 +66,7 @@ class PokeReturnValue: :param xcom_value: An optional XCOM value to be returned by the operator. """ - def __init__(self, is_done: bool, xcom_value: Optional[Any] = None) -> None: + def __init__(self, is_done: bool, xcom_value: Any | None = None) -> None: self.xcom_value = xcom_value self.is_done = is_done @@ -83,7 +83,7 @@ class BaseSensorOperator(BaseOperator, SkipMixin): :param soft_fail: Set to true to mark the task as SKIPPED on failure :param poke_interval: Time in seconds that the job should wait in - between each tries + between each try :param timeout: Time, in seconds before the task times out and fails. :param mode: How the sensor operates. Options are: ``{ poke | reschedule }``, default is ``poke``. @@ -99,26 +99,11 @@ class BaseSensorOperator(BaseOperator, SkipMixin): prevent too much load on the scheduler. :param exponential_backoff: allow progressive longer waits between pokes by using exponential backoff algorithm + :param max_wait: maximum wait interval between pokes, can be ``timedelta`` or ``float`` seconds """ - ui_color = '#e6f1f2' # type: str - valid_modes = ['poke', 'reschedule'] # type: Iterable[str] - - # As the poke context in smart sensor defines the poking job signature only, - # The execution_fields defines other execution details - # for this tasks such as the customer defined timeout, the email and the alert - # setup. Smart sensor serialize these attributes into a different DB column so - # that smart sensor service is able to handle corresponding execution details - # without breaking the sensor poking logic with dedup. - execution_fields = ( - 'poke_interval', - 'retries', - 'execution_timeout', - 'timeout', - 'email', - 'email_on_retry', - 'email_on_failure', - ) + ui_color: str = "#e6f1f2" + valid_modes: Iterable[str] = ["poke", "reschedule"] # Adds one additional dependency for all sensor operators that checks if a # sensor task instance can be rescheduled. @@ -128,10 +113,11 @@ def __init__( self, *, poke_interval: float = 60, - timeout: float = conf.getfloat('sensors', 'default_timeout'), + timeout: float = conf.getfloat("sensors", "default_timeout"), soft_fail: bool = False, - mode: str = 'poke', + mode: str = "poke", exponential_backoff: bool = False, + max_wait: timedelta | float | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -140,11 +126,16 @@ def __init__( self.timeout = timeout self.mode = mode self.exponential_backoff = exponential_backoff + self.max_wait = self._coerce_max_wait(max_wait) self._validate_input_values() - self.sensor_service_enabled = conf.getboolean('smart_sensor', 'use_smart_sensor') - self.sensors_support_sensor_service = set( - map(lambda l: l.strip(), conf.get_mandatory_value('smart_sensor', 'sensors_enabled').split(',')) - ) + + @staticmethod + def _coerce_max_wait(max_wait: float | timedelta | None) -> timedelta | None: + if max_wait is None or isinstance(max_wait, timedelta): + return max_wait + if isinstance(max_wait, (int, float)) and max_wait >= 0: + return timedelta(seconds=max_wait) + raise AirflowException("Operator arg `max_wait` must be timedelta object or a non-negative number") def _validate_input_values(self) -> None: if not isinstance(self.poke_interval, (int, float)) or self.poke_interval < 0: @@ -168,89 +159,20 @@ def _validate_input_values(self) -> None: f"mode since it will take reschedule time over MySQL's TIMESTAMP limit." ) - def poke(self, context: Context) -> Union[bool, PokeReturnValue]: - """ - Function that the sensors defined while deriving this class should - override. - """ - raise AirflowException('Override me.') - - def is_smart_sensor_compatible(self): - check_list = [ - not self.sensor_service_enabled, - self.on_success_callback, - self.on_retry_callback, - self.on_failure_callback, - ] - if any(check_list): - return False - - operator = self.__class__.__name__ - return operator in self.sensors_support_sensor_service - - def register_in_sensor_service(self, ti, context): - """ - Register ti in smart sensor service - - :param ti: Task instance object. - :param context: TaskInstance template context from the ti. - :return: boolean - """ - docs_url = get_docs_url('concepts/smart-sensors.html#migrating-to-deferrable-operators') - warnings.warn( - 'Your sensor is using Smart Sensors, which are deprecated.' - f' Please use Deferrable Operators instead. See {docs_url} for more info.', - DeprecationWarning, - ) - poke_context = self.get_poke_context(context) - execution_context = self.get_execution_context(context) - - return SensorInstance.register(ti, poke_context, execution_context) - - def get_poke_context(self, context): - """ - Return a dictionary with all attributes in poke_context_fields. The - poke_context with operator class can be used to identify a unique - sensor job. - - :param context: TaskInstance template context. - :return: A dictionary with key in poke_context_fields. - """ - if not context: - self.log.info("Function get_poke_context doesn't have a context input.") - - poke_context_fields = getattr(self.__class__, "poke_context_fields", None) - result = {key: getattr(self, key, None) for key in poke_context_fields} - return result - - def get_execution_context(self, context): - """ - Return a dictionary with all attributes in execution_fields. The - execution_context include execution requirement for each sensor task - such as timeout setup, email_alert setup. - - :param context: TaskInstance template context. - :return: A dictionary with key in execution_fields. - """ - if not context: - self.log.info("Function get_execution_context doesn't have a context input.") - execution_fields = self.__class__.execution_fields - - result = {key: getattr(self, key, None) for key in execution_fields} - if result['execution_timeout'] and isinstance(result['execution_timeout'], datetime.timedelta): - result['execution_timeout'] = result['execution_timeout'].total_seconds() - return result + def poke(self, context: Context) -> bool | PokeReturnValue: + """Function defined by the sensors while deriving this class should override.""" + raise AirflowException("Override me.") def execute(self, context: Context) -> Any: - started_at: Union[datetime.datetime, float] + started_at: datetime.datetime | float if self.reschedule: # If reschedule, use the start date of the first try (first try can be either the very # first execution of the task, or the first execution after the task was cleared.) - first_try_number = context['ti'].max_tries - self.retries + 1 + first_try_number = context["ti"].max_tries - self.retries + 1 task_reschedules = TaskReschedule.find_for_task_instance( - context['ti'], try_number=first_try_number + context["ti"], try_number=first_try_number ) if not task_reschedules: start_date = timezone.utcnow() @@ -282,10 +204,15 @@ def run_duration() -> float: if run_duration() > self.timeout: # If sensor is in soft fail mode but times out raise AirflowSkipException. + message = ( + f"Sensor has timed out; run duration of {run_duration()} seconds exceeds " + f"the specified timeout of {self.timeout}." + ) + if self.soft_fail: - raise AirflowSkipException(f"Snap. Time is OUT. DAG id: {log_dag_id}") + raise AirflowSkipException(message) else: - raise AirflowSensorTimeout(f"Snap. Time is OUT. DAG id: {log_dag_id}") + raise AirflowSensorTimeout(message) if self.reschedule: next_poke_interval = self._get_next_poke_interval(started_at, run_duration, try_number) reschedule_date = timezone.utcnow() + timedelta(seconds=next_poke_interval) @@ -303,7 +230,7 @@ def run_duration() -> float: def _get_next_poke_interval( self, - started_at: Union[datetime.datetime, float], + started_at: datetime.datetime | float, run_duration: Callable[[], float], try_number: int, ) -> float: @@ -321,6 +248,10 @@ def _get_next_poke_interval( delay_backoff_in_seconds = min(modded_hash, timedelta.max.total_seconds() - 1) new_interval = min(self.timeout - int(run_duration()), delay_backoff_in_seconds) + + if self.max_wait: + new_interval = min(self.max_wait.total_seconds(), new_interval) + self.log.info("new %s interval is %s", self.mode, new_interval) return new_interval @@ -329,15 +260,15 @@ def prepare_for_execution(self) -> BaseOperator: # Sensors in `poke` mode can block execution of DAGs when running # with single process executor, thus we change the mode to`reschedule` # to allow parallel task being scheduled and executed - if conf.get('core', 'executor') == "DebugExecutor": + if conf.get("core", "executor") == "DebugExecutor": self.log.warning("DebugExecutor changes sensor mode to 'reschedule'.") - task.mode = 'reschedule' + task.mode = "reschedule" return task @property def reschedule(self): """Define mode rescheduled sensors.""" - return self.mode == 'reschedule' + return self.mode == "reschedule" @classmethod def get_serialized_fields(cls): @@ -346,8 +277,9 @@ def get_serialized_fields(cls): def poke_mode_only(cls): """ - Class Decorator for child classes of BaseSensorOperator to indicate - that instances of this class are only safe to use poke mode. + Decorate a subclass of BaseSensorOperator with poke. + + Indicate that instances of this class are only safe to use poke mode. Will decorate all methods in the class to assert they did not change the mode from 'poke'. @@ -357,10 +289,10 @@ def poke_mode_only(cls): def decorate(cls_type): def mode_getter(_): - return 'poke' + return "poke" def mode_setter(_, value): - if value != 'poke': + if value != "poke": raise ValueError("cannot set mode to 'poke'.") if not issubclass(cls_type, BaseSensorOperator): diff --git a/airflow/sensors/base_sensor_operator.py b/airflow/sensors/base_sensor_operator.py deleted file mode 100644 index 716f03141ace9..0000000000000 --- a/airflow/sensors/base_sensor_operator.py +++ /dev/null @@ -1,26 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.sensors.base`.""" - -import warnings - -from airflow.sensors.base import BaseSensorOperator # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.sensors.base`.", DeprecationWarning, stacklevel=2 -) diff --git a/airflow/sensors/bash.py b/airflow/sensors/bash.py index c651727e8412d..f1fcb97139e9a 100644 --- a/airflow/sensors/bash.py +++ b/airflow/sensors/bash.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import os from subprocess import PIPE, STDOUT, Popen @@ -27,8 +28,9 @@ class BashSensor(BaseSensorOperator): """ - Executes a bash command/script and returns True if and only if the - return code is 0. + Executes a bash command/script. + + Return True if and only if the return code is 0. :param bash_command: The command, set of commands or reference to a bash script (must be '.sh') to be executed. @@ -38,26 +40,27 @@ class BashSensor(BaseSensorOperator): of inheriting the current process environment, which is the default behavior. (templated) :param output_encoding: output encoding of bash command. + + .. seealso:: + For more information on how to use this sensor,take a look at the guide: + :ref:`howto/operator:BashSensor` """ - template_fields: Sequence[str] = ('bash_command', 'env') + template_fields: Sequence[str] = ("bash_command", "env") - def __init__(self, *, bash_command, env=None, output_encoding='utf-8', **kwargs): + def __init__(self, *, bash_command, env=None, output_encoding="utf-8", **kwargs): super().__init__(**kwargs) self.bash_command = bash_command self.env = env self.output_encoding = output_encoding def poke(self, context: Context): - """ - Execute the bash command in a temporary directory - which will be cleaned afterwards - """ + """Execute the bash command in a temporary directory.""" bash_command = self.bash_command self.log.info("Tmp dir root location: \n %s", gettempdir()) - with TemporaryDirectory(prefix='airflowtmp') as tmp_dir: + with TemporaryDirectory(prefix="airflowtmp") as tmp_dir: with NamedTemporaryFile(dir=tmp_dir, prefix=self.task_id) as f: - f.write(bytes(bash_command, 'utf_8')) + f.write(bytes(bash_command, "utf_8")) f.flush() fname = f.name script_location = tmp_dir + "/" + fname @@ -65,7 +68,7 @@ def poke(self, context: Context): self.log.info("Running command: %s", bash_command) with Popen( - ['bash', fname], + ["bash", fname], stdout=PIPE, stderr=STDOUT, close_fds=True, @@ -75,7 +78,7 @@ def poke(self, context: Context): ) as resp: if resp.stdout: self.log.info("Output:") - for line in iter(resp.stdout.readline, b''): + for line in iter(resp.stdout.readline, b""): self.log.info(line.decode(self.output_encoding).strip()) resp.wait() self.log.info("Command exited with return code %s", resp.returncode) diff --git a/airflow/sensors/date_time.py b/airflow/sensors/date_time.py index 52c41753a8a5a..19168e98f3eca 100644 --- a/airflow/sensors/date_time.py +++ b/airflow/sensors/date_time.py @@ -15,9 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import datetime -from typing import Sequence, Union +from typing import Sequence from airflow.sensors.base import BaseSensorOperator from airflow.triggers.temporal import DateTimeTrigger @@ -56,7 +57,7 @@ class DateTimeSensor(BaseSensorOperator): template_fields: Sequence[str] = ("target_time",) - def __init__(self, *, target_time: Union[str, datetime.datetime], **kwargs) -> None: + def __init__(self, *, target_time: str | datetime.datetime, **kwargs) -> None: super().__init__(**kwargs) # self.target_time can't be a datetime object as it is a template_field @@ -76,9 +77,9 @@ def poke(self, context: Context) -> bool: class DateTimeSensorAsync(DateTimeSensor): """ - Waits until the specified datetime, deferring itself to avoid taking up - a worker slot while it is waiting. + Waits until the specified datetime occurs. + Deferring itself to avoid taking up a worker slot while it is waiting. It is a drop-in replacement for DateTimeSensor. :param target_time: datetime after which the job succeeds. (templated) diff --git a/airflow/sensors/date_time_sensor.py b/airflow/sensors/date_time_sensor.py deleted file mode 100644 index 63a221685af7c..0000000000000 --- a/airflow/sensors/date_time_sensor.py +++ /dev/null @@ -1,26 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.sensors.date_time`.""" - -import warnings - -from airflow.sensors.date_time import DateTimeSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.sensors.date_time`.", DeprecationWarning, stacklevel=2 -) diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py index 40c0a7a5665b7..967bb5a276ce1 100644 --- a/airflow/sensors/external_task.py +++ b/airflow/sensors/external_task.py @@ -15,29 +15,41 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import datetime import os -from typing import TYPE_CHECKING, Any, Callable, Collection, FrozenSet, Iterable, Optional, Union +import warnings +from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable +import attr from sqlalchemy import func -from airflow.exceptions import AirflowException -from airflow.models import BaseOperatorLink, DagBag, DagModel, DagRun, TaskInstance +from airflow.exceptions import AirflowException, AirflowSkipException, RemovedInAirflow3Warning +from airflow.models.baseoperator import BaseOperatorLink +from airflow.models.dag import DagModel +from airflow.models.dagbag import DagBag +from airflow.models.dagrun import DagRun +from airflow.models.taskinstance import TaskInstance from airflow.operators.empty import EmptyOperator from airflow.sensors.base import BaseSensorOperator +from airflow.utils.file import correct_maybe_zipped from airflow.utils.helpers import build_airflow_url_with_query from airflow.utils.session import provide_session from airflow.utils.state import State +if TYPE_CHECKING: + from sqlalchemy.orm import Query -class ExternalTaskSensorLink(BaseOperatorLink): + +class ExternalDagLink(BaseOperatorLink): """ - Operator link for ExternalTaskSensor. It allows users to access - DAG waited with ExternalTaskSensor. + Operator link for ExternalTaskSensor and ExternalTaskMarker. + + It allows users to access DAG waited with ExternalTaskSensor or cleared by ExternalTaskMarker. """ - name = 'External DAG' + name = "External DAG" def get_link(self, operator, dttm): ti = TaskInstance(task=operator, execution_date=dttm) @@ -48,13 +60,34 @@ def get_link(self, operator, dttm): class ExternalTaskSensor(BaseSensorOperator): """ - Waits for a different DAG or a task in a different DAG to complete for a - specific logical date. + Waits for a different DAG, task group, or task to complete for a specific logical date. + + If both `external_task_group_id` and `external_task_id` are ``None`` (default), the sensor + waits for the DAG. + Values for `external_task_group_id` and `external_task_id` can't be set at the same time. + + By default, the ExternalTaskSensor will wait for the external task to + succeed, at which point it will also succeed. However, by default it will + *not* fail if the external task fails, but will continue to check the status + until the sensor times out (thus giving you time to retry the external task + without also having to clear the sensor). + + It is possible to alter the default behavior by setting states which + cause the sensor to fail, e.g. by setting ``allowed_states=[State.FAILED]`` + and ``failed_states=[State.SUCCESS]`` you will flip the behaviour to get a + sensor which goes green when the external task *fails* and immediately goes + red if the external task *succeeds*! + + Note that ``soft_fail`` is respected when examining the failed_states. Thus + if the external task enters a failed state and ``soft_fail == True`` the + sensor will _skip_ rather than fail. As a result, setting ``soft_fail=True`` + and ``failed_states=[State.SKIPPED]`` will result in the sensor skipping if + the external task skips. :param external_dag_id: The dag_id that contains the task you want to wait for :param external_task_id: The task_id that contains the task you want to - wait for. If ``None`` (default value) the sensor waits for the DAG + wait for. :param external_task_ids: The list of task_ids that you want to wait for. If ``None`` (default value) the sensor waits for the DAG. Either external_task_id or external_task_ids can be passed to @@ -77,24 +110,21 @@ class ExternalTaskSensor(BaseSensorOperator): or DAG does not exist (default value: False). """ - template_fields = ['external_dag_id', 'external_task_id', 'external_task_ids'] - ui_color = '#19647e' - - @property - def operator_extra_links(self): - """Return operator extra links""" - return [ExternalTaskSensorLink()] + template_fields = ["external_dag_id", "external_task_id", "external_task_ids"] + ui_color = "#19647e" + operator_extra_links = [ExternalDagLink()] def __init__( self, *, external_dag_id: str, - external_task_id: Optional[str] = None, - external_task_ids: Optional[Collection[str]] = None, - allowed_states: Optional[Iterable[str]] = None, - failed_states: Optional[Iterable[str]] = None, - execution_delta: Optional[datetime.timedelta] = None, - execution_date_fn: Optional[Callable] = None, + external_task_id: str | None = None, + external_task_ids: Collection[str] | None = None, + external_task_group_id: str | None = None, + allowed_states: Iterable[str] | None = None, + failed_states: Iterable[str] | None = None, + execution_delta: datetime.timedelta | None = None, + execution_date_fn: Callable | None = None, check_existence: bool = False, **kwargs, ): @@ -112,31 +142,37 @@ def __init__( if external_task_id is not None and external_task_ids is not None: raise ValueError( - 'Only one of `external_task_id` or `external_task_ids` may ' - 'be provided to ExternalTaskSensor; not both.' + "Only one of `external_task_id` or `external_task_ids` may " + "be provided to ExternalTaskSensor; not both." ) if external_task_id is not None: external_task_ids = [external_task_id] - if external_task_ids: + if external_task_group_id and external_task_ids: + raise ValueError( + "Values for `external_task_group_id` and `external_task_id` or `external_task_ids` " + "can't be set at the same time" + ) + + if external_task_ids or external_task_group_id: if not total_states <= set(State.task_states): raise ValueError( - f'Valid values for `allowed_states` and `failed_states` ' - f'when `external_task_id` or `external_task_ids` is not `None`: {State.task_states}' + f"Valid values for `allowed_states` and `failed_states` " + f"when `external_task_id` or `external_task_ids` or `external_task_group_id` " + f"is not `None`: {State.task_states}" ) - if len(external_task_ids) > len(set(external_task_ids)): - raise ValueError('Duplicate task_ids passed in external_task_ids parameter') + elif not total_states <= set(State.dag_states): raise ValueError( - f'Valid values for `allowed_states` and `failed_states` ' - f'when `external_task_id` is `None`: {State.dag_states}' + f"Valid values for `allowed_states` and `failed_states` " + f"when `external_task_id` and `external_task_group_id` is `None`: {State.dag_states}" ) if execution_delta is not None and execution_date_fn is not None: raise ValueError( - 'Only one of `execution_delta` or `execution_date_fn` may ' - 'be provided to ExternalTaskSensor; not both.' + "Only one of `execution_delta` or `execution_date_fn` may " + "be provided to ExternalTaskSensor; not both." ) self.execution_delta = execution_delta @@ -144,27 +180,42 @@ def __init__( self.external_dag_id = external_dag_id self.external_task_id = external_task_id self.external_task_ids = external_task_ids + self.external_task_group_id = external_task_group_id self.check_existence = check_existence self._has_checked_existence = False - @provide_session - def poke(self, context, session=None): + def _get_dttm_filter(self, context): if self.execution_delta: - dttm = context['logical_date'] - self.execution_delta + dttm = context["logical_date"] - self.execution_delta elif self.execution_date_fn: dttm = self._handle_execution_date_fn(context=context) else: - dttm = context['logical_date'] + dttm = context["logical_date"] + return dttm if isinstance(dttm, list) else [dttm] - dttm_filter = dttm if isinstance(dttm, list) else [dttm] - serialized_dttm_filter = ','.join(dt.isoformat() for dt in dttm_filter) + @provide_session + def poke(self, context, session=None): + if self.external_task_ids and len(self.external_task_ids) > len(set(self.external_task_ids)): + raise ValueError("Duplicate task_ids passed in external_task_ids parameter") - self.log.info( - 'Poking for tasks %s in dag %s on %s ... ', - self.external_task_ids, - self.external_dag_id, - serialized_dttm_filter, - ) + dttm_filter = self._get_dttm_filter(context) + serialized_dttm_filter = ",".join(dt.isoformat() for dt in dttm_filter) + + if self.external_task_ids: + self.log.info( + "Poking for tasks %s in dag %s on %s ... ", + self.external_task_ids, + self.external_dag_id, + serialized_dttm_filter, + ) + + if self.external_task_group_id: + self.log.info( + "Poking for task_group '%s' in dag '%s' on %s ... ", + self.external_task_group_id, + self.external_dag_id, + serialized_dttm_filter, + ) # In poke mode this will check dag existence only once if self.check_existence and not self._has_checked_existence: @@ -176,39 +227,69 @@ def poke(self, context, session=None): if self.failed_states: count_failed = self.get_count(dttm_filter, session, self.failed_states) - if count_failed == len(dttm_filter): + # Fail if anything in the list has failed. + if count_failed > 0: if self.external_task_ids: + if self.soft_fail: + raise AirflowSkipException( + f"Some of the external tasks {self.external_task_ids} " + f"in DAG {self.external_dag_id} failed. Skipping due to soft_fail." + ) raise AirflowException( - f'Some of the external tasks {self.external_task_ids} ' - f'in DAG {self.external_dag_id} failed.' + f"Some of the external tasks {self.external_task_ids} " + f"in DAG {self.external_dag_id} failed." ) + elif self.external_task_group_id: + if self.soft_fail: + raise AirflowSkipException( + f"The external task_group '{self.external_task_group_id}' " + f"in DAG '{self.external_dag_id}' failed. Skipping due to soft_fail." + ) + raise AirflowException( + f"The external task_group '{self.external_task_group_id}' " + f"in DAG '{self.external_dag_id}' failed." + ) + else: - raise AirflowException(f'The external DAG {self.external_dag_id} failed.') + if self.soft_fail: + raise AirflowSkipException( + f"The external DAG {self.external_dag_id} failed. Skipping due to soft_fail." + ) + raise AirflowException(f"The external DAG {self.external_dag_id} failed.") return count_allowed == len(dttm_filter) def _check_for_existence(self, session) -> None: - dag_to_wait = session.query(DagModel).filter(DagModel.dag_id == self.external_dag_id).first() + dag_to_wait = DagModel.get_current(self.external_dag_id, session) if not dag_to_wait: - raise AirflowException(f'The external DAG {self.external_dag_id} does not exist.') + raise AirflowException(f"The external DAG {self.external_dag_id} does not exist.") - if not os.path.exists(dag_to_wait.fileloc): - raise AirflowException(f'The external DAG {self.external_dag_id} was deleted.') + if not os.path.exists(correct_maybe_zipped(dag_to_wait.fileloc)): + raise AirflowException(f"The external DAG {self.external_dag_id} was deleted.") if self.external_task_ids: refreshed_dag_info = DagBag(dag_to_wait.fileloc).get_dag(self.external_dag_id) for external_task_id in self.external_task_ids: if not refreshed_dag_info.has_task(external_task_id): raise AirflowException( - f'The external task {external_task_id} in ' - f'DAG {self.external_dag_id} does not exist.' + f"The external task {external_task_id} in " + f"DAG {self.external_dag_id} does not exist." ) + + if self.external_task_group_id: + refreshed_dag_info = DagBag(dag_to_wait.fileloc).get_dag(self.external_dag_id) + if not refreshed_dag_info.has_task_group(self.external_task_group_id): + raise AirflowException( + f"The external task group '{self.external_task_group_id}' in " + f"DAG '{self.external_dag_id}' does not exist." + ) + self._has_checked_existence = True def get_count(self, dttm_filter, session, states) -> int: """ - Get the count of records against dttm filter and states + Get the count of records against dttm filter and states. :param dttm_filter: date time filter for execution date :param session: airflow session object @@ -222,30 +303,44 @@ def get_count(self, dttm_filter, session, states) -> int: if self.external_task_ids: count = ( - session.query(func.count()) # .count() is inefficient - .filter( - TI.dag_id == self.external_dag_id, - TI.task_id.in_(self.external_task_ids), - TI.state.in_(states), - TI.execution_date.in_(dttm_filter), - ) + self._count_query(TI, session, states, dttm_filter) + .filter(TI.task_id.in_(self.external_task_ids)) .scalar() - ) - count = count / len(self.external_task_ids) - else: + ) / len(self.external_task_ids) + elif self.external_task_group_id: + external_task_group_task_ids = self.get_external_task_group_task_ids(session) count = ( - session.query(func.count()) - .filter( - DR.dag_id == self.external_dag_id, - DR.state.in_(states), - DR.execution_date.in_(dttm_filter), - ) + self._count_query(TI, session, states, dttm_filter) + .filter(TI.task_id.in_(external_task_group_task_ids)) .scalar() - ) + ) / len(external_task_group_task_ids) + else: + count = self._count_query(DR, session, states, dttm_filter).scalar() return count + def _count_query(self, model, session, states, dttm_filter) -> Query: + query = session.query(func.count()).filter( + model.dag_id == self.external_dag_id, + model.state.in_(states), # pylint: disable=no-member + model.execution_date.in_(dttm_filter), + ) + return query + + def get_external_task_group_task_ids(self, session): + refreshed_dag_info = DagBag(read_dags_from_db=True).get_dag(self.external_dag_id, session) + task_group = refreshed_dag_info.task_group_dict.get(self.external_task_group_id) + + if task_group: + return [task.task_id for task in task_group] + + # returning default task_id as group_id itself, this will avoid any failure in case of + # 'check_existence=False' and will fail on timeout + return [self.external_task_group_id] + def _handle_execution_date_fn(self, context) -> Any: """ + Handle backward compatibility. + This function is to handle backwards compatibility with how this operator was previously where it only passes the execution date, but also allow for the newer implementation to pass all context variables as keyword arguments, to allow @@ -268,6 +363,7 @@ def _handle_execution_date_fn(self, context) -> Any: class ExternalTaskMarker(EmptyOperator): """ Use this operator to indicate that a task on a different DAG depends on this task. + When this task is cleared with "Recursive" selected, Airflow will clear the task on the other DAG and its downstream tasks recursively. Transitive dependencies are followed until the recursion_depth is reached. @@ -281,18 +377,19 @@ class ExternalTaskMarker(EmptyOperator): it slower to clear tasks in the web UI. """ - template_fields = ['external_dag_id', 'external_task_id', 'execution_date'] - ui_color = '#19647e' + template_fields = ["external_dag_id", "external_task_id", "execution_date"] + ui_color = "#19647e" + operator_extra_links = [ExternalDagLink()] # The _serialized_fields are lazily loaded when get_serialized_fields() method is called - __serialized_fields: Optional[FrozenSet[str]] = None + __serialized_fields: frozenset[str] | None = None def __init__( self, *, external_dag_id: str, external_task_id: str, - execution_date: Optional[Union[str, datetime.datetime]] = "{{ logical_date.isoformat() }}", + execution_date: str | datetime.datetime | None = "{{ logical_date.isoformat() }}", recursion_depth: int = 10, **kwargs, ): @@ -305,7 +402,7 @@ def __init__( self.execution_date = execution_date else: raise TypeError( - f'Expected str or datetime.datetime type for execution_date. Got {type(execution_date)}' + f"Expected str or datetime.datetime type for execution_date. Got {type(execution_date)}" ) if recursion_depth <= 0: @@ -318,3 +415,19 @@ def get_serialized_fields(cls): if not cls.__serialized_fields: cls.__serialized_fields = frozenset(super().get_serialized_fields() | {"recursion_depth"}) return cls.__serialized_fields + + +@attr.s(auto_attribs=True) +class ExternalTaskSensorLink(ExternalDagLink): + """ + This external link is deprecated. + Please use :class:`airflow.sensors.external_task.ExternalDagLink`. + """ + + def __attrs_post_init__(self): + warnings.warn( + "This external link is deprecated. " + "Please use :class:`airflow.sensors.external_task.ExternalDagLink`.", + RemovedInAirflow3Warning, + stacklevel=2, + ) diff --git a/airflow/sensors/external_task_sensor.py b/airflow/sensors/external_task_sensor.py deleted file mode 100644 index bc24a4d1f27eb..0000000000000 --- a/airflow/sensors/external_task_sensor.py +++ /dev/null @@ -1,30 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.sensors.external_task`.""" - -import warnings - -from airflow.sensors.external_task import ( # noqa - ExternalTaskMarker, - ExternalTaskSensor, - ExternalTaskSensorLink, -) - -warnings.warn( - "This module is deprecated. Please use `airflow.sensors.external_task`.", DeprecationWarning, stacklevel=2 -) diff --git a/airflow/sensors/filesystem.py b/airflow/sensors/filesystem.py index 9aa56900e9bf2..dea1643d9ae21 100644 --- a/airflow/sensors/filesystem.py +++ b/airflow/sensors/filesystem.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +from __future__ import annotations import datetime import os @@ -40,12 +40,18 @@ class FileSensor(BaseSensorOperator): the base path set within the connection), can be a glob. :param recursive: when set to ``True``, enables recursive directory matching behavior of ``**`` in glob filepath parameter. Defaults to ``False``. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/operator:FileSensor` + + """ - template_fields: Sequence[str] = ('filepath',) - ui_color = '#91818a' + template_fields: Sequence[str] = ("filepath",) + ui_color = "#91818a" - def __init__(self, *, filepath, fs_conn_id='fs_default', recursive=False, **kwargs): + def __init__(self, *, filepath, fs_conn_id="fs_default", recursive=False, **kwargs): super().__init__(**kwargs) self.filepath = filepath self.fs_conn_id = fs_conn_id @@ -55,12 +61,12 @@ def poke(self, context: Context): hook = FSHook(self.fs_conn_id) basepath = hook.get_path() full_path = os.path.join(basepath, self.filepath) - self.log.info('Poking for file %s', full_path) + self.log.info("Poking for file %s", full_path) for path in glob(full_path, recursive=self.recursive): if os.path.isfile(path): - mod_time = datetime.datetime.fromtimestamp(os.path.getmtime(path)).strftime('%Y%m%d%H%M%S') - self.log.info('Found File %s last modified: %s', str(path), mod_time) + mod_time = datetime.datetime.fromtimestamp(os.path.getmtime(path)).strftime("%Y%m%d%H%M%S") + self.log.info("Found File %s last modified: %s", str(path), mod_time) return True for _, _, files in os.walk(path): diff --git a/airflow/sensors/hdfs_sensor.py b/airflow/sensors/hdfs_sensor.py deleted file mode 100644 index 0d5690085beb9..0000000000000 --- a/airflow/sensors/hdfs_sensor.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""This module is deprecated. Please use :mod:`airflow.providers.apache.hdfs.sensors.hdfs`.""" - -import warnings - -from airflow.providers.apache.hdfs.sensors.hdfs import HdfsSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.hdfs.sensors.hdfs`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/sensors/hive_partition_sensor.py b/airflow/sensors/hive_partition_sensor.py deleted file mode 100644 index 8f6f08ae3f552..0000000000000 --- a/airflow/sensors/hive_partition_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.hive.sensors.hive_partition`.""" - -import warnings - -from airflow.providers.apache.hive.sensors.hive_partition import HivePartitionSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.hive.sensors.hive_partition`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/sensors/http_sensor.py b/airflow/sensors/http_sensor.py deleted file mode 100644 index 96dce065b50e8..0000000000000 --- a/airflow/sensors/http_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.http.sensors.http`.""" - -import warnings - -from airflow.providers.http.sensors.http import HttpSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.http.sensors.http`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/sensors/metastore_partition_sensor.py b/airflow/sensors/metastore_partition_sensor.py deleted file mode 100644 index 812c86fc57c0a..0000000000000 --- a/airflow/sensors/metastore_partition_sensor.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.apache.hive.sensors.metastore_partition`. -""" - -import warnings - -from airflow.providers.apache.hive.sensors.metastore_partition import MetastorePartitionSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.hive.sensors.metastore_partition`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/sensors/named_hive_partition_sensor.py b/airflow/sensors/named_hive_partition_sensor.py deleted file mode 100644 index 574c2ce04402c..0000000000000 --- a/airflow/sensors/named_hive_partition_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.hive.sensors.named_hive_partition`.""" - -import warnings - -from airflow.providers.apache.hive.sensors.named_hive_partition import NamedHivePartitionSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.hive.sensors.named_hive_partition`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/sensors/python.py b/airflow/sensors/python.py index c139c39a6a80e..615e4e20eea18 100644 --- a/airflow/sensors/python.py +++ b/airflow/sensors/python.py @@ -15,9 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence +from __future__ import annotations -from airflow.sensors.base import BaseSensorOperator +from typing import Any, Callable, Mapping, Sequence + +from airflow.sensors.base import BaseSensorOperator, PokeReturnValue from airflow.utils.context import Context, context_merge from airflow.utils.operator_helpers import determine_kwargs @@ -40,17 +42,21 @@ class PythonSensor(BaseSensorOperator): will get templated by the Airflow engine sometime between ``__init__`` and ``execute`` takes place and are made available in your callable's context after the template has been applied. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/operator:PythonSensor` """ - template_fields: Sequence[str] = ('templates_dict', 'op_args', 'op_kwargs') + template_fields: Sequence[str] = ("templates_dict", "op_args", "op_kwargs") def __init__( self, *, python_callable: Callable, - op_args: Optional[List] = None, - op_kwargs: Optional[Mapping[str, Any]] = None, - templates_dict: Optional[Dict] = None, + op_args: list | None = None, + op_kwargs: Mapping[str, Any] | None = None, + templates_dict: dict | None = None, **kwargs, ): super().__init__(**kwargs) @@ -59,10 +65,10 @@ def __init__( self.op_kwargs = op_kwargs or {} self.templates_dict = templates_dict - def poke(self, context: Context) -> bool: + def poke(self, context: Context) -> PokeReturnValue | bool: context_merge(context, self.op_kwargs, templates_dict=self.templates_dict) self.op_kwargs = determine_kwargs(self.python_callable, self.op_args, context) self.log.info("Poking callable: %s", str(self.python_callable)) return_value = self.python_callable(*self.op_args, **self.op_kwargs) - return bool(return_value) + return PokeReturnValue(bool(return_value)) diff --git a/airflow/sensors/s3_key_sensor.py b/airflow/sensors/s3_key_sensor.py deleted file mode 100644 index d0f7c40b6d657..0000000000000 --- a/airflow/sensors/s3_key_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.s3_key`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.s3_key import S3KeySensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3_key`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/sensors/s3_prefix_sensor.py b/airflow/sensors/s3_prefix_sensor.py deleted file mode 100644 index d5ea72b37d15a..0000000000000 --- a/airflow/sensors/s3_prefix_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.s3`.""" - -import warnings - -from airflow.providers.amazon.aws.sensors.s3_prefix import S3PrefixSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/sensors/smart_sensor.py b/airflow/sensors/smart_sensor.py deleted file mode 100644 index bc22ab9c541eb..0000000000000 --- a/airflow/sensors/smart_sensor.py +++ /dev/null @@ -1,757 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import datetime -import json -import logging -import traceback -from logging.config import DictConfigurator # type: ignore -from time import sleep - -from sqlalchemy import and_, or_, tuple_ - -from airflow.compat.functools import cached_property -from airflow.exceptions import AirflowException, AirflowTaskTimeout -from airflow.models import BaseOperator, DagRun, SensorInstance, SkipMixin, TaskInstance -from airflow.settings import LOGGING_CLASS_PATH -from airflow.stats import Stats -from airflow.utils import helpers, timezone -from airflow.utils.context import Context -from airflow.utils.email import send_email -from airflow.utils.log.logging_mixin import set_context -from airflow.utils.module_loading import import_string -from airflow.utils.net import get_hostname -from airflow.utils.session import provide_session -from airflow.utils.state import PokeState, State -from airflow.utils.timeout import timeout - -config = import_string(LOGGING_CLASS_PATH) -handler_config = config['handlers']['task'] -try: - formatter_config = config['formatters'][handler_config['formatter']] -except Exception as err: - formatter_config = None - print(err) -dictConfigurator = DictConfigurator(config) - - -class SensorWork: - """ - This class stores a sensor work with decoded context value. It is only used - inside of smart sensor. Create a sensor work based on sensor instance record. - A sensor work object has the following attributes: - `dag_id`: sensor_instance dag_id. - `task_id`: sensor_instance task_id. - `execution_date`: sensor_instance execution_date. - `try_number`: sensor_instance try_number - `poke_context`: Decoded poke_context for the sensor task. - `execution_context`: Decoded execution_context. - `hashcode`: This is the signature of poking job. - `operator`: The sensor operator class. - `op_classpath`: The sensor operator class path - `encoded_poke_context`: The raw data from sensor_instance poke_context column. - `log`: The sensor work logger which will mock the corresponding task instance log. - - :param si: The sensor_instance ORM object. - """ - - def __init__(self, si): - self.dag_id = si.dag_id - self.task_id = si.task_id - self.execution_date = si.execution_date - self.try_number = si.try_number - - self.poke_context = json.loads(si.poke_context) if si.poke_context else {} - self.execution_context = json.loads(si.execution_context) if si.execution_context else {} - self.hashcode = si.hashcode - self.start_date = si.start_date - self.operator = si.operator - self.op_classpath = si.op_classpath - self.encoded_poke_context = si.poke_context - self.si = si - - def __eq__(self, other): - if not isinstance(other, SensorWork): - return NotImplemented - - return ( - self.dag_id == other.dag_id - and self.task_id == other.task_id - and self.execution_date == other.execution_date - and self.try_number == other.try_number - ) - - @staticmethod - def create_new_task_handler(): - """ - Create task log handler for a sensor work. - :return: log handler - """ - from airflow.utils.log.secrets_masker import _secrets_masker - - handler_config_copy = {k: handler_config[k] for k in handler_config} - del handler_config_copy['filters'] - - formatter_config_copy = {k: formatter_config[k] for k in formatter_config} - handler = dictConfigurator.configure_handler(handler_config_copy) - formatter = dictConfigurator.configure_formatter(formatter_config_copy) - handler.setFormatter(formatter) - - # We want to share the _global_ filterer instance, not create a new one - handler.addFilter(_secrets_masker()) - return handler - - @cached_property - def log(self): - """Return logger for a sensor instance object.""" - # The created log_id is used inside of smart sensor as the key to fetch - # the corresponding in memory log handler. - si = self.si - si.raw = False # Otherwise set_context will fail - log_id = "-".join( - [si.dag_id, si.task_id, si.execution_date.strftime("%Y_%m_%dT%H_%M_%S_%f"), str(si.try_number)] - ) - logger = logging.getLogger(f'airflow.task.{log_id}') - - if len(logger.handlers) == 0: - handler = self.create_new_task_handler() - logger.addHandler(handler) - set_context(logger, si) - - line_break = "-" * 120 - logger.info(line_break) - logger.info( - "Processing sensor task %s in smart sensor service on host: %s", self.ti_key, get_hostname() - ) - logger.info(line_break) - return logger - - def close_sensor_logger(self): - """Close log handler for a sensor work.""" - for handler in self.log.handlers: - try: - handler.close() - except Exception as e: - print(e) - - @property - def ti_key(self): - """Key for the task instance that maps to the sensor work.""" - return self.dag_id, self.task_id, self.execution_date - - @property - def cache_key(self): - """Key used to query in smart sensor for cached sensor work.""" - return self.operator, self.encoded_poke_context - - -class CachedPokeWork: - """ - Wrapper class for the poke work inside smart sensor. It saves - the sensor_task used to poke and recent poke result state. - state: poke state. - sensor_task: The cached object for executing the poke function. - last_poke_time: The latest time this cached work being called. - to_flush: If we should flush the cached work. - """ - - def __init__(self): - self.state = None - self.sensor_task = None - self.last_poke_time = None - self.to_flush = False - - def set_state(self, state): - """ - Set state for cached poke work. - :param state: The sensor_instance state. - """ - self.state = state - self.last_poke_time = timezone.utcnow() - - def clear_state(self): - """Clear state for cached poke work.""" - self.state = None - - def set_to_flush(self): - """Mark this poke work to be popped from cached dict after current loop.""" - self.to_flush = True - - def is_expired(self): - """ - The cached task object expires if there is no poke for 20 minutes. - :return: Boolean - """ - return self.to_flush or (timezone.utcnow() - self.last_poke_time).total_seconds() > 1200 - - -class SensorExceptionInfo: - """ - Hold sensor exception information and the type of exception. For possible transient - infra failure, give the task more chance to retry before fail it. - """ - - def __init__( - self, - exception_info, - is_infra_failure=False, - infra_failure_retry_window=datetime.timedelta(minutes=130), - ): - self._exception_info = exception_info - self._is_infra_failure = is_infra_failure - self._infra_failure_retry_window = infra_failure_retry_window - - self._infra_failure_timeout = None - self.set_infra_failure_timeout() - self.fail_current_run = self.should_fail_current_run() - - def set_latest_exception(self, exception_info, is_infra_failure=False): - """ - This function set the latest exception information for sensor exception. If the exception - implies an infra failure, this function will check the recorded infra failure timeout - which was set at the first infra failure exception arrives. There is a 6 hours window - for retry without failing current run. - - :param exception_info: Details of the exception information. - :param is_infra_failure: If current exception was caused by transient infra failure. - There is a retry window _infra_failure_retry_window that the smart sensor will - retry poke function without failing current task run. - """ - self._exception_info = exception_info - self._is_infra_failure = is_infra_failure - - self.set_infra_failure_timeout() - self.fail_current_run = self.should_fail_current_run() - - def set_infra_failure_timeout(self): - """ - Set the time point when the sensor should be failed if it kept getting infra - failure. - :return: - """ - # Only set the infra_failure_timeout if there is no existing one - if not self._is_infra_failure: - self._infra_failure_timeout = None - elif self._infra_failure_timeout is None: - self._infra_failure_timeout = timezone.utcnow() + self._infra_failure_retry_window - - def should_fail_current_run(self): - """:return: Should the sensor fail""" - return not self.is_infra_failure or timezone.utcnow() > self._infra_failure_timeout - - @property - def exception_info(self): - """:return: exception msg.""" - return self._exception_info - - @property - def is_infra_failure(self): - """:return: If the exception is an infra failure""" - return self._is_infra_failure - - def is_expired(self): - """:return: If current exception need to be kept.""" - if not self._is_infra_failure: - return True - return timezone.utcnow() > self._infra_failure_timeout + datetime.timedelta(minutes=30) - - -class SmartSensorOperator(BaseOperator, SkipMixin): - """ - Smart sensor operators are derived from this class. - - Smart Sensor operators keep refresh a dictionary by visiting DB. - Taking qualified active sensor tasks. Different from sensor operator, - Smart sensor operators poke for all sensor tasks in the dictionary at - a time interval. When a criteria is met or fail by time out, it update - all sensor task state in task_instance table - - :param soft_fail: Set to true to mark the task as SKIPPED on failure - :param poke_interval: Time in seconds that the job should wait in - between each tries. - :param smart_sensor_timeout: Time, in seconds before the internal sensor - job times out if poke_timeout is not defined. - :param shard_min: shard code lower bound (inclusive) - :param shard_max: shard code upper bound (exclusive) - :param poke_timeout: Time, in seconds before the task times out and fails. - """ - - ui_color = '#e6f1f2' - - def __init__( - self, - poke_interval=180, - smart_sensor_timeout=60 * 60 * 24 * 7, - soft_fail=False, - shard_min=0, - shard_max=100000, - poke_timeout=6.0, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - # super(SmartSensorOperator, self).__init__(*args, **kwargs) - self.poke_interval = poke_interval - self.soft_fail = soft_fail - self.timeout = smart_sensor_timeout - self._validate_input_values() - self.hostname = "" - - self.sensor_works = [] - self.cached_dedup_works = {} - self.cached_sensor_exceptions = {} - - self.max_tis_per_query = 50 - self.shard_min = shard_min - self.shard_max = shard_max - self.poke_timeout = poke_timeout - - def _validate_input_values(self): - if not isinstance(self.poke_interval, (int, float)) or self.poke_interval < 0: - raise AirflowException("The poke_interval must be a non-negative number") - if not isinstance(self.timeout, (int, float)) or self.timeout < 0: - raise AirflowException("The timeout must be a non-negative number") - - @provide_session - def _load_sensor_works(self, session=None): - """ - Refresh sensor instances need to be handled by this operator. Create smart sensor - internal object based on the information persisted in the sensor_instance table. - - """ - SI = SensorInstance - with Stats.timer() as timer: - query = ( - session.query(SI) - .filter(SI.state == State.SENSING) - .filter(SI.shardcode < self.shard_max, SI.shardcode >= self.shard_min) - ) - tis = query.all() - - self.log.info("Performance query %s tis, time: %.3f", len(tis), timer.duration) - - # Query without checking dagrun state might keep some failed dag_run tasks alive. - # Join with DagRun table will be very slow based on the number of sensor tasks we - # need to handle. We query all smart tasks in this operator - # and expect scheduler correct the states in _change_state_for_tis_without_dagrun() - - sensor_works = [] - for ti in tis: - try: - sensor_works.append(SensorWork(ti)) - except Exception: - self.log.exception("Exception at creating sensor work for ti %s", ti.key) - - self.log.info("%d tasks detected.", len(sensor_works)) - - new_sensor_works = [x for x in sensor_works if x not in self.sensor_works] - - self._update_ti_hostname(new_sensor_works) - - self.sensor_works = sensor_works - - @provide_session - def _update_ti_hostname(self, sensor_works, session=None): - """ - Update task instance hostname for new sensor works. - - :param sensor_works: Smart sensor internal object for a sensor task. - :param session: The sqlalchemy session. - """ - DR = DagRun - TI = TaskInstance - - def update_ti_hostname_with_count(count, sensor_works): - # Using or_ instead of in_ here to prevent from full table scan. - if session.bind.dialect.name == 'mssql': - ti_filter = or_( - and_( - TI.dag_id == ti_key.dag_id, - TI.task_id == ti_key.task_id, - DR.execution_date == ti_key.execution_date, - ) - for ti_key in sensor_works - ) - else: - ti_keys = [(x.dag_id, x.task_id, x.execution_date) for x in sensor_works] - ti_filter = or_( - tuple_(TI.dag_id, TI.task_id, DR.execution_date) == ti_key for ti_key in ti_keys - ) - - for ti in session.query(TI).join(TI.dag_run).filter(ti_filter): - ti.hostname = self.hostname - session.commit() - - return count + len(sensor_works) - - count = helpers.reduce_in_chunks( - update_ti_hostname_with_count, sensor_works, 0, self.max_tis_per_query - ) - if count: - self.log.info("Updated hostname on %s tis.", count) - - @provide_session - def _mark_multi_state(self, operator, poke_hash, encoded_poke_context, state, session=None): - """ - Mark state for multiple tasks in the task_instance table to a new state if they have - the same signature as the poke_hash. - - :param operator: The sensor's operator class name. - :param poke_hash: The hash code generated from sensor's poke context. - :param encoded_poke_context: The raw encoded poke_context. - :param state: Set multiple sensor tasks to this state. - :param session: The sqlalchemy session. - """ - - def mark_state(ti, sensor_instance): - ti.state = state - sensor_instance.state = state - if state in State.finished: - ti.end_date = end_date - ti.set_duration() - - SI = SensorInstance - TI = TaskInstance - - count_marked = 0 - query_result = [] - try: - query_result = ( - session.query(TI, SI) - .join( - TI, - and_( - TI.dag_id == SI.dag_id, - TI.task_id == SI.task_id, - TI.execution_date == SI.execution_date, - ), - ) - .filter(SI.state == State.SENSING) - .filter(SI.hashcode == poke_hash) - .filter(SI.operator == operator) - .with_for_update() - .all() - ) - - end_date = timezone.utcnow() - for ti, sensor_instance in query_result: - if sensor_instance.poke_context != encoded_poke_context: - continue - - ti.hostname = self.hostname - if ti.state == State.SENSING: - mark_state(ti=ti, sensor_instance=sensor_instance) - count_marked += 1 - else: - # ti.state != State.SENSING - sensor_instance.state = ti.state - - session.commit() - - except Exception: - self.log.warning( - "Exception _mark_multi_state in smart sensor for hashcode %s", - str(poke_hash), # cast to str in advance for highlighting - exc_info=True, - ) - self.log.info("Marked %s tasks out of %s to state %s", count_marked, len(query_result), state) - - @provide_session - def _retry_or_fail_task(self, sensor_work, error, session=None): - """ - Change single task state for sensor task. For final state, set the end_date. - Since smart sensor take care all retries in one process. Failed sensor tasks - logically experienced all retries and the try_number should be set to max_tries. - - :param sensor_work: The sensor_work with exception. - :param error: The error message for this sensor_work. - :param session: The sqlalchemy session. - """ - - def email_alert(task_instance, error_info): - try: - subject, html_content, _ = task_instance.get_email_subject_content(error_info) - email = sensor_work.execution_context.get('email') - - send_email(email, subject, html_content) - except Exception: - sensor_work.log.warning("Exception alerting email.", exc_info=True) - - def handle_failure(sensor_work, ti): - if sensor_work.execution_context.get('retries') and ti.try_number <= ti.max_tries: - # retry - ti.state = State.UP_FOR_RETRY - if sensor_work.execution_context.get('email_on_retry') and sensor_work.execution_context.get( - 'email' - ): - sensor_work.log.info("%s sending email alert for retry", sensor_work.ti_key) - email_alert(ti, error) - else: - ti.state = State.FAILED - if sensor_work.execution_context.get( - 'email_on_failure' - ) and sensor_work.execution_context.get('email'): - sensor_work.log.info("%s sending email alert for failure", sensor_work.ti_key) - email_alert(ti, error) - - try: - dag_id, task_id, execution_date = sensor_work.ti_key - TI = TaskInstance - SI = SensorInstance - sensor_instance = ( - session.query(SI) - .filter(SI.dag_id == dag_id, SI.task_id == task_id, SI.execution_date == execution_date) - .with_for_update() - .first() - ) - - if sensor_instance.hashcode != sensor_work.hashcode: - # Return without setting state - return - - ti = ( - session.query(TI) - .filter(TI.dag_id == dag_id, TI.task_id == task_id, TI.execution_date == execution_date) - .with_for_update() - .first() - ) - - if ti: - if ti.state == State.SENSING: - ti.hostname = self.hostname - handle_failure(sensor_work, ti) - - sensor_instance.state = State.FAILED - ti.end_date = timezone.utcnow() - ti.set_duration() - else: - sensor_instance.state = ti.state - session.merge(sensor_instance) - session.merge(ti) - session.commit() - - sensor_work.log.info( - "Task %s got an error: %s. Set the state to failed. Exit.", str(sensor_work.ti_key), error - ) - sensor_work.close_sensor_logger() - - except AirflowException: - sensor_work.log.warning("Exception on failing %s", sensor_work.ti_key, exc_info=True) - - def _check_and_handle_ti_timeout(self, sensor_work): - """ - Check if a sensor task in smart sensor is timeout. Could be either sensor operator timeout - or general operator execution_timeout. - - :param sensor_work: SensorWork - """ - task_timeout = sensor_work.execution_context.get('timeout', self.timeout) - task_execution_timeout = sensor_work.execution_context.get('execution_timeout') - if task_execution_timeout: - task_timeout = min(task_timeout, task_execution_timeout) - - if (timezone.utcnow() - sensor_work.start_date).total_seconds() > task_timeout: - error = "Sensor Timeout" - sensor_work.log.exception(error) - self._retry_or_fail_task(sensor_work, error) - - def _handle_poke_exception(self, sensor_work): - """ - Fail task if accumulated exceptions exceeds retries. - - :param sensor_work: SensorWork - """ - sensor_exception = self.cached_sensor_exceptions.get(sensor_work.cache_key) - error = sensor_exception.exception_info - sensor_work.log.exception("Handling poke exception: %s", error) - - if sensor_exception.fail_current_run: - if sensor_exception.is_infra_failure: - sensor_work.log.exception( - "Task %s failed by infra failure in smart sensor.", sensor_work.ti_key - ) - # There is a risk for sensor object cached in smart sensor keep throwing - # exception and cause an infra failure. To make sure the sensor tasks after - # retry will not fall into same object and have endless infra failure, - # we mark the sensor task after an infra failure so that it can be popped - # before next poke loop. - cache_key = sensor_work.cache_key - self.cached_dedup_works[cache_key].set_to_flush() - else: - sensor_work.log.exception("Task %s failed by exceptions.", sensor_work.ti_key) - self._retry_or_fail_task(sensor_work, error) - else: - sensor_work.log.info("Exception detected, retrying without failing current run.") - self._check_and_handle_ti_timeout(sensor_work) - - def _process_sensor_work_with_cached_state(self, sensor_work, state): - if state == PokeState.LANDED: - sensor_work.log.info("Task %s succeeded", str(sensor_work.ti_key)) - sensor_work.close_sensor_logger() - - if state == PokeState.NOT_LANDED: - # Handle timeout if connection valid but not landed yet - self._check_and_handle_ti_timeout(sensor_work) - elif state == PokeState.POKE_EXCEPTION: - self._handle_poke_exception(sensor_work) - - def _execute_sensor_work(self, sensor_work): - ti_key = sensor_work.ti_key - log = sensor_work.log or self.log - log.info("Sensing ti: %s", str(ti_key)) - log.info("Poking with arguments: %s", sensor_work.encoded_poke_context) - - cache_key = sensor_work.cache_key - if cache_key not in self.cached_dedup_works: - # create an empty cached_work for a new cache_key - self.cached_dedup_works[cache_key] = CachedPokeWork() - - cached_work = self.cached_dedup_works[cache_key] - - if cached_work.state is not None: - # Have a valid cached state, don't poke twice in certain time interval - self._process_sensor_work_with_cached_state(sensor_work, cached_work.state) - return - - try: - with timeout(seconds=self.poke_timeout): - if self.poke(sensor_work): - # Got a landed signal, mark all tasks waiting for this partition - cached_work.set_state(PokeState.LANDED) - - self._mark_multi_state( - sensor_work.operator, - sensor_work.hashcode, - sensor_work.encoded_poke_context, - State.SUCCESS, - ) - - log.info("Task %s succeeded", str(ti_key)) - sensor_work.close_sensor_logger() - else: - # Not landed yet. Handle possible timeout - cached_work.set_state(PokeState.NOT_LANDED) - self._check_and_handle_ti_timeout(sensor_work) - - self.cached_sensor_exceptions.pop(cache_key, None) - except Exception as e: - # The retry_infra_failure decorator inside hive_hooks will raise exception with - # is_infra_failure == True. Long poking timeout here is also considered an infra - # failure. Other exceptions should fail. - is_infra_failure = getattr(e, 'is_infra_failure', False) or isinstance(e, AirflowTaskTimeout) - exception_info = traceback.format_exc() - cached_work.set_state(PokeState.POKE_EXCEPTION) - - if cache_key in self.cached_sensor_exceptions: - self.cached_sensor_exceptions[cache_key].set_latest_exception( - exception_info, is_infra_failure=is_infra_failure - ) - else: - self.cached_sensor_exceptions[cache_key] = SensorExceptionInfo( - exception_info, is_infra_failure=is_infra_failure - ) - - self._handle_poke_exception(sensor_work) - - def flush_cached_sensor_poke_results(self): - """Flush outdated cached sensor states saved in previous loop.""" - for key, cached_work in self.cached_dedup_works.copy().items(): - if cached_work.is_expired(): - self.cached_dedup_works.pop(key, None) - else: - cached_work.state = None - - for ti_key, sensor_exception in self.cached_sensor_exceptions.copy().items(): - if sensor_exception.fail_current_run or sensor_exception.is_expired(): - self.cached_sensor_exceptions.pop(ti_key, None) - - def poke(self, sensor_work): - """ - Function that the sensors defined while deriving this class should - override. - - """ - cached_work = self.cached_dedup_works[sensor_work.cache_key] - if not cached_work.sensor_task: - init_args = dict(list(sensor_work.poke_context.items()) + [('task_id', sensor_work.task_id)]) - operator_class = import_string(sensor_work.op_classpath) - cached_work.sensor_task = operator_class(**init_args) - - return cached_work.sensor_task.poke(sensor_work.poke_context) - - def _emit_loop_stats(self): - try: - count_poke = 0 - count_poke_success = 0 - count_poke_exception = 0 - count_exception_failures = 0 - count_infra_failure = 0 - for cached_work in self.cached_dedup_works.values(): - if cached_work.state is None: - continue - count_poke += 1 - if cached_work.state == PokeState.LANDED: - count_poke_success += 1 - elif cached_work.state == PokeState.POKE_EXCEPTION: - count_poke_exception += 1 - for cached_exception in self.cached_sensor_exceptions.values(): - if cached_exception.is_infra_failure and cached_exception.fail_current_run: - count_infra_failure += 1 - if cached_exception.fail_current_run: - count_exception_failures += 1 - - Stats.gauge("smart_sensor_operator.poked_tasks", count_poke) - Stats.gauge("smart_sensor_operator.poked_success", count_poke_success) - Stats.gauge("smart_sensor_operator.poked_exception", count_poke_exception) - Stats.gauge("smart_sensor_operator.exception_failures", count_exception_failures) - Stats.gauge("smart_sensor_operator.infra_failures", count_infra_failure) - except Exception: - self.log.exception("Exception at getting loop stats %s") - - def execute(self, context: Context): - started_at = timezone.utcnow() - - self.hostname = get_hostname() - while True: - poke_start_time = timezone.utcnow() - - self.flush_cached_sensor_poke_results() - - self._load_sensor_works() - self.log.info("Loaded %s sensor_works", len(self.sensor_works)) - Stats.gauge("smart_sensor_operator.loaded_tasks", len(self.sensor_works)) - - for sensor_work in self.sensor_works: - self._execute_sensor_work(sensor_work) - - duration = timezone.utcnow() - poke_start_time - duration_seconds = duration.total_seconds() - - self.log.info("Taking %s seconds to execute %s tasks.", duration_seconds, len(self.sensor_works)) - - Stats.timing("smart_sensor_operator.loop_duration", duration) - Stats.gauge("smart_sensor_operator.executed_tasks", len(self.sensor_works)) - self._emit_loop_stats() - - if duration_seconds < self.poke_interval: - sleep(self.poke_interval - duration_seconds) - if (timezone.utcnow() - started_at).total_seconds() > self.timeout: - self.log.info("Time is out for smart sensor.") - return - - def on_kill(self): - pass diff --git a/airflow/sensors/sql.py b/airflow/sensors/sql.py deleted file mode 100644 index a35d7566ceb41..0000000000000 --- a/airflow/sensors/sql.py +++ /dev/null @@ -1,111 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import Sequence - -from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook -from airflow.hooks.dbapi import DbApiHook -from airflow.sensors.base import BaseSensorOperator -from airflow.utils.context import Context - - -class SqlSensor(BaseSensorOperator): - """ - Runs a sql statement repeatedly until a criteria is met. It will keep trying until - success or failure criteria are met, or if the first cell is not in (0, '0', '', None). - Optional success and failure callables are called with the first cell returned as the argument. - If success callable is defined the sensor will keep retrying until the criteria is met. - If failure callable is defined and the criteria is met the sensor will raise AirflowException. - Failure criteria is evaluated before success criteria. A fail_on_empty boolean can also - be passed to the sensor in which case it will fail if no rows have been returned - - :param conn_id: The connection to run the sensor against - :param sql: The sql to run. To pass, it needs to return at least one cell - that contains a non-zero / empty string value. - :param parameters: The parameters to render the SQL query with (optional). - :param success: Success criteria for the sensor is a Callable that takes first_cell - as the only argument, and returns a boolean (optional). - :param failure: Failure criteria for the sensor is a Callable that takes first_cell - as the only argument and return a boolean (optional). - :param fail_on_empty: Explicitly fail on no rows returned. - :param hook_params: Extra config params to be passed to the underlying hook. - Should match the desired hook constructor params. - """ - - template_fields: Sequence[str] = ('sql',) - template_ext: Sequence[str] = ( - '.hql', - '.sql', - ) - ui_color = '#7c7287' - - def __init__( - self, - *, - conn_id, - sql, - parameters=None, - success=None, - failure=None, - fail_on_empty=False, - hook_params=None, - **kwargs, - ): - self.conn_id = conn_id - self.sql = sql - self.parameters = parameters - self.success = success - self.failure = failure - self.fail_on_empty = fail_on_empty - self.hook_params = hook_params - super().__init__(**kwargs) - - def _get_hook(self): - conn = BaseHook.get_connection(self.conn_id) - hook = conn.get_hook(hook_params=self.hook_params) - if not isinstance(hook, DbApiHook): - raise AirflowException( - f'The connection type is not supported by {self.__class__.__name__}. ' - f'The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}' - ) - return hook - - def poke(self, context: Context): - hook = self._get_hook() - - self.log.info('Poking: %s (with parameters %s)', self.sql, self.parameters) - records = hook.get_records(self.sql, self.parameters) - if not records: - if self.fail_on_empty: - raise AirflowException("No rows returned, raising as per fail_on_empty flag") - else: - return False - first_cell = records[0][0] - if self.failure is not None: - if callable(self.failure): - if self.failure(first_cell): - raise AirflowException(f"Failure criteria met. self.failure({first_cell}) returned True") - else: - raise AirflowException(f"self.failure is present, but not callable -> {self.failure}") - if self.success is not None: - if callable(self.success): - return self.success(first_cell) - else: - raise AirflowException(f"self.success is present, but not callable -> {self.success}") - return bool(first_cell) diff --git a/airflow/sensors/sql_sensor.py b/airflow/sensors/sql_sensor.py deleted file mode 100644 index 8a077db534496..0000000000000 --- a/airflow/sensors/sql_sensor.py +++ /dev/null @@ -1,26 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.sensors.sql`.""" - -import warnings - -from airflow.sensors.sql import SqlSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.sensors.sql`.", DeprecationWarning, stacklevel=2 -) diff --git a/airflow/sensors/time_delta.py b/airflow/sensors/time_delta.py index 5b336c3f88ccd..a73d123c3d761 100644 --- a/airflow/sensors/time_delta.py +++ b/airflow/sensors/time_delta.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from airflow.sensors.base import BaseSensorOperator from airflow.triggers.temporal import DateTimeTrigger @@ -27,6 +28,12 @@ class TimeDeltaSensor(BaseSensorOperator): Waits for a timedelta after the run's data interval. :param delta: time length to wait after the data interval before succeeding. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/operator:TimeDeltaSensor` + + """ def __init__(self, *, delta, **kwargs): @@ -34,22 +41,28 @@ def __init__(self, *, delta, **kwargs): self.delta = delta def poke(self, context: Context): - target_dttm = context['data_interval_end'] + target_dttm = context["data_interval_end"] target_dttm += self.delta - self.log.info('Checking if the time (%s) has come', target_dttm) + self.log.info("Checking if the time (%s) has come", target_dttm) return timezone.utcnow() > target_dttm class TimeDeltaSensorAsync(TimeDeltaSensor): """ - A drop-in replacement for TimeDeltaSensor that defers itself to avoid - taking up a worker slot while it is waiting. + A deferrable drop-in replacement for TimeDeltaSensor. + + Will defers itself to avoid taking up a worker slot while it is waiting. :param delta: time length to wait after the data interval before succeeding. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/operator:TimeDeltaSensorAsync` + """ def execute(self, context: Context): - target_dttm = context['data_interval_end'] + target_dttm = context["data_interval_end"] target_dttm += self.delta self.defer(trigger=DateTimeTrigger(moment=target_dttm), method_name="execute_complete") diff --git a/airflow/sensors/time_delta_sensor.py b/airflow/sensors/time_delta_sensor.py deleted file mode 100644 index 73f32c2fc82fc..0000000000000 --- a/airflow/sensors/time_delta_sensor.py +++ /dev/null @@ -1,26 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.sensors.time_delta`.""" - -import warnings - -from airflow.sensors.time_delta import TimeDeltaSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.sensors.time_delta`.", DeprecationWarning, stacklevel=2 -) diff --git a/airflow/sensors/time_sensor.py b/airflow/sensors/time_sensor.py index 117390925def6..12b26d06bdd50 100644 --- a/airflow/sensors/time_sensor.py +++ b/airflow/sensors/time_sensor.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import datetime from airflow.sensors.base import BaseSensorOperator @@ -28,6 +30,11 @@ class TimeSensor(BaseSensorOperator): Waits until the specified time of the day. :param target_time: time after which the job succeeds + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/operator:TimeSensor` + """ def __init__(self, *, target_time, **kwargs): @@ -35,26 +42,33 @@ def __init__(self, *, target_time, **kwargs): self.target_time = target_time def poke(self, context: Context): - self.log.info('Checking if the time (%s) has come', self.target_time) + self.log.info("Checking if the time (%s) has come", self.target_time) return timezone.make_naive(timezone.utcnow(), self.dag.timezone).time() > self.target_time class TimeSensorAsync(BaseSensorOperator): """ - Waits until the specified time of the day, freeing up a worker slot while - it is waiting. + Waits until the specified time of the day. + + This frees up a worker slot while it is waiting. :param target_time: time after which the job succeeds + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/operator:TimeSensorAsync` """ def __init__(self, *, target_time, **kwargs): super().__init__(**kwargs) self.target_time = target_time - self.target_datetime = timezone.coerce_datetime( + aware_time = timezone.coerce_datetime( datetime.datetime.combine(datetime.datetime.today(), self.target_time) ) + self.target_datetime = timezone.convert_to_utc(aware_time) + def execute(self, context: Context): self.defer( trigger=DateTimeTrigger(moment=self.target_datetime), diff --git a/airflow/sensors/web_hdfs_sensor.py b/airflow/sensors/web_hdfs_sensor.py deleted file mode 100644 index 8f9324e7c2b7b..0000000000000 --- a/airflow/sensors/web_hdfs_sensor.py +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.apache.hdfs.sensors.web_hdfs`.""" - -import warnings - -from airflow.providers.apache.hdfs.sensors.web_hdfs import WebHdfsSensor # noqa - -warnings.warn( - "This module is deprecated. Please use `airflow.providers.apache.hdfs.sensors.web_hdfs`.", - DeprecationWarning, - stacklevel=2, -) diff --git a/airflow/sensors/weekday.py b/airflow/sensors/weekday.py index 5bb4db646f7c4..1cc3e36ddfabc 100644 --- a/airflow/sensors/weekday.py +++ b/airflow/sensors/weekday.py @@ -15,8 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import warnings +from typing import Iterable +from airflow.exceptions import RemovedInAirflow3Warning from airflow.sensors.base import BaseSensorOperator from airflow.utils import timezone from airflow.utils.context import Context @@ -25,9 +29,10 @@ class DayOfWeekSensor(BaseSensorOperator): """ - Waits until the first specified day of the week. For example, if the execution - day of the task is '2018-12-22' (Saturday) and you pass 'FRIDAY', the task will wait - until next Friday. + Waits until the first specified day of the week. + + For example, if the execution day of the task is '2018-12-22' (Saturday) + and you pass 'FRIDAY', the task will wait until next Friday. **Example** (with single day): :: @@ -65,13 +70,28 @@ class DayOfWeekSensor(BaseSensorOperator): * ``{WeekDay.TUESDAY}`` * ``{WeekDay.SATURDAY, WeekDay.SUNDAY}`` + To use ``WeekDay`` enum, import it from ``airflow.utils.weekday`` + :param use_task_logical_date: If ``True``, uses task's logical date to compare with week_day. Execution Date is Useful for backfilling. If ``False``, uses system's day of the week. Useful when you don't want to run anything on weekdays on the system. + :param use_task_execution_day: deprecated parameter, same effect as `use_task_logical_date` + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/operator:DayOfWeekSensor` + """ - def __init__(self, *, week_day, use_task_logical_date=False, use_task_execution_day=False, **kwargs): + def __init__( + self, + *, + week_day: str | Iterable[str] | WeekDay | Iterable[WeekDay], + use_task_logical_date: bool = False, + use_task_execution_day: bool = False, + **kwargs, + ) -> None: super().__init__(**kwargs) self.week_day = week_day self.use_task_logical_date = use_task_logical_date @@ -79,18 +99,18 @@ def __init__(self, *, week_day, use_task_logical_date=False, use_task_execution_ self.use_task_logical_date = use_task_execution_day warnings.warn( "Parameter ``use_task_execution_day`` is deprecated. Use ``use_task_logical_date``.", - DeprecationWarning, + RemovedInAirflow3Warning, stacklevel=2, ) self._week_day_num = WeekDay.validate_week_day(week_day) - def poke(self, context: Context): + def poke(self, context: Context) -> bool: self.log.info( - 'Poking until weekday is in %s, Today is %s', + "Poking until weekday is in %s, Today is %s", self.week_day, WeekDay(timezone.utcnow().isoweekday()).name, ) if self.use_task_logical_date: - return context['logical_date'].isoweekday() in self._week_day_num + return context["logical_date"].isoweekday() in self._week_day_num else: return timezone.utcnow().isoweekday() in self._week_day_num diff --git a/airflow/sentry.py b/airflow/sentry.py index 074ecb332ca6c..a5d1ff5c97843 100644 --- a/airflow/sentry.py +++ b/airflow/sentry.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Sentry Integration""" +from __future__ import annotations + import logging from functools import wraps @@ -48,7 +49,7 @@ def flush(self): Sentry: DummySentry = DummySentry() -if conf.getboolean("sentry", 'sentry_on', fallback=False): +if conf.getboolean("sentry", "sentry_on", fallback=False): import sentry_sdk # Verify blinker installation @@ -106,7 +107,7 @@ def __init__(self): ", ".join(unsupported_options), ) - sentry_config_opts['before_send'] = conf.getimport('sentry', 'before_send', fallback=None) + sentry_config_opts["before_send"] = conf.getimport("sentry", "before_send", fallback=None) if dsn: sentry_sdk.init(dsn=dsn, integrations=integrations, **sentry_config_opts) @@ -160,14 +161,14 @@ def wrapper(_self, *args, **kwargs): # tags and breadcrumbs to a specific Task Instance try: - session = kwargs.get('session', args[session_args_idx]) + session = kwargs.get("session", args[session_args_idx]) except IndexError: session = None with sentry_sdk.push_scope(): try: # Is a LocalTaskJob get the task instance - if hasattr(_self, 'task_instance'): + if hasattr(_self, "task_instance"): task_instance = _self.task_instance else: task_instance = _self diff --git a/airflow/serialization/__init__.py b/airflow/serialization/__init__.py index c6d1a147f91f6..ea277edbaf0f4 100644 --- a/airflow/serialization/__init__.py +++ b/airflow/serialization/__init__.py @@ -15,5 +15,4 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """DAG serialization.""" diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py index f4227a6f7aed2..f2332616132ac 100644 --- a/airflow/serialization/enums.py +++ b/airflow/serialization/enums.py @@ -15,8 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Enums for DAG serialization.""" +from __future__ import annotations from enum import Enum, unique @@ -26,8 +26,8 @@ class Encoding(str, Enum): """Enum of encoding constants.""" - TYPE = '__type' - VAR = '__var' + TYPE = "__type" + VAR = "__var" # Supported types for encoding. primitives and list are not encoded. @@ -35,17 +35,19 @@ class Encoding(str, Enum): class DagAttributeTypes(str, Enum): """Enum of supported attribute types of DAG.""" - DAG = 'dag' - OP = 'operator' - DATETIME = 'datetime' - TIMEDELTA = 'timedelta' - TIMEZONE = 'timezone' - RELATIVEDELTA = 'relativedelta' - DICT = 'dict' - SET = 'set' - TUPLE = 'tuple' - POD = 'k8s.V1Pod' - TASK_GROUP = 'taskgroup' - EDGE_INFO = 'edgeinfo' - PARAM = 'param' - XCOM_REF = 'xcomref' + DAG = "dag" + OP = "operator" + DATETIME = "datetime" + TIMEDELTA = "timedelta" + TIMEZONE = "timezone" + RELATIVEDELTA = "relativedelta" + DICT = "dict" + SET = "set" + TUPLE = "tuple" + POD = "k8s.V1Pod" + TASK_GROUP = "taskgroup" + EDGE_INFO = "edgeinfo" + PARAM = "param" + XCOM_REF = "xcomref" + DATASET = "dataset" + SIMPLE_TASK_INSTANCE = "simple_task_instance" diff --git a/airflow/serialization/helpers.py b/airflow/serialization/helpers.py index b26c0df9efd2e..5026f070d96df 100644 --- a/airflow/serialization/helpers.py +++ b/airflow/serialization/helpers.py @@ -14,14 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Serialized DAG and BaseOperator""" -from typing import Any, Union +from __future__ import annotations + +from typing import Any from airflow.settings import json -def serialize_template_field(template_field: Any) -> Union[str, dict, list, int, float]: +def serialize_template_field(template_field: Any) -> str | dict | list | int | float: """ Return a serializable representation of the templated_field. If a templated_field contains a Class or Instance for recursive templating, store them diff --git a/airflow/serialization/json_schema.py b/airflow/serialization/json_schema.py index 272d669e3d869..a505713e81283 100644 --- a/airflow/serialization/json_schema.py +++ b/airflow/serialization/json_schema.py @@ -15,18 +15,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """jsonschema for validating serialized DAG and operator.""" +from __future__ import annotations import pkgutil -from typing import Iterable - -import jsonschema +from typing import TYPE_CHECKING, Iterable from airflow.exceptions import AirflowException from airflow.settings import json from airflow.typing_compat import Protocol +if TYPE_CHECKING: + import jsonschema + class Validator(Protocol): """ @@ -50,7 +51,7 @@ def iter_errors(self, instance) -> Iterable[jsonschema.exceptions.ValidationErro def load_dag_schema_dict() -> dict: """Load & return Json Schema for DAG as Python dict""" - schema_file_name = 'schema.json' + schema_file_name = "schema.json" schema_file = pkgutil.get_data(__name__, schema_file_name) if schema_file is None: @@ -62,5 +63,7 @@ def load_dag_schema_dict() -> dict: def load_dag_schema() -> Validator: """Load & Validate Json Schema for DAG""" + import jsonschema + schema = load_dag_schema_dict() return jsonschema.Draft7Validator(schema) diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json index 423950bb5e74a..19d6c4d5c9a4e 100644 --- a/airflow/serialization/schema.json +++ b/airflow/serialization/schema.json @@ -53,6 +53,34 @@ { "type": "integer" } ] }, + "dataset": { + "type": "object", + "properties": { + "uri": { "type": "string" }, + "extra": { + "anyOf": [ + {"type": "null"}, + { "$ref": "#/definitions/dict" } + ] + } + }, + "required": [ "uri", "extra" ] + }, + "typed_dataset": { + "type": "object", + "properties": { + "__type": { + "type": "string", + "constant": "dataset" + }, + "__var": { "$ref": "#/definitions/dataset" } + }, + "required": [ + "__type", + "__var" + ], + "additionalProperties": false + }, "dict": { "description": "A python dictionary containing values of any type", "type": "object" @@ -90,6 +118,11 @@ { "$ref": "#/definitions/typed_relativedelta" } ] }, + "dataset_triggers": { + "type": "array", + "items": { "$ref": "#/definitions/typed_dataset" } + }, + "owner_links": { "type": "object" }, "timetable": { "type": "object", "properties": { @@ -100,7 +133,13 @@ "catchup": { "type": "boolean" }, "is_subdag": { "type": "boolean" }, "fileloc": { "type" : "string"}, - "orientation": { "type" : "string"}, + "_processor_dags_folder": { + "anyOf": [ + { "type": "null" }, + {"type": "string"} + ] + }, + "orientation": { "type" : "string"}, "_description": { "type" : "string"}, "_concurrency": { "type" : "number"}, "_max_active_tasks": { "type" : "number"}, @@ -214,13 +253,13 @@ "doc_yaml": { "type": "string" }, "doc_rst": { "type": "string" }, "_is_mapped": { "const": true, "$comment": "only present when True" }, - "mapped_kwargs": { "type": "object" }, + "expand_input": { "type": "object" }, "partial_kwargs": { "type": "object" } }, "dependencies": { - "mapped_kwargs": ["partial_kwargs", "_is_mapped"], - "partial_kwargs": ["mapped_kwargs", "_is_mapped"], - "_is_mapped": ["mapped_kwargs", "partial_kwargs"] + "expand_input": ["partial_kwargs", "_is_mapped"], + "partial_kwargs": ["expand_input", "_is_mapped"], + "_is_mapped": ["expand_input", "partial_kwargs"] }, "additionalProperties": true }, @@ -241,6 +280,7 @@ ], "properties": { "_group_id": {"anyOf": [{"type": "null"}, { "type": "string" }]}, + "is_mapped": { "type": "boolean" }, "prefix_group_id": { "type": "boolean" }, "children": { "$ref": "#/definitions/dict" }, "tooltip": { "type": "string" }, diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 3e674b2f8d0a2..6cb33cd203277 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -14,46 +14,50 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Serialized DAG and BaseOperator""" -import contextlib +from __future__ import annotations + +import collections.abc import datetime import enum import logging +import warnings import weakref from dataclasses import dataclass from inspect import Parameter, signature -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Set, Type, Union +from typing import TYPE_CHECKING, Any, Collection, Iterable, Mapping, NamedTuple, Union import cattr +import lazy_object_proxy import pendulum from dateutil import relativedelta from pendulum.tz.timezone import FixedTimezone, Timezone from airflow.compat.functools import cache from airflow.configuration import conf -from airflow.exceptions import AirflowException, SerializationError +from airflow.datasets import Dataset +from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, SerializationError from airflow.models.baseoperator import BaseOperator, BaseOperatorLink from airflow.models.connection import Connection from airflow.models.dag import DAG, create_timetable +from airflow.models.expandinput import EXPAND_INPUT_EMPTY, ExpandInput, create_expand_input, get_map_type_key from airflow.models.mappedoperator import MappedOperator from airflow.models.operator import Operator from airflow.models.param import Param, ParamsDict +from airflow.models.taskinstance import SimpleTaskInstance from airflow.models.taskmixin import DAGNode -from airflow.models.xcom_arg import XComArg -from airflow.operators.trigger_dagrun import TriggerDagRunOperator +from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg from airflow.providers_manager import ProvidersManager -from airflow.sensors.external_task import ExternalTaskSensor from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.helpers import serialize_template_field from airflow.serialization.json_schema import Validator, load_dag_schema -from airflow.settings import json +from airflow.settings import DAGS_FOLDER, json from airflow.timetables.base import Timetable from airflow.utils.code_utils import get_python_source from airflow.utils.docs import get_docs_url from airflow.utils.module_loading import as_importable_string, import_string from airflow.utils.operator_resources import Resources -from airflow.utils.task_group import TaskGroup +from airflow.utils.task_group import MappedTaskGroup, TaskGroup if TYPE_CHECKING: from airflow.ti_deps.deps.base_ti_dep import BaseTIDep @@ -68,17 +72,18 @@ log = logging.getLogger(__name__) -_OPERATOR_EXTRA_LINKS: Set[str] = { +_OPERATOR_EXTRA_LINKS: set[str] = { "airflow.operators.trigger_dagrun.TriggerDagRunLink", - "airflow.sensors.external_task.ExternalTaskSensorLink", + "airflow.sensors.external_task.ExternalDagLink", # Deprecated names, so that existing serialized dags load straight away. + "airflow.sensors.external_task.ExternalTaskSensorLink", "airflow.operators.dagrun_operator.TriggerDagRunLink", "airflow.sensors.external_task_sensor.ExternalTaskSensorLink", } @cache -def get_operator_extra_links() -> Set[str]: +def get_operator_extra_links() -> set[str]: """Get the operator extra links. This includes both the built-in ones, and those come from the providers. @@ -88,7 +93,7 @@ def get_operator_extra_links() -> Set[str]: @cache -def _get_default_mapped_partial() -> Dict[str, Any]: +def _get_default_mapped_partial() -> dict[str, Any]: """Get default partial kwargs in a mapped operator. This is used to simplify a serialized mapped operator by excluding default @@ -97,27 +102,27 @@ def _get_default_mapped_partial() -> Dict[str, Any]: don't need to store them. """ # Use the private _expand() method to avoid the empty kwargs check. - default_partial_kwargs = BaseOperator.partial(task_id="_")._expand().partial_kwargs - return BaseSerialization._serialize(default_partial_kwargs)[Encoding.VAR] + default = BaseOperator.partial(task_id="_")._expand(EXPAND_INPUT_EMPTY, strict=False).partial_kwargs + return BaseSerialization.serialize(default)[Encoding.VAR] -def encode_relativedelta(var: relativedelta.relativedelta) -> Dict[str, Any]: +def encode_relativedelta(var: relativedelta.relativedelta) -> dict[str, Any]: encoded = {k: v for k, v in var.__dict__.items() if not k.startswith("_") and v} if var.weekday and var.weekday.n: # Every n'th Friday for example - encoded['weekday'] = [var.weekday.weekday, var.weekday.n] + encoded["weekday"] = [var.weekday.weekday, var.weekday.n] elif var.weekday: - encoded['weekday'] = [var.weekday.weekday] + encoded["weekday"] = [var.weekday.weekday] return encoded -def decode_relativedelta(var: Dict[str, Any]) -> relativedelta.relativedelta: - if 'weekday' in var: - var['weekday'] = relativedelta.weekday(*var['weekday']) # type: ignore +def decode_relativedelta(var: dict[str, Any]) -> relativedelta.relativedelta: + if "weekday" in var: + var["weekday"] = relativedelta.weekday(*var["weekday"]) # type: ignore return relativedelta.relativedelta(**var) -def encode_timezone(var: Timezone) -> Union[str, int]: +def encode_timezone(var: Timezone) -> str | int: """Encode a Pendulum Timezone for serialization. Airflow only supports timezone objects that implements Pendulum's Timezone @@ -139,12 +144,12 @@ def encode_timezone(var: Timezone) -> Union[str, int]: ) -def decode_timezone(var: Union[str, int]) -> Timezone: +def decode_timezone(var: str | int) -> Timezone: """Decode a previously serialized Pendulum Timezone.""" return pendulum.tz.timezone(var) -def _get_registered_timetable(importable_string: str) -> Optional[Type[Timetable]]: +def _get_registered_timetable(importable_string: str) -> type[Timetable] | None: from airflow import plugins_manager if importable_string.startswith("airflow.timetables."): @@ -161,10 +166,14 @@ def __init__(self, type_string: str) -> None: self.type_string = type_string def __str__(self) -> str: - return f"Timetable class {self.type_string!r} is not registered" + return ( + f"Timetable class {self.type_string!r} is not registered or " + "you have a top level database access that disrupted the session. " + "Please check the airflow best practices documentation." + ) -def _encode_timetable(var: Timetable) -> Dict[str, Any]: +def _encode_timetable(var: Timetable) -> dict[str, Any]: """Encode a timetable instance. This delegates most of the serialization work to the type, so the behavior @@ -177,7 +186,7 @@ def _encode_timetable(var: Timetable) -> Dict[str, Any]: return {Encoding.TYPE: importable_string, Encoding.VAR: var.serialize()} -def _decode_timetable(var: Dict[str, Any]) -> Timetable: +def _decode_timetable(var: dict[str, Any]) -> Timetable: """Decode a previously serialized timetable. Most of the deserialization logic is delegated to the actual type, which @@ -191,15 +200,71 @@ def _decode_timetable(var: Dict[str, Any]) -> Timetable: class _XComRef(NamedTuple): + """Used to store info needed to create XComArg. + + We can't turn it in to a XComArg until we've loaded _all_ the tasks, so when + deserializing an operator, we need to create something in its place, and + post-process it in ``deserialize_dag``. """ - Used to store info needed to create XComArg when deserializing MappedOperator. - We can't turn it in to a XComArg until we've loaded _all_ the tasks, so when deserializing an operator we - need to create _something_, and then post-process it in deserialize_dag + data: dict + + def deref(self, dag: DAG) -> XComArg: + return deserialize_xcom_arg(self.data, dag) + + +# These two should be kept in sync. Note that these are intentionally not using +# the type declarations in expandinput.py so we always remember to update +# serialization logic when adding new ExpandInput variants. If you add things to +# the unions, be sure to update _ExpandInputRef to match. +_ExpandInputOriginalValue = Union[ + # For .expand(**kwargs). + Mapping[str, Any], + # For expand_kwargs(arg). + XComArg, + Collection[Union[XComArg, Mapping[str, Any]]], +] +_ExpandInputSerializedValue = Union[ + # For .expand(**kwargs). + Mapping[str, Any], + # For expand_kwargs(arg). + _XComRef, + Collection[Union[_XComRef, Mapping[str, Any]]], +] + + +class _ExpandInputRef(NamedTuple): + """Used to store info needed to create a mapped operator's expand input. + + This references a ``ExpandInput`` type, but replaces ``XComArg`` objects + with ``_XComRef`` (see documentation on the latter type for reasoning). """ - task_id: str key: str + value: _ExpandInputSerializedValue + + @classmethod + def validate_expand_input_value(cls, value: _ExpandInputOriginalValue) -> None: + """Validate we've covered all ``ExpandInput.value`` types. + + This function does not actually do anything, but is called during + serialization so Mypy will *statically* check we have handled all + possible ExpandInput cases. + """ + + def deref(self, dag: DAG) -> ExpandInput: + """De-reference into a concrete ExpandInput object. + + If you add more cases here, be sure to update _ExpandInputOriginalValue + and _ExpandInputSerializedValue to match the logic. + """ + if isinstance(self.value, _XComRef): + value: Any = self.value.deref(dag) + elif isinstance(self.value, collections.abc.Mapping): + value = {k: v.deref(dag) if isinstance(v, _XComRef) else v for k, v in self.value.items()} + else: + value = [v.deref(dag) if isinstance(v, _XComRef) else v for v in self.value] + return create_expand_input(self.key, value) class BaseSerialization: @@ -215,48 +280,46 @@ class BaseSerialization: # Object types that are always excluded in serialization. _excluded_types = (logging.Logger, Connection, type) - _json_schema: Optional[Validator] = None + _json_schema: Validator | None = None # Should the extra operator link be loaded via plugins when # de-serializing the DAG? This flag is set to False in Scheduler so that Extra Operator links # are not loaded to not run User code in Scheduler. _load_operator_extra_links = True - _CONSTRUCTOR_PARAMS: Dict[str, Parameter] = {} + _CONSTRUCTOR_PARAMS: dict[str, Parameter] = {} SERIALIZER_VERSION = 1 @classmethod - def to_json(cls, var: Union[DAG, BaseOperator, dict, list, set, tuple]) -> str: + def to_json(cls, var: DAG | BaseOperator | dict | list | set | tuple) -> str: """Stringifies DAGs and operators contained by var and returns a JSON string of var.""" return json.dumps(cls.to_dict(var), ensure_ascii=True) @classmethod - def to_dict(cls, var: Union[DAG, BaseOperator, dict, list, set, tuple]) -> dict: + def to_dict(cls, var: DAG | BaseOperator | dict | list | set | tuple) -> dict: """Stringifies DAGs and operators contained by var and returns a dict of var.""" # Don't call on this class directly - only SerializedDAG or # SerializedBaseOperator should be used as the "entrypoint" raise NotImplementedError() @classmethod - def from_json(cls, serialized_obj: str) -> Union['BaseSerialization', dict, list, set, tuple]: + def from_json(cls, serialized_obj: str) -> BaseSerialization | dict | list | set | tuple: """Deserializes json_str and reconstructs all DAGs and operators it contains.""" return cls.from_dict(json.loads(serialized_obj)) @classmethod - def from_dict( - cls, serialized_obj: Dict[Encoding, Any] - ) -> Union['BaseSerialization', dict, list, set, tuple]: + def from_dict(cls, serialized_obj: dict[Encoding, Any]) -> BaseSerialization | dict | list | set | tuple: """Deserializes a python dict stored with type decorators and reconstructs all DAGs and operators it contains. """ - return cls._deserialize(serialized_obj) + return cls.deserialize(serialized_obj) @classmethod - def validate_schema(cls, serialized_obj: Union[str, dict]) -> None: + def validate_schema(cls, serialized_obj: str | dict) -> None: """Validate serialized_obj satisfies JSON schema.""" if cls._json_schema is None: - raise AirflowException(f'JSON schema of {cls.__name__:s} is not set.') + raise AirflowException(f"JSON schema of {cls.__name__:s} is not set.") if isinstance(serialized_obj, dict): cls._json_schema.validate(serialized_obj) @@ -266,7 +329,7 @@ def validate_schema(cls, serialized_obj: Union[str, dict]) -> None: raise TypeError("Invalid type: Only dict and str are supported.") @staticmethod - def _encode(x: Any, type_: Any) -> Dict[Encoding, Any]: + def _encode(x: Any, type_: Any) -> dict[Encoding, Any]: """Encode data by a JSON dict.""" return {Encoding.VAR: x, Encoding.TYPE: type_} @@ -290,10 +353,10 @@ def _is_excluded(cls, var: Any, attrname: str, instance: Any) -> bool: @classmethod def serialize_to_json( - cls, object_to_serialize: Union["BaseOperator", "MappedOperator", DAG], decorated_fields: Set - ) -> Dict[str, Any]: + cls, object_to_serialize: BaseOperator | MappedOperator | DAG, decorated_fields: set + ) -> dict[str, Any]: """Serializes an object to json""" - serialized_object: Dict[str, Any] = {} + serialized_object: dict[str, Any] = {} keys_to_serialize = object_to_serialize.get_serialized_fields() for key in keys_to_serialize: # None is ignored in serialized form and is added back in deserialization. @@ -301,19 +364,27 @@ def serialize_to_json( if cls._is_excluded(value, key, object_to_serialize): continue - if key in decorated_fields: - serialized_object[key] = cls._serialize(value) + if key == "_operator_name": + # when operator_name matches task_type, we can remove + # it to reduce the JSON payload + task_type = getattr(object_to_serialize, "_task_type", None) + if value != task_type: + serialized_object[key] = cls.serialize(value) + elif key in decorated_fields: + serialized_object[key] = cls.serialize(value) elif key == "timetable" and value is not None: serialized_object[key] = _encode_timetable(value) else: - value = cls._serialize(value) + value = cls.serialize(value) if isinstance(value, dict) and Encoding.TYPE in value: value = value[Encoding.VAR] serialized_object[key] = value return serialized_object @classmethod - def _serialize(cls, var: Any) -> Any: # Unfortunately there is no support for recursive types in mypy + def serialize( + cls, var: Any, *, strict: bool = False + ) -> Any: # Unfortunately there is no support for recursive types in mypy """Helper function of depth first search for serialization. The serialization protocol is: @@ -323,6 +394,8 @@ def _serialize(cls, var: Any) -> Any: # Unfortunately there is no support for r step decode VAR according to TYPE; (3) Operator has a special field CLASS to record the original class name for displaying in UI. + + :meta private: """ if cls._is_primitive(var): # enum.IntEnum is an int instance, it causes json dumps error so we use its value. @@ -330,10 +403,12 @@ def _serialize(cls, var: Any) -> Any: # Unfortunately there is no support for r return var.value return var elif isinstance(var, dict): - return cls._encode({str(k): cls._serialize(v) for k, v in var.items()}, type_=DAT.DICT) + return cls._encode( + {str(k): cls.serialize(v, strict=strict) for k, v in var.items()}, type_=DAT.DICT + ) elif isinstance(var, list): - return [cls._serialize(v) for v in var] - elif _has_kubernetes() and isinstance(var, k8s.V1Pod): + return [cls.serialize(v, strict=strict) for v in var] + elif var.__class__.__name__ == "V1Pod" and _has_kubernetes() and isinstance(var, k8s.V1Pod): json_pod = PodGenerator.serialize_pod(var) return cls._encode(json_pod, type_=DAT.POD) elif isinstance(var, DAG): @@ -357,30 +432,39 @@ def _serialize(cls, var: Any) -> Any: # Unfortunately there is no support for r elif isinstance(var, set): # FIXME: casts set to list in customized serialization in future. try: - return cls._encode(sorted(cls._serialize(v) for v in var), type_=DAT.SET) + return cls._encode(sorted(cls.serialize(v, strict=strict) for v in var), type_=DAT.SET) except TypeError: - return cls._encode([cls._serialize(v) for v in var], type_=DAT.SET) + return cls._encode([cls.serialize(v, strict=strict) for v in var], type_=DAT.SET) elif isinstance(var, tuple): # FIXME: casts tuple to list in customized serialization in future. - return cls._encode([cls._serialize(v) for v in var], type_=DAT.TUPLE) + return cls._encode([cls.serialize(v, strict=strict) for v in var], type_=DAT.TUPLE) elif isinstance(var, TaskGroup): - return SerializedTaskGroup.serialize_task_group(var) + return TaskGroupSerialization.serialize_task_group(var) elif isinstance(var, Param): return cls._encode(cls._serialize_param(var), type_=DAT.PARAM) elif isinstance(var, XComArg): - return cls._encode(cls._serialize_xcomarg(var), type_=DAT.XCOM_REF) + return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF) + elif isinstance(var, Dataset): + return cls._encode(dict(uri=var.uri, extra=var.extra), type_=DAT.DATASET) + elif isinstance(var, SimpleTaskInstance): + return cls._encode(cls.serialize(var.__dict__, strict=strict), type_=DAT.SIMPLE_TASK_INSTANCE) else: - log.debug('Cast type %s to str in serialization.', type(var)) + log.debug("Cast type %s to str in serialization.", type(var)) + if strict: + raise SerializationError("Encountered unexpected type") return str(var) @classmethod - def _deserialize(cls, encoded_var: Any) -> Any: - """Helper function of depth first search for deserialization.""" + def deserialize(cls, encoded_var: Any) -> Any: + """Helper function of depth first search for deserialization. + + :meta private: + """ # JSON primitives (except for dict) are not encoded. if cls._is_primitive(encoded_var): return encoded_var elif isinstance(encoded_var, list): - return [cls._deserialize(v) for v in encoded_var] + return [cls.deserialize(v) for v in encoded_var] if not isinstance(encoded_var, dict): raise ValueError(f"The encoded_var should be dict and is {type(encoded_var)}") @@ -388,7 +472,7 @@ def _deserialize(cls, encoded_var: Any) -> Any: type_ = encoded_var[Encoding.TYPE] if type_ == DAT.DICT: - return {k: cls._deserialize(v) for k, v in var.items()} + return {k: cls.deserialize(v) for k, v in var.items()} elif type_ == DAT.DAG: return SerializedDAG.deserialize_dag(var) elif type_ == DAT.OP: @@ -407,15 +491,19 @@ def _deserialize(cls, encoded_var: Any) -> Any: elif type_ == DAT.RELATIVEDELTA: return decode_relativedelta(var) elif type_ == DAT.SET: - return {cls._deserialize(v) for v in var} + return {cls.deserialize(v) for v in var} elif type_ == DAT.TUPLE: - return tuple(cls._deserialize(v) for v in var) + return tuple(cls.deserialize(v) for v in var) elif type_ == DAT.PARAM: return cls._deserialize_param(var) elif type_ == DAT.XCOM_REF: - return cls._deserialize_xcomref(var) + return _XComRef(var) # Delay deserializing XComArg objects until we have the entire DAG. + elif type_ == DAT.DATASET: + return Dataset(**var) + elif type_ == DAT.SIMPLE_TASK_INSTANCE: + return SimpleTaskInstance(**cls.deserialize(var)) else: - raise TypeError(f'Invalid type {type_!s} in deserialization.') + raise TypeError(f"Invalid type {type_!s} in deserialization.") _deserialize_datetime = pendulum.from_timestamp _deserialize_timezone = pendulum.tz.timezone @@ -456,36 +544,43 @@ def _value_is_hardcoded_default(cls, attrname: str, value: Any, instance: Any) - def _serialize_param(cls, param: Param): return dict( __class=f"{param.__module__}.{param.__class__.__name__}", - default=cls._serialize(param.value), - description=cls._serialize(param.description), - schema=cls._serialize(param.schema), + default=cls.serialize(param.value), + description=cls.serialize(param.description), + schema=cls.serialize(param.schema), ) @classmethod - def _deserialize_param(cls, param_dict: Dict): + def _deserialize_param(cls, param_dict: dict): """ In 2.2.0, Param attrs were assumed to be json-serializable and were not run through - this class's ``_serialize`` method. So before running through ``_deserialize``, + this class's ``serialize`` method. So before running through ``deserialize``, we first verify that it's necessary to do. """ - class_name = param_dict['__class'] - class_ = import_string(class_name) # type: Type[Param] - attrs = ('default', 'description', 'schema') + class_name = param_dict["__class"] + class_: type[Param] = import_string(class_name) + attrs = ("default", "description", "schema") kwargs = {} + + def is_serialized(val): + if isinstance(val, dict): + return Encoding.TYPE in val + if isinstance(val, list): + return all(isinstance(item, dict) and Encoding.TYPE in item for item in val) + return False + for attr in attrs: if attr not in param_dict: continue val = param_dict[attr] - is_serialized = isinstance(val, dict) and '__type' in val - if is_serialized: - deserialized_val = cls._deserialize(param_dict[attr]) + if is_serialized(val): + deserialized_val = cls.deserialize(param_dict[attr]) kwargs[attr] = deserialized_val else: kwargs[attr] = val return class_(**kwargs) @classmethod - def _serialize_params_dict(cls, params: Union[ParamsDict, dict]): + def _serialize_params_dict(cls, params: ParamsDict | dict): """Serialize Params dict for a DAG/Task""" serialized_params = {} for k, v in params.items(): @@ -504,7 +599,7 @@ def _serialize_params_dict(cls, params: Union[ParamsDict, dict]): return serialized_params @classmethod - def _deserialize_params_dict(cls, encoded_params: Dict) -> ParamsDict: + def _deserialize_params_dict(cls, encoded_params: dict) -> ParamsDict: """Deserialize a DAG's Params dict""" op_params = {} for k, v in encoded_params.items(): @@ -516,37 +611,63 @@ def _deserialize_params_dict(cls, encoded_params: Dict) -> ParamsDict: return ParamsDict(op_params) - @classmethod - def _serialize_xcomarg(cls, arg: XComArg) -> dict: - return {"key": arg.key, "task_id": arg.operator.task_id} - - @classmethod - def _deserialize_xcomref(cls, encoded: dict) -> _XComRef: - return _XComRef(key=encoded['key'], task_id=encoded['task_id']) - class DependencyDetector: - """Detects dependencies between DAGs.""" + """ + Detects dependencies between DAGs. + + :meta private: + """ @staticmethod - def detect_task_dependencies(task: Operator) -> Optional['DagDependency']: + def detect_task_dependencies(task: Operator) -> list[DagDependency]: + from airflow.operators.trigger_dagrun import TriggerDagRunOperator + from airflow.sensors.external_task import ExternalTaskSensor + """Detects dependencies caused by tasks""" + deps = [] if isinstance(task, TriggerDagRunOperator): - return DagDependency( - source=task.dag_id, - target=getattr(task, "trigger_dag_id"), - dependency_type="trigger", - dependency_id=task.task_id, + deps.append( + DagDependency( + source=task.dag_id, + target=getattr(task, "trigger_dag_id"), + dependency_type="trigger", + dependency_id=task.task_id, + ) ) elif isinstance(task, ExternalTaskSensor): - return DagDependency( - source=getattr(task, "external_dag_id"), - target=task.dag_id, - dependency_type="sensor", - dependency_id=task.task_id, + deps.append( + DagDependency( + source=getattr(task, "external_dag_id"), + target=task.dag_id, + dependency_type="sensor", + dependency_id=task.task_id, + ) ) + for obj in task.outlets or []: + if isinstance(obj, Dataset): + deps.append( + DagDependency( + source=task.dag_id, + target="dataset", + dependency_type="dataset", + dependency_id=obj.uri, + ) + ) + return deps - return None + @staticmethod + def detect_dag_dependencies(dag: DAG | None) -> Iterable[DagDependency]: + """Detects dependencies set directly on the DAG object.""" + if not dag: + return + for x in dag.dataset_triggers: + yield DagDependency( + source="dataset", + target=dag.dag_id, + dependency_type="dataset", + dependency_id=x.uri, + ) class SerializedBaseOperator(BaseOperator, BaseSerialization): @@ -556,7 +677,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): Class specific attributes used by UI are move to object attributes. """ - _decorated_fields = {'executor_config'} + _decorated_fields = {"executor_config"} _CONSTRUCTOR_PARAMS = { k: v.default @@ -564,13 +685,11 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): if v.default is not v.empty } - dependency_detector = conf.getimport('scheduler', 'dependency_detector') - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # task_type is used by UI to display the correct class type, because UI only # receives BaseOperator from deserialized DAGs. - self._task_type = 'BaseOperator' + self._task_type = "BaseOperator" # Move class attributes into object attributes. self.ui_color = BaseOperator.ui_color self.ui_fgcolor = BaseOperator.ui_fgcolor @@ -588,9 +707,27 @@ def task_type(self) -> str: def task_type(self, task_type: str): self._task_type = task_type + @property + def operator_name(self) -> str: + # Overwrites operator_name of BaseOperator to use _operator_name instead of + # __class__.operator_name. + return self._operator_name + + @operator_name.setter + def operator_name(self, operator_name: str): + self._operator_name = operator_name + @classmethod - def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]: - serialized_op = cls._serialize_node(op, include_deps=op.deps is MappedOperator.deps_for(BaseOperator)) + def serialize_mapped_operator(cls, op: MappedOperator) -> dict[str, Any]: + serialized_op = cls._serialize_node(op, include_deps=op.deps != MappedOperator.deps_for(BaseOperator)) + # Handle expand_input and op_kwargs_expand_input. + expansion_kwargs = op._get_specified_expand_input() + if TYPE_CHECKING: # Let Mypy check the input type for us! + _ExpandInputRef.validate_expand_input_value(expansion_kwargs.value) + serialized_op[op._expand_input_attr] = { + "type": get_map_type_key(expansion_kwargs), + "value": cls.serialize(expansion_kwargs.value), + } # Simplify partial_kwargs by comparing it to the most barebone object. # Remove all entries that are simply default values. @@ -603,44 +740,32 @@ def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]: if v == default: del serialized_partial[k] - # Simplify op_kwargs format. It must be a dict, so we flatten it. - with contextlib.suppress(KeyError): - op_kwargs = serialized_op["mapped_kwargs"]["op_kwargs"] - assert op_kwargs[Encoding.TYPE] == DAT.DICT - serialized_op["mapped_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR] - with contextlib.suppress(KeyError): - op_kwargs = serialized_op["partial_kwargs"]["op_kwargs"] - assert op_kwargs[Encoding.TYPE] == DAT.DICT - serialized_op["partial_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR] - with contextlib.suppress(KeyError): - op_kwargs = serialized_op["mapped_op_kwargs"] - assert op_kwargs[Encoding.TYPE] == DAT.DICT - serialized_op["mapped_op_kwargs"] = op_kwargs[Encoding.VAR] - serialized_op["_is_mapped"] = True return serialized_op @classmethod - def serialize_operator(cls, op: BaseOperator) -> Dict[str, Any]: + def serialize_operator(cls, op: BaseOperator) -> dict[str, Any]: return cls._serialize_node(op, include_deps=op.deps is not BaseOperator.deps) @classmethod - def _serialize_node(cls, op: Union[BaseOperator, MappedOperator], include_deps: bool) -> Dict[str, Any]: + def _serialize_node(cls, op: BaseOperator | MappedOperator, include_deps: bool) -> dict[str, Any]: """Serializes operator into a JSON object.""" serialize_op = cls.serialize_to_json(op, cls._decorated_fields) - serialize_op['_task_type'] = getattr(op, "_task_type", type(op).__name__) - serialize_op['_task_module'] = getattr(op, "_task_module", type(op).__module__) + serialize_op["_task_type"] = getattr(op, "_task_type", type(op).__name__) + serialize_op["_task_module"] = getattr(op, "_task_module", type(op).__module__) + if op.operator_name != serialize_op["_task_type"]: + serialize_op["_operator_name"] = op.operator_name # Used to determine if an Operator is inherited from EmptyOperator - serialize_op['_is_empty'] = op.inherits_from_empty_operator + serialize_op["_is_empty"] = op.inherits_from_empty_operator if op.operator_extra_links: - serialize_op['_operator_extra_links'] = cls._serialize_operator_extra_links( + serialize_op["_operator_extra_links"] = cls._serialize_operator_extra_links( op.operator_extra_links ) if include_deps: - serialize_op['deps'] = cls._serialize_deps(op.deps) + serialize_op["deps"] = cls._serialize_deps(op.deps) # Store all template_fields as they are if there are JSON Serializable # If not, store them as strings @@ -651,12 +776,12 @@ def _serialize_node(cls, op: Union[BaseOperator, MappedOperator], include_deps: serialize_op[template_field] = serialize_template_field(value) if op.params: - serialize_op['params'] = cls._serialize_params_dict(op.params) + serialize_op["params"] = cls._serialize_params_dict(op.params) return serialize_op @classmethod - def _serialize_deps(cls, op_deps: Iterable["BaseTIDep"]) -> List[str]: + def _serialize_deps(cls, op_deps: Iterable[BaseTIDep]) -> list[str]: from airflow import plugins_manager plugins_manager.initialize_ti_deps_plugins() @@ -667,7 +792,7 @@ def _serialize_deps(cls, op_deps: Iterable["BaseTIDep"]) -> List[str]: for dep in op_deps: klass = type(dep) module_name = klass.__module__ - qualname = f'{module_name}.{klass.__name__}' + qualname = f"{module_name}.{klass.__name__}" if ( not qualname.startswith("airflow.ti_deps.deps.") and qualname not in plugins_manager.registered_ti_dep_classes @@ -682,7 +807,7 @@ def _serialize_deps(cls, op_deps: Iterable["BaseTIDep"]) -> List[str]: return sorted(deps) @classmethod - def populate_operator(cls, op: Operator, encoded_op: Dict[str, Any]) -> None: + def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None: if "label" not in encoded_op: # Handle deserialization of old data before the introduction of TaskGroup encoded_op["label"] = encoded_op["task_id"] @@ -690,6 +815,9 @@ def populate_operator(cls, op: Operator, encoded_op: Dict[str, Any]) -> None: # Extra Operator Links defined in Plugins op_extra_links_from_plugin = {} + if "_operator_name" not in encoded_op: + encoded_op["_operator_name"] = encoded_op["_task_type"] + # We don't want to load Extra Operator links in Scheduler if cls._load_operator_extra_links: from airflow import plugins_manager @@ -718,6 +846,10 @@ def populate_operator(cls, op: Operator, encoded_op: Dict[str, Any]) -> None: # Todo: TODO: Remove in Airflow 3.0 when dummy operator is removed if k == "_is_dummy": k = "_is_empty" + + if k in ("_outlets", "_inlets"): + # `_outlets` -> `outlets` + k = k[1:] if k == "_downstream_task_ids": # Upgrade from old format/name k = "downstream_task_ids" @@ -752,18 +884,18 @@ def populate_operator(cls, op: Operator, encoded_op: Dict[str, Any]) -> None: v = cls._deserialize_deps(v) elif k == "params": v = cls._deserialize_params_dict(v) - elif k in ("mapped_kwargs", "partial_kwargs"): - if "op_kwargs" not in v: - op_kwargs: Optional[dict] = None - else: - op_kwargs = {arg: cls._deserialize(value) for arg, value in v.pop("op_kwargs").items()} - v = {arg: cls._deserialize(value) for arg, value in v.items()} - if op_kwargs is not None: - v["op_kwargs"] = op_kwargs - elif k == "mapped_op_kwargs": - v = {arg: cls._deserialize(value) for arg, value in v.items()} + if op.params: # Merge existing params if needed. + v, new = op.params, v + v.update(new) + elif k == "partial_kwargs": + v = {arg: cls.deserialize(value) for arg, value in v.items()} + elif k in {"expand_input", "op_kwargs_expand_input"}: + v = _ExpandInputRef(v["type"], cls.deserialize(v["value"])) elif k in cls._decorated_fields or k not in op.get_serialized_fields(): - v = cls._deserialize(v) + v = cls.deserialize(v) + elif k in ("outlets", "inlets"): + v = cls.deserialize(v) + # else use v as it is setattr(op, k, v) @@ -783,15 +915,19 @@ def populate_operator(cls, op: Operator, encoded_op: Dict[str, Any]) -> None: setattr(op, "_is_empty", bool(encoded_op.get("_is_empty", False))) @classmethod - def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> Operator: + def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator: """Deserializes an operator from a JSON object.""" op: Operator if encoded_op.get("_is_mapped", False): # Most of these will be loaded later, these are just some stand-ins. op_data = {k: v for k, v in encoded_op.items() if k in BaseOperator.get_serialized_fields()} + try: + operator_name = encoded_op["_operator_name"] + except KeyError: + operator_name = encoded_op["_task_type"] op = MappedOperator( operator_class=op_data, - mapped_kwargs={}, + expand_input=EXPAND_INPUT_EMPTY, partial_kwargs={}, task_id=encoded_op["task_id"], params={}, @@ -805,25 +941,51 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> Operator: is_empty=False, task_module=encoded_op["_task_module"], task_type=encoded_op["_task_type"], + operator_name=operator_name, dag=None, task_group=None, start_date=None, end_date=None, - expansion_kwargs_attr=encoded_op["_expansion_kwargs_attr"], + disallow_kwargs_override=encoded_op["_disallow_kwargs_override"], + expand_input_attr=encoded_op["_expand_input_attr"], ) else: - op = SerializedBaseOperator(task_id=encoded_op['task_id']) + op = SerializedBaseOperator(task_id=encoded_op["task_id"]) cls.populate_operator(op, encoded_op) return op @classmethod - def detect_dependencies(cls, op: Operator) -> Optional['DagDependency']: + def detect_dependencies(cls, op: Operator) -> set[DagDependency]: """Detects between DAG dependencies for the operator.""" - return cls.dependency_detector.detect_task_dependencies(op) + + def get_custom_dep() -> list[DagDependency]: + """ + If custom dependency detector is configured, use it. + + TODO: Remove this logic in 3.0. + """ + custom_dependency_detector_cls = conf.getimport("scheduler", "dependency_detector", fallback=None) + if not ( + custom_dependency_detector_cls is None or custom_dependency_detector_cls is DependencyDetector + ): + warnings.warn( + "Use of a custom dependency detector is deprecated. " + "Support will be removed in a future release.", + RemovedInAirflow3Warning, + ) + dep = custom_dependency_detector_cls().detect_task_dependencies(op) + if type(dep) is DagDependency: + return [dep] + return [] + + dependency_detector = DependencyDetector() + deps = set(dependency_detector.detect_task_dependencies(op)) + deps.update(get_custom_dep()) # todo: remove in 3.0 + return deps @classmethod - def _is_excluded(cls, var: Any, attrname: str, op: "DAGNode"): + def _is_excluded(cls, var: Any, attrname: str, op: DAGNode): if var is not None and op.has_dag() and attrname.endswith("_date"): # If this date is the same as the matching field in the dag, then # don't store it again at the task level. @@ -833,7 +995,7 @@ def _is_excluded(cls, var: Any, attrname: str, op: "DAGNode"): return super()._is_excluded(var, attrname, op) @classmethod - def _deserialize_deps(cls, deps: List[str]) -> Set["BaseTIDep"]: + def _deserialize_deps(cls, deps: list[str]) -> set[BaseTIDep]: from airflow import plugins_manager plugins_manager.initialize_ti_deps_plugins() @@ -857,7 +1019,7 @@ def _deserialize_deps(cls, deps: List[str]) -> Set["BaseTIDep"]: return instances @classmethod - def _deserialize_operator_extra_links(cls, encoded_op_links: list) -> Dict[str, BaseOperatorLink]: + def _deserialize_operator_extra_links(cls, encoded_op_links: list) -> dict[str, BaseOperatorLink]: """ Deserialize Operator Links if the Classes are registered in Airflow Plugins. Error is raised if the OperatorLink is not found in Plugins too. @@ -956,15 +1118,15 @@ class SerializedDAG(DAG, BaseSerialization): not pickle-able. SerializedDAG works for all DAGs. """ - _decorated_fields = {'schedule_interval', 'default_args', '_access_control'} + _decorated_fields = {"schedule_interval", "default_args", "_access_control"} @staticmethod def __get_constructor_defaults(): param_to_attr = { - 'max_active_tasks': '_max_active_tasks', - 'description': '_description', - 'default_view': '_default_view', - 'access_control': '_access_control', + "max_active_tasks": "_max_active_tasks", + "description": "_description", + "default_view": "_default_view", + "access_control": "_access_control", } return { param_to_attr.get(k, k): v.default @@ -975,7 +1137,7 @@ def __get_constructor_defaults(): _CONSTRUCTOR_PARAMS = __get_constructor_defaults.__func__() # type: ignore del __get_constructor_defaults - _json_schema = load_dag_schema() + _json_schema = lazy_object_proxy.Proxy(load_dag_schema) @classmethod def serialize_dag(cls, dag: DAG) -> dict: @@ -983,6 +1145,8 @@ def serialize_dag(cls, dag: DAG) -> dict: try: serialized_dag = cls.serialize_to_json(dag, cls._decorated_fields) + serialized_dag["_processor_dags_folder"] = DAGS_FOLDER + # If schedule_interval is backed by timetable, serialize only # timetable; vice versa for a timetable backed by schedule_interval. if dag.timetable.summary == dag.schedule_interval: @@ -990,13 +1154,15 @@ def serialize_dag(cls, dag: DAG) -> dict: else: del serialized_dag["timetable"] - serialized_dag["tasks"] = [cls._serialize(task) for _, task in dag.task_dict.items()] - serialized_dag["dag_dependencies"] = [ - vars(t) - for t in (SerializedBaseOperator.detect_dependencies(task) for task in dag.task_dict.values()) - if t is not None - ] - serialized_dag['_task_group'] = SerializedTaskGroup.serialize_task_group(dag.task_group) + serialized_dag["tasks"] = [cls.serialize(task) for _, task in dag.task_dict.items()] + dag_deps = { + dep + for task in dag.task_dict.values() + for dep in SerializedBaseOperator.detect_dependencies(task) + } + dag_deps.update(DependencyDetector.detect_dag_dependencies(dag)) + serialized_dag["dag_dependencies"] = [x.__dict__ for x in dag_deps] + serialized_dag["_task_group"] = TaskGroupSerialization.serialize_task_group(dag.task_group) # Edge info in the JSON exactly matches our internal structure serialized_dag["edge_info"] = dag.edge_info @@ -1004,19 +1170,19 @@ def serialize_dag(cls, dag: DAG) -> dict: # has_on_*_callback are only stored if the value is True, as the default is False if dag.has_on_success_callback: - serialized_dag['has_on_success_callback'] = True + serialized_dag["has_on_success_callback"] = True if dag.has_on_failure_callback: - serialized_dag['has_on_failure_callback'] = True + serialized_dag["has_on_failure_callback"] = True return serialized_dag except SerializationError: raise except Exception as e: - raise SerializationError(f'Failed to serialize DAG {dag.dag_id!r}: {e}') + raise SerializationError(f"Failed to serialize DAG {dag.dag_id!r}: {e}") @classmethod - def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG': + def deserialize_dag(cls, encoded_dag: dict[str, Any]) -> SerializedDAG: """Deserializes a DAG from a JSON object.""" - dag = SerializedDAG(dag_id=encoded_dag['_dag_id']) + dag = SerializedDAG(dag_id=encoded_dag["_dag_id"]) for k, v in encoded_dag.items(): if k == "_downstream_task_ids": @@ -1039,9 +1205,11 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG': elif k == "timetable": v = _decode_timetable(v) elif k in cls._decorated_fields: - v = cls._deserialize(v) + v = cls.deserialize(v) elif k == "params": v = cls._deserialize_params_dict(v) + elif k == "dataset_triggers": + v = cls.deserialize(v) # else use v as it is setattr(dag, k, v) @@ -1056,8 +1224,11 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG': # Set _task_group if "_task_group" in encoded_dag: - dag._task_group = SerializedTaskGroup.deserialize_task_group( - encoded_dag["_task_group"], None, dag.task_dict, dag + dag._task_group = TaskGroupSerialization.deserialize_task_group( + encoded_dag["_task_group"], + None, + dag.task_dict, + dag, ) else: # This must be old data that had no task_group. Create a root TaskGroup and add @@ -1084,15 +1255,13 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG': setattr(task, date_attr, getattr(dag, date_attr)) if task.subdag is not None: - setattr(task.subdag, 'parent_dag', dag) + setattr(task.subdag, "parent_dag", dag) - if isinstance(task, MappedOperator): - expansion_kwargs = task._get_expansion_kwargs() - for k, v in expansion_kwargs.items(): - if not isinstance(v, _XComRef): - continue - - expansion_kwargs[k] = XComArg(operator=dag.get_task(v.task_id), key=v.key) + # Dereference expand_input and op_kwargs_expand_input. + for k in ("expand_input", "op_kwargs_expand_input"): + kwargs_ref = getattr(task, k, None) + if isinstance(kwargs_ref, _ExpandInputRef): + setattr(task, k, kwargs_ref.deref(dag)) for task_id in task.downstream_task_ids: # Bypass set_upstream etc here - it does more than we want @@ -1110,19 +1279,19 @@ def to_dict(cls, var: Any) -> dict: return json_dict @classmethod - def from_dict(cls, serialized_obj: dict) -> 'SerializedDAG': + def from_dict(cls, serialized_obj: dict) -> SerializedDAG: """Deserializes a python dict in to the DAG and operators it contains.""" - ver = serialized_obj.get('__version', '') + ver = serialized_obj.get("__version", "") if ver != cls.SERIALIZER_VERSION: raise ValueError(f"Unsure how to deserialize version {ver!r}") - return cls.deserialize_dag(serialized_obj['dag']) + return cls.deserialize_dag(serialized_obj["dag"]) -class SerializedTaskGroup(TaskGroup, BaseSerialization): - """A JSON serializable representation of TaskGroup.""" +class TaskGroupSerialization(BaseSerialization): + """JSON serializable representation of a task group.""" @classmethod - def serialize_task_group(cls, task_group: TaskGroup) -> Optional[Dict[str, Any]]: + def serialize_task_group(cls, task_group: TaskGroup) -> dict[str, Any] | None: """Serializes TaskGroup into a JSON object.""" if not task_group: return None @@ -1130,7 +1299,7 @@ def serialize_task_group(cls, task_group: TaskGroup) -> Optional[Dict[str, Any]] # task_group.xxx_ids needs to be sorted here, because task_group.xxx_ids is a set, # when converting set to list, the order is uncertain. # When calling json.dumps(self.data, sort_keys=True) to generate dag_hash, misjudgment will occur - serialize_group = { + encoded = { "_group_id": task_group._group_id, "prefix_group_id": task_group.prefix_group_id, "tooltip": task_group.tooltip, @@ -1139,48 +1308,67 @@ def serialize_task_group(cls, task_group: TaskGroup) -> Optional[Dict[str, Any]] "children": { label: child.serialize_for_task_group() for label, child in task_group.children.items() }, - "upstream_group_ids": cls._serialize(sorted(task_group.upstream_group_ids)), - "downstream_group_ids": cls._serialize(sorted(task_group.downstream_group_ids)), - "upstream_task_ids": cls._serialize(sorted(task_group.upstream_task_ids)), - "downstream_task_ids": cls._serialize(sorted(task_group.downstream_task_ids)), + "upstream_group_ids": cls.serialize(sorted(task_group.upstream_group_ids)), + "downstream_group_ids": cls.serialize(sorted(task_group.downstream_group_ids)), + "upstream_task_ids": cls.serialize(sorted(task_group.upstream_task_ids)), + "downstream_task_ids": cls.serialize(sorted(task_group.downstream_task_ids)), } - return serialize_group + if isinstance(task_group, MappedTaskGroup): + expand_input = task_group._expand_input + encoded["expand_input"] = { + "type": get_map_type_key(expand_input), + "value": cls.serialize(expand_input.value), + } + encoded["is_mapped"] = True + + return encoded @classmethod def deserialize_task_group( cls, - encoded_group: Dict[str, Any], - parent_group: Optional[TaskGroup], - task_dict: Dict[str, Operator], + encoded_group: dict[str, Any], + parent_group: TaskGroup | None, + task_dict: dict[str, Operator], dag: SerializedDAG, ) -> TaskGroup: """Deserializes a TaskGroup from a JSON object.""" - group_id = cls._deserialize(encoded_group["_group_id"]) + group_id = cls.deserialize(encoded_group["_group_id"]) kwargs = { - key: cls._deserialize(encoded_group[key]) + key: cls.deserialize(encoded_group[key]) for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor"] } - group = SerializedTaskGroup(group_id=group_id, parent_group=parent_group, dag=dag, **kwargs) + + if not encoded_group.get("is_mapped"): + group = TaskGroup(group_id=group_id, parent_group=parent_group, dag=dag, **kwargs) + else: + xi = encoded_group["expand_input"] + group = MappedTaskGroup( + group_id=group_id, + parent_group=parent_group, + dag=dag, + expand_input=_ExpandInputRef(xi["type"], cls.deserialize(xi["value"])).deref(dag), + **kwargs, + ) def set_ref(task: Operator) -> Operator: task.task_group = weakref.proxy(group) return task group.children = { - label: set_ref(task_dict[val]) # type: ignore - if _type == DAT.OP # type: ignore - else SerializedTaskGroup.deserialize_task_group(val, group, task_dict, dag=dag) + label: set_ref(task_dict[val]) + if _type == DAT.OP + else cls.deserialize_task_group(val, group, task_dict, dag=dag) for label, (_type, val) in encoded_group["children"].items() } - group.upstream_group_ids.update(cls._deserialize(encoded_group["upstream_group_ids"])) - group.downstream_group_ids.update(cls._deserialize(encoded_group["downstream_group_ids"])) - group.upstream_task_ids.update(cls._deserialize(encoded_group["upstream_task_ids"])) - group.downstream_task_ids.update(cls._deserialize(encoded_group["downstream_task_ids"])) + group.upstream_group_ids.update(cls.deserialize(encoded_group["upstream_group_ids"])) + group.downstream_group_ids.update(cls.deserialize(encoded_group["downstream_group_ids"])) + group.upstream_task_ids.update(cls.deserialize(encoded_group["upstream_task_ids"])) + group.downstream_task_ids.update(cls.deserialize(encoded_group["downstream_task_ids"])) return group -@dataclass +@dataclass(frozen=True) class DagDependency: """Dataclass for representing dependencies between DAGs. These are calculated during serialization and attached to serialized DAGs. @@ -1189,12 +1377,17 @@ class DagDependency: source: str target: str dependency_type: str - dependency_id: str + dependency_id: str | None = None @property def node_id(self): """Node ID for graph rendering""" - return f"{self.dependency_type}:{self.source}:{self.target}:{self.dependency_id}" + val = f"{self.dependency_type}" + if not self.dependency_type == "dataset": + val += f":{self.source}:{self.target}" + if self.dependency_id: + val += f":{self.dependency_id}" + return val def _has_kubernetes() -> bool: @@ -1209,8 +1402,8 @@ def _has_kubernetes() -> bool: from airflow.kubernetes.pod_generator import PodGenerator - globals()['k8s'] = k8s - globals()['PodGenerator'] = PodGenerator + globals()["k8s"] = k8s + globals()["PodGenerator"] = PodGenerator # isort: on HAS_KUBERNETES = True diff --git a/airflow/settings.py b/airflow/settings.py index e8bf80a2d929b..537c49141a87b 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import atexit import functools import json @@ -22,7 +24,7 @@ import os import sys import warnings -from typing import TYPE_CHECKING, Callable, List, Optional, Union +from typing import TYPE_CHECKING, Callable import pendulum import sqlalchemy @@ -33,6 +35,7 @@ from sqlalchemy.pool import NullPool from airflow.configuration import AIRFLOW_HOME, WEBSERVER_CONFIG, conf # NOQA F401 +from airflow.exceptions import RemovedInAirflow3Warning from airflow.executors import executor_constants from airflow.logging_config import configure_logging from airflow.utils.orm_event_handlers import setup_event_handlers @@ -43,7 +46,7 @@ log = logging.getLogger(__name__) -TIMEZONE = pendulum.tz.timezone('UTC') +TIMEZONE = pendulum.tz.timezone("UTC") try: tz = conf.get_mandatory_value("core", "default_timezone") if tz == "system": @@ -55,13 +58,13 @@ log.info("Configured default timezone %s", TIMEZONE) -HEADER = '\n'.join( +HEADER = "\n".join( [ - r' ____________ _____________', - r' ____ |__( )_________ __/__ /________ __', - r'____ /| |_ /__ ___/_ /_ __ /_ __ \_ | /| / /', - r'___ ___ | / _ / _ __/ _ / / /_/ /_ |/ |/ /', - r' _/_/ |_/_/ /_/ /_/ /_/ \____/____/|__/', + r" ____________ _____________", + r" ____ |__( )_________ __/__ /________ __", + r"____ /| |_ /__ ___/_ /_ __ /_ __ \_ | /| / /", + r"___ ___ | / _ / _ __/ _ / / /_/ /_ |/ |/ /", + r" _/_/ |_/_/ /_/ /_/ /_/ \____/____/|__/", ] ) @@ -70,14 +73,14 @@ # the prefix to append to gunicorn worker processes after init GUNICORN_WORKER_READY_PREFIX = "[ready] " -LOG_FORMAT = conf.get('logging', 'log_format') -SIMPLE_LOG_FORMAT = conf.get('logging', 'simple_log_format') +LOG_FORMAT = conf.get("logging", "log_format") +SIMPLE_LOG_FORMAT = conf.get("logging", "simple_log_format") -SQL_ALCHEMY_CONN: Optional[str] = None -PLUGINS_FOLDER: Optional[str] = None -LOGGING_CLASS_PATH: Optional[str] = None -DONOT_MODIFY_HANDLERS: Optional[bool] = None -DAGS_FOLDER: str = os.path.expanduser(conf.get_mandatory_value('core', 'DAGS_FOLDER')) +SQL_ALCHEMY_CONN: str | None = None +PLUGINS_FOLDER: str | None = None +LOGGING_CLASS_PATH: str | None = None +DONOT_MODIFY_HANDLERS: bool | None = None +DAGS_FOLDER: str = os.path.expanduser(conf.get_mandatory_value("core", "DAGS_FOLDER")) engine: Engine Session: Callable[..., SASession] @@ -91,8 +94,11 @@ "deferred": "mediumpurple", "failed": "red", "queued": "gray", + "removed": "lightgrey", + "restarting": "violet", "running": "lime", "scheduled": "tan", + "shutdown": "blue", "skipped": "hotpink", "success": "green", "up_for_reschedule": "turquoise", @@ -222,7 +228,7 @@ def get_airflow_context_vars(context): return {} -def get_dagbag_import_timeout(dag_file_path: str) -> Union[int, float]: +def get_dagbag_import_timeout(dag_file_path: str) -> int | float: """ This setting allows for dynamic control of the DAG file parsing timeout based on the DAG file path. @@ -231,7 +237,7 @@ def get_dagbag_import_timeout(dag_file_path: str) -> Union[int, float]: If the return value is less than or equal to 0, it means no timeout during the DAG parsing. """ - return conf.getfloat('core', 'DAGBAG_IMPORT_TIMEOUT') + return conf.getfloat("core", "DAGBAG_IMPORT_TIMEOUT") def configure_vars(): @@ -240,30 +246,30 @@ def configure_vars(): global DAGS_FOLDER global PLUGINS_FOLDER global DONOT_MODIFY_HANDLERS - SQL_ALCHEMY_CONN = conf.get('database', 'SQL_ALCHEMY_CONN') - DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER')) + SQL_ALCHEMY_CONN = conf.get("database", "SQL_ALCHEMY_CONN") + DAGS_FOLDER = os.path.expanduser(conf.get("core", "DAGS_FOLDER")) - PLUGINS_FOLDER = conf.get('core', 'plugins_folder', fallback=os.path.join(AIRFLOW_HOME, 'plugins')) + PLUGINS_FOLDER = conf.get("core", "plugins_folder", fallback=os.path.join(AIRFLOW_HOME, "plugins")) # If donot_modify_handlers=True, we do not modify logging handlers in task_run command # If the flag is set to False, we remove all handlers from the root logger # and add all handlers from 'airflow.task' logger to the root Logger. This is done # to get all the logs from the print & log statements in the DAG files before a task is run # The handlers are restored after the task completes execution. - DONOT_MODIFY_HANDLERS = conf.getboolean('logging', 'donot_modify_handlers', fallback=False) + DONOT_MODIFY_HANDLERS = conf.getboolean("logging", "donot_modify_handlers", fallback=False) -def configure_orm(disable_connection_pool=False): +def configure_orm(disable_connection_pool=False, pool_class=None): """Configure ORM using SQLAlchemy""" from airflow.utils.log.secrets_masker import mask_secret log.debug("Setting up DB connection pool (PID %s)", os.getpid()) global engine global Session - engine_args = prepare_engine_args(disable_connection_pool) + engine_args = prepare_engine_args(disable_connection_pool, pool_class) - if conf.has_option('database', 'sql_alchemy_connect_args'): - connect_args = conf.getimport('database', 'sql_alchemy_connect_args') + if conf.has_option("database", "sql_alchemy_connect_args"): + connect_args = conf.getimport("database", "sql_alchemy_connect_args") else: connect_args = {} @@ -281,12 +287,12 @@ def configure_orm(disable_connection_pool=False): expire_on_commit=False, ) ) - if engine.dialect.name == 'mssql': + if engine.dialect.name == "mssql": session = Session() try: result = session.execute( sqlalchemy.text( - 'SELECT is_read_committed_snapshot_on FROM sys.databases WHERE name=:database_name' + "SELECT is_read_committed_snapshot_on FROM sys.databases WHERE name=:database_name" ), params={"database_name": engine.url.database}, ) @@ -305,15 +311,15 @@ def configure_orm(disable_connection_pool=False): DEFAULT_ENGINE_ARGS = { - 'postgresql': { - 'executemany_mode': 'values', - 'executemany_values_page_size': 10000, - 'executemany_batch_page_size': 2000, + "postgresql": { + "executemany_mode": "values", + "executemany_values_page_size": 10000, + "executemany_batch_page_size": 2000, }, } -def prepare_engine_args(disable_connection_pool=False): +def prepare_engine_args(disable_connection_pool=False, pool_class=None): """Prepare SQLAlchemy engine args""" default_args = {} for dialect, default in DEFAULT_ENGINE_ARGS.items(): @@ -322,17 +328,20 @@ def prepare_engine_args(disable_connection_pool=False): break engine_args: dict = conf.getjson( - 'database', 'sql_alchemy_engine_args', fallback=default_args + "database", "sql_alchemy_engine_args", fallback=default_args ) # type: ignore - if disable_connection_pool or not conf.getboolean('database', 'SQL_ALCHEMY_POOL_ENABLED'): - engine_args['poolclass'] = NullPool + if pool_class: + # Don't use separate settings for size etc, only those from sql_alchemy_engine_args + engine_args["poolclass"] = pool_class + elif disable_connection_pool or not conf.getboolean("database", "SQL_ALCHEMY_POOL_ENABLED"): + engine_args["poolclass"] = NullPool log.debug("settings.prepare_engine_args(): Using NullPool") - elif not SQL_ALCHEMY_CONN.startswith('sqlite'): + elif not SQL_ALCHEMY_CONN.startswith("sqlite"): # Pool size engine args not supported by sqlite. # If no config value is defined for the pool size, select a reasonable value. # 0 means no limit, which could lead to exceeding the Database connection limit. - pool_size = conf.getint('database', 'SQL_ALCHEMY_POOL_SIZE', fallback=5) + pool_size = conf.getint("database", "SQL_ALCHEMY_POOL_SIZE", fallback=5) # The maximum overflow size of the pool. # When the number of checked-out connections reaches the size set in pool_size, @@ -340,24 +349,24 @@ def prepare_engine_args(disable_connection_pool=False): # When those additional connections are returned to the pool, they are disconnected and discarded. # It follows then that the total number of simultaneous connections # the pool will allow is pool_size + max_overflow, - # and the total number of “sleeping” connections the pool will allow is pool_size. + # and the total number of "sleeping" connections the pool will allow is pool_size. # max_overflow can be set to -1 to indicate no overflow limit; # no limit will be placed on the total number # of concurrent connections. Defaults to 10. - max_overflow = conf.getint('database', 'SQL_ALCHEMY_MAX_OVERFLOW', fallback=10) + max_overflow = conf.getint("database", "SQL_ALCHEMY_MAX_OVERFLOW", fallback=10) # The DB server already has a value for wait_timeout (number of seconds after # which an idle sleeping connection should be killed). Since other DBs may # co-exist on the same server, SQLAlchemy should set its # pool_recycle to an equal or smaller value. - pool_recycle = conf.getint('database', 'SQL_ALCHEMY_POOL_RECYCLE', fallback=1800) + pool_recycle = conf.getint("database", "SQL_ALCHEMY_POOL_RECYCLE", fallback=1800) # Check connection at the start of each connection pool checkout. - # Typically, this is a simple statement like “SELECT 1”, but may also make use + # Typically, this is a simple statement like "SELECT 1", but may also make use # of some DBAPI-specific method to test the connection for liveness. # More information here: # https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic - pool_pre_ping = conf.getboolean('database', 'SQL_ALCHEMY_POOL_PRE_PING', fallback=True) + pool_pre_ping = conf.getboolean("database", "SQL_ALCHEMY_POOL_PRE_PING", fallback=True) log.debug( "settings.prepare_engine_args(): Using pool settings. pool_size=%d, max_overflow=%d, " @@ -367,10 +376,10 @@ def prepare_engine_args(disable_connection_pool=False): pool_recycle, os.getpid(), ) - engine_args['pool_size'] = pool_size - engine_args['pool_recycle'] = pool_recycle - engine_args['pool_pre_ping'] = pool_pre_ping - engine_args['max_overflow'] = max_overflow + engine_args["pool_size"] = pool_size + engine_args["pool_recycle"] = pool_recycle + engine_args["pool_pre_ping"] = pool_pre_ping + engine_args["max_overflow"] = max_overflow # The default isolation level for MySQL (REPEATABLE READ) can introduce inconsistencies when # running multiple schedulers, as repeated queries on the same session may read from stale snapshots. @@ -383,12 +392,12 @@ def prepare_engine_args(disable_connection_pool=False): # Select queries are running. This is by default enforced during init/upgrade. More information: # https://docs.microsoft.com/en-us/sql/t-sql/statements/set-transaction-isolation-level-transact-sql - if SQL_ALCHEMY_CONN.startswith(('mysql', 'mssql')): - engine_args['isolation_level'] = 'READ COMMITTED' + if SQL_ALCHEMY_CONN.startswith(("mysql", "mssql")): + engine_args["isolation_level"] = "READ COMMITTED" # Allow the user to specify an encoding for their DB otherwise default # to utf-8 so jobs & users with non-latin1 characters can still use us. - engine_args['encoding'] = conf.get('database', 'SQL_ENGINE_ENCODING', fallback='utf-8') + engine_args["encoding"] = conf.get("database", "SQL_ENGINE_ENCODING", fallback="utf-8") return engine_args @@ -407,22 +416,22 @@ def dispose_orm(): engine = None -def reconfigure_orm(disable_connection_pool=False): +def reconfigure_orm(disable_connection_pool=False, pool_class=None): """Properly close database connections and re-configure ORM""" dispose_orm() - configure_orm(disable_connection_pool=disable_connection_pool) + configure_orm(disable_connection_pool=disable_connection_pool, pool_class=pool_class) def configure_adapters(): """Register Adapters and DB Converters""" from pendulum import DateTime as Pendulum - if SQL_ALCHEMY_CONN.startswith('sqlite'): + if SQL_ALCHEMY_CONN.startswith("sqlite"): from sqlite3 import register_adapter - register_adapter(Pendulum, lambda val: val.isoformat(' ')) + register_adapter(Pendulum, lambda val: val.isoformat(" ")) - if SQL_ALCHEMY_CONN.startswith('mysql'): + if SQL_ALCHEMY_CONN.startswith("mysql"): try: import MySQLdb.converters @@ -441,7 +450,7 @@ def validate_session(): """Validate ORM Session""" global engine - worker_precheck = conf.getboolean('celery', 'worker_precheck', fallback=False) + worker_precheck = conf.getboolean("celery", "worker_precheck", fallback=False) if not worker_precheck: return True else: @@ -457,11 +466,10 @@ def validate_session(): return conn_status -def configure_action_logging(): +def configure_action_logging() -> None: """ Any additional configuration (register callback) for airflow.utils.action_loggers module - :rtype: None """ @@ -472,7 +480,7 @@ def prepare_syspath(): # Add ./config/ for loading custom log parsers etc, or # airflow_local_settings etc. - config_path = os.path.join(AIRFLOW_HOME, 'config') + config_path = os.path.join(AIRFLOW_HOME, "config") if config_path not in sys.path: sys.path.append(config_path) @@ -482,21 +490,21 @@ def prepare_syspath(): def get_session_lifetime_config(): """Gets session timeout configs and handles outdated configs gracefully.""" - session_lifetime_minutes = conf.get('webserver', 'session_lifetime_minutes', fallback=None) - session_lifetime_days = conf.get('webserver', 'session_lifetime_days', fallback=None) + session_lifetime_minutes = conf.get("webserver", "session_lifetime_minutes", fallback=None) + session_lifetime_days = conf.get("webserver", "session_lifetime_days", fallback=None) uses_deprecated_lifetime_configs = session_lifetime_days or conf.get( - 'webserver', 'force_log_out_after', fallback=None + "webserver", "force_log_out_after", fallback=None ) minutes_per_day = 24 * 60 - default_lifetime_minutes = '43200' + default_lifetime_minutes = "43200" if uses_deprecated_lifetime_configs and session_lifetime_minutes == default_lifetime_minutes: warnings.warn( - '`session_lifetime_days` option from `[webserver]` section has been ' - 'renamed to `session_lifetime_minutes`. The new option allows to configure ' - 'session lifetime in minutes. The `force_log_out_after` option has been removed ' - 'from `[webserver]` section. Please update your configuration.', - category=DeprecationWarning, + "`session_lifetime_days` option from `[webserver]` section has been " + "renamed to `session_lifetime_minutes`. The new option allows to configure " + "session lifetime in minutes. The `force_log_out_after` option has been removed " + "from `[webserver]` section. Please update your configuration.", + category=RemovedInAirflow3Warning, ) if session_lifetime_days: session_lifetime_minutes = minutes_per_day * int(session_lifetime_days) @@ -505,7 +513,7 @@ def get_session_lifetime_config(): session_lifetime_days = 30 session_lifetime_minutes = minutes_per_day * session_lifetime_days - logging.debug('User session lifetime is set to %s minutes.', session_lifetime_minutes) + logging.debug("User session lifetime is set to %s minutes.", session_lifetime_minutes) return int(session_lifetime_minutes) @@ -534,7 +542,7 @@ def import_local_settings(): globals()["task_policy"] = globals()["policy"] del globals()["policy"] - if not hasattr(task_instance_mutation_hook, 'is_noop'): + if not hasattr(task_instance_mutation_hook, "is_noop"): task_instance_mutation_hook.is_noop = False log.info("Loaded airflow_local_settings from %s .", airflow_local_settings.__file__) @@ -572,52 +580,52 @@ def initialize(): KILOBYTE = 1024 MEGABYTE = KILOBYTE * KILOBYTE -WEB_COLORS = {'LIGHTBLUE': '#4d9de0', 'LIGHTORANGE': '#FF9933'} +WEB_COLORS = {"LIGHTBLUE": "#4d9de0", "LIGHTORANGE": "#FF9933"} # Updating serialized DAG can not be faster than a minimum interval to reduce database # write rate. -MIN_SERIALIZED_DAG_UPDATE_INTERVAL = conf.getint('core', 'min_serialized_dag_update_interval', fallback=30) +MIN_SERIALIZED_DAG_UPDATE_INTERVAL = conf.getint("core", "min_serialized_dag_update_interval", fallback=30) # If set to True, serialized DAGs is compressed before writing to DB, -COMPRESS_SERIALIZED_DAGS = conf.getboolean('core', 'compress_serialized_dags', fallback=False) +COMPRESS_SERIALIZED_DAGS = conf.getboolean("core", "compress_serialized_dags", fallback=False) # Fetching serialized DAG can not be faster than a minimum interval to reduce database # read rate. This config controls when your DAGs are updated in the Webserver -MIN_SERIALIZED_DAG_FETCH_INTERVAL = conf.getint('core', 'min_serialized_dag_fetch_interval', fallback=10) +MIN_SERIALIZED_DAG_FETCH_INTERVAL = conf.getint("core", "min_serialized_dag_fetch_interval", fallback=10) CAN_FORK = hasattr(os, "fork") EXECUTE_TASKS_NEW_PYTHON_INTERPRETER = not CAN_FORK or conf.getboolean( - 'core', - 'execute_tasks_new_python_interpreter', + "core", + "execute_tasks_new_python_interpreter", fallback=False, ) -ALLOW_FUTURE_EXEC_DATES = conf.getboolean('scheduler', 'allow_trigger_in_future', fallback=False) +ALLOW_FUTURE_EXEC_DATES = conf.getboolean("scheduler", "allow_trigger_in_future", fallback=False) # Whether or not to check each dagrun against defined SLAs -CHECK_SLAS = conf.getboolean('core', 'check_slas', fallback=True) +CHECK_SLAS = conf.getboolean("core", "check_slas", fallback=True) -USE_JOB_SCHEDULE = conf.getboolean('scheduler', 'use_job_schedule', fallback=True) +USE_JOB_SCHEDULE = conf.getboolean("scheduler", "use_job_schedule", fallback=True) # By default Airflow plugins are lazily-loaded (only loaded when required). Set it to False, # if you want to load plugins whenever 'airflow' is invoked via cli or loaded from module. -LAZY_LOAD_PLUGINS = conf.getboolean('core', 'lazy_load_plugins', fallback=True) +LAZY_LOAD_PLUGINS = conf.getboolean("core", "lazy_load_plugins", fallback=True) # By default Airflow providers are lazily-discovered (discovery and imports happen only when required). # Set it to False, if you want to discover providers whenever 'airflow' is invoked via cli or # loaded from module. -LAZY_LOAD_PROVIDERS = conf.getboolean('core', 'lazy_discover_providers', fallback=True) +LAZY_LOAD_PROVIDERS = conf.getboolean("core", "lazy_discover_providers", fallback=True) # Determines if the executor utilizes Kubernetes -IS_K8S_OR_K8SCELERY_EXECUTOR = conf.get('core', 'EXECUTOR') in { +IS_K8S_OR_K8SCELERY_EXECUTOR = conf.get("core", "EXECUTOR") in { executor_constants.KUBERNETES_EXECUTOR, executor_constants.CELERY_KUBERNETES_EXECUTOR, executor_constants.LOCAL_KUBERNETES_EXECUTOR, } -HIDE_SENSITIVE_VAR_CONN_FIELDS = conf.getboolean('core', 'hide_sensitive_var_conn_fields') +HIDE_SENSITIVE_VAR_CONN_FIELDS = conf.getboolean("core", "hide_sensitive_var_conn_fields") # By default this is off, but is automatically configured on when running task # instances @@ -636,7 +644,9 @@ def initialize(): # UIAlert('Visit airflow.apache.org', html=True), # ] # -DASHBOARD_UIALERTS: List["UIAlert"] = [] +DASHBOARD_UIALERTS: list[UIAlert] = [] # Prefix used to identify tables holding data moved during migration. AIRFLOW_MOVED_TABLE_PREFIX = "_airflow_moved" + +DAEMON_UMASK: str = conf.get("core", "daemon_umask", fallback="0o077") diff --git a/airflow/smart_sensor_dags/__init__.py b/airflow/smart_sensor_dags/__init__.py deleted file mode 100644 index 217e5db960782..0000000000000 --- a/airflow/smart_sensor_dags/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/airflow/smart_sensor_dags/smart_sensor_group.py b/airflow/smart_sensor_dags/smart_sensor_group.py deleted file mode 100644 index df6329c407567..0000000000000 --- a/airflow/smart_sensor_dags/smart_sensor_group.py +++ /dev/null @@ -1,56 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Smart sensor DAGs managing all smart sensor tasks.""" -from datetime import datetime, timedelta - -from airflow.configuration import conf -from airflow.models import DAG -from airflow.sensors.smart_sensor import SmartSensorOperator - -num_smart_sensor_shard = conf.getint("smart_sensor", "shards") -shard_code_upper_limit = conf.getint('smart_sensor', 'shard_code_upper_limit') - -for i in range(num_smart_sensor_shard): - shard_min = (i * shard_code_upper_limit) / num_smart_sensor_shard - shard_max = ((i + 1) * shard_code_upper_limit) / num_smart_sensor_shard - - dag_id = f'smart_sensor_group_shard_{i}' - dag = DAG( - dag_id=dag_id, - schedule_interval=timedelta(minutes=5), - max_active_tasks=1, - max_active_runs=1, - catchup=False, - dagrun_timeout=timedelta(hours=24), - start_date=datetime(2021, 1, 1), - ) - - SmartSensorOperator( - task_id='smart_sensor_task', - dag=dag, - retries=100, - retry_delay=timedelta(seconds=10), - priority_weight=999, - shard_min=shard_min, - shard_max=shard_max, - poke_timeout=10, - smart_sensor_timeout=timedelta(hours=24).total_seconds(), - ) - - globals()[dag_id] = dag diff --git a/airflow/stats.py b/airflow/stats.py index ddcb307c3ef2d..92ff3809f695b 100644 --- a/airflow/stats.py +++ b/airflow/stats.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import datetime import logging @@ -22,7 +23,7 @@ import string import time from functools import wraps -from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar, Union, cast +from typing import TYPE_CHECKING, Callable, TypeVar, cast from airflow.configuration import conf from airflow.exceptions import AirflowConfigException, InvalidStatsNameException @@ -65,7 +66,7 @@ def gauge(cls, stat: str, value: float, rate: int = 1, delta: bool = False) -> N """Gauge stat""" @classmethod - def timing(cls, stat: str, dt: Union[float, datetime.timedelta]) -> None: + def timing(cls, stat: str, dt: float | datetime.timedelta) -> None: """Stats timing""" @classmethod @@ -125,8 +126,8 @@ class Timer: # pystatsd and dogstatsd both have a timer class, but present different API # so we can't use this as a mixin on those, instead this class is contains the "real" timer - _start_time: Optional[int] - duration: Optional[int] + _start_time: int | None + duration: int | None def __init__(self, real_timer=None): self.real_timer = real_timer @@ -178,7 +179,7 @@ def timer(cls, *args, **kwargs): # Only characters in the character set are considered valid # for the stat_name if stat_name_default_handler is used. -ALLOWED_CHARACTERS = set(string.ascii_letters + string.digits + '_.-') +ALLOWED_CHARACTERS = set(string.ascii_letters + string.digits + "_.-") def stat_name_default_handler(stat_name, max_length=250) -> str: @@ -186,21 +187,22 @@ def stat_name_default_handler(stat_name, max_length=250) -> str: if necessary and return the transformed stat name. """ if not isinstance(stat_name, str): - raise InvalidStatsNameException('The stat_name has to be a string') + raise InvalidStatsNameException("The stat_name has to be a string") if len(stat_name) > max_length: raise InvalidStatsNameException( f"The stat_name ({stat_name}) has to be less than {max_length} characters." ) if not all((c in ALLOWED_CHARACTERS) for c in stat_name): raise InvalidStatsNameException( - f"The stat name ({stat_name}) has to be composed with characters in {ALLOWED_CHARACTERS}." + f"The stat name ({stat_name}) has to be composed of ASCII " + f"alphabets, numbers, or the underscore, dot, or dash characters." ) return stat_name def get_current_handler_stat_name_func() -> Callable[[str], str]: """Get Stat Name Handler from airflow.cfg""" - return conf.getimport('metrics', 'stat_name_handler') or stat_name_default_handler + return conf.getimport("metrics", "stat_name_handler") or stat_name_default_handler T = TypeVar("T", bound=Callable) @@ -219,7 +221,7 @@ def wrapper(_self, stat=None, *args, **kwargs): stat = handler_stat_name_func(stat) return fn(_self, stat, *args, **kwargs) except InvalidStatsNameException: - log.error('Invalid stat name: %s.', stat, exc_info=True) + log.exception("Invalid stat name: %s.", stat) return None return cast(T, wrapper) @@ -231,7 +233,7 @@ class AllowListValidator: def __init__(self, allow_list=None): if allow_list: - self.allow_list = tuple(item.strip().lower() for item in allow_list.split(',')) + self.allow_list = tuple(item.strip().lower() for item in allow_list.split(",")) else: self.allow_list = None @@ -318,7 +320,7 @@ def gauge(self, stat, value, rate=1, delta=False, tags=None): return None @validate_stat - def timing(self, stat, dt: Union[float, datetime.timedelta], tags: Optional[List[str]] = None): + def timing(self, stat, dt: float | datetime.timedelta, tags: list[str] | None = None): """Stats timing""" if self.allow_list_validator.test(stat): tags = tags or [] @@ -338,7 +340,7 @@ def timer(self, stat=None, *args, tags=None, **kwargs): class _Stats(type): factory = None - instance: Optional[StatsLogger] = None + instance: StatsLogger | None = None def __getattr__(cls, name): if not cls.instance: @@ -352,10 +354,10 @@ def __getattr__(cls, name): def __init__(cls, *args, **kwargs): super().__init__(cls) if cls.__class__.factory is None: - is_datadog_enabled_defined = conf.has_option('metrics', 'statsd_datadog_enabled') - if is_datadog_enabled_defined and conf.getboolean('metrics', 'statsd_datadog_enabled'): + is_datadog_enabled_defined = conf.has_option("metrics", "statsd_datadog_enabled") + if is_datadog_enabled_defined and conf.getboolean("metrics", "statsd_datadog_enabled"): cls.__class__.factory = cls.get_dogstatsd_logger - elif conf.getboolean('metrics', 'statsd_on'): + elif conf.getboolean("metrics", "statsd_on"): cls.__class__.factory = cls.get_statsd_logger else: cls.__class__.factory = DummyStatsLogger @@ -367,7 +369,7 @@ def get_statsd_logger(cls): # and previously it would crash with None is callable if it was called without it. from statsd import StatsClient - stats_class = conf.getimport('metrics', 'statsd_custom_client_path', fallback=None) + stats_class = conf.getimport("metrics", "statsd_custom_client_path", fallback=None) if stats_class: if not issubclass(stats_class, StatsClient): @@ -382,11 +384,11 @@ def get_statsd_logger(cls): stats_class = StatsClient statsd = stats_class( - host=conf.get('metrics', 'statsd_host'), - port=conf.getint('metrics', 'statsd_port'), - prefix=conf.get('metrics', 'statsd_prefix'), + host=conf.get("metrics", "statsd_host"), + port=conf.getint("metrics", "statsd_port"), + prefix=conf.get("metrics", "statsd_prefix"), ) - allow_list_validator = AllowListValidator(conf.get('metrics', 'statsd_allow_list', fallback=None)) + allow_list_validator = AllowListValidator(conf.get("metrics", "statsd_allow_list", fallback=None)) return SafeStatsdLogger(statsd, allow_list_validator) @classmethod @@ -395,12 +397,12 @@ def get_dogstatsd_logger(cls): from datadog import DogStatsd dogstatsd = DogStatsd( - host=conf.get('metrics', 'statsd_host'), - port=conf.getint('metrics', 'statsd_port'), - namespace=conf.get('metrics', 'statsd_prefix'), + host=conf.get("metrics", "statsd_host"), + port=conf.getint("metrics", "statsd_port"), + namespace=conf.get("metrics", "statsd_prefix"), constant_tags=cls.get_constant_tags(), ) - dogstatsd_allow_list = conf.get('metrics', 'statsd_allow_list', fallback=None) + dogstatsd_allow_list = conf.get("metrics", "statsd_allow_list", fallback=None) allow_list_validator = AllowListValidator(dogstatsd_allow_list) return SafeDogStatsdLogger(dogstatsd, allow_list_validator) @@ -408,11 +410,11 @@ def get_dogstatsd_logger(cls): def get_constant_tags(cls): """Get constant DataDog tags to add to all stats""" tags = [] - tags_in_string = conf.get('metrics', 'statsd_datadog_tags', fallback=None) - if tags_in_string is None or tags_in_string == '': + tags_in_string = conf.get("metrics", "statsd_datadog_tags", fallback=None) + if tags_in_string is None or tags_in_string == "": return tags else: - for key_value in tags_in_string.split(','): + for key_value in tags_in_string.split(","): tags.append(key_value) return tags diff --git a/airflow/task/task_runner/__init__.py b/airflow/task/task_runner/__init__.py index dc63883e5b4bd..fce577c33fef8 100644 --- a/airflow/task/task_runner/__init__.py +++ b/airflow/task/task_runner/__init__.py @@ -15,17 +15,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +from __future__ import annotations import logging from airflow.configuration import conf from airflow.exceptions import AirflowConfigException +from airflow.task.task_runner.base_task_runner import BaseTaskRunner from airflow.utils.module_loading import import_string log = logging.getLogger(__name__) -_TASK_RUNNER_NAME = conf.get('core', 'TASK_RUNNER') +_TASK_RUNNER_NAME = conf.get("core", "TASK_RUNNER") STANDARD_TASK_RUNNER = "StandardTaskRunner" @@ -37,14 +38,13 @@ } -def get_task_runner(local_task_job): +def get_task_runner(local_task_job) -> BaseTaskRunner: """ Get the task runner that can be used to run the given job. :param local_task_job: The LocalTaskJob associated with the TaskInstance that needs to be executed. :return: The task runner to use to run the task. - :rtype: airflow.task.task_runner.base_task_runner.BaseTaskRunner """ if _TASK_RUNNER_NAME in CORE_TASK_RUNNERS: log.debug("Loading core task runner: %s", _TASK_RUNNER_NAME) diff --git a/airflow/task/task_runner/base_task_runner.py b/airflow/task/task_runner/base_task_runner.py index 55dcf05d34b43..095c9b7300ff8 100644 --- a/airflow/task/task_runner/base_task_runner.py +++ b/airflow/task/task_runner/base_task_runner.py @@ -15,18 +15,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Base task runner""" +"""Base task runner.""" +from __future__ import annotations + import os import subprocess import threading +from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager from airflow.utils.platform import IS_WINDOWS if not IS_WINDOWS: # ignored to avoid flake complaining on Linux from pwd import getpwnam # noqa -from typing import Optional from airflow.configuration import conf from airflow.exceptions import AirflowConfigException @@ -35,13 +37,14 @@ from airflow.utils.net import get_hostname from airflow.utils.platform import getuser -PYTHONPATH_VAR = 'PYTHONPATH' +PYTHONPATH_VAR = "PYTHONPATH" class BaseTaskRunner(LoggingMixin): """ - Runs Airflow task instances by invoking the `airflow tasks run` command with raw - mode enabled in a subprocess. + Runs Airflow task instances via CLI. + + Invoke the `airflow tasks run` command with raw mode enabled in a subprocess. :param local_task_job: The local task job associated with running the associated task instance. @@ -57,7 +60,7 @@ def __init__(self, local_task_job): self.run_as_user = self._task_instance.run_as_user else: try: - self.run_as_user = conf.get('core', 'default_impersonation') + self.run_as_user = conf.get("core", "default_impersonation") except AirflowConfigException: self.run_as_user = None @@ -72,14 +75,14 @@ def __init__(self, local_task_job): cfg_path = tmp_configuration_copy(chmod=0o600, include_env=True, include_cmds=True) # Give ownership of file to user; only they can read and write - subprocess.check_call(['sudo', 'chown', self.run_as_user, cfg_path], close_fds=True) + subprocess.check_call(["sudo", "chown", self.run_as_user, cfg_path], close_fds=True) # propagate PYTHONPATH environment variable - pythonpath_value = os.environ.get(PYTHONPATH_VAR, '') - popen_prepend = ['sudo', '-E', '-H', '-u', self.run_as_user] + pythonpath_value = os.environ.get(PYTHONPATH_VAR, "") + popen_prepend = ["sudo", "-E", "-H", "-u", self.run_as_user] if pythonpath_value: - popen_prepend.append(f'{PYTHONPATH_VAR}={pythonpath_value}') + popen_prepend.append(f"{PYTHONPATH_VAR}={pythonpath_value}") else: # Always provide a copy of the configuration file settings. Since @@ -103,49 +106,51 @@ def _read_task_logs(self, stream): while True: line = stream.readline() if isinstance(line, bytes): - line = line.decode('utf-8') + line = line.decode("utf-8") if not line: break self.log.info( - 'Job %s: Subtask %s %s', + "Job %s: Subtask %s %s", self._task_instance.job_id, self._task_instance.task_id, - line.rstrip('\n'), + line.rstrip("\n"), ) - def run_command(self, run_with=None): + def run_command(self, run_with=None) -> subprocess.Popen: """ Run the task command. :param run_with: list of tokens to run the task command with e.g. ``['bash', '-c']`` :return: the process that was run - :rtype: subprocess.Popen """ run_with = run_with or [] full_cmd = run_with + self._command self.log.info("Running on host: %s", get_hostname()) - self.log.info('Running: %s', full_cmd) - - if IS_WINDOWS: - proc = subprocess.Popen( - full_cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - universal_newlines=True, - close_fds=True, - env=os.environ.copy(), - ) - else: - proc = subprocess.Popen( - full_cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - universal_newlines=True, - close_fds=True, - env=os.environ.copy(), - preexec_fn=os.setsid, - ) + self.log.info("Running: %s", full_cmd) + with _airflow_parsing_context_manager( + dag_id=self._task_instance.dag_id, + task_id=self._task_instance.task_id, + ): + if IS_WINDOWS: + proc = subprocess.Popen( + full_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + close_fds=True, + env=os.environ.copy(), + ) + else: + proc = subprocess.Popen( + full_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + close_fds=True, + env=os.environ.copy(), + preexec_fn=os.setsid, + ) # Start daemon thread to read subprocess logging output log_reader = threading.Thread( @@ -160,11 +165,12 @@ def start(self): """Start running the task instance in a subprocess.""" raise NotImplementedError() - def return_code(self) -> Optional[int]: + def return_code(self, timeout: int = 0) -> int | None: """ + Extract the return code. + :return: The return code associated with running the task instance or None if the task is not yet done. - :rtype: int """ raise NotImplementedError() @@ -176,6 +182,6 @@ def on_finish(self) -> None: """A callback that should be called when this is done running.""" if self._cfg_path and os.path.isfile(self._cfg_path): if self.run_as_user: - subprocess.call(['sudo', 'rm', self._cfg_path], close_fds=True) + subprocess.call(["sudo", "rm", self._cfg_path], close_fds=True) else: os.remove(self._cfg_path) diff --git a/airflow/task/task_runner/cgroup_task_runner.py b/airflow/task/task_runner/cgroup_task_runner.py index d6c6e53abf935..0bd3f616dee07 100644 --- a/airflow/task/task_runner/cgroup_task_runner.py +++ b/airflow/task/task_runner/cgroup_task_runner.py @@ -15,8 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -"""Task runner for cgroup to run Airflow task""" +"""Task runner for cgroup to run Airflow task.""" +from __future__ import annotations import datetime import os @@ -33,9 +33,10 @@ class CgroupTaskRunner(BaseTaskRunner): """ - Runs the raw Airflow task in a cgroup that has containment for memory and - cpu. It uses the resource requirements defined in the task to construct - the settings for the cgroup. + Runs the raw Airflow task in a cgroup container. + + With containment for memory and cpu. It uses the resource requirements + defined in the task to construct the settings for the cgroup. Cgroup must be mounted first otherwise CgroupTaskRunner will not be able to work. @@ -72,14 +73,13 @@ def __init__(self, local_task_job): self._created_mem_cgroup = False self._cur_user = getuser() - def _create_cgroup(self, path): + def _create_cgroup(self, path) -> trees.Node: """ Create the specified cgroup. :param path: The path of the cgroup to create. E.g. cpu/mygroup/mysubgroup :return: the Node associated with the created cgroup. - :rtype: cgroupspy.nodes.Node """ node = trees.Tree().root path_split = path.split(os.sep) @@ -161,9 +161,9 @@ def start(self): # Start the process w/ cgroups self.log.debug("Starting task process with cgroups cpu,memory: %s", cgroup_name) - self.process = self.run_command(['cgexec', '-g', f'cpu,memory:{cgroup_name}']) + self.process = self.run_command(["cgexec", "-g", f"cpu,memory:{cgroup_name}"]) - def return_code(self): + def return_code(self, timeout: int = 0) -> int | None: return_code = self.process.poll() # TODO(plypaul) Monitoring the control file in the cgroup fs is better than # checking the return code here. The PR to use this is here: @@ -189,7 +189,7 @@ def _log_memory_usage(self, mem_cgroup_node): def byte_to_gb(num_bytes, precision=2): return round(num_bytes / (1024 * 1024 * 1024), precision) - with open(mem_cgroup_node.full_path + '/memory.max_usage_in_bytes') as f: + with open(mem_cgroup_node.full_path + "/memory.max_usage_in_bytes") as f: max_usage_in_bytes = int(f.read().strip()) used_gb = byte_to_gb(max_usage_in_bytes) @@ -217,10 +217,11 @@ def on_finish(self): super().on_finish() @staticmethod - def _get_cgroup_names(): + def _get_cgroup_names() -> dict[str, str]: """ + Get the mapping between the subsystem name and the cgroup name. + :return: a mapping between the subsystem name to the cgroup name - :rtype: dict[str, str] """ with open("/proc/self/cgroup") as file: lines = file.readlines() diff --git a/airflow/task/task_runner/standard_task_runner.py b/airflow/task/task_runner/standard_task_runner.py index 53f873ec1f7b0..4d2d55e9276d6 100644 --- a/airflow/task/task_runner/standard_task_runner.py +++ b/airflow/task/task_runner/standard_task_runner.py @@ -15,17 +15,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Standard task runner""" +"""Standard task runner.""" +from __future__ import annotations + import logging import os -from typing import Optional import psutil from setproctitle import setproctitle from airflow.settings import CAN_FORK from airflow.task.task_runner.base_task_runner import BaseTaskRunner -from airflow.utils.process_utils import reap_process_group +from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager +from airflow.utils.process_utils import reap_process_group, set_new_process_group class StandardTaskRunner(BaseTaskRunner): @@ -34,6 +36,7 @@ class StandardTaskRunner(BaseTaskRunner): def __init__(self, local_task_job): super().__init__(local_task_job) self._rc = None + self.dag = local_task_job.task_instance.task.dag def start(self): if CAN_FORK and not self.run_as_user: @@ -53,7 +56,7 @@ def _start_by_fork(self): return psutil.Process(pid) else: # Start a new process group - os.setpgid(0, 0) + set_new_process_group() import signal signal.signal(signal.SIGINT, signal.SIG_DFL) @@ -62,7 +65,6 @@ def _start_by_fork(self): from airflow import settings from airflow.cli.cli_parser import get_parser from airflow.sentry import Sentry - from airflow.utils.cli import get_dag # Force a new SQLAlchemy session. We can't share open DB handles # between process. The cli code will re-create this as part of its @@ -77,20 +79,21 @@ def _start_by_fork(self): # We prefer the job_id passed on the command-line because at this time, the # task instance may not have been updated. job_id = getattr(args, "job_id", self._task_instance.job_id) - self.log.info('Running: %s', self._command) - self.log.info('Job %s: Subtask %s', job_id, self._task_instance.task_id) + self.log.info("Running: %s", self._command) + self.log.info("Job %s: Subtask %s", job_id, self._task_instance.task_id) proc_title = "airflow task runner: {0.dag_id} {0.task_id} {0.execution_date_or_run_id}" if job_id is not None: proc_title += " {0.job_id}" setproctitle(proc_title.format(args)) - return_code = 0 try: - # parse dag file since `airflow tasks run --local` does not parse dag file - dag = get_dag(args.subdir, args.dag_id) - args.func(args, dag=dag) - return_code = 0 + with _airflow_parsing_context_manager( + dag_id=self._task_instance.dag_id, + task_id=self._task_instance.task_id, + ): + args.func(args, dag=self.dag) + return_code = 0 except Exception as exc: return_code = 1 @@ -129,7 +132,7 @@ def _start_by_fork(self): # deleted at os._exit() os._exit(return_code) - def return_code(self, timeout: int = 0) -> Optional[int]: + def return_code(self, timeout: int = 0) -> int | None: # We call this multiple times, but we can only wait on the process once if self._rc is not None or not self.process: return self._rc @@ -163,6 +166,6 @@ def terminate(self): # If either we or psutil gives out a -9 return code, it likely means # an OOM happened self.log.error( - 'Job %s was killed before it finished (likely due to running out of memory)', + "Job %s was killed before it finished (likely due to running out of memory)", self._task_instance.job_id, ) diff --git a/airflow/templates.py b/airflow/templates.py index 6ec010f618fd3..2b466a49526b4 100644 --- a/airflow/templates.py +++ b/airflow/templates.py @@ -15,6 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + +import datetime import jinja2.nativetypes import jinja2.sandbox @@ -44,30 +47,40 @@ class SandboxedEnvironment(_AirflowEnvironmentMixin, jinja2.sandbox.SandboxedEnv """SandboxedEnvironment for Airflow task templates.""" -def ds_filter(value): - return value.strftime('%Y-%m-%d') +def ds_filter(value: datetime.date | datetime.time | None) -> str | None: + if value is None: + return None + return value.strftime("%Y-%m-%d") -def ds_nodash_filter(value): - return value.strftime('%Y%m%d') +def ds_nodash_filter(value: datetime.date | datetime.time | None) -> str | None: + if value is None: + return None + return value.strftime("%Y%m%d") -def ts_filter(value): +def ts_filter(value: datetime.date | datetime.time | None) -> str | None: + if value is None: + return None return value.isoformat() -def ts_nodash_filter(value): - return value.strftime('%Y%m%dT%H%M%S') +def ts_nodash_filter(value: datetime.date | datetime.time | None) -> str | None: + if value is None: + return None + return value.strftime("%Y%m%dT%H%M%S") -def ts_nodash_with_tz_filter(value): - return value.isoformat().replace('-', '').replace(':', '') +def ts_nodash_with_tz_filter(value: datetime.date | datetime.time | None) -> str | None: + if value is None: + return None + return value.isoformat().replace("-", "").replace(":", "") FILTERS = { - 'ds': ds_filter, - 'ds_nodash': ds_nodash_filter, - 'ts': ts_filter, - 'ts_nodash': ts_nodash_filter, - 'ts_nodash_with_tz': ts_nodash_with_tz_filter, + "ds": ds_filter, + "ds_nodash": ds_nodash_filter, + "ts": ts_filter, + "ts_nodash": ts_nodash_filter, + "ts_nodash_with_tz": ts_nodash_with_tz_filter, } diff --git a/airflow/ti_deps/dep_context.py b/airflow/ti_deps/dep_context.py index 8a0ae14dc097d..829e396417d71 100644 --- a/airflow/ti_deps/dep_context.py +++ b/airflow/ti_deps/dep_context.py @@ -15,8 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING import attr from sqlalchemy.orm.session import Session @@ -71,16 +72,18 @@ class DepContext: ignore_task_deps: bool = False ignore_ti_state: bool = False ignore_unmapped_tasks: bool = False - finished_tis: Optional[List["TaskInstance"]] = None + finished_tis: list[TaskInstance] | None = None - def ensure_finished_tis(self, dag_run: "DagRun", session: Session) -> List["TaskInstance"]: + have_changed_ti_states: bool = False + """Have any of the TIs state's been changed as a result of evaluating dependencies""" + + def ensure_finished_tis(self, dag_run: DagRun, session: Session) -> list[TaskInstance]: """ This method makes sure finished_tis is populated if it's currently None. This is for the strange feature of running tasks without dag_run. :param dag_run: The DagRun for which to find finished tasks :return: A list of all the finished tasks of this DAG and execution_date - :rtype: list[airflow.models.TaskInstance] """ if self.finished_tis is None: finished_tis = dag_run.get_task_instances(state=State.finished, session=session) diff --git a/airflow/ti_deps/dependencies_deps.py b/airflow/ti_deps/dependencies_deps.py index cfacb11297be5..cf02513eab468 100644 --- a/airflow/ti_deps/dependencies_deps.py +++ b/airflow/ti_deps/dependencies_deps.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from airflow.ti_deps.dependencies_states import ( BACKFILL_QUEUEABLE_STATES, diff --git a/airflow/ti_deps/dependencies_states.py b/airflow/ti_deps/dependencies_states.py index d6f4b473724d3..543ce3528a1b5 100644 --- a/airflow/ti_deps/dependencies_states.py +++ b/airflow/ti_deps/dependencies_states.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from airflow.utils.state import State diff --git a/airflow/ti_deps/deps/base_ti_dep.py b/airflow/ti_deps/deps/base_ti_dep.py index 109545b8c541d..9e7d49699c08e 100644 --- a/airflow/ti_deps/deps/base_ti_dep.py +++ b/airflow/ti_deps/deps/base_ti_dep.py @@ -15,12 +15,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, Any, Iterator, NamedTuple from airflow.ti_deps.dep_context import DepContext from airflow.utils.session import provide_session +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from airflow.models.taskinstance import TaskInstance + class BaseTIDep: """ @@ -37,27 +43,29 @@ class BaseTIDep: # to some tasks (e.g. depends_on_past is not specified by all tasks). IS_TASK_DEP = False - def __init__(self): - pass - - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return isinstance(self, type(other)) - def __hash__(self): + def __hash__(self) -> int: return hash(type(self)) - def __repr__(self): + def __repr__(self) -> str: return f"" @property - def name(self): - """ - The human-readable name for the dependency. Use the classname as the default name - if this method is not overridden in the subclass. - """ - return getattr(self, 'NAME', self.__class__.__name__) + def name(self) -> str: + """The human-readable name for the dependency. - def _get_dep_statuses(self, ti, session, dep_context): + Use the class name as the default if ``NAME`` is not provided. + """ + return getattr(self, "NAME", self.__class__.__name__) + + def _get_dep_statuses( + self, + ti: TaskInstance, + session: Session, + dep_context: DepContext, + ) -> Iterator[TIDepStatus]: """ Abstract method that returns an iterable of TIDepStatus objects that describe whether the given task instance has this dependency met. @@ -72,7 +80,12 @@ def _get_dep_statuses(self, ti, session, dep_context): raise NotImplementedError @provide_session - def get_dep_statuses(self, ti, session, dep_context=None): + def get_dep_statuses( + self, + ti: TaskInstance, + session: Session, + dep_context: DepContext | None = None, + ) -> Iterator[TIDepStatus]: """ Wrapper around the private _get_dep_statuses method that contains some global checks for all dependencies. @@ -81,21 +94,20 @@ def get_dep_statuses(self, ti, session, dep_context=None): :param session: database session :param dep_context: the context for which this dependency should be evaluated for """ - if dep_context is None: - dep_context = DepContext() + cxt = DepContext() if dep_context is None else dep_context - if self.IGNORABLE and dep_context.ignore_all_deps: + if self.IGNORABLE and cxt.ignore_all_deps: yield self._passing_status(reason="Context specified all dependencies should be ignored.") return - if self.IS_TASK_DEP and dep_context.ignore_task_deps: + if self.IS_TASK_DEP and cxt.ignore_task_deps: yield self._passing_status(reason="Context specified all task dependencies should be ignored.") return - yield from self._get_dep_statuses(ti, session, dep_context) + yield from self._get_dep_statuses(ti, session, cxt) @provide_session - def is_met(self, ti, session, dep_context=None): + def is_met(self, ti: TaskInstance, session: Session, dep_context: DepContext | None = None) -> bool: """ Returns whether or not this dependency is met for a given task instance. A dependency is considered met if all of the dependency statuses it reports are @@ -109,7 +121,12 @@ def is_met(self, ti, session, dep_context=None): return all(status.passed for status in self.get_dep_statuses(ti, session, dep_context)) @provide_session - def get_failure_reasons(self, ti, session, dep_context=None): + def get_failure_reasons( + self, + ti: TaskInstance, + session: Session, + dep_context: DepContext | None = None, + ) -> Iterator[str]: """ Returns an iterable of strings that explain why this dependency wasn't met. @@ -122,10 +139,10 @@ def get_failure_reasons(self, ti, session, dep_context=None): if not dep_status.passed: yield dep_status.reason - def _failing_status(self, reason=''): + def _failing_status(self, reason: str = "") -> TIDepStatus: return TIDepStatus(self.name, False, reason) - def _passing_status(self, reason=''): + def _passing_status(self, reason: str = "") -> TIDepStatus: return TIDepStatus(self.name, True, reason) diff --git a/airflow/ti_deps/deps/dag_ti_slots_available_dep.py b/airflow/ti_deps/deps/dag_ti_slots_available_dep.py index e9dadd1d594c7..550695afeb04a 100644 --- a/airflow/ti_deps/deps/dag_ti_slots_available_dep.py +++ b/airflow/ti_deps/deps/dag_ti_slots_available_dep.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils.session import provide_session diff --git a/airflow/ti_deps/deps/dag_unpaused_dep.py b/airflow/ti_deps/deps/dag_unpaused_dep.py index cb1784d373ba4..a854d21632d7d 100644 --- a/airflow/ti_deps/deps/dag_unpaused_dep.py +++ b/airflow/ti_deps/deps/dag_unpaused_dep.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils.session import provide_session diff --git a/airflow/ti_deps/deps/dagrun_backfill_dep.py b/airflow/ti_deps/deps/dagrun_backfill_dep.py index 949af43d38a4d..f45e564577e2d 100644 --- a/airflow/ti_deps/deps/dagrun_backfill_dep.py +++ b/airflow/ti_deps/deps/dagrun_backfill_dep.py @@ -15,8 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module defines dep for making sure DagRun not a backfill.""" +from __future__ import annotations from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils.session import provide_session diff --git a/airflow/ti_deps/deps/dagrun_exists_dep.py b/airflow/ti_deps/deps/dagrun_exists_dep.py index 0910aec7c7056..781ab0ebaf022 100644 --- a/airflow/ti_deps/deps/dagrun_exists_dep.py +++ b/airflow/ti_deps/deps/dagrun_exists_dep.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils.session import provide_session diff --git a/airflow/ti_deps/deps/exec_date_after_start_date_dep.py b/airflow/ti_deps/deps/exec_date_after_start_date_dep.py index a8b34963ccd5a..09e0d8c229615 100644 --- a/airflow/ti_deps/deps/exec_date_after_start_date_dep.py +++ b/airflow/ti_deps/deps/exec_date_after_start_date_dep.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils.session import provide_session diff --git a/airflow/ti_deps/deps/mapped_task_expanded.py b/airflow/ti_deps/deps/mapped_task_expanded.py index 149644ebc926c..87a804006be45 100644 --- a/airflow/ti_deps/deps/mapped_task_expanded.py +++ b/airflow/ti_deps/deps/mapped_task_expanded.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from airflow.ti_deps.deps.base_ti_dep import BaseTIDep diff --git a/airflow/ti_deps/deps/not_in_retry_period_dep.py b/airflow/ti_deps/deps/not_in_retry_period_dep.py index 17ba4d7bae00e..b3b5d4ec568fe 100644 --- a/airflow/ti_deps/deps/not_in_retry_period_dep.py +++ b/airflow/ti_deps/deps/not_in_retry_period_dep.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils import timezone diff --git a/airflow/ti_deps/deps/not_previously_skipped_dep.py b/airflow/ti_deps/deps/not_previously_skipped_dep.py index 7c7a8221420dc..64161ade88d62 100644 --- a/airflow/ti_deps/deps/not_previously_skipped_dep.py +++ b/airflow/ti_deps/deps/not_previously_skipped_dep.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from airflow.ti_deps.deps.base_ti_dep import BaseTIDep diff --git a/airflow/ti_deps/deps/pool_slots_available_dep.py b/airflow/ti_deps/deps/pool_slots_available_dep.py index bc45b6cebf02e..138dbac9c7106 100644 --- a/airflow/ti_deps/deps/pool_slots_available_dep.py +++ b/airflow/ti_deps/deps/pool_slots_available_dep.py @@ -15,8 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """This module defines dep for pool slots availability""" +from __future__ import annotations from airflow.ti_deps.dependencies_states import EXECUTION_STATES from airflow.ti_deps.deps.base_ti_dep import BaseTIDep diff --git a/airflow/ti_deps/deps/prev_dagrun_dep.py b/airflow/ti_deps/deps/prev_dagrun_dep.py index 391563a12770a..dce4b8527f887 100644 --- a/airflow/ti_deps/deps/prev_dagrun_dep.py +++ b/airflow/ti_deps/deps/prev_dagrun_dep.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from sqlalchemy import func from airflow.models.taskinstance import TaskInstance as TI @@ -66,7 +68,7 @@ def _get_dep_statuses(self, ti: TI, session, dep_context): yield self._passing_status(reason="This task instance was the first task instance for its task.") return - previous_ti = last_dagrun.get_task_instance(ti.task_id, session=session) + previous_ti = last_dagrun.get_task_instance(ti.task_id, map_index=ti.map_index, session=session) if not previous_ti: if ti.task.ignore_first_depends_on_past: has_historical_ti = ( diff --git a/airflow/ti_deps/deps/ready_to_reschedule.py b/airflow/ti_deps/deps/ready_to_reschedule.py index 9086822ceac8d..6ac9f492f0062 100644 --- a/airflow/ti_deps/deps/ready_to_reschedule.py +++ b/airflow/ti_deps/deps/ready_to_reschedule.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from airflow.models.taskreschedule import TaskReschedule from airflow.ti_deps.deps.base_ti_dep import BaseTIDep @@ -40,7 +41,13 @@ def _get_dep_statuses(self, ti, session, dep_context): considered as passed. This dependency fails if the latest reschedule request's reschedule date is still in future. """ - if not getattr(ti.task, "reschedule", False): + from airflow.models.mappedoperator import MappedOperator + + is_mapped = isinstance(ti.task, MappedOperator) + if not is_mapped and not getattr(ti.task, "reschedule", False): + # Mapped sensors don't have the reschedule property (it can only + # be calculated after unmapping), so we don't check them here. + # They are handled below by checking TaskReschedule instead. yield self._passing_status(reason="Task is not in reschedule mode.") return @@ -62,6 +69,11 @@ def _get_dep_statuses(self, ti, session, dep_context): .first() ) if not task_reschedule: + # Because mapped sensors don't have the reschedule property, here's the last resort + # and we need a slightly different passing reason + if is_mapped: + yield self._passing_status(reason="The task is mapped and not in reschedule mode") + return yield self._passing_status(reason="There is no reschedule request for this task instance.") return diff --git a/airflow/ti_deps/deps/runnable_exec_date_dep.py b/airflow/ti_deps/deps/runnable_exec_date_dep.py index ab26b9ef5b6df..b7a23e91e27e6 100644 --- a/airflow/ti_deps/deps/runnable_exec_date_dep.py +++ b/airflow/ti_deps/deps/runnable_exec_date_dep.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils import timezone diff --git a/airflow/ti_deps/deps/task_concurrency_dep.py b/airflow/ti_deps/deps/task_concurrency_dep.py index 264edb86fd683..5b5f4f515acf0 100644 --- a/airflow/ti_deps/deps/task_concurrency_dep.py +++ b/airflow/ti_deps/deps/task_concurrency_dep.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils.session import provide_session diff --git a/airflow/ti_deps/deps/task_not_running_dep.py b/airflow/ti_deps/deps/task_not_running_dep.py index 59df2fff07112..fd76873466d33 100644 --- a/airflow/ti_deps/deps/task_not_running_dep.py +++ b/airflow/ti_deps/deps/task_not_running_dep.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. """Contains the TaskNotRunningDep.""" +from __future__ import annotations from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils.session import provide_session @@ -40,4 +41,4 @@ def _get_dep_statuses(self, ti, session, dep_context=None): yield self._passing_status(reason="Task is not in running state.") return - yield self._failing_status(reason='Task is in the running state') + yield self._failing_status(reason="Task is in the running state") diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py index 1a65467896084..7d78b591af323 100644 --- a/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow/ti_deps/deps/trigger_rule_dep.py @@ -15,24 +15,58 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from collections import Counter -from typing import TYPE_CHECKING +import collections +import collections.abc +import functools +from typing import TYPE_CHECKING, Iterator, NamedTuple -from sqlalchemy import func +from sqlalchemy import and_, func, or_ from airflow.ti_deps.dep_context import DepContext -from airflow.ti_deps.deps.base_ti_dep import BaseTIDep -from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.state import State +from airflow.ti_deps.deps.base_ti_dep import BaseTIDep, TIDepStatus +from airflow.utils.state import TaskInstanceState from airflow.utils.trigger_rule import TriggerRule as TR if TYPE_CHECKING: from sqlalchemy.orm import Session + from sqlalchemy.sql.expression import ColumnOperators from airflow.models.taskinstance import TaskInstance +class _UpstreamTIStates(NamedTuple): + """States of the upstream tis for a specific ti. + + This is used to determine whether the specific ti can run in this iteration. + """ + + success: int + skipped: int + failed: int + upstream_failed: int + removed: int + done: int + + @classmethod + def calculate(cls, finished_upstreams: Iterator[TaskInstance]) -> _UpstreamTIStates: + """Calculate states for a task instance. + + :param ti: the ti that we want to calculate deps for + :param finished_tis: all the finished tasks of the dag_run + """ + counter = collections.Counter(ti.state for ti in finished_upstreams) + return _UpstreamTIStates( + success=counter.get(TaskInstanceState.SUCCESS, 0), + skipped=counter.get(TaskInstanceState.SKIPPED, 0), + failed=counter.get(TaskInstanceState.FAILED, 0), + upstream_failed=counter.get(TaskInstanceState.UPSTREAM_FAILED, 0), + removed=counter.get(TaskInstanceState.REMOVED, 0), + done=sum(counter.values()), + ) + + class TriggerRuleDep(BaseTIDep): """ Determines if a task's upstream tasks are in a state that allows a given task instance @@ -43,159 +77,200 @@ class TriggerRuleDep(BaseTIDep): IGNORABLE = True IS_TASK_DEP = True - @staticmethod - def _get_states_count_upstream_ti(task, finished_tis): - """ - This function returns the states of the upstream tis for a specific ti in order to determine - whether this ti can run in this iteration - - :param ti: the ti that we want to calculate deps for - :param finished_tis: all the finished tasks of the dag_run - """ - counter = Counter(ti.state for ti in finished_tis if ti.task_id in task.upstream_task_ids) - return ( - counter.get(State.SUCCESS, 0), - counter.get(State.SKIPPED, 0), - counter.get(State.FAILED, 0), - counter.get(State.UPSTREAM_FAILED, 0), - sum(counter.values()), - ) - - @provide_session - def _get_dep_statuses(self, ti, session, dep_context: DepContext): - # Checking that all upstream dependencies have succeeded - if not ti.task.upstream_list: + def _get_dep_statuses( + self, + ti: TaskInstance, + session: Session, + dep_context: DepContext, + ) -> Iterator[TIDepStatus]: + # Checking that all upstream dependencies have succeeded. + if not ti.task.upstream_task_ids: yield self._passing_status(reason="The task instance did not have any upstream tasks.") return - if ti.task.trigger_rule == TR.ALWAYS: yield self._passing_status(reason="The task had a always trigger rule set.") return - # see if the task name is in the task upstream for our task - successes, skipped, failed, upstream_failed, done = self._get_states_count_upstream_ti( - task=ti.task, finished_tis=dep_context.ensure_finished_tis(ti.get_dagrun(session), session) - ) + yield from self._evaluate_trigger_rule(ti=ti, dep_context=dep_context, session=session) - yield from self._evaluate_trigger_rule( - ti=ti, - successes=successes, - skipped=skipped, - failed=failed, - upstream_failed=upstream_failed, - done=done, - flag_upstream_failed=dep_context.flag_upstream_failed, - session=session, - ) + def _evaluate_trigger_rule( + self, + *, + ti: TaskInstance, + dep_context: DepContext, + session: Session, + ) -> Iterator[TIDepStatus]: + """Evaluate whether ``ti``'s trigger rule was met. - @staticmethod - def _count_upstreams(ti: "TaskInstance", *, session: "Session"): + :param ti: Task instance to evaluate the trigger rule of. + :param dep_context: The current dependency context. + :param session: Database session. + """ + from airflow.models.abstractoperator import NotMapped + from airflow.models.expandinput import NotFullyPopulated + from airflow.models.operator import needs_expansion from airflow.models.taskinstance import TaskInstance - # Optimization: Don't need to hit the database if no upstreams are mapped. - upstream_task_ids = ti.task.upstream_task_ids - if ti.task.dag and not any(ti.task.dag.get_task(tid).is_mapped for tid in upstream_task_ids): - return len(upstream_task_ids) + task = ti.task + upstream_tasks = {t.task_id: t for t in task.upstream_list} + trigger_rule = task.trigger_rule + + @functools.lru_cache() + def _get_expanded_ti_count() -> int: + """Get how many tis the current task is supposed to be expanded into. - # We don't naively count task instances because it is not guaranteed - # that all upstreams have been created in the database at this point. - # Instead, we look for already-expanded tasks, and add them to the raw - # task count without considering mapping. - mapped_tis_addition = ( - session.query(func.count()) - .filter( - TaskInstance.dag_id == ti.dag_id, - TaskInstance.run_id == ti.run_id, - TaskInstance.task_id.in_(upstream_task_ids), - TaskInstance.map_index > 0, + This extra closure allows us to query the database only when needed, + and at most once. + """ + return task.get_mapped_ti_count(ti.run_id, session=session) + + @functools.lru_cache() + def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None: + """Get the given task's map indexes relevant to the current ti. + + This extra closure allows us to query the database only when needed, + and at most once for each task (instead of once for each expanded + task instance of the same task). + """ + try: + expanded_ti_count = _get_expanded_ti_count() + except (NotFullyPopulated, NotMapped): + return None + return ti.get_relevant_upstream_map_indexes( + upstream_tasks[upstream_id], + expanded_ti_count, + session=session, ) - .scalar() + + def _is_relevant_upstream(upstream: TaskInstance) -> bool: + """Whether a task instance is a "relevant upstream" of the current task.""" + # Not actually an upstream task. + if upstream.task_id not in task.upstream_task_ids: + return False + # The current task is not in a mapped task group. All tis from an + # upstream task are relevant. + if task.get_closest_mapped_task_group() is None: + return True + # The upstream ti is not expanded. The upstream may be mapped or + # not, but the ti is relevant either way. + if upstream.map_index < 0: + return True + # Now we need to perform fine-grained check on whether this specific + # upstream ti's map index is relevant. + relevant = _get_relevant_upstream_map_indexes(upstream.task_id) + if relevant is None: + return True + if relevant == upstream.map_index: + return True + if isinstance(relevant, collections.abc.Container) and upstream.map_index in relevant: + return True + return False + + finished_upstream_tis = ( + finished_ti + for finished_ti in dep_context.ensure_finished_tis(ti.get_dagrun(session), session) + if _is_relevant_upstream(finished_ti) ) - return len(upstream_task_ids) + mapped_tis_addition + upstream_states = _UpstreamTIStates.calculate(finished_upstream_tis) - @provide_session - def _evaluate_trigger_rule( - self, - ti, - successes, - skipped, - failed, - upstream_failed, - done, - flag_upstream_failed, - session: "Session" = NEW_SESSION, - ): - """ - Yields a dependency status that indicate whether the given task instance's trigger - rule was met. + success = upstream_states.success + skipped = upstream_states.skipped + failed = upstream_states.failed + upstream_failed = upstream_states.upstream_failed + removed = upstream_states.removed + done = upstream_states.done - :param ti: the task instance to evaluate the trigger rule of - :param successes: Number of successful upstream tasks - :param skipped: Number of skipped upstream tasks - :param failed: Number of failed upstream tasks - :param upstream_failed: Number of upstream_failed upstream tasks - :param done: Number of completed upstream tasks - :param flag_upstream_failed: This is a hack to generate - the upstream_failed state creation while checking to see - whether the task instance is runnable. It was the shortest - path to add the feature - :param session: database session - """ - task = ti.task - upstream = self._count_upstreams(ti, session=session) - trigger_rule: TR = task.trigger_rule + def _iter_upstream_conditions() -> Iterator[ColumnOperators]: + # Optimization: If the current task is not in a mapped task group, + # it depends on all upstream task instances. + if task.get_closest_mapped_task_group() is None: + yield TaskInstance.task_id.in_(upstream_tasks) + return + # Otherwise we need to figure out which map indexes are depended on + # for each upstream by the current task instance. + for upstream_id in upstream_tasks: + map_indexes = _get_relevant_upstream_map_indexes(upstream_id) + if map_indexes is None: # All tis of this upstream are dependencies. + yield (TaskInstance.task_id == upstream_id) + continue + # At this point we know we want to depend on only selected tis + # of this upstream task. Since the upstream may not have been + # expanded at this point, we also depend on the non-expanded ti + # to ensure at least one ti is included for the task. + yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index < 0) + if isinstance(map_indexes, range) and map_indexes.step == 1: + yield and_( + TaskInstance.task_id == upstream_id, + TaskInstance.map_index >= map_indexes.start, + TaskInstance.map_index < map_indexes.stop, + ) + elif isinstance(map_indexes, collections.abc.Container): + yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index.in_(map_indexes)) + else: + yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index == map_indexes) + + # Optimization: Don't need to hit the database if all upstreams are + # "simple" tasks (no task or task group mapping involved). + if not any(needs_expansion(t) for t in upstream_tasks.values()): + upstream = len(upstream_tasks) + else: + upstream = ( + session.query(func.count()) + .filter(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id) + .filter(or_(*_iter_upstream_conditions())) + .scalar() + ) upstream_done = done >= upstream - upstream_tasks_state = { - "total": upstream, - "successes": successes, - "skipped": skipped, - "failed": failed, - "upstream_failed": upstream_failed, - "done": done, - } - # TODO(aoen): Ideally each individual trigger rules would be its own class, but - # this isn't very feasible at the moment since the database queries need to be - # bundled together for efficiency. - # handling instant state assignment based on trigger rules - if flag_upstream_failed: + + changed = False + if dep_context.flag_upstream_failed: if trigger_rule == TR.ALL_SUCCESS: if upstream_failed or failed: - ti.set_state(State.UPSTREAM_FAILED, session) + changed = ti.set_state(TaskInstanceState.UPSTREAM_FAILED, session) elif skipped: - ti.set_state(State.SKIPPED, session) + changed = ti.set_state(TaskInstanceState.SKIPPED, session) + elif removed and success and ti.map_index > -1: + if ti.map_index >= success: + changed = ti.set_state(TaskInstanceState.REMOVED, session) elif trigger_rule == TR.ALL_FAILED: - if successes or skipped: - ti.set_state(State.SKIPPED, session) + if success or skipped: + changed = ti.set_state(TaskInstanceState.SKIPPED, session) elif trigger_rule == TR.ONE_SUCCESS: if upstream_done and done == skipped: # if upstream is done and all are skipped mark as skipped - ti.set_state(State.SKIPPED, session) - elif upstream_done and successes <= 0: - # if upstream is done and there are no successes mark as upstream failed - ti.set_state(State.UPSTREAM_FAILED, session) + changed = ti.set_state(TaskInstanceState.SKIPPED, session) + elif upstream_done and success <= 0: + # if upstream is done and there are no success mark as upstream failed + changed = ti.set_state(TaskInstanceState.UPSTREAM_FAILED, session) elif trigger_rule == TR.ONE_FAILED: if upstream_done and not (failed or upstream_failed): - ti.set_state(State.SKIPPED, session) + changed = ti.set_state(TaskInstanceState.SKIPPED, session) + elif trigger_rule == TR.ONE_DONE: + if upstream_done and not (failed or success): + changed = ti.set_state(TaskInstanceState.SKIPPED, session) elif trigger_rule == TR.NONE_FAILED: if upstream_failed or failed: - ti.set_state(State.UPSTREAM_FAILED, session) + changed = ti.set_state(TaskInstanceState.UPSTREAM_FAILED, session) elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS: if upstream_failed or failed: - ti.set_state(State.UPSTREAM_FAILED, session) + changed = ti.set_state(TaskInstanceState.UPSTREAM_FAILED, session) elif skipped == upstream: - ti.set_state(State.SKIPPED, session) + changed = ti.set_state(TaskInstanceState.SKIPPED, session) elif trigger_rule == TR.NONE_SKIPPED: if skipped: - ti.set_state(State.SKIPPED, session) + changed = ti.set_state(TaskInstanceState.SKIPPED, session) elif trigger_rule == TR.ALL_SKIPPED: - if successes or failed: - ti.set_state(State.SKIPPED, session) + if success or failed: + changed = ti.set_state(TaskInstanceState.SKIPPED, session) + + if changed: + dep_context.have_changed_ti_states = True if trigger_rule == TR.ONE_SUCCESS: - if successes <= 0: + if success <= 0: yield self._failing_status( reason=( f"Task's trigger rule '{trigger_rule}' requires one upstream task success, " - f"but none were found. upstream_tasks_state={upstream_tasks_state}, " + f"but none were found. upstream_states={upstream_states}, " f"upstream_task_ids={task.upstream_task_ids}" ) ) @@ -204,29 +279,43 @@ def _evaluate_trigger_rule( yield self._failing_status( reason=( f"Task's trigger rule '{trigger_rule}' requires one upstream task failure, " - f"but none were found. upstream_tasks_state={upstream_tasks_state}, " + f"but none were found. upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) + ) + elif trigger_rule == TR.ONE_DONE: + if success + failed <= 0: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}'" + "requires at least one upstream task failure or success" + f"but none were failed or success. upstream_states={upstream_states}, " f"upstream_task_ids={task.upstream_task_ids}" ) ) elif trigger_rule == TR.ALL_SUCCESS: - num_failures = upstream - successes + num_failures = upstream - success + if ti.map_index > -1: + num_failures -= removed if num_failures > 0: yield self._failing_status( reason=( f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " f"succeeded, but found {num_failures} non-success(es). " - f"upstream_tasks_state={upstream_tasks_state}, " + f"upstream_states={upstream_states}, " f"upstream_task_ids={task.upstream_task_ids}" ) ) elif trigger_rule == TR.ALL_FAILED: - num_successes = upstream - failed - upstream_failed - if num_successes > 0: + num_success = upstream - failed - upstream_failed + if ti.map_index > -1: + num_success -= removed + if num_success > 0: yield self._failing_status( reason=( f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have failed, " - f"but found {num_successes} non-failure(s). " - f"upstream_tasks_state={upstream_tasks_state}, " + f"but found {num_success} non-failure(s). " + f"upstream_states={upstream_states}, " f"upstream_task_ids={task.upstream_task_ids}" ) ) @@ -236,29 +325,33 @@ def _evaluate_trigger_rule( reason=( f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " f"completed, but found {upstream_done} task(s) that were not done. " - f"upstream_tasks_state={upstream_tasks_state}, " + f"upstream_states={upstream_states}, " f"upstream_task_ids={task.upstream_task_ids}" ) ) elif trigger_rule == TR.NONE_FAILED: - num_failures = upstream - successes - skipped + num_failures = upstream - success - skipped + if ti.map_index > -1: + num_failures -= removed if num_failures > 0: yield self._failing_status( reason=( f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " f"succeeded or been skipped, but found {num_failures} non-success(es). " - f"upstream_tasks_state={upstream_tasks_state}, " + f"upstream_states={upstream_states}, " f"upstream_task_ids={task.upstream_task_ids}" ) ) elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS: - num_failures = upstream - successes - skipped + num_failures = upstream - success - skipped + if ti.map_index > -1: + num_failures -= removed if num_failures > 0: yield self._failing_status( reason=( f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " f"succeeded or been skipped, but found {num_failures} non-success(es). " - f"upstream_tasks_state={upstream_tasks_state}, " + f"upstream_states={upstream_states}, " f"upstream_task_ids={task.upstream_task_ids}" ) ) @@ -268,7 +361,7 @@ def _evaluate_trigger_rule( reason=( f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to not have been " f"skipped, but found {skipped} task(s) skipped. " - f"upstream_tasks_state={upstream_tasks_state}, " + f"upstream_states={upstream_states}, " f"upstream_task_ids={task.upstream_task_ids}" ) ) @@ -279,7 +372,7 @@ def _evaluate_trigger_rule( reason=( f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have been " f"skipped, but found {num_non_skipped} task(s) in non skipped state. " - f"upstream_tasks_state={upstream_tasks_state}, " + f"upstream_states={upstream_states}, " f"upstream_task_ids={task.upstream_task_ids}" ) ) diff --git a/airflow/ti_deps/deps/valid_state_dep.py b/airflow/ti_deps/deps/valid_state_dep.py index 4216ed3c417da..e4a8a678290c6 100644 --- a/airflow/ti_deps/deps/valid_state_dep.py +++ b/airflow/ti_deps/deps/valid_state_dep.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations from airflow.exceptions import AirflowException from airflow.ti_deps.deps.base_ti_dep import BaseTIDep @@ -37,7 +38,7 @@ def __init__(self, valid_states): super().__init__() if not valid_states: - raise AirflowException('ValidStatesDep received an empty set of valid states.') + raise AirflowException("ValidStatesDep received an empty set of valid states.") self._valid_states = valid_states def __eq__(self, other): diff --git a/airflow/timetables/_cron.py b/airflow/timetables/_cron.py new file mode 100644 index 0000000000000..b1e315a7d14f1 --- /dev/null +++ b/airflow/timetables/_cron.py @@ -0,0 +1,138 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime +from typing import Any + +from cron_descriptor import CasingTypeEnum, ExpressionDescriptor, FormatException, MissingFieldException +from croniter import CroniterBadCronError, CroniterBadDateError, croniter +from pendulum import DateTime +from pendulum.tz.timezone import Timezone + +from airflow.compat.functools import cached_property +from airflow.exceptions import AirflowTimetableInvalid +from airflow.utils.dates import cron_presets +from airflow.utils.timezone import convert_to_utc, make_aware, make_naive + + +def _is_schedule_fixed(expression: str) -> bool: + """Figures out if the schedule has a fixed time (e.g. 3 AM every day). + + :return: True if the schedule has a fixed time, False if not. + + Detection is done by "peeking" the next two cron trigger time; if the + two times have the same minute and hour value, the schedule is fixed, + and we *don't* need to perform the DST fix. + + This assumes DST happens on whole minute changes (e.g. 12:59 -> 12:00). + """ + cron = croniter(expression) + next_a = cron.get_next(datetime.datetime) + next_b = cron.get_next(datetime.datetime) + return next_b.minute == next_a.minute and next_b.hour == next_a.hour + + +class CronMixin: + """Mixin to provide interface to work with croniter.""" + + def __init__(self, cron: str, timezone: str | Timezone) -> None: + self._expression = cron_presets.get(cron, cron) + + if isinstance(timezone, str): + timezone = Timezone(timezone) + self._timezone = timezone + + descriptor = ExpressionDescriptor( + expression=self._expression, casing_type=CasingTypeEnum.Sentence, use_24hour_time_format=True + ) + try: + # checking for more than 5 parameters in Cron and avoiding evaluation for now, + # as Croniter has inconsistent evaluation with other libraries + if len(croniter(self._expression).expanded) > 5: + raise FormatException() + interval_description = descriptor.get_description() + except (CroniterBadCronError, FormatException, MissingFieldException): + interval_description = "" + self.description = interval_description + + def __eq__(self, other: Any) -> bool: + """Both expression and timezone should match. + + This is only for testing purposes and should not be relied on otherwise. + """ + if not isinstance(other, type(self)): + return NotImplemented + return self._expression == other._expression and self._timezone == other._timezone + + @property + def summary(self) -> str: + return self._expression + + def validate(self) -> None: + try: + croniter(self._expression) + except (CroniterBadCronError, CroniterBadDateError) as e: + raise AirflowTimetableInvalid(str(e)) + + @cached_property + def _should_fix_dst(self) -> bool: + # This is lazy so instantiating a schedule does not immediately raise + # an exception. Validity is checked with validate() during DAG-bagging. + return not _is_schedule_fixed(self._expression) + + def _get_next(self, current: DateTime) -> DateTime: + """Get the first schedule after specified time, with DST fixed.""" + naive = make_naive(current, self._timezone) + cron = croniter(self._expression, start_time=naive) + scheduled = cron.get_next(datetime.datetime) + if not self._should_fix_dst: + return convert_to_utc(make_aware(scheduled, self._timezone)) + delta = scheduled - naive + return convert_to_utc(current.in_timezone(self._timezone) + delta) + + def _get_prev(self, current: DateTime) -> DateTime: + """Get the first schedule before specified time, with DST fixed.""" + naive = make_naive(current, self._timezone) + cron = croniter(self._expression, start_time=naive) + scheduled = cron.get_prev(datetime.datetime) + if not self._should_fix_dst: + return convert_to_utc(make_aware(scheduled, self._timezone)) + delta = naive - scheduled + return convert_to_utc(current.in_timezone(self._timezone) - delta) + + def _align_to_next(self, current: DateTime) -> DateTime: + """Get the next scheduled time. + + This is ``current + interval``, unless ``current`` falls right on the + interval boundary, when ``current`` is returned. + """ + next_time = self._get_next(current) + if self._get_prev(next_time) != current: + return next_time + return current + + def _align_to_prev(self, current: DateTime) -> DateTime: + """Get the prev scheduled time. + + This is ``current - interval``, unless ``current`` falls right on the + interval boundary, when ``current`` is returned. + """ + prev_time = self._get_prev(current) + if self._get_next(prev_time) != current: + return prev_time + return current diff --git a/airflow/timetables/base.py b/airflow/timetables/base.py index 168f2af82d68d..19a86e96b3e8f 100644 --- a/airflow/timetables/base.py +++ b/airflow/timetables/base.py @@ -14,12 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import Any, Dict, NamedTuple, Optional +from typing import TYPE_CHECKING, Any, NamedTuple, Sequence from pendulum import DateTime -from airflow.typing_compat import Protocol +from airflow.typing_compat import Protocol, runtime_checkable + +if TYPE_CHECKING: + from airflow.utils.types import DagRunType class DataInterval(NamedTuple): @@ -33,7 +37,7 @@ class DataInterval(NamedTuple): end: DateTime @classmethod - def exact(cls, at: DateTime) -> "DataInterval": + def exact(cls, at: DateTime) -> DataInterval: """Represent an "interval" containing only an exact time.""" return cls(start=at, end=at) @@ -54,8 +58,8 @@ class TimeRestriction(NamedTuple): created by Airflow. """ - earliest: Optional[DateTime] - latest: Optional[DateTime] + earliest: DateTime | None + latest: DateTime | None catchup: bool @@ -76,12 +80,12 @@ class DagRunInfo(NamedTuple): """The data interval this DagRun to operate over.""" @classmethod - def exact(cls, at: DateTime) -> "DagRunInfo": + def exact(cls, at: DateTime) -> DagRunInfo: """Represent a run on an exact time.""" return cls(run_after=at, data_interval=DataInterval.exact(at)) @classmethod - def interval(cls, start: DateTime, end: DateTime) -> "DagRunInfo": + def interval(cls, start: DateTime, end: DateTime) -> DagRunInfo: """Represent a run on a continuous schedule. In such a schedule, each data interval starts right after the previous @@ -91,7 +95,7 @@ def interval(cls, start: DateTime, end: DateTime) -> "DagRunInfo": return cls(run_after=end, data_interval=DataInterval(start, end)) @property - def logical_date(self: "DagRunInfo") -> DateTime: + def logical_date(self: DagRunInfo) -> DateTime: """Infer the logical date to represent a DagRun. This replaces ``execution_date`` in Airflow 2.1 and prior. The idea is @@ -100,6 +104,7 @@ def logical_date(self: "DagRunInfo") -> DateTime: return self.data_interval.start +@runtime_checkable class Timetable(Protocol): """Protocol that all Timetable classes are expected to implement.""" @@ -114,7 +119,7 @@ class Timetable(Protocol): """Whether this timetable runs periodically. This defaults to and should generally be *True*, but some special setups - like ``schedule_interval=None`` and ``"@once"`` set it to *False*. + like ``schedule=None`` and ``"@once"`` set it to *False*. """ can_run: bool = True @@ -124,8 +129,14 @@ class Timetable(Protocol): this to *False*. """ + run_ordering: Sequence[str] = ("data_interval_end", "execution_date") + """How runs triggered from this timetable should be ordered in UI. + + This should be a list of field names on the DAG run object. + """ + @classmethod - def deserialize(cls, data: Dict[str, Any]) -> "Timetable": + def deserialize(cls, data: dict[str, Any]) -> Timetable: """Deserialize a timetable from data. This is called when a serialized DAG is deserialized. ``data`` will be @@ -134,7 +145,7 @@ def deserialize(cls, data: Dict[str, Any]) -> "Timetable": """ return cls() - def serialize(self) -> Dict[str, Any]: + def serialize(self) -> dict[str, Any]: """Serialize the timetable for JSON encoding. This is called during DAG serialization to store timetable information @@ -175,9 +186,9 @@ def infer_manual_data_interval(self, *, run_after: DateTime) -> DataInterval: def next_dagrun_info( self, *, - last_automated_data_interval: Optional[DataInterval], + last_automated_data_interval: DataInterval | None, restriction: TimeRestriction, - ) -> Optional[DagRunInfo]: + ) -> DagRunInfo | None: """Provide information to schedule the next DagRun. The default implementation raises ``NotImplementedError``. @@ -193,3 +204,13 @@ def next_dagrun_info( a DagRunInfo object when asked at another time. """ raise NotImplementedError() + + def generate_run_id( + self, + *, + run_type: DagRunType, + logical_date: DateTime, + data_interval: DataInterval | None, + **extra, + ) -> str: + return run_type.generate_run_id(logical_date) diff --git a/airflow/timetables/events.py b/airflow/timetables/events.py index 8024045f014a8..a59f4fc5b2077 100644 --- a/airflow/timetables/events.py +++ b/airflow/timetables/events.py @@ -14,8 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import itertools -from typing import Iterable, Optional +from typing import Iterable import pendulum from pendulum import DateTime @@ -25,8 +27,9 @@ class EventsTimetable(Timetable): """ - Timetable that schedules DAG runs at specific listed datetimes. Suitable for - predictable but truly irregular scheduling such as sporting events. + Timetable that schedules DAG runs at specific listed datetimes. + + Suitable for predictable but truly irregular scheduling such as sporting events. :param event_dates: List of datetimes for the DAG to run at. Duplicates will be ignored. Must be finite and of reasonable size as it will be loaded in its entirety. @@ -43,7 +46,7 @@ def __init__( event_dates: Iterable[DateTime], restrict_to_events: bool = False, presorted: bool = False, - description: Optional[str] = None, + description: str | None = None, ): self.event_dates = list(event_dates) # Must be reversible and indexable @@ -70,9 +73,9 @@ def __repr__(self): def next_dagrun_info( self, *, - last_automated_data_interval: Optional[DataInterval], + last_automated_data_interval: DataInterval | None, restriction: TimeRestriction, - ) -> Optional[DagRunInfo]: + ) -> DagRunInfo | None: if last_automated_data_interval is None: next_event = self.event_dates[0] else: diff --git a/airflow/timetables/interval.py b/airflow/timetables/interval.py index 5ed9cd21d2571..50478c6f7551e 100644 --- a/airflow/timetables/interval.py +++ b/airflow/timetables/interval.py @@ -14,21 +14,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import datetime -from typing import Any, Dict, Optional, Union +from typing import Any, Union -from cron_descriptor import CasingTypeEnum, ExpressionDescriptor, FormatException, MissingFieldException -from croniter import CroniterBadCronError, CroniterBadDateError, croniter from dateutil.relativedelta import relativedelta from pendulum import DateTime -from pendulum.tz.timezone import Timezone -from airflow.compat.functools import cached_property from airflow.exceptions import AirflowTimetableInvalid +from airflow.timetables._cron import CronMixin from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable -from airflow.utils.dates import cron_presets -from airflow.utils.timezone import convert_to_utc, make_aware, make_naive +from airflow.utils.timezone import convert_to_utc Delta = Union[datetime.timedelta, relativedelta] @@ -41,7 +38,7 @@ class _DataIntervalTimetable(Timetable): instance), and schedule a DagRun at the end of each interval. """ - def _skip_to_latest(self, earliest: Optional[DateTime]) -> DateTime: + def _skip_to_latest(self, earliest: DateTime | None) -> DateTime: """Bound the earliest time a run can be scheduled. This is called when ``catchup=False``. See docstring of subclasses for @@ -49,8 +46,8 @@ def _skip_to_latest(self, earliest: Optional[DateTime]) -> DateTime: """ raise NotImplementedError() - def _align(self, current: DateTime) -> DateTime: - """Align given time to the scheduled. + def _align_to_next(self, current: DateTime) -> DateTime: + """Align given time to the next scheduled time. For fixed schedules (e.g. every midnight); this finds the next time that aligns to the declared time, if the given time does not align. If the @@ -58,6 +55,19 @@ def _align(self, current: DateTime) -> DateTime: """ raise NotImplementedError() + def _align_to_prev(self, current: DateTime) -> DateTime: + """Align given time to the previous scheduled time. + + For fixed schedules (e.g. every midnight); this finds the prev time that + aligns to the declared time, if the given time does not align. If the + schedule is not fixed (e.g. every hour), the given time is returned. + + It is not enough to use ``_get_prev(_align_to_next())``, since when a + DAG's schedule changes, this alternative would make the first scheduling + after the schedule change remain the same. + """ + raise NotImplementedError() + def _get_next(self, current: DateTime) -> DateTime: """Get the first schedule after the current time.""" raise NotImplementedError() @@ -69,14 +79,14 @@ def _get_prev(self, current: DateTime) -> DateTime: def next_dagrun_info( self, *, - last_automated_data_interval: Optional[DataInterval], + last_automated_data_interval: DataInterval | None, restriction: TimeRestriction, - ) -> Optional[DagRunInfo]: + ) -> DagRunInfo | None: earliest = restriction.earliest if not restriction.catchup: earliest = self._skip_to_latest(earliest) elif earliest is not None: - earliest = self._align(earliest) + earliest = self._align_to_next(earliest) if last_automated_data_interval is None: # First run; schedule the run at the first available time matching # the schedule, and retrospectively create a data interval for it. @@ -84,40 +94,25 @@ def next_dagrun_info( return None start = earliest else: # There's a previous run. + # Alignment is needed when DAG has new schedule interval. + align_last_data_interval_end = self._align_to_prev(last_automated_data_interval.end) if earliest is not None: # Catchup is False or DAG has new start date in the future. # Make sure we get the later one. - start = max(last_automated_data_interval.end, earliest) + start = max(align_last_data_interval_end, earliest) else: # Data interval starts from the end of the previous interval. - start = last_automated_data_interval.end + start = align_last_data_interval_end if restriction.latest is not None and start > restriction.latest: return None end = self._get_next(start) return DagRunInfo.interval(start=start, end=end) -def _is_schedule_fixed(expression: str) -> bool: - """Figures out if the schedule has a fixed time (e.g. 3 AM every day). - - :return: True if the schedule has a fixed time, False if not. - - Detection is done by "peeking" the next two cron trigger time; if the - two times have the same minute and hour value, the schedule is fixed, - and we *don't* need to perform the DST fix. - - This assumes DST happens on whole minute changes (e.g. 12:59 -> 12:00). - """ - cron = croniter(expression) - next_a = cron.get_next(datetime.datetime) - next_b = cron.get_next(datetime.datetime) - return next_b.minute == next_a.minute and next_b.hour == next_a.hour - - -class CronDataIntervalTimetable(_DataIntervalTimetable): +class CronDataIntervalTimetable(CronMixin, _DataIntervalTimetable): """Timetable that schedules data intervals with a cron expression. - This corresponds to ``schedule_interval=``, where ```` is either + This corresponds to ``schedule=``, where ```` is either a five/six-segment representation, or one of ``cron_presets``. The implementation extends on croniter to add timezone awareness. This is @@ -127,94 +122,18 @@ class CronDataIntervalTimetable(_DataIntervalTimetable): Don't pass ``@once`` in here; use ``OnceTimetable`` instead. """ - def __init__(self, cron: str, timezone: Union[str, Timezone]) -> None: - self._expression = cron_presets.get(cron, cron) - - if isinstance(timezone, str): - timezone = Timezone(timezone) - self._timezone = timezone - - descriptor = ExpressionDescriptor( - expression=self._expression, casing_type=CasingTypeEnum.Sentence, use_24hour_time_format=True - ) - try: - # checking for more than 5 parameters in Cron and avoiding evaluation for now, - # as Croniter has inconsistent evaluation with other libraries - if len(croniter(self._expression).expanded) > 5: - raise FormatException() - interval_description = descriptor.get_description() - except (CroniterBadCronError, FormatException, MissingFieldException): - interval_description = "" - self.description = interval_description - @classmethod - def deserialize(cls, data: Dict[str, Any]) -> "Timetable": + def deserialize(cls, data: dict[str, Any]) -> Timetable: from airflow.serialization.serialized_objects import decode_timezone return cls(data["expression"], decode_timezone(data["timezone"])) - def __eq__(self, other: Any) -> bool: - """Both expression and timezone should match. - - This is only for testing purposes and should not be relied on otherwise. - """ - if not isinstance(other, CronDataIntervalTimetable): - return NotImplemented - return self._expression == other._expression and self._timezone == other._timezone - - @property - def summary(self) -> str: - return self._expression - - def serialize(self) -> Dict[str, Any]: + def serialize(self) -> dict[str, Any]: from airflow.serialization.serialized_objects import encode_timezone return {"expression": self._expression, "timezone": encode_timezone(self._timezone)} - def validate(self) -> None: - try: - croniter(self._expression) - except (CroniterBadCronError, CroniterBadDateError) as e: - raise AirflowTimetableInvalid(str(e)) - - @cached_property - def _should_fix_dst(self) -> bool: - # This is lazy so instantiating a schedule does not immediately raise - # an exception. Validity is checked with validate() during DAG-bagging. - return not _is_schedule_fixed(self._expression) - - def _get_next(self, current: DateTime) -> DateTime: - """Get the first schedule after specified time, with DST fixed.""" - naive = make_naive(current, self._timezone) - cron = croniter(self._expression, start_time=naive) - scheduled = cron.get_next(datetime.datetime) - if not self._should_fix_dst: - return convert_to_utc(make_aware(scheduled, self._timezone)) - delta = scheduled - naive - return convert_to_utc(current.in_timezone(self._timezone) + delta) - - def _get_prev(self, current: DateTime) -> DateTime: - """Get the first schedule before specified time, with DST fixed.""" - naive = make_naive(current, self._timezone) - cron = croniter(self._expression, start_time=naive) - scheduled = cron.get_prev(datetime.datetime) - if not self._should_fix_dst: - return convert_to_utc(make_aware(scheduled, self._timezone)) - delta = naive - scheduled - return convert_to_utc(current.in_timezone(self._timezone) - delta) - - def _align(self, current: DateTime) -> DateTime: - """Get the next scheduled time. - - This is ``current + interval``, unless ``current`` falls right on the - interval boundary, when ``current`` is returned. - """ - next_time = self._get_next(current) - if self._get_prev(next_time) != current: - return next_time - return current - - def _skip_to_latest(self, earliest: Optional[DateTime]) -> DateTime: + def _skip_to_latest(self, earliest: DateTime | None) -> DateTime: """Bound the earliest time a run can be scheduled. The logic is that we move start_date up until one period before, so the @@ -235,20 +154,20 @@ def _skip_to_latest(self, earliest: Optional[DateTime]) -> DateTime: raise AssertionError("next schedule shouldn't be earlier") if earliest is None: return new_start - return max(new_start, self._align(earliest)) + return max(new_start, self._align_to_next(earliest)) def infer_manual_data_interval(self, *, run_after: DateTime) -> DataInterval: # Get the last complete period before run_after, e.g. if a DAG run is # scheduled at each midnight, the data interval of a manually triggered # run at 1am 25th is between 0am 24th and 0am 25th. - end = self._get_prev(self._align(run_after)) + end = self._align_to_prev(run_after) return DataInterval(start=self._get_prev(end), end=end) class DeltaDataIntervalTimetable(_DataIntervalTimetable): """Timetable that schedules data intervals with a time delta. - This corresponds to ``schedule_interval=``, where ```` is + This corresponds to ``schedule=``, where ```` is either a ``datetime.timedelta`` or ``dateutil.relativedelta.relativedelta`` instance. """ @@ -257,7 +176,7 @@ def __init__(self, delta: Delta) -> None: self._delta = delta @classmethod - def deserialize(cls, data: Dict[str, Any]) -> "Timetable": + def deserialize(cls, data: dict[str, Any]) -> Timetable: from airflow.serialization.serialized_objects import decode_relativedelta delta = data["delta"] @@ -278,7 +197,7 @@ def __eq__(self, other: Any) -> bool: def summary(self) -> str: return str(self._delta) - def serialize(self) -> Dict[str, Any]: + def serialize(self) -> dict[str, Any]: from airflow.serialization.serialized_objects import encode_relativedelta delta: Any @@ -299,10 +218,13 @@ def _get_next(self, current: DateTime) -> DateTime: def _get_prev(self, current: DateTime) -> DateTime: return convert_to_utc(current - self._delta) - def _align(self, current: DateTime) -> DateTime: + def _align_to_next(self, current: DateTime) -> DateTime: + return current + + def _align_to_prev(self, current: DateTime) -> DateTime: return current - def _skip_to_latest(self, earliest: Optional[DateTime]) -> DateTime: + def _skip_to_latest(self, earliest: DateTime | None) -> DateTime: """Bound the earliest time a run can be scheduled. The logic is that we move start_date up until one period before, so the diff --git a/airflow/timetables/simple.py b/airflow/timetables/simple.py index 516bd291ab5a9..3ddd31b6911bb 100644 --- a/airflow/timetables/simple.py +++ b/airflow/timetables/simple.py @@ -14,22 +14,30 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import Any, Dict, Optional - -from pendulum import DateTime +import operator +from typing import TYPE_CHECKING, Any, Collection from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable +if TYPE_CHECKING: + from pendulum import DateTime + from sqlalchemy import Session + + from airflow.models.dataset import DatasetEvent + from airflow.utils.types import DagRunType + class _TrivialTimetable(Timetable): """Some code reuse for "trivial" timetables that has nothing complex.""" periodic = False can_run = False + run_ordering = ("execution_date",) @classmethod - def deserialize(cls, data: Dict[str, Any]) -> "Timetable": + def deserialize(cls, data: dict[str, Any]) -> Timetable: return cls() def __eq__(self, other: Any) -> bool: @@ -41,7 +49,7 @@ def __eq__(self, other: Any) -> bool: return NotImplemented return True - def serialize(self) -> Dict[str, Any]: + def serialize(self) -> dict[str, Any]: return {} def infer_manual_data_interval(self, *, run_after: DateTime) -> DataInterval: @@ -51,7 +59,7 @@ def infer_manual_data_interval(self, *, run_after: DateTime) -> DataInterval: class NullTimetable(_TrivialTimetable): """Timetable that never schedules anything. - This corresponds to ``schedule_interval=None``. + This corresponds to ``schedule=None``. """ description: str = "Never, external triggers only" @@ -63,16 +71,16 @@ def summary(self) -> str: def next_dagrun_info( self, *, - last_automated_data_interval: Optional[DataInterval], + last_automated_data_interval: DataInterval | None, restriction: TimeRestriction, - ) -> Optional[DagRunInfo]: + ) -> DagRunInfo | None: return None class OnceTimetable(_TrivialTimetable): """Timetable that schedules the execution once as soon as possible. - This corresponds to ``schedule_interval="@once"``. + This corresponds to ``schedule="@once"``. """ description: str = "Once, as soon as possible" @@ -84,9 +92,9 @@ def summary(self) -> str: def next_dagrun_info( self, *, - last_automated_data_interval: Optional[DataInterval], + last_automated_data_interval: DataInterval | None, restriction: TimeRestriction, - ) -> Optional[DagRunInfo]: + ) -> DagRunInfo | None: if last_automated_data_interval is not None: return None # Already run, no more scheduling. if restriction.earliest is None: # No start date, won't run. @@ -98,3 +106,49 @@ def next_dagrun_info( if restriction.latest is not None and run_after > restriction.latest: return None return DagRunInfo.exact(run_after) + + +class DatasetTriggeredTimetable(NullTimetable): + """Timetable that never schedules anything. + + This should not be directly used anywhere, but only set if a DAG is triggered by datasets. + + :meta private: + """ + + description: str = "Triggered by datasets" + + @property + def summary(self) -> str: + return "Dataset" + + def generate_run_id( + self, + *, + run_type: DagRunType, + logical_date: DateTime, + data_interval: DataInterval | None, + session: Session | None = None, + events: Collection[DatasetEvent] | None = None, + **extra, + ) -> str: + from airflow.models.dagrun import DagRun + + return DagRun.generate_run_id(run_type, logical_date) + + def data_interval_for_events( + self, + logical_date: DateTime, + events: Collection[DatasetEvent], + ) -> DataInterval: + + if not events: + return DataInterval(logical_date, logical_date) + + start = min( + events, key=operator.attrgetter("source_dag_run.data_interval_start") + ).source_dag_run.data_interval_start + end = max( + events, key=operator.attrgetter("source_dag_run.data_interval_end") + ).source_dag_run.data_interval_end + return DataInterval(start, end) diff --git a/airflow/timetables/trigger.py b/airflow/timetables/trigger.py new file mode 100644 index 0000000000000..7807542da5a63 --- /dev/null +++ b/airflow/timetables/trigger.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime +from typing import Any + +from dateutil.relativedelta import relativedelta +from pendulum import DateTime +from pendulum.tz.timezone import Timezone + +from airflow.timetables._cron import CronMixin +from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable + + +class CronTriggerTimetable(CronMixin, Timetable): + """Timetable that triggers DAG runs according to a cron expression. + + This is different from ``CronDataIntervalTimetable``, where the cron + expression specifies the *data interval* of a DAG run. With this timetable, + the data intervals are specified independently from the cron expression. + Also for the same reason, this timetable kicks off a DAG run immediately at + the start of the period (similar to POSIX cron), instead of needing to wait + for one data interval to pass. + + Don't pass ``@once`` in here; use ``OnceTimetable`` instead. + """ + + def __init__( + self, + cron: str, + *, + timezone: str | Timezone, + interval: datetime.timedelta | relativedelta = datetime.timedelta(), + ) -> None: + super().__init__(cron, timezone) + self._interval = interval + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> Timetable: + from airflow.serialization.serialized_objects import decode_relativedelta, decode_timezone + + interval: datetime.timedelta | relativedelta + if isinstance(data["interval"], dict): + interval = decode_relativedelta(data["interval"]) + else: + interval = datetime.timedelta(seconds=data["interval"]) + return cls(data["expression"], timezone=decode_timezone(data["timezone"]), interval=interval) + + def serialize(self) -> dict[str, Any]: + from airflow.serialization.serialized_objects import encode_relativedelta, encode_timezone + + interval: float | dict[str, Any] + if isinstance(self._interval, datetime.timedelta): + interval = self._interval.total_seconds() + else: + interval = encode_relativedelta(self._interval) + timezone = encode_timezone(self._timezone) + return {"expression": self._expression, "timezone": timezone, "interval": interval} + + def infer_manual_data_interval(self, *, run_after: DateTime) -> DataInterval: + return DataInterval(run_after - self._interval, run_after) + + def next_dagrun_info( + self, + *, + last_automated_data_interval: DataInterval | None, + restriction: TimeRestriction, + ) -> DagRunInfo | None: + if restriction.catchup: + if last_automated_data_interval is not None: + next_start_time = self._get_next(last_automated_data_interval.end) + elif restriction.earliest is None: + return None # Don't know where to catch up from, give up. + else: + next_start_time = self._align_to_next(restriction.earliest) + else: + start_time_candidates = [self._align_to_next(DateTime.utcnow())] + if last_automated_data_interval is not None: + start_time_candidates.append(self._get_next(last_automated_data_interval.end)) + if restriction.earliest is not None: + start_time_candidates.append(self._align_to_next(restriction.earliest)) + next_start_time = max(start_time_candidates) + if restriction.latest is not None and restriction.latest < next_start_time: + return None + return DagRunInfo.interval(next_start_time - self._interval, next_start_time) diff --git a/airflow/triggers/base.py b/airflow/triggers/base.py index ed5197545435b..6bfc0883efd6d 100644 --- a/airflow/triggers/base.py +++ b/airflow/triggers/base.py @@ -14,9 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import abc -from typing import Any, AsyncIterator, Dict, Tuple +from typing import Any, AsyncIterator from airflow.utils.log.logging_mixin import LoggingMixin @@ -39,7 +40,7 @@ def __init__(self, **kwargs): pass @abc.abstractmethod - def serialize(self) -> Tuple[str, Dict[str, Any]]: + def serialize(self) -> tuple[str, dict[str, Any]]: """ Returns the information needed to reconstruct this Trigger. @@ -65,10 +66,13 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: and then rely on cleanup() being called when they are no longer needed. """ raise NotImplementedError("Triggers must implement run()") + yield # To convince Mypy this is an async iterator. def cleanup(self) -> None: """ - Called when the trigger is no longer needed and it's being removed + Cleanup the trigger. + + Called when the trigger is no longer needed, and it's being removed from the active triggerer process. """ diff --git a/airflow/triggers/temporal.py b/airflow/triggers/temporal.py index 9b4d21dca4b31..3967940a7e9ba 100644 --- a/airflow/triggers/temporal.py +++ b/airflow/triggers/temporal.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import asyncio import datetime -from typing import Any, Dict, Tuple +from typing import Any from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.utils import timezone @@ -25,6 +26,8 @@ class DateTimeTrigger(BaseTrigger): """ + Trigger based on a datetime. + A trigger that fires exactly once, at the given datetime, give or take a few seconds. @@ -43,7 +46,7 @@ def __init__(self, moment: datetime.datetime): else: self.moment = moment - def serialize(self) -> Tuple[str, Dict[str, Any]]: + def serialize(self) -> tuple[str, dict[str, Any]]: return ("airflow.triggers.temporal.DateTimeTrigger", {"moment": self.moment}) async def run(self): @@ -67,6 +70,8 @@ async def run(self): class TimeDeltaTrigger(DateTimeTrigger): """ + Create DateTimeTriggers based on delays. + Subclass to create DateTimeTriggers based on time delays rather than exact moments. diff --git a/airflow/triggers/testing.py b/airflow/triggers/testing.py index 0d7d75f4ab547..ce38982209e17 100644 --- a/airflow/triggers/testing.py +++ b/airflow/triggers/testing.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import Any, Dict, Tuple +from typing import Any from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -27,7 +28,7 @@ class SuccessTrigger(BaseTrigger): Should only be used for testing. """ - def serialize(self) -> Tuple[str, Dict[str, Any]]: + def serialize(self) -> tuple[str, dict[str, Any]]: return ("airflow.triggers.testing.SuccessTrigger", {}) async def run(self): @@ -41,7 +42,7 @@ class FailureTrigger(BaseTrigger): Should only be used for testing. """ - def serialize(self) -> Tuple[str, Dict[str, Any]]: + def serialize(self) -> tuple[str, dict[str, Any]]: return ("airflow.triggers.testing.FailureTrigger", {}) async def run(self): diff --git a/airflow/typing_compat.py b/airflow/typing_compat.py index 163889b8a2975..5f2321962a9f8 100644 --- a/airflow/typing_compat.py +++ b/airflow/typing_compat.py @@ -15,16 +15,35 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """ This module provides helper code to make type annotation within Airflow codebase easier. """ +from __future__ import annotations + +__all__ = [ + "Literal", + "ParamSpec", + "Protocol", + "TypedDict", + "TypeGuard", + "runtime_checkable", +] + +import sys + +if sys.version_info >= (3, 8): + from typing import Protocol, TypedDict, runtime_checkable +else: + from typing_extensions import Protocol, TypedDict, runtime_checkable + +# Literal in 3.8 is limited to one single argument, not e.g. "Literal[1, 2]". +if sys.version_info >= (3, 9): + from typing import Literal +else: + from typing_extensions import Literal -try: - # Literal, Protocol and TypedDict are only added to typing module starting from - # python 3.8 we can safely remove this shim import after Airflow drops - # support for <3.8 - from typing import Literal, Protocol, TypedDict, runtime_checkable # type: ignore -except ImportError: - from typing_extensions import Literal, Protocol, TypedDict, runtime_checkable # type: ignore # noqa +if sys.version_info >= (3, 10): + from typing import ParamSpec, TypeGuard +else: + from typing_extensions import ParamSpec, TypeGuard diff --git a/airflow/ui/.env.example b/airflow/ui/.env.example deleted file mode 100644 index 454c936d56e64..0000000000000 --- a/airflow/ui/.env.example +++ /dev/null @@ -1,18 +0,0 @@ -#!/usr/bin/env bash -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -WEBSERVER_URL = 'http://127.0.0.1:28080' diff --git a/airflow/ui/.eslintrc.js b/airflow/ui/.eslintrc.js deleted file mode 100644 index 8e27638a715bc..0000000000000 --- a/airflow/ui/.eslintrc.js +++ /dev/null @@ -1,37 +0,0 @@ -/*! - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* -* Linting config -*/ -module.exports = { - env: { - jest: true, - }, - extends: ['airbnb-typescript', 'plugin:react-hooks/recommended'], - parserOptions: { - project: './tsconfig.json', - }, - rules: { - 'react/prop-types': 0, - 'react/jsx-props-no-spreading': 0, - 'arrow-body-style': 1, - 'react/jsx-one-expression-per-line': 1, - }, -}; diff --git a/airflow/ui/.gitignore b/airflow/ui/.gitignore deleted file mode 100644 index 2ea437be6e5a6..0000000000000 --- a/airflow/ui/.gitignore +++ /dev/null @@ -1,40 +0,0 @@ -# Logs -logs -*.log -npm-debug.log* -yarn-debug.log* -yarn-error.log* - -# Diagnostic reports (https://nodejs.org/api/report.html) -report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json - -# Runtime data -pids -*.pid -*.seed -*.pid.lock - -# Dependency directories -node_modules/ - -# TypeScript cache -*.tsbuildinfo - -# Optional npm cache directory -.npm - -# Optional eslint cache -.eslintcache - -# Yarn Integrity file -.yarn-integrity - -# yarn v2 -.yarn/cache -.yarn/unplugged -.yarn/build-state.yml -.yarn/install-state.gz -.pnp.* - -# Neutrino build directory -build diff --git a/airflow/ui/.neutrinorc.js b/airflow/ui/.neutrinorc.js deleted file mode 100644 index d9658016d1007..0000000000000 --- a/airflow/ui/.neutrinorc.js +++ /dev/null @@ -1,74 +0,0 @@ -/*! - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - Config for running and building the app -*/ -require('dotenv').config(); -const typescript = require('neutrinojs-typescript'); -const typescriptLint = require('neutrinojs-typescript-eslint'); -const react = require('@neutrinojs/react'); -const jest = require('@neutrinojs/jest'); -const eslint = require('@neutrinojs/eslint'); -const { resolve } = require('path'); -const copy = require('@neutrinojs/copy'); - -module.exports = { - options: { - root: __dirname, - }, - use: [ - (neutrino) => { - // Aliases for internal modules - neutrino.config.resolve.alias.set('root', resolve(__dirname)); - neutrino.config.resolve.alias.set('src', resolve(__dirname, 'src')); - neutrino.config.resolve.alias.set('views', resolve(__dirname, 'src/views')); - neutrino.config.resolve.alias.set('utils', resolve(__dirname, 'src/utils')); - neutrino.config.resolve.alias.set('providers', resolve(__dirname, 'src/providers')); - neutrino.config.resolve.alias.set('components', resolve(__dirname, 'src/components')); - neutrino.config.resolve.alias.set('interfaces', resolve(__dirname, 'src/interfaces')); - neutrino.config.resolve.alias.set('api', resolve(__dirname, 'src/api')); - }, - typescript(), - // Modify typescript config in .tsconfig.json - typescriptLint(), - eslint({ - eslint: { - // Modify eslint config in .eslintrc.js config instead - useEslintrc: true, - }, - }), - jest({ - moduleDirectories: ['node_modules', 'src'], - }), - react({ - env: [ - 'WEBSERVER_URL' - ], - html: { - title: 'Apache Airflow', - } - }), - copy({ - patterns: [ - { from: 'src/static/favicon.ico', to: '.' }, - ], - }), - ], -}; diff --git a/airflow/ui/README.md b/airflow/ui/README.md deleted file mode 100644 index e80bce40701d4..0000000000000 --- a/airflow/ui/README.md +++ /dev/null @@ -1,71 +0,0 @@ - - -# Airflow UI - -## Built with: - -- [React](https://reactjs.org/) - a JavaScript library for building user interfaces -- [TypeScript](https://www.typescriptlang.org/) - extends JavaScript by adding types. -- [Neutrino](https://neutrinojs.org/) - lets you build web and Node.js applications with shared presets or configurations. -- [Chakra UI](https://chakra-ui.com/) - a simple, modular and accessible component library that gives you all the building blocks you need to build your React applications. -- [React Testing Library](https://testing-library.com/docs/react-testing-library/intro/) - write tests that focus on functionality instead of implementation -- [React Query](https://react-query.tanstack.com/) - powerful async data handler. all API calls go through this - -## Environment variables - -To communicate with the API you need to adjust some environment variables for the webserver and this UI. - -Be sure to allow CORS headers and set up an auth backend on your Airflow instance. - -``` -export AIRFLOW__API__AUTH_BACKENDS=airflow.api.auth.backend.basic_auth -export AIRFLOW__API__ACCESS_CONTROL_ALLOW_HEADERS=* -export AIRFLOW__API__ACCESS_CONTROL_ALLOW_METHODS=* -export AIRFLOW__API__ACCESS_CONTROL_ALLOW_ORIGIN=http://127.0.0.1:28080 -``` - -Create your local environment and adjust the `WEBSERVER_URL` if needed. - -```bash -cp .env.example .env -``` - -## Installation - -Clone the repository and use the package manager [yarn](https://yarnpkg.com) to install dependencies and get the project running. - -```bash -yarn install -yarn start -``` - -Other useful commands include: - -```bash -yarn lint -``` - -```bash -yarn test -``` - -## Contributing - -Be sure to check out our [contribution guide](docs/CONTRIBUTING.md) diff --git a/airflow/ui/docs/CONTRIBUTING.md b/airflow/ui/docs/CONTRIBUTING.md deleted file mode 100644 index ca9002e60d6da..0000000000000 --- a/airflow/ui/docs/CONTRIBUTING.md +++ /dev/null @@ -1,52 +0,0 @@ - - -# Contributing to the UI - -## Learn - -If you're new to modern frontend development or parts of our stack, you may want to check out these resources to understand our codebase: - -- TypeScript is an extension of javascript to add type-checking to our app. Files ending in `.ts` or `.tsx` will be type-checked. Check out the [handbook](https://www.typescriptlang.org/docs/handbook/typescript-in-5-minutes-func.html) for an introduction or feel free to keep this [cheatsheet](https://github.com/typescript-cheatsheets/react) open while developing. - -- React powers our entire app, so it would be valuable to learn JSX, the html-in-js templates React utilizes. Files that contain JSX will end in `.tsx` instead of `.ts`. Check out their official [tutorial](https://reactjs.org/tutorial/tutorial.html#overview) for a basic overview. - -- Chakra-UI is our component library and theming system. You'll notice we have no traditional css nor html tags. This is all handled in Chakra with importing standard components like `` or `` that are styled globally in `src/theme.ts` file and then by passing styles as component props. Check out their [docs](https://chakra-ui.com/docs/getting-started) to see all the included components and hooks. - -- Testing is done with React Testing Library. We follow their idea of "The more your tests resemble the way your software is used, -the more confidence they can give you." Keep their [cheatsheet](https://testing-library.com/docs/react-testing-library/cheatsheet) open when writing tests - -- Neutrino handles our App's configuration and Webpack build. Check out their [docs](https://neutrinojs.org/api/) if you need to customize it. - -- State management is handled with [Context](https://reactjs.org/docs/context.html) and [react-query](https://react-query.tanstack.com/). Context is used for App-level state that doesn't change frequently (authentication, dark/light mode). React Query handles all the state and side effects (loading, error, caching, etc.) of async data from the API. - -## Project Structure - -- `src/index.tsx` is the entry point of the app. Here you will find all the root level Providers that expose functionality to the rest of the app like the Chakra component library, routing, authentication or API queries. -- `.neutrinorc.js` is the main config file. Although some custom typescript or linting may need to be changed in `tsconfig.json` or `.eslintrc.js`, respectively -- `src/components` are React components that will be shared across the app -- `src/views` are React components that are specific to a certain url route -- `src/interfaces` are custom-defined TypeScript types/interfaces -- `src/utils` contains various helper functions that are shared throughout the app -- `src/auth` has the Context for authentication -- `src/api` contains all of the actual API requests as custom hooks around react-query - -## Find open issues - -Take a look at our [project board](https://github.com/apache/airflow/projects/9) for unassigned issues in the `Next Up` column. If you're interested in one, leave a comment saying you'd like it to be assigned to you. diff --git a/airflow/ui/jest.config.js b/airflow/ui/jest.config.js deleted file mode 100644 index 2f89890f578b9..0000000000000 --- a/airflow/ui/jest.config.js +++ /dev/null @@ -1,28 +0,0 @@ -/*! - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* -* Most testing config happens in .neutrinorc.js instead -*/ -const neutrino = require('neutrino'); - -process.env.NODE_ENV = process.env.NODE_ENV || 'test'; -process.env.WEBSERVER_URL = process.env.WEBSERVER_URL || 'http://localhost:9999'; - -module.exports = neutrino().jest(); diff --git a/airflow/ui/package.json b/airflow/ui/package.json deleted file mode 100644 index 046cc6700f6f5..0000000000000 --- a/airflow/ui/package.json +++ /dev/null @@ -1,62 +0,0 @@ -{ - "name": "ui", - "version": "1.0.0", - "license": "Apache-2.0", - "scripts": { - "start": "webpack-dev-server --mode development --open", - "build": "webpack --mode production", - "test": "jest", - "lint": "eslint --format codeframe --ext mjs,jsx,js,tsx,ts src test && tsc" - }, - "dependencies": { - "@chakra-ui/react": "^1.6.1", - "@emotion/react": "^11.1.5", - "@emotion/styled": "^11.1.5", - "@neutrinojs/copy": "^9.5.0", - "@vvo/tzdb": "^6.7.0", - "axios": "^0.21.2", - "dayjs": "^1.10.4", - "dotenv": "^8.2.0", - "framer-motion": "^3.10.0", - "humps": "^2.0.1", - "react": "^16", - "react-dom": "^16", - "react-hot-loader": "^4", - "react-icons": "^4.2.0", - "react-query": "^3.12.3", - "react-router-dom": "^5.2.0", - "react-select": "^4.3.0", - "react-table": "^7.7.0", - "use-react-router": "^1.0.7" - }, - "devDependencies": { - "@neutrinojs/eslint": "^9.5.0", - "@neutrinojs/jest": "^9.5.0", - "@neutrinojs/react": "^9.5.0", - "@testing-library/jest-dom": "^5.11.9", - "@testing-library/react": "^11.2.5", - "@types/humps": "^2.0.0", - "@types/jest": "^26.0.20", - "@types/react": "^17.0.3", - "@types/react-dom": "^17.0.2", - "@types/react-router-dom": "^5.1.7", - "@types/react-select": "^4.0.15", - "@types/react-table": "^7.7.0", - "eslint": "^7", - "eslint-config-airbnb-typescript": "^12.3.1", - "eslint-plugin-import": "^2.22.1", - "eslint-plugin-jsx-a11y": "^6.4.1", - "eslint-plugin-react-hooks": "^4.2.0", - "history": "^5.0.0", - "jest": "^26", - "neutrino": "^9.5.0", - "neutrinojs-typescript": "^1.1.6", - "neutrinojs-typescript-eslint": "^1.3.1", - "nock": "^13.0.11", - "react-test-renderer": "^17.0.1", - "typescript": "^4.2.3", - "webpack": "^4", - "webpack-cli": "^3", - "webpack-dev-server": "^3" - } -} diff --git a/airflow/ui/src/App.tsx b/airflow/ui/src/App.tsx deleted file mode 100644 index 26ffb7557b730..0000000000000 --- a/airflow/ui/src/App.tsx +++ /dev/null @@ -1,114 +0,0 @@ -/*! - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -import { hot } from 'react-hot-loader'; -import React from 'react'; -import { Route, Redirect, Switch } from 'react-router-dom'; - -import PrivateRoute from 'providers/auth/PrivateRoute'; - -import Pipelines from 'views/Pipelines'; - -import Details from 'views/Pipeline/runs/Details'; -import Code from 'views/Pipeline/runs/Code'; -import TaskTries from 'views/Pipeline/runs/TaskTries'; -import TaskDuration from 'views/Pipeline/runs/TaskDuration'; -import LandingTimes from 'views/Pipeline/runs/LandingTimes'; - -import Graph from 'views/Pipeline/run/Graph'; -import Gantt from 'views/Pipeline/run/Gantt'; - -import TIDetails from 'views/Pipeline/ti/Details'; -import RenderedTemplate from 'views/Pipeline/ti/RenderedTemplate'; -import RenderedK8s from 'views/Pipeline/ti/RenderedK8s'; -import Log from 'views/Pipeline/ti/Log'; -import XCom from 'views/Pipeline/ti/XCom'; - -import EventLogs from 'views/Activity/EventLogs'; -import Runs from 'views/Activity/Runs'; -import Jobs from 'views/Activity/Jobs'; -import TaskInstances from 'views/Activity/TaskInstances'; -import TaskReschedules from 'views/Activity/TaskReschedules'; -import SLAMisses from 'views/Activity/SLAMisses'; -import XComs from 'views/Activity/XComs'; - -import Config from 'views/Config'; -import Variables from 'views/Config/Variables'; -import Connections from 'views/Config/Connections'; -import Pools from 'views/Config/Pools'; - -import Access from 'views/Access'; -import Users from 'views/Access/Users'; -import Roles from 'views/Access/Roles'; -import Permissions from 'views/Access/Permissions'; - -import Docs from 'views/Docs'; -import NotFound from 'views/NotFound'; - -const App = () => ( - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -); - -export default hot(module)(App); diff --git a/airflow/ui/src/api/defaults.ts b/airflow/ui/src/api/defaults.ts deleted file mode 100644 index 54c61c74d6a81..0000000000000 --- a/airflow/ui/src/api/defaults.ts +++ /dev/null @@ -1,28 +0,0 @@ -/*! - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -export const defaultVersion = { version: '', gitVersion: '' }; - -export const defaultDags = { dags: [], totalEntries: 0 }; - -export const defaultDagRuns = { dagRuns: [], totalEntries: 0 }; - -export const defaultTaskInstances = { taskInstances: [], totalEntries: 0 }; - -export const defaultConfig = { sections: [] }; diff --git a/airflow/ui/src/api/index.ts b/airflow/ui/src/api/index.ts deleted file mode 100644 index ee3ab5b55eed3..0000000000000 --- a/airflow/ui/src/api/index.ts +++ /dev/null @@ -1,208 +0,0 @@ -/* eslint-disable no-console */ -/*! - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -import axios, { AxiosResponse } from 'axios'; -import { - useMutation, useQuery, useQueryClient, setLogger, -} from 'react-query'; -import humps from 'humps'; -import { useToast } from '@chakra-ui/react'; - -import type { - Config, Dag, DagRun, Version, -} from 'interfaces'; -import type { - DagsResponse, - DagRunsResponse, - TaskInstancesResponse, - TriggerRunRequest, -} from 'interfaces/api'; - -axios.defaults.baseURL = `${process.env.WEBSERVER_URL}/api/v1`; -axios.interceptors.response.use( - (res) => (res.data ? humps.camelizeKeys(res.data) as unknown as AxiosResponse : res), -); - -// turn off logging, retry and refetch on tests -const isTest = process.env.NODE_ENV === 'test'; - -setLogger({ - log: isTest ? () => {} : console.log, - warn: isTest ? () => {} : console.warn, - error: isTest ? () => {} : console.warn, -}); - -const toastDuration = 3000; -const refetchInterval = isTest ? false : 1000; - -interface PageProps { - offset?: number; - limit?: number -} - -export function useDags({ offset = 0, limit }: PageProps) { - return useQuery( - ['dags', offset], - (): Promise => axios.get('/dags', { - params: { offset, limit }, - }), - { - refetchInterval, - retry: !isTest, - }, - ); -} - -export function useDagRuns(dagId: Dag['dagId'], dateMin?: string) { - return useQuery( - ['dagRun', dagId], - (): Promise => axios.get(`dags/${dagId}/dagRuns${dateMin ? `?start_date_gte=${dateMin}` : ''}`), - { refetchInterval }, - ); -} - -export function useTaskInstances(dagId: Dag['dagId'], dagRunId: DagRun['dagRunId'], dateMin?: string) { - return useQuery( - ['taskInstance', dagRunId], - (): Promise => ( - axios.get(`dags/${dagId}/dagRuns/${dagRunId}/taskInstances${dateMin ? `?start_date_gte=${dateMin}` : ''}`) - ), - ); -} - -export function useVersion() { - return useQuery( - 'version', - (): Promise => axios.get('/version'), - ); -} - -export function useConfig() { - return useQuery('config', (): Promise => axios.get('/config')); -} - -export function useTriggerRun(dagId: Dag['dagId']) { - const queryClient = useQueryClient(); - const toast = useToast(); - return useMutation( - (trigger: TriggerRunRequest) => axios.post(`dags/${dagId}/dagRuns`, humps.decamelizeKeys(trigger)), - { - onSettled: (res, error) => { - if (error) { - toast({ - title: 'Error triggering DAG', - description: (error as Error).message, - status: 'error', - duration: toastDuration, - isClosable: true, - }); - } else { - toast({ - title: 'DAG Triggered', - status: 'success', - duration: toastDuration, - isClosable: true, - }); - const dagRunData = queryClient.getQueryData(['dagRun', dagId]) as unknown as DagRunsResponse; - if (dagRunData) { - queryClient.setQueryData(['dagRun', dagId], { - dagRuns: [...dagRunData.dagRuns, res], - totalEntries: dagRunData.totalEntries += 1, - }); - } else { - queryClient.setQueryData(['dagRun', dagId], { - dagRuns: [res], - totalEntries: 1, - }); - } - } - queryClient.invalidateQueries(['dagRun', dagId]); - }, - }, - ); -} - -export function useSaveDag(dagId: Dag['dagId'], offset: number) { - const queryClient = useQueryClient(); - const toast = useToast(); - return useMutation( - (updatedValues: Record) => axios.patch(`dags/${dagId}`, humps.decamelizeKeys(updatedValues)), - { - onMutate: async (updatedValues: Record) => { - await queryClient.cancelQueries(['dag', dagId]); - const previousDag = queryClient.getQueryData(['dag', dagId]) as Dag; - const previousDags = queryClient.getQueryData(['dags', offset]) as DagsResponse; - - const newDags = previousDags.dags.map((dag) => ( - dag.dagId === dagId ? { ...dag, ...updatedValues } : dag - )); - const newDag = { - ...previousDag, - ...updatedValues, - }; - - // optimistically set the dag before the async request - queryClient.setQueryData(['dag', dagId], () => newDag); - queryClient.setQueryData(['dags', offset], (old) => ({ - ...(old as Dag[]), - ...{ - dags: newDags, - totalEntries: previousDags.totalEntries, - }, - })); - return { [dagId]: previousDag, dags: previousDags }; - }, - onSettled: (res, error, variables, context) => { - const previousDag = (context as any)[dagId] as Dag; - const previousDags = (context as any).dags as DagsResponse; - // rollback to previous cache on error - if (error) { - queryClient.setQueryData(['dag', dagId], previousDag); - queryClient.setQueryData(['dags', offset], previousDags); - toast({ - title: 'Error updating pipeline', - description: (error as Error).message, - status: 'error', - duration: toastDuration, - isClosable: true, - }); - } else { - // check if server response is different from our optimistic update - if (JSON.stringify(res) !== JSON.stringify(previousDag)) { - queryClient.setQueryData(['dag', dagId], res); - queryClient.setQueryData(['dags', offset], { - dags: previousDags.dags.map((dag) => ( - dag.dagId === dagId ? res : dag - )), - totalEntries: previousDags.totalEntries, - }); - } - toast({ - title: 'Pipeline Updated', - status: 'success', - duration: toastDuration, - isClosable: true, - }); - } - queryClient.invalidateQueries(['dag', dagId]); - }, - }, - ); -} diff --git a/airflow/ui/src/components/AppContainer/AppHeader.tsx b/airflow/ui/src/components/AppContainer/AppHeader.tsx deleted file mode 100644 index ff3fd54c015c5..0000000000000 --- a/airflow/ui/src/components/AppContainer/AppHeader.tsx +++ /dev/null @@ -1,124 +0,0 @@ -/*! - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -import React from 'react'; -import { Link } from 'react-router-dom'; -import { - Avatar, - Flex, - Icon, - Menu, - MenuButton, - MenuDivider, - MenuList, - MenuItem, - useColorMode, - useColorModeValue, -} from '@chakra-ui/react'; -import { - MdWbSunny, - MdBrightness2, - MdAccountCircle, - MdExitToApp, - MdQueryBuilder, -} from 'react-icons/md'; - -import { useAuthContext } from 'providers/auth/context'; -import { useDateContext, HOURS_24 } from 'providers/DateProvider'; - -import ApacheAirflowLogo from 'components/icons/ApacheAirflowLogo'; -import TimezoneDropdown from './TimezoneDropdown'; - -interface Props { - bodyBg: string; - overlayBg: string; - breadcrumb?: React.ReactNode; -} - -const AppHeader: React.FC = ({ bodyBg, overlayBg, breadcrumb }) => { - const { toggleColorMode } = useColorMode(); - const { dateFormat, toggle24Hour } = useDateContext(); - const headerHeight = '56px'; - const { hasValidAuthToken, logout } = useAuthContext(); - const darkLightIcon = useColorModeValue(MdBrightness2, MdWbSunny); - const darkLightText = useColorModeValue(' Dark ', ' Light '); - - const handleOpenProfile = () => window.alert('This will take you to your user profile view.'); - - return ( - - {breadcrumb} - {!breadcrumb && ( - - - - )} - {hasValidAuthToken && ( - - - - - - - - - - Your Profile - - - - Set - {darkLightText} - Mode - - {/* Clock config should move to User Profile Settings when that page exists */} - - - Use - {dateFormat === HOURS_24 ? ' 12 hour ' : ' 24 hour '} - clock - - - - - Logout - - - - - )} - - ); -}; - -export default AppHeader; diff --git a/airflow/ui/src/components/AppContainer/AppNav.tsx b/airflow/ui/src/components/AppContainer/AppNav.tsx deleted file mode 100644 index 44aa322f29ec3..0000000000000 --- a/airflow/ui/src/components/AppContainer/AppNav.tsx +++ /dev/null @@ -1,111 +0,0 @@ -/*! - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -import React from 'react'; -import { Link } from 'react-router-dom'; -import { Box } from '@chakra-ui/react'; -import { - FiActivity, - FiBookOpen, - FiSettings, - FiUsers, -} from 'react-icons/fi'; - -import { useAuthContext } from 'providers/auth/context'; - -import PinwheelLogo from 'components/icons/PinwheelLogo'; -import PipelineIcon from 'components/icons/PipelineIcon'; - -import AppNavBtn from './AppNavBtn'; - -interface Props { - bodyBg: string; - overlayBg: string; -} - -const AppNav: React.FC = ({ bodyBg, overlayBg }) => { - const { hasValidAuthToken } = useAuthContext(); - - const navItems = [ - { - label: 'Pipelines', - icon: PipelineIcon, - path: '/pipelines', - activePath: '/pipelines', - }, - { - label: 'Activity', - icon: FiActivity, - path: '/activity/event-logs', - activePath: '/activity', - }, - { - label: 'Config', - icon: FiSettings, - path: '/config', - activePath: '/config', - }, - { - label: 'access', - icon: FiUsers, - path: '/access', - activePath: '/access', - }, - { - label: 'Docs', - icon: FiBookOpen, - path: '/docs', - activePath: '/docs', - }, - ]; - - return ( - - - - - {hasValidAuthToken && navItems.map((item) => ( - - ))} - - ); -}; - -export default AppNav; diff --git a/airflow/ui/src/components/AppContainer/AppNavBtn.tsx b/airflow/ui/src/components/AppContainer/AppNavBtn.tsx deleted file mode 100644 index b6e560f1affe7..0000000000000 --- a/airflow/ui/src/components/AppContainer/AppNavBtn.tsx +++ /dev/null @@ -1,87 +0,0 @@ -/*! - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -import React from 'react'; -import { Link, useLocation } from 'react-router-dom'; -import { - Box, - Icon, - Tooltip, -} from '@chakra-ui/react'; - -import type { IconType } from 'react-icons/lib'; - -interface Props { - navItem: { - label: string; - icon: IconType | typeof Icon; - path?: string; - activePath?: string; - href?: string; - }; -} - -const AppNavBtn: React.FC = ({ navItem }) => { - const location = useLocation(); - const { - label, icon, path, href, activePath, - } = navItem; - const isHome = activePath === '/'; - const isActive = activePath && ((isHome && location.pathname === '/') || (!isHome && location.pathname.includes(activePath))); - - return ( - - - - - - ); -}; - -export default AppNavBtn; diff --git a/airflow/ui/src/components/AppContainer/TimezoneDropdown.tsx b/airflow/ui/src/components/AppContainer/TimezoneDropdown.tsx deleted file mode 100644 index cb829f0370fe3..0000000000000 --- a/airflow/ui/src/components/AppContainer/TimezoneDropdown.tsx +++ /dev/null @@ -1,83 +0,0 @@ -/*! - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -import React, { useRef } from 'react'; -import { - Box, - Button, - Menu, - MenuButton, - MenuList, - Tooltip, -} from '@chakra-ui/react'; -import { getTimeZones } from '@vvo/tzdb'; - -import Select from 'components/MultiSelect'; -import { useDateContext } from 'providers/DateProvider'; - -interface Option { value: string, label: string } - -const TimezoneDropdown: React.FC = () => { - const { timezone, setTimezone, formatDate } = useDateContext(); - const menuRef = useRef(null); - - const timezones = getTimeZones(); - - let currentTimezone; - const options = timezones.map(({ name, currentTimeFormat, group }) => { - const label = `${currentTimeFormat.substring(0, 6)} ${name.replace(/_/g, ' ')}`; - if (name === timezone || group.includes(timezone)) currentTimezone = { label, value: name }; - return { label, value: name }; - }); - - const onChangeTimezone = (newTimezone: Option | null) => { - if (newTimezone) { - setTimezone(newTimezone.value); - // Close the dropdown on a successful change - menuRef?.current?.click(); - } - }; - - return ( - - - - - {formatDate()} - - - - - { - const inputStyles = useMultiStyleConfig('Input', {}); - return ( - - - {children} - - - ); - }, - MultiValueContainer: ({ - children, - innerRef, - innerProps, - data: { isFixed }, - }) => ( - - {children} - - ), - MultiValueLabel: ({ children, innerRef, innerProps }) => ( - - {children} - - ), - MultiValueRemove: ({ - children, innerRef, innerProps, data: { isFixed }, - }) => { - if (isFixed) { - return null; - } - - return ( - - {children} - - ); - }, - IndicatorSeparator: ({ innerProps }) => ( - - ), - ClearIndicator: ({ innerProps }) => ( - - ), - DropdownIndicator: ({ innerProps }) => { - const { addon } = useStyles(); - - return ( -
- -
- ); - }, - // Menu components - MenuPortal: ({ children, ...portalProps }) => ( - - {children} - - ), - Menu: ({ children, ...menuProps }) => { - const menuStyles = useMultiStyleConfig('Menu', {}); - return ( - - {children} - - ); - }, - MenuList: ({ - innerRef, children, maxHeight, - }) => { - const { list } = useStyles(); - return ( - - {children} - - ); - }, - GroupHeading: ({ innerProps, children }) => { - const { groupTitle } = useStyles(); - return ( - - {children} - - ); - }, - Option: ({ - innerRef, innerProps, children, isFocused, isDisabled, - }) => { - const { item } = useStyles(); - interface ItemProps extends CSSWithMultiValues { - _disabled: CSSWithMultiValues, - _focus: CSSWithMultiValues, - } - return ( - )._focus.bg : 'transparent', - ...(isDisabled && (item as RecursiveCSSObject)._disabled), - }} - ref={innerRef} - {...innerProps} - {...(isDisabled && { disabled: true })} - > - {children} - - ); - }, - }, - ...components, - }} - styles={{ - ...chakraStyles, - ...styles, - }} - theme={(baseTheme) => ({ - ...baseTheme, - borderRadius: chakraTheme.radii.md, - colors: { - ...baseTheme.colors, - neutral50: placeholderColor, // placeholder text color - neutral40: placeholderColor, // noOptionsMessage color - }, - })} - {...props} - /> - ); -}; - -export default MultiSelect; diff --git a/airflow/ui/src/components/PipelineBreadcrumb.tsx b/airflow/ui/src/components/PipelineBreadcrumb.tsx deleted file mode 100644 index c3ba68bb293c2..0000000000000 --- a/airflow/ui/src/components/PipelineBreadcrumb.tsx +++ /dev/null @@ -1,93 +0,0 @@ -/*! - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -import React from 'react'; -import { Link } from 'react-router-dom'; -import { - Box, Flex, Heading, useColorModeValue, -} from '@chakra-ui/react'; - -import type { - Dag as DagType, - DagRun as DagRunType, - Task as TaskType, -} from 'interfaces'; - -interface Props { - dagId: DagType['dagId']; - dagRunId?: DagRunType['dagRunId']; - taskId?: TaskType['taskId']; -} - -const PipelineBreadcrumb: React.FC = ({ dagId, dagRunId, taskId }) => { - const dividerColor = useColorModeValue('gray.100', 'gray.700'); - - return ( - - - PIPELINE - - {!dagRunId && dagId} - {dagRunId && ( - - {dagId} - - )} - - - {dagRunId && ( - <> - / - - RUN - - {!taskId && dagRunId} - {taskId && ( - - {dagRunId} - - )} - - - - )} - {taskId && ( - <> - / - - TASK INSTANCE - {taskId} - - - )} - - ); -}; - -export default PipelineBreadcrumb; diff --git a/airflow/ui/src/components/SectionNav.tsx b/airflow/ui/src/components/SectionNav.tsx deleted file mode 100644 index 47ae594ff7cbe..0000000000000 --- a/airflow/ui/src/components/SectionNav.tsx +++ /dev/null @@ -1,61 +0,0 @@ -/*! - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -import React from 'react'; -import { - Box, - useColorModeValue, -} from '@chakra-ui/react'; - -import SectionNavBtn from 'components/SectionNavBtn'; - -interface Props { - currentView: string; - navItems: { - label: string; - path: string; - }[] -} - -const SectionNav: React.FC = ({ currentView, navItems }) => { - const bg = useColorModeValue('gray.100', 'gray.700'); - return ( - - - - {navItems.map((item) => ( - - ))} - - - - ); -}; - -export default SectionNav; diff --git a/airflow/ui/src/components/SectionNavBtn.tsx b/airflow/ui/src/components/SectionNavBtn.tsx deleted file mode 100644 index 16ab65f70ef43..0000000000000 --- a/airflow/ui/src/components/SectionNavBtn.tsx +++ /dev/null @@ -1,48 +0,0 @@ -/*! - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -import React from 'react'; -import { Link } from 'react-router-dom'; -import { Button } from '@chakra-ui/react'; - -interface Props { - item: { - label: string; - path: string; - }; - currentView: string; -} - -const SectionNavBtn: React.FC = ({ item, currentView }) => { - const { label, path } = item; - return ( - - ); -}; - -export default SectionNavBtn; diff --git a/airflow/ui/src/components/SectionWrapper.tsx b/airflow/ui/src/components/SectionWrapper.tsx deleted file mode 100644 index 4f8e145d554f0..0000000000000 --- a/airflow/ui/src/components/SectionWrapper.tsx +++ /dev/null @@ -1,86 +0,0 @@ -/*! - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -import React from 'react'; -import { - Box, - Heading, - useColorModeValue, -} from '@chakra-ui/react'; - -import AppContainer from 'components/AppContainer'; -import SectionNav from 'components/SectionNav'; - -interface Props { - currentSection: string; - currentView: string; - navItems: { - label: string; - path: string; - }[] - toolBar?: React.ReactNode; -} - -const SectionWrapper: React.FC = ({ - children, currentSection, currentView, navItems, toolBar, -}) => { - const heading = useColorModeValue('gray.400', 'gray.500'); - const border = useColorModeValue('gray.100', 'gray.700'); - const toolbarBg = useColorModeValue('white', 'gray.800'); - return ( - - - {currentSection} - / - - {currentView} - - )} - > - - {toolBar && ( - - {toolBar} - - )} - {children} - - ); -}; - -export default SectionWrapper; diff --git a/airflow/ui/src/components/Table.tsx b/airflow/ui/src/components/Table.tsx deleted file mode 100644 index bcee2fa74405a..0000000000000 --- a/airflow/ui/src/components/Table.tsx +++ /dev/null @@ -1,187 +0,0 @@ -/*! - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * Custom wrapper of react-table using Chakra UI components -*/ - -import React, { useEffect } from 'react'; -import { - Flex, - Table as ChakraTable, - Thead, - Tbody, - Tr, - Th, - Td, - IconButton, - Text, - useColorModeValue, -} from '@chakra-ui/react'; -import { - useTable, useSortBy, Column, usePagination, SortingRule, -} from 'react-table'; -import { - MdKeyboardArrowLeft, MdKeyboardArrowRight, -} from 'react-icons/md'; -import { - TiArrowUnsorted, TiArrowSortedDown, TiArrowSortedUp, -} from 'react-icons/ti'; - -interface Props { - data: any[]; - columns: Column[]; - /* - * manualPagination is when you need to do server-side pagination. - * Leave blank for client-side only - */ - manualPagination?: { - offset: number; - setOffset: (off: number) => void; - totalEntries: number; - }; - /* - * setSortBy is for custom sorting such as server-side sorting - */ - setSortBy?: (sortBy: SortingRule[]) => void; - pageSize?: number; -} - -const Table: React.FC = ({ - data, columns, manualPagination, pageSize = 25, setSortBy, -}) => { - const { totalEntries, offset, setOffset } = manualPagination || {}; - const oddColor = useColorModeValue('gray.50', 'gray.900'); - const hoverColor = useColorModeValue('gray.100', 'gray.700'); - - const pageCount = totalEntries ? (Math.ceil(totalEntries / pageSize) || 1) : data.length; - - const lowerCount = (offset || 0) + 1; - const upperCount = lowerCount + data.length - 1; - - const { - getTableProps, - getTableBodyProps, - allColumns, - prepareRow, - page, - canPreviousPage, - canNextPage, - nextPage, - previousPage, - state: { pageIndex, sortBy }, - } = useTable( - { - columns, - data, - pageCount, - manualPagination: !!manualPagination, - manualSortBy: !!setSortBy, - initialState: { - pageIndex: offset ? offset / pageSize : 0, - pageSize, - }, - }, - useSortBy, - usePagination, - ); - - const handleNext = () => { - nextPage(); - if (setOffset) setOffset((pageIndex + 1) * pageSize); - }; - - const handlePrevious = () => { - previousPage(); - if (setOffset) setOffset((pageIndex - 1 || 0) * pageSize); - }; - - useEffect(() => { - if (setSortBy) setSortBy(sortBy); - }, [sortBy, setSortBy]); - - return ( - <> - -
- - {allColumns.map((column) => ( - - ))} - - - - {!data.length && ( - - - - )} - {page.map((row) => { - prepareRow(row); - return ( - - {row.cells.map((cell) => ( - - ))} - - ); - })} - - - - - - - - - - - {lowerCount} - - - {upperCount} - {' of '} - {totalEntries} - - - - ); -}; - -export default Table; diff --git a/airflow/ui/src/components/TriggerRunModal.tsx b/airflow/ui/src/components/TriggerRunModal.tsx deleted file mode 100644 index c988c82e7d5d8..0000000000000 --- a/airflow/ui/src/components/TriggerRunModal.tsx +++ /dev/null @@ -1,83 +0,0 @@ -/*! - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -import React, { ChangeEvent, useState } from 'react'; -import { - Button, - FormControl, - FormLabel, - Modal, - ModalHeader, - ModalFooter, - ModalCloseButton, - ModalOverlay, - ModalContent, - ModalBody, - Textarea, -} from '@chakra-ui/react'; - -import type { Dag } from 'interfaces'; -import { useTriggerRun } from 'api'; - -interface Props { - dagId: Dag['dagId']; - isOpen: boolean; - onClose: () => void; -} - -const TriggerRunModal: React.FC = ({ dagId, isOpen, onClose }) => { - const mutation = useTriggerRun(dagId); - const [config, setConfig] = useState('{}'); - - const onTrigger = () => { - mutation.mutate({ - conf: JSON.parse(config), - executionDate: new Date(), - }); - onClose(); - }; - - return ( - - - - - Trigger Run: - {' '} - {dagId} - - - - - Configuration JSON (Optional) - ', @@ -197,14 +201,48 @@ def test_trigger_dag_params_conf(admin_client, request_conf, expected_conf): ) +def test_trigger_dag_params_render(admin_client, dag_maker, session, app, monkeypatch): + """ + Test that textarea in Trigger DAG UI is pre-populated + with param value set in DAG. + """ + account = {"name": "account_name_1", "country": "usa"} + expected_conf = {"accounts": [account]} + expected_dag_conf = json.dumps(expected_conf, indent=4).replace('"', """) + DAG_ID = "params_dag" + param = Param( + [account], + schema={ + "type": "array", + "minItems": 1, + "items": { + "type": "object", + "default": account, + "properties": {"name": {"type": "string"}, "country": {"type": "string"}}, + "required": ["name", "country"], + }, + }, + ) + with monkeypatch.context() as m: + with dag_maker(dag_id=DAG_ID, serialized=True, session=session, params={"accounts": param}): + EmptyOperator(task_id="task1") + + m.setattr(app, "dag_bag", dag_maker.dagbag) + resp = admin_client.get(f"trigger?dag_id={DAG_ID}") + + check_content_in_response( + f'', resp + ) + + def test_trigger_endpoint_uses_existing_dagbag(admin_client): """ Test that Trigger Endpoint uses the DagBag already created in views.py instead of creating a new one. """ - url = 'trigger?dag_id=example_bash_operator' + url = "trigger?dag_id=example_bash_operator" resp = admin_client.post(url, data={}, follow_redirects=True) - check_content_in_response('example_bash_operator', resp) + check_content_in_response("example_bash_operator", resp) def test_viewer_cant_trigger_dag(app): @@ -221,7 +259,7 @@ def test_viewer_cant_trigger_dag(app): (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), ], ) as client: - url = 'trigger?dag_id=example_bash_operator' + url = "trigger?dag_id=example_bash_operator" resp = client.get(url, follow_redirects=True) response_data = resp.data.decode() assert "Access is Denied" in response_data diff --git a/tests/www/views/test_views_variable.py b/tests/www/views/test_views_variable.py index 8da4090d674e9..aca8e2aeeebdd 100644 --- a/tests/www/views/test_views_variable.py +++ b/tests/www/views/test_views_variable.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import io from unittest import mock @@ -24,13 +26,18 @@ from airflow.security import permissions from airflow.utils.session import create_session from tests.test_utils.api_connexion_utils import create_user -from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login +from tests.test_utils.www import ( + _check_last_log, + check_content_in_response, + check_content_not_in_response, + client_with_login, +) VARIABLE = { - 'key': 'test_key', - 'val': 'text_val', - 'description': 'test_description', - 'is_encrypted': True, + "key": "test_key", + "val": "text_val", + "description": "test_description", + "is_encrypted": True, } @@ -63,18 +70,18 @@ def client_variable_reader(app, user_variable_reader): def test_can_handle_error_on_decrypt(session, admin_client): # create valid variable - admin_client.post('/variable/add', data=VARIABLE, follow_redirects=True) + admin_client.post("/variable/add", data=VARIABLE, follow_redirects=True) # update the variable with a wrong value, given that is encrypted - session.query(Variable).filter(Variable.key == VARIABLE['key']).update( - {'val': 'failed_value_not_encrypted'}, + session.query(Variable).filter(Variable.key == VARIABLE["key"]).update( + {"val": "failed_value_not_encrypted"}, synchronize_session=False, ) session.commit() # retrieve Variables page, should not fail and contain the Invalid # label for the variable - resp = admin_client.get('/variable/list', follow_redirects=True) + resp = admin_client.get("/variable/list", follow_redirects=True) check_content_in_response( 'Invalid', resp, @@ -88,66 +95,67 @@ def test_xss_prevention(admin_client): def test_import_variables_no_file(admin_client): - resp = admin_client.post('/variable/varimport', follow_redirects=True) - check_content_in_response('Missing file or syntax error.', resp) + resp = admin_client.post("/variable/varimport", follow_redirects=True) + check_content_in_response("Missing file or syntax error.", resp) def test_import_variables_failed(session, admin_client): content = '{"str_key": "str_value"}' - with mock.patch('airflow.models.Variable.set') as set_mock: + with mock.patch("airflow.models.Variable.set") as set_mock: set_mock.side_effect = UnicodeEncodeError assert session.query(Variable).count() == 0 - bytes_content = io.BytesIO(bytes(content, encoding='utf-8')) + bytes_content = io.BytesIO(bytes(content, encoding="utf-8")) resp = admin_client.post( - '/variable/varimport', data={'file': (bytes_content, 'test.json')}, follow_redirects=True + "/variable/varimport", data={"file": (bytes_content, "test.json")}, follow_redirects=True ) - check_content_in_response('1 variable(s) failed to be updated.', resp) + check_content_in_response("1 variable(s) failed to be updated.", resp) def test_import_variables_success(session, admin_client): assert session.query(Variable).count() == 0 content = '{"str_key": "str_value", "int_key": 60, "list_key": [1, 2], "dict_key": {"k_a": 2, "k_b": 3}}' - bytes_content = io.BytesIO(bytes(content, encoding='utf-8')) + bytes_content = io.BytesIO(bytes(content, encoding="utf-8")) resp = admin_client.post( - '/variable/varimport', data={'file': (bytes_content, 'test.json')}, follow_redirects=True + "/variable/varimport", data={"file": (bytes_content, "test.json")}, follow_redirects=True ) - check_content_in_response('4 variable(s) successfully updated.', resp) + check_content_in_response("4 variable(s) successfully updated.", resp) + _check_last_log(session, dag_id=None, event="variables.varimport", execution_date=None) def test_import_variables_anon(session, app): assert session.query(Variable).count() == 0 content = '{"str_key": "str_value}' - bytes_content = io.BytesIO(bytes(content, encoding='utf-8')) + bytes_content = io.BytesIO(bytes(content, encoding="utf-8")) resp = app.test_client().post( - '/variable/varimport', data={'file': (bytes_content, 'test.json')}, follow_redirects=True + "/variable/varimport", data={"file": (bytes_content, "test.json")}, follow_redirects=True ) - check_content_not_in_response('variable(s) successfully updated.', resp) - check_content_in_response('Sign In', resp) + check_content_not_in_response("variable(s) successfully updated.", resp) + check_content_in_response("Sign In", resp) def test_import_variables_form_shown(app, admin_client): - resp = admin_client.get('/variable/list/') - check_content_in_response('Import Variables', resp) + resp = admin_client.get("/variable/list/") + check_content_in_response("Import Variables", resp) def test_import_variables_form_hidden(app, client_variable_reader): - resp = client_variable_reader.get('/variable/list/') - check_content_not_in_response('Import Variables', resp) + resp = client_variable_reader.get("/variable/list/") + check_content_not_in_response("Import Variables", resp) def test_description_retrieval(session, admin_client): # create valid variable - admin_client.post('/variable/add', data=VARIABLE, follow_redirects=True) + admin_client.post("/variable/add", data=VARIABLE, follow_redirects=True) row = session.query(Variable.key, Variable.description).first() - assert row.key == 'test_key' and row.description == 'test_description' + assert row.key == "test_key" and row.description == "test_description" @pytest.fixture()
- {column.render('Header')} - {column.isSorted && ( - column.isSortedDesc ? ( - - ) : ( - - ) - )} - {(!column.isSorted && column.canSort) && ()} -
No Data found.
- {cell.render('Cell')} -